├── .gitignore ├── FeatureExtractorPart ├── pointnext.py └── utils.py ├── Fig └── test_acc.svg ├── Loss.py ├── Model.py ├── Parameters.py ├── README.md ├── Trainer.py ├── Transforms.py ├── dataset └── ModelNet40.py ├── main.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pkl 3 | *.ipynb 4 | *.pcd 5 | *.pyc 6 | __pycache__ 7 | result_train 8 | result_test 9 | result_eval 10 | result_demo 11 | runs 12 | .vscode 13 | *.ffs_db 14 | *.ffs_lock 15 | wandb 16 | __* 17 | .idea 18 | test.py 19 | -------------------------------------------------------------------------------- /FeatureExtractorPart/pointnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from FeatureExtractorPart.utils import index_points, farthest_point_sample, query_hybrid, \ 4 | coordinate_distance, build_mlp 5 | 6 | 7 | class SetAbstraction(nn.Module): 8 | """ 9 | 点云特征提取 10 | 包含一个单尺度S-G-P过程 11 | """ 12 | 13 | def __init__(self, 14 | npoint: int, 15 | radius: int, 16 | nsample: int, 17 | in_channel: int, 18 | coor_dim: int = 3): 19 | """ 20 | :param npoint: 采样点数量 21 | :param radius: 采样半径 22 | :param nsample: 采样点数量 23 | :param in_channel: 特征维度的输入值 24 | :param coor_dim: 点的坐标维度,默认为3 25 | """ 26 | super().__init__() 27 | self.npoint = npoint 28 | self.radius = radius 29 | self.nsample = nsample 30 | self.in_channel = in_channel 31 | self.coor_dim = coor_dim 32 | self.mlp = build_mlp(in_channel=in_channel + coor_dim, channel_list=[in_channel * 2], dim=2) 33 | 34 | def forward(self, points_coor, points_fea): 35 | """ 36 | :param points_coor: (B, 3, N) 点云原始坐标 37 | :param points_fea: (B, C, N) 点云特征 38 | :return: 39 | new_xyz: (B, 3, S) 下采样后的点云坐标 40 | new_fea: (B, D, S) 采样点特征 41 | """ 42 | points_coor, points_fea = points_coor.permute(0, 2, 1), points_fea.permute(0, 2, 1) # (B, C, N) -> (B, N, C) 43 | bs, nbr_point_in, _ = points_coor.shape 44 | num_point_out = self.npoint 45 | 46 | '''S 采样''' 47 | new_coor = index_points(points_coor, farthest_point_sample(points_coor, num_point_out)) # 获取新采样点 (B, S, coor) 48 | 49 | '''G 分组''' 50 | # 每个group的点云索引 (B, S, K) 51 | group_idx = query_hybrid(self.radius, self.nsample, points_coor[..., :3], new_coor[..., :3]) 52 | 53 | # 基于分组获取各组内点云坐标和特征,并进行拼接 54 | grouped_points_coor = index_points(points_coor[..., :3], group_idx) # 每个group内所有点云的坐标 (B, S, K, 3) 55 | grouped_points_coor -= new_coor[..., :3].view(bs, num_point_out, 1, 3) # 坐标转化为与采样点的偏移量 (B, S, K, 3) 56 | grouped_points_coor = grouped_points_coor / self.radius # 相对坐标归一化 57 | grouped_points_fea = index_points(points_fea, group_idx) # 每个group内所有点云的特征 (B, S, K, C) 58 | grouped_points_fea = torch.cat([grouped_points_fea, grouped_points_coor], dim=-1) # 拼接坐标偏移量 (B, S, K, C+3) 59 | 60 | '''P 特征提取''' 61 | # (B, S, K, C+3) -> (B, C+3, K, S) -mlp-> (B, D, K, S) -pooling-> (B, D, S) 62 | grouped_points_fea = grouped_points_fea.permute(0, 3, 2, 1) # 2d卷积作用于维度1 63 | grouped_points_fea = self.mlp(grouped_points_fea) 64 | new_points_fea = torch.max(grouped_points_fea, dim=2)[0] 65 | 66 | new_coor = new_coor.permute(0, 2, 1) # (B, 3, S) 67 | return new_coor, new_points_fea 68 | 69 | 70 | class LocalAggregation(nn.Module): 71 | """ 72 | 局部特征提取 73 | 包含一个单尺度G-P过程,每一个点都作为采样点进行group以聚合局部特征,无下采样过程 74 | """ 75 | 76 | def __init__(self, 77 | radius: int, 78 | nsample: int, 79 | in_channel: int, 80 | coor_dim: int = 3): 81 | """ 82 | :param radius: 采样半径 83 | :param nsample: 采样点数量 84 | :param in_channel: 特征维度的输入值 85 | :param coor_dim: 点的坐标维度,默认为3 86 | """ 87 | super().__init__() 88 | self.radius = radius 89 | self.nsample = nsample 90 | self.in_channel = in_channel 91 | self.mlp = build_mlp(in_channel=in_channel + coor_dim, channel_list=[in_channel], dim=2) 92 | 93 | def forward(self, points_coor, points_fea): 94 | """ 95 | :param points_coor: (B, 3, N) 点云原始坐标 96 | :param points_fea: (B, C, N) 点云特征 97 | :return: 98 | new_fea: (B, D, N) 局部特征聚合后的特征 99 | """ 100 | # (B, C, N) -> (B, N, C) 101 | points_coor, points_fea = points_coor.permute(0, 2, 1), points_fea.permute(0, 2, 1) 102 | bs, npoint, _ = points_coor.shape 103 | 104 | '''G 分组''' 105 | # 每个group的点云索引 (B, N, K) 106 | group_idx = query_hybrid(self.radius, self.nsample, points_coor[..., :3], points_coor[..., :3]) 107 | 108 | # 基于分组获取各组内点云坐标和特征,并进行拼接 109 | grouped_points_coor = index_points(points_coor[..., :3], group_idx) # 每个group内所有点云的坐标 (B, N, K, 3) 110 | grouped_points_coor = grouped_points_coor - points_coor[..., :3].view(bs, npoint, 1, 3) # 坐标转化为与采样点的偏移量 111 | grouped_points_coor = grouped_points_coor / self.radius # 相对坐标归一化 112 | grouped_points_fea = index_points(points_fea, group_idx) # 每个group内所有点云的特征 (B, N, K, C) 113 | grouped_points_fea = torch.cat([grouped_points_fea, grouped_points_coor], dim=-1) # 拼接坐标偏移量 (B, N, K, C+3) 114 | 115 | '''P 特征提取''' 116 | # (B, N, K, C+3) -> (B, C+3, K, N) -mlp-> (B, D, K, N) -pooling-> (B, D, N) 117 | grouped_points_fea = grouped_points_fea.permute(0, 3, 2, 1) # 2d卷积作用于维度1 118 | grouped_points_fea = self.mlp(grouped_points_fea) 119 | new_fea = torch.max(grouped_points_fea, dim=2)[0] 120 | 121 | return new_fea 122 | 123 | 124 | class InvResMLP(nn.Module): 125 | """ 126 | 逆瓶颈残差块 127 | """ 128 | 129 | def __init__(self, 130 | radius: int, 131 | nsample: int, 132 | in_channel: int, 133 | coor_dim: int = 3, 134 | expansion: int = 4): 135 | """ 136 | :param radius: 采样半径 137 | :param nsample: 采样点数量 138 | :param in_channel: 特征维度的输入值 139 | :param coor_dim: 点的坐标维度,默认为3 140 | :param expansion: 中间层通道数扩张倍数 141 | """ 142 | super().__init__() 143 | self.la = LocalAggregation(radius=radius, nsample=nsample, in_channel=in_channel, coor_dim=coor_dim) 144 | channel_list = [in_channel * expansion, in_channel] 145 | self.pw_conv = build_mlp(in_channel=in_channel, channel_list=channel_list, dim=1, drop_last_act=True) 146 | self.act = nn.ReLU(inplace=True) 147 | 148 | def forward(self, points): 149 | """ 150 | :param points: 151 | (B, 3, N) 点云原始坐标 152 | (B, C, N) 点云特征 153 | :return: 154 | new_fea: (B, D, N) 155 | """ 156 | points_coor, points_fea = points 157 | identity = points_fea 158 | points_fea = self.la(points_coor, points_fea) 159 | points_fea = self.pw_conv(points_fea) 160 | points_fea = points_fea + identity 161 | points_fea = self.act(points_fea) 162 | return [points_coor, points_fea] 163 | 164 | 165 | class Stage(nn.Module): 166 | """ 167 | PointNeXt一个下采样阶段 168 | """ 169 | 170 | def __init__(self, 171 | npoint: int, 172 | radius_list: list, 173 | nsample_list: list, 174 | in_channel: int, 175 | coor_dim: int = 3, 176 | expansion: int = 4): 177 | """ 178 | :param npoint: 采样点数量 179 | :param radius_list: 采样半径 180 | :param nsample_list: 采样邻点数量 181 | :param in_channel: 特征维度的输入值 182 | :param coor_dim: 点的坐标维度,默认为3 183 | :param expansion: 中间层通道数扩张倍数 184 | """ 185 | super().__init__() 186 | self.sa = SetAbstraction(npoint=npoint, radius=radius_list[0], nsample=nsample_list[0], 187 | in_channel=in_channel, coor_dim=coor_dim) 188 | 189 | irm = [] 190 | for i in range(1, len(radius_list)): 191 | irm.append( 192 | InvResMLP(radius=radius_list[i], nsample=nsample_list[i], in_channel=in_channel * 2, 193 | coor_dim=coor_dim, expansion=expansion) 194 | ) 195 | self.irm = nn.Sequential(*irm) 196 | 197 | def forward(self, points_coor, points_fea): 198 | """ 199 | :param points_coor: (B, 3, N) 点云原始坐标 200 | :param points_fea: (B, D, N) 点云特征 201 | :return: 202 | new_xyz: (B, 3, S) 下采样后的点云坐标 203 | new_points_concat: (B, D', S) 下采样后的点云特征 204 | """ 205 | new_coor, new_points_fea = self.sa(points_coor, points_fea) 206 | new_coor, new_points_fea = self.irm([new_coor, new_points_fea]) 207 | return new_coor, new_points_fea 208 | 209 | 210 | class FeaturePropagation(nn.Module): 211 | """ 212 | FP上采样模块 213 | """ 214 | 215 | def __init__(self, in_channel, mlp, coor_dim=3): 216 | """ 217 | :param in_channel: > 同层和下层特征维度的输入值 218 | :param mlp: mlp的通道维度数 219 | """ 220 | super(FeaturePropagation, self).__init__() 221 | self.mlp_modules = build_mlp(in_channel=sum(in_channel), channel_list=mlp, dim=1) 222 | self.coor_dim = coor_dim 223 | 224 | def forward(self, xyz1, xyz2, points1, points2): 225 | """ 226 | :param xyz1: (B, 3, N) 同层点云原始坐标 227 | :param xyz2: (B, 3, S) 下层点云原始坐标 228 | :param points1: (B, D1, N) 同层点云特征 229 | :param points2: (B, D2, S) 下层点云特征 230 | :return: (B, D, N) 上采样后的点云特征 231 | """ 232 | B, _, N = xyz1.shape 233 | _, _, S = xyz2.shape 234 | 235 | if S == 1: 236 | # (B, D2, 1) -> (B, D2, N) 只有一个特征点则直接扩展至所有点 237 | new_points = points2.repeat(1, 1, N) 238 | else: 239 | # (B, C, N) -> (B, N, C) 240 | xyz1, xyz2, points2 = xyz1.permute(0, 2, 1), xyz2.permute(0, 2, 1), points2.permute(0, 2, 1) 241 | # 找到与每个同层点最近的前3个下层特征点 242 | dists = coordinate_distance(xyz1[..., :3], xyz2[..., :3]) 243 | dists, idx = torch.topk(dists, k=3, dim=-1, largest=False) # 基于距离选择最近点 (B, N, 3) 244 | 245 | # 基于距离进行特征值加权求和 246 | dist_recip = 1.0 / dists.clamp(min=1e-8) 247 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 248 | weight = dist_recip / norm 249 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 250 | # (B, N, D2) -> (B, D2, N) 251 | new_points = interpolated_points.permute(0, 2, 1) 252 | 253 | # 下层特征值与同层特征值拼接,作为扩展后的特征值 (B, D2, N) -> (B, D1+D2, N) -> (B, D, N) 254 | new_points = torch.cat((points1, new_points), dim=1) 255 | new_points = self.mlp_modules(new_points) 256 | return new_points 257 | 258 | 259 | class Head(nn.Module): 260 | """分类头 & 分割头""" 261 | def __init__(self, in_channel, mlp, num_class, task_type): 262 | """ 263 | :param in_channel: 特征维度的输入值 264 | :param mlp: mlp的通道维度数 265 | :param num_class: 输出类别的数量 266 | """ 267 | super(Head, self).__init__() 268 | mlp.append(num_class) 269 | self.mlp_modules = build_mlp(in_channel=in_channel, channel_list=mlp, dim=1, 270 | drop_last_norm_act=True, dropout=True) 271 | self.task_type = task_type 272 | 273 | def forward(self, points_fea): 274 | """ 275 | :param points_fea: (B, C, N) 点云特征 276 | :return: (B, num_class, N) 点云特征 277 | """ 278 | if self.task_type == 'classification': 279 | points_fea = torch.max(points_fea, dim=-1, keepdim=True)[0] # (B, C, N) -> (B, C, 1) 280 | points_cls = self.mlp_modules(points_fea) 281 | return points_cls 282 | 283 | -------------------------------------------------------------------------------- /FeatureExtractorPart/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def build_mlp(in_channel, channel_list, dim=2, bias=False, drop_last_act=False, 6 | drop_last_norm_act=False, dropout=False): 7 | """ 8 | 构造基于n dim 1x1卷积的mlp 9 | :param in_channel: 特征维度的输入值 10 | :param channel_list: mlp各层的输出通道维度数 11 | :param dim: 维度,1或2 12 | :param bias: 卷积层是否添加bias,一般BN前的卷积层不使用bias 13 | :param drop_last_act: 是否去除最后一层激活函数 14 | :param drop_last_norm_act: 是否去除最后一层标准化层和激活函数 15 | :param dropout: 是否添加dropout层 16 | :return: 17 | """ 18 | # 解析参数获取相应卷积层、归一化层、激活函数 19 | if dim == 1: 20 | Conv = nn.Conv1d 21 | NORM = nn.BatchNorm1d 22 | else: 23 | Conv = nn.Conv2d 24 | NORM = nn.BatchNorm2d 25 | ACT = nn.ReLU 26 | 27 | # 根据通道数构建mlp 28 | mlp = [] 29 | for i, channel in enumerate(channel_list): 30 | if dropout and i > 0: 31 | mlp.append(nn.Dropout(0.5, inplace=False)) 32 | # 每层为conv-bn-relu 33 | mlp.append(Conv(in_channels=in_channel, out_channels=channel, kernel_size=1, bias=bias)) 34 | mlp.append(NORM(channel)) 35 | mlp.append(ACT(inplace=True)) 36 | if i < len(channel_list) - 1: 37 | in_channel = channel 38 | 39 | if drop_last_act: 40 | mlp = mlp[:-1] 41 | elif drop_last_norm_act: 42 | mlp = mlp[:-2] 43 | mlp[-1] = Conv(in_channels=in_channel, out_channels=channel, kernel_size=1, bias=True) 44 | 45 | return nn.Sequential(*mlp) 46 | 47 | 48 | def coordinate_distance(src, dst): 49 | """ 50 | 计算两个点集的各点间距 51 | !!!使用半精度运算或自动混合精度时[不要]使用化简的方法,否则会出现严重的浮点误差 52 | :param src: (B, M, C) C为坐标 53 | :param dst: (B, N, C) C为坐标 54 | :return: (B, M, N) 55 | """ 56 | B, M, _ = src.shape 57 | _, N, _ = dst.shape 58 | dist = -2 * torch.matmul(src, dst.transpose(1, 2)) 59 | dist += torch.sum(src ** 2, -1).view(B, M, 1) 60 | dist += torch.sum(dst ** 2, -1).view(B, 1, N) 61 | 62 | # dist = torch.sum((src.unsqueeze(2) - dst.unsqueeze(1)).pow(2), dim=-1) 63 | return dist 64 | 65 | 66 | def index_points(points, idx): 67 | """ 68 | 跟据采样点索引获取其原始点云xyz坐标等信息 69 | :param points: (B, N, 3+) 原始点云 70 | :param idx: (B, S)/(B, S, G) 采样点索引,S为采样点数量,G为每个采样点grouping的点数 71 | :return: (B, S, 3+)/(B, S, G, 3+) 获取了原始点云信息的采样点 72 | """ 73 | B = points.shape[0] 74 | view_shape = list(idx.shape) 75 | view_shape[1:] = [1] * (len(view_shape) - 1) 76 | repeat_shape = list(idx.shape) 77 | repeat_shape[0] = 1 78 | batch_indices = torch.arange(B, dtype=torch.long, device=points.device).view(view_shape).repeat(repeat_shape) 79 | new_points = points[batch_indices, idx, :] 80 | return new_points 81 | 82 | 83 | def farthest_point_sample(xyz, npoint): 84 | """ 85 | 最远点采样 86 | 随机选择一个初始点作为采样点,循环的将与当前采样点距离最远的点当作下一个采样点,直至满足采样点的数量需求 87 | :param xyz: (B, N, 3+) 原始点云 88 | :param npoint: 采样点数量 89 | :return: (B, npoint) 采样点索引 90 | """ 91 | device = xyz.device 92 | B, N, C = xyz.shape 93 | npoint = min(npoint, N) 94 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 95 | distance = torch.ones(B, N).to(device) * 1e10 # 每个点与最近采样点的最小距离 96 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) # 随机选取初始点 97 | 98 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 99 | for i in range(npoint): 100 | centroids[:, i] = farthest 101 | centroid = xyz[batch_indices, farthest, :].view(B, 1, -1) # [bs, 1, coor_dim] 102 | dist = torch.nn.functional.pairwise_distance(xyz, centroid) 103 | mask = dist < distance 104 | distance[mask] = dist[mask] 105 | farthest = torch.max(distance, -1)[1] 106 | return centroids 107 | 108 | 109 | def query_hybrid(radius, nsample, xyz, new_xyz): 110 | """ 111 | 基于采样点进行KNN与ball query混合的grouping 112 | :param radius: grouping半径 113 | :param nsample: group内点云数量 114 | :param xyz: (B, N, 3) 原始点云 115 | :param new_xyz: (B, S, 3) 采样点 116 | :return: (B, S, nsample) 每个采样点grouping的点云索引 117 | """ 118 | B, N, C = xyz.shape 119 | _, S, _ = new_xyz.shape 120 | 121 | dist = coordinate_distance(new_xyz, xyz) # 每个采样点与其他点的距离的平方 122 | dist, group_idx = torch.topk(dist, k=nsample, dim=-1, largest=False) # 基于距离选择最近的作为采样点 123 | radius = radius ** 2 124 | mask = dist > radius # 距离较远的点替换为距离最近的点 125 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 126 | group_idx[mask] = group_first[mask] 127 | 128 | return group_idx 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /Fig/test_acc.svg: -------------------------------------------------------------------------------- 1 | 0.8650.870.8750.880.8850.890.8950.90.9050.910.9150.920.9250.93-50050100150200250300350400450500550600650 -------------------------------------------------------------------------------- /Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LabelSmoothingCE(nn.Module): 7 | """ 8 | 带有标签平滑的交叉熵损失 9 | """ 10 | 11 | def __init__(self, smoothing: float = 0.1): 12 | super().__init__() 13 | self.smoothing = smoothing 14 | self.confidence = 1 - smoothing 15 | 16 | def forward(self, pred: torch.Tensor, gt: torch.Tensor): 17 | """ 18 | :param pred: (B, num_class, N) 分类时N=1,分割时N等于点云数量 19 | :param gt: (B, N) 20 | :return: loss, acc 21 | """ 22 | B, cls, N = pred.shape 23 | 24 | # (B, cls, N) -> (B*N, cls) 25 | pred = pred.permute(0, 2, 1).reshape(B*N, cls) 26 | gt = gt.reshape(B*N,) 27 | 28 | acc = torch.sum(torch.max(pred, dim=-1)[1] == gt) / (B * N) 29 | 30 | logprobs = F.log_softmax(pred, dim=-1) 31 | loss_pos = -logprobs.gather(dim=-1, index=gt.unsqueeze(1)).squeeze(1) 32 | loss_smoothing = -logprobs.mean(-1) 33 | loss = self.confidence * loss_pos + self.smoothing * loss_smoothing 34 | 35 | return loss.mean(), acc.item() 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /Model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from FeatureExtractorPart.pointnext import Stage, FeaturePropagation, Head 3 | 4 | 5 | class PointNeXt(nn.Module): 6 | """ 7 | PointNeXt语义分割模型特征提取部分 8 | """ 9 | 10 | def __init__(self, cfg): 11 | super().__init__() 12 | self.type = cfg['type'] 13 | self.num_class = cfg['num_class'] 14 | self.coor_dim = cfg['coor_dim'] 15 | self.normal = cfg['normal'] 16 | width = cfg['width'] 17 | 18 | self.mlp = nn.Conv1d(in_channels=self.coor_dim + self.coor_dim * self.normal, 19 | out_channels=width, kernel_size=1) 20 | self.stage = nn.ModuleList() 21 | 22 | for i in range(len(cfg['npoint'])): 23 | self.stage.append( 24 | Stage( 25 | npoint=cfg['npoint'][i], radius_list=cfg['radius_list'][i], nsample_list=cfg['nsample_list'][i], 26 | in_channel=width, expansion=cfg['expansion'], coor_dim=self.coor_dim 27 | ) 28 | ) 29 | width *= 2 30 | 31 | if self.type == 'segmentation': 32 | self.decoder = nn.ModuleList() 33 | for i in range(len(cfg['npoint'])): 34 | self.decoder.append( 35 | FeaturePropagation(in_channel=[width, width // 2], mlp=[width // 2, width // 2], 36 | coor_dim=self.coor_dim) 37 | ) 38 | width = width // 2 39 | 40 | self.head = Head(in_channel=width, mlp=cfg['head'], num_class=self.num_class, task_type=self.type) 41 | 42 | def forward(self, x): 43 | l0_xyz, l0_points = x[:, :self.coor_dim, :], x[:, :self.coor_dim + self.coor_dim * self.normal, :] 44 | l0_points = self.mlp(l0_points) 45 | 46 | record = [[l0_xyz, l0_points]] 47 | for stage in self.stage: 48 | record.append(list(stage(*record[-1]))) 49 | if self.type == 'segmentation': 50 | for i, decoder in enumerate(self.decoder): 51 | record[-i-2][1] = decoder(record[-i-2][0], record[-i-1][0], record[-i-2][1], record[-i-1][1]) 52 | points_cls = self.head(record[0][1]) 53 | 54 | else: 55 | points_cls = self.head(record[-1][1]) 56 | 57 | return points_cls 58 | 59 | -------------------------------------------------------------------------------- /Parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | 5 | 6 | def str_to_bool(s): 7 | if (s.lower() == 'true'): 8 | return True 9 | elif (s.lower() == 'false'): 10 | return False 11 | else: 12 | raise TypeError(f'str {s} can not convert to bool.') 13 | 14 | 15 | parser = argparse.ArgumentParser(description='Feature Extractor for Alignment') 16 | 17 | # 基本参数 18 | parser.add_argument('--name', default='PointNeXt', type=str, 19 | help='Name of the model') 20 | parser.add_argument('--mode', default='train', type=str, 21 | choices=['train', 'test'], 22 | help='train or test') 23 | 24 | # 数据参数 25 | parser.add_argument('--dataset', default='ModelNet40', type=str, 26 | choices=['ModelNet40'], 27 | help='Dataset name') 28 | parser.add_argument('--dataset_path', default=None, type=str, 29 | help='Path to checkpoint file') 30 | 31 | # 训练参数 32 | parser.add_argument('--batch_size', '-bs', default=64, type=int, 33 | help='Batch size for training') 34 | parser.add_argument('--num_epochs', '-ne', default=600, type=int, 35 | help='Number of epochs for training') 36 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 37 | help='initial learning rate for optimizer') 38 | parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, 39 | help='Weight decay for optimizer') 40 | parser.add_argument('--momentum', default=0.9, type=float, 41 | help='Momentum value for optimizer') 42 | parser.add_argument('--data_aug', default='basic', type=str, 43 | help='configuration for data augment') 44 | parser.add_argument('--optimizer', default='AdamW', type=str, 45 | help='optimizer') 46 | parser.add_argument('--scheduler', default='cosine', type=str, 47 | help='scheduler') 48 | 49 | # 模型参数 50 | parser.add_argument('--model_cfg', default='basic_c', type=str, 51 | help='Configuration for building pcd backbone') 52 | 53 | # 日志参数 54 | parser.add_argument('--eval_cycle', '-ec', default=5, type=int, 55 | help='Evaluate every n epochs') 56 | parser.add_argument('--log_cycle', '-lc', default=320, type=int, 57 | help='Log every n steps') 58 | parser.add_argument('--save_cycle', '-sc', default=1, type=int, 59 | help='Save every n epochs') 60 | parser.add_argument('--checkpoint', '-cp', default='', type=str, 61 | help='Checkpoint file name') 62 | 63 | # 设备参数 64 | parser.add_argument('--num_workers', '-nw', default=32, type=int, 65 | help='Number of workers used in dataloader') 66 | parser.add_argument('--use_cuda', default='True', type=str_to_bool, 67 | help='Using cuda to run') 68 | parser.add_argument('--auto_cast', default='False', type=str_to_bool, 69 | help='Using torch.cuda.amp.autocast to accelerate computing') 70 | parser.add_argument('--gpu_index', default='0', type=str, 71 | help='Index of gpu') 72 | 73 | '''数据增强配置参数''' 74 | aug_None = {} 75 | aug_basic = {'RT': [0, 0.5, 0, 0.1, 1], 'jitter': [0, 0.001, 1], 'random_drop': 0.5, 'random_sample': True} 76 | DATA_AUG_CONFIG = {'None': aug_None, 'basic': aug_basic} 77 | 78 | '''模型配置参数''' 79 | basic_c = { 80 | 'type': 'classification', 81 | 'num_class': 40, 82 | 'max_input': 2048, # 输入点最大数量 83 | 'npoint': [512, 128, 32, 8], 84 | 'radius_list': [[0.1, 0.2], [0.2, 0.4, 0.4], [0.4, 0.8], [0.8, 1.6]], 85 | 'nsample_list': [[16, 16], [16, 16, 16], [16, 16], [8, 8]], 86 | 'coor_dim': 3, 87 | 'width': 32, 88 | 'expansion': 4, 89 | 'normal': True, 90 | 'head': [512, 256] 91 | } 92 | 93 | MODEL_CONFIG = { 94 | 'basic_c': basic_c, 95 | } 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PointNeXt with pure python 2 | 3 | --- 4 | 5 | ## 简介 6 | 7 | 本项目为PointNeXt复现,原论文地址:[[arXiv]](https://arxiv.org/abs/2206.04670) 8 | 基本环境:Python 3.8 + PyTorch 1.12.1 9 | > ModelNet40点云分类任务的测试精度 10 | > 11 | > 12 | 13 | 14 | ### 与官方源码的区别 15 | - 重构了整体框架,重写了所有代码,去除部分可选参数,更加轻量、易读、易修改 16 | - 代码100%由python编写,无需编译cuda算子,兼容性更强但推理速度有所下降 17 | - group方法由ball-query改为ball-query + KNN的混合query方式 18 | 19 | --- 20 | 21 | ## 使用方法 22 | 23 | ### 训练与测试 24 | 1. 训练 `python main.py --gpu_index 0` 25 | 2. 测试 `python main.py --gpu_index 0 --mode test --checkpoint PATH/TO/YOUR/CHECKPOINT` 26 | 27 | ### Tensorboard 28 | 1. 控制台执行`tensorboard --logdir=runs --port 6006` 29 | 2. 浏览器访问`http://localhost:6006` 30 | 31 | 32 | -------------------------------------------------------------------------------- /Trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | logger.setLevel(logging.INFO) 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from torch.cuda.amp import autocast as autocast 7 | from torch.utils.tensorboard import SummaryWriter 8 | import os 9 | from tqdm import tqdm 10 | from utils import MetricLogger, fakecast 11 | from utils import show_pcd 12 | 13 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 14 | torch.manual_seed(3407) 15 | 16 | 17 | class Trainer: 18 | """ 19 | 训练器,输入待训练的模型、参数,封装训练过程 20 | """ 21 | def __init__(self, args, model, optimizer, scheduler, criterion, dataset, mode): 22 | self.args = args 23 | self.model = model 24 | self.optimizer = optimizer 25 | self.scheduler = scheduler 26 | self.criterion = criterion 27 | self.dataset = dataset 28 | self.mode = mode 29 | self.dataloader = None 30 | self.epoch = 1 31 | self.step = 1 32 | 33 | self.epoch_metric_logger = MetricLogger() 34 | 35 | # 恢复检查点 36 | if self.args.checkpoint != '': 37 | checkpoint = torch.load(self.args.checkpoint, map_location=self.args.device) 38 | # Load model 39 | if 'model' in checkpoint: 40 | model_state_dict = checkpoint['model'] 41 | if model_state_dict.keys() != self.model.state_dict().keys(): 42 | logger.info("Load model Failed, keys not match..") 43 | else: 44 | self.model.load_state_dict(model_state_dict) 45 | logger.info("Load model state") 46 | if 'optimizer' in checkpoint: 47 | self.optimizer.load_state_dict(checkpoint['optimizer']) 48 | logger.info("Load optimizer state") 49 | if 'scheduler' in checkpoint: 50 | self.scheduler.load_state_dict(checkpoint['scheduler']) 51 | logger.info("Load scheduler state") 52 | 53 | if 'epoch' in checkpoint: 54 | self.epoch = checkpoint['epoch'] + 1 55 | logger.info(f"Load epoch, current = {self.epoch}") 56 | 57 | if 'step' in checkpoint: 58 | self.step = checkpoint['step'] + 1 59 | logger.info(f"Load step, current = {self.step}") 60 | logger.info(f'Load checkpoint complete: \'{self.args.checkpoint}\'') 61 | else: 62 | logger.info(f'{mode} with a initial model') 63 | 64 | if self.args.auto_cast: 65 | self.cast = autocast 66 | else: 67 | self.cast = fakecast 68 | 69 | # 创建训练、测试结果保存目录 70 | self.log = f'{self.args.name}_model={self.args.model_cfg}_ds={self.args.dataset}_aug={self.args.data_aug}_' \ 71 | f'lr={self.args.lr}_wd={self.args.wd}_bs={self.args.batch_size}_' \ 72 | f'{self.args.optimizer}_{self.args.scheduler}' 73 | if self.mode == 'train': 74 | self.save_root = os.path.join('./result_train', self.log) 75 | elif self.mode == 'test': 76 | self.save_root = os.path.join('./result_test', self.log) 77 | else: 78 | raise ValueError 79 | os.makedirs(self.save_root, exist_ok=True) 80 | logger.info(f'save root = \'{self.save_root}\'') 81 | logger.info(f'run in {self.args.device}') 82 | 83 | def run(self): 84 | if self.mode == 'train': 85 | self.train() 86 | elif self.mode == 'test': 87 | self.test() 88 | 89 | def train(self): 90 | # tensorboard可视化训练过程,记录训练时的相关数据,使用指令:tensorboard --logdir=runs 91 | self.writer = SummaryWriter(os.path.join('./runs', self.log)) 92 | 93 | self.dataloader = DataLoader(dataset=self.dataset, 94 | batch_size=self.args.batch_size, 95 | num_workers=self.args.num_workers, 96 | shuffle=True, 97 | pin_memory=True, 98 | drop_last=False) 99 | 100 | start_epoch = self.epoch 101 | for ep in range(start_epoch, self.args.num_epochs + 1): 102 | # 记录日志 103 | self.writer.add_scalar("learning_rate", self.optimizer.param_groups[0]['lr'], ep) 104 | 105 | # 单轮训练 106 | self.train_one_epoch() 107 | 108 | # 动态学习率 109 | self.scheduler.step() 110 | 111 | # 定期保存 112 | if self.epoch % self.args.save_cycle == 0: 113 | self.save() 114 | 115 | # 定期验证 116 | if self.epoch % self.args.eval_cycle == 0: 117 | self.test() 118 | 119 | self.epoch += 1 120 | 121 | self.save(finish=True) 122 | 123 | def train_one_epoch(self): 124 | self.model.train() 125 | self.dataset.train() 126 | epoch_loss, epoch_acc = [], [] 127 | count = self.args.log_cycle // self.args.batch_size 128 | 129 | loop = tqdm(self.dataloader, total=len(self.dataloader), leave=False) 130 | loop.set_description('train') 131 | for data in loop: 132 | pcd, label = data 133 | # show_pcd([pcd[0].T], normal=True) 134 | pcd, label = pcd.to(self.args.device, non_blocking=True), label.to(self.args.device, non_blocking=True) 135 | 136 | # 前向传播与反向传播 137 | with self.cast(): 138 | points_cls = self.model(pcd) 139 | loss, acc = self.criterion(points_cls, label) 140 | 141 | self.epoch_metric_logger.add_metric('loss', loss.item()) 142 | self.epoch_metric_logger.add_metric('acc', acc) 143 | 144 | self.optimizer.zero_grad() 145 | loss.backward() 146 | self.optimizer.step() 147 | 148 | # 记录日志 149 | loop.set_postfix(train_loss=loss.item(), acc=f'{acc * 100:.2f}%') 150 | epoch_loss.append(loss.item()) 151 | epoch_acc.append(acc) 152 | count -= 1 153 | if count <= 0: 154 | count = self.args.log_cycle // self.args.batch_size 155 | self.writer.add_scalar("train/step_loss", sum(epoch_loss[-count:]) / count, self.step) 156 | self.writer.add_scalar("train/step_acc", sum(epoch_acc[-count:]) / count, self.step) 157 | self.step += 1 158 | self.writer.add_scalar("train/epoch_loss", sum(epoch_loss) / len(epoch_loss), self.epoch) 159 | self.writer.add_scalar("train/epoch_acc", sum(epoch_acc) / len(epoch_acc), self.epoch) 160 | logger.info(f'Train Epoch {self.epoch:>4d} ' + self.epoch_metric_logger.tostring()) 161 | self.epoch_metric_logger.clear() 162 | 163 | def test(self): 164 | self.model.eval() 165 | self.dataset.eval() 166 | if self.mode == 'test': 167 | self.epoch -= 1 168 | self.dataset.transforms.set_padding(False) 169 | eval_dataloader = DataLoader(dataset=self.dataset, 170 | batch_size=1, 171 | num_workers=min(self.args.num_workers, 16), 172 | pin_memory=True, 173 | drop_last=False, 174 | shuffle=False) 175 | 176 | loop = tqdm(eval_dataloader, total=len(eval_dataloader), leave=False) 177 | loop.set_description('eval') 178 | for data in loop: 179 | pcd, label = data 180 | # show_pcd([pcd[0].T], normal=True) 181 | pcd, label = pcd.to(self.args.device, non_blocking=True), label.to(self.args.device, non_blocking=True) 182 | 183 | # 前向传播 184 | with torch.no_grad(): 185 | points_cls = self.model(pcd) 186 | loss, acc = self.criterion(points_cls, label) 187 | 188 | self.epoch_metric_logger.add_metric('loss', loss.item()) 189 | self.epoch_metric_logger.add_metric('acc', acc) 190 | 191 | loop.set_postfix(eval_loss=loss.item()) 192 | 193 | self.dataset.transforms.set_padding(True) 194 | print('Eval :', self.epoch_metric_logger.tostring()) 195 | metric = self.epoch_metric_logger.get_average_value() 196 | 197 | if self.mode == 'train': 198 | self.writer.add_scalar("eval/loss", metric['loss'], self.epoch) 199 | self.writer.add_scalar("eval/acc", metric['acc'], self.epoch) 200 | self.epoch_metric_logger.clear() 201 | 202 | def save(self, finish=False): 203 | model_state_dict = self.model.state_dict() 204 | if not finish: 205 | state = { 206 | 'model': model_state_dict, 207 | 'optimizer': self.optimizer.state_dict(), 208 | 'scheduler': self.scheduler.state_dict(), 209 | 'epoch': self.epoch, 210 | 'step': self.step, 211 | } 212 | file_path = os.path.join(self.save_root, f'{self.args.name}_{self.args.dataset}_epoch{self.epoch}.pth') 213 | else: 214 | state = { 215 | 'model': model_state_dict, 216 | } 217 | file_path = os.path.join(self.save_root, f'{self.args.name}_{self.args.dataset}.pth') 218 | 219 | torch.save(state, file_path) 220 | 221 | -------------------------------------------------------------------------------- /Transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.transforms import Compose 4 | import math 5 | import random 6 | from utils import voxel_down_sample 7 | from FeatureExtractorPart.utils import index_points, farthest_point_sample 8 | 9 | 10 | class PCDPretreatment(nn.Module): 11 | """ 12 | 点云预处理与部分数据增强 13 | """ 14 | 15 | def __init__(self, num=2048, padding=True, down_sample='fps', mode='train', normal=True, 16 | data_augmentation=None, random_drop=0, resampling=False): 17 | super().__init__() 18 | self.num = num 19 | self.padding = padding 20 | self.normal = normal 21 | self.random_drop = random_drop 22 | self.resampling = resampling 23 | self.mode = mode 24 | self.sampling = down_sample 25 | self.set_sampling(down_sample) 26 | self.data_aug = data_augmentation if data_augmentation is not None else nn.Identity() 27 | 28 | def forward(self, pcd): 29 | """ 30 | :param pcd: (N, 3+) 点云矩阵 31 | :return: (3+, N) 32 | """ 33 | # 坐标归一化 34 | pcd_xyz = pcd[:, :3] 35 | pcd_xyz = pcd_xyz - pcd_xyz.mean(dim=0, keepdim=True) 36 | dis = torch.norm(pcd_xyz, dim=1) 37 | max_dis = dis.max() 38 | pcd_xyz /= max_dis 39 | 40 | # 法线 41 | if self.normal: 42 | pcd[:, :3] = pcd_xyz 43 | else: 44 | pcd = pcd_xyz 45 | 46 | # 随机丢弃一定比率的点 47 | if self.random_drop > 0 and self.mode == 'train': 48 | drop_ratio = random.uniform(0, self.random_drop) 49 | remain_points = torch.rand(size=(pcd.shape[0],), device=pcd.device) >= drop_ratio 50 | pcd = pcd[remain_points] 51 | 52 | # 点云数量统一化 53 | if pcd.shape[0] < self.num and self.padding: 54 | padding = torch.zeros(size=(self.num - pcd.shape[0], pcd.shape[1]), device=pcd.device) 55 | padding[:, 2] = -10 56 | pcd = torch.cat((pcd, padding), dim=0) 57 | elif pcd.shape[0] > self.num: 58 | pcd = self.down_sample(pcd) 59 | pcd = pcd.T 60 | 61 | # 点云数量无关的数据增强 62 | if self.mode == 'train': 63 | if self.normal: 64 | pcd[:3, :] = self.data_aug(pcd[:3, :]) 65 | else: 66 | pcd[:6, :] = self.data_aug(pcd[:6, :]) 67 | 68 | return pcd 69 | 70 | def min_dis(self, pcd): 71 | """ 72 | 基于距离的下采样,保留距离车辆最近的点 73 | """ 74 | dis = torch.norm(pcd[:, :3], p=2, dim=1) # 计算点云距离车辆的直线距离 75 | _, sorted_ids = torch.sort(dis) 76 | sorted_ids = sorted_ids[:self.num] # 只保留距离最近的num个点 77 | pcd = pcd[sorted_ids] 78 | return pcd 79 | 80 | def random(self, pcd): 81 | """ 82 | 随机点云下采样 83 | """ 84 | downsample_ids = torch.randperm(pcd.shape[0])[:self.num] 85 | pcd = pcd[downsample_ids] 86 | return pcd 87 | 88 | def voxel_down_sample(self, pcd): 89 | return voxel_down_sample(pcd, voxel_size=0.01, num=self.num, padding=self.padding) 90 | 91 | def fps(self, pcd): 92 | sample_ids = farthest_point_sample(pcd[:, :3].unsqueeze(0), self.num) 93 | pcd = index_points(pcd.unsqueeze(0), sample_ids)[0] 94 | return pcd 95 | 96 | def set_padding(self, option: bool): 97 | self.padding = option 98 | 99 | def set_sampling(self, sampling): 100 | if sampling == 'dis': 101 | self.down_sample = self.min_dis 102 | elif sampling == 'voxel': 103 | self.down_sample = self.voxel_down_sample 104 | elif sampling == 'random': 105 | self.down_sample = self.random 106 | elif sampling == 'fps': 107 | self.down_sample = self.fps 108 | elif sampling == 'identical': 109 | self.down_sample = lambda x: x 110 | else: 111 | raise ValueError 112 | 113 | def set_mode(self, mode): 114 | self.mode = mode 115 | if self.resampling: 116 | if mode == 'train': 117 | self.down_sample = self.random 118 | else: 119 | self.set_sampling(self.sampling) 120 | 121 | 122 | class RandomRT(nn.Module): 123 | def __init__(self, r_mean=0, r_std=0.5, t_mean=0, t_std=0.1, p=1) -> None: 124 | super().__init__() 125 | self.r_mean = r_mean 126 | self.r_std = r_std 127 | self.t_mean = t_mean 128 | self.t_std = t_std 129 | self.p = p 130 | 131 | def forward(self, input): 132 | """ 133 | pcd: Tensor [3+, n] 134 | """ 135 | if random.random() > self.p: 136 | return input 137 | pcd = input 138 | 139 | # 生成三方向的随机角度,得到各方位的旋转矩阵,最后整合为总体旋转矩阵 140 | z = (torch.rand(size=(1,)) - 0.5) * 2 * self.r_std 141 | x = (torch.rand(size=(1,)) - 0.5) * 2 * self.r_std 142 | y = (torch.rand(size=(1,)) - 0.5) * 2 * self.r_std 143 | 144 | R_x = torch.tensor([[1, 0, 0], 145 | [0, math.cos(x), -math.sin(x)], 146 | [0, math.sin(x), math.cos(x)]]) 147 | R_y = torch.tensor([[math.cos(y), 0, math.sin(y)], 148 | [0, 1, 0], 149 | [-math.sin(y), 0, math.cos(y)]]) 150 | R_z = torch.tensor([[math.cos(z), -math.sin(z), 0], 151 | [math.sin(z), math.cos(z), 0], 152 | [0, 0, 1]]) 153 | 154 | R_aug = R_x @ R_y @ R_z 155 | R_aug.to(pcd.device) 156 | 157 | if self.t_std > 0: 158 | T_aug = (torch.rand(size=(3, 1)) - 0.5) * 2 * self.t_std 159 | else: 160 | T_aug = torch.zeros(size=(3, 1), device=pcd.device) 161 | 162 | pcd[:3, :] = R_aug @ pcd[:3, :] + T_aug 163 | if pcd.shape[0] >= 6: 164 | pcd[3:6, :] = R_aug @ pcd[3:6, :] 165 | 166 | return pcd 167 | 168 | 169 | class RandomPosJitter(nn.Module): 170 | """点云位置随机抖动""" 171 | def __init__(self, mean=0, std=0.01, p=1): 172 | super().__init__() 173 | self.mean = mean 174 | self.std = std 175 | self.p = p 176 | 177 | def forward(self, input): 178 | """ 179 | :param input: 180 | pcd: Tensor [3+, n] 181 | :return: 182 | """ 183 | if random.random() > self.p: 184 | return input 185 | pcd = input 186 | pos_jitter = (torch.rand(size=(3, pcd.shape[1])) - 0.5) * 2 * self.std 187 | pcd[:3, :] += pos_jitter 188 | return pcd 189 | 190 | 191 | def get_data_augment(data_aug): 192 | aug_list, data_augment, random_sample, random_drop = [], None, False, 0 193 | if 'RT' in data_aug and data_aug['RT'] is not None: 194 | aug_list.append(RandomRT(*data_aug['RT'])) 195 | if 'jitter' in data_aug and data_aug['jitter'] is not None: 196 | aug_list.append(RandomPosJitter(*data_aug['jitter'])) 197 | if 'random_sample' in data_aug and data_aug['random_sample'] is not None: 198 | random_sample = data_aug['random_sample'] 199 | if 'random_drop' in data_aug and data_aug['random_drop'] is not None: 200 | random_drop = data_aug['random_drop'] 201 | 202 | if len(aug_list) == 1: 203 | data_augment = aug_list[0] 204 | elif len(aug_list) > 1: 205 | data_augment = Compose(aug_list) 206 | return data_augment, random_sample, random_drop 207 | -------------------------------------------------------------------------------- /dataset/ModelNet40.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | import torch 6 | 7 | 8 | class ModelNet40(Dataset): 9 | def __init__(self, dataroot, transforms=None): 10 | super(ModelNet40, self).__init__() 11 | self.dataroot = dataroot 12 | self.transforms = transforms 13 | self.train_set = list(glob(os.path.join(self.dataroot, 'train', '*.npz'))) 14 | # self.train_set = list(glob(os.path.join(self.dataroot, 'test', '*.npz'))) 15 | self.test_set = list(glob(os.path.join(self.dataroot, 'test', '*.npz'))) 16 | self.training = True 17 | self.pcd = self.train_set 18 | 19 | self.LABEL_DICT = {'airplane': 0, 'bathtub': 1, 'bed': 2, 'bench': 3, 'bookshelf': 4, 'bottle': 5, 'bowl': 6, 20 | 'car': 7, 'chair': 8, 'cone': 9, 'cup': 10, 'curtain': 11, 'desk': 12, 'door': 13, 21 | 'dresser': 14, 'flower': 15, 'glass': 16, 'guitar': 17, 'keyboard': 18, 'lamp': 19, 22 | 'laptop': 20, 'mantel': 21, 'monitor': 22, 'night': 23, 'person': 24, 'piano': 25, 23 | 'plant': 26, 'radio': 27, 'range': 28, 'sink': 29, 'sofa': 30, 'stairs': 31, 'stool': 32, 24 | 'table': 33, 'tent': 34, 'toilet': 35, 'tv': 36, 'vase': 37, 'wardrobe': 38, 'xbox': 39} 25 | # LABEL = set() 26 | # for x in self.test_set: 27 | # LABEL.add(os.path.split(x)[1].split('_')[0]) 28 | # LABEL = sorted(list(LABEL)) 29 | # LABEL_DICT = {k: v for v, k in enumerate(LABEL)} 30 | 31 | print( 32 | f'Load ModelNet40 done, load {len(self.train_set)} items for training and {len(self.test_set)} items for testing.' 33 | ) 34 | 35 | def __len__(self): 36 | return len(self.pcd) 37 | 38 | def __getitem__(self, index): 39 | # index = 0 40 | x = self.pcd[index] 41 | label = os.path.split(x)[1].split('_')[0] 42 | label = self.LABEL_DICT[label] 43 | with np.load(x) as npz: 44 | pcd, norm = npz['pcd'], npz['norm'] 45 | x = torch.from_numpy(np.concatenate([pcd, norm], axis=1)).float() 46 | if self.transforms is not None: 47 | x = self.transforms(x) 48 | return x, torch.tensor([label], dtype=torch.long) 49 | 50 | def train(self): 51 | self.training = True 52 | self.pcd = self.train_set 53 | self.transforms.set_mode('train') 54 | 55 | def eval(self): 56 | self.training = False 57 | self.pcd = self.test_set 58 | self.transforms.set_mode('eval') 59 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from Parameters import * 2 | args = parser.parse_args() 3 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_index 4 | sys.path.insert(1, os.path.dirname(os.path.abspath(__name__))) 5 | import logging 6 | logger = logging.getLogger(__name__) 7 | logger.setLevel(logging.INFO) 8 | import torch 9 | from Model import PointNeXt 10 | from dataset.ModelNet40 import ModelNet40 11 | from Loss import LabelSmoothingCE 12 | from Transforms import PCDPretreatment, get_data_augment 13 | from Trainer import Trainer 14 | from utils import IdentityScheduler 15 | 16 | 17 | def main(): 18 | # 解析参数 19 | if args.use_cuda and torch.cuda.is_available(): 20 | args.device = torch.device('cuda') 21 | gpus = list(range(torch.cuda.device_count())) 22 | torch.cuda.set_device('cuda:{}'.format(gpus[0])) 23 | else: 24 | args.device = torch.device('cpu') 25 | 26 | if sys.platform == 'darwin': 27 | args.num_workers = 0 28 | args.batch_size = 2 29 | 30 | model_cfg = MODEL_CONFIG[args.model_cfg] 31 | max_input = model_cfg['max_input'] 32 | normal = model_cfg['normal'] 33 | 34 | if args.optimizer.lower() == 'adamw': 35 | Optimizer = torch.optim.AdamW 36 | else: 37 | args.optimizer = 'Adam' 38 | Optimizer = torch.optim.Adam 39 | 40 | if args.scheduler.lower() == 'identity': 41 | Scheduler = IdentityScheduler 42 | else: 43 | args.scheduler = 'cosine' 44 | Scheduler = torch.optim.lr_scheduler.CosineAnnealingLR 45 | 46 | # 数据变换、加载数据集 47 | logger.info('Prepare Data') 48 | '''数据变换、加载数据集''' 49 | data_augment, random_sample, random_drop = get_data_augment(DATA_AUG_CONFIG[args.data_aug]) 50 | transforms = PCDPretreatment(num=max_input, down_sample='fps', normal=normal, 51 | data_augmentation=data_augment, random_drop=random_drop, resampling=random_sample) 52 | 53 | '''Prepare dataset''' 54 | if args.dataset_path is None or args.dataset_path == 'default': 55 | if args.dataset == 'ModelNet40': 56 | default_dataset_path_list = [r'../../Dataset/ModelNet40_points', 57 | r'/Users/dingziheng/dataset/ModelNet40_points', 58 | r'/root/dataset/ModelNet40_points'] 59 | else: 60 | raise ValueError 61 | for path in default_dataset_path_list: 62 | if os.path.exists(path): 63 | args.dataset_path = path 64 | break 65 | else: # this is for-else block, indent is not missing 66 | raise FileNotFoundError(f'Dataset path not found.') 67 | logger.info(f'Load default dataset from {args.dataset_path}') 68 | if args.dataset == 'ModelNet40': 69 | dataset = ModelNet40(dataroot=args.dataset_path, transforms=transforms) 70 | else: 71 | raise ValueError 72 | 73 | # 模型与损失函数 74 | logger.info('Prepare Models...') 75 | model = PointNeXt(model_cfg).to(device=args.device) 76 | optimizer = Optimizer(model.parameters(), lr=args.lr, weight_decay=args.wd) 77 | scheduler = Scheduler(optimizer, T_max=args.num_epochs, eta_min=args.lr * 0.001) 78 | criterion = LabelSmoothingCE() 79 | 80 | # 训练器 81 | logger.info('Trainer launching...') 82 | trainer = Trainer( 83 | args=args, 84 | model=model, 85 | optimizer=optimizer, 86 | scheduler=scheduler, 87 | criterion=criterion, 88 | dataset=dataset, 89 | mode=args.mode 90 | ) 91 | trainer.run() 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | print('Done.') 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | open3d==0.15.1 3 | torch==1.12.0 4 | torchvision==0.13.0 5 | tqdm==4.64.0 6 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import open3d as o3d 4 | 5 | 6 | class fakecast: 7 | def __enter__(self): 8 | pass 9 | 10 | def __exit__(self, exc_type, exc_val, exc_tb): 11 | pass 12 | 13 | 14 | class MetricLogger(object): 15 | def __init__(self) -> None: 16 | self.metrics = dict() 17 | 18 | def add_metric(self, name: str, value): 19 | assert (isinstance(value, int) or isinstance(value, float) 20 | or (isinstance(value, np.ndarray) and value.ndim in [1,2]) 21 | or (isinstance(value, torch.Tensor) and value.ndim == [1,2])) 22 | if (name not in self.metrics.keys()): 23 | self.metrics[name] = [value] 24 | else: 25 | self.metrics[name].append(value) 26 | 27 | def clear(self): 28 | self.metrics = dict() 29 | 30 | def tostring(self): 31 | result = self.get_average_value() 32 | s = '' 33 | for name, mean_value in result.items(): 34 | s += f'| {name} = ' 35 | if (isinstance(mean_value, int) or isinstance(mean_value, float)): 36 | s += f'{mean_value:>5.3f} ' 37 | elif (isinstance(mean_value, np.ndarray)): 38 | for v in mean_value: 39 | s += f'{float(v):>5.3f}, ' 40 | elif (isinstance(mean_value, torch.Tensor)): 41 | for v in mean_value: 42 | s += f'{float(v):>5.3f}, ' 43 | return s 44 | 45 | def get_average_value(self): 46 | result = dict() 47 | for name, values in self.metrics.items(): 48 | if (isinstance(values[0], int) or isinstance(values[0], float)): 49 | meanval = sum(values) / len(values) 50 | elif (isinstance(values[0], np.ndarray)): 51 | meanval = np.stack(values, axis=0).mean(0) 52 | elif (isinstance(values[0], torch.tensor)): 53 | meanval = torch.stack(values, dim=0).mean(0) 54 | result[name] = meanval 55 | return result 56 | 57 | 58 | def GetPcdFromNumpy(pcd_np: np.ndarray, color=None): 59 | ''' 60 | convert a numpy.ndarray with shape(xyz+, n) to pointcloud in o3d 61 | ''' 62 | pcd = o3d.geometry.PointCloud() 63 | pcd.points = o3d.utility.Vector3dVector(pcd_np[:3, :].T) 64 | if (color is not None): 65 | pcd.paint_uniform_color(color) 66 | return pcd 67 | 68 | 69 | def show_pcd(pcds, colors=None, normal=False, window_name="PCD"): 70 | ''' 71 | pcds: List(ArrayLike), points to be shown, shape (K, xyz+) 72 | colors: List[Tuple], color list, shape (r,g,b) scaled 0~1 73 | ''' 74 | import open3d as o3d 75 | # 创建窗口对象 76 | vis = o3d.visualization.Visualizer() 77 | # 设置窗口标题 78 | vis.create_window(window_name=window_name) 79 | # 设置点云大小 80 | # vis.get_render_option().point_size = 1 81 | # 设置颜色背景为黑色 82 | # opt = vis.get_render_option() 83 | # opt.background_color = np.asarray([0, 0, 0]) 84 | 85 | for i in range(len(pcds)): 86 | # 创建点云对象 87 | pcd_o3d = o3d.open3d.geometry.PointCloud() 88 | # 将点云数据转换为Open3d可以直接使用的数据类型 89 | if (isinstance(pcds[i], np.ndarray)): 90 | pcd_points = pcds[i][:, :3] 91 | if normal: 92 | pcd_normals = pcds[i][:, 3:6] 93 | elif (isinstance(pcds[i], torch.Tensor)): 94 | pcd_points = pcds[i][:, :3].detach().cpu().numpy() 95 | if normal: 96 | pcd_normals = pcds[i][:, 3:6].detach().cpu().numpy() 97 | else: 98 | pcd_points = np.array(pcds[i][:, :3]) 99 | if normal: 100 | pcd_normals = np.array(pcds[i][:, 3:6]) 101 | pcd_o3d.points = o3d.open3d.utility.Vector3dVector(pcd_points) 102 | if normal: 103 | pcd_o3d.normals = o3d.open3d.utility.Vector3dVector(pcd_normals) 104 | # 设置点的颜色 105 | if colors is not None: 106 | pcd_o3d.paint_uniform_color(colors[i]) 107 | # 将点云加入到窗口中 108 | vis.add_geometry(pcd_o3d) 109 | 110 | vis.run() 111 | vis.destroy_window() 112 | 113 | 114 | def voxel_down_sample(pcd, voxel_size=0.3, num=None, padding=True): 115 | """ 116 | 点云体素化下采样 117 | :param pcd: (N, 3+) 原始点云 118 | :param voxel_size: 体素边长 119 | :param num: 下采样后的点数数量,为None时不约束,否则采样后点云数量不超过该值 120 | :param padding: num不为None且采样后数量不足时,使用0填充 121 | :return: (N, 3+) 下采样后的点云 122 | """ 123 | pcd_xyz = pcd[:, :3] 124 | # 根据点云范围确定voxel数量 125 | xyz_min = torch.min(pcd_xyz, dim=0)[0] 126 | xyz_max = torch.max(pcd_xyz, dim=0)[0] 127 | X, Y, Z = torch.div(xyz_max[0] - xyz_min[0], voxel_size, rounding_mode='trunc') + 1, \ 128 | torch.div(xyz_max[1] - xyz_min[1], voxel_size, rounding_mode='trunc') + 1, \ 129 | torch.div(xyz_max[2] - xyz_min[2], voxel_size, rounding_mode='trunc') + 1 130 | 131 | # 计算每个点云所在voxel的xyz索引和总索引 132 | relative_xyz = pcd_xyz - xyz_min 133 | voxel_xyz = torch.div(relative_xyz, voxel_size, rounding_mode='trunc').int() 134 | voxel_id = (voxel_xyz[:, 0] + voxel_xyz[:, 1] * X + voxel_xyz[:, 2] * X * Y).int() 135 | 136 | '''每个voxel仅保留最接近中心点的点云,并根据voxel内点云数量排序''' 137 | dis = torch.sum((relative_xyz - voxel_xyz * voxel_size - voxel_size / 2).pow(2), dim=-1) 138 | 139 | # 预先根据点云距离voxel中心的距离由近到远进行排序,使得每个voxel第一次被统计时即对应了最近点云 140 | dis, sorted_id = torch.sort(dis) 141 | voxel_id = voxel_id[sorted_id] 142 | pcd = pcd[sorted_id] 143 | 144 | # 去除相同voxel,id即为每个voxel内的采样点,cnt为当前采样点所在voxel的点云数量之和 145 | _, unique_id, cnt = np.unique(voxel_id.cpu(), return_index=True, return_counts=True) 146 | unique_id, cnt = torch.tensor(unique_id, device=pcd.device), torch.tensor(cnt, device=pcd.device) 147 | 148 | # 保留点云数量最多的voxel 149 | if num is not None and unique_id.shape[0] > num: 150 | _, cnt_topk_id = torch.topk(cnt, k=num) 151 | unique_id = unique_id[cnt_topk_id] 152 | new_pcd = pcd[unique_id] 153 | 154 | if num is not None: 155 | if new_pcd.shape[0] < num and padding: 156 | padding_num = num - new_pcd.shape[0] 157 | padding = torch.zeros(size=(padding_num, new_pcd.shape[1]), device=new_pcd.device) 158 | padding[:, 2] = -10 159 | new_pcd = torch.cat((new_pcd, padding), dim=0) 160 | else: 161 | new_pcd = new_pcd[:num] 162 | 163 | return new_pcd 164 | 165 | 166 | class IdentityScheduler(torch.nn.Module): 167 | def __init__(self, *args, **kwargs): 168 | super().__init__() 169 | 170 | def step(self): 171 | pass 172 | 173 | --------------------------------------------------------------------------------