├── README.md └── point_rend.py /README.md: -------------------------------------------------------------------------------- 1 | # PointRend-pytorch 2 | 3 | This is an unofficial implementation of PointRend function. The paper can be find at 4 | 5 | We only define a simple structure of PointRend function with out any segmentation structure. 6 | 7 | 8 | 9 | # Instructions 10 | Build a PointRend block: 11 | ```python 12 | from point_rend import PointRend 13 | #use random value 14 | coarse_prediction = torch.rand([32, 3, 128, 128]).cuda() 15 | fine_grained = torch.rand([32, 128, 128, 128]).cuda() 16 | 17 | #you can get coarse_prediction and fine_grained by your segmentation 18 | #from your_seg_model import seg_model 19 | #coarse_prediction, fine_grained = seg_model(your_image_input) 20 | 21 | net = PointRend(3,1000,[128,128],[128,128],[256,256],128) 22 | output_point, output_mask = net(fine_grained, coarse_prediction) 23 | ``` 24 | 25 | 26 | -------------------------------------------------------------------------------- /point_rend.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | class MLP(nn.Module): 6 | def __init__(self, nIn, num_class): 7 | super(MLP, self).__init__() 8 | self.bn0 = nn.BatchNorm2d(54) 9 | self.conv1 = nn.Conv2d(nIn, 256, kernel_size=1, padding=0, bias=False) 10 | self.bn1 = nn.BatchNorm2d(259) 11 | self.conv2 = nn.Conv2d(259, 128, kernel_size=1, padding=0, bias=False) 12 | self.bn2 = nn.BatchNorm2d(131) 13 | self.conv3 = nn.Conv2d(131, num_class, kernel_size=1, padding=0, bias=False) 14 | #self.ReLU = nn.ReLU(inplace=True) 15 | self.ReLU = nn.PReLU() 16 | self._init_weight() 17 | def _init_weight(self): 18 | for m in self.modules(): 19 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 20 | nn.init.xavier_normal_(m.weight) 21 | if m.bias is not None: 22 | nn.init.constant_(m.bias, 0) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | nn.init.constant_(m.weight, 1) 25 | nn.init.constant_(m.bias, 0) 26 | 27 | def forward(self, x): 28 | #x:batch,N,channel 29 | #model_input:batch,channel,N,1 30 | model_input = x.unsqueeze(2).permute([0,3,1,2]) 31 | #model_input = self.bn0(model_input_pre) 32 | print('model_input:') 33 | print(model_input[0,:,0,0]) 34 | print(model_input.shape) 35 | layer1 = self.conv1(model_input) 36 | print('layer1') 37 | print(layer1.shape) 38 | 39 | 40 | #layer1_output = self.ReLU(self.bn1(torch.cat((layer1, model_input[:,-3:,:,:]),1))) 41 | #layer2 = self.conv2(layer1_output) 42 | #layer2_output = self.ReLU(self.bn2(torch.cat((layer2, model_input[:,-3:,:,:]),1))) 43 | #batch,num_class,N 44 | #last_conv = torch.squeeze(self.conv3(layer2_output),3) 45 | 46 | layer1_output = self.ReLU(layer1) 47 | layer2 = self.conv2(torch.cat((layer1_output, model_input[:,-3:,:,:]),1)) 48 | layer2_output = self.ReLU(layer2) 49 | last_conv = torch.squeeze(self.conv3(torch.cat((layer2_output, model_input[:,-3:,:,:]),1)),3) 50 | print('layer1_output:') 51 | print(layer1_output[0,:,0,0]) 52 | print(layer1_output.shape) 53 | print('layer2_output:') 54 | print(layer2_output[0,:,0,0]) 55 | print(layer2_output.shape) 56 | #batch,num_class,N 57 | 58 | return last_conv 59 | 60 | 61 | class PointRend(nn.Module): 62 | def __init__(self, num_classes,N,coarse_size,fine_size,img_size,fine_channels,is_training=True,k=3,belta=0.75): 63 | super(PointRend, self).__init__() 64 | self.num_classes = num_classes 65 | self.N = N 66 | self.is_training = is_training 67 | self.coarse_size = coarse_size 68 | self.img_size = img_size 69 | self.fine_size = fine_size 70 | self.k = k 71 | self.belta = belta 72 | self.mlp = MLP(fine_channels+3, num_classes) 73 | if self.is_training: 74 | self.up1 = nn.Upsample(scale_factor=img_size[0]/coarse_size[0], mode='bilinear') 75 | self.up2 = nn.Upsample(scale_factor=img_size[0]/fine_size[0], mode='bilinear') 76 | else: 77 | self.up = nn.Upsample(scale_factor=2, mode='bilinear') 78 | 79 | def forward(self, fine_grained, coarse_pre): 80 | if self.is_training: 81 | #batch,c,h,w 82 | up_coarse_o = self.up1(coarse_pre) 83 | up_coarse_softmax = F.softmax(up_coarse_o, dim=1) 84 | up_fine_o = self.up2(fine_grained) 85 | 86 | #batch,h,w,c 87 | up_coarse = up_coarse_softmax.permute([0,2,3,1]) 88 | up_coarse_ori_fea = up_coarse_o.permute([0,2,3,1]) 89 | up_fine = up_fine_o.permute([0,2,3,1]) 90 | #选点方法一 91 | num_over = self.k * self.N 92 | step1_n = int(self.N * self.belta) 93 | step2_n = self.N - step1_n 94 | #torch产生整形随机数作为坐标,组合新的feature map,计算度量距离,计算topk的index,采样剩余index,组合特征向量 95 | #index = (torch.rand(num_over, 2)*self.img_size[0]).int() 96 | 97 | random_1d_index = torch.randperm(self.img_size[0]*self.img_size[1]) 98 | h_index = (random_1d_index/self.img_size[1]).int() 99 | w_index = (random_1d_index%self.img_size[1]).int() 100 | all_index = torch.cat((h_index.reshape(-1,1),w_index.reshape(-1,1)),1) 101 | index = all_index[:num_over,:].cuda() 102 | pre_list = [] 103 | for i in range(num_over): 104 | #batch,3 105 | one_shot = up_coarse[:,index[i][0],index[i][1],:] 106 | #batch,1,3 107 | one_shot_ext = one_shot.unsqueeze(1) 108 | pre_list.append(one_shot_ext) 109 | #batch,num_over,3 110 | over_pre = torch.cat(pre_list, 1) 111 | #计算置信度最高的两个值,top2:n,num_over,2 112 | top2,_ = torch.topk(over_pre,2,2) 113 | #batch,num_over 114 | certain_score = top2[:,:,0] - top2[:,:,1] 115 | uncertain_score = 1-certain_score 116 | 117 | #batch, step1_n 代表over_pre中step1_n个uncertain点 118 | 119 | _,uncertain_index = torch.topk(uncertain_score,step1_n,1) 120 | uncertain_index_reshape = uncertain_index.reshape([-1,1]) 121 | uncertain_index_repeat = uncertain_index_reshape.repeat(1,2) 122 | uncertain_points_unbatch = torch.gather(index, dim=0, index=uncertain_index_repeat) 123 | 124 | #batch,step1_n,2 uncertain点在原图中的坐标 125 | uncertain_points = uncertain_points_unbatch.reshape([-1,step1_n,2]) 126 | #covrage_points = (torch.rand(step2_n, 2)*self.img_size[0]).int() 127 | 128 | covrage_points = all_index[-step2_n:,:].unsqueeze(0).repeat([uncertain_points.shape[0],1,1]) 129 | #batch,N,2 选出的所有点 130 | all_select_point = torch.cat([uncertain_points.int().cuda(),covrage_points.cuda()],1) 131 | 132 | #组合特征向量 133 | batch_id = torch.arange(all_select_point.shape[0]) 134 | #1,batch*N 135 | batch_index = batch_id.reshape([-1,1]).repeat([1,self.N]).reshape([1,-1]).int() 136 | #2,batch*N 137 | all_select_point_prtmute = all_select_point.permute([2,0,1]).reshape(2,-1) 138 | #3,batch*N 139 | mask_pre_data = torch.cat((batch_index.cuda(),all_select_point_prtmute.cuda()),0) 140 | val = torch.tensor([1]*mask_pre_data.shape[1]).long() 141 | #batch,h,w 142 | mask = torch.sparse.FloatTensor(mask_pre_data.long().cuda(), val.long().cuda(), torch.Size([all_select_point.shape[0],self.img_size[0],self.img_size[1]])).to_dense() 143 | #mask = torch.where(mask_tmp>1,torch.tensor([1]).int(),mask_tmp) 144 | select_coarse = torch.masked_select(up_coarse_ori_fea,mask[:,:,:,None].byte()) 145 | select_fine = torch.masked_select(up_fine,mask[:,:,:,None].byte()) 146 | 147 | #batch,N,channel+3 148 | select_feature = torch.cat([select_fine, select_coarse], -1).reshape([all_select_point.shape[0],self.N, -1]) 149 | #batch,num_class,N 150 | 151 | out = self.mlp(select_feature) 152 | pre = F.softmax(out, dim=1) 153 | 154 | #debug 155 | print(pre[0]) 156 | ori_pre = select_coarse.reshape([all_select_point.shape[0],self.N, -1]).permute([0,2,1]) 157 | print(F.softmax(ori_pre, dim=1)[0]) 158 | #return select_coarse.reshape([all_select_point.shape[0],self.N, -1]).permute([0,2,1]), mask 159 | return out, mask 160 | 161 | if __name__ == '__main__': 162 | net = PointRend(3,1000,[128,128],[128,128],[256,256],128) 163 | net = net.cuda() 164 | for i in range(5): 165 | start = time.time() 166 | coarse_prediction=torch.rand([32, 3, 128, 128]).cuda() 167 | fine_grained = torch.rand([32, 128, 128, 128]).cuda() 168 | out,mask = net(up_fine, up_coarse) 169 | end = time.time() 170 | print(end-start) 171 | print(out.shape) 172 | print(mask.shape) 173 | torch.save(net, 'model.pkl') 174 | 175 | 176 | --------------------------------------------------------------------------------