├── .gitignore ├── example └── 2007_000129.jpg ├── matlab ├── pascal_seg_colormap.mat ├── EvalSegResults.m ├── MyVOCevalseg.m └── GetVOCopts.m ├── npy2tfmodel.py ├── README.md ├── deeplab_main.py ├── caffemodel2npy.py ├── LICENSE └── deeplab_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | *.pyc 3 | 4 | model/ 5 | -------------------------------------------------------------------------------- /example/2007_000129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxi116/TF-deeplab/HEAD/example/2007_000129.jpg -------------------------------------------------------------------------------- /matlab/pascal_seg_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenxi116/TF-deeplab/HEAD/matlab/pascal_seg_colormap.mat -------------------------------------------------------------------------------- /matlab/EvalSegResults.m: -------------------------------------------------------------------------------- 1 | seg_root = '/media/Work_HD/cxliu/datasets/VOCdevkit/VOC2012/'; 2 | seg_res_dir = '../example/val'; 3 | add_colormap = 0; 4 | evaluate = 1; 5 | 6 | %% add colormap to prediction image 7 | if add_colormap 8 | load('pascal_seg_colormap.mat'); 9 | imgs_dir = dir(fullfile(seg_res_dir, '*.png')); 10 | for i = 1:numel(imgs_dir) 11 | fprintf(1, 'adding colormap %d (%d) ...\n', i, numel(imgs_dir)); 12 | img = imread(fullfile(seg_res_dir, imgs_dir(i).name)); 13 | imwrite(img, colormap, fullfile(seg_res_dir, imgs_dir(i).name)); 14 | end 15 | end 16 | 17 | %% evaluate IOU 18 | if evaluate 19 | VOCopts = GetVOCopts(seg_root, seg_res_dir, 'train', 'val', 'VOC2012'); 20 | [accuracies, avacc, conf, rawcounts] = MyVOCevalseg(VOCopts); 21 | end -------------------------------------------------------------------------------- /npy2tfmodel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Chenxi Liu. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # sample usage: 17 | # python npy2tfmodel.py 0 ./model/ResNet101_init.npy ./model/ResNet101_init.tfmodel 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | import deeplab_model 22 | import sys 23 | import os; os.environ['CUDA_VISIBLE_DEVICES'] = sys.argv[1] 24 | import pdb 25 | 26 | 27 | weights = np.load(sys.argv[2])[()] 28 | 29 | model = deeplab_model.DeepLab() 30 | 31 | sess = tf.Session() 32 | sess.run(tf.global_variables_initializer()) 33 | var_list = tf.all_variables() 34 | count = 0 35 | for item in var_list: 36 | item_name = item.name[8:-2] # "DeepLab/" at beginning, ":0" at last 37 | if not item_name in weights.keys(): 38 | continue 39 | print item_name 40 | count += 1 41 | sess.run(tf.assign(item, weights[item_name])) 42 | assert(count == len(weights)) 43 | 44 | snapshot_saver = tf.train.Saver() 45 | snapshot_saver.save(sess, sys.argv[3]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TF-deeplab 2 | 3 | This is a Tensorflow implementation of [DeepLab](http://liangchiehchen.com/projects/DeepLab.html), compatible with Tensorflow 1.2.1. 4 | 5 | Currently it supports both training and testing the ResNet 101 version by converting the caffemodel provided by Jay. 6 | 7 | Note that the current version is not multi-scale, i.e. only uses the original resolution branch and discarding all layers of 0.5 and 0.75 resolution. 8 | 9 | The `caffemodel2npy.py` is modified from [here](https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/utils/loadcaffe.py), and the `deeplab_model.py` is modified from [here](https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py). 10 | 11 | ## Example Usage 12 | - Download the prototxt and caffemodel [provided by Jay](http://liangchiehchen.com/projects/DeepLabv2_resnet.html) 13 | - Convert caffemodel to npy file 14 | ```bash 15 | python caffemodel2npy.py deploy.prototxt ../deeplab/ResNet101/init.caffemodel ./model/ResNet101_init.npy 16 | python caffemodel2npy.py deploy.prototxt ../deeplab/ResNet101/train_iter_20000.caffemodel ./model/ResNet101_train.npy 17 | python caffemodel2npy.py deploy.prototxt ../deeplab/ResNet101/train2_iter_20000.caffemodel ./model/ResNet101_train2.npy 18 | ``` 19 | - Convert npy file to tfmodel 20 | ```bash 21 | python npy2tfmodel.py 0 ./model/ResNet101_init.npy ./model/ResNet101_init.tfmodel 22 | python npy2tfmodel.py 0 ./model/ResNet101_train.npy ./model/ResNet101_train.tfmodel 23 | python npy2tfmodel.py 0 ./model/ResNet101_train2.npy ./model/ResNet101_train2.tfmodel 24 | ``` 25 | - Test on a single image 26 | ```bash 27 | python deeplab_main.py 0 single 28 | ``` 29 | - Test on the PASCAL VOC2012 validation set (you will also want to look at the `matlab` folder and run `EvalSegResults.m` after you run the following command) 30 | ```bash 31 | python deeplab_main.py 0 test 32 | ``` 33 | 34 | - To train on the PASCAL VOC2012 `train_aug`, run 35 | ```bash 36 | python deeplab_main.py 0 train 37 | ``` 38 | 39 | ## Performance 40 | 41 | The converted DeepLab ResNet 101 model achieves mean IOU of 73.296% on the validation set of PASCAL VOC2012. Again, this is only with the original resolution branch, which is likely to be the reason for the performance gap (according to the [paper](https://arxiv.org/pdf/1606.00915.pdf) this number should be around 75%). 42 | 43 | ## TODO 44 | 45 | - Incorporating 0.5 and 0.75 resolution 46 | -------------------------------------------------------------------------------- /deeplab_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Chenxi Liu. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # sample usage: 17 | # python deeplab_main.py 0 single 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | from deeplab_model import DeepLab 22 | from PIL import Image 23 | import sys 24 | import os; os.environ['CUDA_VISIBLE_DEVICES'] = sys.argv[1] 25 | import pdb 26 | 27 | def process_im(imname, mu): 28 | im = np.array(Image.open(imname), dtype=np.float32) 29 | if im.ndim == 3: 30 | if im.shape[2] == 4: 31 | im = im[:, :, 0:3] 32 | im = im[:,:,::-1] 33 | else: 34 | im = np.tile(im[:, :, np.newaxis], (1, 1, 3)) 35 | im -= mu 36 | im = np.expand_dims(im, axis=0) 37 | return im 38 | 39 | if __name__ == "__main__": 40 | 41 | mu = np.array((104.00698793, 116.66876762, 122.67891434)) 42 | 43 | if sys.argv[2] == 'train': 44 | pretrained_model = './model/ResNet101_init.tfmodel' 45 | model = DeepLab(mode='train') 46 | load_var = {var.op.name: var for var in tf.global_variables() 47 | if not 'Momentum' in var.op.name and not 'global_step' in var.op.name} 48 | snapshot_restorer = tf.train.Saver(load_var) 49 | else: 50 | pretrained_model = './model/ResNet101_train.tfmodel' 51 | # pretrained_model = './model/ResNet101_epoch_2.tfmodel' 52 | model = DeepLab() 53 | snapshot_restorer = tf.train.Saver() 54 | sess = tf.Session() 55 | sess.run(tf.global_variables_initializer()) 56 | snapshot_restorer.restore(sess, pretrained_model) 57 | 58 | if sys.argv[2] == 'single': 59 | im = process_im('example/2007_000129.jpg', mu) 60 | pred = sess.run(model.up, feed_dict={ 61 | model.images : im 62 | }) 63 | pred = np.argmax(pred, axis=3).squeeze().astype(np.uint8) 64 | seg = Image.fromarray(pred) 65 | seg.save('example/2007_000129.png') 66 | 67 | elif sys.argv[2] == 'test': 68 | pascal_dir = '/media/Work_HD/cxliu/datasets/VOCdevkit/VOC2012/JPEGImages/' 69 | list_dir = '/media/Work_HD/cxliu/projects/deeplab/list/' 70 | save_dir = 'example/val/' 71 | lines = np.loadtxt(list_dir + 'val_id.txt', dtype=str) 72 | for i, line in enumerate(lines): 73 | imname = line 74 | im = process_im(pascal_dir + imname + '.jpg', mu) 75 | pred = sess.run(model.up, feed_dict={ 76 | model.images : im 77 | }) 78 | pred = np.argmax(pred, axis=3).squeeze().astype(np.uint8) 79 | seg = Image.fromarray(pred) 80 | seg.save('example/val/' + imname + '.png') 81 | print('processing %d/%d' % (i + 1, len(lines))) 82 | 83 | elif sys.argv[2] == 'train': 84 | cls_loss_avg = 0 85 | decay = 0.99 86 | num_epochs = 2 # train for 2 epochs 87 | snapshot_saver = tf.train.Saver(max_to_keep = 1000) 88 | snapshot_file = './model/ResNet101_epoch_%d.tfmodel' 89 | pascal_dir = '/media/Work_HD/cxliu/datasets/VOCdevkit/VOC2012' 90 | list_dir = '/media/Work_HD/cxliu/projects/deeplab/list/' 91 | lines = np.loadtxt(list_dir + 'train_aug.txt', dtype=str) 92 | for epoch in range(num_epochs): 93 | lines = np.random.permutation(lines) 94 | for i, line in enumerate(lines): 95 | imname, labelname = line 96 | im = process_im(pascal_dir + imname, mu) 97 | label = np.array(Image.open(pascal_dir + labelname)) 98 | label = np.expand_dims(label, axis=0) 99 | _, cls_loss_val, lr_val, label_val = sess.run([model.train_step, 100 | model.cls_loss, 101 | model.learning_rate, 102 | model.labels_coarse], 103 | feed_dict={ 104 | model.images : im, 105 | model.labels : np.expand_dims(label, axis=3) 106 | }) 107 | cls_loss_avg = decay*cls_loss_avg + (1-decay)*cls_loss_val 108 | print('iter = %d / %d, loss (cur) = %f, loss (avg) = %f, lr = %f' % (i, 109 | len(lines), cls_loss_val, cls_loss_avg, lr_val)) 110 | snapshot_saver.save(sess, snapshot_file % (epoch + 1)) -------------------------------------------------------------------------------- /matlab/MyVOCevalseg.m: -------------------------------------------------------------------------------- 1 | %VOCEVALSEG Evaluates a set of segmentation results. 2 | % VOCEVALSEG(VOCopts,ID); prints out the per class and overall 3 | % segmentation accuracies. Accuracies are given using the intersection/union 4 | % metric: 5 | % true positives / (true positives + false positives + false negatives) 6 | % 7 | % [ACCURACIES,AVACC,CONF] = VOCEVALSEG(VOCopts,ID) returns the per class 8 | % percentage ACCURACIES, the average accuracy AVACC and the confusion 9 | % matrix CONF. 10 | % 11 | % [ACCURACIES,AVACC,CONF,RAWCOUNTS] = VOCEVALSEG(VOCopts,ID) also returns 12 | % the unnormalised confusion matrix, which contains raw pixel counts. 13 | function [accuracies,avacc,conf,rawcounts] = MyVOCevalseg(VOCopts) 14 | 15 | % image test set 16 | [gtids,t]=textread(sprintf(VOCopts.seg.imgsetpath,VOCopts.testset),'%s %d'); 17 | 18 | % number of labels = number of classes plus one for the background 19 | num = VOCopts.nclasses+1; 20 | confcounts = zeros(num); 21 | count=0; 22 | 23 | num_missing_img = 0; 24 | 25 | tic; 26 | for i=1:length(gtids) 27 | % display progress 28 | if toc>1 29 | fprintf('test confusion: %d/%d\n',i,length(gtids)); 30 | drawnow; 31 | tic; 32 | end 33 | 34 | imname = gtids{i}; 35 | 36 | % ground truth label file 37 | gtfile = sprintf(VOCopts.seg.clsimgpath,imname); 38 | [gtim,map] = imread(gtfile); 39 | gtim = double(gtim); 40 | 41 | % results file 42 | resfile = sprintf(VOCopts.seg.clsrespath,imname); 43 | try 44 | [resim,map] = imread(resfile); 45 | catch err 46 | num_missing_img = num_missing_img + 1; 47 | %fprintf(1, 'Fail to read %s\n', resfile); 48 | continue; 49 | end 50 | 51 | resim = double(resim); 52 | 53 | % Check validity of results image 54 | maxlabel = max(resim(:)); 55 | if (maxlabel>VOCopts.nclasses) 56 | error('Results image ''%s'' has out of range value %d (the value should be <= %d)',imname,maxlabel,VOCopts.nclasses); 57 | end 58 | 59 | szgtim = size(gtim); szresim = size(resim); 60 | if any(szgtim~=szresim) 61 | error('Results image ''%s'' is the wrong size, was %d x %d, should be %d x %d.',imname,szresim(1),szresim(2),szgtim(1),szgtim(2)); 62 | end 63 | 64 | %pixel locations to include in computation 65 | locs = gtim<255; 66 | 67 | % joint histogram 68 | sumim = 1+gtim+resim*num; 69 | hs = histc(sumim(locs),1:num*num); 70 | count = count + numel(find(locs)); 71 | confcounts(:) = confcounts(:) + hs(:); 72 | end 73 | 74 | if (num_missing_img > 0) 75 | fprintf(1, 'WARNING: There are %d missing results!\n', num_missing_img); 76 | end 77 | 78 | % confusion matrix - first index is true label, second is inferred label 79 | %conf = zeros(num); 80 | conf = 100*confcounts./repmat(1E-20+sum(confcounts,2),[1 size(confcounts,2)]); 81 | rawcounts = confcounts; 82 | 83 | % Pixel Accuracy 84 | overall_acc = 100*sum(diag(confcounts)) / sum(confcounts(:)); 85 | fprintf('Percentage of pixels correctly labelled overall: %6.3f%%\n',overall_acc); 86 | 87 | % Class Accuracy 88 | class_acc = zeros(1, num); 89 | class_count = 0; 90 | fprintf('Accuracy for each class (pixel accuracy)\n'); 91 | for i = 1 : num 92 | denom = sum(confcounts(i, :)); 93 | if (denom == 0) 94 | denom = 1; 95 | end 96 | class_acc(i) = 100 * confcounts(i, i) / denom; 97 | if i == 1 98 | clname = 'background'; 99 | else 100 | clname = VOCopts.classes{i-1}; 101 | end 102 | 103 | if ~strcmp(clname, 'void') 104 | class_count = class_count + 1; 105 | fprintf(' %14s: %6.3f%%\n', clname, class_acc(i)); 106 | end 107 | end 108 | fprintf('-------------------------\n'); 109 | avg_class_acc = sum(class_acc) / class_count; 110 | fprintf('Mean Class Accuracy: %6.3f%%\n', avg_class_acc); 111 | 112 | % Pixel IOU 113 | accuracies = zeros(VOCopts.nclasses,1); 114 | fprintf('Accuracy for each class (intersection/union measure)\n'); 115 | 116 | real_class_count = 0; 117 | 118 | for j=1:num 119 | 120 | gtj=sum(confcounts(j,:)); 121 | resj=sum(confcounts(:,j)); 122 | gtjresj=confcounts(j,j); 123 | % The accuracy is: true positive / (true positive + false positive + false negative) 124 | % which is equivalent to the following percentage: 125 | denom = (gtj+resj-gtjresj); 126 | 127 | if denom == 0 128 | denom = 1; 129 | end 130 | 131 | accuracies(j)=100*gtjresj/denom; 132 | 133 | clname = 'background'; 134 | if (j>1), clname = VOCopts.classes{j-1};end; 135 | 136 | if ~strcmp(clname, 'void') 137 | real_class_count = real_class_count + 1; 138 | else 139 | if denom ~= 1 140 | fprintf(1, 'WARNING: this void class has denom = %d\n', denom); 141 | end 142 | end 143 | 144 | if ~strcmp(clname, 'void') 145 | fprintf(' %14s: %6.3f%%\n',clname,accuracies(j)); 146 | end 147 | end 148 | 149 | %accuracies = accuracies(1:end); 150 | %avacc = mean(accuracies); 151 | avacc = sum(accuracies) / real_class_count; 152 | 153 | fprintf('-------------------------\n'); 154 | fprintf('Average accuracy: %6.3f%%\n',avacc); 155 | -------------------------------------------------------------------------------- /matlab/GetVOCopts.m: -------------------------------------------------------------------------------- 1 | function VOCopts = GetVOCopts(seg_root, seg_res_dir, trainset, testset, dataset) 2 | %clear VOCopts 3 | 4 | if nargin < 5 5 | dataset = 'VOC2012'; 6 | end 7 | 8 | % dataset 9 | % 10 | % Note for experienced users: the VOC2008-11 test sets are subsets 11 | % of the VOC2012 test set. You don't need to do anything special 12 | % to submit results for VOC2008-11. 13 | 14 | VOCopts.dataset=dataset; 15 | 16 | % get devkit directory with forward slashes 17 | devkitroot=strrep(fileparts(fileparts(mfilename('fullpath'))),'\','/'); 18 | 19 | % change this path to point to your copy of the PASCAL VOC data 20 | VOCopts.datadir=[devkitroot '/']; 21 | 22 | % change this path to a writable directory for your results 23 | %VOCopts.resdir=[devkitroot '/results/' VOCopts.dataset '/']; 24 | VOCopts.resdir = seg_res_dir; 25 | 26 | % change this path to a writable local directory for the example code 27 | VOCopts.localdir=[devkitroot '/local/' VOCopts.dataset '/']; 28 | 29 | % initialize the training set 30 | 31 | VOCopts.trainset = trainset; 32 | %VOCopts.trainset='train'; % use train for development 33 | % VOCopts.trainset='trainval'; % use train+val for final challenge 34 | 35 | % initialize the test set 36 | 37 | VOCopts.testset = testset; 38 | %VOCopts.testset='val'; % use validation data for development test set 39 | % VOCopts.testset='test'; % use test set for final challenge 40 | 41 | % initialize main challenge paths 42 | 43 | %VOCopts.annopath=[VOCopts.datadir VOCopts.dataset '/Annotations/%s.xml']; 44 | %VOCopts.imgpath=[VOCopts.datadir VOCopts.dataset '/JPEGImages/%s.jpg']; 45 | %VOCopts.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Main/%s.txt']; 46 | %VOCopts.clsimgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Main/%s_%s.txt']; 47 | 48 | VOCopts.annopath=[seg_root '/Annotations/%s.xml']; 49 | VOCopts.imgpath=[seg_root '/JPEGImages/%s.jpg']; 50 | VOCopts.imgsetpath=[seg_root '/ImageSets/Main/%s.txt']; 51 | VOCopts.clsimgsetpath=[seg_root '/ImageSets/Main/%s_%s.txt']; 52 | 53 | 54 | VOCopts.clsrespath=[VOCopts.resdir 'Main/%s_cls_' VOCopts.testset '_%s.txt']; 55 | VOCopts.detrespath=[VOCopts.resdir 'Main/%s_det_' VOCopts.testset '_%s.txt']; 56 | 57 | % initialize segmentation task paths 58 | 59 | %if strcmp(dataset, 'VOC2012') 60 | % VOCopts.seg.clsimgpath=[seg_root '/SegmentationClassAug/%s.png']; 61 | %else 62 | VOCopts.seg.clsimgpath=[seg_root '/SegmentationClass/%s.png']; 63 | %end 64 | 65 | VOCopts.seg.instimgpath=[seg_root '/SegmentationObject/%s.png']; 66 | VOCopts.seg.imgsetpath=[seg_root '/ImageSets/Segmentation/%s.txt']; 67 | 68 | %VOCopts.seg.clsimgpath=[VOCopts.datadir VOCopts.dataset '/SegmentationClass/%s.png']; 69 | %VOCopts.seg.instimgpath=[VOCopts.datadir VOCopts.dataset '/SegmentationObject/%s.png']; 70 | %VOCopts.seg.imgsetpath=[VOCopts.dataset '/ImageSets/Segmentation/%s.txt']; 71 | 72 | 73 | VOCopts.seg.clsresdir=[VOCopts.resdir 'Segmentation/%s_%s_cls']; 74 | VOCopts.seg.instresdir=[VOCopts.resdir 'Segmentation/%s_%s_inst']; 75 | VOCopts.seg.clsrespath=[VOCopts.resdir '/%s.png']; 76 | VOCopts.seg.instrespath=[VOCopts.resdir '/%s.png']; 77 | 78 | % initialize layout task paths 79 | 80 | VOCopts.layout.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Layout/%s.txt']; 81 | VOCopts.layout.respath=[VOCopts.resdir 'Layout/%s_layout_' VOCopts.testset '.xml']; 82 | 83 | % initialize action task paths 84 | 85 | VOCopts.action.imgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Action/%s.txt']; 86 | VOCopts.action.clsimgsetpath=[VOCopts.datadir VOCopts.dataset '/ImageSets/Action/%s_%s.txt']; 87 | VOCopts.action.respath=[VOCopts.resdir 'Action/%s_action_' VOCopts.testset '_%s.txt']; 88 | 89 | % initialize the VOC challenge options 90 | 91 | % classes 92 | 93 | if ~isempty(strfind(seg_root, 'VOC')) 94 | VOCopts.classes={... 95 | 'aeroplane' 96 | 'bicycle' 97 | 'bird' 98 | 'boat' 99 | 'bottle' 100 | 'bus' 101 | 'car' 102 | 'cat' 103 | 'chair' 104 | 'cow' 105 | 'diningtable' 106 | 'dog' 107 | 'horse' 108 | 'motorbike' 109 | 'person' 110 | 'pottedplant' 111 | 'sheep' 112 | 'sofa' 113 | 'train' 114 | 'tvmonitor'}; 115 | 116 | elseif ~isempty(strfind(seg_root, 'coco')) || ~isempty(strfind(seg_root, 'COCO')) 117 | coco_categories = GetCocoCategories(); 118 | VOCopts.classes = coco_categories.values(); 119 | else 120 | error('Unknown dataset!\n'); 121 | end 122 | 123 | VOCopts.nclasses=length(VOCopts.classes); 124 | 125 | 126 | % poses 127 | 128 | VOCopts.poses={... 129 | 'Unspecified' 130 | 'Left' 131 | 'Right' 132 | 'Frontal' 133 | 'Rear'}; 134 | 135 | VOCopts.nposes=length(VOCopts.poses); 136 | 137 | % layout parts 138 | 139 | VOCopts.parts={... 140 | 'head' 141 | 'hand' 142 | 'foot'}; 143 | 144 | VOCopts.nparts=length(VOCopts.parts); 145 | 146 | VOCopts.maxparts=[1 2 2]; % max of each of above parts 147 | 148 | % actions 149 | 150 | VOCopts.actions={... 151 | 'other' % skip this when training classifiers 152 | 'jumping' 153 | 'phoning' 154 | 'playinginstrument' 155 | 'reading' 156 | 'ridingbike' 157 | 'ridinghorse' 158 | 'running' 159 | 'takingphoto' 160 | 'usingcomputer' 161 | 'walking'}; 162 | 163 | VOCopts.nactions=length(VOCopts.actions); 164 | 165 | % overlap threshold 166 | 167 | VOCopts.minoverlap=0.5; 168 | 169 | % annotation cache for evaluation 170 | 171 | VOCopts.annocachepath=[VOCopts.localdir '%s_anno.mat']; 172 | 173 | % options for example implementations 174 | 175 | VOCopts.exfdpath=[VOCopts.localdir '%s_fd.mat']; 176 | -------------------------------------------------------------------------------- /caffemodel2npy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Chenxi Liu. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # This code is modified from 17 | # https://github.com/ppwwyyxx/tensorpack/blob/master/tensorpack/utils/loadcaffe.py 18 | 19 | # sample usage: 20 | # python caffemodel2npy.py deploy.prototxt 21 | # ../deeplab/ResNet101/init.caffemodel ./model/ResNet101_init.npy 22 | 23 | import numpy as np 24 | import pdb 25 | import re 26 | 27 | 28 | class CaffeLayerProcessor(object): 29 | def __init__(self, net): 30 | self.net = net 31 | self.layer_names = net._layer_names 32 | self.param_dict = {} 33 | self.processors = { 34 | 'Convolution': self.proc_conv, 35 | 'InnerProduct': self.proc_fc, 36 | 'BatchNorm': self.proc_bn, 37 | 'Scale': self.proc_scale 38 | } 39 | 40 | def process(self): 41 | for idx, layer in enumerate(self.net.layers): 42 | param = layer.blobs 43 | name = self.layer_names[idx] 44 | if 'res05' in name or 'res075' in name: 45 | continue 46 | if layer.type in self.processors: 47 | name_ = self.rename(name) 48 | dic = self.processors[layer.type](idx, name_, param) 49 | self.param_dict.update(dic) 50 | return self.param_dict 51 | 52 | def rename(self, caffe_layer_name): 53 | if caffe_layer_name.startswith('scale'): 54 | caffe_layer_name = 'bn' + caffe_layer_name[5:] 55 | 56 | NAME_MAP = {'bn_conv1': 'group_1/bn_conv1', 57 | 'conv1': 'group_1/conv1', 58 | 'fc1_voc12_c0': 'fc1_voc12/conv0', 59 | 'fc1_voc12_c1': 'fc1_voc12/conv1', 60 | 'fc1_voc12_c2': 'fc1_voc12/conv2', 61 | 'fc1_voc12_c3': 'fc1_voc12/conv3'} 62 | if caffe_layer_name in NAME_MAP: 63 | return NAME_MAP[caffe_layer_name] 64 | 65 | s = re.search('([a-z]+)([0-9]+)([a-z]+)_', caffe_layer_name) 66 | if s is None: 67 | s = re.search('([a-z]+)([0-9]+)([a-z]+)([0-9]+)_', caffe_layer_name) 68 | layer_block_part1 = s.group(3) 69 | layer_block_part2 = s.group(4) 70 | assert layer_block_part1 in ['a', 'b'] 71 | layer_block = 0 if layer_block_part1 == 'a' else int(layer_block_part2) 72 | else: 73 | layer_block = ord(s.group(3)) - ord('a') 74 | layer_type = s.group(1) 75 | layer_group = s.group(2) 76 | 77 | layer_branch = int(re.search('_branch([0-9])', caffe_layer_name).group(1)) 78 | assert layer_branch in [1, 2] 79 | if layer_branch == 2: 80 | layer_id = re.search('_branch[0-9]([a-z])', caffe_layer_name).group(1) 81 | layer_id = ord(layer_id) - ord('a') + 1 82 | else: 83 | layer_id = 'add' 84 | 85 | TYPE_DICT = {'res':'conv', 'bn':'bn'} 86 | 87 | layer_type = TYPE_DICT[layer_type] 88 | tf_name = 'group_{}_{}/block_{}/{}'.format( 89 | int(layer_group), layer_block, layer_id, layer_type) 90 | print caffe_layer_name, tf_name 91 | return tf_name 92 | 93 | def proc_conv(self, idx, name, param): 94 | assert len(param) <= 2 95 | assert param[0].data.ndim == 4 96 | # caffe: ch_out, ch_in, h, w 97 | W = param[0].data.transpose(2,3,1,0) 98 | if len(param) == 1: 99 | return {name + '/DW': W} 100 | else: 101 | return {name + '/DW': W, 102 | name + '/biases': param[1].data} 103 | 104 | def proc_fc(self, idx, name, param): 105 | # TODO caffe has an 'transpose' option for fc/W 106 | assert len(param) == 2 107 | prev_layer_name = self.net.bottom_names[name][0] 108 | prev_layer_output = self.net.blobs[prev_layer_name].data 109 | if prev_layer_output.ndim == 4: 110 | W = param[0].data 111 | # original: outx(CxHxW) 112 | W = W.reshape((-1,) + prev_layer_output.shape[1:]).transpose(2,3,1,0) 113 | # become: (HxWxC)xout 114 | else: 115 | W = param[0].data.transpose() 116 | return {name + '/DW': W.squeeze(), 117 | name + '/biases': param[1].data.squeeze()} 118 | 119 | def proc_bn(self, idx, name, param): 120 | # assert param[2].data[0] == 1.0 121 | return {name + '/mean': param[0].data, 122 | name + '/variance': param[1].data, 123 | name + '/factor': param[2].data } 124 | 125 | def proc_scale(self, idx, name, param): 126 | # bottom_name = self.net.bottom_names[name][0] 127 | # # find the bn layer before this scaling 128 | # for i, layer in enumerate(self.net.layers): 129 | # if layer.type == 'BatchNorm': 130 | # name2 = self.layer_names[i] 131 | # bottom_name2 = self.net.bottom_names[name2][0] 132 | # if bottom_name2 == bottom_name: 133 | # # scaling and BN share the same bottom, should merge 134 | # return {name2 + '/beta': param[1].data, 135 | # name2 + '/gamma': param[0].data } 136 | return {name + '/beta': param[1].data, 137 | name + '/gamma': param[0].data} 138 | # assume this scaling layer is part of some BN 139 | # raise ValueError() 140 | 141 | 142 | def load_caffe(model_desc, model_file): 143 | """ 144 | return a dict of params 145 | """ 146 | import caffe 147 | caffe.set_mode_cpu() 148 | net = caffe.Net(model_desc, model_file, caffe.TEST) 149 | param_dict = CaffeLayerProcessor(net).process() 150 | return param_dict 151 | 152 | if __name__ == '__main__': 153 | import argparse 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('model') 156 | parser.add_argument('weights') 157 | parser.add_argument('output') 158 | args = parser.parse_args() 159 | ret = load_caffe(args.model, args.weights) 160 | 161 | # pdb.set_trace() 162 | 163 | import numpy as np 164 | np.save(args.output, ret) 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /deeplab_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Chenxi Liu. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # This code is modified from 17 | # https://github.com/tensorflow/models/blob/master/resnet/resnet_model.py 18 | 19 | """DeepLab model. 20 | 21 | Related paper: 22 | https://arxiv.org/abs/1606.00915 23 | """ 24 | from collections import namedtuple 25 | from tensorflow.python.training import moving_averages 26 | import numpy as np 27 | import tensorflow as tf 28 | import pdb 29 | 30 | 31 | class DeepLab(object): 32 | """DeepLab model.""" 33 | 34 | def __init__(self, batch_size = 1, 35 | num_classes = 21, 36 | lrn_rate = 0.00025, 37 | lr_decay_step = 20000, 38 | num_residual_units = [3, 4, 23, 3], 39 | use_bottleneck = True, 40 | weight_decay_rate = 0.0005, 41 | relu_leakiness = 0.0, 42 | bn = False, 43 | images = tf.placeholder(tf.float32), 44 | labels = tf.placeholder(tf.int32), 45 | filters = [64, 256, 512, 1024, 2048], 46 | optimizer = 'mom', 47 | mode = 'eval'): 48 | """DeepLab constructor. 49 | 50 | Args: 51 | : Hyperparameters. 52 | images: Batches of images. [batch_size, image_size, image_size, 3] 53 | labels: Batches of labels. [batch_size, image_size, image_size] 54 | mode: One of 'train' and 'eval'. 55 | """ 56 | self.images = images 57 | self.labels = labels 58 | self.H = tf.shape(self.images)[1] 59 | self.W = tf.shape(self.images)[2] 60 | self.batch_size = batch_size 61 | self.num_classes = num_classes 62 | self.lrn_rate = lrn_rate 63 | self.lr_decay_step = lr_decay_step 64 | self.num_residual_units = num_residual_units 65 | self.use_bottleneck = use_bottleneck 66 | self.weight_decay_rate = weight_decay_rate 67 | self.relu_leakiness = relu_leakiness 68 | self.bn = bn 69 | self.filters = filters 70 | self.optimizer = optimizer 71 | self.mode = mode 72 | self._extra_train_ops = [] 73 | 74 | with tf.variable_scope("DeepLab"): 75 | self.build_graph() 76 | 77 | def build_graph(self): 78 | """Build a whole graph for the model.""" 79 | self._build_model() 80 | if self.mode == 'train': 81 | self._build_train_op() 82 | # self.summaries = tf.summary.merge_all() 83 | 84 | def _stride_arr(self, stride): 85 | """Map a stride scalar to the stride array for tf.nn.conv2d.""" 86 | return [1, stride, stride, 1] 87 | 88 | def _build_model(self): 89 | """Build the core model within the graph.""" 90 | with tf.variable_scope('group_1'): 91 | x = self.images 92 | x = self._conv('conv1', x, 7, 3, 64, self._stride_arr(2)) 93 | x = self._batch_norm('bn_conv1', x) 94 | x = self._relu(x, self.relu_leakiness) 95 | x = tf.nn.max_pool(x, [1, 3, 3, 1], [1, 2, 2, 1], padding='SAME') 96 | 97 | res_func = self._bottleneck_residual 98 | filters = self.filters 99 | 100 | with tf.variable_scope('group_2_0'): 101 | x = res_func(x, filters[0], filters[1], self._stride_arr(1)) 102 | for i in xrange(1, self.num_residual_units[0]): 103 | with tf.variable_scope('group_2_%d' % i): 104 | x = res_func(x, filters[1], filters[1], self._stride_arr(1)) 105 | 106 | with tf.variable_scope('group_3_0'): 107 | x = res_func(x, filters[1], filters[2], self._stride_arr(2)) 108 | for i in xrange(1, self.num_residual_units[1]): 109 | with tf.variable_scope('group_3_%d' % i): 110 | x = res_func(x, filters[2], filters[2], self._stride_arr(1)) 111 | 112 | with tf.variable_scope('group_4_0'): 113 | x = res_func(x, filters[2], filters[3], self._stride_arr(1), 2) 114 | for i in xrange(1, self.num_residual_units[2]): 115 | with tf.variable_scope('group_4_%d' % i): 116 | x = res_func(x, filters[3], filters[3], self._stride_arr(1), 2) 117 | 118 | with tf.variable_scope('group_5_0'): 119 | x = res_func(x, filters[3], filters[4], self._stride_arr(1), 4) 120 | for i in xrange(1, self.num_residual_units[3]): 121 | with tf.variable_scope('group_5_%d' % i): 122 | x = res_func(x, filters[4], filters[4], self._stride_arr(1), 4) 123 | 124 | with tf.variable_scope('group_last'): 125 | x = self._relu(x, self.relu_leakiness) 126 | self.res5c = x 127 | 128 | with tf.variable_scope('fc1_voc12'): 129 | x0 = self._conv('conv0', x, 3, filters[4], self.num_classes, self._stride_arr(1), 6, True) 130 | x1 = self._conv('conv1', x, 3, filters[4], self.num_classes, self._stride_arr(1), 12, True) 131 | x2 = self._conv('conv2', x, 3, filters[4], self.num_classes, self._stride_arr(1), 18, True) 132 | x3 = self._conv('conv3', x, 3, filters[4], self.num_classes, self._stride_arr(1), 24, True) 133 | x = tf.add(x0, x1) 134 | x = tf.add(x, x2) 135 | x = tf.add(x, x3) 136 | self.logits = x 137 | x_flat = tf.reshape(x, [-1, self.num_classes]) 138 | pred = tf.nn.softmax(x_flat) 139 | self.pred = tf.reshape(pred, tf.shape(x)) 140 | self.up = tf.image.resize_bilinear(self.pred, [self.H, self.W]) 141 | 142 | def _build_train_op(self): 143 | """Build training specific ops for the graph.""" 144 | labels_coarse = tf.image.resize_nearest_neighbor(self.labels, 145 | [tf.shape(self.pred)[1], tf.shape(self.pred)[2]]) 146 | labels_coarse = tf.squeeze(labels_coarse, squeeze_dims=[3]) 147 | self.labels_coarse = tf.to_int32(labels_coarse) 148 | 149 | # ignore illegal labels 150 | raw_pred = tf.reshape(self.logits, [-1, self.num_classes]) 151 | raw_gt = tf.reshape(self.labels_coarse, [-1,]) 152 | indices = tf.squeeze(tf.where(tf.less_equal(raw_gt, self.num_classes - 1)), 1) 153 | remain_pred = tf.gather(raw_pred, indices) 154 | remain_gt = tf.gather(raw_gt, indices) 155 | 156 | xent = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=remain_pred, 157 | labels=remain_gt) 158 | self.cls_loss = tf.reduce_mean(xent, name='xent') 159 | self.cost = self.cls_loss + self._decay() 160 | # tf.summary.scalar('cost', self.cost) 161 | 162 | self.global_step = tf.Variable(0, name='global_step', trainable=False) 163 | self.learning_rate = tf.train.polynomial_decay(self.lrn_rate, 164 | self.global_step, self.lr_decay_step, power=0.9) 165 | # tf.summary.scalar('learning rate', self.learning_rate) 166 | 167 | tvars = tf.trainable_variables() 168 | 169 | if self.optimizer == 'sgd': 170 | optimizer = tf.train.GradientDescentOptimizer(self.learning_rate) 171 | elif self.optimizer == 'mom': 172 | optimizer = tf.train.MomentumOptimizer(self.learning_rate, 0.9) 173 | else: 174 | raise NameError("Unknown optimizer type %s!" % self.optimizer) 175 | 176 | grads_and_vars = optimizer.compute_gradients(self.cost, var_list=tvars) 177 | var_lr_mult = {} 178 | for var in tvars: 179 | if var.op.name.find(r'fc1_voc12') > 0 and var.op.name.find(r'biases') > 0: 180 | var_lr_mult[var] = 20. 181 | elif var.op.name.find(r'fc1_voc12') > 0: 182 | var_lr_mult[var] = 10. 183 | else: 184 | var_lr_mult[var] = 1. 185 | grads_and_vars = [((g if var_lr_mult[v] == 1 else tf.multiply(var_lr_mult[v], g)), v) 186 | for g, v in grads_and_vars] 187 | 188 | apply_op = optimizer.apply_gradients(grads_and_vars, 189 | global_step=self.global_step, name='train_step') 190 | 191 | train_ops = [apply_op] + self._extra_train_ops 192 | self.train_step = tf.group(*train_ops) 193 | 194 | # TODO(xpan): Consider batch_norm in contrib/layers/python/layers/layers.py 195 | def _batch_norm(self, name, x): 196 | """Batch normalization.""" 197 | with tf.variable_scope(name): 198 | params_shape = [x.get_shape()[-1]] 199 | 200 | beta = tf.get_variable( 201 | 'beta', params_shape, tf.float32, 202 | initializer=tf.constant_initializer(0.0, tf.float32), 203 | trainable=False) 204 | gamma = tf.get_variable( 205 | 'gamma', params_shape, tf.float32, 206 | initializer=tf.constant_initializer(1.0, tf.float32), 207 | trainable=False) 208 | factor = tf.get_variable( 209 | 'factor', 1, tf.float32, 210 | initializer=tf.constant_initializer(1.0, tf.float32), 211 | trainable=False) 212 | 213 | if self.bn: 214 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 215 | 216 | moving_mean = tf.get_variable( 217 | 'mean', params_shape, tf.float32, 218 | initializer=tf.constant_initializer(0.0, tf.float32), 219 | trainable=False) 220 | moving_variance = tf.get_variable( 221 | 'variance', params_shape, tf.float32, 222 | initializer=tf.constant_initializer(1.0, tf.float32), 223 | trainable=False) 224 | 225 | self._extra_train_ops.append(moving_averages.assign_moving_average( 226 | moving_mean, mean, 0.9)) 227 | self._extra_train_ops.append(moving_averages.assign_moving_average( 228 | moving_variance, variance, 0.9)) 229 | else: 230 | mean = tf.get_variable( 231 | 'mean', params_shape, tf.float32, 232 | initializer=tf.constant_initializer(0.0, tf.float32), 233 | trainable=False) 234 | variance = tf.get_variable( 235 | 'variance', params_shape, tf.float32, 236 | initializer=tf.constant_initializer(1.0, tf.float32), 237 | trainable=False) 238 | 239 | # inv_factor = tf.reciprocal(factor) 240 | inv_factor = tf.div(1., factor) 241 | mean = tf.multiply(inv_factor, mean) 242 | variance = tf.multiply(inv_factor, variance) 243 | 244 | # tf.summary.histogram(mean.op.name, mean) 245 | # tf.summary.histogram(variance.op.name, variance) 246 | # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 247 | y = tf.nn.batch_normalization( 248 | x, mean, variance, beta, gamma, 0.001) 249 | y.set_shape(x.get_shape()) 250 | return y 251 | 252 | def _bottleneck_residual(self, x, in_filter, out_filter, stride, atrous=1): 253 | """Bottleneck residual unit with 3 sub layers.""" 254 | 255 | orig_x = x 256 | 257 | with tf.variable_scope('block_1'): 258 | x = self._conv('conv', x, 1, in_filter, out_filter/4, stride, atrous) 259 | x = self._batch_norm('bn', x) 260 | x = self._relu(x, self.relu_leakiness) 261 | 262 | with tf.variable_scope('block_2'): 263 | x = self._conv('conv', x, 3, out_filter/4, out_filter/4, self._stride_arr(1), atrous) 264 | x = self._batch_norm('bn', x) 265 | x = self._relu(x, self.relu_leakiness) 266 | 267 | with tf.variable_scope('block_3'): 268 | x = self._conv('conv', x, 1, out_filter/4, out_filter, self._stride_arr(1), atrous) 269 | x = self._batch_norm('bn', x) 270 | 271 | with tf.variable_scope('block_add'): 272 | if in_filter != out_filter: 273 | orig_x = self._conv('conv', orig_x, 1, in_filter, out_filter, stride, atrous) 274 | orig_x = self._batch_norm('bn', orig_x) 275 | x += orig_x 276 | x = self._relu(x, self.relu_leakiness) 277 | 278 | tf.logging.info('image after unit %s', x.get_shape()) 279 | return x 280 | 281 | def _decay(self): 282 | """L2 weight decay loss.""" 283 | costs = [] 284 | for var in tf.trainable_variables(): 285 | if var.op.name.find(r'DW') > 0: 286 | costs.append(tf.nn.l2_loss(var)) 287 | # tf.histogram_summary(var.op.name, var) 288 | 289 | return tf.multiply(self.weight_decay_rate, tf.add_n(costs)) 290 | 291 | def _conv(self, name, x, filter_size, in_filters, out_filters, strides, atrous=1, bias=False): 292 | """Convolution.""" 293 | with tf.variable_scope(name): 294 | n = filter_size * filter_size * out_filters 295 | w = tf.get_variable( 296 | 'DW', [filter_size, filter_size, in_filters, out_filters], 297 | tf.float32, initializer=tf.random_normal_initializer( 298 | stddev=np.sqrt(2.0/n))) 299 | if atrous == 1: 300 | conv = tf.nn.conv2d(x, w, strides, padding='SAME') 301 | else: 302 | assert(strides == self._stride_arr(1)) 303 | conv = tf.nn.atrous_conv2d(x, w, rate=atrous, padding='SAME') 304 | if bias: 305 | b = tf.get_variable('biases', [out_filters], initializer=tf.constant_initializer()) 306 | return conv + b 307 | else: 308 | return conv 309 | 310 | def _relu(self, x, leakiness=0.0): 311 | """Relu, with optional leaky support.""" 312 | return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu') 313 | 314 | def _fully_connected(self, x, out_dim): 315 | """FullyConnected layer for final output.""" 316 | x = tf.reshape(x, [self.batch_size, -1]) 317 | w = tf.get_variable( 318 | 'DW', [self.filters[-1], out_dim], 319 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 320 | b = tf.get_variable('biases', [out_dim], 321 | initializer=tf.constant_initializer()) 322 | return tf.nn.xw_plus_b(x, w, b) 323 | 324 | def _fully_convolutional(self, x, out_dim): 325 | """FullyConvolutional layer for final output.""" 326 | w = tf.get_variable( 327 | 'DW', [1, 1, self.filters[-1], out_dim], 328 | initializer=tf.uniform_unit_scaling_initializer(factor=1.0)) 329 | b = tf.get_variable('biases', [out_dim], 330 | initializer=tf.constant_initializer()) 331 | return tf.nn.conv2d(x, w, self._stride_arr(1), padding='SAME') + b 332 | 333 | def _global_avg_pool(self, x): 334 | assert x.get_shape().ndims == 4 335 | return tf.expand_dims(tf.expand_dims(tf.reduce_mean(x, [1, 2]), 0), 0) 336 | --------------------------------------------------------------------------------