├── LICENSE ├── MNIST_PointNet.py ├── MNIST_Res.py ├── Mmodule.py ├── PointCNN.py ├── PointNet2_cls_train.py ├── PointNet2_module.py ├── PointNet_train.py ├── README.md ├── dataset.py ├── module.py ├── non.py ├── pointnet_util.py └── tmp.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kun Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MNIST_PointNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from Mmodule import PointNet, res_block 5 | from Mmodule import pic2point 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import tqdm 11 | 12 | # from tensorboardX import SummaryWriter 13 | 14 | transform = transforms.Compose([ 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 18 | transforms.Normalize((0.5,), (0.5,)) 19 | ]) 20 | 21 | batch_size = 32 22 | # MNIST很容易达到95%以上准确率 23 | trainset = torchvision.datasets.MNIST(root='./Mdata', train=True, 24 | download=True, transform=transform) 25 | # trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 26 | # download=True, transform=transform) 27 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 28 | shuffle=True, num_workers=2) 29 | testset = torchvision.datasets.MNIST(root='./Mdata', train=False, 30 | download=True, transform=transform) 31 | # testset = torchvision.datasets.CIFAR10(root='./data', train=False, 32 | # download=True, transform=transform) 33 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 34 | shuffle=True, num_workers=2) 35 | num_classes = 10 36 | num_batch = len(trainset) / batch_size 37 | 38 | net = PointNet(res_block, k=num_classes) 39 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 40 | net.to(device) 41 | # net = nn.DataParallel(net, device_ids=[0, 1]) 42 | net = nn.DataParallel(net) 43 | 44 | optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999)) 45 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 46 | 47 | print(device) 48 | print('Start training') 49 | for epoch in range(1): 50 | scheduler.step() 51 | for i, data in enumerate(trainloader, 0): 52 | points, target = data 53 | points = pic2point(points) 54 | points, target = points.to(device), target.to(device) 55 | optimizer.zero_grad() 56 | # classifier = net.train() 57 | pred, trans, trans_feat = net(points) 58 | loss = F.nll_loss(pred, target) 59 | loss.backward() 60 | optimizer.step() 61 | pred_choice = pred.data.max(1)[1] 62 | correct = pred_choice.eq(target.data).cpu().sum() 63 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(batch_size))) 64 | 65 | # if i % 10 == 0: 66 | # j, data = next(enumerate(testloader, 0)) 67 | # points, target = data 68 | # target = target[:, 0] 69 | # points = points.transpose(2, 1) 70 | # points, target = points.to(device), target.to(device) 71 | # classifier = classifier.eval() 72 | # pred, _, _ = classifier(points) 73 | # loss = F.nll_loss(pred, target) 74 | # pred_choice = pred.data.max(1)[1] 75 | # correct = pred_choice.eq(target.data).cpu().sum() 76 | # print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 77 | 78 | # torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 79 | 80 | total_correct = 0 81 | total_testset = 0 82 | with torch.no_grad(): 83 | # for i, data in tqdm(enumerate(testloader, 0)): 84 | for i, data in enumerate(testloader, 0): 85 | points, target = data 86 | points = pic2point(points) 87 | points, target = points.to(device), target.to(device) 88 | classifier = net.eval() 89 | pred, _, _ = net(points) 90 | pred_choice = pred.data.max(1)[1] 91 | correct = pred_choice.eq(target.data).sum() 92 | total_correct += correct.item() 93 | total_testset += points.size()[0] 94 | 95 | print("final accuracy {}".format(total_correct / float(total_testset))) 96 | -------------------------------------------------------------------------------- /MNIST_Res.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | 10 | from tensorboardX import SummaryWriter 11 | 12 | transform = transforms.Compose([ 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 16 | transforms.Normalize((0.5,), (0.5,)) 17 | ]) 18 | 19 | batch_size = 16 20 | 21 | trainset = torchvision.datasets.MNIST(root='./data', train=True, 22 | download=True, transform=transform) 23 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 24 | shuffle=True, num_workers=2) 25 | testset = torchvision.datasets.MNIST(root='./data', train=False, 26 | download=True, transform=transform) 27 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 28 | shuffle=True, num_workers=2) 29 | 30 | 31 | class res_block(nn.Module): 32 | def __init__(self, in_channels, out_channels): 33 | super(res_block, self).__init__() 34 | self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) 35 | self.bn = nn.BatchNorm2d(out_channels) 36 | 37 | self.relu = nn.ReLU() 38 | self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 39 | self.self_conv = nn.Conv2d(in_channels, out_channels, 1, padding=0) 40 | 41 | def _initialize_weights(self): 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 45 | if m.bias is not None: 46 | nn.init.constant_(m.bias, 0) 47 | elif isinstance(m, nn.BatchNorm2d): 48 | nn.init.constant_(m.weight, 1) 49 | nn.init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.Linear): 51 | nn.init.normal_(m.weight, 0, 0.01) 52 | nn.init.constant_(m.bias, 0) 53 | 54 | def forward(self, x): 55 | out = self.conv1(x) 56 | out = self.relu(self.bn(out)) 57 | out = self.relu(self.bn(self.conv2(out))) 58 | out = self.relu(self.bn(self.conv2(out))) 59 | x = self.self_conv(x) 60 | return out + x 61 | 62 | 63 | class ResNet(nn.Module): 64 | def __init__(self, block): 65 | super(ResNet, self).__init__() 66 | # self.conv1 = block(3, 32) 67 | # self.conv1_2 = block(32, 32) 68 | # self.conv2 = block(32, 64) 69 | # self.conv2_2 = block(64, 64) 70 | # self.conv3 = block(64, 128) 71 | # self.conv3_2 = block(128, 128) 72 | 73 | self.conv1 = block(1, 64) 74 | self.conv1_2 = block(64, 64) 75 | self.conv2 = block(64, 128) 76 | self.conv2_2 = block(128, 128) 77 | self.conv3 = block(128, 256) 78 | self.conv3_2 = block(256, 256) 79 | 80 | self.pool = nn.MaxPool2d(2, 2) 81 | self.conv8 = block(256, 256) 82 | self.global_pool = nn.Conv2d(256, 10, 1) 83 | self.dropout = nn.Dropout2d() 84 | 85 | # resblock不能写在这里,因为定义在其他函数(非__init__)下的卷积等,权重参数不会跟着移到cuda上 86 | # def res_block(self, x, in_channels, out_channels, is_pool=False): 87 | # out = nn.Conv2d(in_channels, out_channels, 3, padding=1)(x) 88 | # out = nn.BatchNorm2d(out_channels)(out) 89 | # out = nn.ReLU(out) 90 | # out = nn.Conv2d(out_channels, out_channels, 3, padding=1)(out) 91 | # out = nn.BatchNorm2d(out_channels)(out) 92 | # x = nn.Conv2d(in_channels, out_channels, 1, padding=0)(x) 93 | # out = nn.ReLU(out) + x 94 | # return out 95 | 96 | def _initialize_weights(self): 97 | for m in self.modules(): 98 | if isinstance(m, nn.Conv2d): 99 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 100 | if m.bias is not None: 101 | nn.init.constant_(m.bias, 0) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | nn.init.constant_(m.weight, 1) 104 | nn.init.constant_(m.bias, 0) 105 | elif isinstance(m, nn.Linear): 106 | nn.init.normal_(m.weight, 0, 0.01) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | def forward(self, x): 110 | x = self.pool(self.conv1_2(self.conv1(x))) 111 | x = self.pool(self.conv2_2(self.conv2(x))) 112 | x = self.pool(self.conv3_2(self.conv3(x))) 113 | 114 | # 全卷积 115 | x = self.pool(self.conv8(x)) 116 | x = self.dropout(x) 117 | x = self.global_pool(self.pool(x)) 118 | x = x.view(-1, 10) 119 | return x 120 | 121 | 122 | net = ResNet(res_block) 123 | # print(net) 124 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 125 | print(device) 126 | net.to(device) 127 | net = nn.DataParallel(net) 128 | 129 | 130 | criterion = nn.CrossEntropyLoss() 131 | # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 132 | optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.99)) 133 | 134 | print('Start training') 135 | writer = SummaryWriter('runs') 136 | for epoch in range(10): # loop over the dataset multiple times 137 | 138 | running_loss = 0.0 139 | for i, data in enumerate(trainloader, 0): 140 | # get the inputs; data is a list of [inputs, labels] 141 | inputs, labels = data[0].to(device), data[1].to(device) 142 | # inputs, labels = data[0].cuda(), data[1].cuda() 143 | 144 | # zero the parameter gradients 145 | optimizer.zero_grad() 146 | 147 | # forward + backward + optimize 148 | outputs = net(inputs) 149 | # 使用交叉熵损失函数的时候会自动把label转化成onehot 150 | # 一句话将标签转化为one-hot: 151 | # label_onehot = torch.eye(10).index_select(0, labels) 152 | # 实验证明cross entropy输入one-hot报错,但MSE需要输入one-hot 153 | loss = criterion(outputs, labels) 154 | 155 | loss.backward() 156 | optimizer.step() 157 | 158 | # print statistics 159 | running_loss += loss.item() 160 | if i*batch_size % 6400 == 6384: # print every 8192 mini-batches 161 | print('[%d, %5d] loss: %.3f' % 162 | (epoch + 1, i + 1, running_loss / 2000)) 163 | # print(inputs.shape) 164 | running_loss = 0.0 165 | writer.add_scalar('loss', loss, epoch) 166 | writer.add_scalar('running_loss', running_loss, epoch) 167 | writer.add_graph(net, (inputs,)) 168 | writer.close() 169 | print('Finished Training') 170 | 171 | correct = 0 172 | total = 0 # total = 10000 173 | with torch.no_grad(): 174 | for data in testloader: 175 | images, labels = data[0].to(device), data[1].to(device) 176 | outputs = net(images) 177 | _, predicted = torch.max(outputs.data, 1) 178 | total += labels.size(0) 179 | correct += (predicted == labels).sum().item() 180 | 181 | print('Accuracy of the network on the 10000 test images: %d %%' % ( 182 | 100 * correct / total)) 183 | -------------------------------------------------------------------------------- /Mmodule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | ################################################ 10 | # 11 | # 12 | # 用于魔改普通数据集如cifar10等用于PointNet的module 13 | # 14 | # 15 | ################################################ 16 | 17 | 18 | # 针对4维B*C*W*H变为B*4*(C*W*H)的三维点云(3表示3维坐标+1维原始数据) 19 | def pic2point(pic_data): 20 | b, c, w, h = pic_data.size() 21 | x, y, z = np.where(np.ones([c, w, h]) > 0) 22 | coordinates = np.vstack((x, y, z)) 23 | coordinates = coordinates / (coordinates.max(axis=1).reshape(-1, 1)+0.0000001) 24 | coordinates = coordinates.reshape(1, 3, -1).repeat(b, 0) 25 | 26 | point_data = pic_data.view(b, 1, c*w*h) 27 | point_data = np.concatenate((point_data, coordinates), 1) 28 | 29 | return Variable(torch.from_numpy(point_data.astype(np.float32))) 30 | 31 | 32 | # TODO: 设想是将下面的尺寸根据输入数据集进行魔改 33 | class STN3d(nn.Module): 34 | def __init__(self, channel): 35 | super(STN3d, self).__init__() 36 | self.channel = channel 37 | self.conv1 = torch.nn.Conv1d(channel, 64, 1) 38 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 39 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 40 | self.fc1 = nn.Linear(1024, 512) 41 | self.fc2 = nn.Linear(512, 256) 42 | self.fc3 = nn.Linear(256, channel*channel) 43 | self.relu = nn.ReLU() 44 | 45 | self.bn1 = nn.BatchNorm1d(64) 46 | self.bn2 = nn.BatchNorm1d(128) 47 | self.bn3 = nn.BatchNorm1d(1024) 48 | self.bn4 = nn.BatchNorm1d(512) 49 | self.bn5 = nn.BatchNorm1d(256) 50 | 51 | def forward(self, x): 52 | batchsize = x.size()[0] 53 | x = self.relu(self.bn1(self.conv1(x))) 54 | x = self.relu(self.bn2(self.conv2(x))) 55 | x = self.relu(self.bn3(self.conv3(x))) 56 | x = torch.max(x, 2, keepdim=True)[0] 57 | x = x.view(-1, 1024) 58 | 59 | x = self.relu(self.bn4(self.fc1(x))) 60 | x = self.relu(self.bn5(self.fc2(x))) 61 | x = self.fc3(x) 62 | 63 | iden = Variable(torch.from_numpy(np.eye(self.channel).flatten().astype(np.float32))).view(1, self.channel*self.channel).repeat(batchsize, 1) 64 | if x.is_cuda: 65 | iden = iden.cuda() 66 | x = x + iden 67 | x = x.view(-1, self.channel, self.channel) 68 | return x 69 | 70 | 71 | class STNkd(nn.Module): 72 | def __init__(self, k=64): 73 | super(STNkd, self).__init__() 74 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 75 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 76 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 77 | self.fc1 = nn.Linear(1024, 512) 78 | self.fc2 = nn.Linear(512, 256) 79 | self.fc3 = nn.Linear(256, k * k) 80 | self.relu = nn.ReLU() 81 | 82 | self.bn1 = nn.BatchNorm1d(64) 83 | self.bn2 = nn.BatchNorm1d(128) 84 | self.bn3 = nn.BatchNorm1d(1024) 85 | self.bn4 = nn.BatchNorm1d(512) 86 | self.bn5 = nn.BatchNorm1d(256) 87 | 88 | self.k = k 89 | 90 | def forward(self, x): 91 | batchsize = x.size()[0] 92 | x = self.relu(self.bn1(self.conv1(x))) 93 | x = self.relu(self.bn2(self.conv2(x))) 94 | x = self.relu(self.bn3(self.conv3(x))) 95 | x = torch.max(x, 2, keepdim=True)[0] 96 | x = x.view(-1, 1024) 97 | 98 | x = self.relu(self.bn4(self.fc1(x))) 99 | x = self.relu(self.bn5(self.fc2(x))) 100 | x = self.fc3(x) 101 | 102 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( 103 | batchsize, 1) 104 | if x.is_cuda: 105 | iden = iden.cuda() 106 | x = x + iden 107 | x = x.view(-1, self.k, self.k) 108 | return x 109 | 110 | 111 | class res_block(nn.Module): 112 | def __init__(self, in_channels, out_channels): 113 | super(res_block, self).__init__() 114 | self.conv1 = nn.Conv1d(in_channels, out_channels, 1, padding=0) 115 | self.bn = nn.BatchNorm1d(out_channels) 116 | 117 | self.relu = nn.ReLU() 118 | self.conv2 = nn.Conv1d(out_channels, out_channels, 1, padding=0) 119 | self.self_conv = nn.Conv1d(in_channels, out_channels, 1, padding=0) 120 | 121 | def _initialize_weights(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 125 | if m.bias is not None: 126 | nn.init.constant_(m.bias, 0) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | elif isinstance(m, nn.Linear): 131 | nn.init.normal_(m.weight, 0, 0.01) 132 | nn.init.constant_(m.bias, 0) 133 | 134 | def forward(self, x): 135 | out = self.conv1(x) 136 | out = self.relu(self.bn(out)) 137 | out = self.relu(self.bn(self.conv2(out))) 138 | x = self.self_conv(x) 139 | return out + x 140 | 141 | 142 | class PointNet(nn.Module): 143 | def __init__(self, block, k=2): 144 | super(PointNet, self).__init__() 145 | self.k = k 146 | self.stn3d = STN3d(4) 147 | self.conv1 = torch.nn.Conv1d(4, 64, 1) 148 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 149 | # self.conv1 = block(4, 64) 150 | # self.conv2 = block(64, 64) 151 | self.stnkd = STNkd(64) 152 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 153 | self.conv4 = torch.nn.Conv1d(64, 128, 1) 154 | self.conv5 = torch.nn.Conv1d(128, 256, 1) 155 | self.conv6 = torch.nn.Conv1d(256, 512, 1) 156 | self.conv7 = torch.nn.Conv1d(512, 1024, 1) 157 | # self.conv3 = block(64, 64) 158 | # self.conv4 = block(64, 128) 159 | # self.conv5 = block(128, 256) 160 | # self.conv6 = block(256, 512) 161 | # self.conv7 = block(512, 1024) 162 | 163 | self.fc1 = nn.Linear(1024, 512) 164 | self.fc2 = nn.Linear(512, 256) 165 | self.fc3 = nn.Linear(256, k) 166 | 167 | self.bn1 = nn.BatchNorm1d(64) 168 | self.bn2 = nn.BatchNorm1d(64) 169 | self.bn3 = nn.BatchNorm1d(64) 170 | self.bn4 = nn.BatchNorm1d(128) 171 | self.bn5 = nn.BatchNorm1d(256) 172 | self.bn6 = nn.BatchNorm1d(512) 173 | self.bn7 = nn.BatchNorm1d(1024) 174 | 175 | self.bn8 = nn.BatchNorm1d(512) 176 | self.bn9 = nn.BatchNorm1d(256) 177 | 178 | self.dropout = nn.Dropout2d() 179 | 180 | def forward(self, x): 181 | # channel = x.size(dim=1) 182 | trans = self.stn3d(x) 183 | x = x.transpose(2, 1) # 变成B*N*3以满足矩阵计算x*trans 184 | x = torch.bmm(x, trans) # 实现bath间的矩阵乘法,不改变batch维度,计算结果维度为B*N*3 185 | x = x.transpose(2, 1) 186 | x = F.relu(self.bn1(self.conv1(x))) 187 | x = F.relu(self.bn2(self.conv2(x))) 188 | # x = F.relu(self.conv1(x)) 189 | # x = F.relu(self.conv2(x)) 190 | trans_feat = self.stnkd(x) 191 | x = x.transpose(2, 1) # 变成B*N*3以满足矩阵计算x*trans 192 | x = torch.bmm(x, trans_feat) # 实现bath间的矩阵乘法,不改变batch维度,计算结果维度为B*N*3 193 | x = x.transpose(2, 1) 194 | x = F.relu(self.bn3(self.conv3(x))) 195 | x = F.relu(self.bn4(self.conv4(x))) 196 | x = F.relu(self.bn5(self.conv5(x))) 197 | x = F.relu(self.bn6(self.conv6(x))) 198 | x = F.relu(self.bn7(self.conv7(x))) 199 | # x = F.relu(self.conv3(x)) 200 | # x = F.relu(self.conv4(x)) 201 | # x = F.relu(self.conv5(x)) 202 | # x = F.relu(self.conv6(x)) 203 | # x = F.relu(self.conv7(x)) 204 | x = torch.max(x, 2, keepdim=True)[0].view(-1, 1024) 205 | 206 | x = F.relu(self.bn8(self.fc1(x))) 207 | x = F.relu(self.bn9(self.fc2(x))) 208 | x = self.dropout(x) 209 | x = F.log_softmax(self.fc3(x), dim=1) 210 | return x, trans, trans_feat 211 | 212 | 213 | def feature_transform_regularizer(trans): 214 | d = trans.size()[1] 215 | batchsize = trans.size()[0] 216 | I = torch.eye(d)[None, :, :] 217 | if trans.is_cuda: 218 | I = I.cuda() 219 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1,2))) 220 | return loss 221 | -------------------------------------------------------------------------------- /PointCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | def square_distance(src, dst): 9 | """ 10 | Calculate Euclid distance between each two points. 11 | 12 | src^T * dst = xn * xm + yn * ym + zn * zm; 13 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 14 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 15 | 16 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 17 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 18 | 19 | Input: 20 | src: source points, [B, N, C] 21 | dst: target points, [B, M, C] 22 | Output: 23 | dist: per-point square distance, [B, N, M] 24 | """ 25 | B, N, _ = src.shape 26 | _, M, _ = dst.shape 27 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 28 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 29 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 30 | return dist 31 | 32 | 33 | def index_points(points, idx): 34 | """ 35 | Input: 36 | points: input points data, [B, N, C] 37 | idx: sample index data, [B, S] 38 | Return: 39 | new_points:, indexed points data, [B, S, C] 40 | """ 41 | device = points.device 42 | B = points.shape[0] 43 | view_shape = list(idx.shape) 44 | # view_shape = [B, 1] 45 | view_shape[1:] = [1] * (len(view_shape) - 1) 46 | repeat_shape = list(idx.shape) 47 | # repeat_shape = [1, S] 48 | repeat_shape[0] = 1 49 | # 其实就是想把[0, B-1]的range一维向量变为与idx维度相同的[0, B-1]重复S次矩阵,这样维度相同切片时才能一一对应 50 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 51 | new_points = points[batch_indices, idx, :] 52 | return new_points 53 | 54 | 55 | def farthest_point_sample(xyz, npoint): 56 | """ 57 | Input: 58 | xyz: pointcloud data, [B, N, C] 59 | npoint: number of samples 60 | Return: 61 | centroids: sampled pointcloud index, [B, npoint] 62 | """ 63 | device = xyz.device 64 | B, N, C = xyz.shape 65 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 66 | distance = torch.ones(B, N).to(device) * 1e10 67 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 68 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 69 | # 每一次循环找的是batch个最远点: farthest [B, ] 70 | # distance里存的是每一次操作以后B*N个点相对于所有备选点集的最远距离 71 | # mask存在是因为点到点集的距离是点到点集中每个点的距离中的最小值 72 | for i in range(npoint): 73 | centroids[:, i] = farthest 74 | # batch_indices [B ], farthest [B ],输出 [B, 1, 3] 75 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 76 | dist = torch.sum((xyz - centroid) ** 2, -1) 77 | mask = dist < distance 78 | distance[mask] = dist[mask] 79 | farthest = torch.max(distance, -1)[1] 80 | return centroids 81 | 82 | 83 | def find_k_neighbor(pts, xyz, K, D): 84 | """ 85 | Find K nearest neighbor points 86 | :param pts: represent points(B, P, C) 87 | :param xyz: original points (B, N, C) 88 | :param K: 89 | :param D: Dilation rate 90 | :return group_pts: K neighbor points(B, P, K, C) 91 | """ 92 | device = pts.device 93 | B, P, _ = pts.size() 94 | sqrdists = square_distance(pts, xyz) # (B,P,N) 95 | # k_ind = torch.topk(sqrdists, k=K, dim=-1, largest=False)[1] 96 | k_ind = sqrdists.sort(dim=-1)[1][:, :, :K*D] # (B,P,K*D) 97 | rand_columns = torch.randperm(K*D, dtype=torch.long)[:K].to(device) 98 | k_ind = k_ind[:, :, rand_columns] 99 | group_pts = index_points(xyz, k_ind) # (B,P,K,C) 100 | 101 | return group_pts, k_ind 102 | 103 | 104 | class xconv(nn.Module): 105 | def __init__(self, in_channel, lift_channel, out_channel, P, K, D=1, sampling='fps'): 106 | """ 107 | :param in_channel: Input channel of the points' features 108 | :param lift_channel: Lifted channel C_delta 109 | :param out_channel: 110 | :param P: P represent points 111 | :param K: K neighbors to operate 112 | :param D: Dilation rate 113 | """ 114 | super(xconv, self).__init__() 115 | self.P = P 116 | self.K = K 117 | self.D = D 118 | self.sampling = sampling 119 | # Input should be (B, 3, P, K) 120 | self.MLP_delta = nn.Sequential( 121 | nn.Conv2d(3, lift_channel, kernel_size=1), 122 | nn.ELU(inplace=True), 123 | nn.BatchNorm2d(lift_channel), 124 | nn.Conv2d(lift_channel, lift_channel, kernel_size=1), 125 | nn.ELU(inplace=True), 126 | nn.BatchNorm2d(lift_channel) 127 | ) 128 | # Input should be (B, 3, P, K) 129 | self.MLP_X = nn.Sequential( 130 | nn.Conv2d(3, K, kernel_size=1), 131 | nn.ELU(inplace=True), 132 | nn.BatchNorm2d(K), 133 | nn.Conv2d(K, K, kernel_size=1), 134 | nn.ELU(inplace=True), 135 | nn.BatchNorm2d(K), 136 | # nn.Conv2d(K, K, kernel_size=1), 137 | # nn.BatchNorm2d(K) 138 | ) 139 | nn.init.xavier_uniform_(self.MLP_X[0].weight) 140 | nn.init.xavier_uniform_(self.MLP_X[3].weight) 141 | # nn.init.xavier_uniform_(self.MLP_X[6].weight) 142 | 143 | self.MLP_feat0 = nn.Sequential( 144 | nn.Conv2d(K, K, kernel_size=1), 145 | nn.ELU(inplace=True), 146 | nn.BatchNorm2d(K), 147 | nn.Conv2d(K, 1, kernel_size=1), 148 | nn.ELU(inplace=True), 149 | nn.BatchNorm2d(1) 150 | ) 151 | self.MLP_feat1 = nn.Sequential( 152 | nn.Conv1d(lift_channel+in_channel, out_channel, kernel_size=1), 153 | nn.ELU(inplace=True), 154 | nn.BatchNorm1d(out_channel) 155 | ) 156 | 157 | def forward(self, pts, fts): 158 | """ 159 | :param x: (rep_pt, pts, fts) where 160 | - pts: Regional point cloud (B, N, 3) 161 | - fts: Regional features (B, N, C) 162 | :return: Features aggregated into point rep_pt. 163 | """ 164 | B, N, _ = pts.size() 165 | if self.P == -1: 166 | self.P = N 167 | represent_pts = pts 168 | pre_ind = torch.arange(0, N, step=1).unsqueeze(0).repeat((B, 1)).to(pts.device) 169 | else: 170 | if self.sampling == 'fps': 171 | pre_ind = farthest_point_sample(pts, self.P) # (B, P) 172 | represent_pts = index_points(pts, pre_ind) # .view(B, self.P, 1, 3) 173 | else: 174 | # idx = np.random.choice(pts.size()[1], self.P, replace=False).tolist() # .to(pts.device) 175 | # represent_pts = pts[:, idx, :] 176 | pre_ind = torch.randint(low=0, high=N, size=(B, self.P), dtype=torch.long).to(pts.device) 177 | represent_pts = index_points(pts, pre_ind) 178 | 179 | group_pts, k_ind = find_k_neighbor(represent_pts, pts, self.K, self.D) # (B, P, K, 3), (B, P, K) 180 | center_pts = torch.unsqueeze(represent_pts, dim=2) # (B, P, 1, 3) 181 | group_pts = group_pts - center_pts # (B, P, K, 3) 182 | # MLP得到fts_lifted 183 | group_pts = group_pts.permute(0,3,1,2) 184 | fts_lifted = self.MLP_delta(group_pts) # (B, C_delta, P, K) 185 | if fts is not None: 186 | # TODO: ind会越界 187 | # center_feat = index_points(fts, pre_ind) # (B, P, C_in) 188 | # group_fts = index_points(center_feat, k_ind) # (B, P, K, C_in) 189 | group_fts = index_points(fts, k_ind) 190 | 191 | group_fts = group_fts.permute(0,3,1,2) 192 | feat = torch.cat((fts_lifted, group_fts), 1) # (B, C_delta + C_in, P, K) 193 | else: 194 | feat = fts_lifted 195 | # X阵 196 | X = self.MLP_X(group_pts).permute(0,2,3,1) # (B, P, K, K) 197 | 198 | X = X.contiguous().view(B*self.P, self.K, self.K) 199 | feat = feat.permute(0,2,3,1).contiguous().view(B*self.P, self.K, -1) 200 | feat = torch.bmm(X, feat).view(B, self.P, self.K, -1).permute(0,2,1,3) # (B, K, P, C_delta + C_in) 201 | feat = self.MLP_feat0(feat).squeeze(1) # (B, self.P, C_delta + C_in) 202 | feat = feat.permute(0,2,1) # (B, C_delta + C_in, self.P) 203 | feat = self.MLP_feat1(feat).permute(0,2,1) # (B, self.P, C_out) 204 | 205 | return represent_pts, feat # (B, P, 3), (B, P, C_out) 206 | 207 | 208 | class PointCNN_cls(nn.Module): 209 | def __init__(self, num_class): 210 | super().__init__() 211 | # X_conv1 212 | C_out = 16*3 213 | C_delta = C_out // 2 214 | self.x_conv1 = xconv(0, C_delta, C_out, P=-1, K=8)# , sampling='rand') 215 | 216 | # X_conv2 217 | C_in = C_out 218 | C_out = 32*3 219 | C_delta = C_in // 4 220 | self.x_conv2 = xconv(C_in, C_delta, C_out, P=384, K=12, D=2)# , sampling='rand') 221 | 222 | # X_conv3 223 | C_in = C_out 224 | C_out = 64*3 225 | C_delta = C_in // 4 226 | self.x_conv3 = xconv(C_in, C_delta, C_out, P=128, K=16, D=2)# , sampling='rand') 227 | 228 | # X_conv4 229 | C_in = C_out 230 | C_out = 128*3 231 | C_delta = C_in // 4 232 | self.x_conv4 = xconv(C_in, C_delta, C_out, P=-1, K=16, D=3)# , sampling='rand') 233 | 234 | self.fc = nn.Sequential( 235 | nn.Conv1d(C_out, 64 * 3, 1), 236 | nn.ELU(inplace=True), 237 | nn.BatchNorm1d(64 * 3), 238 | nn.Conv1d(64 * 3, num_class, 1), 239 | nn.ELU(inplace=True), 240 | nn.BatchNorm1d(num_class), 241 | nn.Dropout(p=0.8) 242 | ) 243 | 244 | def forward(self, x): # (B, N, 3) 245 | pts, fts = self.x_conv1(x, None) 246 | pts, fts = self.x_conv2(pts, fts) 247 | pts, fts = self.x_conv3(pts, fts) 248 | pts, fts = self.x_conv4(pts, fts) # (B, 128, 3), (B, 128, 384) 249 | 250 | fts = fts.permute(0, 2, 1) 251 | fts = self.fc(fts) # (B, num_class, 128) 252 | logits = torch.mean(fts, dim=-1) 253 | 254 | return logits 255 | -------------------------------------------------------------------------------- /PointNet2_cls_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from dataset import ShapeNetDataset, ModelNetDataset 10 | from module import PointNet, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--batchSize', type=int, default=32, help='input batch size') 18 | parser.add_argument( 19 | '--num_points', type=int, default=2500, help='input batch size') 20 | parser.add_argument( 21 | '--workers', type=int, help='number of data loading workers', default=4) 22 | parser.add_argument( 23 | '--nepoch', type=int, default=5, help='number of epochs to train for') 24 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 25 | parser.add_argument('--model', type=str, default='', help='model path') 26 | # parser.add_argument('--dataset', type=str, required=True, help="dataset path") 27 | # path = '/home/kun/Documents/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0' 28 | path = './shapenetcore_partanno_segmentation_benchmark_v0' 29 | parser.add_argument('--dataset', type=str, default=path, help="dataset path") 30 | parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") 31 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 32 | 33 | opt = parser.parse_args() 34 | print(opt) 35 | 36 | blue = lambda x: '\033[94m' + x + '\033[0m' 37 | 38 | opt.manualSeed = random.randint(1, 10000) # fix seed 39 | print("Random Seed: ", opt.manualSeed) 40 | random.seed(opt.manualSeed) 41 | torch.manual_seed(opt.manualSeed) 42 | 43 | if opt.dataset_type == 'shapenet': 44 | dataset = ShapeNetDataset( 45 | root=opt.dataset, 46 | classification=True, 47 | npoints=opt.num_points) 48 | 49 | test_dataset = ShapeNetDataset( 50 | root=opt.dataset, 51 | classification=True, 52 | split='test', 53 | npoints=opt.num_points, 54 | data_augmentation=False) 55 | elif opt.dataset_type == 'modelnet40': 56 | dataset = ModelNetDataset( 57 | root=opt.dataset, 58 | npoints=opt.num_points, 59 | split='trainval') 60 | 61 | test_dataset = ModelNetDataset( 62 | root=opt.dataset, 63 | split='test', 64 | npoints=opt.num_points, 65 | data_augmentation=False) 66 | else: 67 | exit('wrong dataset type') 68 | 69 | 70 | dataloader = torch.utils.data.DataLoader( 71 | dataset, 72 | batch_size=opt.batchSize, 73 | shuffle=True, 74 | num_workers=int(opt.workers)) 75 | 76 | testdataloader = torch.utils.data.DataLoader( 77 | test_dataset, 78 | batch_size=opt.batchSize, 79 | shuffle=True, 80 | num_workers=int(opt.workers)) 81 | 82 | print(len(dataset), len(test_dataset)) 83 | num_classes = len(dataset.classes) 84 | print('classes', num_classes) 85 | 86 | try: 87 | os.makedirs(opt.outf) 88 | except OSError: 89 | pass 90 | 91 | classifier = PointNet(k=num_classes) 92 | 93 | if opt.model != '': 94 | classifier.load_state_dict(torch.load(opt.model)) 95 | 96 | 97 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 98 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 99 | classifier.cuda() 100 | 101 | num_batch = len(dataset) / opt.batchSize 102 | 103 | for epoch in range(opt.nepoch): 104 | scheduler.step() 105 | for i, data in enumerate(dataloader, 0): 106 | points, target = data 107 | target = target[:, 0] 108 | points = points.transpose(2, 1) 109 | points, target = points.cuda(), target.cuda() 110 | optimizer.zero_grad() 111 | classifier = classifier.train() 112 | pred, trans, trans_feat = classifier(points) 113 | loss = F.nll_loss(pred, target) 114 | if opt.feature_transform: 115 | loss += feature_transform_regularizer(trans_feat) * 0.001 116 | loss.backward() 117 | optimizer.step() 118 | pred_choice = pred.data.max(1)[1] 119 | correct = pred_choice.eq(target.data).cpu().sum() 120 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) 121 | 122 | if i % 10 == 0: 123 | j, data = next(enumerate(testdataloader, 0)) 124 | points, target = data 125 | target = target[:, 0] 126 | points = points.transpose(2, 1) 127 | points, target = points.cuda(), target.cuda() 128 | classifier = classifier.eval() 129 | pred, _, _ = classifier(points) 130 | loss = F.nll_loss(pred, target) 131 | pred_choice = pred.data.max(1)[1] 132 | correct = pred_choice.eq(target.data).cpu().sum() 133 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 134 | 135 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 136 | 137 | total_correct = 0 138 | total_testset = 0 139 | with torch.no_grad(): 140 | for i, data in tqdm(enumerate(testdataloader, 0)): 141 | points, target = data 142 | target = target[:, 0] 143 | points = points.transpose(2, 1) 144 | points, target = points.cuda(), target.cuda() 145 | classifier = classifier.eval() 146 | pred, _, _ = classifier(points) 147 | pred_choice = pred.data.max(1)[1] 148 | correct = pred_choice.eq(target.data).cpu().sum() 149 | total_correct += correct.item() 150 | total_testset += points.size()[0] 151 | 152 | print("final accuracy {}".format(total_correct / float(total_testset))) 153 | -------------------------------------------------------------------------------- /PointNet2_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from pointnet_util import PointNetSetAbstractionMsg,PointNetSetAbstraction,PointNetFeaturePropagation 6 | # npoint, radius_list, nsample_list, in_channel, mlp_list 7 | 8 | 9 | class PointNet2ClsMsg(nn.Module): 10 | def __init__(self): 11 | super(PointNet2ClsMsg, self).__init__() 12 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], 0, 13 | [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 14 | self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320, 15 | [[64, 64, 128], [128, 128, 256], [128, 128, 256]]) 16 | self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True) 17 | self.fc1 = nn.Linear(1024, 512) 18 | self.bn1 = nn.BatchNorm1d(512) 19 | self.drop1 = nn.Dropout(0.4) 20 | self.fc2 = nn.Linear(512, 256) 21 | self.bn2 = nn.BatchNorm1d(256) 22 | self.drop2 = nn.Dropout(0.4) 23 | self.fc3 = nn.Linear(256, 40) 24 | 25 | def forward(self, xyz): 26 | B, _, _ = xyz.shape 27 | l1_xyz, l1_points = self.sa1(xyz, None) 28 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 29 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 30 | x = l3_points.view(B, 1024) 31 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 32 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 33 | x = self.fc3(x) 34 | x = F.log_softmax(x, -1) 35 | return x, l3_points 36 | 37 | 38 | class PointNet2ClsSsg(nn.Module): 39 | def __init__(self): 40 | super(PointNet2ClsSsg, self).__init__() 41 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=3, mlp=[64, 64, 128], group_all=False) 42 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 43 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 44 | self.fc1 = nn.Linear(1024, 512) 45 | self.bn1 = nn.BatchNorm1d(512) 46 | self.drop1 = nn.Dropout(0.4) 47 | self.fc2 = nn.Linear(512, 256) 48 | self.bn2 = nn.BatchNorm1d(256) 49 | self.drop2 = nn.Dropout(0.4) 50 | self.fc3 = nn.Linear(256, 40) 51 | 52 | def forward(self, xyz): 53 | B, _, _ = xyz.shape 54 | l1_xyz, l1_points = self.sa1(xyz, None) 55 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 56 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 57 | x = l3_points.view(B, 1024) 58 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 59 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 60 | x = self.fc3(x) 61 | x = F.log_softmax(x, -1) 62 | return x 63 | 64 | 65 | class PointNet2SemSeg(nn.Module): 66 | def __init__(self, num_classes): 67 | super(PointNet2SemSeg, self).__init__() 68 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 6 + 3, [32, 32, 64], False) 69 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 70 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 71 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 72 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 73 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 74 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 75 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 76 | self.conv1 = nn.Conv1d(128, 128, 1) 77 | self.bn1 = nn.BatchNorm1d(128) 78 | self.drop1 = nn.Dropout(0.5) 79 | self.conv2 = nn.Conv1d(128, num_classes, 1) 80 | 81 | def forward(self, xyz,points): 82 | l1_xyz, l1_points = self.sa1(xyz, points) 83 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 84 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 85 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 86 | 87 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 88 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 89 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 90 | l0_points = self.fp1(xyz, l1_xyz, None, l1_points) 91 | 92 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 93 | x = self.conv2(x) 94 | x = F.log_softmax(x, dim=1) 95 | x = x.permute(0, 2, 1) 96 | return x 97 | -------------------------------------------------------------------------------- /PointNet_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import os 4 | import random 5 | import torch 6 | import torch.nn.parallel 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from dataset import ShapeNetDataset, ModelNetDataset 10 | from module import PointNet, feature_transform_regularizer 11 | import torch.nn.functional as F 12 | from tqdm import tqdm 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | '--batchSize', type=int, default=32, help='input batch size') 18 | parser.add_argument( 19 | '--num_points', type=int, default=2500, help='input batch size') 20 | parser.add_argument( 21 | '--workers', type=int, help='number of data loading workers', default=4) 22 | parser.add_argument( 23 | '--nepoch', type=int, default=5, help='number of epochs to train for') 24 | parser.add_argument('--outf', type=str, default='cls', help='output folder') 25 | parser.add_argument('--model', type=str, default='', help='model path') 26 | # parser.add_argument('--dataset', type=str, required=True, help="dataset path") 27 | # path = '/home/kun/Documents/pointnet.pytorch/shapenetcore_partanno_segmentation_benchmark_v0' 28 | path = './shapenetcore_partanno_segmentation_benchmark_v0' 29 | parser.add_argument('--dataset', type=str, default=path, help="dataset path") 30 | parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") 31 | parser.add_argument('--feature_transform', action='store_true', help="use feature transform") 32 | 33 | opt = parser.parse_args() 34 | print(opt) 35 | 36 | blue = lambda x: '\033[94m' + x + '\033[0m' 37 | 38 | opt.manualSeed = random.randint(1, 10000) # fix seed 39 | print("Random Seed: ", opt.manualSeed) 40 | random.seed(opt.manualSeed) 41 | torch.manual_seed(opt.manualSeed) 42 | 43 | if opt.dataset_type == 'shapenet': 44 | dataset = ShapeNetDataset( 45 | root=opt.dataset, 46 | classification=True, 47 | npoints=opt.num_points) 48 | 49 | test_dataset = ShapeNetDataset( 50 | root=opt.dataset, 51 | classification=True, 52 | split='test', 53 | npoints=opt.num_points, 54 | data_augmentation=False) 55 | elif opt.dataset_type == 'modelnet40': 56 | dataset = ModelNetDataset( 57 | root=opt.dataset, 58 | npoints=opt.num_points, 59 | split='trainval') 60 | 61 | test_dataset = ModelNetDataset( 62 | root=opt.dataset, 63 | split='test', 64 | npoints=opt.num_points, 65 | data_augmentation=False) 66 | else: 67 | exit('wrong dataset type') 68 | 69 | 70 | dataloader = torch.utils.data.DataLoader( 71 | dataset, 72 | batch_size=opt.batchSize, 73 | shuffle=True, 74 | num_workers=int(opt.workers)) 75 | 76 | testdataloader = torch.utils.data.DataLoader( 77 | test_dataset, 78 | batch_size=opt.batchSize, 79 | shuffle=True, 80 | num_workers=int(opt.workers)) 81 | 82 | print(len(dataset), len(test_dataset)) 83 | num_classes = len(dataset.classes) 84 | print('classes', num_classes) 85 | 86 | try: 87 | os.makedirs(opt.outf) 88 | except OSError: 89 | pass 90 | 91 | classifier = PointNet(k=num_classes) 92 | 93 | if opt.model != '': 94 | classifier.load_state_dict(torch.load(opt.model)) 95 | 96 | 97 | optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999)) 98 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) 99 | classifier.cuda() 100 | 101 | num_batch = len(dataset) / opt.batchSize 102 | 103 | for epoch in range(opt.nepoch): 104 | scheduler.step() 105 | for i, data in enumerate(dataloader, 0): 106 | points, target = data 107 | target = target[:, 0] 108 | points = points.transpose(2, 1) 109 | points, target = points.cuda(), target.cuda() 110 | optimizer.zero_grad() 111 | classifier = classifier.train() 112 | pred, trans, trans_feat = classifier(points) 113 | loss = F.nll_loss(pred, target) 114 | if opt.feature_transform: 115 | loss += feature_transform_regularizer(trans_feat) * 0.001 116 | loss.backward() 117 | optimizer.step() 118 | pred_choice = pred.data.max(1)[1] 119 | correct = pred_choice.eq(target.data).cpu().sum() 120 | print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(opt.batchSize))) 121 | 122 | if i % 10 == 0: 123 | j, data = next(enumerate(testdataloader, 0)) 124 | points, target = data 125 | target = target[:, 0] 126 | points = points.transpose(2, 1) 127 | points, target = points.cuda(), target.cuda() 128 | classifier = classifier.eval() 129 | pred, _, _ = classifier(points) 130 | loss = F.nll_loss(pred, target) 131 | pred_choice = pred.data.max(1)[1] 132 | correct = pred_choice.eq(target.data).cpu().sum() 133 | print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize))) 134 | 135 | torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch)) 136 | 137 | total_correct = 0 138 | total_testset = 0 139 | with torch.no_grad(): 140 | for i, data in tqdm(enumerate(testdataloader, 0)): 141 | points, target = data 142 | target = target[:, 0] 143 | points = points.transpose(2, 1) 144 | points, target = points.cuda(), target.cuda() 145 | classifier = classifier.eval() 146 | pred, _, _ = classifier(points) 147 | pred_choice = pred.data.max(1)[1] 148 | correct = pred_choice.eq(target.data).cpu().sum() 149 | total_correct += correct.item() 150 | total_testset += points.size()[0] 151 | 152 | print("final accuracy {}".format(total_correct / float(total_testset))) 153 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNet 2 | Vanilla realization of PointNet and PointCNN 3 | ## Mmodule 4 | 将常用图片数据集转换为点云的格式用于比较模型 5 | ## PointCNN 6 | Pytorch implement of PointCNN 7 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | import torch 6 | import numpy as np 7 | import sys 8 | from tqdm import tqdm 9 | import json 10 | from plyfile import PlyData, PlyElement 11 | 12 | 13 | def get_segmentation_classes(root): 14 | catfile = os.path.join(root, 'synsetoffset2category.txt') 15 | cat = {} 16 | meta = {} 17 | 18 | with open(catfile, 'r') as f: 19 | for line in f: 20 | ls = line.strip().split() 21 | cat[ls[0]] = ls[1] 22 | 23 | for item in cat: 24 | dir_seg = os.path.join(root, cat[item], 'points_label') 25 | dir_point = os.path.join(root, cat[item], 'points') 26 | fns = sorted(os.listdir(dir_point)) 27 | meta[item] = [] 28 | for fn in fns: 29 | token = (os.path.splitext(os.path.basename(fn))[0]) 30 | meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'))) 31 | 32 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f: 33 | for item in cat: 34 | datapath = [] 35 | num_seg_classes = 0 36 | for fn in meta[item]: 37 | datapath.append((item, fn[0], fn[1])) 38 | 39 | for i in tqdm(range(len(datapath))): 40 | l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8))) 41 | if l > num_seg_classes: 42 | num_seg_classes = l 43 | 44 | print("category {} num segmentation classes {}".format(item, num_seg_classes)) 45 | f.write("{}\t{}\n".format(item, num_seg_classes)) 46 | 47 | 48 | def gen_modelnet_id(root): 49 | classes = [] 50 | with open(os.path.join(root, 'train.txt'), 'r') as f: 51 | for line in f: 52 | classes.append(line.strip().split('/')[0]) 53 | classes = np.unique(classes) 54 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f: 55 | for i in range(len(classes)): 56 | f.write('{}\t{}\n'.format(classes[i], i)) 57 | 58 | 59 | class ShapeNetDataset(data.Dataset): 60 | def __init__(self, 61 | root, 62 | npoints=2500, 63 | classification=False, 64 | class_choice=None, 65 | split='train', 66 | data_augmentation=True): 67 | self.npoints = npoints 68 | self.root = root 69 | self.catfile = os.path.join(self.root, 'synsetoffset2category.txt') # 存标签和文件夹关系 70 | self.cat = {} 71 | self.data_augmentation = data_augmentation 72 | self.classification = classification 73 | self.seg_classes = {} 74 | 75 | with open(self.catfile, 'r') as f: # 只读 76 | for line in f: 77 | ls = line.strip().split() 78 | self.cat[ls[0]] = ls[1] 79 | #print(self.cat) 80 | if not class_choice is None: 81 | self.cat = {k: v for k, v in self.cat.items() if k in class_choice} 82 | 83 | self.id2cat = {v: k for k, v in self.cat.items()} 84 | 85 | self.meta = {} 86 | splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split)) 87 | #from IPython import embed; embed() 88 | filelist = json.load(open(splitfile, 'r')) 89 | for item in self.cat: 90 | self.meta[item] = [] 91 | 92 | for file in filelist: 93 | _, category, uuid = file.split('/') 94 | if category in self.cat.values(): 95 | self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'), 96 | os.path.join(self.root, category, 'points_label', uuid+'.seg'))) 97 | 98 | self.datapath = [] 99 | for item in self.cat: 100 | for fn in self.meta[item]: 101 | self.datapath.append((item, fn[0], fn[1])) 102 | 103 | self.classes = dict(zip(sorted(self.cat), range(len(self.cat)))) 104 | print(self.classes) 105 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), './misc/num_seg_classes.txt'), 'r') as f: 106 | for line in f: 107 | ls = line.strip().split() 108 | self.seg_classes[ls[0]] = int(ls[1]) 109 | self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]] 110 | print(self.seg_classes, self.num_seg_classes) 111 | 112 | def __getitem__(self, index): 113 | fn = self.datapath[index] 114 | cls = self.classes[self.datapath[index][0]] 115 | point_set = np.loadtxt(fn[1]).astype(np.float32) 116 | seg = np.loadtxt(fn[2]).astype(np.int64) 117 | # print(point_set.shape, seg.shape) 118 | 119 | choice = np.random.choice(len(seg), self.npoints, replace=True) 120 | # resample 121 | point_set = point_set[choice, :] 122 | 123 | point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center 124 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0) 125 | point_set = point_set / dist #scale 126 | 127 | if self.data_augmentation: 128 | theta = np.random.uniform(0,np.pi*2) 129 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]]) 130 | point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation 131 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 132 | 133 | seg = seg[choice] 134 | point_set = torch.from_numpy(point_set) 135 | seg = torch.from_numpy(seg) 136 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 137 | 138 | if self.classification: 139 | return point_set, cls 140 | else: 141 | return point_set, seg 142 | 143 | def __len__(self): 144 | return len(self.datapath) 145 | 146 | 147 | class ModelNetDataset(data.Dataset): 148 | def __init__(self, 149 | root, 150 | npoints=2500, 151 | split='train', 152 | data_augmentation=True): 153 | self.npoints = npoints 154 | self.root = root 155 | self.split = split 156 | self.data_augmentation = data_augmentation 157 | self.fns = [] 158 | with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f: 159 | for line in f: 160 | self.fns.append(line.strip()) 161 | 162 | self.cat = {} 163 | with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f: 164 | for line in f: 165 | ls = line.strip().split() 166 | self.cat[ls[0]] = int(ls[1]) 167 | 168 | print(self.cat) 169 | self.classes = list(self.cat.keys()) 170 | 171 | def __getitem__(self, index): 172 | fn = self.fns[index] 173 | cls = self.cat[fn.split('/')[0]] 174 | with open(os.path.join(self.root, fn), 'rb') as f: 175 | plydata = PlyData.read(f) 176 | pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T 177 | choice = np.random.choice(len(pts), self.npoints, replace=True) 178 | point_set = pts[choice, :] 179 | 180 | point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center 181 | dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0) 182 | point_set = point_set / dist # scale 183 | 184 | if self.data_augmentation: 185 | theta = np.random.uniform(0, np.pi * 2) 186 | rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 187 | point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation 188 | point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter 189 | 190 | point_set = torch.from_numpy(point_set.astype(np.float32)) 191 | cls = torch.from_numpy(np.array([cls]).astype(np.int64)) 192 | return point_set, cls 193 | 194 | def __len__(self): 195 | return len(self.fns) 196 | 197 | 198 | if __name__ == '__main__': 199 | dataset = sys.argv[1] 200 | datapath = sys.argv[2] 201 | 202 | if dataset == 'shapenet': 203 | d = ShapeNetDataset(root=datapath, class_choice=['Chair']) 204 | print(len(d)) 205 | ps, seg = d[0] 206 | print(ps.size(), ps.type(), seg.size(),seg.type()) 207 | 208 | d = ShapeNetDataset(root=datapath, classification=True) 209 | print(len(d)) 210 | ps, cls = d[0] 211 | print(ps.size(), ps.type(), cls.size(),cls.type()) 212 | # get_segmentation_classes(datapath) 213 | 214 | if dataset == 'modelnet': 215 | gen_modelnet_id(datapath) 216 | d = ModelNetDataset(root=datapath) 217 | print(len(d)) 218 | print(d[0]) 219 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | class STN3d(nn.Module): 11 | def __init__(self): 12 | super(STN3d, self).__init__() 13 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 14 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 15 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 16 | self.fc1 = nn.Linear(1024, 512) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.fc3 = nn.Linear(256, 9) 19 | self.relu = nn.ReLU() 20 | 21 | self.bn1 = nn.BatchNorm1d(64) 22 | self.bn2 = nn.BatchNorm1d(128) 23 | self.bn3 = nn.BatchNorm1d(1024) 24 | self.bn4 = nn.BatchNorm1d(512) 25 | self.bn5 = nn.BatchNorm1d(256) 26 | 27 | def forward(self, x): 28 | batchsize = x.size()[0] 29 | x = self.relu(self.bn1(self.conv1(x))) 30 | x = self.relu(self.bn2(self.conv2(x))) 31 | x = self.relu(self.bn3(self.conv3(x))) 32 | x = torch.max(x, 2, keepdim=True)[0] 33 | x = x.view(-1, 1024) 34 | 35 | x = self.relu(self.bn4(self.fc1(x))) 36 | x = self.relu(self.bn5(self.fc2(x))) 37 | x = self.fc3(x) 38 | 39 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 40 | batchsize, 1) 41 | if x.is_cuda: 42 | iden = iden.cuda() 43 | x = x + iden 44 | x = x.view(-1, 3, 3) 45 | return x 46 | 47 | 48 | class STNkd(nn.Module): 49 | def __init__(self, k=64): 50 | super(STNkd, self).__init__() 51 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 52 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 53 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 54 | self.fc1 = nn.Linear(1024, 512) 55 | self.fc2 = nn.Linear(512, 256) 56 | self.fc3 = nn.Linear(256, k * k) 57 | self.relu = nn.ReLU() 58 | 59 | self.bn1 = nn.BatchNorm1d(64) 60 | self.bn2 = nn.BatchNorm1d(128) 61 | self.bn3 = nn.BatchNorm1d(1024) 62 | self.bn4 = nn.BatchNorm1d(512) 63 | self.bn5 = nn.BatchNorm1d(256) 64 | 65 | self.k = k 66 | 67 | def forward(self, x): 68 | batchsize = x.size()[0] 69 | x = self.relu(self.bn1(self.conv1(x))) 70 | x = self.relu(self.bn2(self.conv2(x))) 71 | x = self.relu(self.bn3(self.conv3(x))) 72 | x = torch.max(x, 2, keepdim=True)[0] 73 | x = x.view(-1, 1024) 74 | 75 | x = self.relu(self.bn4(self.fc1(x))) 76 | x = self.relu(self.bn5(self.fc2(x))) 77 | x = self.fc3(x) 78 | 79 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( 80 | batchsize, 1) 81 | if x.is_cuda: 82 | iden = iden.cuda() 83 | x = x + iden 84 | x = x.view(-1, self.k, self.k) 85 | return x 86 | 87 | 88 | class PointNet(nn.Module): 89 | def __init__(self, k=2): 90 | super(PointNet, self).__init__() 91 | self.k = k 92 | self.stn3d = STN3d() 93 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 94 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 95 | self.stnkd = STNkd(64) 96 | self.conv3 = torch.nn.Conv1d(64, 64, 1) 97 | self.conv4 = torch.nn.Conv1d(64, 128, 1) 98 | self.conv5 = torch.nn.Conv1d(128, 1024, 1) 99 | 100 | self.fc1 = nn.Linear(1024, 512) 101 | self.fc2 = nn.Linear(512, 256) 102 | self.fc3 = nn.Linear(256, k) 103 | 104 | self.bn1 = nn.BatchNorm1d(64) 105 | self.bn2 = nn.BatchNorm1d(64) 106 | self.bn3 = nn.BatchNorm1d(64) 107 | self.bn4 = nn.BatchNorm1d(128) 108 | self.bn5 = nn.BatchNorm1d(1024) 109 | self.bn6 = nn.BatchNorm1d(512) 110 | self.bn7 = nn.BatchNorm1d(256) 111 | 112 | def forward(self, x): 113 | trans = self.stn3d(x) 114 | x = x.transpose(2, 1) # 变成B*N*3以满足矩阵计算x*trans 115 | x = torch.bmm(x, trans) # 实现batch间的矩阵乘法,不改变batch维度,计算结果维度为B*N*3 116 | x = x.transpose(2, 1) 117 | x = F.relu(self.bn1(self.conv1(x))) 118 | x = F.relu(self.bn2(self.conv2(x))) 119 | trans_feat = self.stnkd(x) 120 | x = x.transpose(2, 1) # 变成B*N*3以满足矩阵计算x*trans 121 | x = torch.bmm(x, trans_feat) # 实现batch间的矩阵乘法,不改变batch维度,计算结果维度为B*N*3 122 | x = x.transpose(2, 1) 123 | x = F.relu(self.bn3(self.conv3(x))) 124 | x = F.relu(self.bn4(self.conv4(x))) 125 | x = F.relu(self.bn5(self.conv5(x))) 126 | x = torch.max(x, 2, keepdim=True)[0].view(-1, 1024) 127 | 128 | x = F.relu(self.bn6(self.fc1(x))) 129 | x = F.relu(self.bn7(self.fc2(x))) 130 | x = F.log_softmax(self.fc3(x), dim=1) 131 | return x, trans, trans_feat 132 | 133 | 134 | def feature_transform_regularizer(trans): 135 | d = trans.size()[1] 136 | I = torch.eye(d)[None, :, :] 137 | if trans.is_cuda: 138 | I = I.cuda() 139 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2))) 140 | return loss 141 | -------------------------------------------------------------------------------- /non.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class STN3d(nn.Module): 12 | def __init__(self): 13 | super(STN3d, self).__init__() 14 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 15 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 16 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 17 | self.fc1 = nn.Linear(1024, 512) 18 | self.fc2 = nn.Linear(512, 256) 19 | self.fc3 = nn.Linear(256, 9) 20 | self.relu = nn.ReLU() 21 | 22 | self.bn1 = nn.BatchNorm1d(64) 23 | self.bn2 = nn.BatchNorm1d(128) 24 | self.bn3 = nn.BatchNorm1d(1024) 25 | self.bn4 = nn.BatchNorm1d(512) 26 | self.bn5 = nn.BatchNorm1d(256) 27 | 28 | def forward(self, x): 29 | batchsize = x.size()[0] 30 | x = F.relu(self.bn1(self.conv1(x))) 31 | x = F.relu(self.bn2(self.conv2(x))) 32 | x = F.relu(self.bn3(self.conv3(x))) 33 | x = torch.max(x, 2, keepdim=True)[0] 34 | x = x.view(-1, 1024) 35 | 36 | x = F.relu(self.bn4(self.fc1(x))) 37 | x = F.relu(self.bn5(self.fc2(x))) 38 | x = self.fc3(x) 39 | 40 | iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1) 41 | if x.is_cuda: 42 | iden = iden.cuda() 43 | x = x + iden 44 | x = x.view(-1, 3, 3) 45 | return x 46 | 47 | 48 | class STNkd(nn.Module): 49 | def __init__(self, k=64): 50 | super(STNkd, self).__init__() 51 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 52 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 53 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 54 | self.fc1 = nn.Linear(1024, 512) 55 | self.fc2 = nn.Linear(512, 256) 56 | self.fc3 = nn.Linear(256, k*k) 57 | self.relu = nn.ReLU() 58 | 59 | self.bn1 = nn.BatchNorm1d(64) 60 | self.bn2 = nn.BatchNorm1d(128) 61 | self.bn3 = nn.BatchNorm1d(1024) 62 | self.bn4 = nn.BatchNorm1d(512) 63 | self.bn5 = nn.BatchNorm1d(256) 64 | 65 | self.k = k 66 | 67 | def forward(self, x): 68 | batchsize = x.size()[0] 69 | x = F.relu(self.bn1(self.conv1(x))) 70 | x = F.relu(self.bn2(self.conv2(x))) 71 | x = F.relu(self.bn3(self.conv3(x))) 72 | x = torch.max(x, 2, keepdim=True)[0] 73 | x = x.view(-1, 1024) 74 | 75 | x = F.relu(self.bn4(self.fc1(x))) 76 | x = F.relu(self.bn5(self.fc2(x))) 77 | x = self.fc3(x) 78 | 79 | iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) 80 | if x.is_cuda: 81 | iden = iden.cuda() 82 | x = x + iden 83 | x = x.view(-1, self.k, self.k) 84 | return x 85 | 86 | 87 | class PointNetfeat(nn.Module): 88 | def __init__(self, global_feat=True, feature_transform=False): 89 | super(PointNetfeat, self).__init__() 90 | self.stn = STN3d() 91 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 92 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 93 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 94 | self.bn1 = nn.BatchNorm1d(64) 95 | self.bn2 = nn.BatchNorm1d(128) 96 | self.bn3 = nn.BatchNorm1d(1024) 97 | self.global_feat = global_feat 98 | self.feature_transform = feature_transform 99 | if self.feature_transform: 100 | self.fstn = STNkd(k=64) 101 | 102 | def forward(self, x): 103 | n_pts = x.size()[2] # 原始输入为B*3*N 104 | trans = self.stn(x) # trans.shape = B*3*3 105 | x = x.transpose(2, 1) # 变成B*N*3以满足矩阵计算x*trans 106 | x = torch.bmm(x, trans) # 实现bath间的矩阵乘法,不改变batch维度,计算结果维度为B*N*3 107 | x = x.transpose(2, 1) 108 | x = F.relu(self.bn1(self.conv1(x))) 109 | 110 | if self.feature_transform: 111 | trans_feat = self.fstn(x) 112 | x = x.transpose(2,1) 113 | x = torch.bmm(x, trans_feat) 114 | x = x.transpose(2,1) 115 | else: 116 | trans_feat = None 117 | 118 | pointfeat = x 119 | x = F.relu(self.bn2(self.conv2(x))) 120 | x = self.bn3(self.conv3(x)) 121 | x = torch.max(x, 2, keepdim=True)[0] 122 | x = x.view(-1, 1024) 123 | if self.global_feat: 124 | return x, trans, trans_feat 125 | else: 126 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) # 从B*1024变成B*1024*N用于语义分割 127 | return torch.cat([x, pointfeat], 1), trans, trans_feat 128 | 129 | 130 | class PointNetCls(nn.Module): 131 | def __init__(self, k=2, feature_transform=False): 132 | super(PointNetCls, self).__init__() 133 | self.feature_transform = feature_transform 134 | self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) 135 | self.fc1 = nn.Linear(1024, 512) 136 | self.fc2 = nn.Linear(512, 256) 137 | self.fc3 = nn.Linear(256, k) 138 | self.dropout = nn.Dropout(p=0.3) 139 | self.bn1 = nn.BatchNorm1d(512) 140 | self.bn2 = nn.BatchNorm1d(256) 141 | self.relu = nn.ReLU() 142 | 143 | def forward(self, x): 144 | x, trans, trans_feat = self.feat(x) 145 | x = F.relu(self.bn1(self.fc1(x))) 146 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 147 | x = self.fc3(x) 148 | return F.log_softmax(x, dim=1), trans, trans_feat 149 | 150 | -------------------------------------------------------------------------------- /pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | 8 | def timeit(tag, t): 9 | print("{}: {}s".format(tag, time() - t)) 10 | return time() 11 | 12 | 13 | def pc_normalize(pc): 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 17 | pc = pc / m 18 | return pc 19 | 20 | 21 | def square_distance(src, dst): 22 | """ 23 | Calculate Euclid distance between each two points. 24 | 25 | src^T * dst = xn * xm + yn * ym + zn * zm; 26 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 27 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 28 | 29 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 30 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 31 | 32 | Input: 33 | src: source points, [B, N, C] 34 | dst: target points, [B, M, C] 35 | Output: 36 | dist: per-point square distance, [B, N, M] 37 | """ 38 | B, N, _ = src.shape 39 | _, M, _ = dst.shape 40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 41 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 43 | return dist 44 | 45 | 46 | def index_points(points, idx): 47 | """ 48 | Input: 49 | points: input points data, [B, N, C] 50 | idx: sample index data, [B, S] 51 | Return: 52 | new_points:, indexed points data, [B, S, C] 53 | """ 54 | device = points.device 55 | B = points.shape[0] 56 | view_shape = list(idx.shape) 57 | # view_shape = [B, 1] 58 | view_shape[1:] = [1] * (len(view_shape) - 1) 59 | repeat_shape = list(idx.shape) 60 | # repeat_shape = [1, S] 61 | repeat_shape[0] = 1 62 | # 其实就是想把[0, B-1]的range一维向量变为与idx维度相同的[0, B-1]重复S次矩阵,这样维度相同切片时才能一一对应 63 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 64 | new_points = points[batch_indices, idx, :] 65 | return new_points 66 | 67 | 68 | def farthest_point_sample(xyz, npoint): 69 | """ 70 | Input: 71 | xyz: pointcloud data, [B, N, C] 72 | npoint: number of samples 73 | Return: 74 | centroids: sampled pointcloud index, [B, npoint] 75 | """ 76 | device = xyz.device 77 | B, N, C = xyz.shape 78 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 79 | distance = torch.ones(B, N).to(device) * 1e10 80 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 81 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 82 | # 每一次循环找的是batch个最远点: farthest [B, ] 83 | # distance里存的是每一次操作以后B*N个点相对于所有备选点集的最远距离 84 | # mask存在是因为点到点集的距离是点到点集中每个点的距离中的最小值 85 | for i in range(npoint): 86 | centroids[:, i] = farthest 87 | # batch_indices [B ], farthest [B ],输出 [B, 1, 3] 88 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 89 | dist = torch.sum((xyz - centroid) ** 2, -1) 90 | mask = dist < distance 91 | distance[mask] = dist[mask] 92 | farthest = torch.max(distance, -1)[1] 93 | return centroids 94 | 95 | 96 | def query_ball_point(radius, nsample, xyz, new_xyz): 97 | """ 98 | Input: 99 | radius: local region radius 100 | nsample: max sample number in local region 101 | xyz: all points, [B, N, C] 102 | new_xyz: query points, [B, S, C] 103 | Return: 104 | group_idx: grouped points index, [B, S, nsample] 105 | """ 106 | device = xyz.device 107 | B, N, C = xyz.shape 108 | _, S, _ = new_xyz.shape 109 | # group_idx [B, S, N] 110 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 111 | sqrdists = square_distance(new_xyz, xyz) 112 | # group_idx 初始化范围为0-N-1,置为 N 即直接丢弃 113 | group_idx[sqrdists > radius ** 2] = N 114 | # 只取前 nsample 个数 [B, S, N] -> [B, S, nsample] 115 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 116 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 117 | # 处理球形区域内不足nsample个点的情况:重复第一个点nsample次 118 | mask = group_idx == N 119 | group_idx[mask] = group_first[mask] 120 | return group_idx 121 | 122 | 123 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 124 | """ 125 | Input: 126 | npoint: 127 | radius: 128 | nsample: 129 | xyz: input points position data, [B, N, C] 130 | points: input points data, [B, N, D] 131 | Return: 132 | new_xyz: sampled points position data, [B, npoint, C] 133 | new_points: sampled points data, [B, npoint, N, C+D] 134 | """ 135 | B, N, C = xyz.shape 136 | S = npoint 137 | fps_idx = farthest_point_sample(xyz, npoint) 138 | new_xyz = index_points(xyz, fps_idx) # [B, npoint, C] 139 | idx = query_ball_point(radius, nsample, xyz, new_xyz) # [B, npoint, nsample] 140 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 141 | # 减去区域的中心值进行归一化 142 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 143 | # 如果每个点上面有新的特征的维度,则用新的特征与旧的特征拼接,否则直接返回旧的特征 144 | if points is not None: 145 | grouped_points = index_points(points, idx) 146 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 147 | else: 148 | new_points = grouped_xyz_norm 149 | if returnfps: 150 | return new_xyz, new_points, grouped_xyz, fps_idx 151 | else: 152 | return new_xyz, new_points 153 | 154 | 155 | def sample_and_group_all(xyz, points): 156 | """ 157 | Input: 158 | xyz: input points position data, [B, N, C] 159 | points: input points data, [B, N, D] 160 | Return: 161 | new_xyz: sampled points position data, [B, 1, C] 162 | new_points: sampled points data, [B, 1, N, C+D] 163 | """ 164 | device = xyz.device 165 | B, N, C = xyz.shape 166 | new_xyz = torch.zeros(B, 1, C).to(device) 167 | grouped_xyz = xyz.view(B, 1, N, C) 168 | if points is not None: 169 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 170 | else: 171 | new_points = grouped_xyz 172 | return new_xyz, new_points 173 | 174 | 175 | class PointNetSetAbstraction(nn.Module): 176 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 177 | """ 178 | Input: 179 | npoint: Number of point for FPS sampling 180 | radius: Radius for ball query 181 | nsample: Number of point for each ball query 182 | in_channel: the dimention of channel 183 | mlp: A list for mlp input-output channel, such as [64, 64, 128] 184 | group_all: bool type for group_all or not 185 | """ 186 | super(PointNetSetAbstraction, self).__init__() 187 | self.npoint = npoint 188 | self.radius = radius 189 | self.nsample = nsample 190 | self.mlp_convs = nn.ModuleList() 191 | self.mlp_bns = nn.ModuleList() 192 | last_channel = in_channel 193 | for out_channel in mlp: 194 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 195 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 196 | last_channel = out_channel 197 | self.group_all = group_all 198 | 199 | def forward(self, xyz, points): 200 | """ 201 | Input: 202 | xyz: input points position data, [B, C, N] 203 | points: input points data, [B, D, N] 204 | Return: 205 | new_xyz: sampled points position data, [B, C, S] 206 | new_points_concat: sample points feature data, [B, D', S] 207 | """ 208 | xyz = xyz.permute(0, 2, 1) 209 | if points is not None: 210 | points = points.permute(0, 2, 1) 211 | 212 | if self.group_all: 213 | new_xyz, new_points = sample_and_group_all(xyz, points) 214 | else: 215 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 216 | # new_xyz: sampled points position data, [B, npoint, C] 217 | # new_points: sampled points data, [B, npoint, nsample, C+D] 218 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 219 | for i, conv in enumerate(self.mlp_convs): 220 | bn = self.mlp_bns[i] 221 | new_points = F.relu(bn(conv(new_points))) 222 | 223 | new_points = torch.max(new_points, 2)[0] 224 | new_xyz = new_xyz.permute(0, 2, 1) 225 | return new_xyz, new_points 226 | 227 | 228 | class PointNetSetAbstractionMsg(nn.Module): 229 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 230 | super(PointNetSetAbstractionMsg, self).__init__() 231 | self.npoint = npoint 232 | self.radius_list = radius_list 233 | self.nsample_list = nsample_list 234 | self.conv_blocks = nn.ModuleList() 235 | self.bn_blocks = nn.ModuleList() 236 | for i in range(len(mlp_list)): 237 | convs = nn.ModuleList() 238 | bns = nn.ModuleList() 239 | last_channel = in_channel + 3 240 | for out_channel in mlp_list[i]: 241 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 242 | bns.append(nn.BatchNorm2d(out_channel)) 243 | last_channel = out_channel 244 | self.conv_blocks.append(convs) 245 | self.bn_blocks.append(bns) 246 | 247 | def forward(self, xyz, points): 248 | """ 249 | Input: 250 | xyz: input points position data, [B, C, N] 251 | points: input points data, [B, D, N] 252 | Return: 253 | new_xyz: sampled points position data, [B, C, S] 254 | new_points_concat: sample points feature data, [B, D', S] 255 | """ 256 | xyz = xyz.permute(0, 2, 1) 257 | if points is not None: 258 | points = points.permute(0, 2, 1) 259 | 260 | B, N, C = xyz.shape 261 | S = self.npoint 262 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 263 | new_points_list = [] 264 | for i, radius in enumerate(self.radius_list): 265 | K = self.nsample_list[i] 266 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 267 | grouped_xyz = index_points(xyz, group_idx) 268 | # normalization 269 | grouped_xyz -= new_xyz.view(B, S, 1, C) 270 | if points is not None: 271 | grouped_points = index_points(points, group_idx) 272 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 273 | else: 274 | grouped_points = grouped_xyz 275 | 276 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 277 | for j in range(len(self.conv_blocks[i])): 278 | conv = self.conv_blocks[i][j] 279 | bn = self.bn_blocks[i][j] 280 | grouped_points = F.relu(bn(conv(grouped_points))) 281 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 282 | new_points_list.append(new_points) 283 | 284 | new_xyz = new_xyz.permute(0, 2, 1) 285 | new_points_concat = torch.cat(new_points_list, dim=1) 286 | return new_xyz, new_points_concat 287 | 288 | 289 | class PointNetFeaturePropagation(nn.Module): 290 | def __init__(self, in_channel, mlp): 291 | super(PointNetFeaturePropagation, self).__init__() 292 | self.mlp_convs = nn.ModuleList() 293 | self.mlp_bns = nn.ModuleList() 294 | last_channel = in_channel 295 | for out_channel in mlp: 296 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 297 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 298 | last_channel = out_channel 299 | 300 | def forward(self, xyz1, xyz2, points1, points2): 301 | """ 302 | Input: 303 | xyz1: input points position data, [B, C, N] 304 | xyz2: sampled input points position data, [B, C, S] 305 | points1: input points data, [B, D, N] 306 | points2: input points data, [B, D, S] 307 | Return: 308 | new_points: upsampled points data, [B, D', N] 309 | """ 310 | xyz1 = xyz1.permute(0, 2, 1) 311 | xyz2 = xyz2.permute(0, 2, 1) 312 | 313 | points2 = points2.permute(0, 2, 1) 314 | B, N, C = xyz1.shape 315 | _, S, _ = xyz2.shape 316 | 317 | if S == 1: 318 | interpolated_points = points2.repeat(1, N, 1) 319 | else: 320 | dists = square_distance(xyz1, xyz2) 321 | dists, idx = dists.sort(dim=-1) 322 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 323 | dists[dists < 1e-10] = 1e-10 324 | weight = 1.0 / dists # [B, N, 3] 325 | weight = weight / torch.sum(weight, dim=-1).view(B, N, 1) # [B, N, 3] 326 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 327 | 328 | if points1 is not None: 329 | points1 = points1.permute(0, 2, 1) 330 | new_points = torch.cat([points1, interpolated_points], dim=-1) 331 | else: 332 | new_points = interpolated_points 333 | 334 | new_points = new_points.permute(0, 2, 1) 335 | for i, conv in enumerate(self.mlp_convs): 336 | bn = self.mlp_bns[i] 337 | new_points = F.relu(bn(conv(new_points))) 338 | return new_points 339 | -------------------------------------------------------------------------------- /tmp.py: -------------------------------------------------------------------------------- 1 | def cosine_similarity(feature): 2 | """ 3 | Input: 4 | feature: source points, [B, N, C] 5 | Output: 6 | dist: per-point cosine_similarity distance, [B, N, N] 7 | """ 8 | B, N, C = feature.shape 9 | feat = torch.matmul(feature, feature.permute(0, 2, 1)) # [B, N, N] 10 | norm = torch.sqrt(torch.sum(feature ** 2, -1)).view(B, N, 1) 11 | norm = torch.matmul(norm, norm.permute(0, 2, 1)) 12 | res = torch.div(feat, norm) 13 | return res 14 | --------------------------------------------------------------------------------