├── .gitignore ├── LICENSE.txt ├── README.md ├── demo.ipynb ├── demo.py ├── download_model.sh ├── images └── demo.png ├── init_paths.py ├── lib ├── __init__.py └── utils │ ├── __init__.py │ ├── benchmark_utils.py │ ├── io.py │ ├── ops.py │ ├── queue_runner.py │ └── util.py ├── load_models.py ├── models ├── __init__.py └── exif │ ├── __init__.py │ ├── exif_net.py │ └── exif_solver.py ├── ncuts_demo.py ├── nets ├── __init__.py ├── resnet_utils.py ├── resnet_v1.py └── resnet_v2.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.jpg 4 | *.png 5 | *.ckpt 6 | *.ckpt.meta 7 | *.ckpt.index 8 | *.npy 9 | *swp* 10 | /tmp 11 | ./ipynb_checkpoints 12 | */.ipynb_checkpoints 13 | /tb 14 | /ckpt 15 | /output 16 | /results 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fighting Fake News: Image Splice Detection via Learned Self-Consistency 2 | ### [[paper]](https://arxiv.org/pdf/1805.04096.pdf) [[website]](https://minyoungg.github.io/selfconsistency/) 3 | 4 | [Minyoung Huh *12](https://minyounghuh.com), [Andrew Liu *1](http://andrewhliu.github.io/), [Andrew Owens1](http://andrewowens.com/), [Alexei A. Efros1](https://people.eecs.berkeley.edu/~efros/) 5 | In [ECCV](https://eccv2018.org/) 2018. 6 | UC Berkeley, Berkeley AI Research1 7 | Carnegie Mellon University2 8 | ### Abstract 9 | In this paper, we introduce a self-supervised method for 10 | learning to detect visual manipulations using only unlabeled data. Given a large collection of real photographs with automatically recorded EXIF meta-data, we train a model to determine whether an image is self-consistent -- that is, whether its content could have been produced by a single imaging pipeline. 11 | 12 | ### 1) Prerequisites 13 | First clone this repo 14 | ```git clone --single-branch https://github.com/minyoungg/selfconsistency``` 15 | 16 | All prerequisites should be listed in requirements.txt. The code is written on TensorFlow and is run on Python2.7, we have not verified whether Python3 works. The following command should automatically load any necessary requirements: 17 | ```bash pip install -r requirements.txt``` 18 | 19 | ### 2) Downloading pretrained model 20 | To download our pretrained-model run the following script in the terminal: 21 | ```chmod 755 download_model.sh && ./download_model.sh ``` 22 | 23 | ### 3) Demo 24 | To run our model on an image run the following code: 25 | ``` python demo.py --im_path=./images/demo.png``` 26 | 27 | We also provide a normalized cut implementation by running the code: 28 | ``` python ncuts_demo.py --im_path=./images/ncuts_demo.png``` 29 | 30 | We have setup a ipython notebook demo [here](demo.ipynb) 31 | Disclaimer: Our model works the best on high-resolution natural images. Frames from videos do not generally work well. 32 | 33 | ### Citation 34 | If you find our work useful, please cite: 35 | ``` 36 | @inproceedings{huh18forensics, 37 | title = {Fighting Fake News: Image Splice Detection via Learned Self-Consistency} 38 | author = {Huh, Minyoung and Liu, Andrew and 39 | Owens, Andrew and Efros, Alexei A.}, 40 | booktitle = {ECCV}, 41 | year = {2018} 42 | } 43 | ``` 44 | 45 | ## Questions 46 | For any further questions please contact Minyoung Huh or Andrew Liu 47 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os, sys, numpy as np, ast 5 | import init_paths 6 | import load_models 7 | from lib.utils import benchmark_utils, util 8 | import tensorflow as tf 9 | import cv2, time, scipy, scipy.misc as scm, sklearn.cluster, skimage.io as skio, numpy as np, argparse 10 | import matplotlib.pyplot as plt 11 | from sklearn.cluster import DBSCAN 12 | 13 | def mean_shift(points_, heat_map, iters=5): 14 | points = np.copy(points_) 15 | kdt = scipy.spatial.cKDTree(points) 16 | eps_5 = np.percentile(scipy.spatial.distance.cdist(points, points, metric='euclidean'), 10) 17 | 18 | for epis in range(iters): 19 | for point_ind in range(points.shape[0]): 20 | point = points[point_ind] 21 | nearest_inds = kdt.query_ball_point(point, r=eps_5) 22 | points[point_ind] = np.mean(points[nearest_inds], axis=0) 23 | val = [] 24 | for i in range(points.shape[0]): 25 | val.append(kdt.count_neighbors(scipy.spatial.cKDTree(np.array([points[i]])), r=eps_5)) 26 | mode_ind = np.argmax(val) 27 | ind = np.nonzero(val == np.max(val)) 28 | return np.mean(points[ind[0]], axis=0).reshape(heat_map.shape[0], heat_map.shape[1]) 29 | 30 | def centroid_mode(heat_map): 31 | eps_thresh = np.percentile(heat_map, 10) 32 | k = heat_map <= eps_thresh 33 | # Get's max centroid 34 | num_affinities = np.sum(k, axis=(2, 3)) 35 | x = np.nonzero(num_affinities >= np.max(num_affinities)) 36 | if type(x) is tuple: 37 | ind1 = x[0][0] 38 | ind2 = x[1][0] 39 | else: 40 | ind1 = x[0] 41 | ind2 = x[1] 42 | assert np.max(num_affinities) == num_affinities[ind1, ind2] 43 | return heat_map[ind1, ind2] 44 | 45 | def normalized_cut(res): 46 | sc = sklearn.cluster.SpectralClustering(n_clusters=2, n_jobs=-1, 47 | affinity="precomputed") 48 | out = sc.fit_predict(res.reshape((res.shape[0] * res.shape[1], -1))) 49 | vis = out.reshape((res.shape[0], res.shape[1])) 50 | return vis 51 | def process_response_no_resize(response): 52 | return 255 * plt.cm.jet(response)[:,:,:3] 53 | 54 | def process_response(response): 55 | size = get_resized_shape(response) 56 | im = 255 * plt.cm.jet(response)[:,:,:3] 57 | return scm.imresize(im, size)# , interp='nearest') 58 | 59 | def get_resized_shape(im, max_im_dim=400): 60 | ratio = float(max_im_dim) / np.max(im.shape) 61 | return (int(im.shape[0] * ratio), int(im.shape[1] * ratio), 3) 62 | 63 | def process_image(im): 64 | size = get_resized_shape(im) 65 | return scm.imresize(im, size) #, interp='nearest') 66 | 67 | def norm(response): 68 | res = response - np.min(response) 69 | return res/np.max(res) 70 | 71 | def apply_mask(im, mask): 72 | mask = scipy.misc.imresize(mask, (im.shape[0], im.shape[1])) / 255. 73 | mask = mask.reshape(im.shape[0], im.shape[1], 1) 74 | mask = mask * 0.8 + 0.2 75 | return mask * im 76 | 77 | def aff_fn(v1, v2): 78 | return np.mean((v1 * v2 + (1 - v1)*(1 - v2))) 79 | 80 | def ssd_distance(results, with_inverse=True): 81 | def ssd(x, y): 82 | # uses mean instead 83 | return np.mean(np.square(x - y)) 84 | 85 | results = np.array(results) 86 | results = np.concatenate([results, 1.0 - results], axis=0) 87 | 88 | dist_matrix = np.zeros((len(results), len(results))) 89 | for i, r_x in enumerate(results): 90 | for j, r_y in enumerate(results): 91 | score = ssd(r_x, r_y) 92 | dist_matrix[i][j] = score 93 | return dist_matrix, results 94 | 95 | def dbscan_consensus(results, eps_range=(0.1, 0.5), eps_sample=10, dbscan_sample=4): 96 | """ 97 | Slowly increases DBSCAN epsilon until a cluster is found. 98 | The distance between responses is the SSD. 99 | Best prediction is based on the spread within the cluster. 100 | Here spread is the average per-pixel variance of the output. 101 | The cluster is then combined using the median of the cluster. 102 | When no cluster is found, returns the response 103 | that has smallest median score across other responses. 104 | """ 105 | 106 | dist_matrix, results = ssd_distance(results, with_inverse=True) 107 | 108 | debug = False #True 109 | lowest_spread = 100.0 110 | best_pred = None 111 | 112 | for eps in np.linspace(eps_range[0], eps_range[1], eps_sample): 113 | db = DBSCAN(eps=eps, min_samples=dbscan_sample).fit(dist_matrix) 114 | labels = set(db.labels_) 115 | 116 | if debug: 117 | print('DBSCAN with epsilon %.3f' % eps) 118 | print('Found %i labels' % len(labels)) 119 | 120 | try: 121 | labels.remove(-1) 122 | except: 123 | pass 124 | 125 | if debug: 126 | print('%i Unique cluster' % len(labels)) 127 | labels = np.array(list(labels)) 128 | 129 | if len(labels) < 2: 130 | if debug: 131 | print('Not enough cluster found') 132 | continue 133 | 134 | clusters = {l:np.argwhere(db.labels_ == l) for l in labels} 135 | cluster_spreads = {} 136 | cluster_preds = {} 137 | 138 | for lbl, cluster_indices in clusters.items(): 139 | if debug: 140 | print('Cluster %i with %i samples' % (lbl, len(cluster_indices))) 141 | 142 | cluster_indices = np.squeeze(cluster_indices) 143 | cluster_results = [results[i] for i in cluster_indices] 144 | 145 | #mean_result = np.mean(cluster_results, axis=0) 146 | median_result = np.median(cluster_results, axis=0) 147 | 148 | # Average Per pixel deviation 149 | average_spread = np.mean(np.std(cluster_results, axis=0)) 150 | cluster_spreads[lbl] = average_spread 151 | cluster_preds[lbl] = median_result 152 | #print average_spread 153 | if average_spread < lowest_spread: 154 | lowest_spread = average_spread 155 | best_pred = median_result 156 | 157 | best_lbl, avg_spread = util.sort_dict(cluster_spreads)[0] 158 | 159 | if debug: 160 | print('Cluster spread %.3f' % avg_spread) 161 | plt.imshow(cluster_preds[best_lbl], cmap='jet', vmin=0.0, vmax=1.0) 162 | plt.show() 163 | 164 | if best_pred is None: 165 | # Uses a sample that has the median minimum distance between all predicted sample 166 | print('Failed to find DBSCAN cluster') 167 | compact_dist_matrix = dist_matrix[:len(dist_matrix)//2, :len(dist_matrix)//2] 168 | avg_dist = np.median(compact_dist_matrix, axis=0) 169 | best_pred = results[np.argmin(avg_dist)] 170 | 171 | if debug: 172 | plt.figure() 173 | plt.imshow(best_pred, cmap='jet', vmin=0.0, vmax=1.0) 174 | return best_pred, lowest_spread 175 | 176 | def run_vote_no_threads(image, solver, exif_to_use, n_anchors=1, num_per_dim=None, 177 | patch_size=None, batch_size=None, sample_ratio=3.0, override_anchor=False): 178 | """ 179 | solver: exif_solver module. Must be initialized and have a network connected. 180 | exif_to_use: exif to extract responses from. A list. If exif_to_use is None 181 | extract result from classification output cls_pred 182 | n_anchors: number of anchors to use. 183 | num_per_dim: number of patches to use along the largest dimension. 184 | patch_size: size of the patch. If None, uses the one specified in solver.net 185 | batch_size: size of the batch. If None, uses the one specified in solver.net 186 | sample_ratio: The ratio of overlap between patches. num_per_dim must be None 187 | to be useful. 188 | """ 189 | 190 | h, w = np.shape(image)[:2] 191 | 192 | if patch_size is None: 193 | patch_size = solver.net.im_size 194 | 195 | if batch_size is None: 196 | batch_size = solver.net.batch_size 197 | 198 | if num_per_dim is None: 199 | num_per_dim = int(np.ceil(sample_ratio * (max(h,w)/float(patch_size)))) 200 | 201 | if exif_to_use is None: 202 | not_exif = True 203 | exif_to_use = ['out'] 204 | else: 205 | not_exif = False 206 | exif_map = {e: np.squeeze(np.argwhere(np.array(solver.net.train_runner.tags) == e)) for e in exif_to_use} 207 | 208 | responses = {e:np.zeros((n_anchors, h, w)) for e in exif_to_use} 209 | vote_counts = {e:1e-6 * np.ones((n_anchors, h, w)) for e in exif_to_use} 210 | 211 | if np.min(image) < 0.0: 212 | # already preprocessed 213 | processed_image = image 214 | else: 215 | processed_image = util.process_im(image) 216 | ones = np.ones((patch_size, patch_size)) 217 | 218 | anchor_indices = [] 219 | # select n anchors 220 | for anchor_idx in range(n_anchors): 221 | if override_anchor is False: 222 | _h, _w = np.random.randint(0, h - patch_size), np.random.randint(0, w - patch_size) 223 | else: 224 | assert len(override_anchor) == 2, override_anchor 225 | _h, _w = override_anchor 226 | 227 | anchor_indices.append((_h, _w)) 228 | anchor_patch = processed_image[_h:_h+patch_size, _w:_w+patch_size, :] 229 | 230 | batch_a = np.tile([anchor_patch], [batch_size, 1, 1, 1]) 231 | batch_b, batch_b_coord = [], [] 232 | 233 | prev_batch = None 234 | for i in np.linspace(0, h - patch_size, num_per_dim).astype(int): 235 | for j in np.linspace(0, w - patch_size, num_per_dim).astype(int): 236 | compare_patch = processed_image[i:i+patch_size, j:j+patch_size] 237 | batch_b.append(compare_patch) 238 | batch_b_coord.append((i,j)) 239 | 240 | if len(batch_b) == batch_size: 241 | if not_exif: 242 | pred = solver.sess.run(solver.net.cls_pred, 243 | feed_dict={solver.net.im_a:batch_a, 244 | solver.net.im_b:batch_b, 245 | solver.net.is_training:False}) 246 | else: 247 | pred = solver.sess.run(solver.net.pred, 248 | feed_dict={solver.net.im_a:batch_a, 249 | solver.net.im_b:batch_b, 250 | solver.net.is_training:False}) 251 | 252 | for p_vec, (_i, _j) in zip(pred, batch_b_coord): 253 | for e in exif_to_use: 254 | if not_exif: 255 | p = p_vec[0] 256 | else: 257 | p = p_vec[int(exif_map[e])] 258 | responses[e][anchor_idx, _i:_i+patch_size, _j:_j+patch_size] += (p * ones) 259 | vote_counts[e][anchor_idx, _i:_i+patch_size, _j:_j+patch_size] += ones 260 | prev_batch = batch_b 261 | batch_b, batch_b_coord = [], [] 262 | 263 | if len(batch_b) > 0: 264 | batch_b_len = len(batch_b) 265 | to_pad = np.array(prev_batch)[:batch_size - batch_b_len] 266 | batch_b = np.concatenate([batch_b, to_pad], axis=0) 267 | 268 | if not_exif: 269 | pred = solver.sess.run(solver.net.cls_pred, 270 | feed_dict={solver.net.im_a:batch_a, 271 | solver.net.im_b:batch_b, 272 | solver.net.is_training:False}) 273 | else: 274 | pred = solver.sess.run(solver.net.pred, 275 | feed_dict={solver.net.im_a:batch_a, 276 | solver.net.im_b:batch_b, 277 | solver.net.is_training:False}) 278 | 279 | for p_vec, (_i, _j) in zip(pred, batch_b_coord): 280 | for e in exif_to_use: 281 | if not_exif: 282 | p = p_vec[0] 283 | else: 284 | p = p_vec[int(exif_map[e])] 285 | responses[e][anchor_idx, _i:_i+patch_size, _j:_j+patch_size] += (p * ones) 286 | vote_counts[e][anchor_idx, _i:_i+patch_size, _j:_j+patch_size] += ones 287 | 288 | return {e: {'responses':(responses[e] / vote_counts[e]), 'anchors':anchor_indices} for e in exif_to_use} 289 | 290 | class Demo(): 291 | def __init__(self, ckpt_path='/data/scratch/minyoungg/ckpt/exif_medifor/exif_medifor.ckpt', use_gpu=0, 292 | quality=3.0, patch_size=128, num_per_dim=30): 293 | self.quality = quality # sample ratio 294 | self.solver, nc, params = load_models.initialize_exif(ckpt=ckpt_path, init=False, use_gpu=use_gpu) 295 | params["im_size"] = patch_size 296 | self.im_size = patch_size 297 | tf.reset_default_graph() 298 | im = np.zeros((256, 256, 3)) 299 | self.bu = benchmark_utils.EfficientBenchmark(self.solver, nc, params, im, auto_close_sess=False, 300 | mirror_pred=False, dense_compute=False, stride=None, n_anchors=10, 301 | patch_size=patch_size, num_per_dim=num_per_dim) 302 | return 303 | 304 | def run(self, im, gt=None, show=False, save=False, 305 | blue_high=False, use_ncuts=False): 306 | # run for every new image 307 | self.bu.reset_image(im) 308 | res = self.bu.precomputed_analysis_vote_cls(num_fts=4096) 309 | #print('result shape', np.shape(res)) 310 | ms = mean_shift(res.reshape((-1, res.shape[0] * res.shape[1])), res) 311 | 312 | if np.mean(ms > .5) > .5: 313 | # majority of the image is above .5 314 | if blue_high: 315 | ms = 1 - ms 316 | 317 | if use_ncuts: 318 | 319 | ncuts = normalized_cut(res) 320 | if np.mean(ncuts > .5) > .5: 321 | # majority of the image is white 322 | # flip so spliced is white 323 | ncuts = 1 - ncuts 324 | out_ncuts = cv2.resize(ncuts.astype(np.float32), (im.shape[1], im.shape[0]), 325 | interpolation=cv2.INTER_LINEAR) 326 | 327 | out_ms = cv2.resize(ms, (im.shape[1], im.shape[0]), interpolation=cv2.INTER_LINEAR) 328 | 329 | 330 | if use_ncuts: 331 | return out_ms, out_ncuts 332 | return out_ms 333 | 334 | def run_vote(self, im, num_per_dim=3, patch_size=128): 335 | h,w = np.shape(im)[:2] 336 | all_results = [] 337 | for hSt in np.linspace(0, h - patch_size, num_per_dim).astype(int): 338 | for wSt in np.linspace(0, w - patch_size, num_per_dim).astype(int): 339 | res = run_vote_no_threads(im, self.solver, None, n_anchors=1, num_per_dim=None, 340 | patch_size=128, batch_size=64, sample_ratio=self.quality, 341 | override_anchor=(hSt, wSt))['out']['responses'][0] 342 | all_results.append(res) 343 | 344 | return dbscan_consensus(all_results) 345 | 346 | def __call__(self, url, dense=False): 347 | """ 348 | @Args 349 | url: This can either be a web-url or directory 350 | dense: If False, runs the new DBSCAN clustering. 351 | Using dense will be low-res and low-variance. 352 | @Returns 353 | output of the clustered response 354 | """ 355 | if type(url) is not str: 356 | im = url 357 | else: 358 | if url.startswith('http'): 359 | im = util.get(url) 360 | else: 361 | im = cv2.imread(url)[:,:,[2,1,0]] 362 | 363 | #print('Image shape:', np.shape(im)) 364 | assert min(np.shape(im)[:2]) > self.im_size, 'image dimension too small' 365 | 366 | if not dense: 367 | # Runs default dense clustering 368 | out, _ = self.run_vote(im, num_per_dim=3, patch_size=self.im_size) 369 | else: 370 | # Runs DBSCAN 371 | out = self.run(im) 372 | return im, out 373 | 374 | if __name__ == '__main__': 375 | plt.switch_backend('agg') 376 | parser = argparse.ArgumentParser() 377 | parser.add_argument("--im_path", type=str, help="path_to_image") 378 | cfg = parser.parse_args() 379 | 380 | assert os.path.exists(cfg.im_path) 381 | 382 | imid = cfg.im_path.split('/')[-1].split('.')[0] 383 | save_path = os.path.join('./images', imid + '_result.png') 384 | 385 | ckpt_path = './ckpt/exif_final/exif_final.ckpt' 386 | exif_demo = Demo(ckpt_path=ckpt_path, use_gpu=0, quality=3.0, num_per_dim=30) 387 | 388 | print('Running image %s' % cfg.im_path) 389 | ms_st = time.time() 390 | im_path = cfg.im_path 391 | im, res = exif_demo(im_path, dense=True) 392 | print('MeanShift run time: %.3f' % (time.time() - ms_st)) 393 | 394 | plt.subplots(figsize=(16, 8)) 395 | plt.subplot(1, 3, 1) 396 | plt.title('Input Image') 397 | plt.imshow(im) 398 | plt.axis('off') 399 | 400 | plt.subplot(1, 3, 2) 401 | plt.title('Cluster w/ MeanShift') 402 | plt.axis('off') 403 | if np.mean(res > 0.5) > 0.5: 404 | res = 1.0 - res 405 | plt.imshow(res, cmap='jet', vmin=0.0, vmax=1.0) 406 | plt.savefig(save_path) 407 | print('Result saved %s' % save_path) 408 | -------------------------------------------------------------------------------- /download_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Downloading exif_final.zip" 4 | 5 | # Google Drive link to exif_final.zip 6 | gdown https://drive.google.com/uc?id=1X6b55rwZzU68Mz1m68WIX_G2idsEw3Qh 7 | echo "Unzipping to ./ckpt/exif_final.zip" 8 | 9 | mkdir -p ./ckpt/ 10 | unzip exif_final.zip -d ./ckpt/ 11 | 12 | rm exif_final.zip -------------------------------------------------------------------------------- /images/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungg/selfconsistency/c173e6a24e2aa375317310718bd0ff1d8b4079ee/images/demo.png -------------------------------------------------------------------------------- /init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = osp.dirname(__file__) 9 | 10 | # Add lib to PYTHONPATH 11 | lib_path = osp.join(this_dir, 'lib') 12 | root_path = osp.join(this_dir) 13 | 14 | add_path(lib_path) 15 | add_path(root_path) 16 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | # :) 2 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /lib/utils/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import init_paths 3 | from lib.utils import queue_runner, util 4 | import tensorflow as tf 5 | import threading 6 | import numpy as np 7 | import time 8 | import scipy.misc 9 | import cv2 10 | 11 | 12 | class EfficientBenchmark(): 13 | 14 | def __init__(self, solver, net_module_obj, net_module_obj_init_params, im, 15 | num_processes=1, num_threads=1, stride=None, max_bs=20000, n_anchors=3, 16 | patch_size=224, auto_close_sess=True, patches=None, mirror_pred=False, 17 | dense_compute=False, num_per_dim=30): 18 | """ 19 | solver: The model solver to run predictions 20 | net_module_obj: The corresponding net class 21 | net_module_obj_init_params: Dictionary that would normally be passed into net.initialize 22 | im: The image to analyze 23 | num_processes: Number of data grabbing processes, can only run 1 24 | num_thread: Number of threads to trasnfer from Python Queue to TF Queue 25 | stride: Distance between sampled grid patches 26 | max_bs: For precomputing, determines number of index selecting we do per batch 27 | n_anchors: Number of anchor patches, if 28 | patches: A numpy array of (n x patch_size x patch_size x 3) which is used as anchor patch, 29 | should n_anchors argument 30 | auto_close_sess: Whether to close tf session after finishing analysis 31 | 32 | (deprecated): 33 | dense_compute, always leave on false, precomputing does dense faster 34 | mirror_pred, always leave on false, precomputing does mirror predictions 35 | """ 36 | assert num_processes == 1, "Can only do single process" 37 | assert num_threads > 0, "Need at least one threads for queuing" 38 | 39 | self.use_patches = False 40 | if type(patches) != type(None): 41 | # use defined patches 42 | assert patches.shape[0] == n_anchors 43 | self.use_patches = True 44 | self.patches = patches 45 | 46 | self.mirror_pred = mirror_pred 47 | # For when we use indices for precomputed features 48 | self.max_bs = max_bs 49 | 50 | self.solver = solver 51 | self.n_anchors = n_anchors 52 | self.num_per_dim = num_per_dim 53 | self.patch_size = patch_size 54 | self.recompute_stride = False 55 | self.stride = stride 56 | if not stride: 57 | # compute stride dynamically 58 | self.recompute_stride = True 59 | self.stride = self.compute_stride(im) 60 | self.dense_compute = dense_compute 61 | if dense_compute: 62 | self.patches = None 63 | self.num_processes = num_processes 64 | self.num_threads = num_threads 65 | 66 | self.label_shape = 1 if not 'num_classes' in net_module_obj_init_params else net_module_obj_init_params['num_classes'] 67 | 68 | self.cr = self.update_queue_runner(im) 69 | self.auto_close_sess = auto_close_sess 70 | self.n_responses = self.max_h_ind * self.max_w_ind if self.dense_compute else n_anchors 71 | 72 | net_module_obj_init_params["train_runner"] = self.cr 73 | net_module_obj_init_params["use_tf_threading"] = True 74 | self.net = net_module_obj.initialize(net_module_obj_init_params) 75 | self.solver.setup_net(net=self.net) 76 | 77 | def compute_stride(self, im): 78 | return (max(im.shape[0], im.shape[1]) - self.patch_size) // self.num_per_dim 79 | 80 | def update_queue_runner(self, im): 81 | # returns a new queue_runner 82 | self.set_image(np.zeros((self.patch_size, self.patch_size, 3), dtype=np.float32)) 83 | 84 | fn = self.dense_argless if self.dense_compute else self.argless 85 | cr = queue_runner.CustomRunner(fn, n_processes=self.num_processes, 86 | n_threads=self.num_threads) 87 | self.original_cr_get_inputs = cr.get_inputs 88 | self.set_image(im) 89 | 90 | def new_cr(batch_size): 91 | self.anch_indices_, self.h_indices_, self.w_indices_, im_a, im_b = self.original_cr_get_inputs(batch_size) 92 | # we don't use the label since it's test time 93 | if self.label_shape == 1: 94 | return im_a, im_b, tf.zeros((batch_size), dtype=tf.int64) 95 | else: 96 | return im_a, im_b, tf.zeros((batch_size, self.label_shape), dtype=tf.float32) 97 | 98 | # Directly rewriting get_inputs since that's all the net class sees and uses 99 | cr.get_inputs = new_cr 100 | 101 | return cr 102 | 103 | 104 | def reset_image(self, im): 105 | if self.recompute_stride: 106 | # compute stride dynamically 107 | self.stride = self.compute_stride(im) 108 | 109 | fn = self.dense_argless if self.dense_compute else self.argless 110 | self.cr.kill_programs() 111 | # programs are all dead, purge tf queue 112 | 113 | while True: 114 | self.solver.sess.run(self.cr.tf_queue.dequeue_up_to(self.cr.tf_queue.size())) 115 | remain = self.solver.sess.run(self.cr.tf_queue.size()) 116 | if remain == 0: 117 | break 118 | 119 | self.set_image(im) 120 | self.cr.set_data_fn(fn) 121 | self.cr.start_p_threads(self.solver.sess) 122 | 123 | def get_patch(self, hind, wind): 124 | return self.image[hind:hind+self.patch_size, wind:wind+self.patch_size] 125 | 126 | def rand_patch(self): 127 | h = np.random.randint(self.image.shape[0] - self.patch_size + 1) 128 | w = np.random.randint(self.image.shape[1] - self.patch_size + 1) 129 | return self.image[h:h+self.patch_size, w:w+self.patch_size, :] 130 | 131 | def get_anchor_patches(self): 132 | # set seed here if want same patches 133 | # Regardless of whether use or not, should 134 | # be 0 if not dense compute 135 | self.anchor_count = 0 136 | if self.dense_compute: 137 | self.anchor_inds = self.indices.copy() 138 | return util.process_im(np.array([self.get_patch( 139 | self.anchor_inds[i][0], self.anchor_inds[i][1]) for i in range(self.anchor_count, 140 | min(self.anchor_count + self.n_anchors, 141 | self.anchor_inds.shape[0]))])) 142 | 143 | if self.use_patches: 144 | # pass existing patches 145 | return util.process_im(self.patches) 146 | return util.process_im( 147 | np.array([self.rand_patch() for i in range(self.n_anchors)], dtype=np.float32)) 148 | 149 | def set_image(self, image): 150 | # new image, need to refresh 151 | self.image = image 152 | self.max_h_ind = 1 + int(np.floor((self.image.shape[0] - self.patch_size) / float(self.stride))) 153 | self.max_w_ind = 1 + int(np.floor((self.image.shape[1] - self.patch_size) / float(self.stride))) 154 | self.indices = np.mgrid[0:self.max_h_ind, 0:self.max_w_ind].reshape((2, -1)).T # (n 2) 155 | self.anchor_patches = self.get_anchor_patches() 156 | self.count = -1 157 | 158 | def data_fn(self, hind, wind): 159 | n_anchors = self.anchor_patches.shape[0] 160 | y_ind, x_ind = hind * self.stride, wind * self.stride 161 | 162 | patch = self.image[y_ind:y_ind + self.patch_size, 163 | x_ind:x_ind + self.patch_size, 164 | :] 165 | 166 | anchor_inds = np.arange(self.anchor_count, self.anchor_count + n_anchors) 167 | h_inds = np.array([hind] * n_anchors, dtype=np.int64) 168 | w_inds = np.array([wind] * n_anchors, dtype=np.int64) 169 | batch_a = self.anchor_patches 170 | batch_b = util.process_im(np.array([patch] * n_anchors, dtype=np.float32)) 171 | # anc, y, x, bat_a, bat_b 172 | if self.mirror_pred: 173 | anchor_inds = np.vstack([anchor_inds] * 2) 174 | h_inds = np.vstack([h_inds] * 2) 175 | w_inds = np.vstack([w_inds] * 2) 176 | batch_a, batch_b = np.vstack([batch_a, batch_b]), np.vstack([batch_b, batch_a]) 177 | 178 | return anchor_inds, h_inds, w_inds, batch_a, batch_b 179 | 180 | def dense_argless(self): 181 | assert False, "Deprecated" 182 | if self.count >= self.indices.shape[0]: 183 | self.count = 0 184 | self.anchor_count += self.n_anchors 185 | if self.anchor_count >= self.anchor_inds.shape[0]: 186 | raise StopIteration() 187 | inds2 = self.anchor_inds[self.anchor_count] 188 | self.anchor_patches = util.process_im(np.array([self.get_patch( 189 | self.anchor_inds[i][0], self.anchor_inds[i][1]) for i in range(self.anchor_count, 190 | min(self.anchor_count + self.n_anchors, 191 | self.anchor_inds.shape[0]))])) 192 | self.n_anchors = self.anchor_patches.shape[0] 193 | inds = self.indices[self.count] 194 | self.count += 1 195 | d = self.data_fn(inds[0], inds[1]) 196 | return d 197 | 198 | 199 | def argless(self): 200 | self.count += 1 201 | if self.count >= self.indices.shape[0]: 202 | raise StopIteration() 203 | inds = self.indices[self.count] 204 | return self.data_fn(inds[0], inds[1]) 205 | 206 | def argless_extract_inds(self): 207 | iterator = np.mgrid[0:self.max_h_ind, 0:self.max_w_ind, 0:self.max_h_ind, 0:self.max_w_ind].reshape((4, -1)).T # (n 4) 208 | count = 0 209 | while True: 210 | if count * self.max_bs > len(iterator): 211 | break 212 | # each indice is a read into np.mgrid[0:self.max_h_ind, 0:self.max_w_ind] 213 | yield iterator[count * self.max_bs:(count + 1) * self.max_bs, :] # self.max_bs x 4 214 | count += 1 215 | 216 | def run_ft(self, num_fts=4096): 217 | #print("Starting Analysis") 218 | # Batch_b contains the sweeping patches, feat_b to get features of a patch 219 | # For most efficient running set n_anchors to 1 220 | responses = np.zeros((num_fts, self.max_h_ind, 221 | self.max_w_ind)) 222 | 223 | expected_num_running = self.max_h_ind * self.max_w_ind 224 | visited = np.zeros((self.max_h_ind, self.max_w_ind)) 225 | while True: 226 | try: 227 | # t0 = time.time() 228 | h_ind_, w_ind_, fts_ = self.solver.sess.run([self.h_indices_, 229 | self.w_indices_, 230 | self.solver.net.im_b_feat]) 231 | # print time.time() - t0 232 | for i in range(h_ind_.shape[0]): 233 | responses[:, h_ind_[i], w_ind_[i]] = fts_[i] 234 | visited[h_ind_[i], w_ind_[i]] = 1 235 | if np.sum(visited) == expected_num_running: 236 | raise RuntimeError("Finished") 237 | 238 | except tf.errors.OutOfRangeError as e: 239 | # TF Queue emptied, return responses 240 | if self.auto_close_sess: 241 | self.solver.sess.close() 242 | return responses 243 | except RuntimeError as e: 244 | if self.auto_close_sess: 245 | self.solver.sess.close() 246 | return responses 247 | 248 | def precomputed_analysis_vote_cls(self, num_fts=4096): 249 | #print("Starting Analysis") 250 | assert not self.auto_close_sess, "Need to keep sess open" 251 | 252 | feature_response = self.run_ft(num_fts=num_fts) 253 | 254 | flattened_features = feature_response.reshape((num_fts, -1)).T 255 | # Use np.unravel_index to recover x,y coordinate 256 | 257 | spread = max(1, self.patch_size // self.stride) 258 | 259 | responses = np.zeros((self.max_h_ind + spread - 1, self.max_w_ind + spread - 1, 260 | self.max_h_ind + spread - 1, self.max_w_ind + spread - 1), dtype=np.float32) 261 | vote_counts = np.zeros((self.max_h_ind + spread - 1, self.max_w_ind + spread - 1, 262 | self.max_h_ind + spread - 1, self.max_w_ind + spread - 1)) + 1e-4 263 | 264 | iterator = self.argless_extract_inds() 265 | while True: 266 | try: 267 | inds = next(iterator) 268 | except StopIteration as e: 269 | if self.auto_close_sess: 270 | self.solver.sess.close() 271 | out = (responses / vote_counts) 272 | return out 273 | patch_a_inds = inds[:, :2] 274 | patch_b_inds = inds[:, 2:] 275 | 276 | a_ind = np.ravel_multi_index(patch_a_inds.T, [self.max_h_ind, self.max_w_ind]) 277 | b_ind = np.ravel_multi_index(patch_b_inds.T, [self.max_h_ind, self.max_w_ind]) 278 | 279 | # t0 = time.time() 280 | preds_ = self.solver.sess.run(self.solver.net.pc_cls_pred, 281 | feed_dict={self.net.precomputed_features:flattened_features, 282 | self.net.im_a_index: a_ind, 283 | self.net.im_b_index: b_ind}) 284 | # print preds_ 285 | # print time.time() - t0 286 | for i in range(preds_.shape[0]): 287 | responses[inds[i][0] : (inds[i][0] + spread), 288 | inds[i][1] : (inds[i][1] + spread), 289 | inds[i][2] : (inds[i][2] + spread), 290 | inds[i][3] : (inds[i][3] + spread)] += preds_[i] 291 | vote_counts[inds[i][0] : (inds[i][0] + spread), 292 | inds[i][1] : (inds[i][1] + spread), 293 | inds[i][2] : (inds[i][2] + spread), 294 | inds[i][3] : (inds[i][3] + spread)] += 1 295 | 296 | 297 | -------------------------------------------------------------------------------- /lib/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | def make_subdir(path): 5 | """ Makes subdirectories for the filepath """ 6 | subdir = '/'.join(path.split('/')[:-1]) 7 | if os.path.exists(subdir): 8 | return 9 | os.makedirs(subdir) 10 | return subdir 11 | 12 | def make_dir(dir): 13 | """ Makes subdirectories for the filepath """ 14 | if os.path.exists(dir): 15 | return dir 16 | os.makedirs(dir) 17 | return dir 18 | 19 | def read_json(path): 20 | """ 21 | loads all json line formatted file 22 | """ 23 | data = [json.loads(line) for line in open(path)] 24 | return data 25 | 26 | def to_npy(path): 27 | """ Changes to image path to npy format """ 28 | return '.'.join(path.split('.')[:-1])+'.npy' 29 | 30 | def show(args, phase, iter): 31 | """ Used to show training outputs """ 32 | print '(%s) Iterations %i' % (phase, iter) 33 | max_len = max([len(k[0]) for k in args]) 34 | for out in args: 35 | a,b = out 36 | print '\t',a.ljust(max_len),': ', b 37 | return 38 | 39 | def add_summary(writer, list_of_summary, i): 40 | """ Adds list of summary to the writer """ 41 | for s in list_of_summary: 42 | writer.add_summary(s, i) 43 | 44 | def parse_checkpoint(ckpt): 45 | """ Parses checkpoint string to get iteration """ 46 | assert type(ckpt) == str, ckpt 47 | try: 48 | i = int(ckpt.split('_')[-1].split('.')[0]) 49 | except: 50 | print 'unknown checkpoint string format %s setting iteration to 0' % ckpt 51 | i = 0 52 | return i 53 | 54 | def make_ckpt(saver, sess, save_prefix, i=None): 55 | """ Makes a checkpoint """ 56 | if i is not None: 57 | save_prefix += '_' + str(i) 58 | save_path = save_prefix + '.ckpt' 59 | saver.save(sess, save_path) 60 | print 'Saved checkpoint at %s' % save_path 61 | return 62 | -------------------------------------------------------------------------------- /lib/utils/ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tensorflow.contrib.slim as slim 4 | 5 | def get_variables(finetune_ckpt_path, exclude_scopes=None): 6 | """Returns list of variables without scopes that start with exclude_scopes .""" 7 | if exclude_scopes is not None: 8 | exclusions = [scope.strip() for scope in exclude_scopes] 9 | variables_to_restore = [ var for var in slim.get_model_variables() if not np.any([var.op.name.startswith(ex) for ex in exclusions])] 10 | else: 11 | variables_to_restore = [ var for var in slim.get_model_variables()] 12 | return variables_to_restore 13 | 14 | def config(use_gpu=None): 15 | config = tf.ConfigProto() 16 | config.gpu_options.allow_growth = True 17 | config.allow_soft_placement = True 18 | if use_gpu: 19 | if type(use_gpu) is list: 20 | use_gpu = ','.join([str(g) for g in use_gpu]) 21 | config.gpu_options.visible_device_list = str(use_gpu) 22 | return config 23 | 24 | def tfprint(x): 25 | print x 26 | return x 27 | 28 | def extract_var(starts_with, is_not=False): 29 | if type(starts_with) is str: 30 | starts_with = [starts_with] 31 | selected_vars = [] 32 | for s in starts_with: 33 | if not is_not: 34 | selected_vars.extend([var for var in tf.trainable_variables() if var.op.name.startswith(s)]) 35 | else: 36 | selected_vars.extend([var for var in tf.trainable_variables() if not var.op.name.startswith(s)]) 37 | return selected_vars 38 | 39 | def init_solver(param): 40 | """ Initializes solver using solver param """ 41 | return param.solver(learning_rate=param.learning_rate, 42 | beta1=param.beta1, 43 | beta2=param.beta2) 44 | 45 | def multiclass_accuracy(pr, gt): 46 | """ pr is logits. computes multiclass accuracy """ 47 | correct_prediction = tf.equal(tf.round(tf.nn.sigmoid(pr)), tf.round(gt)) 48 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 49 | return accuracy 50 | 51 | def leaky_relu(input, slope=0.2): 52 | """ Leaky relu """ 53 | with tf.name_scope('leaky_relu'): 54 | return tf.maximum(slope*input, input) 55 | 56 | def batch_norm(input, is_training): 57 | """ batch normalization """ 58 | with tf.variable_scope('batch_norm'): 59 | return tf.contrib.layers.batch_norm(input, decay=0.9, scale=True, 60 | updates_collections=None, is_training=is_training) 61 | 62 | def renorm(input, is_training): 63 | return tf.layers.batch_normalization(input, training=is_training, renorm_momentum=0.9) 64 | 65 | def instance_norm(input, is_training): 66 | """ instance normalization """ 67 | with tf.variable_scope('instance_norm'): 68 | num_out = input.get_shape()[-1] 69 | scale = tf.get_variable('scale', [num_out], 70 | initializer=tf.random_normal_initializer(mean=1.0, stddev=0.02)) 71 | offset = tf.get_variable('offset', [num_out], 72 | initializer=tf.random_normal_initializer(mean=0.0, stddev=0.02)) 73 | mean, var = tf.nn.moments(input, axes=[1,2], keep_dims=True) 74 | epsilon = 1e-6 75 | inv = tf.rsqrt(var + epsilon) 76 | return scale * (input - mean) * inv + offset 77 | 78 | def fc(input, output, reuse=False, norm=None, activation=tf.nn.relu, dropout=0.7, is_training=True, name='fc'): 79 | """ FC with norm, activation, dropout support """ 80 | with tf.variable_scope(name, reuse=reuse): 81 | x = slim.fully_connected(input, output, activation_fn=activation, normalizer_fn=norm, reuse=reuse) 82 | x = tf.nn.dropout(x, dropout) 83 | return x 84 | 85 | def conv(input, output, size, stride, 86 | reuse=False, 87 | norm=instance_norm, 88 | activation=leaky_relu, 89 | dropout=1.0, 90 | padding='VALID', 91 | pad_size=None, 92 | is_training=True, 93 | name='conv'): 94 | """ 95 | Performs convolution -> batchnorm -> relu 96 | """ 97 | with tf.variable_scope(name, reuse=reuse): 98 | dropout = 1.0 if dropout is None else dropout 99 | # Pre pad the input feature map 100 | x = pad(input, pad_size) 101 | # Apply convolution 102 | x = slim.conv2d(x, output, size, stride, 103 | activation_fn=None, 104 | weights_initializer=tf.truncated_normal_initializer(stddev=0.02), 105 | padding=padding) 106 | # Apply dropout 107 | x = tf.nn.dropout(x, dropout) 108 | # Apply activation 109 | x = activation(x) if activation else x 110 | # Apply normalization 111 | x = norm(x, is_training) if norm else x 112 | return x 113 | 114 | def pad(input, pad_size): 115 | """ Reflect pads input by adding pad_size to h x w dimensions """ 116 | if not pad_size: 117 | return input 118 | return tf.pad(input, [[0,0],[pad_size, pad_size],[pad_size, pad_size],[0,0]], 'REFLECT') 119 | 120 | def average_gradients(grad_list): 121 | """Calculate the average gradient for each shared variable across all towers. 122 | Note that this function provides a synchronization point across all towers. 123 | Args: 124 | grad_list: List of lists of (gradient, variable) tuples. The outer list 125 | is over individual gradients. The inner list is over the gradient 126 | calculation for each tower. 127 | Returns: 128 | List of pairs of (gradient, variable) where the gradient has been averaged 129 | across all towers. 130 | """ 131 | average_grads = [] 132 | for grad_and_vars in zip(*grad_list): 133 | # Note that each grad_and_vars looks like the following: 134 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 135 | grads = [] 136 | for g, _ in grad_and_vars: 137 | # Add 0 dimension to the gradients to represent the tower. 138 | expanded_g = tf.expand_dims(g, 0) 139 | 140 | # Append on a 'tower' dimension which we will average over below. 141 | grads.append(expanded_g) 142 | 143 | # Average over the 'tower' dimension. 144 | grad = tf.concat(axis=0, values=grads) 145 | grad = tf.reduce_mean(grad, 0) 146 | 147 | # Keep in mind that the Variables are redundant because they are shared 148 | # across towers. So .. we will just return the first tower's pointer to 149 | # the Variable. 150 | v = grad_and_vars[0][1] 151 | grad_and_var = (grad, v) 152 | average_grads.append(grad_and_var) 153 | return average_grads 154 | -------------------------------------------------------------------------------- /lib/utils/queue_runner.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import time 4 | import multiprocessing as mp 5 | import threading 6 | import Queue 7 | 8 | class CustomRunner(object): 9 | """ 10 | This class manages the the background threads needed to fill 11 | a queue full of data. 12 | 13 | # Need to call the following code block after initializing everything 14 | self.sess.run(tf.global_variables_initializer()) 15 | 16 | if self.use_tf_threading: 17 | self.coord = tf.train.Coordinator() 18 | self.net.train_runner.start_p_threads(self.sess) 19 | tf.train.start_queue_runners(sess=self.sess, coord=self.coord) 20 | 21 | """ 22 | def __init__(self, arg_less_fn, override_dtypes=None, 23 | n_threads=1, n_processes=3, max_size=30): 24 | # arg_less_fn should be function that returns already ready data 25 | # in the form of numpy arrays. The shape of the output is 26 | # used to shape the output tensors. Should be ready to call at init_time 27 | # override_dtypes is the typing, default to numpy's encoding. 28 | self.data_fn = arg_less_fn 29 | self.n_threads = n_threads 30 | self.n_processes = n_processes 31 | self.max_size = max_size 32 | self.use_pool = False 33 | 34 | # data_fn shouldn't take any argument, 35 | # just directly return the necessary data 36 | # set via the setter fn 37 | 38 | data = self.data_fn() 39 | self.inps = [] 40 | shapes, dtypes = [], [] 41 | for i, d in enumerate(data): 42 | inp = tf.placeholder(dtype=d.dtype, shape=[None] + list(d.shape[1:])) 43 | self.inps.append(inp) 44 | # remove batching index for individual element 45 | shapes.append(d.shape[1:]) 46 | dtypes.append(d.dtype) 47 | # The actual queue of data. 48 | self.tf_queue = tf.FIFOQueue(shapes=shapes, 49 | # override_dtypes or default 50 | dtypes=override_dtypes or dtypes, 51 | capacity=2000) 52 | 53 | # The symbolic operation to add data to the queue 54 | self.enqueue_op = self.tf_queue.enqueue_many(self.inps) 55 | 56 | def get_inputs(self, batch_size): 57 | """ 58 | Return's tensors containing a batch of images and labels 59 | 60 | if tf_queue has been closed this will raise a QueueBase exception 61 | killing the main process if a StopIteration is thrown in one of the 62 | data processes. 63 | """ 64 | return self.tf_queue.dequeue_up_to(tf.reduce_min([batch_size, self.tf_queue.size()])) 65 | 66 | def thread_main(self, sess, stop_event): 67 | """ 68 | Function run on alternate thread. Basically, keep adding data to the queue. 69 | """ 70 | tt_last_update = time.time() - 501 71 | count = 0 72 | tot_p_end = 0 73 | processes_all_done = False 74 | while not stop_event.isSet(): 75 | if tt_last_update + 500 < time.time(): 76 | t = time.time() 77 | # 500 seconds since last update 78 | #print("DataQueue Threading Update:") 79 | #print("TIME: " + str(t)) 80 | # MP.Queue says it is not thread safe and is not perfectly accurate. 81 | # Just want to make sure there's no leakage and max_size 82 | # is safely hit 83 | #print("APPROX SIZE: %d" % self.queue.qsize()) 84 | #print("TOTAL FETCH ITERATIONS: %d" % count) 85 | tt_last_update = t 86 | count += 1 87 | if processes_all_done and self.queue.empty(): 88 | break 89 | try: 90 | data = self.queue.get(5) 91 | except Queue.Empty: 92 | continue 93 | 94 | if type(data) == type(StopIteration()): 95 | tot_p_end += 1 96 | if tot_p_end == self.n_processes: 97 | # Kill any processes 98 | # may need a lock here if multithreading 99 | processes_all_done = True 100 | #print("ALL PROCESSES DONE") 101 | continue 102 | 103 | fd = {} 104 | for i, d in enumerate(data): 105 | fd[self.inps[i]] = d 106 | sess.run(self.enqueue_op, feed_dict=fd) 107 | self.queue.close() 108 | 109 | def process_main(self, queue): 110 | # Scramble seed so it's not a copy of the parent's seed 111 | np.random.seed() 112 | # np.random.seed(1) 113 | try: 114 | while True: 115 | queue.put(self.data_fn()) 116 | except StopIteration as e: 117 | # Should only manually throw when want to close queue 118 | queue.put(e) 119 | #raise e 120 | return 121 | 122 | except Exception as e: 123 | queue.put(StopIteration()) 124 | #raise e 125 | return 126 | 127 | 128 | def set_data_fn(self, fn): 129 | self.data_fn = fn 130 | 131 | def start_p_threads(self, sess): 132 | """ Start background threads to feed queue """ 133 | self.processes = [] 134 | self.queue = mp.Queue(self.max_size) 135 | 136 | for n in range(self.n_processes): 137 | p = mp.Process(target=self.process_main, args=(self.queue,)) 138 | p.daemon = True # thread will close when parent quits 139 | p.start() 140 | self.processes.append(p) 141 | 142 | self.threads = [] 143 | self.thread_event_killer = [] 144 | for n in range(self.n_threads): 145 | kill_thread = threading.Event() 146 | self.thread_event_killer.append(kill_thread) 147 | 148 | t = threading.Thread(target=self.thread_main, args=(sess, kill_thread)) 149 | t.daemon = True # thread will close when parent quits 150 | t.start() 151 | self.threads.append(t) 152 | return self.processes + self.threads 153 | 154 | def kill_programs(self): 155 | # Release objects here if need to 156 | # threads should die in at least 5 seconds because 157 | # nothing blocks for more than 5 seconds 158 | 159 | # Sig term, kill first so no more data 160 | [p.terminate() for p in self.processes] 161 | [p.join() for p in self.processes] 162 | 163 | # kill second after purging 164 | [e.set() for e in self.thread_event_killer] 165 | -------------------------------------------------------------------------------- /lib/utils/util.py: -------------------------------------------------------------------------------- 1 | # Shared and common functions (declustering redundant code) 2 | import numpy as np, os 3 | import random, cv2 4 | import operator 5 | 6 | def get(link, save_as=False): 7 | import urllib 8 | base_dir = './tmp' 9 | assert type(link) == str, type(link) 10 | 11 | if not os.path.exists(base_dir): 12 | os.makedirs(base_dir) 13 | 14 | if save_as: 15 | save_path = os.path.join(base_dir, save_as) 16 | else: 17 | save_path = os.path.join(base_dir, 'tmp.png') 18 | 19 | urllib.urlretrieve(link, save_path) 20 | im = cv2.imread(save_path)[:,:,[2,1,0]] 21 | return im 22 | 23 | def softmax(X, theta = 1.0, axis = None): 24 | y = np.atleast_2d(X) 25 | if axis is None: 26 | axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1) 27 | y = y * float(theta) 28 | y = y - np.expand_dims(np.max(y, axis = axis), axis) 29 | y = np.exp(y) 30 | ax_sum = np.expand_dims(np.sum(y, axis = axis), axis) 31 | p = y / ax_sum 32 | if len(X.shape) == 1: p = p.flatten() 33 | return p 34 | 35 | def sort_dict(d, sort_by='value'): 36 | """ Sorts dictionary """ 37 | assert sort_by in ['value', 'key'], sort_by 38 | if sort_by == 'key': 39 | return sorted(d.items(), key=operator.itemgetter(0)) 40 | if sort_by == 'value': 41 | return sorted(d.items(), key=operator.itemgetter(1)) 42 | 43 | def random_crop(im, crop_size, return_crop_loc=False): 44 | """ Randomly crop """ 45 | h,w = np.shape(im)[:2] 46 | hSt = random.randint(0, h - crop_size[0]) 47 | wSt = random.randint(0, w - crop_size[1]) 48 | patch = im[hSt:hSt+crop_size[0], wSt:wSt+crop_size[1], :] 49 | assert tuple(np.shape(patch)[:2]) == tuple(crop_size) 50 | if return_crop_loc: 51 | return patch, (hSt, wSt) 52 | return patch 53 | 54 | def process_im(im): 55 | """ Normalizes images into the range [-1.0, 1.0] """ 56 | im = np.array(im) 57 | if np.max(im) <= 1: 58 | # PNG format 59 | im = (2.0 * im) - 1.0 60 | else: 61 | # JPEG format 62 | im = 2.0 * (im / 255.) - 1.0 63 | return im 64 | 65 | def deprocess_im(im, dtype=None): 66 | """ Map images in [-1.0, 1.0] back to [0, 255] """ 67 | im = np.array(im) 68 | return ((255.0 * (im + 1.0))/2.0).astype(dtype) 69 | 70 | def random_resize(im_a, im_b, same): 71 | valid_interps = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4, cv2.INTER_AREA] 72 | 73 | def get_param(): 74 | hr, wr = np.random.choice(np.linspace(0.5, 1.5, 11), 2) 75 | #hr, wr = np.random.uniform(low=0.5, high=1.5, size=2) 76 | interp = np.random.choice(valid_interps) 77 | return [hr, wr, interp] 78 | 79 | if same: 80 | if np.random.randint(2): 81 | a_par = get_param() 82 | im_a = cv2.resize(im_a, None, fx=a_par[0], fy=a_par[1], interpolation=a_par[2]) 83 | im_b = cv2.resize(im_b, None, fx=a_par[0], fy=a_par[1], interpolation=a_par[2]) 84 | else: 85 | a_par = get_param() 86 | im_a = cv2.resize(im_a, None, fx=a_par[0], fy=a_par[1], interpolation=a_par[2]) 87 | if np.random.randint(2): 88 | b_par = get_param() 89 | while np.all(a_par == b_par): 90 | b_par = get_param() 91 | im_b = cv2.resize(im_b, None, fx=b_par[0], fy=b_par[1], interpolation=b_par[2]) 92 | return im_a, im_b 93 | 94 | def random_jpeg(im_a, im_b, same): 95 | def get_param(): 96 | #jpeg_quality_a = np.random.randint(50, 100) # doesnt include 100 97 | return np.random.choice(np.linspace(50, 100, 11)) 98 | 99 | if same: 100 | if np.random.randint(2): 101 | a_par = get_param() 102 | _, enc_a = cv2.imencode('.jpg', im_a, [int(cv2.IMWRITE_JPEG_QUALITY), a_par]) 103 | im_a = cv2.imdecode(enc_a, 1) 104 | _, enc_b = cv2.imencode('.jpg', im_b, [int(cv2.IMWRITE_JPEG_QUALITY), a_par]) 105 | im_b = cv2.imdecode(enc_b, 1) 106 | else: 107 | a_par = get_param() 108 | _, enc_a = cv2.imencode('.jpg', im_a, [int(cv2.IMWRITE_JPEG_QUALITY), a_par]) 109 | im_a = cv2.imdecode(enc_a, 1) 110 | if np.random.randint(2): 111 | b_par = get_param() 112 | while np.all(a_par == b_par): 113 | b_par = get_param() 114 | _, enc_b = cv2.imencode('.jpg', im_b, [int(cv2.IMWRITE_JPEG_QUALITY), b_par]) 115 | im_b = cv2.imdecode(enc_b, 1) 116 | return im_a, im_b 117 | 118 | def gaussian_blur(im, kSz=None, sigma=1.0): 119 | # 5x5 kernel blur 120 | if kSz is None: 121 | kSz = np.ceil(3.0 * sigma) 122 | kSz = kSz + 1 if kSz % 2 == 0 else kSz 123 | kSz = max(kSz, 3) # minimum kernel size 124 | kSz = int(kSz) 125 | blur = cv2.GaussianBlur(im,(kSz,kSz), sigma) 126 | return blur 127 | 128 | def random_blur(im_a, im_b, same): 129 | # only square gaussian kernels 130 | def get_param(): 131 | kSz = (2 * np.random.randint(1, 8)) + 1 # [3, 15] 132 | sigma = np.random.choice(np.linspace(1.0, 5.0, 9)) 133 | #sigma = np.random.uniform(low=1.0, high=5.0, size=None) # 3 * sigma = kSz 134 | return [kSz, sigma] 135 | 136 | if same: 137 | if np.random.randint(2): 138 | a_par = get_param() 139 | im_a = cv2.GaussianBlur(im_a, (a_par[0], a_par[0]), a_par[1]) 140 | im_b = cv2.GaussianBlur(im_b, (a_par[0], a_par[0]), a_par[1]) 141 | else: 142 | a_par = get_param() 143 | im_a = cv2.GaussianBlur(im_a, (a_par[0], a_par[0]), a_par[1]) 144 | if np.random.randint(2): 145 | b_par = get_param() 146 | while np.all(a_par == b_par): 147 | b_par = get_param() 148 | im_b = cv2.GaussianBlur(im_b, (b_par[0], b_par[0]), b_par[1]) 149 | return im_a, im_b 150 | 151 | def random_noise(im): 152 | noise = np.random.randn(*np.shape(im)) * 10.0 153 | return np.array(np.clip(noise + im, 0, 255.0), dtype=np.uint8) 154 | -------------------------------------------------------------------------------- /load_models.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import init_paths 3 | import tensorflow as tf 4 | 5 | def initialize_exif(ckpt='', init=True, use_gpu=0): 6 | from models.exif import exif_net, exif_solver 7 | tf.reset_default_graph() 8 | net_args = {'num_classes':80+3, 9 | 'is_training':False, 10 | 'train_classifcation':True, 11 | 'freeze_base': True, 12 | 'im_size':128, 13 | 'batch_size':64, 14 | 'use_gpu':[use_gpu], 15 | 'use_tf_threading':False, 16 | 'learning_rate':1e-4} 17 | 18 | solver = exif_solver.initialize({'checkpoint':ckpt, 19 | 'use_exif_summary':False, 20 | 'init_summary':False, 21 | 'exp_name':'eval'}) 22 | if init: 23 | net = exif_net.initialize(net_args) 24 | solver.setup_net(net=net) 25 | return solver 26 | return solver, exif_net, net_args 27 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/exif/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/exif/exif_net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils import ops 3 | import copy, numpy as np 4 | from nets import resnet_v2, resnet_utils 5 | slim = tf.contrib.slim 6 | resnet_arg_scope = resnet_utils.resnet_arg_scope 7 | 8 | 9 | class EXIFNet(): 10 | """ 11 | Given a patch from an image try to classify which camera model it came from 12 | """ 13 | def __init__(self, num_classes=83, train_classifcation=False, 14 | use_tf_threading=False, train_runner=None, batch_size=None, 15 | im_size=128, is_training=True, freeze_base=False, use_gpu=0, 16 | learning_rate=1e-4, use_classify_with_feat=False): 17 | """ 18 | num_classes: Number of EXIF classes to predict 19 | classify_with_feat: If True, the classification layers use the output 20 | ResNet features along with EXIF predictions 21 | train_classifcation: Trains a classifer on top of the EXIF predictions 22 | use_tf_threading: Uses tf threading 23 | train_runner: The queue_runnner associated with tf threading 24 | batch_size: Batch size of the input variables. Must be specfied if using 25 | use_tf_threading and queue_runner 26 | im_size: image size to specify for the placeholder. 27 | Assumes square input for now. 28 | is_training: When False, use training statistics for normalization. 29 | This can be overwritten by the feed_dict 30 | use_gpu: List of GPUs to use to train 31 | freeze_base: Freezes all the layers except the classification. 32 | train_classifcation must be set to True to be useful. 33 | No loss is computed with self.label 34 | learning_rate: Learning rate for the optimizer 35 | """ 36 | 37 | self.use_gpu = use_gpu if type(use_gpu) is list else [use_gpu] 38 | self.im_size = im_size 39 | self.num_classes = num_classes 40 | self.use_classify_with_feat = use_classify_with_feat 41 | self.train_classifcation = train_classifcation 42 | self.freeze_base = freeze_base 43 | self._is_training = is_training # default value if not provided 44 | self.use_tf_threading = use_tf_threading 45 | self.train_runner = train_runner 46 | self.batch_size = batch_size 47 | self.learning_rate = learning_rate 48 | 49 | if self.use_tf_threading: 50 | assert self.batch_size is not None, self.batch_size 51 | assert self.batch_size % len(use_gpu) == 0, 'batch size should be modulo of the number of gpus' 52 | im_a, im_b, label = self.train_runner.get_inputs(self.batch_size) 53 | self.im_a = tf.placeholder_with_default(im_a, [None, self.im_size, self.im_size, 3]) 54 | self.im_b = tf.placeholder_with_default(im_b, [None, self.im_size, self.im_size, 3]) 55 | self.label = tf.placeholder_with_default(label, [None, self.num_classes]) 56 | self.cls_label = tf.placeholder(tf.float32, [None, 1]) 57 | else: 58 | self.im_a = tf.placeholder(tf.float32, [None, self.im_size, self.im_size, 3]) 59 | self.im_b = tf.placeholder(tf.float32, [None, self.im_size, self.im_size, 3]) 60 | self.label = tf.placeholder(tf.float32, [None, self.num_classes]) 61 | self.cls_label = tf.placeholder(tf.float32, [None, 1]) 62 | 63 | self.is_training = tf.placeholder_with_default(self._is_training, None) 64 | 65 | self.extract_features = self.extract_features_resnet50 66 | # if precomputing, need to populate via feed dict then compute with selecting indices 67 | # self.im_a_ind, self.im_b_ind 68 | self.precomputed_features = tf.placeholder(tf.float32, [None, 4096]) 69 | # Add second precompute_features_b for different patch gridding 70 | self.im_a_index = tf.placeholder(tf.int32, [None,]) 71 | self.im_b_index = tf.placeholder(tf.int32, [None,]) 72 | 73 | 74 | self.pc_im_a_feat = tf.map_fn(self.mapping_fn, self.im_a_index, dtype=tf.float32, 75 | infer_shape=False) 76 | self.pc_im_a_feat.set_shape((None, 4096)) 77 | self.pc_im_b_feat = tf.map_fn(self.mapping_fn, self.im_b_index, dtype=tf.float32, 78 | infer_shape=False) 79 | self.pc_im_b_feat.set_shape((None, 4096)) 80 | 81 | self.model() 82 | 83 | self.cls_variables = ops.extract_var(['classify']) 84 | 85 | return 86 | 87 | def get_variables(self): 88 | """ 89 | Returns only variables that are needed. If freeze_base is True, return 90 | only variables that start with 'classify' 91 | """ 92 | if self.freeze_base: 93 | var_list = ops.extract_var('classify') 94 | else: 95 | var_list = tf.trainable_variables() 96 | 97 | assert len(var_list) > 0, 'No variables are linked to the optimizer' 98 | return var_list 99 | 100 | def mapping_fn(self, v): 101 | # v is an index into precompute_features 102 | return self.precomputed_features[v] 103 | 104 | def model(self, preemptive_reuse=False): 105 | """ 106 | Initializes model to train. 107 | Supports multi-GPU. 108 | Initializes the optimizer in the network graph. 109 | """ 110 | with tf.variable_scope(tf.get_variable_scope()): 111 | # Split data into n equal batches 112 | im_a_list = tf.split(self.im_a, len(self.use_gpu)) 113 | im_b_list = tf.split(self.im_b, len(self.use_gpu)) 114 | label_list = tf.split(self.label, len(self.use_gpu)) 115 | if self.train_classifcation: 116 | cls_label_list = tf.split(self.cls_label, len(self.use_gpu)) 117 | 118 | # We intialize the optimizer here 119 | self._opt = tf.train.AdamOptimizer(learning_rate=self.learning_rate) 120 | 121 | # Used to average 122 | all_grads = [] 123 | all_out = [] 124 | all_cls_out = [] 125 | all_loss = [] 126 | all_cls_loss = [] 127 | all_total_loss = [] 128 | 129 | for i, gpu_id in enumerate(self.use_gpu): 130 | print('Initializing graph on gpu %i' % gpu_id) 131 | with tf.device('/gpu:%d' % gpu_id): 132 | if preemptive_reuse: 133 | tf.get_variable_scope().reuse_variables() 134 | 135 | total_loss = 0 136 | im_a, im_b, label = im_a_list[i], im_b_list[i], label_list[i] 137 | if self.train_classifcation: 138 | cls_label = cls_label_list[i] 139 | 140 | with tf.name_scope('extract_feature_a'): 141 | im_a_feat = self.extract_features(im_a, name='feature_resnet') 142 | self.im_a_feat = im_a_feat 143 | 144 | with tf.name_scope('extract_feature_b'): 145 | im_b_feat = self.extract_features(im_b, name='feature_resnet', reuse=True) 146 | self.im_b_feat = im_b_feat 147 | 148 | with tf.name_scope('predict_same'): 149 | feat_ab = tf.concat([im_a_feat, im_b_feat], axis=-1) 150 | out = self.predict(feat_ab, name='predict') 151 | all_out.append(out) 152 | 153 | pc_feat_ab = tf.concat([self.pc_im_a_feat, self.pc_im_b_feat], axis=-1) 154 | pc_out = self.predict(pc_feat_ab, name='predict', reuse=True) 155 | 156 | if not self.freeze_base: 157 | with tf.name_scope('exif_loss'): 158 | loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=label, logits=out)) 159 | all_loss.append(loss) 160 | total_loss += loss 161 | 162 | if self.train_classifcation: 163 | with tf.name_scope('predict_same_image'): 164 | if self.use_classify_with_feat: 165 | cls_out = self.classify_with_feat(im_a_feat, im_b_feat, out, name='classify') 166 | pc_cls_out = self.classify_with_feat(pc_im_a_feat, pc_im_b_feat, pc_out, name='classify', 167 | reuse=True) 168 | else: 169 | cls_out = self.classify(out, name='classify') 170 | pc_cls_out = self.classify(pc_out, name='classify', reuse=True) 171 | all_cls_out.append(cls_out) 172 | with tf.name_scope('classification_loss'): 173 | cls_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=cls_label, logits=cls_out)) 174 | all_cls_loss.append(cls_loss) 175 | total_loss += cls_loss 176 | 177 | tf.get_variable_scope().reuse_variables() 178 | grad = self._opt.compute_gradients(total_loss, var_list=self.get_variables()) 179 | all_grads.append(grad) 180 | all_total_loss.append(total_loss) 181 | 182 | # Average the gradient and apply 183 | avg_grads = ops.average_gradients(all_grads) 184 | self.all_loss = all_loss 185 | self.avg_grads = avg_grads 186 | 187 | if not self.freeze_base: 188 | self.loss = tf.reduce_mean(all_loss) 189 | 190 | if self.train_classifcation: 191 | self.cls_loss = tf.reduce_mean(all_cls_loss) 192 | 193 | self.total_loss = tf.reduce_mean(all_total_loss) 194 | self.opt = self._opt.apply_gradients(avg_grads) # trains all variables for now 195 | 196 | # For logging results 197 | self.out = tf.concat(all_out, axis=0) 198 | self.pred = tf.sigmoid(self.out) 199 | 200 | if not self.freeze_base: 201 | correct_prediction = tf.equal(tf.round(self.pred), tf.round(self.label)) 202 | self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 203 | 204 | if self.train_classifcation: 205 | self.cls_out = tf.concat(all_cls_out, axis=0) 206 | self.cls_pred = tf.sigmoid(self.cls_out) 207 | self.pc_cls_pred = tf.sigmoid(pc_cls_out) 208 | self.cls_accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(self.cls_pred), tf.round(self.cls_label)), tf.float32)) 209 | return 210 | 211 | def extract_features_resnet50(self, im, name, is_training=True, reuse=False): 212 | use_global_pool = True 213 | num_classes = 4096 if use_global_pool else 512 214 | with tf.name_scope(name): 215 | with slim.arg_scope(resnet_utils.resnet_arg_scope()): 216 | out, _ = resnet_v2.resnet_v2_50(inputs=im, 217 | num_classes=num_classes, 218 | global_pool=use_global_pool, 219 | is_training=self.is_training, 220 | spatial_squeeze=True, 221 | scope='resnet_v2_50', 222 | reuse=reuse) 223 | 224 | if not use_global_pool: 225 | args = {'reuse':reuse, 'norm':None, 'activation':tf.nn.relu ,'padding':'SAME', 'is_training':is_training} 226 | out_args = copy.deepcopy(args) 227 | out_args['activation'] = None 228 | out = ops.conv(out, 1024, 3, 2, name='conv1', **args) 229 | out = slim.batch_norm(out) 230 | out = ops.conv(out, 2048, 3, 2, name='conv2', **args) 231 | out = slim.batch_norm(out) 232 | out = ops.conv(out, 4096, 3, 2, name='conv3', **out_args) 233 | out = slim.batch_norm(out) 234 | out = tf.squeeze(out, [1, 2], name='SpatialSqueeze') 235 | return out 236 | 237 | def predict(self, feat_ab, name, reuse=False): 238 | with tf.variable_scope(name, reuse=reuse): 239 | in_size = int(feat_ab.get_shape()[1]) 240 | out = slim.stack(feat_ab, slim.fully_connected, [4096, 2048, 1024], scope='fc') 241 | out = slim.fully_connected(out, self.num_classes, activation_fn=None, scope='fc_out') 242 | return out 243 | 244 | def classify_with_feat(self, im_a_feat, im_b_feat, affinity_pred, name, is_training=True, reuse=False): 245 | """ Predicts whether the 2 image patches are from the same image """ 246 | with tf.variable_scope(name, reuse=reuse): 247 | x = tf.concat([im_a_feat, im_b_feat, affinity_pred], axis=-1) 248 | x = slim.stack(x, slim.fully_connected, [4096, 1024], scope='fc') 249 | out = slim.fully_connected(x, 1, activation_fn=None, scope='fc_out') 250 | return out 251 | 252 | def classify(self, affinity_pred, name, is_training=True, reuse=False): 253 | """ Predicts whether the 2 image patches are from the same image """ 254 | with tf.variable_scope(name, reuse=reuse): 255 | x = slim.stack(affinity_pred, slim.fully_connected, [512], scope='fc') 256 | out = slim.fully_connected(x, 1, activation_fn=None, scope='fc_out') 257 | return out 258 | 259 | def initialize(args): 260 | return EXIFNet(**args) 261 | -------------------------------------------------------------------------------- /models/exif/exif_solver.py: -------------------------------------------------------------------------------- 1 | import os, sys, numpy as np, time 2 | import init_paths 3 | import tensorflow as tf 4 | import tensorflow.contrib.slim as slim 5 | from utils import ops, io 6 | import traceback 7 | from collections import deque 8 | 9 | class ExifSolver(object): 10 | def __init__(self, checkpoint=None, use_exif_summary=True, exp_name='no_name', init_summary=True): 11 | """ 12 | Args 13 | checkpoint: .ckpt file to initialize weights from 14 | use_exif_summary: EXIF accuracy are stored 15 | exp_name: ckpt and tb name prefix 16 | init_summary: will create TB files, will override use_exif_summary arg 17 | """ 18 | self.checkpoint = None if checkpoint in ['', None] else checkpoint 19 | self.exp_name = exp_name 20 | self._batch_size = 128 21 | self.use_exif_summary = use_exif_summary 22 | self.init_summary = init_summary 23 | self.ckpt_path = os.path.join('./ckpt', exp_name, exp_name) 24 | io.make_dir(self.ckpt_path) 25 | 26 | self.train_iterations = 10000000 27 | self.test_init = True 28 | self.show_iter = 20 29 | self.test_iter = 2000 30 | self.save_iter = 10000 31 | 32 | self.train_timer = deque(maxlen=10) 33 | return 34 | 35 | def setup_net(self, net): 36 | """ Links and setup loss and summary """ 37 | # Link network 38 | self.net = net 39 | 40 | # Initialize some basic things 41 | self.sess = tf.Session(config=ops.config(self.net.use_gpu)) 42 | if self.init_summary: 43 | self.train_writer = tf.summary.FileWriter(os.path.join('./tb', self.exp_name + '_train'), self.sess.graph) 44 | self.test_writer = tf.summary.FileWriter(os.path.join('./tb', self.exp_name + '_test')) 45 | self.setup_summary() 46 | self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=None) 47 | 48 | # Try to load checkpoint 49 | if self.checkpoint is not None: 50 | assert os.path.exists(self.checkpoint) or os.path.exists(self.checkpoint + '.index'), 'checkpoint does not exist' 51 | try: 52 | self.saver.restore(self.sess, self.checkpoint) 53 | self.i = io.parse_checkpoint(self.checkpoint) 54 | print 'Succesfully resuming from %s' % self.checkpoint 55 | except Exception: 56 | print traceback.format_exc() 57 | try: 58 | print 'Model and checkpoint did not match, attempting to restore only weights' 59 | variables_to_restore = ops.get_variables(self.checkpoint, exclude_scopes=['Adam']) 60 | restorer = tf.train.Saver(variables_to_restore) 61 | restorer.restore(self.sess, self.checkpoint) 62 | except Exception: 63 | print 'Model and checkpoint did not match, attempting to partially restore' 64 | self.sess.run(tf.global_variables_initializer()) 65 | # Make sure you correctly set exclude_scopes if you are finetuining models or extending it 66 | variables_to_restore = ops.get_variables(self.checkpoint, exclude_scopes=['classify']) #'resnet_v2_50/logits/', 'predict', 67 | restorer = tf.train.Saver(variables_to_restore) 68 | restorer.restore(self.sess, self.checkpoint) 69 | 70 | print 'Variables intitializing from scratch' 71 | for var in tf.trainable_variables(): 72 | if var not in variables_to_restore: 73 | print var 74 | print 'Succesfully restored %i variables' % len(variables_to_restore) 75 | self.i = 0 76 | else: 77 | print 'Initializing from scratch' 78 | self.i = 0 79 | self.sess.run(tf.global_variables_initializer()) 80 | self.start_i = self.i 81 | 82 | if self.net.use_tf_threading: 83 | self.coord = tf.train.Coordinator() 84 | self.net.train_runner.start_p_threads(self.sess) 85 | tf.train.start_queue_runners(sess=self.sess, coord=self.coord) 86 | return 87 | 88 | def setup_summary(self): 89 | """ Setup summary """ 90 | max_num_out = 2 91 | self.summary = [ 92 | tf.summary.image('input_a', self.net.im_a, max_outputs=max_num_out), 93 | tf.summary.image('input_b', self.net.im_b, max_outputs=max_num_out), 94 | tf.summary.scalar('total_loss', self.net.total_loss), 95 | tf.summary.scalar('learning_rate', self.net._opt._lr) 96 | ] 97 | if not self.net.freeze_base: 98 | self.summary.extend([tf.summary.scalar('exif_loss', self.net.loss), 99 | tf.summary.scalar('exif_accuracy', self.net.accuracy)]) 100 | if self.net.train_classifcation: 101 | self.summary.extend([tf.summary.scalar('cls_loss', self.net.cls_loss), 102 | tf.summary.scalar('cls_accuracy', self.net.cls_accuracy)]) 103 | if self.use_exif_summary: 104 | self.tag_holder = {tag:tf.placeholder(tf.float32) for tag in self.net.train_runner.tags} 105 | self.individual_summary = {tag:tf.summary.scalar('individual/' + tag, self.tag_holder[tag]) for tag in self.net.train_runner.tags} 106 | return 107 | 108 | def setup_data(self, data, data_fn=None): 109 | assert not self.net.use_tf_threading, "Using queue runner" 110 | self.data = data 111 | if data_fn is not None: 112 | self.data_fn = data_fn 113 | else: 114 | try: 115 | self.data_fn = self.data.exif_balanced_nextbatch 116 | except: 117 | self.data_fn = self.data.nextbatch 118 | 119 | assert self.data_fn is not None 120 | return 121 | 122 | def get_data(self, batch_size, split='train'): 123 | """ Make sure to pass None even if not using final classification """ 124 | assert self.data is not None 125 | if batch_size is None: 126 | batch_size = self._batch_size 127 | 128 | data_dict = self.data_fn(batch_size, split=split) 129 | 130 | args = {self.net.im_a:data_dict['im_a'], 131 | self.net.im_b:data_dict['im_b']} 132 | 133 | if 'cls_lbl' in data_dict: 134 | args[self.net.cls_label] = data_dict['cls_lbl'] 135 | 136 | if 'exif_lbl' in data_dict: 137 | args[self.net.label] = data_dict['exif_lbl'] 138 | return args 139 | 140 | def train(self): 141 | print 'Started training' 142 | while self.i < self.train_iterations: 143 | if self.test_init and self.i == self.start_i: 144 | print('Testing initialization') 145 | self.test(writer=self.test_writer) 146 | 147 | self._train() 148 | self.i += 1 149 | 150 | if self.i % self.show_iter == 0: 151 | self.show(writer=self.train_writer, phase='train') 152 | 153 | if self.i % self.test_iter == 0: 154 | self.test(writer=self.test_writer) 155 | 156 | if self.i % self.save_iter == 0 and self.i != self.start_i: 157 | io.make_ckpt(self.saver, self.sess, self.ckpt_path, self.i) 158 | return 159 | 160 | def _train(self): 161 | start_time = time.time() 162 | if self.net.use_tf_threading: 163 | self.sess.run(self.net.opt) 164 | else: 165 | args = self.get_data(self.net.batch_size, 'train') 166 | self.sess.run(self.net.opt, feed_dict=args) 167 | self.train_timer.append(time.time() - start_time) 168 | return 169 | 170 | def show(self, writer, phase='train'): 171 | if self.net.use_tf_threading: 172 | summary = self.sess.run(self.summary) 173 | else: 174 | args = self.get_data(self.net.batch_size, phase) 175 | summary = self.sess.run(self.summary, feed_dict=args) 176 | 177 | io.add_summary(writer, summary, self.i) 178 | 179 | io.show([['Train time', np.mean(list(self.train_timer))]], 180 | phase=phase, iter=self.i) 181 | return 182 | 183 | def test(self, writer): 184 | if self.use_exif_summary: 185 | exif_start = time.time() 186 | test_queue = self.net.train_runner.get_random_test(batch_size=self.net.batch_size) 187 | to_print = [] 188 | for i, (im_a_batch, im_b_batch, label_batch) in enumerate(test_queue): 189 | tag = self.net.train_runner.tags[i] 190 | output = self.sess.run(self.net.pred, feed_dict={self.net.im_a:im_a_batch, 191 | self.net.im_b:im_b_batch, 192 | self.net.label:label_batch, 193 | self.net.is_training:False}) 194 | 195 | tag_acc = 100.0 * (np.sum(np.round(output[:, i]) == label_batch[:, i])/float(self.net.batch_size)) 196 | summary = self.sess.run(self.individual_summary[tag], feed_dict={self.tag_holder[tag]:tag_acc}) 197 | io.add_summary(writer, [summary], self.i) 198 | to_print.append([tag, tag_acc]) 199 | io.show(to_print, phase='test', iter=self.i) 200 | print('EXIF test accuracy evaluation took %.2f seconds' % (time.time() - exif_start)) 201 | return 202 | 203 | def initialize(args): 204 | return ExifSolver(**args) 205 | -------------------------------------------------------------------------------- /ncuts_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os, sys, numpy as np, ast 5 | import init_paths 6 | import load_models 7 | from lib.utils import benchmark_utils, util 8 | import tensorflow as tf 9 | import cv2, time, scipy, scipy.misc as scm, sklearn.cluster, skimage.io as skio, numpy as np, argparse 10 | import matplotlib.pyplot as plt 11 | from sklearn.cluster import DBSCAN 12 | 13 | import demo 14 | 15 | if __name__ == '__main__': 16 | plt.switch_backend('agg') 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--im_path", type=str, help="path_to_image") 19 | cfg = parser.parse_args() 20 | 21 | assert os.path.exists(cfg.im_path) 22 | 23 | imid = cfg.im_path.split('/')[-1].split('.')[0] 24 | save_path = os.path.join('./images', imid + '_ncuts_result.png') 25 | 26 | ckpt_path = './ckpt/exif_final/exif_final.ckpt' 27 | exif_demo = demo.Demo(ckpt_path=ckpt_path, use_gpu=0, quality=3.0, num_per_dim=30) 28 | 29 | print('Running image %s' % cfg.im_path) 30 | ms_st = time.time() 31 | im_path = cfg.im_path 32 | im1 = skio.imread(im_path)[:,:,:3].astype(np.float32) 33 | res = exif_demo.run(im1, use_ncuts=True, blue_high=True) 34 | print('MeanShift run time: %.3f' % (time.time() - ms_st)) 35 | 36 | plt.subplots(figsize=(16, 8)) 37 | plt.subplot(1, 3, 1) 38 | plt.title('Input Image') 39 | plt.imshow(im1.astype(np.uint8)) 40 | plt.axis('off') 41 | 42 | plt.subplot(1, 3, 2) 43 | plt.title('Cluster w/ MeanShift') 44 | plt.axis('off') 45 | plt.imshow(res[0], cmap='jet', vmin=0.0, vmax=1.0) 46 | 47 | plt.subplot(1, 3, 3) 48 | plt.title('Segment with NCuts') 49 | plt.axis('off') 50 | plt.imshow(res[1], vmin=0.0, vmax=1.0) 51 | 52 | plt.savefig(save_path) 53 | print('Result saved %s' % save_path) -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | Residual networks (ResNets) were proposed in: 17 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 18 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 19 | More variants were introduced in: 20 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 21 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 22 | We can obtain different ResNet variants by changing the network depth, width, 23 | and form of residual unit. This module implements the infrastructure for 24 | building them. Concrete ResNet units and full ResNet networks are implemented in 25 | the accompanying resnet_v1.py and resnet_v2.py modules. 26 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 27 | implementation we subsample the output activations in the last residual unit of 28 | each block, instead of subsampling the input activations in the first residual 29 | unit of each block. The two implementations give identical results but our 30 | implementation is more memory efficient. 31 | """ 32 | from __future__ import absolute_import 33 | from __future__ import division 34 | from __future__ import print_function 35 | 36 | import collections 37 | import tensorflow as tf 38 | 39 | slim = tf.contrib.slim 40 | 41 | 42 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 43 | """A named tuple describing a ResNet block. 44 | Its parts are: 45 | scope: The scope of the `Block`. 46 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 47 | returns another `Tensor` with the output of the ResNet unit. 48 | args: A list of length equal to the number of units in the `Block`. The list 49 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 50 | block to serve as argument to unit_fn. 51 | """ 52 | 53 | 54 | def subsample(inputs, factor, scope=None): 55 | """Subsamples the input along the spatial dimensions. 56 | Args: 57 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 58 | factor: The subsampling factor. 59 | scope: Optional variable_scope. 60 | Returns: 61 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 62 | input, either intact (if factor == 1) or subsampled (if factor > 1). 63 | """ 64 | if factor == 1: 65 | return inputs 66 | else: 67 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 68 | 69 | 70 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 71 | """Strided 2-D convolution with 'SAME' padding. 72 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 73 | 'VALID' padding. 74 | Note that 75 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 76 | is equivalent to 77 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 78 | net = subsample(net, factor=stride) 79 | whereas 80 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 81 | is different when the input's height or width is even, which is why we add the 82 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 83 | Args: 84 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 85 | num_outputs: An integer, the number of output filters. 86 | kernel_size: An int with the kernel_size of the filters. 87 | stride: An integer, the output stride. 88 | rate: An integer, rate for atrous convolution. 89 | scope: Scope. 90 | Returns: 91 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 92 | the convolution output. 93 | """ 94 | if stride == 1: 95 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 96 | padding='SAME', scope=scope) 97 | else: 98 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 99 | pad_total = kernel_size_effective - 1 100 | pad_beg = pad_total // 2 101 | pad_end = pad_total - pad_beg 102 | inputs = tf.pad(inputs, 103 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 104 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 105 | rate=rate, padding='VALID', scope=scope) 106 | 107 | 108 | @slim.add_arg_scope 109 | def stack_blocks_dense(net, blocks, output_stride=None, 110 | outputs_collections=None): 111 | """Stacks ResNet `Blocks` and controls output feature density. 112 | First, this function creates scopes for the ResNet in the form of 113 | 'block_name/unit_1', 'block_name/unit_2', etc. 114 | Second, this function allows the user to explicitly control the ResNet 115 | output_stride, which is the ratio of the input to output spatial resolution. 116 | This is useful for dense prediction tasks such as semantic segmentation or 117 | object detection. 118 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 119 | factor of 2 when transitioning between consecutive ResNet blocks. This results 120 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 121 | half the nominal network stride (e.g., output_stride=4), then we compute 122 | responses twice. 123 | Control of the output feature density is implemented by atrous convolution. 124 | Args: 125 | net: A `Tensor` of size [batch, height, width, channels]. 126 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 127 | element is a ResNet `Block` object describing the units in the `Block`. 128 | output_stride: If `None`, then the output will be computed at the nominal 129 | network stride. If output_stride is not `None`, it specifies the requested 130 | ratio of input to output spatial resolution, which needs to be equal to 131 | the product of unit strides from the start up to some level of the ResNet. 132 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 133 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 134 | is equivalent to output_stride=24). 135 | outputs_collections: Collection to add the ResNet block outputs. 136 | Returns: 137 | net: Output tensor with stride equal to the specified output_stride. 138 | Raises: 139 | ValueError: If the target output_stride is not valid. 140 | """ 141 | # The current_stride variable keeps track of the effective stride of the 142 | # activations. This allows us to invoke atrous convolution whenever applying 143 | # the next residual unit would result in the activations having stride larger 144 | # than the target output_stride. 145 | current_stride = 1 146 | 147 | # The atrous convolution rate parameter. 148 | rate = 1 149 | 150 | for block in blocks: 151 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 152 | for i, unit in enumerate(block.args): 153 | if output_stride is not None and current_stride > output_stride: 154 | raise ValueError('The target output_stride cannot be reached.') 155 | 156 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 157 | # If we have reached the target output_stride, then we need to employ 158 | # atrous convolution with stride=1 and multiply the atrous rate by the 159 | # current unit's stride for use in subsequent layers. 160 | if output_stride is not None and current_stride == output_stride: 161 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 162 | rate *= unit.get('stride', 1) 163 | 164 | else: 165 | net = block.unit_fn(net, rate=1, **unit) 166 | current_stride *= unit.get('stride', 1) 167 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 168 | 169 | if output_stride is not None and current_stride != output_stride: 170 | raise ValueError('The target output_stride cannot be reached.') 171 | 172 | return net 173 | 174 | 175 | def resnet_arg_scope(weight_decay=0.0001, 176 | batch_norm_decay=0.997, 177 | batch_norm_epsilon=1e-5, 178 | batch_norm_scale=True, 179 | activation_fn=tf.nn.relu, 180 | use_batch_norm=True): 181 | """Defines the default ResNet arg scope. 182 | TODO(gpapan): The batch-normalization related default values above are 183 | appropriate for use in conjunction with the reference ResNet models 184 | released at https://github.com/KaimingHe/deep-residual-networks. When 185 | training ResNets from scratch, they might need to be tuned. 186 | Args: 187 | weight_decay: The weight decay to use for regularizing the model. 188 | batch_norm_decay: The moving average decay when estimating layer activation 189 | statistics in batch normalization. 190 | batch_norm_epsilon: Small constant to prevent division by zero when 191 | normalizing activations by their variance in batch normalization. 192 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 193 | activations in the batch normalization layer. 194 | activation_fn: The activation function which is used in ResNet. 195 | use_batch_norm: Whether or not to use batch normalization. 196 | Returns: 197 | An `arg_scope` to use for the resnet models. 198 | """ 199 | batch_norm_params = { 200 | 'decay': batch_norm_decay, 201 | 'epsilon': batch_norm_epsilon, 202 | 'scale': batch_norm_scale, 203 | 'updates_collections': None, 204 | } 205 | 206 | with slim.arg_scope( 207 | [slim.conv2d], 208 | weights_regularizer=slim.l2_regularizer(weight_decay), 209 | weights_initializer=slim.variance_scaling_initializer(), 210 | activation_fn=activation_fn, 211 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 212 | normalizer_params=batch_norm_params): 213 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 214 | # The following implies padding='SAME' for pool1, which makes feature 215 | # alignment easier for dense prediction tasks. This is also used in 216 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 217 | # code of 'Deep Residual Learning for Image Recognition' uses 218 | # padding='VALID' for pool1. You can switch to that choice by setting 219 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 220 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 221 | return arg_sc 222 | -------------------------------------------------------------------------------- /nets/resnet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the original form of Residual Networks. 16 | The 'v1' residual networks (ResNets) implemented in this module were proposed 17 | by: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | Other variants were introduced in: 21 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 22 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 23 | The networks defined in this module utilize the bottleneck building block of 24 | [1] with projection shortcuts only for increasing depths. They employ batch 25 | normalization *after* every weight layer. This is the architecture used by 26 | MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and 27 | ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1' 28 | architecture and the alternative 'v2' architecture of [2] which uses batch 29 | normalization *before* every weight layer in the so-called full pre-activation 30 | units. 31 | Typical use: 32 | from tensorflow.contrib.slim.nets import resnet_v1 33 | ResNet-101 for image classification into 1000 classes: 34 | # inputs has shape [batch, 224, 224, 3] 35 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 36 | net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False) 37 | ResNet-101 for semantic segmentation into 21 classes: 38 | # inputs has shape [batch, 513, 513, 3] 39 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 40 | net, end_points = resnet_v1.resnet_v1_101(inputs, 41 | 21, 42 | is_training=False, 43 | global_pool=False, 44 | output_stride=16) 45 | """ 46 | from __future__ import absolute_import 47 | from __future__ import division 48 | from __future__ import print_function 49 | 50 | import tensorflow as tf 51 | 52 | from nets import resnet_utils 53 | 54 | 55 | resnet_arg_scope = resnet_utils.resnet_arg_scope 56 | slim = tf.contrib.slim 57 | 58 | 59 | @slim.add_arg_scope 60 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 61 | outputs_collections=None, scope=None): 62 | """Bottleneck residual unit variant with BN after convolutions. 63 | This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for 64 | its definition. Note that we use here the bottleneck variant which has an 65 | extra bottleneck layer. 66 | When putting together two consecutive ResNet blocks that use this unit, one 67 | should use stride = 2 in the last unit of the first block. 68 | Args: 69 | inputs: A tensor of size [batch, height, width, channels]. 70 | depth: The depth of the ResNet unit output. 71 | depth_bottleneck: The depth of the bottleneck layers. 72 | stride: The ResNet unit's stride. Determines the amount of downsampling of 73 | the units output compared to its input. 74 | rate: An integer, rate for atrous convolution. 75 | outputs_collections: Collection to add the ResNet unit output. 76 | scope: Optional variable_scope. 77 | Returns: 78 | The ResNet unit's output. 79 | """ 80 | with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc: 81 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 82 | if depth == depth_in: 83 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 84 | else: 85 | shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride, 86 | activation_fn=None, scope='shortcut') 87 | 88 | residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1, 89 | scope='conv1') 90 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 91 | rate=rate, scope='conv2') 92 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 93 | activation_fn=None, scope='conv3') 94 | 95 | output = tf.nn.relu(shortcut + residual) 96 | 97 | return slim.utils.collect_named_outputs(outputs_collections, 98 | sc.original_name_scope, 99 | output) 100 | 101 | 102 | def resnet_v1(inputs, 103 | blocks, 104 | num_classes=None, 105 | is_training=True, 106 | global_pool=True, 107 | output_stride=None, 108 | include_root_block=True, 109 | spatial_squeeze=True, 110 | reuse=None, 111 | scope=None): 112 | """Generator for v1 ResNet models. 113 | This function generates a family of ResNet v1 models. See the resnet_v1_*() 114 | methods for specific model instantiations, obtained by selecting different 115 | block instantiations that produce ResNets of various depths. 116 | Training for image classification on Imagenet is usually done with [224, 224] 117 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 118 | block for the ResNets defined in [1] that have nominal stride equal to 32. 119 | However, for dense prediction tasks we advise that one uses inputs with 120 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 121 | this case the feature maps at the ResNet output will have spatial shape 122 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 123 | and corners exactly aligned with the input image corners, which greatly 124 | facilitates alignment of the features to the image. Using as input [225, 225] 125 | images results in [8, 8] feature maps at the output of the last ResNet block. 126 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 127 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 128 | have nominal stride equal to 32 and a good choice in FCN mode is to use 129 | output_stride=16 in order to increase the density of the computed features at 130 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 131 | Args: 132 | inputs: A tensor of size [batch, height_in, width_in, channels]. 133 | blocks: A list of length equal to the number of ResNet blocks. Each element 134 | is a resnet_utils.Block object describing the units in the block. 135 | num_classes: Number of predicted classes for classification tasks. If None 136 | we return the features before the logit layer. 137 | is_training: whether is training or not. 138 | global_pool: If True, we perform global average pooling before computing the 139 | logits. Set to True for image classification, False for dense prediction. 140 | output_stride: If None, then the output will be computed at the nominal 141 | network stride. If output_stride is not None, it specifies the requested 142 | ratio of input to output spatial resolution. 143 | include_root_block: If True, include the initial convolution followed by 144 | max-pooling, if False excludes it. 145 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 146 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 147 | reuse: whether or not the network and its variables should be reused. To be 148 | able to reuse 'scope' must be given. 149 | scope: Optional variable_scope. 150 | Returns: 151 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 152 | If global_pool is False, then height_out and width_out are reduced by a 153 | factor of output_stride compared to the respective height_in and width_in, 154 | else both height_out and width_out equal one. If num_classes is None, then 155 | net is the output of the last ResNet block, potentially after global 156 | average pooling. If num_classes is not None, net contains the pre-softmax 157 | activations. 158 | end_points: A dictionary from components of the network to the corresponding 159 | activation. 160 | Raises: 161 | ValueError: If the target output_stride is not valid. 162 | """ 163 | with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc: 164 | end_points_collection = sc.name + '_end_points' 165 | with slim.arg_scope([slim.conv2d, bottleneck, 166 | resnet_utils.stack_blocks_dense], 167 | outputs_collections=end_points_collection): 168 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 169 | net = inputs 170 | if include_root_block: 171 | if output_stride is not None: 172 | if output_stride % 4 != 0: 173 | raise ValueError('The output_stride needs to be a multiple of 4.') 174 | output_stride /= 4 175 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 176 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 177 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 178 | if global_pool: 179 | # Global average pooling. 180 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 181 | if num_classes is not None: 182 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 183 | normalizer_fn=None, scope='logits') 184 | if spatial_squeeze: 185 | logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 186 | # Convert end_points_collection into a dictionary of end_points. 187 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 188 | if num_classes is not None: 189 | end_points['predictions'] = slim.softmax(logits, scope='predictions') 190 | return logits, end_points 191 | resnet_v1.default_image_size = 224 192 | 193 | 194 | def resnet_v1_50(inputs, 195 | num_classes=None, 196 | is_training=True, 197 | global_pool=True, 198 | output_stride=None, 199 | reuse=None, 200 | scope='resnet_v1_50'): 201 | """ResNet-50 model of [1]. See resnet_v1() for arg and return description.""" 202 | blocks = [ 203 | resnet_utils.Block( 204 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 205 | resnet_utils.Block( 206 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 207 | resnet_utils.Block( 208 | 'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]), 209 | resnet_utils.Block( 210 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 211 | ] 212 | return resnet_v1(inputs, blocks, num_classes, is_training, 213 | global_pool=global_pool, output_stride=output_stride, 214 | include_root_block=True, reuse=reuse, scope=scope) 215 | resnet_v1_50.default_image_size = resnet_v1.default_image_size 216 | 217 | 218 | def resnet_v1_101(inputs, 219 | num_classes=None, 220 | is_training=True, 221 | global_pool=True, 222 | output_stride=None, 223 | reuse=None, 224 | scope='resnet_v1_101'): 225 | """ResNet-101 model of [1]. See resnet_v1() for arg and return description.""" 226 | blocks = [ 227 | resnet_utils.Block( 228 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 229 | resnet_utils.Block( 230 | 'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]), 231 | resnet_utils.Block( 232 | 'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]), 233 | resnet_utils.Block( 234 | 'block4', bottleneck, [(2048, 512, 1)] * 3) 235 | ] 236 | return resnet_v1(inputs, blocks, num_classes, is_training, 237 | global_pool=global_pool, output_stride=output_stride, 238 | include_root_block=True, reuse=reuse, scope=scope) 239 | resnet_v1_101.default_image_size = resnet_v1.default_image_size 240 | 241 | 242 | def resnet_v1_152(inputs, 243 | num_classes=None, 244 | is_training=True, 245 | global_pool=True, 246 | output_stride=None, 247 | reuse=None, 248 | scope='resnet_v1_152'): 249 | """ResNet-152 model of [1]. See resnet_v1() for arg and return description.""" 250 | blocks = [ 251 | resnet_utils.Block( 252 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 253 | resnet_utils.Block( 254 | 'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]), 255 | resnet_utils.Block( 256 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 257 | resnet_utils.Block( 258 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 259 | return resnet_v1(inputs, blocks, num_classes, is_training, 260 | global_pool=global_pool, output_stride=output_stride, 261 | include_root_block=True, reuse=reuse, scope=scope) 262 | resnet_v1_152.default_image_size = resnet_v1.default_image_size 263 | 264 | 265 | def resnet_v1_200(inputs, 266 | num_classes=None, 267 | is_training=True, 268 | global_pool=True, 269 | output_stride=None, 270 | reuse=None, 271 | scope='resnet_v1_200'): 272 | """ResNet-200 model of [2]. See resnet_v1() for arg and return description.""" 273 | blocks = [ 274 | resnet_utils.Block( 275 | 'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]), 276 | resnet_utils.Block( 277 | 'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]), 278 | resnet_utils.Block( 279 | 'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]), 280 | resnet_utils.Block( 281 | 'block4', bottleneck, [(2048, 512, 1)] * 3)] 282 | return resnet_v1(inputs, blocks, num_classes, is_training, 283 | global_pool=global_pool, output_stride=output_stride, 284 | include_root_block=True, reuse=reuse, scope=scope) 285 | resnet_v1_200.default_image_size = resnet_v1.default_image_size 286 | -------------------------------------------------------------------------------- /nets/resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the preactivation form of Residual Networks. 16 | Residual networks (ResNets) were originally proposed in: 17 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 18 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 19 | The full preactivation 'v2' ResNet variant implemented in this module was 20 | introduced by: 21 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 22 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 23 | The key difference of the full preactivation 'v2' variant compared to the 24 | 'v1' variant in [1] is the use of batch normalization before every weight layer. 25 | Typical use: 26 | from tensorflow.contrib.slim.nets import resnet_v2 27 | ResNet-101 for image classification into 1000 classes: 28 | # inputs has shape [batch, 224, 224, 3] 29 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 30 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False) 31 | ResNet-101 for semantic segmentation into 21 classes: 32 | # inputs has shape [batch, 513, 513, 3] 33 | with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)): 34 | net, end_points = resnet_v2.resnet_v2_101(inputs, 35 | 21, 36 | is_training=False, 37 | global_pool=False, 38 | output_stride=16) 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import tensorflow as tf 45 | 46 | from nets import resnet_utils 47 | 48 | slim = tf.contrib.slim 49 | resnet_arg_scope = resnet_utils.resnet_arg_scope 50 | 51 | 52 | @slim.add_arg_scope 53 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 54 | outputs_collections=None, scope=None): 55 | """Bottleneck residual unit variant with BN before convolutions. 56 | This is the full preactivation residual unit variant proposed in [2]. See 57 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck 58 | variant which has an extra bottleneck layer. 59 | When putting together two consecutive ResNet blocks that use this unit, one 60 | should use stride = 2 in the last unit of the first block. 61 | Args: 62 | inputs: A tensor of size [batch, height, width, channels]. 63 | depth: The depth of the ResNet unit output. 64 | depth_bottleneck: The depth of the bottleneck layers. 65 | stride: The ResNet unit's stride. Determines the amount of downsampling of 66 | the units output compared to its input. 67 | rate: An integer, rate for atrous convolution. 68 | outputs_collections: Collection to add the ResNet unit output. 69 | scope: Optional variable_scope. 70 | Returns: 71 | The ResNet unit's output. 72 | """ 73 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: 74 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 75 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') 76 | if depth == depth_in: 77 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 78 | else: 79 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride, 80 | normalizer_fn=None, activation_fn=None, 81 | scope='shortcut') 82 | 83 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1, 84 | scope='conv1') 85 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 86 | rate=rate, scope='conv2') 87 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 88 | normalizer_fn=None, activation_fn=None, 89 | scope='conv3') 90 | 91 | output = shortcut + residual 92 | 93 | return slim.utils.collect_named_outputs(outputs_collections, 94 | sc.original_name_scope, 95 | output) 96 | 97 | 98 | def resnet_v2(inputs, 99 | blocks, 100 | num_classes=None, 101 | is_training=True, 102 | global_pool=True, 103 | output_stride=None, 104 | include_root_block=True, 105 | spatial_squeeze=True, 106 | reuse=None, 107 | scope=None): 108 | """Generator for v2 (preactivation) ResNet models. 109 | This function generates a family of ResNet v2 models. See the resnet_v2_*() 110 | methods for specific model instantiations, obtained by selecting different 111 | block instantiations that produce ResNets of various depths. 112 | Training for image classification on Imagenet is usually done with [224, 224] 113 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 114 | block for the ResNets defined in [1] that have nominal stride equal to 32. 115 | However, for dense prediction tasks we advise that one uses inputs with 116 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 117 | this case the feature maps at the ResNet output will have spatial shape 118 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 119 | and corners exactly aligned with the input image corners, which greatly 120 | facilitates alignment of the features to the image. Using as input [225, 225] 121 | images results in [8, 8] feature maps at the output of the last ResNet block. 122 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 123 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 124 | have nominal stride equal to 32 and a good choice in FCN mode is to use 125 | output_stride=16 in order to increase the density of the computed features at 126 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 127 | Args: 128 | inputs: A tensor of size [batch, height_in, width_in, channels]. 129 | blocks: A list of length equal to the number of ResNet blocks. Each element 130 | is a resnet_utils.Block object describing the units in the block. 131 | num_classes: Number of predicted classes for classification tasks. If None 132 | we return the features before the logit layer. 133 | is_training: whether is training or not. 134 | global_pool: If True, we perform global average pooling before computing the 135 | logits. Set to True for image classification, False for dense prediction. 136 | output_stride: If None, then the output will be computed at the nominal 137 | network stride. If output_stride is not None, it specifies the requested 138 | ratio of input to output spatial resolution. 139 | include_root_block: If True, include the initial convolution followed by 140 | max-pooling, if False excludes it. If excluded, `inputs` should be the 141 | results of an activation-less convolution. 142 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 143 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 144 | To use this parameter, the input images must be smaller than 300x300 145 | pixels, in which case the output logit layer does not contain spatial 146 | information and can be removed. 147 | reuse: whether or not the network and its variables should be reused. To be 148 | able to reuse 'scope' must be given. 149 | scope: Optional variable_scope. 150 | Returns: 151 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 152 | If global_pool is False, then height_out and width_out are reduced by a 153 | factor of output_stride compared to the respective height_in and width_in, 154 | else both height_out and width_out equal one. If num_classes is None, then 155 | net is the output of the last ResNet block, potentially after global 156 | average pooling. If num_classes is not None, net contains the pre-softmax 157 | activations. 158 | end_points: A dictionary from components of the network to the corresponding 159 | activation. 160 | Raises: 161 | ValueError: If the target output_stride is not valid. 162 | """ 163 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: 164 | end_points_collection = sc.name + '_end_points' 165 | with slim.arg_scope([slim.conv2d, bottleneck, 166 | resnet_utils.stack_blocks_dense], 167 | outputs_collections=end_points_collection): 168 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 169 | net = inputs 170 | if include_root_block: 171 | if output_stride is not None: 172 | if output_stride % 4 != 0: 173 | raise ValueError('The output_stride needs to be a multiple of 4.') 174 | output_stride /= 4 175 | # We do not include batch normalization or activation functions in 176 | # conv1 because the first ResNet unit will perform these. Cf. 177 | # Appendix of [2]. 178 | with slim.arg_scope([slim.conv2d], 179 | activation_fn=None, normalizer_fn=None): 180 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 181 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 182 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 183 | # This is needed because the pre-activation variant does not have batch 184 | # normalization or activation functions in the residual unit output. See 185 | # Appendix of [2]. 186 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') 187 | if global_pool: 188 | # Global average pooling. 189 | net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True) 190 | if num_classes is not None: 191 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 192 | normalizer_fn=None, scope='logits') 193 | if spatial_squeeze: 194 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 195 | # Convert end_points_collection into a dictionary of end_points. 196 | end_points = slim.utils.convert_collection_to_dict( 197 | end_points_collection) 198 | if num_classes is not None: 199 | end_points['predictions'] = slim.softmax(net, scope='predictions') 200 | return net, end_points 201 | resnet_v2.default_image_size = 224 202 | 203 | 204 | def resnet_v2_block(scope, base_depth, num_units, stride): 205 | """Helper function for creating a resnet_v2 bottleneck block. 206 | Args: 207 | scope: The scope of the block. 208 | base_depth: The depth of the bottleneck layer for each unit. 209 | num_units: The number of units in the block. 210 | stride: The stride of the block, implemented as a stride in the last unit. 211 | All other units have stride=1. 212 | Returns: 213 | A resnet_v2 bottleneck block. 214 | """ 215 | return resnet_utils.Block(scope, bottleneck, [{ 216 | 'depth': base_depth * 4, 217 | 'depth_bottleneck': base_depth, 218 | 'stride': 1 219 | }] * (num_units - 1) + [{ 220 | 'depth': base_depth * 4, 221 | 'depth_bottleneck': base_depth, 222 | 'stride': stride 223 | }]) 224 | resnet_v2.default_image_size = 224 225 | 226 | 227 | def resnet_v2_50(inputs, 228 | num_classes=None, 229 | is_training=True, 230 | global_pool=True, 231 | output_stride=None, 232 | spatial_squeeze=True, 233 | reuse=None, 234 | scope='resnet_v2_50'): 235 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" 236 | blocks = [ 237 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 238 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 239 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2), 240 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 241 | ] 242 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 243 | global_pool=global_pool, output_stride=output_stride, 244 | include_root_block=True, spatial_squeeze=spatial_squeeze, 245 | reuse=reuse, scope=scope) 246 | resnet_v2_50.default_image_size = resnet_v2.default_image_size 247 | 248 | 249 | def resnet_v2_101(inputs, 250 | num_classes=None, 251 | is_training=True, 252 | global_pool=True, 253 | output_stride=None, 254 | spatial_squeeze=True, 255 | reuse=None, 256 | scope='resnet_v2_101'): 257 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" 258 | blocks = [ 259 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 260 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 261 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2), 262 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 263 | ] 264 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 265 | global_pool=global_pool, output_stride=output_stride, 266 | include_root_block=True, spatial_squeeze=spatial_squeeze, 267 | reuse=reuse, scope=scope) 268 | resnet_v2_101.default_image_size = resnet_v2.default_image_size 269 | 270 | 271 | def resnet_v2_152(inputs, 272 | num_classes=None, 273 | is_training=True, 274 | global_pool=True, 275 | output_stride=None, 276 | spatial_squeeze=True, 277 | reuse=None, 278 | scope='resnet_v2_152'): 279 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" 280 | blocks = [ 281 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 282 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2), 283 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 284 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 285 | ] 286 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 287 | global_pool=global_pool, output_stride=output_stride, 288 | include_root_block=True, spatial_squeeze=spatial_squeeze, 289 | reuse=reuse, scope=scope) 290 | resnet_v2_152.default_image_size = resnet_v2.default_image_size 291 | 292 | 293 | def resnet_v2_200(inputs, 294 | num_classes=None, 295 | is_training=True, 296 | global_pool=True, 297 | output_stride=None, 298 | spatial_squeeze=True, 299 | reuse=None, 300 | scope='resnet_v2_200'): 301 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" 302 | blocks = [ 303 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 304 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2), 305 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 306 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 307 | ] 308 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 309 | global_pool=global_pool, output_stride=output_stride, 310 | include_root_block=True, spatial_squeeze=spatial_squeeze, 311 | reuse=reuse, scope=scope) 312 | resnet_v2_200.default_image_size = resnet_v2.default_image_size 313 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.0 2 | opencv_python==3.4.0.12 3 | tensorflow_gpu>=1.2.0 4 | scikit-image==0.13.1 5 | scipy==1.0.0 6 | matplotlib==2.1.2 7 | scikit_learn==0.19.1 8 | Pillow==5.0.0 9 | --------------------------------------------------------------------------------