├── __init__.py ├── lib ├── __init__.py └── utils.py ├── models ├── __init__.py └── fcn8.py ├── paper.pdf ├── media ├── fcn.png ├── gain.png ├── sg_loss.png ├── example1.png ├── example11.png ├── example2.png ├── example3.png ├── example4.png ├── example5.png ├── modification.jpg └── classification_loss.png ├── LICENSE ├── evaluate.py ├── GAIN.py ├── updater.py ├── train_classifier.py ├── visualize.py ├── train_GAIN.py └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/paper.pdf -------------------------------------------------------------------------------- /media/fcn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/fcn.png -------------------------------------------------------------------------------- /media/gain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/gain.png -------------------------------------------------------------------------------- /media/sg_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/sg_loss.png -------------------------------------------------------------------------------- /media/example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example1.png -------------------------------------------------------------------------------- /media/example11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example11.png -------------------------------------------------------------------------------- /media/example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example2.png -------------------------------------------------------------------------------- /media/example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example3.png -------------------------------------------------------------------------------- /media/example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example4.png -------------------------------------------------------------------------------- /media/example5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example5.png -------------------------------------------------------------------------------- /media/modification.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/modification.jpg -------------------------------------------------------------------------------- /media/classification_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/classification_loss.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Alok Kumar Bishoyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import chainer 3 | from chainer import cuda 4 | import fcn 5 | import numpy as np 6 | import tqdm 7 | from models.fcn8 import FCN8s 8 | 9 | 10 | def evaluate(): 11 | parser = argparse.ArgumentParser( 12 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument('--file', type=str, help='model file path') 14 | 15 | args = parser.parse_args() 16 | file = args.file 17 | print("evaluating: ",file) 18 | dataset = fcn.datasets.VOC2011ClassSeg('seg11valid') 19 | n_class = len(dataset.class_names) 20 | 21 | model = FCN8s() 22 | chainer.serializers.load_npz(file, model) 23 | 24 | gpu = 0 25 | 26 | if gpu >= 0: 27 | cuda.get_device(gpu).use() 28 | model.to_gpu() 29 | 30 | lbl_preds, lbl_trues = [], [] 31 | for i in tqdm.trange(len(dataset)): 32 | datum, lbl_true = fcn.datasets.transform_lsvrc2012_vgg16( 33 | dataset.get_example(i)) 34 | x_data = np.expand_dims(datum, axis=0) 35 | if gpu >= 0: 36 | x_data = cuda.to_gpu(x_data) 37 | 38 | with chainer.no_backprop_mode(): 39 | x = chainer.Variable(x_data) 40 | with chainer.using_config('train', False): 41 | model(x) 42 | lbl_pred = chainer.functions.argmax(model.score, axis=1)[0] 43 | lbl_pred = chainer.cuda.to_cpu(lbl_pred.data) 44 | 45 | lbl_preds.append(lbl_pred) 46 | lbl_trues.append(lbl_true) 47 | 48 | acc, acc_cls, mean_iu, fwavacc = fcn.utils.label_accuracy_score(lbl_trues, lbl_preds, n_class) 49 | print('Accuracy: %.4f' % (100 * acc)) 50 | print('AccClass: %.4f' % (100 * acc_cls)) 51 | print('Mean IoU: %.4f' % (100 * mean_iu)) 52 | print('Fwav Acc: %.4f' % (100 * fwavacc)) 53 | if __name__ == '__main__': 54 | evaluate() -------------------------------------------------------------------------------- /GAIN.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | 4 | 5 | class GAIN(chainer.Chain): 6 | def __init__(self): 7 | super(GAIN, self).__init__() 8 | # To override in child class or 9 | # set after initiations from respective class function 10 | # eg see set_final_conv_layer, set_grad_target_layer 11 | # set_GAIN_functions 12 | 13 | self.size = None # Size of images 14 | self.GAIN_functions = None # Refer files in /models 15 | self.final_conv_layer = None 16 | self.grad_target_layer = None 17 | 18 | def stream_cl(self, inp, label=None): 19 | h = inp 20 | for key, funcs in self.GAIN_functions.items(): 21 | for func in funcs: 22 | h = func(h) 23 | if key == self.final_conv_layer: 24 | activation = h 25 | if key == self.grad_target_layer: 26 | break 27 | 28 | gcam, class_id = self.get_gcam(h, activation, (inp.shape[-2], inp.shape[-1]), label=label) 29 | return gcam, h, class_id 30 | 31 | def stream_am(self, masked_image): 32 | h = masked_image 33 | for key, funcs in self.GAIN_functions.items(): 34 | for func in funcs: 35 | h = func(h) 36 | 37 | return h 38 | 39 | def stream_ext(self, inp): 40 | raise NotImplementedError 41 | 42 | def get_gcam(self, end_output, activations, shape, label): 43 | self.cleargrads() 44 | class_id = self.set_init_grad(end_output, label) 45 | end_output.backward(retain_grad=True) 46 | grad = activations.grad_var 47 | grad = F.average_pooling_2d(grad, (grad.shape[-2], grad.shape[-1]), 1) 48 | grad = F.expand_dims(F.reshape(grad, (grad.shape[0]*grad.shape[1], grad.shape[2], grad.shape[3])), 0) 49 | weights = activations 50 | weights = F.expand_dims(F.reshape(weights, (weights.shape[0]*weights.shape[1], weights.shape[2], weights.shape[3])), 0) 51 | gcam = F.resize_images(F.relu(F.convolution_2d(weights, grad, None, 1, 0)), shape) 52 | return gcam, class_id 53 | 54 | def set_init_grad(self, var, label): 55 | var.grad = self.xp.zeros_like(var.data) 56 | if label is None: 57 | class_id = F.argmax(var).data 58 | var.grad[0][class_id] = 1 59 | 60 | else: 61 | class_id = self.xp.random.choice(label, 1) 62 | var.grad[0][class_id] = 1 63 | return class_id 64 | 65 | def add_freeze_layers(self, links_list): 66 | self.freezed_layers = links_list 67 | 68 | def freeze_layers(self): 69 | for link in self.freezed_layers: 70 | getattr(self, link).disable_update() 71 | 72 | def set_final_conv_layer(self, layername): 73 | self.final_conv_layer = layername 74 | 75 | def set_grad_target_layer(self, layername): 76 | self.grad_target_layer = layername 77 | 78 | def set_GAIN_functions(self, ordered_dict): 79 | for key in ordered_dict.keys(): 80 | for item_no in range(len(ordered_dict[key])): 81 | if isinstance(ordered_dict[key][item_no], str): 82 | ordered_dict[key][item_no] = getattr(self, ordered_dict[key][item_no]) 83 | self.GAIN_functions = ordered_dict 84 | 85 | @staticmethod 86 | def get_mask(gcam, sigma=.5, w=8): 87 | gcam = (gcam - F.min(gcam).data)/(F.max(gcam) - F.min(gcam)).data 88 | mask = F.squeeze(F.sigmoid(w * (gcam - sigma))) 89 | return mask 90 | 91 | @staticmethod 92 | def mask_image(img, mask): 93 | broadcasted_mask = F.broadcast_to(mask, img.shape) 94 | to_subtract = img*broadcasted_mask 95 | return img - to_subtract 96 | -------------------------------------------------------------------------------- /updater.py: -------------------------------------------------------------------------------- 1 | from chainer.training import StandardUpdater 2 | from chainer import Variable 3 | from chainer import report 4 | from chainer import functions as F 5 | from chainer.backends.cuda import get_array_module 6 | import numpy as np 7 | 8 | 9 | class VOC_ClassificationUpdater(StandardUpdater): 10 | def __init__(self, iterator, optimizer, no_of_classes=20, device=-1): 11 | super(VOC_ClassificationUpdater, self).__init__(iterator, optimizer) 12 | self.device = device 13 | self.no_of_classes=no_of_classes 14 | 15 | self._optimizers['main'].target.freeze_layers() 16 | def update_core(self): 17 | 18 | image, labels = self.converter(self.get_iterator('main').next()) 19 | assert image.shape[0] == 1, "Batchsize of only 1 is allowed for now" 20 | image = Variable(image) 21 | 22 | if self.device >= 0: 23 | image.to_gpu(self.device) 24 | cl_output = self._optimizers['main'].target.classify(image) 25 | xp = get_array_module(cl_output.data) 26 | 27 | target = xp.asarray([[0]*(self.no_of_classes)]*cl_output.shape[0]) 28 | for i in range(labels.shape[0]): 29 | gt_labels = np.unique(labels[i]).astype(np.int32)[2:] - 1 # Not considering -1 & 0 30 | target[i][gt_labels] = 1 31 | loss = F.sigmoid_cross_entropy(cl_output, target, normalize=True) 32 | report({'Loss':loss}, self.get_optimizer('main').target) 33 | self._optimizers['main'].target.cleargrads() 34 | loss.backward() 35 | self._optimizers['main'].update() 36 | 37 | 38 | 39 | class VOC_GAIN_Updater(StandardUpdater): 40 | 41 | def __init__(self, iterator, optimizer, no_of_classes=20, device=-1, lambd1=1.5, lambd2=1, lambd3=1.5): 42 | super(VOC_GAIN_Updater, self).__init__(iterator, optimizer) 43 | self.device = device 44 | self.no_of_classes = no_of_classes 45 | self.lambd1 = lambd1 46 | self.lambd2 = lambd2 47 | self.lambd3 = lambd3 48 | 49 | self._optimizers['main'].target.freeze_layers() 50 | 51 | def update_core(self): 52 | image, labels = self.converter(self.get_iterator('main').next()) 53 | image = Variable(image) 54 | 55 | assert image.shape[0] == 1, "Batchsize of only 1 is allowed for now" 56 | 57 | if self.device >= 0: 58 | image.to_gpu(self.device) 59 | 60 | xp = get_array_module(image.data) 61 | to_substract = np.array((-1, 0)) 62 | noise_classes = np.unique(labels[0]).astype(np.int32) 63 | target = xp.asarray([[0] * (self.no_of_classes)]) 64 | gt_labels = np.setdiff1d(noise_classes, to_substract) - 1 # np.unique(labels[0]).astype(np.int32)[2:] - 1 65 | target[0][gt_labels] = 1 66 | 67 | gcam, cl_scores, class_id = self._optimizers['main'].target.stream_cl(image, gt_labels) 68 | 69 | mask = self._optimizers['main'].target.get_mask(gcam) 70 | masked_image = self._optimizers['main'].target.mask_image(image, mask) 71 | masked_output = self._optimizers['main'].target.stream_am(masked_image) 72 | masked_output = F.sigmoid(masked_output) 73 | 74 | cl_loss = F.sigmoid_cross_entropy(cl_scores, target, normalize=True) 75 | am_loss = masked_output[0][class_id][0] 76 | 77 | labels = Variable(labels) 78 | if self.device >= 0: 79 | labels.to_gpu(self.device) 80 | segment_loss = self._optimizers['main'].target(image, labels) 81 | total_loss = self.lambd1 * cl_loss + self.lambd2 * am_loss + self.lambd3*segment_loss 82 | report({'AM_Loss': am_loss}, self.get_optimizer('main').target) 83 | report({'CL_Loss': cl_loss}, self.get_optimizer('main').target) 84 | report({'SG_Loss': segment_loss}, self.get_optimizer('main').target) 85 | report({'TotalLoss': total_loss}, self.get_optimizer('main').target) 86 | self._optimizers['main'].target.cleargrads() 87 | total_loss.backward() 88 | self._optimizers['main'].update() -------------------------------------------------------------------------------- /train_classifier.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import fcn 4 | import chainer 5 | from models.fcn8 import FCN8s 6 | from chainercv.datasets import VOCSemanticSegmentationDataset 7 | from chainer.iterators import SerialIterator 8 | from chainer.training.trainer import Trainer 9 | from chainer.training import extensions 10 | from chainer.optimizers import Adam, SGD 11 | from updater import VOC_ClassificationUpdater 12 | 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 18 | parser.add_argument('--device', type=int, default=-1, help='gpu id') 19 | parser.add_argument('--lr_init', type=float, default=5*1e-5, help='init learning rate') 20 | # parser.add_argument('--lr_trigger', type=float, default=5, help='trigger to decreace learning rate') 21 | # parser.add_argument('--lr_target', type=float, default=5*1e-5, help='target learning rate') 22 | # parser.add_argument('--lr_factor', type=float, default=.75, help='decay factor') 23 | parser.add_argument('--name', type=str, default='classifier', help='name of the experiment') 24 | parser.add_argument('--resume', type=bool, default=False, help='resume training or not') 25 | parser.add_argument('--snapshot', type=str, help='snapshot file of the trainer to resume from') 26 | 27 | args = parser.parse_args() 28 | 29 | resume = args.resume 30 | device = args.device 31 | 32 | if resume: 33 | load_snapshot_path = args.snapshot 34 | 35 | experiment = args.name 36 | lr_init = args.lr_init 37 | # lr_target = args.lr_target 38 | # lr_factor = args.lr_factor 39 | # lr_trigger_interval = (args.lr_trigger, 'epoch') 40 | 41 | 42 | os.makedirs('result/'+experiment, exist_ok=True) 43 | f = open('result/'+experiment+'/details.txt',"w+") 44 | f.write("lr - "+str(lr_init)+"\n") 45 | f.write("optimizer - "+str(Adam)) 46 | # f.write("lr_trigger_interval - "+str(lr_trigger_interval)+"\n") 47 | f.close() 48 | 49 | if not resume: 50 | # Add the FC layers to original FCN for GAIN 51 | model_own = FCN8s() 52 | model_original = fcn.models.FCN8s() 53 | model_file = fcn.models.FCN8s.download() 54 | chainer.serializers.load_npz(model_file, model_original) 55 | 56 | for layers in model_original._children: 57 | setattr(model_own, layers, getattr(model_original, layers)) 58 | del(model_original, model_file) 59 | 60 | 61 | else: 62 | model_own = FCN8s() 63 | 64 | if device>=0: 65 | model_own.to_gpu(device) 66 | 67 | dataset = VOCSemanticSegmentationDataset() 68 | iterator = SerialIterator(dataset, 1) 69 | optimizer = Adam(alpha=lr_init) 70 | optimizer.setup(model_own) 71 | 72 | updater = VOC_ClassificationUpdater(iterator, optimizer, device=device) 73 | trainer = Trainer(updater, (100, 'epoch')) 74 | log_keys = ['epoch', 'iteration', 'main/Loss'] 75 | trainer.extend(extensions.LogReport(log_keys, (100, 'iteration'), log_name='log_'+experiment)) 76 | trainer.extend(extensions.PrintReport(log_keys), trigger=(100, 'iteration')) 77 | trainer.extend(extensions.snapshot(filename=experiment+"_snapshot_{.updater.iteration}"), trigger=(5, 'epoch')) 78 | trainer.extend(extensions.snapshot_object(trainer.updater._optimizers['main'].target, experiment+"_model_{.updater.iteration}"), trigger=(5, 'epoch')) 79 | trainer.extend(extensions.PlotReport(['main/Loss'], 'iteration',(100, 'iteration'), file_name='trainer_'+experiment+'/loss.png', grid=True, marker=" ")) 80 | 81 | # trainer.extend(extensions.ExponentialShift('lr', lr_factor, target=lr_target), trigger=lr_trigger_interval) 82 | if resume: 83 | chainer.serializers.load_npz(load_snapshot_path, trainer) 84 | 85 | print("Running - - ", experiment) 86 | print('initial lr ', lr_init) 87 | # print('lr_trigger_interval ', lr_trigger_interval) 88 | trainer.run() 89 | 90 | if __name__ =="__main__": 91 | main() 92 | 93 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import chainer 4 | from chainer import Variable 5 | import chainer.functions as F 6 | from models.fcn8 import FCN8s 7 | from chainercv.datasets import VOCSemanticSegmentationDataset 8 | from chainer.iterators import SerialIterator 9 | from chainer.serializers import load_npz 10 | from chainer.backends.cuda import get_array_module 11 | from chainercv.datasets.voc.voc_utils import voc_semantic_segmentation_label_names 12 | 13 | from matplotlib import pyplot as plt 14 | import numpy as np 15 | import cupy as cp 16 | 17 | 18 | def main() 19 | parser = argparse.ArgumentParser( 20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 22 | parser.add_argument('--pretrained', type=str, help='path to model that has trained classifier but has not been trained through GAIN routine') 23 | parser.add_argument('--trained', type=str, help='path to model trained through GAIN') 24 | parser.add_argument('--device', type=int, default=-1, help='gpu id') 25 | parser.add_argument('--shuffle', type=bool, default=False, help='whether to shuffle dataset') 26 | parser.add_argument('--whole', type=bool, default=False, help='whether to test for the whole validation dataset') 27 | parser.add_argument('--no', type=int, default=10, help='if not whole, then no of images to visualize') 28 | parser.add_argument('--name', type=str, help='name of the subfolder or experiment under which to save') 29 | 30 | args = parser.parse_args() 31 | 32 | pretrained_file = args.pretrained 33 | trained_file = args.trained 34 | device = args.device 35 | shuffle = args.shuffle 36 | whole = args.whole 37 | name = args.name 38 | N = args.no 39 | 40 | dataset = VOCSemanticSegmentationDataset() 41 | iterator = SerialIterator(dataset, 1, shuffle=shuffle, repeat=False) 42 | converter = chainer.dataset.concat_examples 43 | os.makedirs('viz/'+name, exist_ok=True) 44 | no_of_classes = 20 45 | device = 0 46 | pretrained = FCN8s() 47 | trainer = FCN8s() 48 | load_npz(pretrained_file, pretrained) 49 | load_npz(trained_file, trained) 50 | 51 | if device >=0: 52 | pretrained.to_gpu() 53 | trained.to_gpu() 54 | i = 0 55 | 56 | while not iterator.is_new_epoch: 57 | 58 | if not whole and i >= N: 59 | break 60 | 61 | image, labels = converter(iterator.next()) 62 | image = Variable(image) 63 | if device >=0: 64 | image.to_gpu() 65 | 66 | xp = get_array_module(image.data) 67 | to_substract = np.array((-1, 0)) 68 | noise_classes = np.unique(labels[0]).astype(np.int32) 69 | target = xp.asarray([[0]*(no_of_classes)]) 70 | gt_labels = np.setdiff1d(noise_classes, to_substract) - 1 71 | 72 | gcam1, cl_scores1, class_id1 = pretrained.stream_cl(image, gt_labels) 73 | gcam2, cl_scores2, class_id2 = trained.stream_cl(image, gt_labels) 74 | 75 | if device>-0: 76 | class_id = cp.asnumpy(class_id) 77 | fig1 = plt.figure(figsize=(20,10)) 78 | ax1= plt.subplot2grid((3, 9), (0, 0), colspan=3, rowspan=3) 79 | ax1.axis('off') 80 | ax1.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.) 81 | 82 | ax2= plt.subplot2grid((3, 9), (0, 3), colspan=3, rowspan=3) 83 | ax2.axis('off') 84 | ax2.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.) 85 | ax2.imshow(cp.asnumpy(F.squeeze(gcam1[0], 0).data), cmap='jet', alpha=.5) 86 | ax2.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id1[0])+1]), color='teal') 87 | 88 | ax3= plt.subplot2grid((3, 9), (0, 6), colspan=3, rowspan=3) 89 | ax3.axis('off') 90 | ax3.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.) 91 | ax3.imshow(cp.asnumpy(F.squeeze(gcam2[0], 0).data), cmap='jet', alpha=.5) 92 | ax3.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id2[0])+1]), color='teal') 93 | fig1.savefig('viz/'+name+'/'+str(i)+'.png') 94 | plt.close() 95 | print(i) 96 | i += 1 97 | 98 | if __name__ =="__main__": 99 | main() -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import chainer.functions as F 3 | from chainer.dataset import download 4 | from chainer.serializers import npz 5 | from chainer.backends.cuda import get_array_module 6 | import numpy as np 7 | from PIL import Image 8 | from cupy import get_array_module 9 | 10 | 11 | 12 | def convert_caffemodel_to_npz(path_caffemodel, path_npz): 13 | from chainer.links.caffe.caffe_function import CaffeFunction 14 | caffemodel = CaffeFunction(path_caffemodel) 15 | npz.save_npz(path_npz, caffemodel, compression=False) 16 | 17 | 18 | def _make_npz(path_npz, url, model): 19 | path_caffemodel = download.cached_download(url) 20 | print('Now loading caffemodel (usually it may take few minutes)') 21 | convert_caffemodel_to_npz(path_caffemodel, path_npz) 22 | npz.load_npz(path_npz, model) 23 | return model 24 | 25 | 26 | def _retrieve(name, url, model): 27 | root = download.get_dataset_directory('pfnet/chainer/models/') 28 | path = os.path.join(root, name) 29 | return download.cache_or_load_file( 30 | path, lambda path: _make_npz(path, url, model), 31 | lambda path: npz.load_npz(path, model)) 32 | 33 | 34 | def read_image(path, dtype=np.float32, color=True): 35 | """Read an image from a file. 36 | This function reads an image from given file. The image is CHW format and 37 | the range of its value is :math:`[0, 255]`. If :obj:`color = True`, the 38 | order of the channels is RGB. 39 | Args: 40 | path (string): A path of image file. 41 | dtype: The type of array. The default value is :obj:`~numpy.float32`. 42 | color (bool): This option determines the number of channels. 43 | If :obj:`True`, the number of channels is three. In this case, 44 | the order of the channels is RGB. This is the default behaviour. 45 | If :obj:`False`, this function returns a grayscale image. 46 | Returns: 47 | ~numpy.ndarray: An image. 48 | """ 49 | 50 | f = Image.open(path) 51 | try: 52 | if color: 53 | img = f.convert('RGB') 54 | else: 55 | img = f.convert('P') 56 | img = np.asarray(img, dtype=dtype) 57 | finally: 58 | if hasattr(f, 'close'): 59 | f.close() 60 | 61 | return img 62 | 63 | def VGGprepare(image=None, path=None, size=(224, 224)): 64 | """Converts the given image to the numpy array for VGG models. 65 | Note that you have to call this method before ``__call__`` 66 | because the pre-trained vgg model requires to resize the given image, 67 | covert the RGB to the BGR, subtract the mean, 68 | and permute the dimensions before calling. 69 | Args: 70 | image (PIL.Image or numpy.ndarray): Input image. 71 | If an input is ``numpy.ndarray``, its shape must be 72 | ``(height, width)``, ``(height, width, channels)``, 73 | or ``(channels, height, width)``, and 74 | the order of the channels must be RGB. 75 | size (pair of ints): Size of converted images. 76 | If ``None``, the given image is not resized. 77 | Returns: 78 | numpy.ndarray: The converted output array. 79 | """ 80 | if path is not None: 81 | image = read_image(path) 82 | if image.ndim == 4: 83 | image = np.squeeze(image, 0) 84 | if isinstance(image, np.ndarray): 85 | if image.ndim == 3: 86 | if image.shape[0] == 1: 87 | image = image[0, :, :] 88 | elif image.shape[0] == 3: 89 | image = image.transpose((1, 2, 0)) 90 | image = Image.fromarray(image.astype(np.uint8)) 91 | 92 | image = image.convert('RGB') 93 | if size: 94 | image = image.resize(size) 95 | image = np.asarray(image, dtype=np.float32) 96 | image = image[:, :, ::-1] 97 | image -= np.array( 98 | [103.939, 116.779, 123.68], dtype=np.float32) 99 | image = image.transpose((2, 0, 1)) 100 | return np.expand_dims(image, 0) 101 | 102 | def VGGprepare_am_input(var): 103 | xp = get_array_module(var) 104 | 105 | # var = F.resize_images(var, size) 106 | var = F.transpose(var, (0, 2, 3, 1)) # [[W, H, C]] 107 | var = F.flip(var, 3) 108 | var -= xp.array([[103.939, 116.779, 123.68]], dtype=xp.float32) 109 | var = F.transpose(var, (0, 3, 1, 2)) 110 | return var 111 | 112 | -------------------------------------------------------------------------------- /train_GAIN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import chainer 4 | from models.fcn8 import FCN8s 5 | from chainercv.datasets import VOCSemanticSegmentationDataset 6 | from chainer.iterators import SerialIterator 7 | from chainer.training.trainer import Trainer 8 | from chainer.training import extensions 9 | from chainer.optimizers import Adam 10 | from updater import VOC_GAIN_Updater 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser( 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('--device', type=int, default=-1, help='gpu id') 16 | parser.add_argument('--modelfile', help='pretrained model file of FCN8') 17 | parser.add_argument('--lr', type=float, default=1e-7, help='init learning rate') 18 | parser.add_argument('--name', type=str, default='GAIN', help='name of the experiment') 19 | parser.add_argument('--resume', type=bool, default=False, help='resume training or not') 20 | parser.add_argument('--snapshot', type=str, help='snapshot file to resume from') 21 | parser.add_argument('--lambda1', default=5, type=float, help='lambda1 param') 22 | parser.add_argument('--lambda2', default=1, type=float, help='lambda2 param') 23 | parser.add_argument('--lambda3', default=1.5, type=float, help='lambda3 param') 24 | 25 | args = parser.parse_args() 26 | 27 | 28 | resume = args.resume 29 | device = args.device 30 | 31 | if resume: 32 | load_snapshot_path = args.snapshot 33 | load_model_path = args.modelfile 34 | else: 35 | pretrained_model_path = args.modelfile 36 | 37 | experiment = args.name 38 | lr = args.lr 39 | optim = Adam 40 | training_interval = (20000, 'iteration') 41 | snapshot_interval = (1000, 'iteration') 42 | lambd1 = args.lambda1 43 | lambd2 = args.lambda2 44 | lambd3 = args.lambda3 45 | updtr = VOC_GAIN_Updater 46 | 47 | os.makedirs('result/'+experiment, exist_ok=True) 48 | f = open('result/'+experiment+'/details.txt', "w+") 49 | f.write("lr - "+str(lr)+"\n") 50 | f.write("optimizer - "+str(optim)+"\n") 51 | f.write("lambd1 - "+str(lambd1)+"\n") 52 | f.write("lambd2 - "+str(lambd2)+"\n") 53 | f.write("lambd3 - "+str(lambd3)+"\n") 54 | f.write("training_interval - "+str(training_interval)+"\n") 55 | f.write("Updater - "+str(updtr)+"\n") 56 | f.close() 57 | 58 | if resume: 59 | model = FCN8s() 60 | chainer.serializers.load_npz(load_model_path, model) 61 | else: 62 | model = FCN8s() 63 | chainer.serializers.load_npz(pretrained_model_path, model) 64 | 65 | 66 | if device >= 0: 67 | model.to_gpu(device) 68 | dataset = VOCSemanticSegmentationDataset() 69 | iterator = SerialIterator(dataset, 1, shuffle=False) 70 | 71 | optimizer = Adam(alpha=lr) 72 | optimizer.setup(model) 73 | 74 | updater = updtr(iterator, optimizer, device=device, lambd1=lambd1, lambd2=lambd2) 75 | trainer = Trainer(updater, training_interval) 76 | log_keys = ['epoch', 'iteration', 'main/AM_Loss', 'main/CL_Loss', 'main/TotalLoss'] 77 | trainer.extend(extensions.LogReport(log_keys, (10, 'iteration'), log_name='log'+experiment)) 78 | trainer.extend(extensions.PrintReport(log_keys), trigger=(100, 'iteration')) 79 | trainer.extend(extensions.ProgressBar(training_length=training_interval, update_interval=100)) 80 | 81 | trainer.extend(extensions.snapshot(filename=experiment+'_snapshot_{.updater.iteration}'), trigger=snapshot_interval) 82 | trainer.extend(extensions.snapshot_object(trainer.updater._optimizers['main'].target, 83 | experiment+'_model_{.updater.iteration}'), trigger=snapshot_interval) 84 | 85 | trainer.extend(extensions.PlotReport(['main/AM_Loss'], 'iteration',(20, 'iteration'), file_name=experiment+'/am_loss.png', grid=True, marker=" ")) 86 | trainer.extend(extensions.PlotReport(['main/CL_Loss'], 'iteration',(20, 'iteration'), file_name=experiment+'/cl_loss.png', grid=True, marker=" ")) 87 | trainer.extend( 88 | extensions.PlotReport(['main/SG_Loss'], 'iteration', (20, 'iteration'), file_name=experiment + '/sg_loss.png',grid=True, marker=" ")) 89 | trainer.extend(extensions.PlotReport(['main/TotalLoss'], 'iteration',(20, 'iteration'), file_name=experiment+'/total_loss.png', grid=True, marker=" ")) 90 | trainer.extend(extensions.PlotReport(log_keys[2:], 'iteration',(20, 'iteration'), file_name=experiment+'/all_loss.png', grid=True, marker=" ")) 91 | 92 | if resume: 93 | chainer.serializers.load_npz(load_snapshot_path, trainer) 94 | print("Running - - ", experiment) 95 | print('initial lr ',lr) 96 | print('optimizer ', optim) 97 | print('lambd1 ', lambd1) 98 | print('lambd2 ', lambd2) 99 | print('lambd3', lambd3) 100 | trainer.run() 101 | 102 | if __name__ == "__main__": 103 | main() 104 | 105 | 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Guided Attention for FCN 2 | 3 | ## About 4 | Chainer implementation of Tell Me Where To Look. 5 | This is an experiment to apply Guided Attention Inference Network(GAIN) as presented in the paper to Fully Convolutional Networks(FCN) used for segmentation purposes. The trained FCN8s model is fine tuned using guided attention. 6 | 7 | ## GAIN 8 | GAIN is based on supervising the attention maps that is produced when we train the network for 9 | the task of interest. 10 | 11 | ![Image](media/gain.png) 12 | ## FCN 13 | Fully Convolutional Networks is a network architecture that consists of convolution layers followed by deconvolutions to 14 | give the segmentation output 15 | 16 | ![Image](media/fcn.png) 17 | ## Approach 18 | 19 | * We take the fully trained FCN8 network and add a average pooling and fully connected layers after its convolutional layers. We freeze the convolutional layers and 20 | train the fully connected networks to classify for the objects. We do this in order to get GradCAMs for a particular class to be later used during GAIN 21 | 22 | ![Image](media/modification.jpg) 23 | 24 | * Next we train the network as per the GAIN update rule. However in this implementation I have also considered the segmentation loss along with the 25 | GAIN updates/loss. This is because, I found using only the GAIN updates though did lead to convergence of losses, but also resulted in quite a significant dip in segmentation accuracies. In this step, the fully connected ayers are freezed and are not updated. 26 | 27 | ## Loss Curves 28 | ### For classification training 29 | ![Image](media/classification_loss.png) 30 | 31 | ### Segmentation Loss during GAIN updates 32 | ![Image](media/sg_loss.png) 33 | 34 | 35 | ## Qualitative Results 36 | | Original Image | PreTrained GCAMs | Post GAIN GCAMs | 37 | 38 | ![Image](media/example2.png) 39 | 40 | ![Image](media/example3.png) 41 | 42 | ![Image](media/example4.png) 43 | 44 | ![Image](media/example5.png) 45 | 46 | 47 | ## Quantitative Results 48 | 49 | 50 | ### For FCN8s 51 | 52 | | Implementation | Accuracy | Accuracy Class | Mean IU | FWAVACC | Model File | 53 | |:--------------:|:--------:|:--------------:|:-------:|:-------:|:----------:| 54 | | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/master/voc-fcn8s) | 91.2212 | 77.6146 | 65.5126 | 84.5445 | [`fcn8s_from_caffe.npz`](https://drive.google.com/uc?id=0B9P1L--7Wd2vb0cxV0VhcG1Lb28) | 55 | | Experimental| 90.5962 | **80.4099** | 64.6869 | 83.9952 | **To make public soon** | 56 | 57 | ## How to use 58 | ```bash 59 | pip install chainer 60 | pip install chainercv 61 | pip install cupy 62 | pip install fcn 63 | ``` 64 | Training 65 | -------- 66 | For training the classifier, download. the pretrained FCN8s chainer model 67 | ```bash 68 | python3 train_classifier.py --device 0 69 | ``` 70 | This will automatically download the pretrained file and train the classifier on it. You might run into an error of " xxx.txt file not found " while running this script. To solve this, at the place where your `fcn` library is installed, get the missing file from the fcn repository over github, and take care to put the exisiting file by making the same directory structure as asked in the error message. For more details, refer to this issue 71 | 72 | 73 | For GAIN updates, 74 | ```bash 75 | python3 train_GAIN.py --modelfile --device 0 76 | ``` 77 | 78 | The accuracy of original implementation is computed with (`evaluate.py `) which has been borrowed from wkentaro's implementation 79 | 80 | Visualization 81 | ------------- 82 | ```bash 83 | visualize.py 84 | ``` 85 | required arguements - 86 | ``` 87 | --pretrained 88 | --trained 89 | ``` 90 | 91 | optional arguements - 92 | ``` 93 | --device=-1 94 | --whole=False < whether to test on whole valid dataset> 95 | --shuffle=False 96 | --no=10 97 | ``` 98 | 99 | ## To Do 100 | 101 | - [x] Push Visualization Code 102 | 103 | ## Using GAIN for other models 104 | I have attempted to make GAIN as modular as possible so that it can be used on some other model as well. All you would need to do is make GAIN class( which itself inherits chainer.Chain) as parent class to your model. 105 | Each GAIN model needs to have a few particular instance variables in order to be able to function. GAIN module has methods to instantiate every single one of them. I would advice you to lookup ```models/fcn8.py``` as well as ```GAIN.py``` to have an idea about them. 106 | 107 | * GAIN_functions - An ordered dict consisting of names of steps and it's associated functions. 108 | * final_conv_layer - Name of the step after which no convolutions happen 109 | * grad_target_layer - Name of the step from where gradients are to be collected for computing GradCAM 110 | 111 | 112 | ## Credits 113 | The original FCN module and the fcn package is courtesy of wkentaro 114 | 115 | ## Citation 116 | If you find this code useful in your research, please consider citing: 117 | 118 | @misc{Alok2018, 119 | Author = {Bishoyi, Alok Kumar}, 120 | Title = {Guided Attention Inference Network}, 121 | Year = {2018}, 122 | Publisher = {GitHub}, 123 | journal = {Github repository}, 124 | howpublished = {\url{https://github.com/alokwhitewolf/Guided-Attention-Inference-Network}}, 125 | } 126 | -------------------------------------------------------------------------------- /models/fcn8.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import collections 3 | 4 | import chainer 5 | import chainer.functions as F 6 | import chainer.links as L 7 | import numpy as np 8 | from GAIN import GAIN 9 | 10 | from fcn import data 11 | from fcn import initializers 12 | 13 | 14 | class FCN8s(GAIN): 15 | 16 | def __init__(self, n_class=21): 17 | self.n_class = n_class 18 | kwargs = { 19 | 'initialW': chainer.initializers.Zero(), 20 | 'initial_bias': chainer.initializers.Zero(), 21 | } 22 | super(FCN8s, self).__init__() 23 | with self.init_scope(): 24 | self.conv1_1 = L.Convolution2D(3, 64, 3, 1, 1) 25 | self.conv1_2 = L.Convolution2D(64, 64, 3, 1, 1) 26 | 27 | self.conv2_1 = L.Convolution2D(64, 128, 3, 1, 1) 28 | self.conv2_2 = L.Convolution2D(128, 128, 3, 1, 1) 29 | 30 | self.conv3_1 = L.Convolution2D(128, 256, 3, 1, 1) 31 | self.conv3_2 = L.Convolution2D(256, 256, 3, 1, 1) 32 | self.conv3_3 = L.Convolution2D(256, 256, 3, 1, 1) 33 | 34 | self.conv4_1 = L.Convolution2D(256, 512, 3, 1, 1) 35 | self.conv4_2 = L.Convolution2D(512, 512, 3, 1, 1) 36 | self.conv4_3 = L.Convolution2D(512, 512, 3, 1, 1) 37 | 38 | self.conv5_1 = L.Convolution2D(512, 512, 3, 1, 1) 39 | self.conv5_2 = L.Convolution2D(512, 512, 3, 1, 1) 40 | self.conv5_3 = L.Convolution2D(512, 512, 3, 1, 1) 41 | 42 | self.fc6 = L.Convolution2D(512, 4096, 7, 1, 0) 43 | self.fc7 = L.Convolution2D(4096, 4096, 1, 1, 0) 44 | self.score_fr = L.Convolution2D(4096, n_class, 1, 1, 0) 45 | 46 | self.fc6_cl = L.Linear(512, 4096) 47 | self.fc7_cl = L.Linear(4096, 4096) 48 | self.score_cl = L.Linear(4096, n_class-1) # Disregard 0 class for classification 49 | 50 | 51 | self.upscore2 = L.Deconvolution2D( 52 | n_class, n_class, 4, 2, 0, nobias=True, 53 | initialW=initializers.UpsamplingDeconvWeight()) 54 | self.upscore8 = L.Deconvolution2D( 55 | n_class, n_class, 16, 8, 0, nobias=True, 56 | initialW=initializers.UpsamplingDeconvWeight()) 57 | 58 | self.score_pool3 = L.Convolution2D(256, n_class, 1, 1, 0) 59 | self.score_pool4 = L.Convolution2D(512, n_class, 1, 1, 0) 60 | self.upscore_pool4 = L.Deconvolution2D( 61 | n_class, n_class, 4, 2, 0, nobias=True, 62 | initialW=initializers.UpsamplingDeconvWeight()) 63 | 64 | self.GAIN_functions = collections.OrderedDict([ 65 | ('conv1_1', [self.conv1_1, F.relu]), 66 | ('conv1_2', [self.conv1_2, F.relu]), 67 | ('pool1', [_max_pooling_2d]), 68 | 69 | ('conv2_1', [self.conv2_1, F.relu]), 70 | ('conv2_2', [self.conv2_2, F.relu]), 71 | ('pool2', [_max_pooling_2d]), 72 | 73 | ('conv3_1', [self.conv3_1, F.relu]), 74 | ('conv3_2', [self.conv3_2, F.relu]), 75 | ('conv3_3', [self.conv3_3, F.relu]), 76 | ('pool3', [_max_pooling_2d]), 77 | 78 | ('conv4_1', [self.conv4_1, F.relu]), 79 | ('conv4_2', [self.conv4_2, F.relu]), 80 | ('conv4_3', [self.conv4_3, F.relu]), 81 | ('pool4', [_max_pooling_2d]), 82 | 83 | ('conv5_1', [self.conv5_1, F.relu]), 84 | ('conv5_2', [self.conv5_2, F.relu]), 85 | ('conv5_3', [self.conv5_3, F.relu]), 86 | ('pool5', [_max_pooling_2d]), 87 | 88 | ('avg_pool', [_average_pooling_2d]), 89 | 90 | ('fc6_cl', [self.fc6_cl, F.relu]), 91 | ('fc7_cl', [self.fc7_cl, F.relu]), 92 | ('prob', [self.score_cl, F.sigmoid]) 93 | 94 | ]) 95 | self.final_conv_layer = 'conv5_3' 96 | self.grad_target_layer = 'prob' 97 | self.freezed_layers = ['fc6_cl', 'fc7_cl', 'score_cl'] 98 | 99 | def segment(self, x, t=None): 100 | # conv1 101 | self.conv1_1.pad = (100, 100) 102 | h = F.relu(self.conv1_1(x)) 103 | conv1_1 = h 104 | h = F.relu(self.conv1_2(conv1_1)) 105 | conv1_2 = h 106 | h = _max_pooling_2d(conv1_2) 107 | pool1 = h # 1/2 108 | 109 | # conv2 110 | h = F.relu(self.conv2_1(pool1)) 111 | conv2_1 = h 112 | h = F.relu(self.conv2_2(conv2_1)) 113 | conv2_2 = h 114 | h = _max_pooling_2d(conv2_2) 115 | pool2 = h # 1/4 116 | 117 | # conv3 118 | h = F.relu(self.conv3_1(pool2)) 119 | conv3_1 = h 120 | h = F.relu(self.conv3_2(conv3_1)) 121 | conv3_2 = h 122 | h = F.relu(self.conv3_3(conv3_2)) 123 | conv3_3 = h 124 | h = _max_pooling_2d(conv3_3) 125 | pool3 = h # 1/8 126 | 127 | # conv4 128 | h = F.relu(self.conv4_1(pool3)) 129 | h = F.relu(self.conv4_2(h)) 130 | h = F.relu(self.conv4_3(h)) 131 | h = _max_pooling_2d(h) 132 | pool4 = h # 1/16 133 | 134 | # conv5 135 | h = F.relu(self.conv5_1(pool4)) 136 | h = F.relu(self.conv5_2(h)) 137 | h = F.relu(self.conv5_3(h)) 138 | h = _max_pooling_2d(h) 139 | pool5 = h # 1/32 140 | 141 | # fc6 142 | h = F.relu(self.fc6(pool5)) 143 | h = F.dropout(h, ratio=.5) 144 | fc6 = h # 1/32 145 | 146 | # fc7 147 | h = F.relu(self.fc7(fc6)) 148 | h = F.dropout(h, ratio=.5) 149 | fc7 = h # 1/32 150 | 151 | # score_fr 152 | h = self.score_fr(fc7) 153 | score_fr = h # 1/32 154 | 155 | # score_pool3 156 | h = self.score_pool3(pool3) 157 | score_pool3 = h # 1/8 158 | 159 | # score_pool4 160 | h = self.score_pool4(pool4) 161 | score_pool4 = h # 1/16 162 | 163 | # upscore2 164 | h = self.upscore2(score_fr) 165 | upscore2 = h # 1/16 166 | 167 | # score_pool4c 168 | h = score_pool4[:, :, 169 | 5:5 + upscore2.shape[2], 170 | 5:5 + upscore2.shape[3]] 171 | score_pool4c = h # 1/16 172 | 173 | # fuse_pool4 174 | h = upscore2 + score_pool4c 175 | fuse_pool4 = h # 1/16 176 | 177 | # upscore_pool4 178 | h = self.upscore_pool4(fuse_pool4) 179 | upscore_pool4 = h # 1/8 180 | 181 | # score_pool4c 182 | h = score_pool3[:, :, 183 | 9:9 + upscore_pool4.shape[2], 184 | 9:9 + upscore_pool4.shape[3]] 185 | score_pool3c = h # 1/8 186 | 187 | # fuse_pool3 188 | h = upscore_pool4 + score_pool3c 189 | fuse_pool3 = h # 1/8 190 | 191 | # upscore8 192 | h = self.upscore8(fuse_pool3) 193 | upscore8 = h # 1/1 194 | 195 | # score 196 | h = upscore8[:, :, 31:31 + x.shape[2], 31:31 + x.shape[3]] 197 | score = h # 1/1 198 | self.score = score 199 | 200 | if t is None: 201 | assert not chainer.config.train 202 | return 203 | 204 | loss = F.softmax_cross_entropy(score, t, normalize=True) 205 | if np.isnan(float(loss.data)): 206 | raise ValueError('Loss is nan.') 207 | chainer.report({'loss': loss}, self) 208 | self.conv1_1.pad = (1, 1) 209 | return loss 210 | 211 | 212 | def classify(self, x, is_training=True): 213 | with chainer.using_config('train',False): 214 | # conv1 215 | h = F.relu(self.conv1_1(x)) 216 | h = F.relu(self.conv1_2(h)) 217 | h = _max_pooling_2d(h) 218 | 219 | # conv2 220 | h = F.relu(self.conv2_1(h)) 221 | h = F.relu(self.conv2_2(h)) 222 | h = _max_pooling_2d(h) 223 | 224 | # conv3 225 | h = F.relu(self.conv3_1(h)) 226 | h = F.relu(self.conv3_2(h)) 227 | h = F.relu(self.conv3_3(h)) 228 | h = _max_pooling_2d(h) 229 | 230 | # conv4 231 | h = F.relu(self.conv4_1(h)) 232 | h = F.relu(self.conv4_2(h)) 233 | h = F.relu(self.conv4_3(h)) 234 | h = _max_pooling_2d(h) 235 | 236 | # conv5 237 | h = F.relu(self.conv5_1(h)) 238 | h = F.relu(self.conv5_2(h)) 239 | h = F.relu(self.conv5_3(h)) 240 | h = _max_pooling_2d(h) 241 | h = _average_pooling_2d(h) 242 | 243 | with chainer.using_config('train',is_training): 244 | h = F.relu(F.dropout(self.fc6_cl(h), .5)) 245 | h = F.relu(F.dropout(self.fc7_cl(h), .5)) 246 | h = self.score_cl(h) 247 | 248 | return h 249 | 250 | 251 | def __call__(self, x, t=None): 252 | return self.segment(x, t) 253 | 254 | 255 | def predict(self, imgs): 256 | lbls = [] 257 | for img in imgs: 258 | with chainer.no_backprop_mode(), chainer.using_config('train', False): 259 | x = self.xp.asarray(img[None]) 260 | self.__call__(x) 261 | lbl = chainer.functions.argmax(self.score, axis=1) 262 | lbl = chainer.cuda.to_cpu(lbl.array[0]) 263 | lbls.append(lbl) 264 | return lbls 265 | 266 | def _max_pooling_2d(x): 267 | return F.max_pooling_2d(x, ksize=2, stride=2, pad=0) 268 | 269 | def _average_pooling_2d(x): 270 | return F.average_pooling_2d(x, ksize=(x.shape[-2], x.shape[-1])) 271 | --------------------------------------------------------------------------------