├── data ├── .gitkeep └── label_coco.txt ├── .pep8 ├── imgs ├── 008.jpg ├── 082.jpg ├── key.jpg └── fpn_008.jpg ├── .gitmodules ├── chainer_maskrcnn ├── utils │ ├── depth_transformer.py │ ├── proposal_target_creator.py │ └── proposal_creator.py ├── functions │ └── roi_align_2d_yx.py ├── model │ ├── extractor │ │ ├── c4_backbone.py │ │ ├── darknet.py │ │ └── feature_pyramid_network.py │ ├── head │ │ ├── resnet_roi_mask_head.py │ │ ├── fpn_roi_mask_head.py │ │ ├── fpn_roi_keypoint_head.py │ │ └── light_roi_mask_head.py │ ├── fpn_maskrcnn_train_chain.py │ ├── maskrcnn_train_chain.py │ ├── rpn │ │ └── multilevel_region_proposal_network.py │ └── maskrcnn.py └── dataset │ ├── depth_dataset.py │ └── coco_dataset.py ├── .gitignore ├── README.md ├── evaluator.py ├── vis.py ├── viewer.py ├── train.py ├── train_keypoints.py └── LICENSE /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.pep8: -------------------------------------------------------------------------------- 1 | [pep8] 2 | exclude=caffe_pb*,.eggs,*.egg,build 3 | 4 | -------------------------------------------------------------------------------- /imgs/008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katotetsuro/chainer-maskrcnn/HEAD/imgs/008.jpg -------------------------------------------------------------------------------- /imgs/082.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katotetsuro/chainer-maskrcnn/HEAD/imgs/082.jpg -------------------------------------------------------------------------------- /imgs/key.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katotetsuro/chainer-maskrcnn/HEAD/imgs/key.jpg -------------------------------------------------------------------------------- /imgs/fpn_008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/katotetsuro/chainer-maskrcnn/HEAD/imgs/fpn_008.jpg -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "roi_align"] 2 | path = chainer_maskrcnn/functions/roi_align 3 | url = https://github.com/katotetsuro/roi_align.git 4 | -------------------------------------------------------------------------------- /chainer_maskrcnn/utils/depth_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DepthTransformer(): 5 | def __call__(self, in_data): 6 | x, bbox, keypoint = in_data 7 | 8 | x += (np.random.rand(1).astype(np.float32) - 0.5) * 30 9 | 10 | return x, bbox, keypoint 11 | -------------------------------------------------------------------------------- /chainer_maskrcnn/functions/roi_align_2d_yx.py: -------------------------------------------------------------------------------- 1 | from .roi_align.roi_align_2d import roi_align_2d 2 | 3 | 4 | def _roi_align_2d_yx(x, indices_and_rois, outh, outw, spatial_scale): 5 | xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]] 6 | pool = roi_align_2d(x, xy_indices_and_rois, outh, outw, spatial_scale) 7 | return pool 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pkl 3 | .ipynb_checkpoints 4 | *.ipynb 5 | result 6 | data/**/*.jpg 7 | data/**/*.json 8 | *.zip 9 | *.swp 10 | .DS_Store 11 | data/annotations 12 | data/train2014 13 | data/val2014 14 | *.profile 15 | *.nvvp 16 | data/annotations 17 | data/train2014 18 | data/val2014 19 | data/rgbd 20 | *.stats 21 | *.bag 22 | libs 23 | src 24 | .vscode 25 | -------------------------------------------------------------------------------- /data/label_coco.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /chainer_maskrcnn/model/extractor/c4_backbone.py: -------------------------------------------------------------------------------- 1 | from chainer.links.model.vision.resnet import ResNet50Layers 2 | import collections 3 | import chainer.functions as F 4 | from chainer.links import BatchNormalization 5 | 6 | 7 | class C4Backbone(ResNet50Layers): 8 | def __init__(self, pretrained_model): 9 | super().__init__(pretrained_model) 10 | del self.res5 11 | del self.fc6 12 | 13 | for l in self.links(): 14 | if isinstance(l, BatchNormalization): 15 | l.disable_update() 16 | 17 | @property 18 | def functions(self): 19 | return collections.OrderedDict( 20 | [('conv1', [self.conv1, self.bn1, F.relu]), 21 | ('pool1', [lambda x: F.max_pooling_2d(x, ksize=3, stride=2)]), 22 | ('res2', [self.res2]), ('res3', [self.res3]), ('res4', 23 | [self.res4])]) 24 | 25 | def __call__(self, x, **kwargs): 26 | return super().__call__(x, ['res4'], **kwargs)['res4'], 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # chainer-maskrcnn 2 | 3 | ## original paper 4 | Mask R-CNN http://arxiv.org/abs/1703.06870 5 | 6 | Light-Head R-CNN: http://arxiv.org/abs/1711.07264 7 | 8 | # current status 9 | 10 | ## using LightHead architecture 11 | 12 | Good examples :) 13 | 14 | ![](imgs/008.jpg) 15 | ![](imgs/082.jpg) 16 | 17 | many results are available here. 18 | https://drive.google.com/drive/u/1/folders/1BwYDFdGpaRNWTU2HyV18VuDvwqndv0e_ 19 | 20 | ## Feature Pyramid network 21 | 22 | mask accuracy looks better than above. 23 | 24 | ![](imgs/fpn_008.jpg) 25 | 26 | ## keypoint 27 | 28 | 120000 iter trained 29 | 30 | ![](imgs/key.jpg) 31 | 32 | # todo, issues 33 | 34 | - 学習を進めるとどんどんメモリ使用量が増えていく問題があるので、現在調査中です。 35 | メモリ64GBのp2.xlargeインスタンスで学習したところ、データを1周したあとで増加は止まったので、そうゆう挙動で正しいのかもしれません。 36 | ただ、メモリ16GBのマシンだと動かないのは個人的に困っているので、調査を続けます。 37 | - add prediction notebook 38 | - use COCO 2017 39 | - currently, only FPN backbone will work(I have backward compatibility broke down) 40 | 41 | # setup 42 | 43 | python=3.6 44 | 45 | ``` 46 | pip install chainer chainercv chainerui cupy cython 47 | pip install -e 'git+https://github.com/pdollar/coco.git#egg=pycocotools&subdirectory=PythonAPI' 48 | ``` 49 | 2行目については https://github.com/cocodataset/cocoapi/issues/53#issuecomment-306667323 50 | 51 | 52 | MSCOCOをダウンロードしてdata以下に展開 53 | - train2014.zip 54 | - val2014.zip 55 | - annotations_trainval2014.zip 56 | 57 | # pretrained model 58 | 59 | | architecture | url | 60 | |:-----------|:------------:| 61 | | light head | https://drive.google.com/file/d/10tBJpWkimyr5r_DZ8wXsKPsb7-zm_7BT/view?usp=sharing | 62 | 63 | 64 | # acknowledgement 65 | 66 | [chainercvを用いたMask R-CNNの実装](https://engineer.dena.jp/2017/12/chainercvmask-r-cnn.html) 67 | 68 | MaskRCNN, ProposalTargetCreatorの実装で思いっきり参考にさせていただいています。 69 | -------------------------------------------------------------------------------- /chainer_maskrcnn/dataset/depth_dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import chainer 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | class DepthDataset(chainer.dataset.DatasetMixin): 8 | n_keypoints = 20 9 | 10 | def __init__(self, path, root='.'): 11 | super().__init__() 12 | with open(path, 'r') as f: 13 | self.data = list(map(lambda x: x.strip(), f.readlines())) 14 | self.root = root 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | def get_example(self, index): 20 | if index >= self.__len__(): 21 | raise IndexError 22 | 23 | with np.load(Path(self.root).joinpath(self.data[index])) as f: 24 | img = f['depth'] 25 | keypoints = f['keypoints'] 26 | 27 | h, w = img.shape 28 | h -= 1 29 | w -= 1 30 | keypoints[:, :2] = np.clip(keypoints[:, :2], 0, [h, w]) 31 | 32 | if keypoints.shape[1] == 2: 33 | visible = np.zeros((len(keypoints))).reshape((-1, 1)) 34 | visible.fill(2) 35 | keypoints = np.concatenate((keypoints, visible), axis=1) 36 | else: 37 | keypoints[:, 2] = (keypoints[:, 2] > 0.2) * 2 38 | 39 | assert keypoints.shape[1] == 3 40 | 41 | if len(keypoints) > 20: 42 | print(f'複数人が写っています {self.data[index]}') 43 | 44 | # compute bounding box 45 | x0 = np.clip( 46 | np.min(keypoints[:, :2], axis=0) - [10, 10], 0, [h, w]) 47 | x1 = np.clip(np.max(keypoints[:, :2], axis=0) + [0, 10], 0, [h, w]) 48 | bbox = np.concatenate([x0, x1]).reshape((1, 4)) 49 | 50 | # keypointの(y,x)の順番をあえて逆にしておくという辻褄合わせ 51 | keypoints[:, :2] = keypoints[:, [1, 0]] 52 | # (number of box, numberof keypoints, (x,y,visibility)) 53 | keypoints = keypoints[None] 54 | 55 | # make number of channels 3 56 | # なんとなく[0,255]くらいのfloatの配列にしておく 57 | # FasterRCNNのprepareメソッドで /255されるという複雑さ 58 | img = (img.astype(np.float32) - 1000) / 3000 * 255 59 | img = np.stack([img, img, img]) 60 | 61 | return img, bbox, keypoints 62 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/extractor/darknet.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.links as L 3 | import chainer.functions as F 4 | 5 | 6 | class ConvBatch(chainer.Chain): 7 | def __init__(self, out_channels, ksize, stride, pad, activation): 8 | super().__init__() 9 | with self.init_scope(): 10 | self.c = L.Convolution2D(in_channels=None, out_channels=out_channels, 11 | ksize=ksize, stride=stride, pad=pad) 12 | self.bn = L.BatchNormalization(size=out_channels) 13 | self.activation = activation 14 | 15 | def __call__(self, x): 16 | return self.activation(self.bn(self.c(x))) 17 | 18 | 19 | class Darknet(chainer.Chain): 20 | # determined by network architecture (where stride >1 occurs.) 21 | feat_strides = [16] 22 | # inverse of feat_strides. used in RoIAlign to calculate x in Image Coord to x' in feature map 23 | spatial_scales = list(map(lambda x: 1. / x, feat_strides)) 24 | anchor_base = 16 # from original implementation. why? 25 | # from FPN paper. 26 | anchor_sizes = [64] 27 | # anchor_sizes / anchor_base anchor_base is invisible from lamba function?? 28 | anchor_scales = list(map(lambda x: x / 16., anchor_sizes)) 29 | 30 | def __init__(self, activation=F.relu): 31 | super().__init__() 32 | 33 | with self.init_scope(): 34 | self.conv1 = ConvBatch(16, 3, 1, 1, activation) 35 | self.conv2 = ConvBatch(32, 3, 1, 1, activation) 36 | self.conv3 = ConvBatch(64, 3, 1, 1, activation) 37 | self.conv4 = ConvBatch(128, 3, 1, 1, activation) 38 | self.conv5 = ConvBatch(256, 3, 1, 1, activation) 39 | # self.conv6 = ConvBatch(512, 3, 1, 1) 40 | # self.conv7 = ConvBatch(1024, 3, 1, 1) 41 | # anchor_sizes / anchor_base anchor_base is invisible from lamba function?? 42 | self.anchor_scales = list( 43 | map(lambda x: x / float(self.anchor_base), self.anchor_sizes)) 44 | 45 | def __call__(self, x): 46 | h = self.conv1(x) 47 | h = F.max_pooling_2d(h, ksize=2, stride=2) 48 | h = self.conv2(h) 49 | h = F.max_pooling_2d(h, ksize=2, stride=2) 50 | h = self.conv3(h) 51 | h = F.max_pooling_2d(h, ksize=2, stride=2) 52 | h = self.conv4(h) 53 | h = F.max_pooling_2d(h, ksize=2, stride=2) 54 | h = self.conv5(h) 55 | # h = F.max_pooling_2d(h, ksize=2, stride=2) 56 | # h = self.conv6(h) 57 | # h = F.max_pooling_2d(h, ksize=2, stride=2) 58 | # h = self.conv7(h) 59 | 60 | return h, 61 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/head/resnet_roi_mask_head.py: -------------------------------------------------------------------------------- 1 | # ResNet50を使ったMaskHeadの実装 2 | import chainer 3 | import chainer.links as L 4 | import chainer.functions as F 5 | from chainer.links.model.vision.resnet import ResNet50Layers, BuildingBlock, _global_average_pooling_2d 6 | import numpy as np 7 | import copy 8 | from chainer_maskrcnn.functions.roi_align_2d_yx import _roi_align_2d_yx 9 | 10 | 11 | class ResnetRoIMaskHead(chainer.Chain): 12 | mask_size = 14 13 | 14 | def __init__(self, 15 | n_class, 16 | roi_size, 17 | spatial_scale, 18 | loc_initialW=None, 19 | score_initialW=None, 20 | mask_initialW=None): 21 | # n_class includes the background 22 | super().__init__() 23 | with self.init_scope(): 24 | # res5ブロックがほしいだけなのに全部読み込むのは無駄ではある 25 | resnet50 = ResNet50Layers() 26 | self.res5 = copy.deepcopy(resnet50.res5) 27 | # strideは1にする 28 | self.res5.a.conv1.stride = (1, 1) 29 | self.res5.a.conv4.stride = (1, 1) 30 | # 論文 図3の左から2つめ 31 | self.conv1 = L.Convolution2D( 32 | in_channels=None, out_channels=2048, ksize=3, stride=1, pad=1) 33 | # マスク推定ブランチへ 34 | self.deconv1 = L.Deconvolution2D( 35 | in_channels=None, 36 | out_channels=256, 37 | ksize=2, 38 | stride=2, 39 | pad=0, 40 | initialW=mask_initialW) 41 | self.conv2 = L.Convolution2D( 42 | in_channels=None, 43 | out_channels=n_class - 1, 44 | ksize=3, 45 | stride=1, 46 | pad=1, 47 | initialW=mask_initialW) 48 | 49 | self.cls_loc = L.Linear(2048, n_class * 4, initialW=loc_initialW) 50 | self.score = L.Linear(2048, n_class, initialW=score_initialW) 51 | 52 | self.n_class = n_class 53 | self.roi_size = roi_size 54 | self.spatial_scale = spatial_scale 55 | 56 | def __call__(self, x, rois, roi_indices, spatial_scale): 57 | roi_indices = roi_indices.astype(np.float32) 58 | indices_and_rois = self.xp.concatenate( 59 | (roi_indices[:, None], rois), axis=1) 60 | 61 | pool = _roi_align_2d_yx(x, indices_and_rois, self.roi_size, 62 | self.roi_size, spatial_scale) 63 | 64 | # h: 分岐する直前まで 65 | h = F.relu(self.res5(pool)) 66 | h = F.relu(self.conv1(h)) 67 | # global average pooling 68 | gap = _global_average_pooling_2d(h) 69 | roi_cls_locs = self.cls_loc(gap) 70 | roi_scores = self.score(gap) 71 | # mask 72 | mask = self.conv2(F.relu(self.deconv1(h))) 73 | return roi_cls_locs, roi_scores, mask 74 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/extractor/feature_pyramid_network.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.links as L 3 | import chainer.functions as F 4 | from chainer.links.model.vision.resnet import ResNet50Layers 5 | 6 | 7 | class FeaturePyramidNetwork(chainer.Chain): 8 | # determined by network architecture (where stride >1 occurs.) 9 | feat_strides = [4, 8, 16, 32, 64] 10 | # inverse of feat_strides. used in RoIAlign to calculate x in Image Coord to x' in feature map 11 | spatial_scales = list(map(lambda x: 1./x, feat_strides)) 12 | anchor_base = 16 # from original implementation. why? 13 | # from FPN paper. 14 | anchor_sizes = [32, 64, 128, 256, 512] 15 | # anchor_sizes / anchor_base anchor_base is invisible from lamba function?? 16 | anchor_scales = list(map(lambda x: x/16., anchor_sizes)) 17 | 18 | def __init__(self): 19 | super().__init__() 20 | with self.init_scope(): 21 | # bottom up 22 | self.resnet = ResNet50Layers('auto') 23 | del self.resnet.fc6 24 | # top layer (reduce channel) 25 | self.toplayer = L.Convolution2D( 26 | in_channels=None, out_channels=256, ksize=1, stride=1, pad=0) 27 | 28 | # conv layer for top-down pathway 29 | self.conv_p4 = L.Convolution2D(None, 256, ksize=3, stride=1, pad=1) 30 | self.conv_p3 = L.Convolution2D(None, 256, ksize=3, stride=1, pad=1) 31 | self.conv_p2 = L.Convolution2D(None, 256, ksize=3, stride=1, pad=1) 32 | self.conv_p6 = L.Convolution2D(None, 256, ksize=1, stride=2, pad=0) 33 | 34 | # lateral connection 35 | self.lat_p4 = L.Convolution2D( 36 | in_channels=None, out_channels=256, ksize=1, stride=1, pad=0) 37 | self.lat_p3 = L.Convolution2D( 38 | in_channels=None, out_channels=256, ksize=1, stride=1, pad=0) 39 | self.lat_p2 = L.Convolution2D( 40 | in_channels=None, out_channels=256, ksize=1, stride=1, pad=0) 41 | 42 | # anchor_sizes / anchor_base anchor_base is invisible from lamba function?? 43 | self.anchor_scales = list( 44 | map(lambda x: x/float(self.anchor_base), self.anchor_sizes)) 45 | 46 | def __call__(self, x): 47 | # bottom-up pathway 48 | h = F.relu(self.resnet.bn1(self.resnet.conv1(x))) 49 | h = F.max_pooling_2d(h, ksize=(2, 2)) 50 | c2 = self.resnet.res2(h) 51 | c3 = self.resnet.res3(c2) 52 | c4 = self.resnet.res4(c3) 53 | c5 = self.resnet.res5(c4) 54 | 55 | # top-down 56 | p5 = self.toplayer(c5) 57 | p4 = self.conv_p4( 58 | F.unpooling_2d(p5, ksize=2, outsize=( 59 | c4.shape[2:4])) + self.lat_p4(c4)) 60 | p3 = self.conv_p3( 61 | F.unpooling_2d(p4, ksize=2, outsize=( 62 | c3.shape[2:4])) + self.lat_p3(c3)) 63 | p2 = self.conv_p2( 64 | F.unpooling_2d(p3, ksize=2, outsize=( 65 | c2.shape[2:4])) + self.lat_p2(c2)) 66 | 67 | # other 68 | p6 = self.conv_p6(p5) 69 | 70 | # fine to coarse 71 | return p2, p3, p4, p5, p6 72 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/head/fpn_roi_mask_head.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.links as L 3 | import chainer.functions as F 4 | from chainer.links.model.vision.resnet import ResNet50Layers, BuildingBlock, _global_average_pooling_2d 5 | import numpy as np 6 | import copy 7 | from chainer_maskrcnn.functions.roi_align_2d_yx import _roi_align_2d_yx 8 | 9 | 10 | class FPNRoIMaskHead(chainer.Chain): 11 | mask_size = 28 12 | 13 | def __init__(self, 14 | n_class, 15 | roi_size_box, 16 | roi_size_mask, 17 | loc_initialW=None, 18 | score_initialW=None, 19 | mask_initialW=None): 20 | # n_class includes the background 21 | super().__init__() 22 | with self.init_scope(): 23 | # layers for box prediction path 24 | self.conv1 = L.Convolution2D( 25 | in_channels=None, out_channels=256, ksize=3, pad=1) 26 | self.fc1 = L.Linear(None, 1024) 27 | self.fc2 = L.Linear(None, 1024) 28 | self.cls_loc = L.Linear(1024, 4, initialW=loc_initialW) 29 | self.score = L.Linear(1024, n_class, initialW=score_initialW) 30 | 31 | # mask prediction path 32 | self.mask1 = L.Convolution2D(None, 256, ksize=3, pad=1) 33 | self.mask2 = L.Convolution2D(None, 256, ksize=3, pad=1) 34 | self.mask3 = L.Convolution2D(None, 256, ksize=3, pad=1) 35 | self.mask4 = L.Convolution2D(None, 256, ksize=3, pad=1) 36 | self.deconv1 = L.Deconvolution2D( 37 | in_channels=None, 38 | out_channels=256, 39 | ksize=2, 40 | stride=2, 41 | pad=0, 42 | initialW=mask_initialW) 43 | self.conv2 = L.Convolution2D( 44 | in_channels=None, 45 | out_channels=n_class - 1, 46 | ksize=1, 47 | stride=1, 48 | pad=0, 49 | initialW=mask_initialW) 50 | 51 | self.n_class = n_class 52 | self.roi_size_box = roi_size_box 53 | self.roi_size_mask = roi_size_mask 54 | 55 | def __call__(self, x, indices_and_rois, levels, spatial_scales): 56 | 57 | pool_box = list() 58 | levels = chainer.cuda.to_cpu(levels).astype(np.int32) 59 | for l, i in zip(levels, indices_and_rois): 60 | pool_box.append(_roi_align_2d_yx(x[l], i[None], self.roi_size_box, 61 | self.roi_size_box, spatial_scales[l])) 62 | 63 | pool_box = F.concat(pool_box, axis=0) 64 | 65 | h = F.relu(self.conv1(pool_box)) 66 | h = F.relu(self.fc1(h)) 67 | h = F.relu(self.fc2(h)) 68 | roi_cls_locs = self.cls_loc(h) 69 | roi_scores = self.score(h) 70 | # at prediction time, we use two pass method. 71 | # at first path, we predict box location and class 72 | # at second path, we predict mask with accurate location from first path 73 | if chainer.config.train: 74 | pool_mask = list() 75 | for l, i in zip(levels, indices_and_rois): 76 | pool_mask.append(_roi_align_2d_yx(x[l], i[None], self.roi_size_mask, 77 | self.roi_size_mask, spatial_scales[l])) 78 | pool_mask = F.concat(pool_mask, axis=0) 79 | mask = F.relu(self.mask1(pool_mask)) 80 | mask = F.relu(self.mask2(mask)) 81 | mask = F.relu(self.mask3(mask)) 82 | mask = F.relu(self.mask4(mask)) 83 | mask = self.conv2(self.deconv1(mask)) 84 | return roi_cls_locs, roi_scores, mask 85 | else: 86 | # cache 87 | self.x = x 88 | return roi_cls_locs, roi_scores 89 | 90 | def predict_mask(self, levels, indices_and_rois, spatial_scales): 91 | pool_mask = list() 92 | for l, i in zip(levels, indices_and_rois): 93 | pool_mask.append(_roi_align_2d_yx(self.x[l], i[None], self.roi_size_mask, 94 | self.roi_size_mask, spatial_scales[l])) 95 | pool_mask = F.concat(pool_mask, axis=0) 96 | mask = F.relu(self.mask1(pool_mask)) 97 | mask = F.relu(self.mask2(mask)) 98 | mask = F.relu(self.mask3(mask)) 99 | mask = F.relu(self.mask4(mask)) 100 | mask = self.conv2(self.deconv1(mask)) 101 | 102 | return mask 103 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | from chainer import reporter 5 | import chainer.training.extensions 6 | 7 | from chainercv.evaluations import eval_instance_segmentation_voc 8 | from chainercv.utils import apply_to_iterator 9 | 10 | 11 | class InstanceSegmentationVOCEvaluator(chainer.training.extensions.Evaluator): 12 | 13 | """An evaluation extension of instance-segmentation by PASCAL VOC metric. 14 | 15 | This extension iterates over an iterator and evaluates the prediction 16 | results by average precisions (APs) and mean of them 17 | (mean Average Precision, mAP). 18 | This extension reports the following values with keys. 19 | Please note that :obj:`'ap/'` is reported only if 20 | :obj:`label_names` is specified. 21 | 22 | * :obj:`'map'`: Mean of average precisions (mAP). 23 | * :obj:`'ap/'`: Average precision for class \ 24 | :obj:`label_names[l]`, where :math:`l` is the index of the class. \ 25 | For example, this evaluator reports :obj:`'ap/aeroplane'`, \ 26 | :obj:`'ap/bicycle'`, etc. if :obj:`label_names` is \ 27 | :obj:`~chainercv.datasets.sbd_instance_segmentation_label_names`. \ 28 | If there is no bounding box assigned to class :obj:`label_names[l]` \ 29 | in either ground truth or prediction, it reports :obj:`numpy.nan` as \ 30 | its average precision. \ 31 | In this case, mAP is computed without this class. 32 | 33 | Args: 34 | iterator (chainer.Iterator): An iterator. Each sample should be 35 | following tuple :obj:`img, bbox, label` or 36 | :obj:`img, bbox, label, difficult`. 37 | :obj:`img` is an image, :obj:`bbox` is coordinates of bounding 38 | boxes, :obj:`label` is labels of the bounding boxes and 39 | :obj:`difficult` is whether the bounding boxes are difficult or 40 | not. If :obj:`difficult` is returned, difficult ground truth 41 | will be ignored from evaluation. 42 | target (chainer.Link): An instance-segmentation link. This link must 43 | have :meth:`predict` method that takes a list of images and returns 44 | :obj:`bboxes`, :obj:`labels` and :obj:`scores`. 45 | iou_thresh (float): Intersection over Union (IoU) threshold for 46 | calulating average precision. The default value is 0.5. 47 | use_07_metric (bool): Whether to use PASCAL VOC 2007 evaluation metric 48 | for calculating average precision. The default value is 49 | :obj:`False`. 50 | label_names (iterable of strings): An iterable of names of classes. 51 | If this value is specified, average precision for each class is 52 | also reported with the key :obj:`'ap/'`. 53 | """ 54 | 55 | trigger = 1, 'epoch' 56 | default_name = 'validation' 57 | priority = chainer.training.PRIORITY_WRITER 58 | 59 | def __init__( 60 | self, iterator, target, 61 | iou_thresh=0.5, use_07_metric=False, label_names=None 62 | ): 63 | super().__init__(iterator, target) 64 | self.iou_thresh = iou_thresh 65 | self.use_07_metric = use_07_metric 66 | self.label_names = label_names 67 | 68 | def evaluate(self): 69 | iterator = self._iterators['main'] 70 | target = self._targets['main'] 71 | 72 | if hasattr(iterator, 'reset'): 73 | iterator.reset() 74 | it = iterator 75 | else: 76 | it = copy.copy(iterator) 77 | 78 | in_values, out_values, rest_values = apply_to_iterator( 79 | target.predict, it) 80 | # delete unused iterators explicitly 81 | del in_values 82 | 83 | pred_masks, pred_labels, pred_scores = out_values 84 | gt_masks, gt_labels = rest_values 85 | 86 | result = eval_instance_segmentation_voc( 87 | pred_masks, pred_labels, pred_scores, 88 | gt_masks, gt_labels, 89 | iou_thresh=self.iou_thresh, 90 | use_07_metric=self.use_07_metric) 91 | 92 | report = {'map': result['map']} 93 | 94 | if self.label_names is not None: 95 | for l, label_name in enumerate(self.label_names): 96 | try: 97 | report['ap/{:s}'.format(label_name)] = result['ap'][l] 98 | except IndexError: 99 | report['ap/{:s}'.format(label_name)] = np.nan 100 | 101 | observation = {} 102 | with reporter.report_scope(observation): 103 | reporter.report(report, target) 104 | return observation 105 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def get_keypoints(): 7 | keypoints = [ 8 | 'SpineBase', 9 | 'SpineMid', 10 | 'Neck', 11 | 'Head', 12 | 'ShoulderLeft', 13 | 'ElbowLeft', 14 | 'WristLeft', 15 | 'HandLeft', 16 | 'ShoulderRight', 17 | 'ElbowRight', 18 | 'WristRight', 19 | 'HandRight', 20 | 'HipLeft', 21 | 'KneeLeft', 22 | 'AnkleLeft', 23 | 'FootLeft', 24 | 'HipRight', 25 | 'KneeRight', 26 | 'AnkleRight', 27 | 'FootRight' 28 | ] 29 | keypoint_flip_map = { 30 | 'ShoulderLeft': 'ShoulderRight', 31 | 'ElbowLeft': 'ElbowRight', 32 | 'WristLeft': 'WristRight', 33 | 'HipLeft': 'HipRight', 34 | 'KneeLeft': 'KneeRight', 35 | 'FootLeft': 'FootRight' 36 | } 37 | return keypoints, keypoint_flip_map 38 | 39 | 40 | def kp_connections(keypoints): 41 | kp_lines = [ 42 | [keypoints.index('ShoulderRight'), keypoints.index('ElbowRight')], 43 | [keypoints.index('ElbowRight'), keypoints.index('WristRight')], 44 | [keypoints.index('ShoulderLeft'), keypoints.index('ElbowLeft')], 45 | [keypoints.index('ElbowLeft'), keypoints.index('WristLeft')], 46 | [keypoints.index('HipRight'), keypoints.index('KneeRight')], 47 | [keypoints.index('KneeRight'), keypoints.index('AnkleRight')], 48 | [keypoints.index('HipLeft'), keypoints.index('KneeLeft')], 49 | [keypoints.index('KneeLeft'), keypoints.index('AnkleLeft')], 50 | [keypoints.index('ShoulderRight'), keypoints.index('Neck')], 51 | [keypoints.index('Neck'), keypoints.index('ShoulderLeft')], 52 | [keypoints.index('Neck'), keypoints.index('Head')], 53 | [keypoints.index('Neck'), keypoints.index('SpineBase')], 54 | [keypoints.index('SpineBase'), keypoints.index('HipRight')], 55 | [keypoints.index('SpineBase'), keypoints.index('HipLeft')] 56 | ] 57 | return kp_lines 58 | 59 | 60 | def vis_keypoints(img, kps, kp_thresh=2, alpha=0.7): 61 | """Visualizes keypoints (adapted from vis_one_image). 62 | kps has shape (4, #keypoints) where 4 rows are (x, y, logit, prob). 63 | """ 64 | dataset_keypoints, _ = get_keypoints() 65 | kp_lines = kp_connections(dataset_keypoints) 66 | 67 | # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv. 68 | cmap = plt.get_cmap('rainbow') 69 | colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)] 70 | colors = [(int(c[2] * 255), int(c[1] * 255), int(c[0] * 255)) 71 | for c in colors] 72 | 73 | # Perform the drawing on a copy of the image, to allow for blending. 74 | kp_mask = np.copy(img) 75 | 76 | # Draw mid shoulder / mid hip first for better visualization. 77 | mid_shoulder = ( 78 | kps[:2, dataset_keypoints.index('ShoulderRight')] + 79 | kps[:2, dataset_keypoints.index('ShoulderLeft')]) / 2.0 80 | sc_mid_shoulder = np.minimum( 81 | kps[2, dataset_keypoints.index('ShoulderRight')], 82 | kps[2, dataset_keypoints.index('ShoulderLeft')]) 83 | mid_hip = ( 84 | kps[:2, dataset_keypoints.index('HipRight')] + 85 | kps[:2, dataset_keypoints.index('HipLeft')]) / 2.0 86 | sc_mid_hip = np.minimum( 87 | kps[2, dataset_keypoints.index('HipRight')], 88 | kps[2, dataset_keypoints.index('HipLeft')]) 89 | 90 | if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh: 91 | cv2.line( 92 | kp_mask, tuple(mid_shoulder.astype(np.int32)), tuple( 93 | mid_hip.astype(np.int32)), 94 | color=colors[len(kp_lines) + 1], thickness=2, lineType=cv2.LINE_AA) 95 | 96 | # Draw the keypoints. 97 | for l in range(len(kp_lines)): 98 | i1 = kp_lines[l][0] 99 | i2 = kp_lines[l][1] 100 | p1 = kps[0, i1].astype(np.int32), kps[1, i1].astype(np.int32) 101 | p2 = kps[0, i2].astype(np.int32), kps[1, i2].astype(np.int32) 102 | if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh: 103 | cv2.line( 104 | kp_mask, p1, p2, 105 | color=colors[l], thickness=2, lineType=cv2.LINE_AA) 106 | if kps[2, i1] > kp_thresh: 107 | cv2.circle( 108 | kp_mask, p1, 109 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 110 | if kps[2, i2] > kp_thresh: 111 | cv2.circle( 112 | kp_mask, p2, 113 | radius=3, color=colors[l], thickness=-1, lineType=cv2.LINE_AA) 114 | 115 | # Blend the keypoints. 116 | return cv2.addWeighted(img, 1.0 - alpha, kp_mask, alpha, 0) 117 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/head/fpn_roi_keypoint_head.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.links as L 3 | import chainer.functions as F 4 | from chainer.links.model.vision.resnet import ResNet50Layers, BuildingBlock, _global_average_pooling_2d 5 | import numpy as np 6 | import copy 7 | from chainer_maskrcnn.functions.roi_align_2d_yx import _roi_align_2d_yx 8 | 9 | 10 | class FPNRoIKeypointHead(chainer.Chain): 11 | mask_size = 56 12 | 13 | def __init__(self, 14 | n_class, 15 | n_keypoints, 16 | roi_size_box, 17 | roi_size_mask, 18 | n_mask_convs=8, 19 | loc_initialW=None, 20 | score_initialW=None, 21 | mask_initialW=None): 22 | # n_class includes the background 23 | super().__init__() 24 | with self.init_scope(): 25 | # layers for box prediction path 26 | self.conv1 = L.Convolution2D( 27 | in_channels=None, out_channels=256, ksize=3, pad=1) 28 | self.fc1 = L.Linear(None, 1024) 29 | self.fc2 = L.Linear(None, 1024) 30 | self.cls_loc = L.Linear(1024, 4, initialW=loc_initialW) 31 | self.score = L.Linear(1024, n_class, initialW=score_initialW) 32 | 33 | # mask prediction path 34 | self.mask_convs = chainer.ChainList() 35 | for i in range(n_mask_convs): 36 | self.mask_convs.add_link( 37 | L.Convolution2D(None, 256, ksize=3, pad=1)) 38 | self.deconv1 = L.Deconvolution2D( 39 | in_channels=None, 40 | out_channels=256, 41 | ksize=2, 42 | stride=2, 43 | pad=0, 44 | initialW=mask_initialW) 45 | self.conv2 = L.Convolution2D( 46 | in_channels=None, 47 | out_channels=n_keypoints, 48 | ksize=1, 49 | stride=1, 50 | pad=0, 51 | initialW=mask_initialW) 52 | 53 | self.n_class = n_class 54 | self.roi_size_box = roi_size_box 55 | self.roi_size_mask = roi_size_mask 56 | 57 | def __call__(self, x, indices_and_rois, levels, spatial_scales): 58 | 59 | pool_box = list() 60 | levels = chainer.cuda.to_cpu(levels).astype(np.int32) 61 | 62 | if len(np.unique(levels)) == 1: 63 | pool_box = _roi_align_2d_yx(x[0], indices_and_rois, self.roi_size_box, 64 | self.roi_size_box, spatial_scales[0]) 65 | else: 66 | for l, i in zip(levels, indices_and_rois): 67 | v = _roi_align_2d_yx(x[l], i[None], self.roi_size_box, 68 | self.roi_size_box, spatial_scales[l]) 69 | pool_box.append(v) 70 | 71 | pool_box = F.concat(pool_box, axis=0) 72 | 73 | h = self.conv1(pool_box) 74 | h = F.relu(h) 75 | h = F.relu(self.fc1(h)) 76 | h = F.relu(self.fc2(h)) 77 | roi_cls_locs = self.cls_loc(h) 78 | roi_scores = self.score(h) 79 | # at prediction time, we use two pass method. 80 | # at first path, we predict box location and class 81 | # at second path, we predict mask with accurate location from first path 82 | if chainer.config.train: 83 | pool_mask = list() 84 | for l, i in zip(levels, indices_and_rois): 85 | pool_mask.append(_roi_align_2d_yx(x[l], i[None], self.roi_size_mask, 86 | self.roi_size_mask, spatial_scales[l])) 87 | mask = F.concat(pool_mask, axis=0) 88 | for l in self.mask_convs.children(): 89 | mask = F.relu(l(mask)) 90 | mask = self.conv2(self.deconv1(mask)) 91 | *_, h, w = mask.shape 92 | mask = F.resize_images(mask, output_shape=(2 * h, 2 * w)) 93 | return roi_cls_locs, roi_scores, mask 94 | else: 95 | # cache 96 | self.x = x 97 | return roi_cls_locs, roi_scores 98 | 99 | def predict_mask(self, levels, indices_and_rois, spatial_scales): 100 | pool_mask = list() 101 | for l, i in zip(levels, indices_and_rois): 102 | pool_mask.append(_roi_align_2d_yx(self.x[l], i[None], self.roi_size_mask, 103 | self.roi_size_mask, spatial_scales[l])) 104 | mask = F.concat(pool_mask, axis=0) 105 | for l in self.mask_convs: 106 | mask = F.relu(l(mask)) 107 | mask = self.conv2(self.deconv1(mask)) 108 | *_, h, w = mask.shape 109 | mask = F.resize_images(mask, output_shape=(2 * h, 2 * w)) 110 | 111 | return mask 112 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/fpn_maskrcnn_train_chain.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | from chainercv.links.model.faster_rcnn.utils.anchor_target_creator import\ 4 | AnchorTargetCreator 5 | from chainercv.links.model.faster_rcnn.faster_rcnn_train_chain import FasterRCNNTrainChain, _smooth_l1_loss, _fast_rcnn_loc_loss 6 | import cv2 7 | import numpy as np 8 | from .maskrcnn import MaskRCNN 9 | from .extractor.c4_backbone import C4Backbone 10 | from .extractor.feature_pyramid_network import FeaturePyramidNetwork 11 | from chainer_maskrcnn.utils.proposal_target_creator import ProposalTargetCreator 12 | 13 | 14 | class FPNMaskRCNNTrainChain(FasterRCNNTrainChain): 15 | def __init__(self, 16 | faster_rcnn, 17 | mask_loss_fun, 18 | binary_mask=True, 19 | rpn_sigma=3., 20 | roi_sigma=1., 21 | anchor_target_creator=AnchorTargetCreator()): 22 | # todo: clean up class dependencies 23 | proposal_target_creator = ProposalTargetCreator( 24 | faster_rcnn.extractor.anchor_sizes) 25 | super().__init__( 26 | faster_rcnn, proposal_target_creator=proposal_target_creator) 27 | self.mask_loss_fun = mask_loss_fun 28 | self.binary_mask = binary_mask 29 | 30 | def __call__(self, imgs, bboxes, labels, masks, scale): 31 | def strip(x): return x.data if isinstance(x, chainer.Variable) else x 32 | bboxes = strip(bboxes) 33 | labels = strip(labels) 34 | scale = strip(scale) 35 | masks = strip(masks) 36 | scale = np.asscalar(chainer.cuda.to_cpu(scale)) 37 | n = bboxes.shape[0] 38 | if n != 1: 39 | raise ValueError( 40 | 'Currently only batch size 1 is supported. n={}'.format(n)) 41 | 42 | _, _, H, W = imgs.shape 43 | img_size = (H, W) 44 | 45 | features = self.faster_rcnn.extractor(imgs) 46 | 47 | # Since batch size is one, convert variables to singular form 48 | bbox = bboxes[0] 49 | label = labels[0] 50 | mask = masks[0] 51 | 52 | rpn_locs, rpn_scores, rois, roi_indices, anchor, levels = self.faster_rcnn.rpn( 53 | features, img_size, scale) 54 | 55 | # Since batch size is one, convert variables to singular form 56 | rpn_score = rpn_scores[0] 57 | rpn_loc = rpn_locs[0] 58 | roi = rois 59 | 60 | # Sample RoIs and forward 61 | # gt_roi_labelになった時点で [0, NUM_FOREGROUND_CLASS-1]が[1, NUM_FOREGROUND_CLASS]にシフトしている 62 | sample_roi, sample_levels, gt_roi_loc, gt_roi_label, gt_roi_mask = self.proposal_target_creator( 63 | roi, 64 | bbox, 65 | label, 66 | mask, 67 | levels, 68 | self.loc_normalize_mean, 69 | self.loc_normalize_std, 70 | mask_size=self.faster_rcnn.head.mask_size, 71 | binary_mask=self.binary_mask) 72 | 73 | sample_roi_index = self.xp.zeros( 74 | (len(sample_roi), ), dtype=np.int32) 75 | 76 | # join roi and index of batch 77 | indices_and_rois = self.xp.concatenate( 78 | (sample_roi_index[:, None], sample_roi), axis=1).astype(self.xp.float32) 79 | 80 | # RPN losses 81 | gt_rpn_loc, gt_rpn_label = self.anchor_target_creator( 82 | bbox, anchor, img_size) 83 | rpn_loc_loss = _fast_rcnn_loc_loss(rpn_loc, gt_rpn_loc, 84 | gt_rpn_label, self.rpn_sigma) 85 | rpn_cls_loss = F.softmax_cross_entropy(rpn_score, gt_rpn_label) 86 | 87 | # Losses for outputs of the head. 88 | roi_cls_loc, roi_score, roi_cls_mask = self.faster_rcnn.head( 89 | features, indices_and_rois, sample_levels, self.faster_rcnn.extractor.spatial_scales) 90 | 91 | # Losses for outputs of the head. 92 | n_sample = roi_cls_loc.shape[0] 93 | roi_cls_loc = roi_cls_loc.reshape(n_sample, -1, 4) 94 | # light-headのときはそのまま使う 95 | if roi_cls_loc.shape[1] == 1: 96 | roi_loc = roi_cls_loc.reshape(n_sample, 4) 97 | else: 98 | roi_loc = roi_cls_loc[self.xp.arange(n_sample), gt_roi_label] 99 | 100 | roi_loc_loss = _fast_rcnn_loc_loss(roi_loc, gt_roi_loc, gt_roi_label, 101 | self.roi_sigma) 102 | roi_cls_loss = F.softmax_cross_entropy(roi_score, gt_roi_label) 103 | 104 | mask_loss = self.mask_loss_fun( 105 | roi_cls_mask, gt_roi_mask, self.xp, gt_roi_label) 106 | loss = rpn_loc_loss + rpn_cls_loss + roi_loc_loss + roi_cls_loss + mask_loss 107 | 108 | chainer.reporter.report({ 109 | 'rpn_loc_loss': rpn_loc_loss, 110 | 'rpn_cls_loss': rpn_cls_loss, 111 | 'roi_loc_loss': roi_loc_loss, 112 | 'roi_cls_loss': roi_cls_loss, 113 | 'mask_loss': mask_loss, 114 | 'loss': loss 115 | }, self) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /viewer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | import pyrealsense2 as rs 5 | import numpy as np 6 | import cv2 7 | 8 | import chainer 9 | from chainer import using_config 10 | from chainer_maskrcnn.model.maskrcnn import MaskRCNN 11 | 12 | import vis 13 | 14 | 15 | class SimpleInfer: 16 | def __init__(self, weight, file=None): 17 | self.model = MaskRCNN(1, n_keypoints=20, n_mask_convs=2, 18 | min_size=240, backbone='darknet', head_arch='fpn_keypoint') 19 | chainer.serializers.load_npz(weight, self.model, strict=True) 20 | self.in_channels = self.model.extractor.conv1.c.W.shape[1] 21 | print('number of parameters:{}'.format( 22 | sum(p.data.size for p in self.model.params()))) 23 | if chainer.backends.intel64.is_ideep_available(): 24 | self.model.to_intel64() 25 | 26 | # Configure depth and color streams 27 | self.pipeline = rs.pipeline() 28 | self.config = rs.config() 29 | 30 | if file is not None: 31 | print('load from {}'.format(file)) 32 | self.config.enable_device_from_file(file) 33 | else: 34 | print('launch device') 35 | self.config.enable_stream( 36 | rs.stream.depth, 424, 240, rs.format.z16, 30) 37 | self.config.enable_stream( 38 | rs.stream.color, 640, 480, rs.format.bgr8, 30) 39 | 40 | self.margin = 52 41 | self.depth_offset = 0 42 | self.wname = 'depth' 43 | self.avg_fps = 15.0 44 | 45 | def run(self): 46 | # Start streaming 47 | self.pipeline.start(self.config) 48 | try: 49 | while True: 50 | self.main_loop() 51 | key = cv2.waitKey(1) & 0xFF 52 | if key == ord('q'): 53 | print('quit') 54 | break 55 | 56 | finally: 57 | # Stop streaming 58 | self.pipeline.stop() 59 | 60 | def main_loop(self): 61 | start = time.time() 62 | # Wait for a coherent pair of frames: depth and color 63 | frames = self.pipeline.wait_for_frames() 64 | depth_frame = frames.get_depth_frame() 65 | color_frame = frames.get_color_frame() 66 | if not depth_frame or not color_frame: 67 | return 68 | 69 | # Convert images to numpy arrays 70 | depth_image = np.asanyarray(depth_frame.get_data()) 71 | color_image = np.asanyarray(color_frame.get_data()) 72 | 73 | # Apply colormap on depth image (image must be converted to 8-bit per pixel first) 74 | depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs( 75 | depth_image, alpha=0.03), cv2.COLORMAP_JET) 76 | 77 | # D435が16:9でしかキャプチャできないので4:3にクロップする 78 | cropped_depth = depth_image[:, self.margin:-self.margin] 79 | # depth_datasetでやってしまった謎の変換 80 | cropped_depth = np.clip(cropped_depth.astype( 81 | np.float32) - self.depth_offset, 0, 4000) / 3000 * 255 82 | 83 | if self.in_channels == 1: 84 | cropped_depth = cropped_depth[None] 85 | else: 86 | cropped_depth = np.stack( 87 | [cropped_depth, cropped_depth, cropped_depth]) 88 | 89 | s = 56 90 | with using_config('train', False), using_config('enable_backprop', False), using_config('use_ideep', 'auto'): 91 | box, label, score, keypoints = self.model.predict( 92 | cropped_depth[None]) 93 | 94 | if len(box[0]): 95 | kps = np.argmax(keypoints[0][0], axis=2) 96 | indices = np.argsort(keypoints[0][0], axis=2) 97 | kp_logits = keypoints[0][0][0, 98 | np.arange(20), indices[:, :, -1]] 99 | kp_probs = chainer.functions.softmax( 100 | keypoints[0][0]).array[0, np.arange(20), indices[:, :, -1]] 101 | vis_keys = [] 102 | for kp, b, logits, probs in zip(kps, box[0], kp_logits, kp_probs): 103 | sh, sw = (b[2] - b[0]) / s, (b[3] - b[1]) / s 104 | for i, (k, l, p) in enumerate(zip(kp, logits, probs)): 105 | vis_keys.append([k // s * sh + b[0], k % 106 | s * sw + b[1] + self.margin, l, p]) 107 | 108 | vis_keys = np.array(vis_keys) 109 | depth_colormap = vis.vis_keypoints( 110 | depth_colormap.transpose((1, 0, 2)).copy(), vis_keys.transpose((1, 0)), kp_thresh=3) 111 | depth_colormap = depth_colormap.transpose((1, 0, 2)) 112 | 113 | # Show images 114 | # cv2.imshow('color', color_image) 115 | cv2.imshow(self.wname, depth_colormap) 116 | end = time.time() 117 | self.avg_fps = self.avg_fps * 0.9 + 1. / (end - start) * 0.1 118 | print(self.avg_fps) 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument('--file', help='bag file') 124 | parser.add_argument('--weight', help='pretrained_weight') 125 | args = parser.parse_args() 126 | SimpleInfer(weight=args.weight, file=args.file).run() 127 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/head/light_roi_mask_head.py: -------------------------------------------------------------------------------- 1 | # Light-Head R-CNN: In Defense of Two-Stage Object Detector http://arxiv.org/abs/1711.07264 2 | import chainer 3 | import chainer.links as L 4 | import chainer.functions as F 5 | from chainer.links.model.vision.resnet import ResNet50Layers, BuildingBlock, _global_average_pooling_2d 6 | import numpy as np 7 | import copy 8 | from chainer_maskrcnn.functions.roi_align_2d_yx import _roi_align_2d_yx 9 | 10 | 11 | class LightRoIMaskHead(chainer.Chain): 12 | mask_size = 14 13 | 14 | def __init__(self, 15 | n_class, 16 | roi_size, 17 | loc_initialW=None, 18 | score_initialW=None, 19 | mask_initialW=None): 20 | # n_class includes the background 21 | super().__init__() 22 | with self.init_scope(): 23 | # Separable Convolution Layers 変数の名前はpaperに準拠してみた 24 | k = 15 25 | C_mid = 256 26 | C_out = 490 27 | # 初期値ワカラン・・ chainer.initializers.Normal(0.001)とかの方がいいか? 28 | # レイヤーの名前は論文の図を見たときに up left, bottom left, up right, bottom rightの4つw 29 | p = int(k / 2) 30 | self.conv_ul = L.Convolution2D( 31 | in_channels=None, out_channels=C_mid, ksize=(k, 1), pad=(p, 0)) 32 | self.conv_bl = L.Convolution2D( 33 | in_channels=C_mid, 34 | out_channels=C_out, 35 | ksize=(1, k), 36 | pad=(0, p)) 37 | self.conv_ur = L.Convolution2D( 38 | in_channels=None, out_channels=C_mid, ksize=(1, k), pad=(0, p)) 39 | self.conv_br = L.Convolution2D( 40 | in_channels=C_mid, 41 | out_channels=C_out, 42 | ksize=(k, 1), 43 | pad=(p, 0)) 44 | self.fc = L.Linear(None, 2048) 45 | self.cls_loc = L.Linear(2048, 4, initialW=loc_initialW) 46 | self.score = L.Linear(2048, n_class, initialW=score_initialW) 47 | # マスク推定ブランチへ 48 | self.conv2 = L.Convolution2D( 49 | in_channels=None, 50 | out_channels=256, 51 | ksize=3, 52 | stride=1, 53 | pad=1, 54 | initialW=mask_initialW) 55 | self.conv3_ = L.Convolution2D( 56 | in_channels=None, 57 | out_channels=256, 58 | ksize=3, 59 | stride=1, 60 | pad=1, 61 | initialW=mask_initialW) 62 | self.conv4 = L.Convolution2D( 63 | in_channels=None, 64 | out_channels=256, 65 | ksize=3, 66 | stride=1, 67 | pad=1, 68 | initialW=mask_initialW) 69 | self.deconv1_ = L.Deconvolution2D( 70 | in_channels=None, 71 | out_channels=n_class - 1, 72 | ksize=2, 73 | stride=2, 74 | pad=0, 75 | initialW=mask_initialW) 76 | 77 | self.n_class = n_class 78 | self.roi_size = roi_size 79 | 80 | def __call__(self, x, rois, roi_indices, spatial_scale): 81 | roi_indices = roi_indices.astype(np.float32) 82 | indices_and_rois = self.xp.concatenate( 83 | (roi_indices[:, None], rois), axis=1) 84 | 85 | # roi poolingをする前に、thin feature mapに変換します 86 | # activationしないっぽいことが書いてあるんだよなー 87 | left_path = self.conv_bl(self.conv_ul(x)) 88 | right_path = self.conv_br(self.conv_ur(x)) 89 | tfp = left_path + right_path 90 | 91 | pool = _roi_align_2d_yx(tfp, indices_and_rois, self.roi_size, 92 | self.roi_size, spatial_scale) 93 | 94 | h = F.relu(self.fc(pool)) 95 | roi_cls_locs = self.cls_loc(h) 96 | roi_scores = self.score(h) 97 | # at prediction time, we use two pass method. 98 | # at first path, we predict box location and class 99 | # at second path, we predict mask with accurate location from first path 100 | if chainer.config.train: 101 | mask = F.relu(self.conv2(pool)) 102 | mask = F.relu(self.conv3_(mask)) 103 | mask = F.relu(self.conv4(mask)) 104 | mask = self.deconv1_(pool) 105 | #mask = self.conv2(self.deconv1(pool)) 106 | return roi_cls_locs, roi_scores, mask 107 | else: 108 | # cache tfp for second path 109 | self.tfp = tfp 110 | return roi_cls_locs, roi_scores 111 | 112 | def predict_mask(self, rois, roi_indices, spatial_scale): 113 | roi_indices = roi_indices.astype(np.float32) 114 | indices_and_rois = self.xp.concatenate( 115 | (roi_indices[:, None], rois), axis=1) 116 | pool = _roi_align_2d_yx(self.tfp, indices_and_rois, self.roi_size, 117 | self.roi_size, spatial_scale) 118 | 119 | mask = F.relu(self.conv2(pool)) 120 | mask = F.relu(self.conv3_(mask)) 121 | mask = F.relu(self.conv4(mask)) 122 | mask = self.deconv1_(pool) 123 | # mask = self.deconv1(pool) 124 | # mask = self.conv2_(mask) 125 | # mask = self.conv3(mask) 126 | # mask = self.conv2(self.deconv1(pool)) 127 | return mask 128 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/maskrcnn_train_chain.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | import chainer.functions as F 3 | from chainercv.links.model.faster_rcnn.utils.anchor_target_creator import\ 4 | AnchorTargetCreator 5 | from chainercv.links.model.faster_rcnn.faster_rcnn_train_chain import FasterRCNNTrainChain, _smooth_l1_loss, _fast_rcnn_loc_loss 6 | import cv2 7 | import numpy as np 8 | from .maskrcnn import MaskRCNN 9 | from chainer_maskrcnn.utils.proposal_target_creator import ProposalTargetCreator 10 | from .extractor.feature_pyramid_network import FeaturePyramidNetwork 11 | from .extractor.c4_backbone import C4Backbone 12 | 13 | 14 | class MaskRCNNTrainChain(FasterRCNNTrainChain): 15 | def __init__(self, 16 | faster_rcnn, 17 | rpn_sigma=3., 18 | roi_sigma=1., 19 | anchor_target_creator=AnchorTargetCreator(), 20 | proposal_target_creator=ProposalTargetCreator()): 21 | super(MaskRCNNTrainChain, self).__init__( 22 | faster_rcnn, proposal_target_creator=proposal_target_creator) 23 | 24 | def __call__(self, imgs, bboxes, labels, masks, scale): 25 | def strip(x): return x.data if isinstance(x, chainer.Variable) else x 26 | bboxes = strip(bboxes) 27 | labels = strip(labels) 28 | masks = strip(masks) 29 | 30 | scale = np.asscalar(chainer.cuda.to_cpu(scale)) 31 | n = bboxes.shape[0] 32 | if n != 1: 33 | raise ValueError( 34 | 'Currently only batch size 1 is supported. n={}'.format(n)) 35 | 36 | _, _, H, W = imgs.shape 37 | img_size = (H, W) 38 | 39 | features = self.faster_rcnn.extractor(imgs) 40 | 41 | # Since batch size is one, convert variables to singular form 42 | bbox = bboxes[0] 43 | label = labels[0] 44 | mask = masks[0] 45 | 46 | # iterate over feature pyramids 47 | proposals = [] 48 | rpn_outputs = [] 49 | gt_data = [] 50 | for feature in features: 51 | rpn_locs, rpn_scores, rois, roi_indices, anchor = self.faster_rcnn.rpn( 52 | feature, img_size, scale) 53 | 54 | # Since batch size is one, convert variables to singular form 55 | rpn_score = rpn_scores[0] 56 | rpn_loc = rpn_locs[0] 57 | roi = rois 58 | 59 | # Sample RoIs and forward 60 | # gt_roi_labelになった時点で [0, NUM_FOREGROUND_CLASS-1]が[1, NUM_FOREGROUND_CLASS]にシフトしている 61 | sample_roi, gt_roi_loc, gt_roi_label, gt_roi_mask = self.proposal_target_creator( 62 | roi, 63 | bbox, 64 | label, 65 | mask, 66 | self.loc_normalize_mean, 67 | self.loc_normalize_std, 68 | mask_size=self.faster_rcnn.head.mask_size) 69 | 70 | sample_roi_index = self.xp.zeros( 71 | (len(sample_roi), ), dtype=np.int32) 72 | 73 | proposals.append((sample_roi, sample_roi_index, 74 | 1 / self.faster_rcnn.feat_stride * s)) 75 | rpn_outputs.append((rpn_loc, rpn_score, roi, anchor)) 76 | gt_data.append((gt_roi_loc, gt_roi_label, gt_roi_mask)) 77 | 78 | if len(features) == 1: 79 | sample_roi, sample_roi_index, s = proposals[0] 80 | roi_cls_loc, roi_score, roi_cls_mask = self.faster_rcnn.head( 81 | features[0], sample_roi, sample_roi_index, s) 82 | 83 | else: 84 | roi_cls_loc, roi_score, roi_cls_mask = self.faster_rcnn.head( 85 | features, proposals) 86 | 87 | # RPN losses 88 | rpn_loc_loss = chainer.Variable( 89 | self.xp.array(0, dtype=self.xp.float32)) 90 | rpn_cls_loss = chainer.Variable( 91 | self.xp.array(0, dtype=self.xp.float32)) 92 | for (p, r) in zip(proposals, rpn_outputs): 93 | rpn_loc, rpn_score, _, anchor = r 94 | gt_rpn_loc, gt_rpn_label = self.anchor_target_creator( 95 | bbox, anchor, img_size) 96 | rpn_loc_loss += _fast_rcnn_loc_loss(rpn_loc, gt_rpn_loc, 97 | gt_rpn_label, self.rpn_sigma) 98 | rpn_cls_loss += F.softmax_cross_entropy(rpn_score, gt_rpn_label) 99 | 100 | # Losses for outputs of the head. 101 | n_sample = roi_cls_loc.shape[0] 102 | roi_cls_loc = roi_cls_loc.reshape(n_sample, -1, 4) 103 | # light-headのときはそのまま使う 104 | if roi_cls_loc.shape[1] == 1: 105 | roi_loc = roi_cls_loc.reshape(n_sample, 4) 106 | else: 107 | roi_loc = roi_cls_loc[self.xp.arange(n_sample), gt_roi_label] 108 | 109 | gt_roi_loc = self.xp.concatenate([g[0] for g in gt_data], axis=0) 110 | gt_roi_label = self.xp.concatenate([g[1] for g in gt_data], axis=0) 111 | gt_roi_mask = self.xp.concatenate([g[2] for g in gt_data], axis=0) 112 | roi_loc_loss = _fast_rcnn_loc_loss(roi_loc, gt_roi_loc, 113 | gt_roi_label, self.roi_sigma) 114 | roi_cls_loss = F.softmax_cross_entropy(roi_score, gt_roi_label) 115 | 116 | # mask 117 | # https://engineer.dena.jp/2017/12/chainercvmask-r-cnn.html 118 | roi_mask = roi_cls_mask[self.xp.arange(n_sample), gt_roi_label] 119 | mask_loss = F.sigmoid_cross_entropy( 120 | roi_mask[0:gt_roi_mask.shape[0]], gt_roi_mask) 121 | loss = rpn_loc_loss + rpn_cls_loss + roi_loc_loss + roi_cls_loss + mask_loss 122 | 123 | chainer.reporter.report({ 124 | 'rpn_loc_loss': rpn_loc_loss, 125 | 'rpn_cls_loss': rpn_cls_loss, 126 | 'roi_loc_loss': roi_loc_loss, 127 | 'roi_cls_loss': roi_cls_loss, 128 | 'mask_loss': mask_loss, 129 | 'loss': loss 130 | }, self) 131 | 132 | return loss 133 | -------------------------------------------------------------------------------- /chainer_maskrcnn/dataset/coco_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join 4 | import chainer 5 | from chainercv.utils import read_image 6 | from pycocotools.coco import COCO 7 | from PIL import Image 8 | from random import shuffle 9 | 10 | 11 | class COCOMaskLoader(chainer.dataset.DatasetMixin): 12 | def __init__(self, 13 | anno_dir='data/annotations', 14 | img_dir='data', 15 | split='train', 16 | data_type='2014', 17 | category_filter=None): 18 | if split not in ['train', 'val', 'validation']: 19 | raise ValueError( 20 | 'please pick split from \'train\', \'val\',\'validation\'') 21 | 22 | if split == 'validation': 23 | split = 'val' 24 | 25 | ann_file = '{}/instances_{}{}.json'.format(anno_dir, split, data_type) 26 | self.coco = COCO(ann_file) 27 | 28 | self.img_dir = '{}/{}{}'.format(img_dir, split, data_type) 29 | print('load jpg images from {}'.format(self.img_dir)) 30 | target_cats = [] if category_filter is None else category_filter 31 | self.cat_ids = self.coco.getCatIds(catNms=target_cats) 32 | # cat_idsの中のどれかが含まれる画像、を探したい(or検索) 33 | # getImgIdsの引数にしていするとand検索されるので、泥草する 34 | img_ids = set() 35 | for cat_id in self.cat_ids: 36 | img_ids |= set(self.coco.getImgIds(catIds=[cat_id])) 37 | 38 | self.img_infos = [(i['file_name'], i['id']) 39 | for i in self.coco.loadImgs(img_ids)] 40 | self.length = len(self.img_infos) 41 | 42 | def __len__(self): 43 | return self.length 44 | 45 | # 少なくとも1つ十分大きいものがあればOKとするフィルタ 46 | def _contain_large_enough_annotation(self, img_id, min_w=10, min_h=10): 47 | anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 48 | for ann in anns: 49 | x, y, w, h = [int(j) for j in ann['bbox']] 50 | if w <= min_w or h <= min_h: 51 | continue 52 | 53 | if ann['category_id'] in self.cat_ids: 54 | return True 55 | 56 | return False 57 | 58 | # 画像内の全てのアノテーションが大きくないとダメだぞというフィルタ 59 | def _contain_large_annotation_only(self, img_id, min_w=10, min_h=10): 60 | anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 61 | for ann in anns: 62 | x, y, w, h = [int(j) for j in ann['bbox']] 63 | if ann['category_id'] in self.cat_ids and (w <= min_w 64 | or h <= min_h): 65 | return False 66 | 67 | return True 68 | 69 | def get_example(self, i): 70 | if i >= self.length: 71 | raise IndexError('index is out of bounds.') 72 | file_name, img_id = self.img_infos[i] 73 | img = read_image(join(self.img_dir, file_name), color=True) 74 | assert img.shape[0] == 3 75 | anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 76 | gt_boxes = [] 77 | gt_masks = [] 78 | gt_labels = [] 79 | for ann in anns: 80 | x, y, w, h = [int(j) for j in ann['bbox']] 81 | 82 | # これめっちゃ罠で、category_idが連続してないんだよなー 83 | if ann['category_id'] in self.cat_ids: 84 | continuous_cat_id = self.cat_ids.index(ann['category_id']) 85 | gt_boxes.append( 86 | np.array([y, x, y + h, x + w], dtype=np.float32)) 87 | gt_masks.append(self.coco.annToMask(ann)) 88 | gt_labels.append(continuous_cat_id) 89 | 90 | if len(gt_boxes) == 0: 91 | print( 92 | '小さすぎるアノテーションを削除した結果、この画像にはground_truthが一つも含まれませんでした。これで学習にエラーが出る場合、事前に小さなアノテーションしか含まれない画像はself.imgsから削除することを検討してください' 93 | ) 94 | 95 | # gt_masksはlistのままにしておいて、Transformでcv::resizeしたあとにnumpy arrayにするという泥臭 96 | return img, np.array(gt_boxes), np.array( 97 | gt_labels, dtype=np.int32), gt_masks 98 | 99 | 100 | class COCOKeypointsLoader(chainer.dataset.DatasetMixin): 101 | n_keypoints = 17 102 | 103 | def __init__(self, 104 | anno_dir='data/annotations', 105 | img_dir='data', 106 | split='train', 107 | data_type='2014'): 108 | if split not in ['train', 'val', 'validation']: 109 | raise ValueError( 110 | 'please pick split from \'train\', \'val\',\'validation\'') 111 | 112 | if split == 'validation': 113 | split = 'val' 114 | 115 | ann_file = '{}/person_keypoints_{}{}.json'.format( 116 | anno_dir, split, data_type) 117 | self.coco = COCO(ann_file) 118 | 119 | self.img_dir = '{}/{}{}'.format(img_dir, split, data_type) 120 | print('load jpg images from {}'.format(self.img_dir)) 121 | img_ids = self.coco.getImgIds(catIds=[1]) # person only 122 | all_img_infos = [(i['file_name'], i['id']) 123 | for i in self.coco.loadImgs(img_ids)] 124 | # keypointsが空のデータもあるので、それは間引く 125 | self.img_infos = list() 126 | for info in all_img_infos: 127 | file_name, img_id = info 128 | anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 129 | if len(anns) > 0: 130 | self.img_infos.append(info) 131 | 132 | self.length = len(self.img_infos) 133 | print('number of valid data.', self.length) 134 | 135 | def __len__(self): 136 | return self.length 137 | 138 | def get_example(self, i): 139 | if i >= self.length: 140 | raise IndexError() 141 | 142 | file_name, img_id = self.img_infos[i] 143 | img = read_image(join(self.img_dir, file_name), color=True) 144 | anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 145 | 146 | keypoints = list() 147 | gt_boxes = list() 148 | for ann in anns: 149 | kp = ann['keypoints'] 150 | kp = np.array(kp).reshape((-1, 3)) 151 | keypoints.append(kp) 152 | 153 | x, y, w, h = [int(j) for j in ann['bbox']] 154 | h = max(1.0, h) 155 | w = max(1.0, w) 156 | gt_boxes.append(np.array([y, x, y + h, x + w], dtype=np.float32)) 157 | 158 | keypoints = np.array(keypoints).reshape((-1, 17, 3)) 159 | gt_boxes = np.array(gt_boxes).reshape((-1, 4)) 160 | 161 | return img, gt_boxes, keypoints 162 | -------------------------------------------------------------------------------- /chainer_maskrcnn/utils/proposal_target_creator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from chainer import cuda 3 | from chainercv.links.model.faster_rcnn.utils.bbox2loc import bbox2loc 4 | from chainercv.utils.bbox.bbox_iou import bbox_iou 5 | import cv2 6 | from chainer_maskrcnn.model.rpn.multilevel_region_proposal_network import map_rois_to_fpn_levels 7 | 8 | # GroundTruthと近いbox, label, maskだけをフィルタリングする 9 | 10 | 11 | class ProposalTargetCreator(object): 12 | def __init__(self, 13 | sizes=[16], 14 | n_sample=256, 15 | pos_ratio=0.25, 16 | pos_iou_thresh=0.5, 17 | neg_iou_thresh_hi=0.5, 18 | neg_iou_thresh_lo=0.0): 19 | self.sizes = sizes 20 | self.n_sample = n_sample 21 | self.pos_ratio = pos_ratio 22 | self.pos_iou_thresh = pos_iou_thresh 23 | self.neg_iou_thresh_hi = neg_iou_thresh_hi 24 | self.neg_iou_thresh_lo = neg_iou_thresh_lo 25 | 26 | def __call__(self, 27 | roi, 28 | bbox, 29 | label, 30 | mask, 31 | levels, 32 | loc_normalize_mean=(0., 0., 0., 0.), 33 | loc_normalize_std=(0.1, 0.1, 0.2, 0.2), 34 | mask_size=14, 35 | binary_mask=True): 36 | """ 37 | binary_mask = False -> keypoint 38 | """ 39 | xp = cuda.get_array_module(roi) 40 | roi = cuda.to_cpu(roi) 41 | bbox = cuda.to_cpu(bbox) 42 | label = cuda.to_cpu(label) 43 | mask = cuda.to_cpu(mask) 44 | levels = cuda.to_cpu(levels) 45 | 46 | n_bbox, _ = bbox.shape 47 | n_proposal = roi.shape[0] 48 | roi = np.concatenate((roi, bbox), axis=0) 49 | 50 | # assign feature levels of ground truth boxes 51 | bbox_levels = map_rois_to_fpn_levels(bbox) 52 | levels = np.concatenate([levels, bbox_levels]) 53 | 54 | pos_roi_per_image = np.round(self.n_sample * self.pos_ratio) 55 | iou = bbox_iou(roi, bbox) 56 | gt_assignment = iou.argmax(axis=1) 57 | max_iou = iou.max(axis=1) 58 | # Offset range of classes from [0, n_fg_class - 1] to [1, n_fg_class]. 59 | # The label with value 0 is the background. 60 | gt_roi_label = label[gt_assignment] + 1 61 | 62 | # Select foreground RoIs as those with >= pos_iou_thresh IoU. 63 | pos_index = np.where(max_iou >= self.pos_iou_thresh)[0] 64 | pos_roi_per_this_image = int(min(pos_roi_per_image, pos_index.size)) 65 | if pos_index.size > 0: 66 | pos_index = np.random.choice( 67 | pos_index, size=pos_roi_per_this_image, replace=False) 68 | 69 | # Select background RoIs as those within 70 | # [neg_iou_thresh_lo, neg_iou_thresh_hi). 71 | neg_index = np.where((max_iou < self.neg_iou_thresh_hi) & 72 | (max_iou >= self.neg_iou_thresh_lo))[0] 73 | neg_roi_per_this_image = self.n_sample - pos_roi_per_this_image 74 | neg_roi_per_this_image = int( 75 | min(neg_roi_per_this_image, neg_index.size)) 76 | if neg_index.size > 0: 77 | neg_index = np.random.choice( 78 | neg_index, size=neg_roi_per_this_image, replace=False) 79 | 80 | # The indices that we're selecting (both positive and negative). 81 | keep_index = np.append(pos_index, neg_index) 82 | gt_roi_label = gt_roi_label[keep_index] 83 | gt_roi_label[pos_roi_per_this_image:] = 0 # negative labels --> 0 84 | sample_roi = roi[keep_index] 85 | sample_levels = levels[keep_index] 86 | 87 | # Compute offsets and scales to match sampled RoIs to the GTs. 88 | gt_roi_loc = bbox2loc(sample_roi, bbox[gt_assignment[keep_index]]) 89 | gt_roi_loc = ((gt_roi_loc - np.array(loc_normalize_mean, np.float32)) / 90 | np.array(loc_normalize_std, np.float32)) 91 | 92 | # https://engineer.dena.jp/2017/12/chainercvmask-r-cnn.html 93 | gt_roi_mask = [] 94 | _, h, w = mask.shape 95 | if binary_mask: 96 | for i, idx in enumerate(gt_assignment[pos_index]): 97 | A = mask[idx, 98 | np.max((int(sample_roi[i, 0]), 99 | 0)):np.min((int(sample_roi[i, 2]), h)), 100 | np.max((int(sample_roi[i, 1]), 101 | 0)):np.min((int(sample_roi[i, 3]), w))] 102 | gt_roi_mask.append( 103 | cv2.resize(A, (mask_size, mask_size)).astype(np.int32)) 104 | else: 105 | for i, idx in enumerate(gt_assignment[pos_index]): 106 | m = np.zeros((mask_size, mask_size), dtype=np.int32) 107 | # remind: shape of keypoints is (N, 17, 3), N is number of bbox, 17 is number of keypoints, 3 is (x, y, v) 108 | # v=0: unlabeled, v=1, labeled but invisible, v=2 labeled and visible 109 | 110 | # bbox's (y0, x0), (y1, x1) 111 | y0, x0, y1, x1 = list(map(int, sample_roi[i, :4])) 112 | kp = mask[idx] # shape is (17, 3) 113 | # convert keypoints coordinate (y, x) into mask coordinate system [0, mask_size]x[0, mask_size] 114 | kp[:, :2] = (kp[:, :2] - [y0, x0]) / \ 115 | [max(y1 - y0, 1), max(x1 - x0, 1)] * mask_size 116 | # mask_size x mask_size 空間でどこにあるかをラベルとして扱う(あとでsoftmax cross entropyする) 117 | # -1でignoreされる 118 | keypoint_labels = np.zeros(kp.shape[0], dtype=np.int32) 119 | for j, r in enumerate(kp): 120 | y, x, v = list(map(int, r)) 121 | if v == 2 and 0 <= y and y < mask_size and 0 <= x and x < mask_size: 122 | keypoint_labels[j] = y * mask_size + x 123 | 124 | else: 125 | keypoint_labels[j] = -1 126 | 127 | gt_roi_mask.append(keypoint_labels) 128 | 129 | gt_roi_mask = xp.array(gt_roi_mask) 130 | 131 | if xp != np: 132 | sample_roi = cuda.to_gpu(sample_roi) 133 | gt_roi_loc = cuda.to_gpu(gt_roi_loc) 134 | gt_roi_label = cuda.to_gpu(gt_roi_label) 135 | gt_roi_mask = cuda.to_gpu(gt_roi_mask) 136 | sample_levels = cuda.to_gpu(sample_levels) 137 | return sample_roi, sample_levels, gt_roi_loc, gt_roi_label, gt_roi_mask 138 | -------------------------------------------------------------------------------- /chainer_maskrcnn/utils/proposal_creator.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code by chainercv 3 | https://github.com/chainer/chainercv/master/chainercv/links/model/faster_rcnn/utils/proposal_creator.py 4 | modified to adapt multi level features produced by Feature Pyramid Network. 5 | """ 6 | import numpy as np 7 | 8 | import chainer 9 | from chainer import cuda 10 | 11 | from chainercv.links.model.faster_rcnn.utils.loc2bbox import loc2bbox 12 | from chainercv.utils.bbox.non_maximum_suppression import \ 13 | non_maximum_suppression 14 | 15 | 16 | class ProposalCreator(object): 17 | """Proposal regions are generated by calling this object. 18 | 19 | The :meth:`__call__` of this object outputs object detection proposals by 20 | applying estimated bounding box offsets 21 | to a set of anchors. 22 | 23 | This class takes parameters to control number of bounding boxes to 24 | pass to NMS and keep after NMS. 25 | If the paramters are negative, it uses all the bounding boxes supplied 26 | or keep all the bounding boxes returned by NMS. 27 | 28 | This class is used for Region Proposal Networks introduced in 29 | Faster R-CNN [#]_. 30 | 31 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 32 | Faster R-CNN: Towards Real-Time Object Detection with \ 33 | Region Proposal Networks. NIPS 2015. 34 | 35 | Args: 36 | nms_thresh (float): Threshold value used when calling NMS. 37 | n_train_pre_nms (int): Number of top scored bounding boxes 38 | to keep before passing to NMS in train mode. 39 | n_train_post_nms (int): Number of top scored bounding boxes 40 | to keep after passing to NMS in train mode. 41 | n_test_pre_nms (int): Number of top scored bounding boxes 42 | to keep before passing to NMS in test mode. 43 | n_test_post_nms (int): Number of top scored bounding boxes 44 | to keep after passing to NMS in test mode. 45 | force_cpu_nms (bool): If this is :obj:`True`, 46 | always use NMS in CPU mode. If :obj:`False`, 47 | the NMS mode is selected based on the type of inputs. 48 | min_size (int): A paramter to determine the threshold on 49 | discarding bounding boxes based on their sizes. 50 | 51 | """ 52 | 53 | def __init__(self, 54 | nms_thresh=0.7, 55 | n_train_pre_nms=12000, 56 | n_train_post_nms=2000, 57 | n_test_pre_nms=6000, 58 | n_test_post_nms=300, 59 | force_cpu_nms=False, 60 | min_size=16 61 | ): 62 | 63 | self.nms_thresh = nms_thresh 64 | self.n_train_pre_nms = n_train_pre_nms 65 | self.n_train_post_nms = n_train_post_nms 66 | self.n_test_pre_nms = n_test_pre_nms 67 | self.n_test_post_nms = n_test_post_nms 68 | self.force_cpu_nms = force_cpu_nms 69 | self.min_size = min_size 70 | 71 | def __call__(self, loc, score, 72 | anchor, level_indices, img_size, scale=1.): 73 | """Propose RoIs. 74 | 75 | Inputs :obj:`loc, score, anchor` refer to the same anchor when indexed 76 | by the same index. 77 | 78 | On notations, :math:`R` is the total number of anchors. This is equal 79 | to product of the height and the width of an image and the number of 80 | anchor bases per pixel. 81 | 82 | Type of the output is same as the inputs. 83 | 84 | Args: 85 | loc (array): Predicted offsets and scaling to anchors. 86 | Its shape is :math:`(R, 4)`. 87 | score (array): Predicted foreground probability for anchors. 88 | Its shape is :math:`(R,)`. 89 | anchor (array): Coordinates of anchors. Its shape is 90 | :math:`(R, 4)`. 91 | level_indices (array): level index which indicate where this anchor comes from in feature levels proposed by FPN. Its shape is 92 | :math:`(R,)`. 93 | img_size (tuple of ints): A tuple :obj:`height, width`, 94 | which contains image size after scaling. 95 | scale (float): The scaling factor used to scale an image after 96 | reading it from a file. 97 | 98 | Returns: 99 | array: 100 | An array of coordinates of proposal boxes. 101 | Its shape is :math:`(S, 4)`. :math:`S` is less than 102 | :obj:`self.n_test_post_nms` in test time and less than 103 | :obj:`self.n_train_post_nms` in train time. :math:`S` depends on 104 | the size of the predicted bounding boxes and the number of 105 | bounding boxes discarded by NMS. 106 | 107 | """ 108 | if chainer.config.train: 109 | n_pre_nms = self.n_train_pre_nms 110 | n_post_nms = self.n_train_post_nms 111 | else: 112 | n_pre_nms = self.n_test_pre_nms 113 | n_post_nms = self.n_test_post_nms 114 | 115 | xp = cuda.get_array_module(loc) 116 | loc = cuda.to_cpu(loc) 117 | score = cuda.to_cpu(score) 118 | anchor = cuda.to_cpu(anchor) 119 | level_indices = cuda.to_cpu(level_indices) 120 | assert score.shape == level_indices.shape, ( 121 | score.shape, level_indices.shape) 122 | 123 | # Convert anchors into proposal via bbox transformations. 124 | roi = loc2bbox(anchor, loc) 125 | 126 | # Clip predicted boxes to image. 127 | roi[:, slice(0, 4, 2)] = np.clip( 128 | roi[:, slice(0, 4, 2)], 0, img_size[0]) 129 | roi[:, slice(1, 4, 2)] = np.clip( 130 | roi[:, slice(1, 4, 2)], 0, img_size[1]) 131 | 132 | # Remove predicted boxes with either height or width < threshold. 133 | min_size = self.min_size * scale 134 | hs = roi[:, 2] - roi[:, 0] 135 | ws = roi[:, 3] - roi[:, 1] 136 | keep = np.where((hs >= min_size) & (ws >= min_size))[0] 137 | roi = roi[keep, :] 138 | score = score[keep] 139 | level_indices = level_indices[keep] 140 | 141 | # Sort all (proposal, score) pairs by score from highest to lowest. 142 | # Take top pre_nms_topN (e.g. 6000). 143 | order = score.ravel().argsort()[::-1] 144 | if n_pre_nms > 0: 145 | order = order[:n_pre_nms] 146 | roi = roi[order, :] 147 | score = score[order] 148 | level_indices = level_indices[order] 149 | 150 | # Apply nms (e.g. threshold = 0.7). 151 | # Take after_nms_topN (e.g. 300). 152 | if xp != np and not self.force_cpu_nms: 153 | keep = non_maximum_suppression( 154 | cuda.to_gpu(roi), 155 | thresh=self.nms_thresh) 156 | keep = cuda.to_cpu(keep) 157 | else: 158 | keep = non_maximum_suppression( 159 | roi, 160 | thresh=self.nms_thresh) 161 | if n_post_nms > 0: 162 | keep = keep[:n_post_nms] 163 | roi = roi[keep] 164 | level_indices = level_indices[keep] 165 | 166 | if xp != np: 167 | roi = cuda.to_gpu(roi) 168 | level_indices = cuda.to_gpu(level_indices) 169 | return roi, level_indices 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer.datasets import TransformDataset 3 | from chainer.training import extensions 4 | from chainercv import transforms 5 | from chainercv.extensions.evaluator.instance_segmentation_voc_evaluator import InstanceSegmentationVOCEvaluator 6 | from chainerui.utils import save_args 7 | from chainerui.extensions import CommandsExtension 8 | import cv2 9 | import numpy as np 10 | from chainer_maskrcnn.model.fpn_maskrcnn_train_chain import FPNMaskRCNNTrainChain 11 | from chainer_maskrcnn.model.maskrcnn import MaskRCNN 12 | from chainer_maskrcnn.dataset.coco_dataset import COCOMaskLoader 13 | 14 | import argparse 15 | from os.path import exists, isfile 16 | import time 17 | import _pickle as pickle 18 | import warnings 19 | 20 | 21 | class Transform(object): 22 | def __init__(self, faster_rcnn): 23 | self.faster_rcnn = faster_rcnn 24 | 25 | def __call__(self, in_data): 26 | img, bbox, label, label_img = in_data 27 | _, H, W = img.shape 28 | img = self.faster_rcnn.prepare(img) 29 | _, o_H, o_W = img.shape 30 | scale = o_H / H 31 | bbox = transforms.resize_bbox(bbox, (H, W), (o_H, o_W)) 32 | bbox[:, 2:] = np.maximum(bbox[:, 2:], bbox[:, 2:] + 1) 33 | for i, im in enumerate(label_img): 34 | label_img[i] = cv2.resize( 35 | im, (o_W, o_H), interpolation=cv2.INTER_NEAREST) 36 | 37 | return img, bbox, label, label_img, scale 38 | 39 | 40 | class EvaluatorTransform(): 41 | """ 42 | chainercvのevalに合うようにTransform 43 | """ 44 | 45 | def __call__(self, in_data): 46 | i, b, l, m = in_data 47 | return i, np.stack(m), l 48 | 49 | 50 | def calc_mask_loss(roi_cls_mask, gt_roi_mask, xp, gt_roi_label): 51 | """ 52 | mask loss 53 | https://engineer.dena.jp/2017/12/chainercvmask-r-cnn.html 54 | """ 55 | roi_mask = roi_cls_mask[xp.arange( 56 | roi_cls_mask.shape[0]), gt_roi_label - 1] 57 | return chainer.functions.sigmoid_cross_entropy(roi_mask[:gt_roi_mask.shape[0]], 58 | gt_roi_mask) 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser(description='Mask R-CNN') 63 | parser.add_argument('--gpu', '-g', type=int, default=0) 64 | parser.add_argument('--lr', '-l', type=float, default=1e-3) 65 | parser.add_argument( 66 | '--out', '-o', default='result', help='Output directory') 67 | parser.add_argument('--iteration', '-i', type=int, default=200000) 68 | parser.add_argument('--weight', '-w', type=str, default='') 69 | parser.add_argument( 70 | '--label_file', '-f', type=str, default='data/label_coco.txt') 71 | parser.add_argument('--backbone', type=str, default='fpn') 72 | parser.add_argument('--head-arch', '-a', type=str, default='fpn') 73 | parser.add_argument('--multi-gpu', '-m', type=int, default=0) 74 | parser.add_argument('--batch-size', '-b', type=int, default=1) 75 | 76 | args = parser.parse_args() 77 | 78 | print('lr:{}'.format(args.lr)) 79 | print('output:{}'.format(args.out)) 80 | print('weight:{}'.format(args.weight)) 81 | print('label file:{}'.format(args.label_file)) 82 | print('iteration::{}'.format(args.iteration)) 83 | print('backbone architecture:{}'.format(args.backbone)) 84 | print('head architecture:{}'.format(args.head_arch)) 85 | 86 | if args.multi_gpu: 87 | print('try to use chainer.training.updaters.MultiprocessParallelUpdater') 88 | if not chainer.training.updaters.MultiprocessParallelUpdater.available(): 89 | print('MultiprocessParallelUpdater is not available') 90 | args.multi_gpu = 0 91 | 92 | with open(args.label_file, "r") as f: 93 | labels = f.read().strip().split("\n") 94 | 95 | faster_rcnn = MaskRCNN( 96 | n_fg_class=len(labels), backbone=args.backbone, head_arch=args.head_arch) 97 | faster_rcnn.use_preset('evaluate') 98 | model = FPNMaskRCNNTrainChain(faster_rcnn, mask_loss_fun=calc_mask_loss) 99 | if exists(args.weight): 100 | chainer.serializers.load_npz( 101 | args.weight, model.faster_rcnn, strict=False) 102 | 103 | if args.gpu >= 0: 104 | chainer.cuda.get_device_from_id(args.gpu).use() 105 | model.to_gpu() 106 | 107 | optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9) 108 | optimizer.setup(model) 109 | optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005)) 110 | 111 | train_data = COCOMaskLoader(category_filter=labels, data_type='2017') 112 | train_data = TransformDataset(train_data, Transform(faster_rcnn)) 113 | test_data = COCOMaskLoader( 114 | category_filter=labels, data_type='2017', split='val') 115 | test_data = TransformDataset(test_data, EvaluatorTransform()) 116 | 117 | if args.multi_gpu: 118 | train_iters = [chainer.iterators.SerialIterator( 119 | train_data, 1, repeat=True, shuffle=True) for i in range(8)] 120 | updater = chainer.training.updater.MultiprocessParallelUpdater( 121 | train_iters, optimizer, device=range(8)) 122 | 123 | else: 124 | train_iter = chainer.iterators.MultithreadIterator( 125 | train_data, batch_size=args.batch_size, repeat=True, shuffle=False) 126 | test_iter = chainer.iterators.SerialIterator( 127 | test_data, batch_size=args.batch_size, repeat=False, shuffle=False) 128 | updater = chainer.training.updater.StandardUpdater( 129 | train_iter, optimizer, device=args.gpu) 130 | 131 | trainer = chainer.training.Trainer(updater, (args.iteration, 'iteration'), 132 | args.out) 133 | 134 | trainer.extend( 135 | extensions.snapshot_object(model.faster_rcnn, 136 | 'model_{.updater.iteration}.npz'), 137 | trigger=(5000, 'iteration')) 138 | 139 | trainer.extend( 140 | extensions.ExponentialShift('lr', 0.1), trigger=(2, 'epoch')) 141 | 142 | log_interval = 100, 'iteration' 143 | trainer.extend( 144 | chainer.training.extensions.observe_lr(), trigger=log_interval) 145 | trainer.extend(extensions.LogReport(trigger=log_interval)) 146 | trainer.extend( 147 | extensions.PrintReport([ 148 | 'iteration', 149 | 'epoch', 150 | 'elapsed_time', 151 | 'lr', 152 | 'main/loss', 153 | 'main/mask_loss', 154 | 'main/roi_loc_loss', 155 | 'main/roi_cls_loss', 156 | 'main/rpn_loc_loss', 157 | 'main/rpn_cls_loss', 158 | 'validation/main/map' 159 | ]), 160 | trigger=(100, 'iteration')) 161 | trainer.extend(extensions.ProgressBar(update_interval=200)) 162 | trainer.extend(extensions.dump_graph('main/loss')) 163 | 164 | evaluator = InstanceSegmentationVOCEvaluator( 165 | test_iter, model.faster_rcnn, label_names=labels) 166 | trainer.extend(evaluator, trigger=(10000, 'iteration')) 167 | 168 | save_args(args, args.out) 169 | trainer.extend(CommandsExtension(), trigger=(100, 'iteration')) 170 | 171 | try: 172 | np.seterr(all='warn') 173 | trainer.run() 174 | except Warning: 175 | import pdb 176 | pdb.set_trace() 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/rpn/multilevel_region_proposal_network.py: -------------------------------------------------------------------------------- 1 | # original code by chainercv 2 | # https://github.com/chainer/chainercv/blob/master/chainercv/links/model/faster_rcnn/region_proposal_network.py 3 | import numpy as np 4 | import chainer 5 | import chainer.links as L 6 | import chainer.functions as F 7 | from chainercv.links.model.faster_rcnn.region_proposal_network import _enumerate_shifted_anchor 8 | 9 | from chainercv.links.model.faster_rcnn.utils.generate_anchor_base import \ 10 | generate_anchor_base 11 | from chainercv.links.model.faster_rcnn.utils.proposal_creator import ProposalCreator 12 | 13 | 14 | # original code from Detectron 15 | # https://github.com/facebookresearch/Detectron/blob/master/lib/modeling/FPN.py 16 | def map_rois_to_fpn_levels(rois, k_min=0, k_max=4): 17 | """Determine which FPN level each RoI in a set of RoIs should map to based 18 | on the heuristic in the FPN paper. 19 | roi: assume (R, 4), y_min, x_min, y_max, x_max 20 | """ 21 | # Compute level ids 22 | xp = chainer.backends.cuda.get_array_module(rois) 23 | area = xp.prod(rois[:, 2:] - rois[:, :2], axis=1) 24 | s = xp.sqrt(area) 25 | s0 = 224 26 | lvl0 = 4 27 | 28 | # Eqn.(1) in FPN paper 29 | target_lvls = xp.floor(lvl0 + xp.log2(s / s0 + 1e-6)) 30 | target_lvls = xp.clip(target_lvls, k_min, k_max) 31 | return target_lvls 32 | 33 | 34 | class MultilevelRegionProposalNetwork(chainer.Chain): 35 | 36 | """Region Proposal Network introduced in Faster R-CNN. 37 | This is Region Proposal Network introduced in Faster R-CNN [#]_. 38 | This takes features extracted from images and propose 39 | class agnostic bounding boxes around "objects". 40 | .. [#] Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun. \ 41 | Faster R-CNN: Towards Real-Time Object Detection with \ 42 | Region Proposal Networks. NIPS 2015. 43 | Args: 44 | in_channels (int): The channel size of input. 45 | mid_channels (int): The channel size of the intermediate tensor. 46 | ratios (list of floats): This is ratios of width to height of 47 | the anchors. 48 | 49 | anchor_scales (list of numbers): This is areas of anchors. 50 | Those areas will be the product of the square of an element in 51 | :obj:`anchor_scales` and the original area of the reference 52 | window. 53 | initialW (callable): Initial weight value. If :obj:`None` then this 54 | function uses Gaussian distribution scaled by 0.1 to 55 | initialize weight. 56 | May also be a callable that takes an array and edits its values. 57 | proposal_creator_params (dict): Key valued paramters for 58 | :class:`~chainercv.links.model.faster_rcnn.ProposalCreator`. 59 | .. seealso:: 60 | :class:`~chainercv.links.model.faster_rcnn.ProposalCreator` 61 | """ 62 | 63 | def __init__( 64 | self, anchor_scales, feat_strides, in_channels=256, mid_channels=256, ratios=[0.5, 1, 2], 65 | initialW=None, 66 | proposal_creator_params=dict()): 67 | if len(anchor_scales) != len(feat_strides): 68 | raise ValueError( 69 | 'length of anchor_scales and feat_strides should be same!') 70 | self.anchor_bases = [generate_anchor_base( 71 | anchor_scales=[s], ratios=ratios) for s in anchor_scales] 72 | self.feat_strides = feat_strides 73 | self.proposal_layer = ProposalCreator(**proposal_creator_params) 74 | # note: to share conv layers, number of output channel should be same in all levels. 75 | # debug 76 | a = self.anchor_bases[0] 77 | for ab in self.anchor_bases[1:]: 78 | assert a.shape == ab.shape 79 | n_anchor = self.anchor_bases[0].shape[0] 80 | super(MultilevelRegionProposalNetwork, self).__init__() 81 | with self.init_scope(): 82 | # note: according fpn paper, parameters are sharable among levels. 83 | self.conv = L.Convolution2D( 84 | in_channels, mid_channels, 3, 1, 1, initialW=initialW) 85 | self.score = L.Convolution2D( 86 | mid_channels, n_anchor * 2, 1, 1, 0, initialW=initialW) 87 | self.loc = L.Convolution2D( 88 | mid_channels, n_anchor * 4, 1, 1, 0, initialW=initialW) 89 | 90 | def __call__(self, xs, img_size, scale=1.): 91 | """Forward Region Proposal Network. 92 | Here are notations. 93 | * :math:`N` is batch size. 94 | * :math:`C` channel size of the input. 95 | * :math:`H` and :math:`W` are height and witdh of the input feature. 96 | * :math:`A` is number of anchors assigned to each pixel. 97 | Args: 98 | xs (list of ~chainer.Variable): The Features extracted from images in multilevel. 99 | img_size (tuple of ints): A tuple :obj:`height, width`, 100 | which contains image size after scaling. 101 | scale (float): The amount of scaling done to the input images after 102 | reading them from files. 103 | Returns: 104 | (~chainer.Variable, ~chainer.Variable, array, array, array): 105 | This is a tuple of five following values. 106 | * **rpn_locs**: Predicted bounding box offsets and scales for \ 107 | anchors. Its shape is :math:`(N, H W A, 4)`. 108 | * **rpn_scores**: Predicted foreground scores for \ 109 | anchors. Its shape is :math:`(N, H W A, 2)`. 110 | * **rois**: A bounding box array containing coordinates of \ 111 | proposal boxes. This is a concatenation of bounding box \ 112 | arrays from multiple images in the batch. \ 113 | Its shape is :math:`(R', 4)`. Given :math:`R_i` predicted \ 114 | bounding boxes from the :math:`i` th image, \ 115 | :math:`R' = \\sum _{i=1} ^ N R_i`. 116 | * **roi_indices**: An array containing indices of images to \ 117 | which RoIs correspond to. Its shape is :math:`(R',)`. 118 | * **anchor**: Coordinates of enumerated shifted anchors. \ 119 | Its shape is :math:`(H W A, 4)`. 120 | """ 121 | 122 | locs = [] 123 | scores = [] 124 | fg_scores = [] 125 | anchors = [] 126 | for i, x in enumerate(xs): 127 | n, _, hh, ww = x.shape 128 | anchor = _enumerate_shifted_anchor( 129 | self.xp.array(self.anchor_bases[i]), self.feat_strides[i], hh, ww) 130 | n_anchor = anchor.shape[0] // (hh * ww) 131 | h = F.relu(self.conv(x)) 132 | 133 | rpn_locs = self.loc(h) 134 | rpn_locs = rpn_locs.transpose((0, 2, 3, 1)).reshape((n, -1, 4)) 135 | 136 | rpn_scores = self.score(h) 137 | rpn_scores = rpn_scores.transpose((0, 2, 3, 1)) 138 | rpn_fg_scores =\ 139 | rpn_scores.reshape((n, hh, ww, n_anchor, 2))[:, :, :, :, 1] 140 | rpn_fg_scores = rpn_fg_scores.reshape((n, -1)) 141 | rpn_scores = rpn_scores.reshape((n, -1, 2)) 142 | 143 | locs.append(rpn_locs) 144 | scores.append(rpn_scores) 145 | fg_scores.append(rpn_fg_scores) 146 | anchors.append(anchor) 147 | 148 | # chainer.functions's default axis=1, but explicitly for myself. 149 | locs = F.concat(locs, axis=1) 150 | scores = F.concat(scores, axis=1) 151 | fg_scores = F.concat(fg_scores, axis=1) 152 | anchors = self.xp.concatenate(anchors, axis=0) 153 | 154 | rois = [] 155 | roi_indices = [] 156 | for i in range(n): 157 | roi = self.proposal_layer( 158 | locs[i].array, fg_scores[i].array, anchors, img_size, scale=scale) 159 | batch_index = i * self.xp.ones((len(roi),), dtype=np.int32) 160 | rois.append(roi) 161 | roi_indices.append(batch_index) 162 | 163 | rois = self.xp.concatenate(rois, axis=0) 164 | levels = map_rois_to_fpn_levels(rois) 165 | roi_indices = self.xp.concatenate(roi_indices, axis=0) 166 | return locs, scores, rois, roi_indices, anchors, levels 167 | -------------------------------------------------------------------------------- /train_keypoints.py: -------------------------------------------------------------------------------- 1 | import chainer 2 | from chainer.datasets import TransformDataset 3 | from chainer.training import extensions 4 | from chainercv import transforms 5 | from chainerui.utils import save_args 6 | from chainerui.extensions import CommandsExtension 7 | import cv2 8 | import numpy as np 9 | from chainer_maskrcnn.model.fpn_maskrcnn_train_chain import FPNMaskRCNNTrainChain 10 | from chainer_maskrcnn.model.maskrcnn import MaskRCNN 11 | from chainer_maskrcnn.dataset.coco_dataset import COCOKeypointsLoader 12 | from chainer_maskrcnn.dataset.depth_dataset import DepthDataset 13 | from chainer_maskrcnn.utils.depth_transformer import DepthTransformer 14 | 15 | import argparse 16 | from os.path import exists, isfile 17 | import time 18 | import _pickle as pickle 19 | 20 | 21 | def calc_mask_loss(roi_cls_mask, gt_roi_mask, xp, gt_roi_label, num_keypoints=17): 22 | # 出力を (n_proposals, 17, mask_size, mask_size) から (n_positive_sample *17, mask_size*mask_size) にreshapeして、softmax crossentropyを取る 23 | num_positives = gt_roi_mask.shape[0] 24 | roi_mask = roi_cls_mask[:num_positives].reshape( 25 | (num_positives * num_keypoints, -1)) 26 | gt_roi_mask = gt_roi_mask.reshape((-1,)) 27 | return chainer.functions.softmax_cross_entropy(roi_mask, gt_roi_mask) 28 | 29 | 30 | def load_dataset(dataset, file): 31 | if isfile(file): 32 | print('pklから読み込みます') 33 | dataload_start = time.time() 34 | with open(file, 'rb') as f: 35 | train_data = pickle.load(f) 36 | dataload_end = time.time() 37 | print('pklからの読み込み {}'.format(dataload_end - dataload_start)) 38 | else: 39 | dataload_start = time.time() 40 | train_data = dataset() 41 | dataload_end = time.time() 42 | print('普通の読み込み {}'.format(dataload_end - dataload_start)) 43 | if file is not '': 44 | print('次回のために保存します') 45 | with open(file, 'wb') as f: 46 | pickle.dump(train_data, f) 47 | return train_data 48 | 49 | 50 | class Transform(): 51 | def __init__(self, faster_rcnn): 52 | self.faster_rcnn = faster_rcnn 53 | 54 | def __call__(self, in_data): 55 | img, bbox, keypoints = in_data 56 | _, H, W = img.shape 57 | img = self.faster_rcnn.prepare(img) 58 | _, o_H, o_W = img.shape 59 | scale = o_H / H 60 | 61 | bbox = transforms.resize_bbox(bbox, (H, W), (o_H, o_W)) 62 | label = np.zeros(bbox.shape[0], dtype=np.int32) 63 | # shape of keypoints is (N, 17, 3), N is number of bbox, 17 is number of keypoints, 3 is (x, y, v) 64 | # v=0: unlabeled, v=1, labeled but invisible, v=2 labeled and visible 65 | keypoints = keypoints.astype(np.float32) 66 | kp = keypoints[:, :, [1, 0]] 67 | kp = np.concatenate([kp * scale, keypoints[:, :, 2, None]], axis=2) 68 | 69 | return img, bbox, label, kp, scale 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description='Mask R-CNN') 74 | parser.add_argument('--gpu', '-g', type=int, default=0) 75 | parser.add_argument('--lr', '-l', type=float, default=1e-3) 76 | parser.add_argument( 77 | '--out', '-o', default='result', help='Output directory') 78 | parser.add_argument('--iteration', '-i', type=int, default=200000) 79 | parser.add_argument('--weight', '-w', type=str, default='') 80 | parser.add_argument( 81 | '--label_file', '-f', type=str, default='data/label_coco.txt') 82 | parser.add_argument('--backbone', type=str, default='fpn') 83 | parser.add_argument('--head_arch', '-a', type=str, default='fpn_keypoint') 84 | parser.add_argument('--multi_gpu', '-m', type=int, default=0) 85 | parser.add_argument('--batch_size', '-b', type=int, default=1) 86 | parser.add_argument('--dataset', default='coco', choices=['coco', 'depth']) 87 | parser.add_argument('--n_mask_convs', type=int, default=None) 88 | parser.add_argument('--min_size', type=int, default=600) 89 | parser.add_argument('--max_size', type=int, default=1000) 90 | 91 | args = parser.parse_args() 92 | 93 | print('lr:{}'.format(args.lr)) 94 | print('output:{}'.format(args.out)) 95 | print('weight:{}'.format(args.weight)) 96 | print('label file:{}'.format(args.label_file)) 97 | print('iteration::{}'.format(args.iteration)) 98 | print('backbone architecture:{}'.format(args.backbone)) 99 | print('head architecture:{}'.format(args.head_arch)) 100 | 101 | if args.dataset == 'coco': 102 | train_data = load_dataset(COCOKeypointsLoader, 'train_data_kp.pkl') 103 | n_keypoints = train_data.n_keypoints 104 | elif args.dataset == 'depth': 105 | train_data = load_dataset( 106 | lambda: DepthDataset(path='data/rgbd/train.txt', root='data/rgbd/'), '') 107 | n_keypoints = train_data.n_keypoints 108 | train_data = chainer.datasets.TransformDataset( 109 | train_data, DepthTransformer()) 110 | print(f'number of keypoints={n_keypoints}') 111 | 112 | if args.multi_gpu: 113 | print('try to use chainer.training.updaters.MultiprocessParallelUpdater') 114 | if not chainer.training.updaters.MultiprocessParallelUpdater.available(): 115 | print('MultiprocessParallelUpdater is not available') 116 | args.multi_gpu = 0 117 | 118 | faster_rcnn = MaskRCNN( 119 | n_fg_class=1, backbone=args.backbone, head_arch=args.head_arch, 120 | n_keypoints=n_keypoints, n_mask_convs=args.n_mask_convs, min_size=args.min_size, max_size=args.max_size) 121 | faster_rcnn.use_preset('evaluate') 122 | model = FPNMaskRCNNTrainChain( 123 | faster_rcnn, mask_loss_fun=lambda x, y, z, w: calc_mask_loss(x, y, z, w, num_keypoints=n_keypoints), binary_mask=False) 124 | if exists(args.weight): 125 | chainer.serializers.load_npz( 126 | args.weight, model.faster_rcnn, strict=False) 127 | 128 | if args.gpu >= 0: 129 | chainer.cuda.get_device_from_id(args.gpu).use() 130 | model.to_gpu() 131 | 132 | optimizer = chainer.optimizers.MomentumSGD(lr=args.lr, momentum=0.9) 133 | optimizer.setup(model) 134 | optimizer.add_hook(chainer.optimizer.WeightDecay(rate=0.0005)) 135 | 136 | # TransformでFaster-RCNNのprepareを参照するので、初期化順が複雑に入り組んでしまったなー 137 | train_data = TransformDataset(train_data, Transform(faster_rcnn)) 138 | if args.multi_gpu: 139 | train_iters = [chainer.iterators.SerialIterator( 140 | train_data, 1, repeat=True, shuffle=True) for i in range(8)] 141 | updater = chainer.training.updater.MultiprocessParallelUpdater( 142 | train_iters, optimizer, device=range(8)) 143 | 144 | else: 145 | train_iter = chainer.iterators.SerialIterator( 146 | train_data, batch_size=args.batch_size, repeat=True, shuffle=True) 147 | updater = chainer.training.updater.StandardUpdater( 148 | train_iter, optimizer, device=args.gpu) 149 | 150 | trainer = chainer.training.Trainer(updater, (args.iteration, 'iteration'), 151 | args.out) 152 | 153 | trainer.extend( 154 | extensions.snapshot_object(model.faster_rcnn, 155 | 'model_{.updater.iteration}.npz'), 156 | trigger=(20000, 'iteration')) 157 | 158 | trainer.extend( 159 | extensions.ExponentialShift('lr', 0.1), trigger=(3, 'epoch')) 160 | 161 | log_interval = 100, 'iteration' 162 | trainer.extend( 163 | chainer.training.extensions.observe_lr(), trigger=log_interval) 164 | trainer.extend(extensions.LogReport(trigger=log_interval)) 165 | trainer.extend( 166 | extensions.PrintReport([ 167 | 'iteration', 168 | 'epoch', 169 | 'elapsed_time', 170 | 'lr', 171 | 'main/loss', 172 | 'main/mask_loss', 173 | 'main/roi_loc_loss', 174 | 'main/roi_cls_loss', 175 | 'main/rpn_loc_loss', 176 | 'main/rpn_cls_loss', 177 | ]), 178 | trigger=(100, 'iteration')) 179 | trainer.extend(extensions.ProgressBar(update_interval=200)) 180 | trainer.extend(extensions.dump_graph('main/loss')) 181 | 182 | save_args(args, args.out) 183 | trainer.extend(CommandsExtension(), trigger=(100, 'iteration')) 184 | 185 | trainer.run() 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /chainer_maskrcnn/model/maskrcnn.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import chainer 4 | from chainer import cuda 5 | import chainer.functions as F 6 | from chainercv.links.model.faster_rcnn.faster_rcnn import FasterRCNN 7 | from chainercv.links.model.faster_rcnn.region_proposal_network import \ 8 | RegionProposalNetwork 9 | from chainercv.links.model.faster_rcnn.utils.loc2bbox import loc2bbox 10 | from chainercv.transforms.image.resize import resize 11 | from chainercv.utils import non_maximum_suppression 12 | from .extractor.c4_backbone import C4Backbone 13 | from .extractor.feature_pyramid_network import FeaturePyramidNetwork 14 | from .extractor.darknet import Darknet 15 | from .rpn.multilevel_region_proposal_network import MultilevelRegionProposalNetwork 16 | from .head.resnet_roi_mask_head import ResnetRoIMaskHead 17 | from .head.light_roi_mask_head import LightRoIMaskHead 18 | from .head.fpn_roi_mask_head import FPNRoIMaskHead 19 | from .head.fpn_roi_keypoint_head import FPNRoIKeypointHead 20 | import cv2 21 | 22 | 23 | class MaskRCNN(FasterRCNN): 24 | feat_stride = 16 25 | 26 | def __init__(self, 27 | n_fg_class, 28 | n_keypoints=None, 29 | n_mask_convs=None, 30 | pretrained_model=None, 31 | min_size=600, 32 | max_size=1000, 33 | ratios=[0.5, 1, 2], 34 | anchor_scales=[8], 35 | rpn_initialW=None, 36 | loc_initialW=None, 37 | score_initialW=None, 38 | proposal_creator_params={}, 39 | backbone='fpn', 40 | head_arch='fpn'): 41 | if n_fg_class is None: 42 | raise ValueError( 43 | 'The n_fg_class needs to be supplied as an argument') 44 | 45 | if loc_initialW is None: 46 | loc_initialW = chainer.initializers.Normal(0.001) 47 | if score_initialW is None: 48 | score_initialW = chainer.initializers.Normal(0.01) 49 | if rpn_initialW is None: 50 | rpn_initialW = chainer.initializers.Normal(0.01) 51 | 52 | if backbone == 'fpn': 53 | extractor = FeaturePyramidNetwork() 54 | print('feat_strides:', extractor.feat_strides, 55 | 'spatial_scales:', extractor.spatial_scales) 56 | rpn = MultilevelRegionProposalNetwork( 57 | anchor_scales=extractor.anchor_scales, feat_strides=extractor.feat_strides) 58 | elif backbone == 'c4': 59 | extractor = C4Backbone('auto') 60 | rpn = RegionProposalNetwork( 61 | 1024, 62 | 516, 63 | ratios=ratios, 64 | anchor_scales=anchor_scales, 65 | feat_stride=self.feat_stride, 66 | initialW=rpn_initialW, 67 | proposal_creator_params=proposal_creator_params, 68 | ) 69 | elif backbone == 'darknet': 70 | extractor = Darknet() 71 | rpn = MultilevelRegionProposalNetwork( 72 | anchor_scales=extractor.anchor_scales, feat_strides=extractor.feat_strides, in_channels=256, 73 | proposal_creator_params={'n_test_pre_nms': 50, 74 | 'n_test_post_nms': 10}) 75 | else: 76 | raise ValueError( 77 | 'unknown backbone: {}'.format(backbone)) 78 | 79 | if head_arch == 'res5': 80 | head = ResnetRoIMaskHead( 81 | n_fg_class + 1, 82 | roi_size=7, 83 | spatial_scale=1. / self.feat_stride, 84 | loc_initialW=loc_initialW, 85 | score_initialW=score_initialW, 86 | mask_initialW=chainer.initializers.Normal(0.01)) 87 | self.predict_mask = True 88 | 89 | elif head_arch == 'light': 90 | head = LightRoIMaskHead( 91 | n_fg_class + 1, 92 | roi_size=7, 93 | loc_initialW=loc_initialW, 94 | score_initialW=score_initialW, 95 | mask_initialW=chainer.initializers.Normal(0.01)) 96 | self.predict_mask = True 97 | elif head_arch == 'fpn': 98 | head = FPNRoIMaskHead( 99 | n_fg_class + 1, 100 | roi_size_box=7, 101 | roi_size_mask=14, 102 | loc_initialW=loc_initialW, 103 | score_initialW=score_initialW, 104 | mask_initialW=chainer.initializers.Normal(0.01)) 105 | self.predict_mask = True 106 | elif head_arch == 'fpn_keypoint': 107 | if n_keypoints == None: 108 | raise ValueError( 109 | 'n_keypoints must be set in keypoint detection') 110 | if n_mask_convs == None: 111 | n_mask_convs = 8 112 | head = FPNRoIKeypointHead( 113 | 2, 114 | n_keypoints, 115 | roi_size_box=7, 116 | roi_size_mask=14, 117 | n_mask_convs=n_mask_convs, 118 | loc_initialW=loc_initialW, 119 | score_initialW=score_initialW, 120 | mask_initialW=chainer.initializers.Normal(0.01)) 121 | self.predict_mask = False 122 | else: 123 | raise ValueError( 124 | 'unknown head archtecture specified. {}'.format(head_arch)) 125 | 126 | super().__init__( 127 | extractor, 128 | rpn, 129 | head, 130 | mean=np.array([122.7717, 115.9465, 102.9801], 131 | dtype=np.float32)[:, None, None], 132 | min_size=min_size, 133 | max_size=max_size) 134 | 135 | def __call__(self, x, scale=1.): 136 | img_size = x.shape[2:] 137 | 138 | h = self.extractor(x) 139 | rpn_locs, rpn_scores, rois, roi_indices, anchor, levels =\ 140 | self.rpn(h, img_size, scale) 141 | levels = np.clip(levels, 0, len(h) - 1) 142 | 143 | # join roi and index of batch 144 | roi_indices = roi_indices.astype(np.float32) 145 | indices_and_rois = self.xp.concatenate( 146 | (roi_indices[:, None], rois), axis=1) 147 | 148 | if chainer.config.train: 149 | roi_cls_locs, roi_scores, mask = self.head( 150 | h, indices_and_rois, levels, self.extractor.spatial_scales) 151 | return roi_cls_locs, roi_scores, rois, roi_indices, mask 152 | else: 153 | roi_cls_locs, roi_scores = self.head( 154 | h, indices_and_rois, levels, self.extractor.spatial_scales) 155 | return roi_cls_locs, roi_scores, rois, roi_indices, levels 156 | 157 | def predict(self, imgs): 158 | prepared_imgs = [] 159 | sizes = [] 160 | for img in imgs: 161 | size = img.shape[1:] 162 | img = self.prepare(img.astype(np.float32)) 163 | prepared_imgs.append(img) 164 | sizes.append(size) 165 | 166 | bboxes = [] 167 | labels = [] 168 | scores = [] 169 | masks = [] 170 | for img, size in zip(prepared_imgs, sizes): 171 | with chainer.using_config('train', False), \ 172 | chainer.using_config('enable_backprop', False): 173 | img_var = chainer.Variable(self.xp.asarray(img[None])) 174 | scale = img_var.shape[3] / size[1] 175 | roi_cls_locs, roi_scores, rois, roi_indices, levels = self.__call__( 176 | img_var, scale=scale) 177 | # We are assuming that batch size is 1. 178 | roi = rois / scale 179 | roi_cls_loc = roi_cls_locs.data 180 | roi_score = roi_scores.data 181 | 182 | if roi_cls_loc.shape[1] == 4: 183 | roi_cls_loc = self.xp.tile(roi_cls_loc, self.n_class) 184 | 185 | # if loc prediction layer uses shared weight, expand (though, not optimized way) 186 | if roi_cls_loc.shape[1] == 4: 187 | roi_cls_loc = self.xp.tile(roi_cls_loc, self.n_class) 188 | 189 | # Convert predictions to bounding boxes in image coordinates. 190 | # Bounding boxes are scaled to the scale of the input images. 191 | mean = self.xp.tile( 192 | self.xp.asarray(self.loc_normalize_mean), self.n_class) 193 | std = self.xp.tile( 194 | self.xp.asarray(self.loc_normalize_std), self.n_class) 195 | roi_cls_loc = (roi_cls_loc * std + mean).astype(np.float32) 196 | roi_cls_loc = roi_cls_loc.reshape((-1, self.n_class, 4)) 197 | roi = self.xp.broadcast_to(roi[:, None], roi_cls_loc.shape) 198 | cls_bbox = loc2bbox( 199 | roi.reshape((-1, 4)), roi_cls_loc.reshape((-1, 4))) 200 | cls_bbox = cls_bbox.reshape((-1, self.n_class * 4)) 201 | # clip bounding box 202 | cls_bbox[:, 0::2] = self.xp.clip(cls_bbox[:, 0::2], 0, size[0]) 203 | cls_bbox[:, 1::2] = self.xp.clip(cls_bbox[:, 1::2], 0, size[1]) 204 | 205 | prob = F.softmax(roi_score).data 206 | 207 | raw_cls_bbox = cuda.to_cpu(cls_bbox) 208 | raw_prob = cuda.to_cpu(prob) 209 | raw_roi = cuda.to_cpu(roi) 210 | raw_levels = cuda.to_cpu(levels) 211 | 212 | bbox, label, score, roi, levels = self._suppress(raw_cls_bbox, raw_prob, 213 | raw_roi, raw_levels) 214 | 215 | # predict only mask based on detected roi 216 | mask_per_image = np.zeros( 217 | (len(bbox),) + size, dtype=np.bool) if self.predict_mask else [] 218 | if len(label) > 0: 219 | with chainer.using_config('train', False), \ 220 | chainer.using_config('enable_backprop', False): 221 | # because we are assuming batch size=1, all elements of roi_indices is zero. 222 | roi_indices = self.xp.zeros(roi.shape[0], dtype=np.float32) 223 | bbox_gpu = cuda.to_gpu( 224 | bbox) if chainer.cuda.available else bbox 225 | indices_and_rois = self.xp.concatenate( 226 | (roi_indices[:, None], bbox_gpu * scale), axis=1) 227 | 228 | mask = self.head.predict_mask( 229 | levels, indices_and_rois, self.extractor.spatial_scales) 230 | 231 | if self.predict_mask: 232 | mask = F.sigmoid(mask).data 233 | mask = mask[np.arange(mask.shape[0]), label] 234 | mask = cuda.to_cpu(mask) 235 | # maskをresizeして、元の画像と同じサイズのmask画像を作る 236 | for i, (b, m) in enumerate(zip(bbox, mask)): 237 | w = b[3] - b[1] 238 | h = b[2] - b[0] 239 | m = cv2.resize(m, (w, h)) * 255 240 | m = m.astype(np.uint8) 241 | _, m = cv2.threshold(m, 127, 255, cv2.THRESH_BINARY) 242 | mask_h, mask_w = m.shape 243 | s, t = b[:2].astype(np.int32) 244 | # todo: 本当はb[0]+mask_h == b[2] になるはずなんだけど、1pxずれたりする? 245 | mask_per_image[i, s:(s+mask_h), 246 | t:(t+mask_w)] = m.astype(np.bool) 247 | 248 | else: 249 | mask = mask.reshape((mask.shape[0], 20, -1)).data 250 | mask = cuda.to_cpu(mask) 251 | mask_per_image.append(mask) 252 | 253 | bboxes.append(bbox) 254 | labels.append(label) 255 | scores.append(score) 256 | masks.append(mask_per_image) 257 | 258 | # return bboxes, labels, scores, masks 259 | return masks, labels, scores 260 | 261 | def prepare(self, img): 262 | _, H, W = img.shape 263 | 264 | scale = 1. 265 | 266 | scale = self.min_size / min(H, W) 267 | 268 | if scale * max(H, W) > self.max_size: 269 | scale = self.max_size / max(H, W) 270 | 271 | img = resize(img, (int(H * scale), int(W * scale))) 272 | 273 | # 元のコードは平均を引くだけ、だったんだけど、なんか[0,1]にするだけでうまくいかないかなぁ 274 | img = img.astype(np.float32) / 255 275 | 276 | return img 277 | 278 | def _suppress(self, raw_cls_bbox, raw_prob, raw_roi, raw_level): 279 | bbox = [] 280 | label = [] 281 | score = [] 282 | roi = [] 283 | level = [] 284 | # skip cls_id = 0 because it is the background class 285 | # -> maskは0から始まるから、l-1を使う 286 | # -> あーしまったTrainChainで最後のクラスToothBlushは範囲外になっておるわ・・ 287 | for l in range(1, self.n_class): 288 | if self.predict_mask and l == self.n_class - 1: 289 | # まったく本質的でないのだか、maskを推定するときの学習でオフセットを間違えており、 290 | # l == self.n_class-1でindex out of boundsする?要検証 291 | continue 292 | cls_bbox_l = raw_cls_bbox.reshape((-1, self.n_class, 4))[:, l, :] 293 | prob_l = raw_prob[:, l] 294 | mask = prob_l > self.score_thresh 295 | cls_bbox_l = cls_bbox_l[mask] 296 | prob_l = prob_l[mask] 297 | keep = non_maximum_suppression(cls_bbox_l, self.nms_thresh, prob_l) 298 | bbox.append(cls_bbox_l[keep]) 299 | # The labels are in [0, self.n_class - 2]. 300 | label.append((l - 1) * np.ones((len(keep), ))) 301 | score.append(prob_l[keep]) 302 | raw_roi_l = raw_roi[:, l, :][mask] 303 | roi.append(raw_roi_l[keep]) 304 | level_l = raw_level[mask] 305 | level.append(level_l[keep]) 306 | 307 | bbox = np.concatenate(bbox, axis=0).astype(np.float32) 308 | label = np.concatenate(label, axis=0).astype(np.int32) 309 | score = np.concatenate(score, axis=0).astype(np.float32) 310 | roi = np.concatenate(roi, axis=0) 311 | level = np.concatenate(level, axis=0).astype(np.int32) 312 | return bbox, label, score, roi, level 313 | --------------------------------------------------------------------------------