├── helper ├── white.png └── graycolor.png ├── imgs └── sample.jpg ├── LICENSE ├── README.MD ├── genIntSegPairs.m ├── IntSeg_GUI.py ├── our_func_cvpr18.py └── IntSeg_Train.py /helper/white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/Intseg/HEAD/helper/white.png -------------------------------------------------------------------------------- /imgs/sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/Intseg/HEAD/imgs/sample.jpg -------------------------------------------------------------------------------- /helper/graycolor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isl-org/Intseg/HEAD/helper/graycolor.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 IntelVCL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # DISCONTINUATION OF PROJECT # 2 | This project will no longer be maintained by Intel. 3 | Intel has ceased development and contributions including, but not limited to, maintenance, bug fixes, new releases, or updates, to this project. 4 | Intel no longer accepts patches to this project. 5 | If you have an ongoing need to use this project, are interested in independently developing it, or would like to maintain patches for the open source software community, please create your own fork of this project. 6 | 7 | # Interactive Image Segmentation with Latent Diversity 8 | This is a Tensorflow implementation of Interactive Image Segmentation with Latent Diversity. It receives positive and negative clicks and produces segmentation masks. 9 | 10 | ## Setup 11 | 12 | ### Requirement 13 | Required python libraries: Tensorflow (>=1.3) + OpenCV + Scipy + Numpy. 14 | 15 | Tested in Ubuntu 16.04 LTS + Intel i7 CPU + Nvidia Titan X (Pascal) with Cuda (>=8.0) and CuDNN (>=6.0). 16 | 17 | ### Quick Start 18 | 1. Clone this repository. 19 | 2. Download the pre-trained model from this [link](https://drive.google.com/open?id=1u96zu0VyNpy-1VL90EbriN74hGaBBK08). Unzip it and put them into the "Models" folder. 20 | 3. Run "IntSeg_GUI.py", and a window will show up. 21 | 4. Open an image (one sample image is provided in "imgs"); the image will show on the up-left. 22 | 5. Use the mouse to input positive (left) and/or negative (right) clicks. 23 | 24 | The segmentation mask will show on the bottom-left, and the overlying image will show on the up-right. The bottom-right window can be ignored at this moment. The click inputs and segmentation results will be saved in the "res" folder under a random user id specified folder. 25 | 26 | Note that the GUI is designed for demonstration only, and thus it is not optimized for images with arbitrary resolution. 27 | 28 | ### Training 29 | 30 | The MATLAB script "genIntSegPairs.m" is provided for automatically generating positive/negative clicks. Note that the synthesizing strategies follow "Deep interactive object selection" ([arxiv link](https://arxiv.org/abs/1603.04042)). 31 | 32 | With the generated positive/negative clicks, run "IntSeg_Train.py" to start training after the "im_path" and "seg_path" are properly set. 33 | 34 | The current implementation processes the SBD dataset ([link](http://home.bharathh.info/pubs/codes/SBD/download.html)), and it can be modified to process any dataset with image and intance mask pairs. 35 | 36 | ## Citation 37 | If you use our code for research, please cite our paper: 38 | 39 | Zhuwen Li, Qifeng Chen, and Vladlen Koltun. Interactive Image Segmentation with Latent Diversity. In CVPR 2018. 40 | 41 | ## Question 42 | If you have any question or request about the code and data, please email me at lzhuwen@gmail.com. 43 | 44 | ## License 45 | MIT License 46 | -------------------------------------------------------------------------------- /genIntSegPairs.m: -------------------------------------------------------------------------------- 1 | %% 2 | 3 | % The path to store the synthesized clicks 4 | 5 | savepath = ['./train']; 6 | 7 | %% 8 | % get the names of all training images 9 | imgset='train'; 10 | ids = textread(['./' imgset '.txt'],'%s'); 11 | num_img = length(ids); 12 | d_step = 10; 13 | d_margin = 5; 14 | % 15 | for i = 1:num_img 16 | if ~isdir([savepath '/' ids{i} '/objs']) 17 | mkdir([savepath '/' ids{i} '/objs']); 18 | end 19 | if ~isdir([savepath '/' ids{i} '/ints']) 20 | mkdir([savepath '/' ids{i} '/ints']); 21 | end 22 | imgpath=sprintf('./img/%s.jpg',ids{i}); 23 | img = imread(imgpath); 24 | objsegpath=sprintf('./inst/%s.mat',ids{i}); 25 | load(objsegpath) 26 | objseg_img = GTinst.Segmentation; 27 | sz = size(objseg_img); 28 | tmp_img = objseg_img; 29 | % tmp_img(objseg_img==255) = 0; 30 | num_obj = max(tmp_img(:)); 31 | for j = 1:num_obj 32 | seg_mask = (tmp_img==j); 33 | imwrite(seg_mask,[savepath '/' ids{i} '/objs/' num2str(j,'%05d') '.png']); 34 | 35 | for k = 1:15 % N_pairs 36 | pc = zeros(sz); % positive channel 37 | nc = zeros(sz); % negative channel 38 | 39 | %% positive clicks 40 | pc_num = randi(10); 41 | dis_bd = bwdist(1-seg_mask); 42 | dis_pt = 255*ones(sz); 43 | for n = 1:pc_num 44 | [m, ind] = max(rand(sz(1)*sz(2),1).*(dis_bd(:)>d_margin).*(dis_pt(:)>d_step)); 45 | if m ~= 0 46 | [r, c] = ind2sub(sz,ind); 47 | pc(r,c) = 1; 48 | dis_pt = bwdist(pc); 49 | else 50 | break; 51 | end 52 | end 53 | imwrite(uint8(dis_pt),[savepath '/' ids{i} '/ints/' num2str(j,'%03d') '_' num2str(k,'%03d') '_pos.png']); 54 | 55 | 56 | %% negative clicks 57 | if num_obj > 1 58 | strat = randi(3); 59 | else 60 | strat = randi(2); 61 | end 62 | dis_bd = bwdist(seg_mask); 63 | switch strat 64 | % Strategy 1 65 | case 1 66 | np_num = randi(15); 67 | dis_pt = 255*ones(sz); 68 | for n = 1:np_num 69 | [m, ind] = max(rand(sz(1)*sz(2),1).*(dis_bd(:)>d_margin).*(dis_pt(:)>d_step)); 70 | if m ~= 0 71 | [r, c] = ind2sub(sz,ind); 72 | nc(r,c) = 1; 73 | dis_pt = bwdist(nc); 74 | else 75 | break; 76 | end 77 | end 78 | imwrite(uint8(dis_pt),[savepath '/' ids{i} '/ints/' num2str(j,'%03d') '_' num2str(k,'%03d') '_neg.png']); 79 | % Strategy 2 80 | case 3 81 | for tmpj = 1:num_obj 82 | if tmpj ~= j 83 | np_num = randi(10); 84 | tmp_mask = (tmp_img==tmpj); 85 | dis_bd = bwdist(1-tmp_mask); 86 | dis_pt = 255*ones(sz); 87 | for n = 1:np_num 88 | [m, ind] = max(rand(sz(1)*sz(2),1).*(dis_bd(:)>d_margin).*(dis_pt(:)>d_step)); 89 | if m ~= 0 90 | [r, c] = ind2sub(sz,ind); 91 | nc(r,c) = 1; 92 | dis_pt = bwdist(nc); 93 | else 94 | break; 95 | end 96 | end 97 | end 98 | end 99 | imwrite(uint8(dis_pt),[savepath '/' ids{i} '/ints/' num2str(j,'%03d') '_' num2str(k,'%03d') '_neg.png']); 100 | % Strategy 3 101 | case 2 102 | np_num = 15; 103 | dis_bd = bwdist(seg_mask); 104 | sample_region = seg_mask + (dis_bd>=40); 105 | % randomly generate the 1st point 106 | [m, ind] = max(rand(sz(1)*sz(2),1).*(1-sample_region(:))); 107 | [r, c] = ind2sub(sz,ind); 108 | nc(r,c) = 1; 109 | sample_region = sample_region + nc; 110 | for n = 2:np_num 111 | dis_bd = bwdist(sample_region); 112 | [m, ind] = max(dis_bd(:)); 113 | if m ~= 0 114 | [r, c] = ind2sub(sz,ind); 115 | nc(r,c) = 1; 116 | sample_region = sample_region + nc; 117 | else 118 | break; 119 | end 120 | end 121 | dis_pt = bwdist(nc); 122 | imwrite(uint8(dis_pt),[savepath '/' ids{i} '/ints/' num2str(j,'%03d') '_' num2str(k,'%03d') '_neg.png']); 123 | end 124 | 125 | %% show points 126 | % figure,imshow(img); 127 | % hold on 128 | % [pos_r,pos_c] = find(pc==1); 129 | % [neg_r,neg_c] = find(nc==1); 130 | % plot(pos_c,pos_r,'+g','markersize',8); 131 | % plot(neg_c,neg_r,'xr','markersize',8); 132 | % hold off 133 | % drawnow;pause(0.1); 134 | % close all; 135 | end 136 | 137 | 138 | end 139 | end 140 | -------------------------------------------------------------------------------- /IntSeg_GUI.py: -------------------------------------------------------------------------------- 1 | from Tkinter import * 2 | import Tkinter, Tkconstants, tkFileDialog 3 | from PIL import Image, ImageTk 4 | from ttk import Frame, Style, Button, Radiobutton 5 | import os, time, cv2 6 | import numpy as np 7 | import tkMessageBox as mbox 8 | from our_func_cvpr18 import our_func 9 | import tensorflow as tf 10 | 11 | class Example(Frame): 12 | 13 | cnt = 0 14 | imIdx = 0 15 | usrId = -1 16 | flag = 0 17 | filename = "" 18 | x_bd = 0 19 | y_bd = 0 20 | 21 | def __init__(self, parent): 22 | Frame.__init__(self, parent) 23 | 24 | self.parent = parent 25 | self.initUI() 26 | 27 | def onNextImg(self): 28 | mbox.showinfo("Information", "One task completed! Thank you!") 29 | 30 | def onNextMethod(self): 31 | mbox.showinfo("Information", "One task completed! Thank you!" ) 32 | 33 | def callback_left(self, event): 34 | if self.flag == 1: 35 | return 36 | self.flag = 1 37 | self.focus_set() 38 | T.insert('1.0', "Click %d: Positive click at [%d, %d].\n\n"%(self.cnt, event.x, event.y)) 39 | # print "left clicked at", event.x, event.y, self.cnt 40 | if event.y >= self.y_bd or event.x >= self.x_bd: 41 | T.insert('1.0', "Click is outside the image! Ignored!\n") 42 | self.flag = 0 43 | return 44 | target0 = open("res/%d/Ours/time_log.txt" % self.usrId, 'a+') 45 | st = time.time() 46 | our_iou = our_func(self.usrId, self.imIdx, self.filename, self.cnt, 1, event) 47 | target0.write("%f\n" % (time.time()-st)) 48 | target0.close() 49 | self.update_show() 50 | self.cnt = self.cnt + 1 51 | self.flag = 0 52 | 53 | def callback_right(self, event): 54 | if self.flag == 1: 55 | return 56 | self.flag = 1 57 | self.focus_set() 58 | T.insert('1.0', "Click %d: Negative click at [%d, %d].\n\n" % (self.cnt, event.x, event.y)) 59 | # print "right clicked at", event.x, event.y, self.cnt 60 | if event.y >= self.y_bd or event.x >= self.x_bd: 61 | T.insert('1.0', "Click is outside the image! Ignored!\n") 62 | self.flag = 0 63 | return 64 | target0 = open("res/%d/Ours/time_log.txt" % self.usrId, 'a+') 65 | st = time.time() 66 | our_iou = our_func(self.usrId, self.imIdx, self.filename, self.cnt, 2, event) 67 | target0.write("%f\n" % (time.time()-st)) 68 | target0.close() 69 | self.update_show() 70 | self.cnt = self.cnt + 1 71 | self.flag = 0 72 | 73 | def initUI(self): 74 | # os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp') 75 | # os.environ['CUDA_VISIBLE_DEVICES'] = str(np.argmax([int(x.split()[2]) for x in open('tmp', 'r').readlines()])) 76 | # os.system('rm tmp') 77 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 78 | sess=tf.Session() 79 | self.usrId = time.time() 80 | self.usrId = int(time.time()-1495000000) 81 | 82 | if not os.path.isdir("res/%d" % self.usrId): 83 | os.makedirs("res/%d" % self.usrId) 84 | os.makedirs("res/%d/Ours" % self.usrId) 85 | 86 | self.parent.title("Interactive Image Segmentation") 87 | self.pack(fill=BOTH, expand=1) 88 | 89 | global T 90 | T = Text(self, height=100, width=20) 91 | T.pack() 92 | T.insert(END, "Welcome!\n") 93 | T.place(x=1760, y=20) 94 | Style().configure("TFrame", background="#333") 95 | 96 | w = 1920 97 | h = 1080 98 | # w = 1280 99 | # h = 900 100 | 101 | sw = self.parent.winfo_screenwidth() 102 | sh = self.parent.winfo_screenheight() 103 | 104 | x = (sw - w) / 2 105 | y = (sh - h) / 2 106 | self.parent.geometry('%dx%d+%d+%d' % (w, h, x, y)) 107 | 108 | self.filename = tkFileDialog.askopenfilename(initialdir="./", title="Select file", 109 | filetypes=(("jpeg files", "*.jpg"), ("all files", "*.*"))) 110 | 111 | self.update_all() 112 | 113 | self.bind_all("", self.callback_left) 114 | self.bind_all("", self.callback_right) 115 | 116 | def update_all(self): 117 | time.sleep(2) 118 | im_path = self.filename 119 | bard = Image.open(im_path) 120 | self.x_bd = bard.width 121 | self.y_bd = bard.height 122 | bardejov = ImageTk.PhotoImage(bard) 123 | label1 = Label(self, image=bardejov) 124 | label1.image = bardejov 125 | label1.place(x=10, y=10) 126 | 127 | imgray_path = "helper/graycolor.png" 128 | rot = Image.open(imgray_path) 129 | rotunda = ImageTk.PhotoImage(rot) 130 | label2 = Label(self, image=rotunda) 131 | label2.image = rotunda 132 | label2.place(x=880, y=500) 133 | 134 | bard = Image.open(imgray_path) 135 | bardejov = ImageTk.PhotoImage(bard) 136 | label1 = Label(self, image=bardejov) 137 | label1.image = bardejov 138 | label1.place(x=10, y=500) 139 | 140 | rot = Image.open(imgray_path) 141 | rotunda = ImageTk.PhotoImage(rot) 142 | label2 = Label(self, image=rotunda) 143 | label2.image = rotunda 144 | label2.place(x=880, y=10) 145 | 146 | self.update() 147 | 148 | def update_show(self): 149 | 150 | res_path = "res/%d/Ours/%05d/segs/%03d.png" % (self.usrId, self.imIdx, self.cnt) 151 | tmp_clk_path = 'res/%d/Ours/%05d/tmps/clk_%03d.png' % (self.usrId, self.imIdx, self.cnt) 152 | tmp_ol_path = 'res/%d/Ours/%05d/tmps/ol_%03d.png' % (self.usrId, self.imIdx, self.cnt) 153 | 154 | bard = Image.open(tmp_clk_path) 155 | bardejov = ImageTk.PhotoImage(bard) 156 | label1 = Label(self, image=bardejov) 157 | label1.image = bardejov 158 | label1.place(x=10, y=10) 159 | bard = Image.open(tmp_ol_path) 160 | bardejov = ImageTk.PhotoImage(bard) 161 | label1 = Label(self, image=bardejov) 162 | label1.image = bardejov 163 | label1.place(x=880, y=10) 164 | minc = Image.open(res_path) 165 | mincol = ImageTk.PhotoImage(minc) 166 | label3 = Label(self, image=mincol) 167 | label3.image = mincol 168 | label3.place(x=10, y=500) 169 | self.update() 170 | 171 | def main(): 172 | root = Tk() 173 | app = Example(root) 174 | root.mainloop() 175 | 176 | 177 | if __name__ == '__main__': 178 | main() -------------------------------------------------------------------------------- /our_func_cvpr18.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os,time,cv2 3 | import scipy.io as sio 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | import numpy as np 7 | from numpy import * 8 | import scipy.linalg 9 | from copy import copy, deepcopy 10 | from scipy import ndimage 11 | 12 | def compIoU(im1, im2): 13 | im1_mask = (im1>0.5) 14 | im2_mask = (im2>0.5) 15 | iou = np.sum(im1_mask&im2_mask)/np.sum(im1_mask|im2_mask) 16 | return iou 17 | 18 | def lrelu(x): 19 | return tf.maximum(x*0.2,x) 20 | 21 | def identity_initializer(): 22 | def _initializer(shape, dtype=tf.float32, partition_info=None): 23 | array = np.zeros(shape, dtype=float) 24 | cx, cy = shape[0]//2, shape[1]//2 25 | for i in range(min(shape[2],shape[3])): 26 | array[cx, cy, i, i] = 1 27 | return tf.constant(array, dtype=dtype) 28 | return _initializer 29 | 30 | def nm(x): 31 | w0=tf.Variable(1.0,name='w0') 32 | w1=tf.Variable(0.0,name='w1') 33 | return w0*x+w1*slim.batch_norm(x) 34 | 35 | MEAN_VALUES = np.array([123.6800, 116.7790, 103.9390]).reshape((1,1,1,3)) 36 | 37 | def build_net(ntype,nin,nwb=None,name=None): 38 | if ntype=='conv': 39 | return tf.nn.relu(tf.nn.conv2d(nin,nwb[0],strides=[1,1,1,1],padding='SAME',name=name)+nwb[1]) 40 | elif ntype=='pool': 41 | return tf.nn.avg_pool(nin,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') 42 | 43 | def get_weight_bias(vgg_layers,i): 44 | weights=vgg_layers[i][0][0][2][0][0] 45 | weights=tf.constant(weights) 46 | bias=vgg_layers[i][0][0][2][0][1] 47 | bias=tf.constant(np.reshape(bias,(bias.size))) 48 | return weights,bias 49 | 50 | def build_vgg19(input,reuse=False): 51 | if reuse: 52 | tf.get_variable_scope().reuse_variables() 53 | net={} 54 | vgg_rawnet=scipy.io.loadmat('Models/imagenet-vgg-verydeep-19.mat') 55 | vgg_layers=vgg_rawnet['layers'][0] 56 | net['input']=input-MEAN_VALUES 57 | net['conv1_1']=build_net('conv',net['input'],get_weight_bias(vgg_layers,0),name='vgg_conv1_1') 58 | net['conv1_2']=build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2),name='vgg_conv1_2') 59 | net['pool1']=build_net('pool',net['conv1_2']) 60 | net['conv2_1']=build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5),name='vgg_conv2_1') 61 | net['conv2_2']=build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7),name='vgg_conv2_2') 62 | net['pool2']=build_net('pool',net['conv2_2']) 63 | net['conv3_1']=build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10),name='vgg_conv3_1') 64 | net['conv3_2']=build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12),name='vgg_conv3_2') 65 | net['conv3_3']=build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14),name='vgg_conv3_3') 66 | net['conv3_4']=build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16),name='vgg_conv3_4') 67 | net['pool3']=build_net('pool',net['conv3_4']) 68 | net['conv4_1']=build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19),name='vgg_conv4_1') 69 | net['conv4_2']=build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21),name='vgg_conv4_2') 70 | net['conv4_3']=build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23),name='vgg_conv4_3') 71 | net['conv4_4']=build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25),name='vgg_conv4_4') 72 | net['pool4']=build_net('pool',net['conv4_4']) 73 | net['conv5_1']=build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28),name='vgg_conv5_1') 74 | net['conv5_2']=build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30),name='vgg_conv5_2') 75 | #net['conv5_3']=build_net('conv',net['conv5_2'],get_weight_bias(vgg_layers,32),name='vgg_conv5_3') 76 | #net['conv5_4']=build_net('conv',net['conv5_3'],get_weight_bias(vgg_layers,34),name='vgg_conv5_4') 77 | #net['pool5']=build_net('pool',net['conv5_4']) 78 | return net 79 | 80 | def build(input,sz): 81 | vgg19_features=build_vgg19(input[:,:,:,0:3]) 82 | for layer_id in range(1,6): 83 | vgg19_f = vgg19_features['conv%d_2'%layer_id] 84 | input = tf.concat([input, tf.image.resize_bilinear(vgg19_f,sz)], axis=3) 85 | input = input/255.0 86 | net=slim.conv2d(input,64,[1,1],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv0') 87 | net=slim.conv2d(net,64,[3,3],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv1') 88 | net=slim.conv2d(net,64,[3,3],rate=2,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv2') 89 | net=slim.conv2d(net,64,[3,3],rate=4,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv3') 90 | net=slim.conv2d(net,64,[3,3],rate=8,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv4') 91 | net=slim.conv2d(net,64,[3,3],rate=16,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv5') 92 | net=slim.conv2d(net,64,[3,3],rate=32,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv6') 93 | net=slim.conv2d(net,64,[3,3],rate=64,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv7') 94 | net=slim.conv2d(net,64,[3,3],rate=128,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv8') 95 | net=slim.conv2d(net,64,[3,3],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv9') 96 | net=slim.conv2d(net,6,[1,1],rate=1,activation_fn=None,scope='g_conv_last') 97 | return tf.tanh(net) 98 | 99 | def our_func(usrId, imIdx, im_path, cnt, pn, clk): 100 | 101 | if not os.path.isdir("res/%d/Ours/%05d" % (usrId, imIdx)): 102 | os.makedirs("res/%d/Ours/%05d/ints" % (usrId, imIdx)) 103 | os.makedirs("res/%d/Ours/%05d/segs" % (usrId, imIdx)) 104 | os.makedirs("res/%d/Ours/%05d/tmps" % (usrId, imIdx)) 105 | 106 | sess=tf.Session() 107 | 108 | if cnt == 0 and imIdx == 0: 109 | global network,input,output,sz 110 | input = tf.placeholder(tf.float32, shape=[None, None, None, 7]) 111 | output = tf.placeholder(tf.float32, shape=[None, None, None, 1]) 112 | sz = tf.placeholder(tf.int32, shape=[2]) 113 | network=build(input,sz) 114 | 115 | saver = tf.train.Saver(var_list=[var for var in tf.trainable_variables() if var.name.startswith('g_')]) 116 | sess.run(tf.initialize_all_variables()) 117 | 118 | ckpt=tf.train.get_checkpoint_state("Models/ours_cvpr18") 119 | if ckpt: 120 | # print('loaded '+ckpt.model_checkpoint_path) 121 | saver.restore(sess,ckpt.model_checkpoint_path) 122 | 123 | 124 | input_image = cv2.imread(im_path, -1) 125 | iH, iW, _ = input_image.shape 126 | if cnt == 0: 127 | int_pos = np.uint8(255*np.ones([iH,iW])) 128 | int_neg = np.uint8(255*np.ones([iH,iW])) 129 | tmp_clk = cv2.imread(im_path, -1) 130 | else: 131 | int_pos = cv2.imread('res/%d/Ours/%05d/ints/pos_dt_%03d.png' % (usrId, imIdx, cnt - 1), -1) 132 | int_neg = cv2.imread('res/%d/Ours/%05d/ints/neg_dt_%03d.png' % (usrId, imIdx, cnt - 1), -1) 133 | tmp_clk = cv2.imread('res/%d/Ours/%05d/tmps/clk_%03d.png' % (usrId, imIdx, cnt - 1), -1) 134 | clk_pos = (int_pos==0) 135 | clk_neg = (int_neg==0) 136 | if pn == 1: 137 | clk_pos[clk.y,clk.x] = 1 138 | int_pos = ndimage.distance_transform_edt(1-clk_pos) 139 | int_pos = np.uint8(np.minimum(np.maximum(int_pos, 0.0), 255.0)) 140 | cv2.imwrite('res/%d/Ours/%05d/ints/pos_dt_%03d.png' % (usrId, imIdx, cnt), int_pos) 141 | cv2.imwrite('res/%d/Ours/%05d/ints/neg_dt_%03d.png' % (usrId, imIdx, cnt), int_neg) 142 | cv2.circle(tmp_clk, (clk.x, clk.y), 5, (0, 255, 0), -1) 143 | else: 144 | clk_neg[clk.y,clk.x] = 1 145 | int_neg = ndimage.distance_transform_edt(1-clk_neg) 146 | int_neg = np.uint8(np.minimum(np.maximum(int_neg, 0.0), 255.0)) 147 | cv2.imwrite('res/%d/Ours/%05d/ints/pos_dt_%03d.png' % (usrId, imIdx, cnt), int_pos) 148 | cv2.imwrite('res/%d/Ours/%05d/ints/neg_dt_%03d.png' % (usrId, imIdx, cnt), int_neg) 149 | cv2.circle(tmp_clk, (clk.x, clk.y), 5, (0, 0, 255), -1) 150 | input_pos_clks = deepcopy(int_pos) 151 | input_neg_clks = deepcopy(int_neg) 152 | input_pos_clks[int_pos != 0] = 255 153 | input_neg_clks[int_neg != 0] = 255 154 | input_ = np.expand_dims(np.float32(np.concatenate([input_image, np.expand_dims(int_pos, axis=2), np.expand_dims(int_neg, axis=2), 155 | np.expand_dims(input_pos_clks, axis=2), np.expand_dims(input_neg_clks, axis=2)],axis=2)), axis=0) 156 | output_image = sess.run([network],feed_dict={input:input_,sz:[iH,iW]}) 157 | output_image = np.minimum(np.maximum(output_image, 0.0), 1.0) 158 | output_image[np.where(output_image>0.5)]=1 159 | output_image[np.where(output_image<=0.5)]=0 160 | res_path = 'res/%d/Ours/%05d/segs/%03d.png' % (usrId, imIdx, cnt) 161 | segmask = np.uint8(output_image[0, 0, :, :, 0] * 255.0) 162 | 163 | cv2.imwrite(res_path, segmask) 164 | 165 | tmp_ol = cv2.imread(im_path, -1) 166 | tmp_ol[:,:,0] = 0.5*tmp_ol[:,:,0] + 0.5*segmask 167 | tmp_ol[:,:,1] = 0.5*tmp_ol[:,:,1] + 0.5*segmask 168 | tmp_ol[:,:,2] = 0.5*tmp_ol[:,:,2] + 0.5*segmask 169 | 170 | tmp_clk_path = 'res/%d/Ours/%05d/tmps/clk_%03d.png' % (usrId, imIdx, cnt) 171 | tmp_ol_path = 'res/%d/Ours/%05d/tmps/ol_%03d.png' % (usrId, imIdx, cnt) 172 | cv2.imwrite(tmp_clk_path, tmp_clk) 173 | cv2.imwrite(tmp_ol_path, tmp_ol) 174 | -------------------------------------------------------------------------------- /IntSeg_Train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os,time,cv2 3 | import scipy.io as sio 4 | import tensorflow as tf 5 | import tensorflow.contrib.slim as slim 6 | import numpy as np 7 | from numpy import * 8 | import scipy.linalg 9 | from copy import copy, deepcopy 10 | 11 | def lrelu(x): 12 | return tf.maximum(x*0.2,x) 13 | 14 | def identity_initializer(): 15 | def _initializer(shape, dtype=tf.float32, partition_info=None): 16 | array = np.zeros(shape, dtype=float) 17 | cx, cy = shape[0]//2, shape[1]//2 18 | for i in range(min(shape[2],shape[3])): 19 | array[cx, cy, i, i] = 1 20 | return tf.constant(array, dtype=dtype) 21 | return _initializer 22 | 23 | def nm(x): 24 | w0=tf.Variable(1.0,name='w0') 25 | w1=tf.Variable(0.0,name='w1') 26 | return w0*x+w1*slim.batch_norm(x) 27 | 28 | MEAN_VALUES = np.array([123.6800, 116.7790, 103.9390]).reshape((1,1,1,3)) 29 | 30 | def build_net(ntype,nin,nwb=None,name=None): 31 | if ntype=='conv': 32 | return tf.nn.relu(tf.nn.conv2d(nin,nwb[0],strides=[1,1,1,1],padding='SAME',name=name)+nwb[1]) 33 | elif ntype=='pool': 34 | return tf.nn.avg_pool(nin,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME') 35 | 36 | def get_weight_bias(vgg_layers,i): 37 | weights=vgg_layers[i][0][0][2][0][0] 38 | weights=tf.constant(weights) 39 | bias=vgg_layers[i][0][0][2][0][1] 40 | bias=tf.constant(np.reshape(bias,(bias.size))) 41 | return weights,bias 42 | 43 | def build_vgg19(input,reuse=False): 44 | if reuse: 45 | tf.get_variable_scope().reuse_variables() 46 | net={} 47 | vgg_rawnet=scipy.io.loadmat('Models/imagenet-vgg-verydeep-19.mat') 48 | vgg_layers=vgg_rawnet['layers'][0] 49 | net['input']=input-MEAN_VALUES 50 | net['conv1_1']=build_net('conv',net['input'],get_weight_bias(vgg_layers,0),name='vgg_conv1_1') 51 | net['conv1_2']=build_net('conv',net['conv1_1'],get_weight_bias(vgg_layers,2),name='vgg_conv1_2') 52 | net['pool1']=build_net('pool',net['conv1_2']) 53 | net['conv2_1']=build_net('conv',net['pool1'],get_weight_bias(vgg_layers,5),name='vgg_conv2_1') 54 | net['conv2_2']=build_net('conv',net['conv2_1'],get_weight_bias(vgg_layers,7),name='vgg_conv2_2') 55 | net['pool2']=build_net('pool',net['conv2_2']) 56 | net['conv3_1']=build_net('conv',net['pool2'],get_weight_bias(vgg_layers,10),name='vgg_conv3_1') 57 | net['conv3_2']=build_net('conv',net['conv3_1'],get_weight_bias(vgg_layers,12),name='vgg_conv3_2') 58 | net['conv3_3']=build_net('conv',net['conv3_2'],get_weight_bias(vgg_layers,14),name='vgg_conv3_3') 59 | net['conv3_4']=build_net('conv',net['conv3_3'],get_weight_bias(vgg_layers,16),name='vgg_conv3_4') 60 | net['pool3']=build_net('pool',net['conv3_4']) 61 | net['conv4_1']=build_net('conv',net['pool3'],get_weight_bias(vgg_layers,19),name='vgg_conv4_1') 62 | net['conv4_2']=build_net('conv',net['conv4_1'],get_weight_bias(vgg_layers,21),name='vgg_conv4_2') 63 | net['conv4_3']=build_net('conv',net['conv4_2'],get_weight_bias(vgg_layers,23),name='vgg_conv4_3') 64 | net['conv4_4']=build_net('conv',net['conv4_3'],get_weight_bias(vgg_layers,25),name='vgg_conv4_4') 65 | net['pool4']=build_net('pool',net['conv4_4']) 66 | net['conv5_1']=build_net('conv',net['pool4'],get_weight_bias(vgg_layers,28),name='vgg_conv5_1') 67 | net['conv5_2']=build_net('conv',net['conv5_1'],get_weight_bias(vgg_layers,30),name='vgg_conv5_2') 68 | #net['conv5_3']=build_net('conv',net['conv5_2'],get_weight_bias(vgg_layers,32),name='vgg_conv5_3') 69 | #net['conv5_4']=build_net('conv',net['conv5_3'],get_weight_bias(vgg_layers,34),name='vgg_conv5_4') 70 | #net['pool5']=build_net('pool',net['conv5_4']) 71 | return net 72 | 73 | def build(input,sz): 74 | vgg19_features=build_vgg19(input[:,:,:,0:3]) 75 | for layer_id in range(1,6): 76 | vgg19_f = vgg19_features['conv%d_2'%layer_id] 77 | input = tf.concat([input, tf.image.resize_bilinear(vgg19_f,sz)], axis=3) 78 | input = input/255.0 79 | net=slim.conv2d(input,64,[1,1],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv0') 80 | net=slim.conv2d(net,64,[3,3],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv1') 81 | net=slim.conv2d(net,64,[3,3],rate=2,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv2') 82 | net=slim.conv2d(net,64,[3,3],rate=4,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv3') 83 | net=slim.conv2d(net,64,[3,3],rate=8,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv4') 84 | net=slim.conv2d(net,64,[3,3],rate=16,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv5') 85 | net=slim.conv2d(net,64,[3,3],rate=32,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv6') 86 | net=slim.conv2d(net,64,[3,3],rate=64,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv7') 87 | net=slim.conv2d(net,64,[3,3],rate=128,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv8') 88 | net=slim.conv2d(net,64,[3,3],rate=1,activation_fn=lrelu,normalizer_fn=nm,weights_initializer=identity_initializer(),scope='g_conv9') 89 | net=slim.conv2d(net,6,[1,1],rate=1,activation_fn=None,scope='g_conv_last') 90 | return tf.tanh(net) 91 | 92 | def prepare_data(): 93 | train_im_names = [line.rstrip() for line in open('./train.txt')] 94 | val_im_names = [line.rstrip() for line in open('./val.txt')] 95 | return train_im_names,val_im_names 96 | 97 | config=tf.ConfigProto() 98 | config.gpu_options.allow_growth=True 99 | sess=tf.Session(config=config) 100 | 101 | im_path = "./img" 102 | seg_path = "./inst" 103 | train_im_names,val_im_names = prepare_data() 104 | input=tf.placeholder(tf.float32,shape=[None,None,None,7]) 105 | output=tf.placeholder(tf.float32,shape=[None,None,None,1]) 106 | sz=tf.placeholder(tf.int32,shape=[2]) 107 | input_vgg=tf.placeholder(tf.float32,shape=[None,None,None,3]) 108 | network=build(input,sz) 109 | vgg19_network=build_vgg19(input_vgg) 110 | 111 | # L2 Loss 112 | loss_d1=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,0],axis=3)-output)) 113 | loss_d2=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,1],axis=3)-output)) 114 | loss_d3=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,2],axis=3)-output)) 115 | loss_d4=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,3],axis=3)-output)) 116 | loss_d5=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,4],axis=3)-output)) 117 | loss_d6=tf.reduce_mean(tf.square(tf.expand_dims(network[:,:,:,5],axis=3)-output)) 118 | loss = tf.reduce_min([loss_d1, loss_d2, loss_d3, loss_d4, loss_d5, loss_d6]) + 0.0025*(32*loss_d1+16*loss_d2+8*loss_d3+4*loss_d4+2*loss_d5+1*loss_d6) 119 | 120 | # L1 Loss 121 | loss2_d1=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,0],axis=3)-output)) 122 | loss2_d2=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,1],axis=3)-output)) 123 | loss2_d3=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,2],axis=3)-output)) 124 | loss2_d4=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,3],axis=3)-output)) 125 | loss2_d5=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,4],axis=3)-output)) 126 | loss2_d6=tf.reduce_mean(tf.abs(tf.expand_dims(network[:,:,:,5],axis=3)-output)) 127 | loss2 = tf.reduce_min([loss2_d1, loss2_d2, loss2_d3, loss2_d4, loss2_d5, loss2_d6]) + 0.0025*(32*loss2_d1+16*loss2_d2+8*loss2_d3+4*loss2_d4+2*loss2_d5+1*loss2_d6) 128 | 129 | # IoU Loss 130 | nw1 = tf.expand_dims(network[:,:,:,0],axis=3) 131 | nw2 = tf.expand_dims(network[:,:,:,1],axis=3) 132 | nw3 = tf.expand_dims(network[:,:,:,2],axis=3) 133 | nw4 = tf.expand_dims(network[:,:,:,3],axis=3) 134 | nw5 = tf.expand_dims(network[:,:,:,4],axis=3) 135 | nw6 = tf.expand_dims(network[:,:,:,5],axis=3) 136 | iou_d1 = 1-tf.reduce_mean(tf.multiply(nw1,output))/(tf.reduce_mean(tf.maximum(nw1,output))+1e-6) 137 | iou_d2 = 1-tf.reduce_mean(tf.multiply(nw2,output))/(tf.reduce_mean(tf.maximum(nw2,output))+1e-6) 138 | iou_d3 = 1-tf.reduce_mean(tf.multiply(nw3,output))/(tf.reduce_mean(tf.maximum(nw3,output))+1e-6) 139 | iou_d4 = 1-tf.reduce_mean(tf.multiply(nw4,output))/(tf.reduce_mean(tf.maximum(nw4,output))+1e-6) 140 | iou_d5 = 1-tf.reduce_mean(tf.multiply(nw5,output))/(tf.reduce_mean(tf.maximum(nw5,output))+1e-6) 141 | iou_d6 = 1-tf.reduce_mean(tf.multiply(nw6,output))/(tf.reduce_mean(tf.maximum(nw6,output))+1e-6) 142 | loss_iou = tf.reduce_min([iou_d1, iou_d2, iou_d3, iou_d4, iou_d5, iou_d6]) + 0.0025*(32*iou_d1+16*iou_d2+8*iou_d3+4*iou_d4+2*iou_d5+1*iou_d6) 143 | 144 | # add positive/negative clicks as soft constraints 145 | ct_mask = tf.cast(input[:,:,:,3],dtype=tf.bool) & tf.cast(input[:,:,:,4],dtype=tf.bool) 146 | ct_mask = tf.tile(tf.expand_dims(~ct_mask,axis=3), [1,1,1,6]) 147 | ct_mask = tf.cast(ct_mask, dtype=tf.float32) 148 | ct_mask /= tf.reduce_mean(ct_mask) 149 | output_tile = tf.tile(output,[1,1,1,6]) 150 | ct_loss = tf.reduce_mean(tf.abs(network - output_tile) * ct_mask) 151 | 152 | all_loss = loss_iou + ct_loss 153 | 154 | opt=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(all_loss,var_list=[var for var in tf.trainable_variables() if var.name.startswith('g_')]) 155 | 156 | saver=tf.train.Saver(max_to_keep=1000) 157 | sess.run(tf.initialize_all_variables()) 158 | ckpt=tf.train.get_checkpoint_state("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh") 159 | if ckpt: 160 | print('loaded '+ckpt.model_checkpoint_path) 161 | saver.restore(sess,ckpt.model_checkpoint_path) 162 | 163 | input_images=[None]*len(train_im_names) 164 | output_masks=[None]*len(train_im_names) 165 | 166 | # For displaying the losses 167 | all=np.zeros(30000,dtype=float) 168 | all2=np.zeros(30000,dtype=float) 169 | all_iou=np.zeros(30000,dtype=float) 170 | all_d1=np.zeros(30000,dtype=float) 171 | all_d2=np.zeros(30000,dtype=float) 172 | all_d3=np.zeros(30000,dtype=float) 173 | all_d4=np.zeros(30000,dtype=float) 174 | all_d5=np.zeros(30000,dtype=float) 175 | all_d6=np.zeros(30000,dtype=float) 176 | 177 | for epoch in range(1,101): 178 | if os.path.isdir("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d"%epoch): 179 | continue 180 | cnt=0 181 | for id in np.random.permutation(len(train_im_names)): 182 | # for id in np.random.permutation(1): 183 | 184 | if input_images[id] is None: 185 | # The input image 186 | input_images[id] = cv2.imread(im_path + "/" + train_im_names[id]+".jpg",-1) 187 | if output_masks[id] is None: 188 | # The SBD Groundtruth mask 189 | mat_contents = sio.loadmat(seg_path + "/" + train_im_names[id] + ".mat") 190 | tmpstr = mat_contents['GTinst'] 191 | tmpmat = tmpstr[0,0] 192 | output_masks[id] = tmpmat['Segmentation'] 193 | output_mask = deepcopy(output_masks[id]) 194 | output_mask[output_mask==255] = 0 195 | num_obj = output_mask.max() 196 | for obj_id in range(num_obj): 197 | st = time.time() 198 | # random clicks 199 | input_pos = cv2.imread("./train" + "/" + train_im_names[id] + "/ints/%03d_%03d_pos.png" % (obj_id + 1, np.random.randint(1, 16)),-1) 200 | input_neg = cv2.imread("./train" + "/" + train_im_names[id] + "/ints/%03d_%03d_neg.png" % (obj_id + 1, np.random.randint(1, 16)),-1) 201 | input_pos_clks = deepcopy(input_pos) 202 | input_neg_clks = deepcopy(input_neg) 203 | input_pos_clks[input_pos != 0] = 255 204 | input_neg_clks[input_neg != 0] = 255 205 | if np.sum(input_pos==0)==0: 206 | continue 207 | input_image=np.expand_dims(np.float32(np.concatenate( 208 | [input_images[id], np.expand_dims(input_pos,axis=2), np.expand_dims(input_neg,axis=2), 209 | np.expand_dims(input_pos_clks,axis=2), np.expand_dims(input_neg_clks,axis=2)], axis=2)),axis=0) 210 | _,iH,iW,_=input_image.shape 211 | 212 | output_image = deepcopy(output_mask) 213 | output_image[output_mask != (obj_id+1)] = 0 214 | output_image[output_mask == (obj_id+1)] = 255 215 | output_image=np.expand_dims(np.expand_dims(np.float32(output_image),axis=0),axis=3)/255.0 216 | _,current,current2,current3,d1,d2,d3,d4,d5,d6=sess.run([opt,loss,loss2,loss_iou, iou_d1, iou_d2, iou_d3, iou_d4, iou_d5, iou_d6],feed_dict={input:input_image,sz:[iH,iW],output:output_image}) 217 | all[cnt]=current*255.0*255.0 #squared in 255 range (remember the network takes [0,1] 218 | all2[cnt]=current2*255.0 #changed to 255 in error 219 | all_iou[cnt]=current3 220 | all_d1[cnt]=d1 221 | all_d2[cnt]=d2 222 | all_d3[cnt]=d3 223 | all_d4[cnt]=d4 224 | all_d5[cnt]=d5 225 | all_d6[cnt]=d6 226 | cnt+=1 227 | print("%d %d l2: %.4f l1: %.4f IoU: %.4f d1-6: %.4f %.4f %.4f %.4f %.4f %.4f time: %.4f %s"%(epoch,cnt,np.mean(all[np.where(all)]),np.mean(all2[np.where(all2)]),np.mean(all_iou[np.where(all_iou)]),np.mean(all_d1[np.where(all_d1)]), 228 | np.mean(all_d2[np.where(all_d2)]),np.mean(all_d3[np.where(all_d3)]),np.mean(all_d4[np.where(all_d4)]), np.mean(all_d5[np.where(all_d5)]), np.mean(all_d6[np.where(all_d6)]), 229 | time.time()-st,os.getcwd().split('/')[-2])) 230 | 231 | os.makedirs("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d"%epoch) 232 | target=open("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d/score.txt"%epoch,'w') 233 | target.write("%f\n%f\n%f"%(np.mean(all[np.where(all)]),np.mean(all2[np.where(all2)]),np.mean(all_iou[np.where(all_iou)]))) 234 | target.close() 235 | 236 | saver.save(sess,"result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/model.ckpt") 237 | saver.save(sess,"result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d/model.ckpt"%epoch) 238 | 239 | # validation 240 | all_test = np.zeros(100, dtype=float) 241 | all2_test = np.zeros(100, dtype=float) 242 | all_iou_test = np.zeros(100, dtype=float) 243 | target = open("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d/test_score.txt" % epoch, 'w') 244 | 245 | for id in range(100): 246 | input_image = cv2.imread(im_path + "/" + val_im_names[id] + ".jpg", -1) 247 | input_pos = cv2.imread("./val" + "/" + val_im_names[id] + "/ints/%03d_%03d_pos.png" % (1, 1), -1) 248 | input_neg = cv2.imread("./val" + "/" + val_im_names[id] + "/ints/%03d_%03d_neg.png" % (1, 1), -1) 249 | input_pos_clks = deepcopy(input_pos) 250 | input_neg_clks = deepcopy(input_neg) 251 | input_pos_clks[input_pos != 0] = 255 252 | input_neg_clks[input_neg != 0] = 255 253 | output_gt = cv2.imread("./val" + "/" + val_im_names[id] + "/objs/%05d.png" % 1, -1) 254 | output_gt = np.expand_dims(np.expand_dims(np.float32(output_gt), axis=0), axis=3) / 255.0 255 | iH, iW, _ = input_image.shape 256 | input_image = np.expand_dims(np.float32(np.concatenate( 257 | [input_image, np.expand_dims(input_pos, axis=2), np.expand_dims(input_neg, axis=2), 258 | np.expand_dims(input_pos_clks, axis=2), np.expand_dims(input_neg_clks, axis=2)], axis=2)), axis=0) 259 | st=time.time() 260 | output_image, loss_test, loss2_test, iou_test = sess.run([network, loss, loss2, loss_iou],feed_dict={input:input_image,sz:[iH,iW],output: output_gt}) 261 | all_test[id] = loss_test * 255.0 * 255.0 262 | all2_test[id] = loss2_test * 255 263 | all_iou_test[id] = iou_test 264 | target.write("%f %f %f\n" % (all_test[id], all2_test[id], all_iou_test[id])) 265 | print("%.3f"%(time.time()-st)) 266 | output_image = np.minimum(np.maximum(output_image, 0.0), 1.0) 267 | for output_d in range(6): 268 | save_image = input_image[0, :, :, 0:3] / 255.0 269 | save_image[:, :, 0] = (save_image[:, :, 0] + 0.5 * output_image[0, :, :, output_d]) 270 | save_image[:, :, 1] = (save_image[:, :, 1] + 0.5 * output_image[0, :, :, output_d]) 271 | save_image[:, :, 2] = (save_image[:, :, 2] + 0.5 * output_image[0, :, :, output_d]) 272 | save_image = np.minimum(np.maximum(save_image, 0.0), 1.0) * 255.0 273 | cv2.imwrite("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d/%s_%02d_BW.png" % (epoch, val_im_names[id], output_d), 274 | np.uint8(output_image[0, :, :, output_d] * 255.0)) 275 | cv2.imwrite("result64_vgg19_RDL6_IoU_dt_pt_ct_tanh/%04d/%s_%02d.jpg" % (epoch, val_im_names[id], output_d), 276 | np.uint8(save_image)) 277 | target.write("Mean: %f %f %f\n" % (np.mean(all_test[np.where(all_test)]), np.mean(all2_test[np.where(all2_test)]), np.mean(all_iou_test[np.where(all_iou_test)]))) 278 | target.close() 279 | --------------------------------------------------------------------------------