├── __init__.py
├── lib
├── __init__.py
└── utils.py
├── models
├── __init__.py
└── fcn8.py
├── paper.pdf
├── media
├── fcn.png
├── gain.png
├── sg_loss.png
├── example1.png
├── example11.png
├── example2.png
├── example3.png
├── example4.png
├── example5.png
├── modification.jpg
└── classification_loss.png
├── LICENSE
├── evaluate.py
├── GAIN.py
├── updater.py
├── train_classifier.py
├── visualize.py
├── train_GAIN.py
└── README.md
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/paper.pdf
--------------------------------------------------------------------------------
/media/fcn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/fcn.png
--------------------------------------------------------------------------------
/media/gain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/gain.png
--------------------------------------------------------------------------------
/media/sg_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/sg_loss.png
--------------------------------------------------------------------------------
/media/example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example1.png
--------------------------------------------------------------------------------
/media/example11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example11.png
--------------------------------------------------------------------------------
/media/example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example2.png
--------------------------------------------------------------------------------
/media/example3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example3.png
--------------------------------------------------------------------------------
/media/example4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example4.png
--------------------------------------------------------------------------------
/media/example5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/example5.png
--------------------------------------------------------------------------------
/media/modification.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/modification.jpg
--------------------------------------------------------------------------------
/media/classification_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alokwhitewolf/Guided-Attention-Inference-Network/HEAD/media/classification_loss.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Alok Kumar Bishoyi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import chainer
3 | from chainer import cuda
4 | import fcn
5 | import numpy as np
6 | import tqdm
7 | from models.fcn8 import FCN8s
8 |
9 |
10 | def evaluate():
11 | parser = argparse.ArgumentParser(
12 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
13 | parser.add_argument('--file', type=str, help='model file path')
14 |
15 | args = parser.parse_args()
16 | file = args.file
17 | print("evaluating: ",file)
18 | dataset = fcn.datasets.VOC2011ClassSeg('seg11valid')
19 | n_class = len(dataset.class_names)
20 |
21 | model = FCN8s()
22 | chainer.serializers.load_npz(file, model)
23 |
24 | gpu = 0
25 |
26 | if gpu >= 0:
27 | cuda.get_device(gpu).use()
28 | model.to_gpu()
29 |
30 | lbl_preds, lbl_trues = [], []
31 | for i in tqdm.trange(len(dataset)):
32 | datum, lbl_true = fcn.datasets.transform_lsvrc2012_vgg16(
33 | dataset.get_example(i))
34 | x_data = np.expand_dims(datum, axis=0)
35 | if gpu >= 0:
36 | x_data = cuda.to_gpu(x_data)
37 |
38 | with chainer.no_backprop_mode():
39 | x = chainer.Variable(x_data)
40 | with chainer.using_config('train', False):
41 | model(x)
42 | lbl_pred = chainer.functions.argmax(model.score, axis=1)[0]
43 | lbl_pred = chainer.cuda.to_cpu(lbl_pred.data)
44 |
45 | lbl_preds.append(lbl_pred)
46 | lbl_trues.append(lbl_true)
47 |
48 | acc, acc_cls, mean_iu, fwavacc = fcn.utils.label_accuracy_score(lbl_trues, lbl_preds, n_class)
49 | print('Accuracy: %.4f' % (100 * acc))
50 | print('AccClass: %.4f' % (100 * acc_cls))
51 | print('Mean IoU: %.4f' % (100 * mean_iu))
52 | print('Fwav Acc: %.4f' % (100 * fwavacc))
53 | if __name__ == '__main__':
54 | evaluate()
--------------------------------------------------------------------------------
/GAIN.py:
--------------------------------------------------------------------------------
1 | import chainer
2 | import chainer.functions as F
3 |
4 |
5 | class GAIN(chainer.Chain):
6 | def __init__(self):
7 | super(GAIN, self).__init__()
8 | # To override in child class or
9 | # set after initiations from respective class function
10 | # eg see set_final_conv_layer, set_grad_target_layer
11 | # set_GAIN_functions
12 |
13 | self.size = None # Size of images
14 | self.GAIN_functions = None # Refer files in /models
15 | self.final_conv_layer = None
16 | self.grad_target_layer = None
17 |
18 | def stream_cl(self, inp, label=None):
19 | h = inp
20 | for key, funcs in self.GAIN_functions.items():
21 | for func in funcs:
22 | h = func(h)
23 | if key == self.final_conv_layer:
24 | activation = h
25 | if key == self.grad_target_layer:
26 | break
27 |
28 | gcam, class_id = self.get_gcam(h, activation, (inp.shape[-2], inp.shape[-1]), label=label)
29 | return gcam, h, class_id
30 |
31 | def stream_am(self, masked_image):
32 | h = masked_image
33 | for key, funcs in self.GAIN_functions.items():
34 | for func in funcs:
35 | h = func(h)
36 |
37 | return h
38 |
39 | def stream_ext(self, inp):
40 | raise NotImplementedError
41 |
42 | def get_gcam(self, end_output, activations, shape, label):
43 | self.cleargrads()
44 | class_id = self.set_init_grad(end_output, label)
45 | end_output.backward(retain_grad=True)
46 | grad = activations.grad_var
47 | grad = F.average_pooling_2d(grad, (grad.shape[-2], grad.shape[-1]), 1)
48 | grad = F.expand_dims(F.reshape(grad, (grad.shape[0]*grad.shape[1], grad.shape[2], grad.shape[3])), 0)
49 | weights = activations
50 | weights = F.expand_dims(F.reshape(weights, (weights.shape[0]*weights.shape[1], weights.shape[2], weights.shape[3])), 0)
51 | gcam = F.resize_images(F.relu(F.convolution_2d(weights, grad, None, 1, 0)), shape)
52 | return gcam, class_id
53 |
54 | def set_init_grad(self, var, label):
55 | var.grad = self.xp.zeros_like(var.data)
56 | if label is None:
57 | class_id = F.argmax(var).data
58 | var.grad[0][class_id] = 1
59 |
60 | else:
61 | class_id = self.xp.random.choice(label, 1)
62 | var.grad[0][class_id] = 1
63 | return class_id
64 |
65 | def add_freeze_layers(self, links_list):
66 | self.freezed_layers = links_list
67 |
68 | def freeze_layers(self):
69 | for link in self.freezed_layers:
70 | getattr(self, link).disable_update()
71 |
72 | def set_final_conv_layer(self, layername):
73 | self.final_conv_layer = layername
74 |
75 | def set_grad_target_layer(self, layername):
76 | self.grad_target_layer = layername
77 |
78 | def set_GAIN_functions(self, ordered_dict):
79 | for key in ordered_dict.keys():
80 | for item_no in range(len(ordered_dict[key])):
81 | if isinstance(ordered_dict[key][item_no], str):
82 | ordered_dict[key][item_no] = getattr(self, ordered_dict[key][item_no])
83 | self.GAIN_functions = ordered_dict
84 |
85 | @staticmethod
86 | def get_mask(gcam, sigma=.5, w=8):
87 | gcam = (gcam - F.min(gcam).data)/(F.max(gcam) - F.min(gcam)).data
88 | mask = F.squeeze(F.sigmoid(w * (gcam - sigma)))
89 | return mask
90 |
91 | @staticmethod
92 | def mask_image(img, mask):
93 | broadcasted_mask = F.broadcast_to(mask, img.shape)
94 | to_subtract = img*broadcasted_mask
95 | return img - to_subtract
96 |
--------------------------------------------------------------------------------
/updater.py:
--------------------------------------------------------------------------------
1 | from chainer.training import StandardUpdater
2 | from chainer import Variable
3 | from chainer import report
4 | from chainer import functions as F
5 | from chainer.backends.cuda import get_array_module
6 | import numpy as np
7 |
8 |
9 | class VOC_ClassificationUpdater(StandardUpdater):
10 | def __init__(self, iterator, optimizer, no_of_classes=20, device=-1):
11 | super(VOC_ClassificationUpdater, self).__init__(iterator, optimizer)
12 | self.device = device
13 | self.no_of_classes=no_of_classes
14 |
15 | self._optimizers['main'].target.freeze_layers()
16 | def update_core(self):
17 |
18 | image, labels = self.converter(self.get_iterator('main').next())
19 | assert image.shape[0] == 1, "Batchsize of only 1 is allowed for now"
20 | image = Variable(image)
21 |
22 | if self.device >= 0:
23 | image.to_gpu(self.device)
24 | cl_output = self._optimizers['main'].target.classify(image)
25 | xp = get_array_module(cl_output.data)
26 |
27 | target = xp.asarray([[0]*(self.no_of_classes)]*cl_output.shape[0])
28 | for i in range(labels.shape[0]):
29 | gt_labels = np.unique(labels[i]).astype(np.int32)[2:] - 1 # Not considering -1 & 0
30 | target[i][gt_labels] = 1
31 | loss = F.sigmoid_cross_entropy(cl_output, target, normalize=True)
32 | report({'Loss':loss}, self.get_optimizer('main').target)
33 | self._optimizers['main'].target.cleargrads()
34 | loss.backward()
35 | self._optimizers['main'].update()
36 |
37 |
38 |
39 | class VOC_GAIN_Updater(StandardUpdater):
40 |
41 | def __init__(self, iterator, optimizer, no_of_classes=20, device=-1, lambd1=1.5, lambd2=1, lambd3=1.5):
42 | super(VOC_GAIN_Updater, self).__init__(iterator, optimizer)
43 | self.device = device
44 | self.no_of_classes = no_of_classes
45 | self.lambd1 = lambd1
46 | self.lambd2 = lambd2
47 | self.lambd3 = lambd3
48 |
49 | self._optimizers['main'].target.freeze_layers()
50 |
51 | def update_core(self):
52 | image, labels = self.converter(self.get_iterator('main').next())
53 | image = Variable(image)
54 |
55 | assert image.shape[0] == 1, "Batchsize of only 1 is allowed for now"
56 |
57 | if self.device >= 0:
58 | image.to_gpu(self.device)
59 |
60 | xp = get_array_module(image.data)
61 | to_substract = np.array((-1, 0))
62 | noise_classes = np.unique(labels[0]).astype(np.int32)
63 | target = xp.asarray([[0] * (self.no_of_classes)])
64 | gt_labels = np.setdiff1d(noise_classes, to_substract) - 1 # np.unique(labels[0]).astype(np.int32)[2:] - 1
65 | target[0][gt_labels] = 1
66 |
67 | gcam, cl_scores, class_id = self._optimizers['main'].target.stream_cl(image, gt_labels)
68 |
69 | mask = self._optimizers['main'].target.get_mask(gcam)
70 | masked_image = self._optimizers['main'].target.mask_image(image, mask)
71 | masked_output = self._optimizers['main'].target.stream_am(masked_image)
72 | masked_output = F.sigmoid(masked_output)
73 |
74 | cl_loss = F.sigmoid_cross_entropy(cl_scores, target, normalize=True)
75 | am_loss = masked_output[0][class_id][0]
76 |
77 | labels = Variable(labels)
78 | if self.device >= 0:
79 | labels.to_gpu(self.device)
80 | segment_loss = self._optimizers['main'].target(image, labels)
81 | total_loss = self.lambd1 * cl_loss + self.lambd2 * am_loss + self.lambd3*segment_loss
82 | report({'AM_Loss': am_loss}, self.get_optimizer('main').target)
83 | report({'CL_Loss': cl_loss}, self.get_optimizer('main').target)
84 | report({'SG_Loss': segment_loss}, self.get_optimizer('main').target)
85 | report({'TotalLoss': total_loss}, self.get_optimizer('main').target)
86 | self._optimizers['main'].target.cleargrads()
87 | total_loss.backward()
88 | self._optimizers['main'].update()
--------------------------------------------------------------------------------
/train_classifier.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import fcn
4 | import chainer
5 | from models.fcn8 import FCN8s
6 | from chainercv.datasets import VOCSemanticSegmentationDataset
7 | from chainer.iterators import SerialIterator
8 | from chainer.training.trainer import Trainer
9 | from chainer.training import extensions
10 | from chainer.optimizers import Adam, SGD
11 | from updater import VOC_ClassificationUpdater
12 |
13 | import matplotlib
14 | matplotlib.use('Agg')
15 |
16 | def main():
17 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
18 | parser.add_argument('--device', type=int, default=-1, help='gpu id')
19 | parser.add_argument('--lr_init', type=float, default=5*1e-5, help='init learning rate')
20 | # parser.add_argument('--lr_trigger', type=float, default=5, help='trigger to decreace learning rate')
21 | # parser.add_argument('--lr_target', type=float, default=5*1e-5, help='target learning rate')
22 | # parser.add_argument('--lr_factor', type=float, default=.75, help='decay factor')
23 | parser.add_argument('--name', type=str, default='classifier', help='name of the experiment')
24 | parser.add_argument('--resume', type=bool, default=False, help='resume training or not')
25 | parser.add_argument('--snapshot', type=str, help='snapshot file of the trainer to resume from')
26 |
27 | args = parser.parse_args()
28 |
29 | resume = args.resume
30 | device = args.device
31 |
32 | if resume:
33 | load_snapshot_path = args.snapshot
34 |
35 | experiment = args.name
36 | lr_init = args.lr_init
37 | # lr_target = args.lr_target
38 | # lr_factor = args.lr_factor
39 | # lr_trigger_interval = (args.lr_trigger, 'epoch')
40 |
41 |
42 | os.makedirs('result/'+experiment, exist_ok=True)
43 | f = open('result/'+experiment+'/details.txt',"w+")
44 | f.write("lr - "+str(lr_init)+"\n")
45 | f.write("optimizer - "+str(Adam))
46 | # f.write("lr_trigger_interval - "+str(lr_trigger_interval)+"\n")
47 | f.close()
48 |
49 | if not resume:
50 | # Add the FC layers to original FCN for GAIN
51 | model_own = FCN8s()
52 | model_original = fcn.models.FCN8s()
53 | model_file = fcn.models.FCN8s.download()
54 | chainer.serializers.load_npz(model_file, model_original)
55 |
56 | for layers in model_original._children:
57 | setattr(model_own, layers, getattr(model_original, layers))
58 | del(model_original, model_file)
59 |
60 |
61 | else:
62 | model_own = FCN8s()
63 |
64 | if device>=0:
65 | model_own.to_gpu(device)
66 |
67 | dataset = VOCSemanticSegmentationDataset()
68 | iterator = SerialIterator(dataset, 1)
69 | optimizer = Adam(alpha=lr_init)
70 | optimizer.setup(model_own)
71 |
72 | updater = VOC_ClassificationUpdater(iterator, optimizer, device=device)
73 | trainer = Trainer(updater, (100, 'epoch'))
74 | log_keys = ['epoch', 'iteration', 'main/Loss']
75 | trainer.extend(extensions.LogReport(log_keys, (100, 'iteration'), log_name='log_'+experiment))
76 | trainer.extend(extensions.PrintReport(log_keys), trigger=(100, 'iteration'))
77 | trainer.extend(extensions.snapshot(filename=experiment+"_snapshot_{.updater.iteration}"), trigger=(5, 'epoch'))
78 | trainer.extend(extensions.snapshot_object(trainer.updater._optimizers['main'].target, experiment+"_model_{.updater.iteration}"), trigger=(5, 'epoch'))
79 | trainer.extend(extensions.PlotReport(['main/Loss'], 'iteration',(100, 'iteration'), file_name='trainer_'+experiment+'/loss.png', grid=True, marker=" "))
80 |
81 | # trainer.extend(extensions.ExponentialShift('lr', lr_factor, target=lr_target), trigger=lr_trigger_interval)
82 | if resume:
83 | chainer.serializers.load_npz(load_snapshot_path, trainer)
84 |
85 | print("Running - - ", experiment)
86 | print('initial lr ', lr_init)
87 | # print('lr_trigger_interval ', lr_trigger_interval)
88 | trainer.run()
89 |
90 | if __name__ =="__main__":
91 | main()
92 |
93 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import chainer
4 | from chainer import Variable
5 | import chainer.functions as F
6 | from models.fcn8 import FCN8s
7 | from chainercv.datasets import VOCSemanticSegmentationDataset
8 | from chainer.iterators import SerialIterator
9 | from chainer.serializers import load_npz
10 | from chainer.backends.cuda import get_array_module
11 | from chainercv.datasets.voc.voc_utils import voc_semantic_segmentation_label_names
12 |
13 | from matplotlib import pyplot as plt
14 | import numpy as np
15 | import cupy as cp
16 |
17 |
18 | def main()
19 | parser = argparse.ArgumentParser(
20 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
22 | parser.add_argument('--pretrained', type=str, help='path to model that has trained classifier but has not been trained through GAIN routine')
23 | parser.add_argument('--trained', type=str, help='path to model trained through GAIN')
24 | parser.add_argument('--device', type=int, default=-1, help='gpu id')
25 | parser.add_argument('--shuffle', type=bool, default=False, help='whether to shuffle dataset')
26 | parser.add_argument('--whole', type=bool, default=False, help='whether to test for the whole validation dataset')
27 | parser.add_argument('--no', type=int, default=10, help='if not whole, then no of images to visualize')
28 | parser.add_argument('--name', type=str, help='name of the subfolder or experiment under which to save')
29 |
30 | args = parser.parse_args()
31 |
32 | pretrained_file = args.pretrained
33 | trained_file = args.trained
34 | device = args.device
35 | shuffle = args.shuffle
36 | whole = args.whole
37 | name = args.name
38 | N = args.no
39 |
40 | dataset = VOCSemanticSegmentationDataset()
41 | iterator = SerialIterator(dataset, 1, shuffle=shuffle, repeat=False)
42 | converter = chainer.dataset.concat_examples
43 | os.makedirs('viz/'+name, exist_ok=True)
44 | no_of_classes = 20
45 | device = 0
46 | pretrained = FCN8s()
47 | trainer = FCN8s()
48 | load_npz(pretrained_file, pretrained)
49 | load_npz(trained_file, trained)
50 |
51 | if device >=0:
52 | pretrained.to_gpu()
53 | trained.to_gpu()
54 | i = 0
55 |
56 | while not iterator.is_new_epoch:
57 |
58 | if not whole and i >= N:
59 | break
60 |
61 | image, labels = converter(iterator.next())
62 | image = Variable(image)
63 | if device >=0:
64 | image.to_gpu()
65 |
66 | xp = get_array_module(image.data)
67 | to_substract = np.array((-1, 0))
68 | noise_classes = np.unique(labels[0]).astype(np.int32)
69 | target = xp.asarray([[0]*(no_of_classes)])
70 | gt_labels = np.setdiff1d(noise_classes, to_substract) - 1
71 |
72 | gcam1, cl_scores1, class_id1 = pretrained.stream_cl(image, gt_labels)
73 | gcam2, cl_scores2, class_id2 = trained.stream_cl(image, gt_labels)
74 |
75 | if device>-0:
76 | class_id = cp.asnumpy(class_id)
77 | fig1 = plt.figure(figsize=(20,10))
78 | ax1= plt.subplot2grid((3, 9), (0, 0), colspan=3, rowspan=3)
79 | ax1.axis('off')
80 | ax1.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
81 |
82 | ax2= plt.subplot2grid((3, 9), (0, 3), colspan=3, rowspan=3)
83 | ax2.axis('off')
84 | ax2.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
85 | ax2.imshow(cp.asnumpy(F.squeeze(gcam1[0], 0).data), cmap='jet', alpha=.5)
86 | ax2.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id1[0])+1]), color='teal')
87 |
88 | ax3= plt.subplot2grid((3, 9), (0, 6), colspan=3, rowspan=3)
89 | ax3.axis('off')
90 | ax3.imshow(cp.asnumpy(F.transpose(F.squeeze(image, 0), (1, 2, 0)).data) / 255.)
91 | ax3.imshow(cp.asnumpy(F.squeeze(gcam2[0], 0).data), cmap='jet', alpha=.5)
92 | ax3.set_title("For class - "+str(voc_semantic_segmentation_label_names[cp.asnumpy(class_id2[0])+1]), color='teal')
93 | fig1.savefig('viz/'+name+'/'+str(i)+'.png')
94 | plt.close()
95 | print(i)
96 | i += 1
97 |
98 | if __name__ =="__main__":
99 | main()
--------------------------------------------------------------------------------
/lib/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import chainer.functions as F
3 | from chainer.dataset import download
4 | from chainer.serializers import npz
5 | from chainer.backends.cuda import get_array_module
6 | import numpy as np
7 | from PIL import Image
8 | from cupy import get_array_module
9 |
10 |
11 |
12 | def convert_caffemodel_to_npz(path_caffemodel, path_npz):
13 | from chainer.links.caffe.caffe_function import CaffeFunction
14 | caffemodel = CaffeFunction(path_caffemodel)
15 | npz.save_npz(path_npz, caffemodel, compression=False)
16 |
17 |
18 | def _make_npz(path_npz, url, model):
19 | path_caffemodel = download.cached_download(url)
20 | print('Now loading caffemodel (usually it may take few minutes)')
21 | convert_caffemodel_to_npz(path_caffemodel, path_npz)
22 | npz.load_npz(path_npz, model)
23 | return model
24 |
25 |
26 | def _retrieve(name, url, model):
27 | root = download.get_dataset_directory('pfnet/chainer/models/')
28 | path = os.path.join(root, name)
29 | return download.cache_or_load_file(
30 | path, lambda path: _make_npz(path, url, model),
31 | lambda path: npz.load_npz(path, model))
32 |
33 |
34 | def read_image(path, dtype=np.float32, color=True):
35 | """Read an image from a file.
36 | This function reads an image from given file. The image is CHW format and
37 | the range of its value is :math:`[0, 255]`. If :obj:`color = True`, the
38 | order of the channels is RGB.
39 | Args:
40 | path (string): A path of image file.
41 | dtype: The type of array. The default value is :obj:`~numpy.float32`.
42 | color (bool): This option determines the number of channels.
43 | If :obj:`True`, the number of channels is three. In this case,
44 | the order of the channels is RGB. This is the default behaviour.
45 | If :obj:`False`, this function returns a grayscale image.
46 | Returns:
47 | ~numpy.ndarray: An image.
48 | """
49 |
50 | f = Image.open(path)
51 | try:
52 | if color:
53 | img = f.convert('RGB')
54 | else:
55 | img = f.convert('P')
56 | img = np.asarray(img, dtype=dtype)
57 | finally:
58 | if hasattr(f, 'close'):
59 | f.close()
60 |
61 | return img
62 |
63 | def VGGprepare(image=None, path=None, size=(224, 224)):
64 | """Converts the given image to the numpy array for VGG models.
65 | Note that you have to call this method before ``__call__``
66 | because the pre-trained vgg model requires to resize the given image,
67 | covert the RGB to the BGR, subtract the mean,
68 | and permute the dimensions before calling.
69 | Args:
70 | image (PIL.Image or numpy.ndarray): Input image.
71 | If an input is ``numpy.ndarray``, its shape must be
72 | ``(height, width)``, ``(height, width, channels)``,
73 | or ``(channels, height, width)``, and
74 | the order of the channels must be RGB.
75 | size (pair of ints): Size of converted images.
76 | If ``None``, the given image is not resized.
77 | Returns:
78 | numpy.ndarray: The converted output array.
79 | """
80 | if path is not None:
81 | image = read_image(path)
82 | if image.ndim == 4:
83 | image = np.squeeze(image, 0)
84 | if isinstance(image, np.ndarray):
85 | if image.ndim == 3:
86 | if image.shape[0] == 1:
87 | image = image[0, :, :]
88 | elif image.shape[0] == 3:
89 | image = image.transpose((1, 2, 0))
90 | image = Image.fromarray(image.astype(np.uint8))
91 |
92 | image = image.convert('RGB')
93 | if size:
94 | image = image.resize(size)
95 | image = np.asarray(image, dtype=np.float32)
96 | image = image[:, :, ::-1]
97 | image -= np.array(
98 | [103.939, 116.779, 123.68], dtype=np.float32)
99 | image = image.transpose((2, 0, 1))
100 | return np.expand_dims(image, 0)
101 |
102 | def VGGprepare_am_input(var):
103 | xp = get_array_module(var)
104 |
105 | # var = F.resize_images(var, size)
106 | var = F.transpose(var, (0, 2, 3, 1)) # [[W, H, C]]
107 | var = F.flip(var, 3)
108 | var -= xp.array([[103.939, 116.779, 123.68]], dtype=xp.float32)
109 | var = F.transpose(var, (0, 3, 1, 2))
110 | return var
111 |
112 |
--------------------------------------------------------------------------------
/train_GAIN.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import chainer
4 | from models.fcn8 import FCN8s
5 | from chainercv.datasets import VOCSemanticSegmentationDataset
6 | from chainer.iterators import SerialIterator
7 | from chainer.training.trainer import Trainer
8 | from chainer.training import extensions
9 | from chainer.optimizers import Adam
10 | from updater import VOC_GAIN_Updater
11 |
12 | def main():
13 | parser = argparse.ArgumentParser(
14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15 | parser.add_argument('--device', type=int, default=-1, help='gpu id')
16 | parser.add_argument('--modelfile', help='pretrained model file of FCN8')
17 | parser.add_argument('--lr', type=float, default=1e-7, help='init learning rate')
18 | parser.add_argument('--name', type=str, default='GAIN', help='name of the experiment')
19 | parser.add_argument('--resume', type=bool, default=False, help='resume training or not')
20 | parser.add_argument('--snapshot', type=str, help='snapshot file to resume from')
21 | parser.add_argument('--lambda1', default=5, type=float, help='lambda1 param')
22 | parser.add_argument('--lambda2', default=1, type=float, help='lambda2 param')
23 | parser.add_argument('--lambda3', default=1.5, type=float, help='lambda3 param')
24 |
25 | args = parser.parse_args()
26 |
27 |
28 | resume = args.resume
29 | device = args.device
30 |
31 | if resume:
32 | load_snapshot_path = args.snapshot
33 | load_model_path = args.modelfile
34 | else:
35 | pretrained_model_path = args.modelfile
36 |
37 | experiment = args.name
38 | lr = args.lr
39 | optim = Adam
40 | training_interval = (20000, 'iteration')
41 | snapshot_interval = (1000, 'iteration')
42 | lambd1 = args.lambda1
43 | lambd2 = args.lambda2
44 | lambd3 = args.lambda3
45 | updtr = VOC_GAIN_Updater
46 |
47 | os.makedirs('result/'+experiment, exist_ok=True)
48 | f = open('result/'+experiment+'/details.txt', "w+")
49 | f.write("lr - "+str(lr)+"\n")
50 | f.write("optimizer - "+str(optim)+"\n")
51 | f.write("lambd1 - "+str(lambd1)+"\n")
52 | f.write("lambd2 - "+str(lambd2)+"\n")
53 | f.write("lambd3 - "+str(lambd3)+"\n")
54 | f.write("training_interval - "+str(training_interval)+"\n")
55 | f.write("Updater - "+str(updtr)+"\n")
56 | f.close()
57 |
58 | if resume:
59 | model = FCN8s()
60 | chainer.serializers.load_npz(load_model_path, model)
61 | else:
62 | model = FCN8s()
63 | chainer.serializers.load_npz(pretrained_model_path, model)
64 |
65 |
66 | if device >= 0:
67 | model.to_gpu(device)
68 | dataset = VOCSemanticSegmentationDataset()
69 | iterator = SerialIterator(dataset, 1, shuffle=False)
70 |
71 | optimizer = Adam(alpha=lr)
72 | optimizer.setup(model)
73 |
74 | updater = updtr(iterator, optimizer, device=device, lambd1=lambd1, lambd2=lambd2)
75 | trainer = Trainer(updater, training_interval)
76 | log_keys = ['epoch', 'iteration', 'main/AM_Loss', 'main/CL_Loss', 'main/TotalLoss']
77 | trainer.extend(extensions.LogReport(log_keys, (10, 'iteration'), log_name='log'+experiment))
78 | trainer.extend(extensions.PrintReport(log_keys), trigger=(100, 'iteration'))
79 | trainer.extend(extensions.ProgressBar(training_length=training_interval, update_interval=100))
80 |
81 | trainer.extend(extensions.snapshot(filename=experiment+'_snapshot_{.updater.iteration}'), trigger=snapshot_interval)
82 | trainer.extend(extensions.snapshot_object(trainer.updater._optimizers['main'].target,
83 | experiment+'_model_{.updater.iteration}'), trigger=snapshot_interval)
84 |
85 | trainer.extend(extensions.PlotReport(['main/AM_Loss'], 'iteration',(20, 'iteration'), file_name=experiment+'/am_loss.png', grid=True, marker=" "))
86 | trainer.extend(extensions.PlotReport(['main/CL_Loss'], 'iteration',(20, 'iteration'), file_name=experiment+'/cl_loss.png', grid=True, marker=" "))
87 | trainer.extend(
88 | extensions.PlotReport(['main/SG_Loss'], 'iteration', (20, 'iteration'), file_name=experiment + '/sg_loss.png',grid=True, marker=" "))
89 | trainer.extend(extensions.PlotReport(['main/TotalLoss'], 'iteration',(20, 'iteration'), file_name=experiment+'/total_loss.png', grid=True, marker=" "))
90 | trainer.extend(extensions.PlotReport(log_keys[2:], 'iteration',(20, 'iteration'), file_name=experiment+'/all_loss.png', grid=True, marker=" "))
91 |
92 | if resume:
93 | chainer.serializers.load_npz(load_snapshot_path, trainer)
94 | print("Running - - ", experiment)
95 | print('initial lr ',lr)
96 | print('optimizer ', optim)
97 | print('lambd1 ', lambd1)
98 | print('lambd2 ', lambd2)
99 | print('lambd3', lambd3)
100 | trainer.run()
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
105 |
106 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Guided Attention for FCN
2 |
3 | ## About
4 | Chainer implementation of Tell Me Where To Look.
5 | This is an experiment to apply Guided Attention Inference Network(GAIN) as presented in the paper to Fully Convolutional Networks(FCN) used for segmentation purposes. The trained FCN8s model is fine tuned using guided attention.
6 |
7 | ## GAIN
8 | GAIN is based on supervising the attention maps that is produced when we train the network for
9 | the task of interest.
10 |
11 | 
12 | ## FCN
13 | Fully Convolutional Networks is a network architecture that consists of convolution layers followed by deconvolutions to
14 | give the segmentation output
15 |
16 | 
17 | ## Approach
18 |
19 | * We take the fully trained FCN8 network and add a average pooling and fully connected layers after its convolutional layers. We freeze the convolutional layers and
20 | train the fully connected networks to classify for the objects. We do this in order to get GradCAMs for a particular class to be later used during GAIN
21 |
22 | 
23 |
24 | * Next we train the network as per the GAIN update rule. However in this implementation I have also considered the segmentation loss along with the
25 | GAIN updates/loss. This is because, I found using only the GAIN updates though did lead to convergence of losses, but also resulted in quite a significant dip in segmentation accuracies. In this step, the fully connected ayers are freezed and are not updated.
26 |
27 | ## Loss Curves
28 | ### For classification training
29 | 
30 |
31 | ### Segmentation Loss during GAIN updates
32 | 
33 |
34 |
35 | ## Qualitative Results
36 | | Original Image | PreTrained GCAMs | Post GAIN GCAMs |
37 |
38 | 
39 |
40 | 
41 |
42 | 
43 |
44 | 
45 |
46 |
47 | ## Quantitative Results
48 |
49 |
50 | ### For FCN8s
51 |
52 | | Implementation | Accuracy | Accuracy Class | Mean IU | FWAVACC | Model File |
53 | |:--------------:|:--------:|:--------------:|:-------:|:-------:|:----------:|
54 | | [Original](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/master/voc-fcn8s) | 91.2212 | 77.6146 | 65.5126 | 84.5445 | [`fcn8s_from_caffe.npz`](https://drive.google.com/uc?id=0B9P1L--7Wd2vb0cxV0VhcG1Lb28) |
55 | | Experimental| 90.5962 | **80.4099** | 64.6869 | 83.9952 | **To make public soon** |
56 |
57 | ## How to use
58 | ```bash
59 | pip install chainer
60 | pip install chainercv
61 | pip install cupy
62 | pip install fcn
63 | ```
64 | Training
65 | --------
66 | For training the classifier, download. the pretrained FCN8s chainer model
67 | ```bash
68 | python3 train_classifier.py --device 0
69 | ```
70 | This will automatically download the pretrained file and train the classifier on it. You might run into an error of " xxx.txt file not found " while running this script. To solve this, at the place where your `fcn` library is installed, get the missing file from the fcn repository over github, and take care to put the exisiting file by making the same directory structure as asked in the error message. For more details, refer to this issue
71 |
72 |
73 | For GAIN updates,
74 | ```bash
75 | python3 train_GAIN.py --modelfile --device 0
76 | ```
77 |
78 | The accuracy of original implementation is computed with (`evaluate.py `) which has been borrowed from wkentaro's implementation
79 |
80 | Visualization
81 | -------------
82 | ```bash
83 | visualize.py
84 | ```
85 | required arguements -
86 | ```
87 | --pretrained
88 | --trained
89 | ```
90 |
91 | optional arguements -
92 | ```
93 | --device=-1
94 | --whole=False < whether to test on whole valid dataset>
95 | --shuffle=False
96 | --no=10
97 | ```
98 |
99 | ## To Do
100 |
101 | - [x] Push Visualization Code
102 |
103 | ## Using GAIN for other models
104 | I have attempted to make GAIN as modular as possible so that it can be used on some other model as well. All you would need to do is make GAIN class( which itself inherits chainer.Chain) as parent class to your model.
105 | Each GAIN model needs to have a few particular instance variables in order to be able to function. GAIN module has methods to instantiate every single one of them. I would advice you to lookup ```models/fcn8.py``` as well as ```GAIN.py``` to have an idea about them.
106 |
107 | * GAIN_functions - An ordered dict consisting of names of steps and it's associated functions.
108 | * final_conv_layer - Name of the step after which no convolutions happen
109 | * grad_target_layer - Name of the step from where gradients are to be collected for computing GradCAM
110 |
111 |
112 | ## Credits
113 | The original FCN module and the fcn package is courtesy of wkentaro
114 |
115 | ## Citation
116 | If you find this code useful in your research, please consider citing:
117 |
118 | @misc{Alok2018,
119 | Author = {Bishoyi, Alok Kumar},
120 | Title = {Guided Attention Inference Network},
121 | Year = {2018},
122 | Publisher = {GitHub},
123 | journal = {Github repository},
124 | howpublished = {\url{https://github.com/alokwhitewolf/Guided-Attention-Inference-Network}},
125 | }
126 |
--------------------------------------------------------------------------------
/models/fcn8.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import collections
3 |
4 | import chainer
5 | import chainer.functions as F
6 | import chainer.links as L
7 | import numpy as np
8 | from GAIN import GAIN
9 |
10 | from fcn import data
11 | from fcn import initializers
12 |
13 |
14 | class FCN8s(GAIN):
15 |
16 | def __init__(self, n_class=21):
17 | self.n_class = n_class
18 | kwargs = {
19 | 'initialW': chainer.initializers.Zero(),
20 | 'initial_bias': chainer.initializers.Zero(),
21 | }
22 | super(FCN8s, self).__init__()
23 | with self.init_scope():
24 | self.conv1_1 = L.Convolution2D(3, 64, 3, 1, 1)
25 | self.conv1_2 = L.Convolution2D(64, 64, 3, 1, 1)
26 |
27 | self.conv2_1 = L.Convolution2D(64, 128, 3, 1, 1)
28 | self.conv2_2 = L.Convolution2D(128, 128, 3, 1, 1)
29 |
30 | self.conv3_1 = L.Convolution2D(128, 256, 3, 1, 1)
31 | self.conv3_2 = L.Convolution2D(256, 256, 3, 1, 1)
32 | self.conv3_3 = L.Convolution2D(256, 256, 3, 1, 1)
33 |
34 | self.conv4_1 = L.Convolution2D(256, 512, 3, 1, 1)
35 | self.conv4_2 = L.Convolution2D(512, 512, 3, 1, 1)
36 | self.conv4_3 = L.Convolution2D(512, 512, 3, 1, 1)
37 |
38 | self.conv5_1 = L.Convolution2D(512, 512, 3, 1, 1)
39 | self.conv5_2 = L.Convolution2D(512, 512, 3, 1, 1)
40 | self.conv5_3 = L.Convolution2D(512, 512, 3, 1, 1)
41 |
42 | self.fc6 = L.Convolution2D(512, 4096, 7, 1, 0)
43 | self.fc7 = L.Convolution2D(4096, 4096, 1, 1, 0)
44 | self.score_fr = L.Convolution2D(4096, n_class, 1, 1, 0)
45 |
46 | self.fc6_cl = L.Linear(512, 4096)
47 | self.fc7_cl = L.Linear(4096, 4096)
48 | self.score_cl = L.Linear(4096, n_class-1) # Disregard 0 class for classification
49 |
50 |
51 | self.upscore2 = L.Deconvolution2D(
52 | n_class, n_class, 4, 2, 0, nobias=True,
53 | initialW=initializers.UpsamplingDeconvWeight())
54 | self.upscore8 = L.Deconvolution2D(
55 | n_class, n_class, 16, 8, 0, nobias=True,
56 | initialW=initializers.UpsamplingDeconvWeight())
57 |
58 | self.score_pool3 = L.Convolution2D(256, n_class, 1, 1, 0)
59 | self.score_pool4 = L.Convolution2D(512, n_class, 1, 1, 0)
60 | self.upscore_pool4 = L.Deconvolution2D(
61 | n_class, n_class, 4, 2, 0, nobias=True,
62 | initialW=initializers.UpsamplingDeconvWeight())
63 |
64 | self.GAIN_functions = collections.OrderedDict([
65 | ('conv1_1', [self.conv1_1, F.relu]),
66 | ('conv1_2', [self.conv1_2, F.relu]),
67 | ('pool1', [_max_pooling_2d]),
68 |
69 | ('conv2_1', [self.conv2_1, F.relu]),
70 | ('conv2_2', [self.conv2_2, F.relu]),
71 | ('pool2', [_max_pooling_2d]),
72 |
73 | ('conv3_1', [self.conv3_1, F.relu]),
74 | ('conv3_2', [self.conv3_2, F.relu]),
75 | ('conv3_3', [self.conv3_3, F.relu]),
76 | ('pool3', [_max_pooling_2d]),
77 |
78 | ('conv4_1', [self.conv4_1, F.relu]),
79 | ('conv4_2', [self.conv4_2, F.relu]),
80 | ('conv4_3', [self.conv4_3, F.relu]),
81 | ('pool4', [_max_pooling_2d]),
82 |
83 | ('conv5_1', [self.conv5_1, F.relu]),
84 | ('conv5_2', [self.conv5_2, F.relu]),
85 | ('conv5_3', [self.conv5_3, F.relu]),
86 | ('pool5', [_max_pooling_2d]),
87 |
88 | ('avg_pool', [_average_pooling_2d]),
89 |
90 | ('fc6_cl', [self.fc6_cl, F.relu]),
91 | ('fc7_cl', [self.fc7_cl, F.relu]),
92 | ('prob', [self.score_cl, F.sigmoid])
93 |
94 | ])
95 | self.final_conv_layer = 'conv5_3'
96 | self.grad_target_layer = 'prob'
97 | self.freezed_layers = ['fc6_cl', 'fc7_cl', 'score_cl']
98 |
99 | def segment(self, x, t=None):
100 | # conv1
101 | self.conv1_1.pad = (100, 100)
102 | h = F.relu(self.conv1_1(x))
103 | conv1_1 = h
104 | h = F.relu(self.conv1_2(conv1_1))
105 | conv1_2 = h
106 | h = _max_pooling_2d(conv1_2)
107 | pool1 = h # 1/2
108 |
109 | # conv2
110 | h = F.relu(self.conv2_1(pool1))
111 | conv2_1 = h
112 | h = F.relu(self.conv2_2(conv2_1))
113 | conv2_2 = h
114 | h = _max_pooling_2d(conv2_2)
115 | pool2 = h # 1/4
116 |
117 | # conv3
118 | h = F.relu(self.conv3_1(pool2))
119 | conv3_1 = h
120 | h = F.relu(self.conv3_2(conv3_1))
121 | conv3_2 = h
122 | h = F.relu(self.conv3_3(conv3_2))
123 | conv3_3 = h
124 | h = _max_pooling_2d(conv3_3)
125 | pool3 = h # 1/8
126 |
127 | # conv4
128 | h = F.relu(self.conv4_1(pool3))
129 | h = F.relu(self.conv4_2(h))
130 | h = F.relu(self.conv4_3(h))
131 | h = _max_pooling_2d(h)
132 | pool4 = h # 1/16
133 |
134 | # conv5
135 | h = F.relu(self.conv5_1(pool4))
136 | h = F.relu(self.conv5_2(h))
137 | h = F.relu(self.conv5_3(h))
138 | h = _max_pooling_2d(h)
139 | pool5 = h # 1/32
140 |
141 | # fc6
142 | h = F.relu(self.fc6(pool5))
143 | h = F.dropout(h, ratio=.5)
144 | fc6 = h # 1/32
145 |
146 | # fc7
147 | h = F.relu(self.fc7(fc6))
148 | h = F.dropout(h, ratio=.5)
149 | fc7 = h # 1/32
150 |
151 | # score_fr
152 | h = self.score_fr(fc7)
153 | score_fr = h # 1/32
154 |
155 | # score_pool3
156 | h = self.score_pool3(pool3)
157 | score_pool3 = h # 1/8
158 |
159 | # score_pool4
160 | h = self.score_pool4(pool4)
161 | score_pool4 = h # 1/16
162 |
163 | # upscore2
164 | h = self.upscore2(score_fr)
165 | upscore2 = h # 1/16
166 |
167 | # score_pool4c
168 | h = score_pool4[:, :,
169 | 5:5 + upscore2.shape[2],
170 | 5:5 + upscore2.shape[3]]
171 | score_pool4c = h # 1/16
172 |
173 | # fuse_pool4
174 | h = upscore2 + score_pool4c
175 | fuse_pool4 = h # 1/16
176 |
177 | # upscore_pool4
178 | h = self.upscore_pool4(fuse_pool4)
179 | upscore_pool4 = h # 1/8
180 |
181 | # score_pool4c
182 | h = score_pool3[:, :,
183 | 9:9 + upscore_pool4.shape[2],
184 | 9:9 + upscore_pool4.shape[3]]
185 | score_pool3c = h # 1/8
186 |
187 | # fuse_pool3
188 | h = upscore_pool4 + score_pool3c
189 | fuse_pool3 = h # 1/8
190 |
191 | # upscore8
192 | h = self.upscore8(fuse_pool3)
193 | upscore8 = h # 1/1
194 |
195 | # score
196 | h = upscore8[:, :, 31:31 + x.shape[2], 31:31 + x.shape[3]]
197 | score = h # 1/1
198 | self.score = score
199 |
200 | if t is None:
201 | assert not chainer.config.train
202 | return
203 |
204 | loss = F.softmax_cross_entropy(score, t, normalize=True)
205 | if np.isnan(float(loss.data)):
206 | raise ValueError('Loss is nan.')
207 | chainer.report({'loss': loss}, self)
208 | self.conv1_1.pad = (1, 1)
209 | return loss
210 |
211 |
212 | def classify(self, x, is_training=True):
213 | with chainer.using_config('train',False):
214 | # conv1
215 | h = F.relu(self.conv1_1(x))
216 | h = F.relu(self.conv1_2(h))
217 | h = _max_pooling_2d(h)
218 |
219 | # conv2
220 | h = F.relu(self.conv2_1(h))
221 | h = F.relu(self.conv2_2(h))
222 | h = _max_pooling_2d(h)
223 |
224 | # conv3
225 | h = F.relu(self.conv3_1(h))
226 | h = F.relu(self.conv3_2(h))
227 | h = F.relu(self.conv3_3(h))
228 | h = _max_pooling_2d(h)
229 |
230 | # conv4
231 | h = F.relu(self.conv4_1(h))
232 | h = F.relu(self.conv4_2(h))
233 | h = F.relu(self.conv4_3(h))
234 | h = _max_pooling_2d(h)
235 |
236 | # conv5
237 | h = F.relu(self.conv5_1(h))
238 | h = F.relu(self.conv5_2(h))
239 | h = F.relu(self.conv5_3(h))
240 | h = _max_pooling_2d(h)
241 | h = _average_pooling_2d(h)
242 |
243 | with chainer.using_config('train',is_training):
244 | h = F.relu(F.dropout(self.fc6_cl(h), .5))
245 | h = F.relu(F.dropout(self.fc7_cl(h), .5))
246 | h = self.score_cl(h)
247 |
248 | return h
249 |
250 |
251 | def __call__(self, x, t=None):
252 | return self.segment(x, t)
253 |
254 |
255 | def predict(self, imgs):
256 | lbls = []
257 | for img in imgs:
258 | with chainer.no_backprop_mode(), chainer.using_config('train', False):
259 | x = self.xp.asarray(img[None])
260 | self.__call__(x)
261 | lbl = chainer.functions.argmax(self.score, axis=1)
262 | lbl = chainer.cuda.to_cpu(lbl.array[0])
263 | lbls.append(lbl)
264 | return lbls
265 |
266 | def _max_pooling_2d(x):
267 | return F.max_pooling_2d(x, ksize=2, stride=2, pad=0)
268 |
269 | def _average_pooling_2d(x):
270 | return F.average_pooling_2d(x, ksize=(x.shape[-2], x.shape[-1]))
271 |
--------------------------------------------------------------------------------