├── .gitignore ├── Dataloader ├── ModelNet40.py ├── S3DIS.py ├── S3DIS_random.py └── scanobjectnn.py ├── Image ├── Motivation.PNG └── model.PNG ├── README.md ├── main_S3DIS.py ├── main_S3DIS_ref.py ├── main_modelnet40.py ├── main_modelnet40_ref.py ├── main_scanobj.py ├── main_scanobj_ref.py ├── model ├── cls │ ├── DGCNN.py │ └── DGCNN_repmax.py └── seg │ ├── DGCNN_seg.py │ └── DGCNN_seg_repmax.py └── utils ├── GDAutil.py ├── GDM_util.py ├── all_utils.py ├── aug.py ├── aug_loss.py ├── cal_final_result.py ├── centerloss.py ├── create_cluster_data.py ├── create_heatmap.py ├── curvenet_util.py ├── part_segmentation_evaluation.py ├── pointnet2_utils.py ├── test_perform_cal.py ├── vis_feature_cluster.py ├── voting_eval_cls.py └── walk.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | pth_file 3 | Exp 4 | __pycache__ -------------------------------------------------------------------------------- /Dataloader/ModelNet40.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import torch.utils.data as data 5 | 6 | 7 | import sys 8 | import glob 9 | import h5py 10 | 11 | 12 | 13 | ########## label meaning ######### 14 | ### 0 is looking forward ### 15 | ### 1 is looking left ### 16 | ### 2 is looking right ### 17 | 18 | 19 | # np.random.seed(0) 20 | class ModuleNet40(data.Dataset): 21 | def __init__(self,root,split): 22 | if split=='train': 23 | self.split=split 24 | else: 25 | self.split='test' 26 | self.root=root 27 | self.data,self.label=self.get_datalabel() 28 | self.num_points=1024 29 | 30 | def get_datalabel(self): 31 | all_data = [] 32 | all_label = [] 33 | for h5_name in glob.glob(os.path.join(self.root, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%self.split)): 34 | f = h5py.File(h5_name) 35 | data = f['data'][:].astype('float32') 36 | label = f['label'][:].astype('int64') 37 | f.close() 38 | all_data.append(data) 39 | all_label.append(label) 40 | all_data = np.concatenate(all_data, axis=0) 41 | all_label = np.concatenate(all_label, axis=0) 42 | return all_data, all_label 43 | 44 | def translate_pointcloud(self,pointcloud): 45 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 46 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 47 | 48 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 49 | return translated_pointcloud 50 | 51 | 52 | 53 | def __getitem__(self, item): 54 | pointcloud = self.data[item][:self.num_points] 55 | label = self.label[item] 56 | if self.split == 'train': 57 | pointcloud = self.translate_pointcloud(pointcloud) 58 | np.random.shuffle(pointcloud) 59 | 60 | pointcloud=torch.FloatTensor(pointcloud) 61 | label=torch.LongTensor(label) 62 | 63 | pointcloud=pointcloud.permute(1,0) 64 | 65 | return pointcloud, label 66 | 67 | def __len__(self): 68 | # return 32 69 | return self.data.shape[0] 70 | 71 | 72 | 73 | 74 | 75 | def get_sets(data_path,train_batch_size,test_batch_size): 76 | train_data=ModuleNet40(data_path,split='train') 77 | train_loader=data.DataLoader(dataset=train_data,batch_size=train_batch_size,shuffle=True,num_workers=2) 78 | 79 | test_data=ModuleNet40(data_path,split='test') 80 | test_loader=data.DataLoader(dataset=test_data,batch_size=test_batch_size,shuffle=True,num_workers=2) 81 | 82 | valid_dataset=ModuleNet40(data_path,split='valid') 83 | valid_loader=data.DataLoader(dataset=valid_dataset,batch_size=test_batch_size,shuffle=True,num_workers=2) 84 | 85 | return train_loader,test_loader,valid_loader 86 | 87 | 88 | 89 | 90 | 91 | 92 | if __name__=='__main__': 93 | data_path='/data1/jiajing/dataset/ModelNet40/data' 94 | dataset=ModuleNet40(data_path,'train') 95 | 96 | #### inpt shape is (3,1024) ##### 97 | inpt,label=dataset[2] 98 | a,b,c=get_sets(data_path,10) 99 | 100 | -------------------------------------------------------------------------------- /Dataloader/S3DIS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # import open3d as o3d 4 | import torch.utils.data as data 5 | import os 6 | 7 | 8 | 9 | 10 | 11 | 12 | # cls_list = ['clutter', 'ceiling', 'floor', 13 | # 'wall', 'beam', 'column', 'door','window', 14 | # 'table', 'chair', 'sofa', 'bookcase', 'board'] 15 | 16 | 17 | np.random.seed(0) 18 | class S3DISDataset(data.Dataset): 19 | def __init__(self,root,split,test_area): 20 | self.test_area=test_area 21 | self.split=split 22 | total_area_list=['Area_1','Area_2','Area_3','Area_4','Area_5','Area_6'] 23 | # self.area_list=[] 24 | assert test_area in np.arange(1,7) 25 | if split=='train': 26 | self.area_list=[i for i in total_area_list if int(i.split('_')[1])!=test_area] 27 | # self.area_list=['Area_6','Area_1','Area_2','Area_3','Area_4'] 28 | else: 29 | self.area_list=[i for i in total_area_list if int(i.split('_')[1])==test_area] 30 | 31 | self.root=root 32 | self.batch_list=self.create_batch_list() 33 | 34 | 35 | def create_batch_list(self): 36 | all_batch_list=[] 37 | for area in self.area_list: 38 | area_path=os.path.join(self.root,area) 39 | room_list=os.listdir(area_path) 40 | for room in room_list: 41 | if (self.test_area==2) and (self.split!='train') and (room=='auditorium_2'): 42 | continue 43 | 44 | batch_folder_path=os.path.join(area_path,room,'Batch_Folder') 45 | batch_list=os.listdir(batch_folder_path) 46 | for batch in batch_list: 47 | batch_path=os.path.join(batch_folder_path,batch) 48 | all_batch_list.append(batch_path) 49 | 50 | 51 | if (self.split=='train') and (self.test_area==2): 52 | auditorium1_path=os.path.join(self.root,'Area_2','auditorium_2') 53 | batch_folder_path=os.path.join(auditorium1_path,'Batch_Folder') 54 | batch_list=os.listdir(batch_folder_path) 55 | for batch in batch_list: 56 | batch_path=os.path.join(batch_folder_path,batch) 57 | all_batch_list.append(batch_path) 58 | 59 | 60 | return all_batch_list 61 | 62 | def __getitem__(self,batch_index): 63 | np_file=self.batch_list[batch_index] 64 | data=np.load(np_file) 65 | # inpt=torch.FloatTensor(data[:,0:6]) 66 | # inpt_color=torch.FloatTensor(data[:,6:9]) 67 | inpt=torch.FloatTensor(data[:,:-1]) 68 | 69 | # index=[6,7,8,3,4,5,0,1,2] 70 | # inpt=inpt[:,index] 71 | label=torch.LongTensor(data[:,-1]) 72 | 73 | 74 | return inpt,label 75 | 76 | def __len__(self): 77 | # return 20 78 | return len(self.batch_list) 79 | 80 | 81 | 82 | 83 | def get_sets(data_path,batch_size,test_batch,test_area): 84 | train_data=S3DISDataset(data_path,split='train',test_area=test_area) 85 | train_loader=data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True,num_workers=2) 86 | 87 | test_data=S3DISDataset(data_path,split='test',test_area=test_area) 88 | test_loader=data.DataLoader(dataset=test_data,batch_size=test_batch,shuffle=True,num_workers=2) 89 | 90 | valid_loader=S3DISDataset(data_path,split='valid',test_area=test_area) 91 | valid_loader=data.DataLoader(dataset=valid_loader,batch_size=test_batch,shuffle=True,num_workers=2) 92 | 93 | return train_loader,test_loader,valid_loader 94 | 95 | 96 | 97 | 98 | 99 | # def visulize_point(point_path,cls=0): 100 | # data=np.load(point_path) 101 | # if cls!=None: 102 | # pos=(data[:,-2]==cls) 103 | # data[pos,3:6]=np.array([255,0,0]) 104 | 105 | 106 | # points_info=o3d.geometry.PointCloud() 107 | # points_info.points=o3d.utility.Vector3dVector(data[:,0:3]) 108 | # points_info.colors=o3d.utility.Vector3dVector(data[:,3:6]/255) 109 | # o3d.visualization.draw_geometries([points_info]) 110 | 111 | 112 | 113 | if __name__=='__main__': 114 | data_path='/data1/jiajing/dataset/S3DIS_area/data' 115 | dataset=S3DISDataset(data_path,split='test',test_area=2) 116 | inpt,label=dataset[0] 117 | 118 | 119 | get_sets(data_path,5,5,2) 120 | # for i in range(len(dataset)): 121 | # inpt,label=dataset[i] 122 | # s 123 | 124 | 125 | 126 | # point_path='D:/Computer_vision/3D_Dataset\Stanford_Large_Scale/plane_seg_sample/Area_1/conferenceRoom_1/whole_room_point.npy' 127 | # visulize_point(point_path) 128 | 129 | -------------------------------------------------------------------------------- /Dataloader/S3DIS_random.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from tqdm import tqdm 5 | from torch.utils.data import Dataset 6 | import torch.utils.data as data 7 | 8 | class S3DISDataset(Dataset): 9 | def __init__(self, split='train', data_root='trainval_fullarea', num_point=4096, test_area=5, block_size=1.0, sample_rate=1.0, transform=None): 10 | super().__init__() 11 | self.num_point = num_point 12 | self.block_size = block_size 13 | self.transform = transform 14 | rooms = sorted(os.listdir(data_root)) 15 | rooms = [room for room in rooms if 'Area_' in room] 16 | if split == 'train': 17 | rooms_split = [room for room in rooms if not 'Area_{}'.format(test_area) in room] 18 | else: 19 | rooms_split = [room for room in rooms if 'Area_{}'.format(test_area) in room] 20 | 21 | self.room_points, self.room_labels = [], [] 22 | self.room_coord_min, self.room_coord_max = [], [] 23 | num_point_all = [] 24 | labelweights = np.zeros(13) 25 | 26 | for room_name in tqdm(rooms_split, total=len(rooms_split)): 27 | room_path = os.path.join(data_root, room_name) 28 | room_data = np.load(room_path) # xyzrgbl, N*7 29 | points, labels = room_data[:, 0:6], room_data[:, 6] # xyzrgb, N*6; l, N 30 | tmp, _ = np.histogram(labels, range(14)) 31 | labelweights += tmp 32 | coord_min, coord_max = np.amin(points, axis=0)[:3], np.amax(points, axis=0)[:3] 33 | self.room_points.append(points), self.room_labels.append(labels) 34 | self.room_coord_min.append(coord_min), self.room_coord_max.append(coord_max) 35 | num_point_all.append(labels.size) 36 | labelweights = labelweights.astype(np.float32) 37 | labelweights = labelweights / np.sum(labelweights) 38 | self.labelweights = np.power(np.amax(labelweights) / labelweights, 1 / 3.0) 39 | print(self.labelweights) 40 | sample_prob = num_point_all / np.sum(num_point_all) 41 | num_iter = int(np.sum(num_point_all) * sample_rate / num_point) 42 | room_idxs = [] 43 | for index in range(len(rooms_split)): 44 | room_idxs.extend([index] * int(round(sample_prob[index] * num_iter))) 45 | self.room_idxs = np.array(room_idxs) 46 | print("Totally {} samples in {} set.".format(len(self.room_idxs), split)) 47 | 48 | def __getitem__(self, idx): 49 | room_idx = self.room_idxs[idx] 50 | points = self.room_points[room_idx] # N * 6 51 | labels = self.room_labels[room_idx] # N 52 | N_points = points.shape[0] 53 | 54 | while (True): 55 | center = points[np.random.choice(N_points)][:3] 56 | block_min = center - [self.block_size / 2.0, self.block_size / 2.0, 0] 57 | block_max = center + [self.block_size / 2.0, self.block_size / 2.0, 0] 58 | point_idxs = np.where((points[:, 0] >= block_min[0]) & (points[:, 0] <= block_max[0]) & (points[:, 1] >= block_min[1]) & (points[:, 1] <= block_max[1]))[0] 59 | if point_idxs.size > 1024: 60 | break 61 | 62 | if point_idxs.size >= self.num_point: 63 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=False) 64 | else: 65 | selected_point_idxs = np.random.choice(point_idxs, self.num_point, replace=True) 66 | 67 | # normalize 68 | selected_points = points[selected_point_idxs, :] # num_point * 6 69 | current_points = np.zeros((self.num_point, 9)) # num_point * 9 70 | current_points[:, 6] = selected_points[:, 0] / self.room_coord_max[room_idx][0] 71 | current_points[:, 7] = selected_points[:, 1] / self.room_coord_max[room_idx][1] 72 | current_points[:, 8] = selected_points[:, 2] / self.room_coord_max[room_idx][2] 73 | selected_points[:, 0] = selected_points[:, 0] - center[0] 74 | selected_points[:, 1] = selected_points[:, 1] - center[1] 75 | selected_points[:, 3:6] /= 255.0 76 | current_points[:, 0:6] = selected_points 77 | current_labels = labels[selected_point_idxs] 78 | if self.transform is not None: 79 | current_points, current_labels = self.transform(current_points, current_labels) 80 | 81 | 82 | current_points=torch.FloatTensor(current_points) 83 | current_labels=torch.LongTensor(current_labels) 84 | 85 | return current_points, current_labels 86 | 87 | def __len__(self): 88 | return len(self.room_idxs) 89 | 90 | 91 | 92 | def get_sets(data_root,batch_size,test_batch,test_area): 93 | num_point, block_size, sample_rate = 4096, 1.0, 0.01 94 | 95 | train_data=S3DISDataset(split='train', data_root=data_root, num_point=num_point, test_area=test_area, 96 | block_size=block_size, sample_rate=sample_rate, transform=None) 97 | train_loader=data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True,num_workers=0) 98 | 99 | test_data=S3DISDataset(split='test', data_root=data_root, num_point=num_point, test_area=test_area, 100 | block_size=block_size, sample_rate=sample_rate, transform=None) 101 | test_loader=data.DataLoader(dataset=test_data,batch_size=test_batch,shuffle=True,num_workers=0) 102 | 103 | valid_loader=S3DISDataset(split='valid', data_root=data_root, num_point=num_point, test_area=test_area, 104 | block_size=block_size, sample_rate=sample_rate, transform=None) 105 | valid_loader=data.DataLoader(dataset=valid_loader,batch_size=test_batch,shuffle=True,num_workers=0) 106 | 107 | return train_loader,test_loader,valid_loader 108 | 109 | 110 | 111 | 112 | 113 | 114 | if __name__=='__main__': 115 | data_root = '/data1/jiajing/worksapce/Algorithm/PointNet/Pointnet_Pointnet2_pytorch/data/stanford_indoor3d/' 116 | num_point, test_area, block_size, sample_rate = 4096, 1, 1.0, 0.01 117 | point_data = S3DISDataset(split='valid', data_root=data_root, num_point=num_point, test_area=test_area, 118 | block_size=block_size, sample_rate=sample_rate, transform=None) 119 | s 120 | 121 | 122 | -------------------------------------------------------------------------------- /Dataloader/scanobjectnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import h5py 5 | import numpy as np 6 | from torch.utils.data import Dataset, dataset 7 | import torch.utils.data as data 8 | # import open3d as o3d 9 | import torch 10 | 11 | 12 | class ScanObjectNN(Dataset): 13 | def __init__(self,data_path,split='train',num_points=1024): 14 | self.split = split 15 | self.BASE_DIR=data_path 16 | self.data, self.label = self.load_scanobjectnn_data() 17 | self.num_points = num_points 18 | 19 | 20 | 21 | 22 | def translate_pointcloud(self,pointcloud): 23 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 24 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 25 | 26 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 27 | return translated_pointcloud 28 | 29 | 30 | 31 | 32 | 33 | def load_scanobjectnn_data(self): 34 | # self.BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 35 | # self.BASE_DIR='D:/Computer_vision/Dataset/ScanObjectNN/h5_files/h5_files/main_split' 36 | # DATA_DIR = os.path.join(self.BASE_DIR, 'data') 37 | all_data = [] 38 | all_label = [] 39 | 40 | if self.split=='train': 41 | partition='training' 42 | else: 43 | partition='test' 44 | 45 | 46 | # h5_name = self.BASE_DIR + '/data/' + partition + '_objectdataset_augmentedrot_scale75.h5' 47 | h5_name = os.path.join(self.BASE_DIR, partition + '_objectdataset.h5') 48 | f = h5py.File(h5_name) 49 | data = f['data'][:].astype('float32') 50 | label = f['label'][:].astype('int64') 51 | f.close() 52 | all_data.append(data) 53 | all_label.append(label) 54 | all_data = np.concatenate(all_data, axis=0) 55 | all_label = np.concatenate(all_label, axis=0) 56 | return all_data, all_label 57 | 58 | 59 | 60 | 61 | def __len__(self): 62 | return self.data.shape[0] 63 | 64 | 65 | 66 | def __getitem__(self, item): 67 | pointcloud = self.data[item][:self.num_points,:].astype(np.float) 68 | label = self.label[item] 69 | # if self.split == 'train': 70 | # pointcloud = self.translate_pointcloud(pointcloud) 71 | # np.random.shuffle(pointcloud) 72 | 73 | pointcloud=torch.FloatTensor(pointcloud) 74 | label=torch.LongTensor([label]) 75 | 76 | pointcloud=pointcloud.permute(1,0) 77 | 78 | return pointcloud, label 79 | 80 | 81 | 82 | 83 | def get_sets(data_path,train_batch_size,test_batch_size): 84 | train_data=ScanObjectNN(data_path,split='train') 85 | train_loader=data.DataLoader(dataset=train_data,batch_size=train_batch_size,shuffle=True,num_workers=2) 86 | 87 | test_data=ScanObjectNN(data_path,split='test') 88 | test_loader=data.DataLoader(dataset=test_data,batch_size=test_batch_size,shuffle=True,num_workers=2) 89 | 90 | valid_dataset=ScanObjectNN(data_path,split='valid') 91 | valid_loader=data.DataLoader(dataset=valid_dataset,batch_size=test_batch_size,shuffle=True,num_workers=2) 92 | 93 | return train_loader,test_loader,valid_loader 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | if __name__=='__main__': 105 | data_path='/data1/jiajing/dataset/scanobjectnn/main_split_nobg' 106 | dataset=ScanObjectNN(data_path,split='train') 107 | 108 | picked_index=100 109 | picked_data=dataset[picked_index] 110 | 111 | a,b,c=get_sets(data_path,10,10) 112 | 113 | # for (x,y) in a: 114 | # s 115 | # pointcloud=o3d.geometry.PointCloud() 116 | # # pc=o3d.geometry.PointCloud() 117 | # # pc.points=o3d.utility.Vector3dVector(point_cloud.transpose(1,0)+2) 118 | # pointcloud.points=o3d.utility.Vector3dVector(picked_data[0].reshape(-1,3)) 119 | # o3d.visualization.draw_geometries([pointcloud]) -------------------------------------------------------------------------------- /Image/Motivation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajingchen113322/Recycle_Maxpooling_Module/3f28fdc12dd87bee9dc416cbff0a422a47d18686/Image/Motivation.PNG -------------------------------------------------------------------------------- /Image/model.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiajingchen113322/Recycle_Maxpooling_Module/3f28fdc12dd87bee9dc416cbff0a422a47d18686/Image/model.PNG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recycling Max Pooling Module for 3D Point Cloud Analysis (CVPR2022) 2 | This is a pytorch implementation of the paper: *Why Discard if You can Recycle?: A Recycling Max Pooling Module for 3D Point Cloud Analysis*. The paper could be found [here](https://openaccess.thecvf.com/content/CVPR2022/papers/Chen_Why_Discard_if_You_Can_Recycle_A_Recycling_Max_Pooling_CVPR_2022_paper.pdf). 3 | **For a quick learning, you could go to /model/cls or /model/seg to compare the original DGCNN and DGCNN with RMP for classification and segmentation task. The code for ScanObjectNN, ModelNet40 and S3DIS dataset experiment has been given, I will finish the code instruction of ScanObjectNN and S3DIS as soon as possible.** 4 | ## Recycle Max Pooling Module 5 | The picture below shows the network structure of point-based method. Most point-based method use max pooling module to extract Permutation Invaraint Feature for downstream task. However, according to our abservation, a great number of points are completely discarded during the max pooling. 6 | ![image width="100" height="100"](Image/Motivation.PNG) 7 | 8 | In order to solve this problem, we proposed the Recycling Max Pooling Module to make use of the discarded points' feature, which is shown below: 9 | ![image width="100" height="100"](Image/model.PNG) \ 10 | For more details, please refer to the [paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Chen_Why_Discard_if_You_Can_Recycle_A_Recycling_Max_Pooling_CVPR_2022_paper.pdf) 11 | 12 | ## Point Cloud Classification on ModelNet40 13 | You can download the [official data](https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip), and unzip it. The path of "Data" is needed for traning. 14 | 15 | ### Train 16 | Training the original DGCNN 17 | 18 | ``` 19 | python main_modelnet40.py --data_path /path/to/Data --exp_name DGCNN 20 | ``` 21 | 22 | Training the DGCNN with RMP(Recycling Max Pooling) 23 | ``` 24 | python main_modelnet40_ref.py --data_path /path/to/Data --exp_name DGCNN_RMP 25 | ``` 26 | 27 | Evaluation is performed after each epoch's training. You could check the accuracy by 28 | ``` 29 | tensorboard --logdir /path/to/the/experiment 30 | ``` 31 | 32 | -------------------------------------------------------------------------------- /main_S3DIS.py: -------------------------------------------------------------------------------- 1 | from ast import parse 2 | import numpy as np 3 | import torch 4 | 5 | 6 | 7 | from utils.test_perform_cal import get_mean_accuracy 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from torch.utils.tensorboard import SummaryWriter 11 | from datetime import datetime 12 | import os 13 | from utils.cal_final_result import accuracy_calculation 14 | from Dataloader.S3DIS_random import get_sets 15 | 16 | from model.seg.DGCNN_seg import DGCNN_semseg 17 | 18 | from sklearn.metrics import confusion_matrix 19 | import argparse 20 | from utils.all_utils import smooth_loss 21 | 22 | TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) 23 | 24 | 25 | 26 | 27 | def get_parse(): 28 | parser=argparse.ArgumentParser(description='argumment') 29 | parser.add_argument('--seed',default=0) 30 | parser.add_argument('--test_area',default=5,type=int) 31 | parser.add_argument('--exp_name',default='DGCNN_area5_ori',type=str) 32 | parser.add_argument('--batch_size',default=10,type=int) 33 | parser.add_argument('--lr',default=0.001) 34 | parser.add_argument('--neighbor',default=20) 35 | parser.add_argument('--data_path',default='/data1/jiajing/worksapce/Algorithm/PointNet/Pointnet_Pointnet2_pytorch/data/stanford_indoor3d/') 36 | parser.add_argument('--epoch',default=100,type=int) 37 | parser.add_argument('--multi_gpu',default=0,type=int) 38 | parser.add_argument('--max_iter',default=6,type=int) 39 | 40 | return parser.parse_args() 41 | 42 | cfg=get_parse() 43 | 44 | def main(): 45 | seed=cfg.seed 46 | np.random.seed(seed) 47 | torch.manual_seed(seed) 48 | torch.cuda.manual_seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | torch.backends.cudnn.enabled=False 52 | cuda=0 53 | 54 | datapath=cfg.data_path 55 | model=DGCNN_semseg(num_cls=13,inpt_length=9) 56 | 57 | 58 | train_loader,test_loader,valid_loader=get_sets(datapath,batch_size=cfg.batch_size,test_batch=cfg.batch_size,test_area=cfg.test_area) 59 | 60 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 69 | assert torch.cuda.is_available() 70 | device=torch.device('cuda:{}'.format(cuda_n)) 71 | #这里应该用GPU 72 | 73 | if cfg.multi_gpu: 74 | model = nn.DataParallel(model).to(device) 75 | else: 76 | model=model.to(device) 77 | 78 | initial_epoch=0 79 | training_epoch=cfg.epoch 80 | 81 | loss_func=smooth_loss 82 | optimizer=torch.optim.Adam(model.parameters(),lr=cfg.lr) 83 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,20),gamma=0.7) 84 | 85 | 86 | 87 | 88 | #here we define train_one_epoch 89 | def train_one_epoch(): 90 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 91 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 92 | 93 | summary={"loss/train":np.mean(epsum['losses'])} 94 | return summary 95 | 96 | 97 | def eval_one_epoch(): 98 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 99 | 100 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 101 | mean_acc=np.mean(epsum['acc']) 102 | summary={'meac':mean_acc} 103 | summary["loss/valid"]=np.mean(epsum['losses']) 104 | return summary,epsum['conf_mat'] 105 | 106 | 107 | 108 | 109 | exp_path=os.path.join('Exp',exp_name) 110 | if not os.path.exists(exp_path): 111 | os.mkdir(exp_path) 112 | 113 | tensorboard=SummaryWriter(log_dir=os.path.join(exp_path,'TB')) 114 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 115 | 116 | #build folder for pth_file 117 | # pth_save_path=os.path.join('./pth_file',exp_name) 118 | pth_save_path=os.path.join(exp_path,'pth_file') 119 | if not os.path.exists(pth_save_path): 120 | os.mkdir(pth_save_path) 121 | 122 | acc_list=[] 123 | for e in tqdm_epoch: 124 | train_summary=train_one_epoch() 125 | valid_summary,conf_mat=eval_one_epoch() 126 | summary={**train_summary,**valid_summary} 127 | acc_list.append(summary['meac']) 128 | lr_schedule.step() 129 | 130 | if np.max(acc_list)==acc_list[-1]: 131 | 132 | if cfg.multi_gpu: 133 | summary_saved={**train_summary, 134 | 'conf_mat':conf_mat, 135 | 'model_state':model.module.state_dict(), 136 | 'optimizer_state':optimizer.state_dict()} 137 | else: 138 | summary_saved={**train_summary, 139 | 'conf_mat':conf_mat, 140 | 'model_state':model.state_dict(), 141 | 'optimizer_state':optimizer.state_dict()} 142 | 143 | 144 | torch.save(summary_saved,os.path.join(pth_save_path,'epoch_{}'.format(e))) 145 | 146 | 147 | 148 | for name,val in summary.items(): 149 | tensorboard.add_scalar(name,val,e) 150 | 151 | 152 | 153 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 154 | if mode=='train': 155 | model.train() 156 | else: 157 | model.eval() 158 | param_grads=[] 159 | for param in model.parameters(): 160 | param_grads+=[param.requires_grad] 161 | param.requires_grad=False 162 | 163 | summary={"losses":[],"acc":[]} 164 | device=next(model.parameters()).device 165 | confusion_martrix=np.zeros((13,13)) 166 | 167 | 168 | 169 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 170 | x,y=x_cpu.to(device),y_cpu.to(device) 171 | 172 | if mode=='train': 173 | optimizer.zero_grad() 174 | 175 | logits=model(x) 176 | if loss_func is not None: 177 | re_logit=logits.reshape(-1,logits.shape[-1]) 178 | loss=loss_func(re_logit,y.view(-1)) 179 | summary['losses']+=[loss.item()] 180 | 181 | 182 | if mode=='train': 183 | loss.backward() 184 | optimizer.step() 185 | 186 | if loss_func is not None and i%loss_interval==0: 187 | tqdm_iter.set_description("Loss: %.3f"%(np.mean(summary['losses']))) 188 | 189 | else: 190 | log=logits.cpu().detach().numpy() 191 | lab=y_cpu.numpy() 192 | num_cls=model.num_cls 193 | 194 | mean_acc=get_mean_accuracy(log,lab,num_cls) 195 | summary['acc'].append(mean_acc) 196 | 197 | 198 | label=lab.reshape(-1) 199 | prediction=log.reshape(-1,num_cls) 200 | prediction=np.argmax(prediction,1) 201 | confusion_martrix+=confusion_matrix(label,prediction,labels=np.arange(13)) 202 | 203 | 204 | 205 | if i%loss_interval==0: 206 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 207 | 208 | 209 | if mode!='train': 210 | for param,value in zip(model.parameters(),param_grads): 211 | param.requires_grad=value 212 | 213 | summary['conf_mat']=confusion_martrix 214 | 215 | return summary 216 | 217 | 218 | if __name__=='__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /main_S3DIS_ref.py: -------------------------------------------------------------------------------- 1 | from ast import parse 2 | import numpy as np 3 | import torch 4 | 5 | from utils.test_perform_cal import get_mean_accuracy 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from torch.utils.tensorboard import SummaryWriter 9 | from datetime import datetime 10 | import os 11 | from utils.cal_final_result import accuracy_calculation 12 | from Dataloader.S3DIS_random import get_sets 13 | 14 | from model.seg.DGCNN_seg_repmax import DGCNN_semseg_ref 15 | 16 | from utils.all_utils import smooth_loss 17 | 18 | 19 | from sklearn.metrics import confusion_matrix 20 | import argparse 21 | 22 | 23 | TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) 24 | 25 | 26 | 27 | 28 | def get_parse(): 29 | parser=argparse.ArgumentParser(description='argumment') 30 | parser.add_argument('--seed',default=0) 31 | parser.add_argument('--cuda',default=0,type=int) 32 | parser.add_argument('--test_area',default=5,type=int) 33 | parser.add_argument('--exp_name',default='DGCNN_area5_ref',type=str) 34 | parser.add_argument('--batch_size',default=10,type=int) 35 | parser.add_argument('--lr',default=0.001) 36 | parser.add_argument('--neighbor',default=20) 37 | parser.add_argument('--data_path',default='/data1/jiajing/worksapce/Algorithm/PointNet/Pointnet_Pointnet2_pytorch/data/stanford_indoor3d/') 38 | parser.add_argument('--epoch',default=100,type=int) 39 | parser.add_argument('--multi_gpu',default=0,type=int) 40 | parser.add_argument('--max_iter',default=6,type=int) 41 | 42 | return parser.parse_args() 43 | 44 | cfg=get_parse() 45 | 46 | 47 | def main(): 48 | seed=cfg.seed 49 | np.random.seed(seed) 50 | torch.manual_seed(seed) 51 | torch.cuda.manual_seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = False 54 | torch.backends.cudnn.enabled=False 55 | 56 | cuda=0 57 | datapath=cfg.data_path 58 | datapath=cfg.data_path 59 | 60 | 61 | model=DGCNN_semseg_ref(num_cls=13,inpt_length=9) 62 | 63 | train_loader,test_loader,valid_loader=get_sets(datapath,batch_size=cfg.batch_size,test_batch=cfg.batch_size,test_area=cfg.test_area) 64 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 65 | 66 | 67 | 68 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 69 | assert torch.cuda.is_available() 70 | device=torch.device('cuda:{}'.format(cuda_n)) 71 | #这里应该用GPU 72 | 73 | if cfg.multi_gpu: 74 | model = nn.DataParallel(model).to(device) 75 | else: 76 | model=model.to(device) 77 | # device=torch.device('cpu') 78 | # model=model.to(device) 79 | initial_epoch=0 80 | training_epoch=cfg.epoch 81 | 82 | loss_func=smooth_loss 83 | optimizer=torch.optim.Adam(model.parameters(),lr=cfg.lr) 84 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,20),gamma=0.7) 85 | 86 | 87 | 88 | 89 | #here we define train_one_epoch 90 | def train_one_epoch(): 91 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 92 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 93 | 94 | summary={"loss/train":np.mean(epsum['losses'])} 95 | return summary 96 | 97 | 98 | def eval_one_epoch(): 99 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 100 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 101 | mean_acc=np.mean(epsum['acc']) 102 | summary={'meac':mean_acc} 103 | summary["loss/valid"]=np.mean(epsum['losses']) 104 | return summary,epsum['conf_mat'] 105 | 106 | 107 | 108 | 109 | exp_path=os.path.join('Exp',exp_name) 110 | if not os.path.exists(exp_path): 111 | os.mkdir(exp_path) 112 | 113 | tensorboard=SummaryWriter(log_dir=os.path.join(exp_path,'TB')) 114 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 115 | 116 | pth_save_path=os.path.join(exp_path,'pth_file') 117 | if not os.path.exists(pth_save_path): 118 | os.mkdir(pth_save_path) 119 | 120 | acc_list=[] 121 | for e in tqdm_epoch: 122 | train_summary=train_one_epoch() 123 | valid_summary,conf_mat=eval_one_epoch() 124 | summary={**train_summary,**valid_summary} 125 | acc_list.append(summary['meac']) 126 | lr_schedule.step() 127 | 128 | if np.max(acc_list)==acc_list[-1]: 129 | 130 | if cfg.multi_gpu: 131 | summary_saved={**train_summary, 132 | 'conf_mat':conf_mat, 133 | 'model_state':model.module.state_dict(), 134 | 'optimizer_state':optimizer.state_dict()} 135 | else: 136 | summary_saved={**train_summary, 137 | 'conf_mat':conf_mat, 138 | 'model_state':model.state_dict(), 139 | 'optimizer_state':optimizer.state_dict()} 140 | 141 | # torch.save(summary_saved,'./pth_file/{0}/epoch_{1}'.format(exp_name,e)) 142 | torch.save(summary_saved,os.path.join(pth_save_path,'epoch_{}'.format(e))) 143 | 144 | 145 | 146 | for name,val in summary.items(): 147 | tensorboard.add_scalar(name,val,e) 148 | 149 | 150 | 151 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 152 | if mode=='train': 153 | model.train() 154 | else: 155 | model.eval() 156 | param_grads=[] 157 | for param in model.parameters(): 158 | param_grads+=[param.requires_grad] 159 | param.requires_grad=False 160 | 161 | summary={"losses":[],"acc":[]} 162 | device=next(model.parameters()).device 163 | confusion_martrix=np.zeros((13,13)) 164 | 165 | 166 | 167 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 168 | x,y=x_cpu.to(device),y_cpu.to(device) 169 | 170 | if mode=='train': 171 | optimizer.zero_grad() 172 | 173 | 174 | logits,loss=model(x,y) 175 | if loss_func is not None: 176 | summary['losses']+=[loss.item()] 177 | 178 | 179 | if mode=='train': 180 | loss.backward() 181 | optimizer.step() 182 | 183 | #display 184 | if loss_func is not None and i%loss_interval==0: 185 | tqdm_iter.set_description("Loss: %.3f"%(np.mean(summary['losses']))) 186 | 187 | else: 188 | log=logits.cpu().detach().numpy() 189 | lab=y_cpu.numpy() 190 | # num_cls=model.num_cls 191 | num_cls=13 192 | 193 | mean_acc=get_mean_accuracy(log,lab,num_cls) 194 | summary['acc'].append(mean_acc) 195 | 196 | 197 | label=lab.reshape(-1) 198 | prediction=log.reshape(-1,num_cls) 199 | prediction=np.argmax(prediction,1) 200 | confusion_martrix+=confusion_matrix(label,prediction,labels=np.arange(13)) 201 | 202 | 203 | 204 | if i%loss_interval==0: 205 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 206 | 207 | 208 | if mode!='train': 209 | for param,value in zip(model.parameters(),param_grads): 210 | param.requires_grad=value 211 | 212 | summary['conf_mat']=confusion_martrix 213 | 214 | return summary 215 | 216 | 217 | if __name__=='__main__': 218 | main() 219 | -------------------------------------------------------------------------------- /main_modelnet40.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.arrayprint import DatetimeFormat 3 | import torch 4 | from torch.nn.modules import module 5 | 6 | from Dataloader.ModelNet40 import get_sets 7 | 8 | from utils.test_perform_cal import get_cls_accuracy 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | from torch.utils.tensorboard import SummaryWriter 12 | from datetime import datetime 13 | import os 14 | import shutil 15 | from utils.cal_final_result import accuracy_calculation 16 | from utils.all_utils import smooth_loss 17 | import random 18 | import time 19 | import argparse 20 | 21 | 22 | def get_parse(): 23 | parser=argparse.ArgumentParser(description='argumment') 24 | parser.add_argument('--exp_name',type=str,default='DGCNN_exp') 25 | parser.add_argument('--train',default=True) 26 | parser.add_argument('--seed',default=0) 27 | parser.add_argument('--batch_size',default=16) 28 | parser.add_argument('--data_path',default='/data1/jiajing/dataset/ModelNet40/data') 29 | parser.add_argument('--lr',default=0.001) 30 | return parser.parse_args() 31 | 32 | 33 | cfg=get_parse() 34 | 35 | 36 | 37 | 38 | 39 | def main(): 40 | random.seed(cfg.seed) 41 | np.random.seed(cfg.seed) 42 | torch.manual_seed(cfg.seed) 43 | torch.cuda.manual_seed(cfg.seed) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | torch.backends.cudnn.enabled=False 47 | 48 | cuda=0 49 | 50 | 51 | from model.cls.DGCNN import DGCNN 52 | model=DGCNN(40) 53 | 54 | train_loader,test_loader,valid_loader=get_sets(cfg.data_path,train_batch_size=cfg.batch_size,test_batch_size=cfg.batch_size) 55 | 56 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 67 | assert torch.cuda.is_available() 68 | epoch_acc=[] 69 | 70 | #这里应该用GPU 71 | device=torch.device('cuda:{}'.format(cuda_n)) 72 | model=model.to(device) 73 | 74 | 75 | initial_epoch=0 76 | training_epoch=350 77 | 78 | loss_func=smooth_loss 79 | optimizer=torch.optim.Adam(model.parameters(),lr=cfg.lr) 80 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,40),gamma=0.7) 81 | 82 | 83 | #here we define train_one_epoch 84 | def train_one_epoch(): 85 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 86 | 87 | #真正训练这里应该解封 88 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 89 | 90 | summary={"loss/train":np.mean(epsum['losses'])} 91 | return summary 92 | 93 | 94 | def eval_one_epoch(): 95 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 96 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 97 | mean_acc=np.mean(epsum['acc']) 98 | 99 | epoch_acc.append(mean_acc) 100 | 101 | summary={'meac':mean_acc} 102 | summary["loss/valid"]=np.mean(epsum['losses']) 103 | return summary 104 | 105 | 106 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 107 | 108 | 109 | if not os.path.exists('./Exp'): 110 | os.mkdir('./Exp') 111 | 112 | 113 | exp_path=os.path.join('./Exp',cfg.exp_name) 114 | pth_path=os.path.join(exp_path,'pth_file') 115 | tensorboard_path=os.path.join(exp_path,'TB') 116 | if not os.path.exists(exp_path): 117 | os.mkdir(exp_path) 118 | os.mkdir(pth_path) 119 | os.mkdir(tensorboard_path) 120 | 121 | tensorboard=SummaryWriter(log_dir=tensorboard_path) 122 | 123 | 124 | for e in tqdm_epoch: 125 | train_summary=train_one_epoch() 126 | valid_summary=eval_one_epoch() 127 | summary={**train_summary,**valid_summary} 128 | lr_schedule.step() 129 | #save checkpoint 130 | if np.max(epoch_acc)==epoch_acc[-1]: 131 | summary_saved={**summary, 132 | 'model_state':model.state_dict(), 133 | 'optimizer_state':optimizer.state_dict()} 134 | 135 | 136 | # torch.save(summary_saved,'./pth_file/{0}/epoch_{1}'.format(exp_name,e)) 137 | torch.save(summary_saved,os.path.join(pth_path,'epoch_{}'.format(e))) 138 | 139 | for name,val in summary.items(): 140 | tensorboard.add_scalar(name,val,e) 141 | 142 | 143 | 144 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 145 | if mode=='train': 146 | model.train() 147 | else: 148 | model.eval() 149 | param_grads=[] 150 | for param in model.parameters(): 151 | param_grads+=[param.requires_grad] 152 | param.requires_grad=False 153 | 154 | summary={"losses":[],"acc":[]} 155 | device=next(model.parameters()).device 156 | 157 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 158 | x,y=x_cpu.to(device),y_cpu.to(device) 159 | 160 | if mode=='train': 161 | optimizer.zero_grad() 162 | 163 | #logtis' shape is [batch,40] 164 | #y size is [batch,1] 165 | 166 | logits=model(x) 167 | 168 | 169 | 170 | if loss_func is not None: 171 | re_logit=logits.reshape(-1,logits.shape[-1]) 172 | 173 | 174 | #### here is the loss ##### 175 | loss=loss_func(re_logit,y.view(-1)) 176 | summary['losses']+=[loss.item()] 177 | 178 | if mode=='train': 179 | loss.backward(retain_graph=True) 180 | optimizer.step() 181 | 182 | #display 183 | if loss_func is not None and i%loss_interval==0: 184 | tqdm_iter.set_description("Loss: {:.3f}".format(np.mean(summary['losses']))) 185 | 186 | else: 187 | log=logits.cpu().detach().numpy() 188 | lab=y_cpu.numpy() 189 | 190 | mean_acc=get_cls_accuracy(log,lab) 191 | summary['acc'].append(mean_acc) 192 | if i%loss_interval==0: 193 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 194 | 195 | 196 | if mode!='train': 197 | for param,value in zip(model.parameters(),param_grads): 198 | param.requires_grad=value 199 | 200 | 201 | return summary 202 | 203 | 204 | 205 | if __name__=='__main__': 206 | main() 207 | -------------------------------------------------------------------------------- /main_modelnet40_ref.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | 5 | 6 | from Dataloader.ModelNet40 import get_sets 7 | 8 | from utils.test_perform_cal import get_cls_accuracy 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | from torch.utils.tensorboard import SummaryWriter 12 | from datetime import datetime 13 | import os 14 | import random 15 | import argparse 16 | 17 | 18 | TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) 19 | 20 | 21 | 22 | 23 | def get_parse(): 24 | parser=argparse.ArgumentParser(description='argumment') 25 | parser.add_argument('--exp_name',type=str,default='DGCNN_ref_exp') 26 | parser.add_argument('--seed',default=0) 27 | parser.add_argument('--batch_size',default=16) 28 | parser.add_argument('--data_path',default='/data1/jiajing/dataset/ModelNet40/data') 29 | parser.add_argument('--lr',default=0.001) 30 | return parser.parse_args() 31 | 32 | cfg=get_parse() 33 | 34 | def main(): 35 | random.seed(cfg.seed) 36 | np.random.seed(cfg.seed) 37 | torch.manual_seed(cfg.seed) 38 | torch.cuda.manual_seed(cfg.seed) 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | torch.backends.cudnn.enabled=False 42 | 43 | cuda=0 44 | from model.cls.DGCNN_repmax import DGCNN_ref 45 | model=DGCNN_ref(40,0.6,2.1) 46 | 47 | 48 | 49 | 50 | train_loader,test_loader,valid_loader=get_sets(cfg.data_path,train_batch_size=cfg.batch_size,test_batch_size=cfg.batch_size) 51 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 52 | 53 | 54 | 55 | 56 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 57 | assert torch.cuda.is_available() 58 | epoch_acc=[] 59 | 60 | #这里应该用GPU 61 | device=torch.device('cuda:{}'.format(cuda_n)) 62 | model=model.to(device) 63 | 64 | 65 | initial_epoch=0 66 | training_epoch=350 67 | 68 | loss_func=nn.CrossEntropyLoss() 69 | optimizer=torch.optim.Adam(model.parameters(),lr=cfg.lr) 70 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,40),gamma=0.7) 71 | 72 | 73 | #here we define train_one_epoch 74 | def train_one_epoch(): 75 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 76 | 77 | #真正训练这里应该解封 78 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 79 | 80 | summary={"loss/train":np.mean(epsum['losses'])} 81 | return summary 82 | 83 | 84 | def eval_one_epoch(): 85 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 86 | 87 | 88 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 89 | mean_acc=np.mean(epsum['acc']) 90 | 91 | epoch_acc.append(mean_acc) 92 | 93 | summary={'meac':mean_acc} 94 | summary["loss/valid"]=np.mean(epsum['losses']) 95 | return summary 96 | 97 | 98 | 99 | #build tensorboard 100 | 101 | # tensorboard=SummaryWriter(log_dir='.Exp/{}/TB'.format(exp_name)) 102 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 103 | 104 | #build folder for pth_file 105 | if not os.path.exists('./Exp'): 106 | os.mkdir('./Exp') 107 | 108 | 109 | exp_path=os.path.join('./Exp',cfg.exp_name) 110 | pth_path=os.path.join(exp_path,'pth_file') 111 | tensorboard_path=os.path.join(exp_path,'TB') 112 | if not os.path.exists(exp_path): 113 | os.mkdir(exp_path) 114 | os.mkdir(pth_path) 115 | os.mkdir(tensorboard_path) 116 | # pth_save_path=os.path.join('Exp',exp_name,'pth_file') 117 | # if not os.path.exists(pth_save_path): 118 | # os.mkdir(pth_save_path) 119 | 120 | tensorboard=SummaryWriter(log_dir=tensorboard_path) 121 | 122 | 123 | for e in tqdm_epoch: 124 | train_summary=train_one_epoch() 125 | valid_summary=eval_one_epoch() 126 | summary={**train_summary,**valid_summary} 127 | lr_schedule.step() 128 | 129 | if np.max(epoch_acc)==epoch_acc[-1]: 130 | summary_saved={**summary, 131 | 'model_state':model.state_dict(), 132 | 'optimizer_state':optimizer.state_dict()} 133 | 134 | torch.save(summary_saved,os.path.join(pth_path,'epoch_{}'.format(e))) 135 | 136 | for name,val in summary.items(): 137 | tensorboard.add_scalar(name,val,e) 138 | 139 | 140 | 141 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 142 | if mode=='train': 143 | model.train() 144 | else: 145 | model.eval() 146 | param_grads=[] 147 | for param in model.parameters(): 148 | param_grads+=[param.requires_grad] 149 | param.requires_grad=False 150 | 151 | summary={"losses":[],"acc":[]} 152 | device=next(model.parameters()).device 153 | 154 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 155 | x,y=x_cpu.to(device),y_cpu.to(device) 156 | 157 | if mode=='train': 158 | optimizer.zero_grad() 159 | 160 | if mode=='train': 161 | logits,loss=model(x,y.view(-1)) 162 | else: 163 | logits,loss=model(x,y.view(-1)) 164 | 165 | 166 | if loss_func is not None: 167 | summary['losses']+=[loss.item()] 168 | 169 | if mode=='train': 170 | loss.backward(retain_graph=True) 171 | optimizer.step() 172 | 173 | #display 174 | if loss_func is not None and i%loss_interval==0: 175 | tqdm_iter.set_description("Loss: {:.3f}".format(np.mean(summary['losses']))) 176 | 177 | else: 178 | log=logits.cpu().detach().numpy() 179 | lab=y_cpu.numpy() 180 | 181 | mean_acc=get_cls_accuracy(log,lab) 182 | summary['acc'].append(mean_acc) 183 | 184 | 185 | # summary['logits']+=[logits.cpu().detach().numpy()] 186 | # summary['labels']+=[y_cpu.numpy()] 187 | if i%loss_interval==0: 188 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 189 | 190 | 191 | if mode!='train': 192 | for param,value in zip(model.parameters(),param_grads): 193 | param.requires_grad=value 194 | 195 | 196 | return summary 197 | 198 | 199 | if __name__=='__main__': 200 | main() 201 | -------------------------------------------------------------------------------- /main_scanobj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn.modules import module 4 | 5 | # from Dataloader.ModelNet40 import get_sets 6 | from Dataloader.scanobjectnn import get_sets 7 | 8 | from utils.test_perform_cal import get_cls_accuracy 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | from torch.utils.tensorboard import SummaryWriter 12 | from datetime import datetime 13 | import os 14 | import shutil 15 | from utils.cal_final_result import accuracy_calculation 16 | from utils.all_utils import smooth_loss 17 | # from model.cls.DPFA import DGCNN_new_cls 18 | import random 19 | 20 | 21 | import time 22 | import argparse 23 | 24 | 25 | 26 | 27 | def get_parse(): 28 | parser=argparse.ArgumentParser(description='argumment') 29 | parser.add_argument('--exp_name',type=str,default='DGCNN_scanobj_exp') 30 | parser.add_argument('--train',default=True) 31 | parser.add_argument('--seed',default=0) 32 | parser.add_argument('--batch_size',default=16) 33 | parser.add_argument('--data_path',default='/data1/jiajing/dataset/scanobjectnn/main_split_nobg') 34 | parser.add_argument('--lr',default=0.001) 35 | return parser.parse_args() 36 | 37 | 38 | cfg=get_parse() 39 | 40 | 41 | 42 | def main(): 43 | random.seed(cfg.seed) 44 | np.random.seed(cfg.seed) 45 | torch.manual_seed(cfg.seed) 46 | torch.cuda.manual_seed(cfg.seed) 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = False 49 | 50 | 51 | 52 | 53 | torch.backends.cudnn.enabled=False 54 | 55 | cuda=0 56 | datapath=cfg.data_path 57 | 58 | 59 | 60 | from model.cls.DGCNN import DGCNN 61 | model=DGCNN(15) 62 | 63 | 64 | train_loader,test_loader,valid_loader=get_sets(datapath,train_batch_size=cfg.batch_size,test_batch_size=cfg.batch_size) 65 | 66 | 67 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 68 | 69 | 70 | 71 | 72 | 73 | 74 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 75 | assert torch.cuda.is_available() 76 | epoch_acc=[] 77 | 78 | device=torch.device('cuda:{}'.format(cuda_n)) 79 | model=model.to(device) 80 | 81 | 82 | initial_epoch=0 83 | training_epoch=350 84 | 85 | loss_func=smooth_loss 86 | optimizer=torch.optim.Adam(model.parameters(),lr=cfg.lr) 87 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,40),gamma=0.7) 88 | 89 | #here we define train_one_epoch 90 | def train_one_epoch(): 91 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 92 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 93 | 94 | summary={"loss/train":np.mean(epsum['losses'])} 95 | return summary 96 | 97 | 98 | def eval_one_epoch(): 99 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 100 | 101 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 102 | mean_acc=np.mean(epsum['acc']) 103 | 104 | epoch_acc.append(mean_acc) 105 | 106 | summary={'meac':mean_acc} 107 | summary["loss/valid"]=np.mean(epsum['losses']) 108 | return summary 109 | 110 | 111 | 112 | #build tensorboard 113 | 114 | # tensorboard=SummaryWriter(log_dir='.Exp/{}/TB'.format(exp_name)) 115 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 116 | 117 | #build folder for pth_file 118 | exp_path=os.path.join('./Exp',exp_name) 119 | pth_path=os.path.join(exp_path,'pth_file') 120 | tensorboard_path=os.path.join(exp_path,'TB') 121 | if not os.path.exists(exp_path): 122 | os.mkdir(exp_path) 123 | os.mkdir(pth_path) 124 | os.mkdir(tensorboard_path) 125 | # pth_save_path=os.path.join('Exp',exp_name,'pth_file') 126 | # if not os.path.exists(pth_save_path): 127 | # os.mkdir(pth_save_path) 128 | 129 | tensorboard=SummaryWriter(log_dir=tensorboard_path) 130 | 131 | 132 | for e in tqdm_epoch: 133 | train_summary=train_one_epoch() 134 | valid_summary=eval_one_epoch() 135 | summary={**train_summary,**valid_summary} 136 | lr_schedule.step() 137 | #save checkpoint 138 | if np.max(epoch_acc)==epoch_acc[-1]: 139 | summary_saved={**summary, 140 | 'model_state':model.state_dict(), 141 | 'optimizer_state':optimizer.state_dict()} 142 | 143 | 144 | # torch.save(summary_saved,'./pth_file/{0}/epoch_{1}'.format(exp_name,e)) 145 | torch.save(summary_saved,os.path.join(pth_path,'epoch_{}'.format(e))) 146 | 147 | for name,val in summary.items(): 148 | tensorboard.add_scalar(name,val,e) 149 | 150 | 151 | 152 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 153 | if mode=='train': 154 | model.train() 155 | else: 156 | model.eval() 157 | param_grads=[] 158 | for param in model.parameters(): 159 | param_grads+=[param.requires_grad] 160 | param.requires_grad=False 161 | 162 | summary={"losses":[],"acc":[]} 163 | device=next(model.parameters()).device 164 | 165 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 166 | x,y=x_cpu.to(device),y_cpu.to(device) 167 | 168 | if mode=='train': 169 | optimizer.zero_grad() 170 | 171 | #logtis' shape is [batch,40] 172 | #y size is [batch,1] 173 | if mode=='train': 174 | logits=model(x) 175 | else: 176 | logits=model(x) 177 | 178 | 179 | if loss_func is not None: 180 | re_logit=logits.reshape(-1,logits.shape[-1]) 181 | 182 | 183 | #### here is the loss ##### 184 | loss=loss_func(re_logit,y.view(-1)) 185 | summary['losses']+=[loss.item()] 186 | 187 | if mode=='train': 188 | loss.backward(retain_graph=True) 189 | optimizer.step() 190 | 191 | #display 192 | if loss_func is not None and i%loss_interval==0: 193 | tqdm_iter.set_description("Loss: {:.3f}".format(np.mean(summary['losses']))) 194 | 195 | else: 196 | log=logits.cpu().detach().numpy() 197 | lab=y_cpu.numpy() 198 | 199 | mean_acc=get_cls_accuracy(log,lab) 200 | summary['acc'].append(mean_acc) 201 | 202 | if i%loss_interval==0: 203 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 204 | 205 | 206 | if mode!='train': 207 | for param,value in zip(model.parameters(),param_grads): 208 | param.requires_grad=value 209 | 210 | # summary["logits"] = np.concatenate(summary["logits"], axis=0) 211 | # summary["labels"] = np.concatenate(summary["labels"], axis=0) 212 | 213 | return summary 214 | 215 | 216 | if __name__=='__main__': 217 | main() 218 | -------------------------------------------------------------------------------- /main_scanobj_ref.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn.modules import module 4 | 5 | from Dataloader.scanobjectnn import get_sets 6 | 7 | from utils.test_perform_cal import get_cls_accuracy 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from torch.utils.tensorboard import SummaryWriter 11 | from datetime import datetime 12 | import os 13 | import shutil 14 | from utils.cal_final_result import accuracy_calculation 15 | from utils.all_utils import smooth_loss 16 | 17 | 18 | # from model.cls.DGCNN_repmax import DGCNN_ref 19 | # from model.cls.PointNet_pp_repmax import PointNet_pp 20 | # from model.cls.PointNet_repmax import PointNet_hie_ref 21 | # from model.cls.curvenet_repmax import CurveNet_repmax 22 | 23 | import random 24 | 25 | 26 | import time 27 | import argparse 28 | 29 | 30 | 31 | def get_parse(): 32 | parser=argparse.ArgumentParser(description='argumment') 33 | parser.add_argument('--exp_name',type=str,default='DGCNN_ref_scanobj_exp') 34 | parser.add_argument('--seed',default=0) 35 | parser.add_argument('--batch_size',default=16) 36 | parser.add_argument('--data_path',default='/data1/jiajing/dataset/scanobjectnn/main_split_nobg') 37 | parser.add_argument('--lr',default=0.001) 38 | return parser.parse_args() 39 | 40 | 41 | cfg=get_parse() 42 | 43 | 44 | 45 | 46 | 47 | 48 | def main(): 49 | seed=cfg.seed 50 | random.seed(seed) 51 | np.random.seed(seed) 52 | torch.manual_seed(seed) 53 | torch.cuda.manual_seed(seed) 54 | torch.backends.cudnn.deterministic = True 55 | torch.backends.cudnn.benchmark = False 56 | torch.backends.cudnn.enabled=False 57 | 58 | 59 | cuda=0 60 | datapath=cfg.data_path 61 | 62 | 63 | from model.cls.DGCNN_repmax import DGCNN_ref 64 | model=DGCNN_ref(40,0.8,2.1) 65 | 66 | train_loader,test_loader,valid_loader=get_sets(datapath,train_batch_size=cfg.batch_size,test_batch_size=cfg.batch_size) 67 | 68 | train_model(model,train_loader,valid_loader,cfg.exp_name,cuda) 69 | 70 | 71 | 72 | 73 | 74 | 75 | def train_model(model,train_loader,valid_loader,exp_name,cuda_n): 76 | assert torch.cuda.is_available() 77 | epoch_acc=[] 78 | 79 | #这里应该用GPU 80 | device=torch.device('cuda:{}'.format(cuda_n)) 81 | model=model.to(device) 82 | 83 | 84 | initial_epoch=0 85 | training_epoch=350 86 | 87 | loss_func=smooth_loss 88 | optimizer=torch.optim.Adam(model.parameters(),lr=0.001) 89 | lr_schedule=torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=np.arange(10,training_epoch,40),gamma=0.7) 90 | 91 | 92 | 93 | #here we define train_one_epoch 94 | def train_one_epoch(): 95 | iterations=tqdm(train_loader,ncols=100,unit='batch',leave=False) 96 | 97 | #真正训练这里应该解封 98 | epsum=run_one_epoch(model,iterations,"train",loss_func=loss_func,optimizer=optimizer,loss_interval=10) 99 | 100 | summary={"loss/train":np.mean(epsum['losses'])} 101 | return summary 102 | 103 | 104 | def eval_one_epoch(): 105 | iteration=tqdm(valid_loader,ncols=100,unit='batch',leave=False) 106 | #epsum only have logit and labes 107 | #epsum['logti'] is (batch,4096,13) 108 | #epsum['labels] is (batch,4096) 109 | 110 | epsum=run_one_epoch(model,iteration,"valid",loss_func=loss_func) 111 | mean_acc=np.mean(epsum['acc']) 112 | 113 | epoch_acc.append(mean_acc) 114 | 115 | summary={'meac':mean_acc} 116 | summary["loss/valid"]=np.mean(epsum['losses']) 117 | return summary 118 | 119 | 120 | 121 | #build tensorboard 122 | 123 | # tensorboard=SummaryWriter(log_dir='.Exp/{}/TB'.format(exp_name)) 124 | tqdm_epoch=tqdm(range(initial_epoch,training_epoch),unit='epoch',ncols=100) 125 | 126 | #build folder for pth_file 127 | exp_path=os.path.join('./Exp',exp_name) 128 | pth_path=os.path.join(exp_path,'pth_file') 129 | tensorboard_path=os.path.join(exp_path,'TB') 130 | if not os.path.exists(exp_path): 131 | os.mkdir(exp_path) 132 | os.mkdir(pth_path) 133 | os.mkdir(tensorboard_path) 134 | # pth_save_path=os.path.join('Exp',exp_name,'pth_file') 135 | # if not os.path.exists(pth_save_path): 136 | # os.mkdir(pth_save_path) 137 | 138 | tensorboard=SummaryWriter(log_dir=tensorboard_path) 139 | 140 | 141 | for e in tqdm_epoch: 142 | train_summary=train_one_epoch() 143 | valid_summary=eval_one_epoch() 144 | summary={**train_summary,**valid_summary} 145 | lr_schedule.step() 146 | #save checkpoint 147 | if np.max(epoch_acc)==epoch_acc[-1]: 148 | summary_saved={**summary, 149 | 'model_state':model.state_dict(), 150 | 'optimizer_state':optimizer.state_dict()} 151 | 152 | 153 | # torch.save(summary_saved,'./pth_file/{0}/epoch_{1}'.format(exp_name,e)) 154 | torch.save(summary_saved,os.path.join(pth_path,'epoch_{}'.format(e))) 155 | 156 | for name,val in summary.items(): 157 | tensorboard.add_scalar(name,val,e) 158 | 159 | 160 | 161 | def run_one_epoch(model,tqdm_iter,mode,loss_func=None,optimizer=None,loss_interval=10): 162 | if mode=='train': 163 | model.train() 164 | else: 165 | model.eval() 166 | param_grads=[] 167 | for param in model.parameters(): 168 | param_grads+=[param.requires_grad] 169 | param.requires_grad=False 170 | 171 | summary={"losses":[],"acc":[]} 172 | device=next(model.parameters()).device 173 | 174 | for i,(x_cpu,y_cpu) in enumerate(tqdm_iter): 175 | x,y=x_cpu.to(device),y_cpu.to(device) 176 | 177 | if mode=='train': 178 | optimizer.zero_grad() 179 | 180 | #logtis' shape is [batch,40] 181 | #y size is [batch,1] 182 | if mode=='train': 183 | logits,loss=model(x,y.view(-1)) 184 | else: 185 | logits,loss=model(x,y.view(-1)) 186 | 187 | 188 | if loss_func is not None: 189 | summary['losses']+=[loss.item()] 190 | 191 | if mode=='train': 192 | loss.backward(retain_graph=True) 193 | optimizer.step() 194 | 195 | #display 196 | if loss_func is not None and i%loss_interval==0: 197 | tqdm_iter.set_description("Loss: {:.3f}".format(np.mean(summary['losses']))) 198 | 199 | else: 200 | log=logits.cpu().detach().numpy() 201 | lab=y_cpu.numpy() 202 | 203 | mean_acc=get_cls_accuracy(log,lab) 204 | summary['acc'].append(mean_acc) 205 | 206 | if i%loss_interval==0: 207 | tqdm_iter.set_description("mea_ac: %.3f"%(np.mean(summary['acc']))) 208 | 209 | 210 | if mode!='train': 211 | for param,value in zip(model.parameters(),param_grads): 212 | param.requires_grad=value 213 | 214 | 215 | return summary 216 | 217 | 218 | if __name__=='__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /model/cls/DGCNN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def knn(x, k): 12 | inner = -2*torch.matmul(x.transpose(2, 1), x) 13 | xx = torch.sum(x**2, dim=1, keepdim=True) 14 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 15 | 16 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 17 | return idx 18 | 19 | 20 | def get_graph_feature(x, k=20, idx=None): 21 | batch_size = x.size(0) 22 | num_points = x.size(2) 23 | x = x.view(batch_size, -1, num_points) 24 | if idx is None: 25 | idx = knn(x, k=k) # (batch_size, num_points, k) 26 | device = x.device 27 | 28 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 29 | 30 | idx = idx + idx_base 31 | 32 | idx = idx.view(-1) 33 | 34 | _, num_dims, _ = x.size() 35 | 36 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 37 | feature = x.view(batch_size*num_points, -1)[idx, :] 38 | feature = feature.view(batch_size, num_points, k, num_dims) 39 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 40 | 41 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 42 | 43 | return feature 44 | 45 | 46 | 47 | class DGCNN(nn.Module): 48 | def __init__(self, output_channels=40): 49 | super(DGCNN, self).__init__() 50 | self.k = 20 51 | 52 | self.bn1 = nn.BatchNorm2d(64) 53 | self.bn2 = nn.BatchNorm2d(64) 54 | self.bn3 = nn.BatchNorm2d(128) 55 | self.bn4 = nn.BatchNorm2d(256) 56 | self.bn5 = nn.BatchNorm1d(1024) 57 | 58 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 59 | self.bn1, 60 | nn.LeakyReLU(negative_slope=0.2)) 61 | self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 62 | self.bn2, 63 | nn.LeakyReLU(negative_slope=0.2)) 64 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), 65 | self.bn3, 66 | nn.LeakyReLU(negative_slope=0.2)) 67 | self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), 68 | self.bn4, 69 | nn.LeakyReLU(negative_slope=0.2)) 70 | self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 71 | self.bn5, 72 | nn.LeakyReLU(negative_slope=0.2)) 73 | self.linear1 = nn.Linear(1024, 512, bias=False) 74 | self.bn6 = nn.BatchNorm1d(512) 75 | self.dp1 = nn.Dropout(p=0.5) 76 | self.linear2 = nn.Linear(512, 256) 77 | self.bn7 = nn.BatchNorm1d(256) 78 | self.dp2 = nn.Dropout(p=0.5) 79 | self.linear3 = nn.Linear(256, output_channels) 80 | 81 | def forward(self, x): 82 | batch_size = x.size(0) 83 | x = get_graph_feature(x, k=self.k) 84 | x = self.conv1(x) 85 | x1 = x.max(dim=-1, keepdim=False)[0] 86 | 87 | 88 | x = get_graph_feature(x1, k=self.k) 89 | x = self.conv2(x) 90 | x2 = x.max(dim=-1, keepdim=False)[0] 91 | 92 | 93 | 94 | x = get_graph_feature(x2, k=self.k) 95 | x = self.conv3(x) 96 | x3 = x.max(dim=-1, keepdim=False)[0] 97 | 98 | x = get_graph_feature(x3, k=self.k) 99 | x = self.conv4(x) 100 | x4 = x.max(dim=-1, keepdim=False)[0] 101 | 102 | x = torch.cat((x1, x2, x3, x4), dim=1) 103 | 104 | x = self.conv5(x) 105 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 106 | 107 | x = F.leaky_relu(self.bn6(self.linear1(x1)), negative_slope=0.2) 108 | x = self.dp1(x) 109 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 110 | x = self.dp2(x) 111 | x = self.linear3(x) 112 | return x 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | if __name__=='__main__': 134 | inpt=torch.rand((10,3,1024)) 135 | # network=DGCNN(output_channels=40) 136 | network=DGCNN(output_channels=40) 137 | out=network(inpt) 138 | -------------------------------------------------------------------------------- /model/cls/DGCNN_repmax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def cal_loss_raw(pred, gold): 11 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 12 | 13 | gold = gold.contiguous().view(-1) 14 | 15 | eps = 0.2 16 | n_class = pred.size(1) 17 | 18 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 19 | #one_hot = F.one_hot(gold, pred.shape[1]).float() 20 | 21 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 22 | log_prb = F.log_softmax(pred, dim=1) 23 | 24 | loss_raw = -(one_hot * log_prb).sum(dim=1) 25 | 26 | 27 | loss = loss_raw.mean() 28 | 29 | return loss,loss_raw 30 | 31 | 32 | 33 | 34 | def knn(x, k): 35 | inner = -2*torch.matmul(x.transpose(2, 1), x) 36 | xx = torch.sum(x**2, dim=1, keepdim=True) 37 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 38 | 39 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 40 | return idx 41 | 42 | 43 | def get_graph_feature(x, k=20, idx=None): 44 | batch_size = x.size(0) 45 | num_points = x.size(2) 46 | x = x.view(batch_size, -1, num_points) 47 | if idx is None: 48 | idx = knn(x, k=k) # (batch_size, num_points, k) 49 | device = x.device 50 | 51 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 52 | 53 | idx = idx + idx_base 54 | 55 | idx = idx.view(-1) 56 | 57 | _, num_dims, _ = x.size() 58 | 59 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 60 | feature = x.view(batch_size*num_points, -1)[idx, :] 61 | feature = feature.view(batch_size, num_points, k, num_dims) 62 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 63 | 64 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 65 | 66 | return feature 67 | 68 | 69 | 70 | 71 | class DGCNN_ref(nn.Module): 72 | def __init__(self, output_channels=40,lamda=None,alpha=2.1): 73 | super(DGCNN_ref, self).__init__() 74 | self.k = 20 75 | 76 | self.bn1 = nn.BatchNorm2d(64) 77 | self.bn2 = nn.BatchNorm2d(64) 78 | self.bn3 = nn.BatchNorm2d(128) 79 | self.bn4 = nn.BatchNorm2d(256) 80 | self.bn5 = nn.BatchNorm1d(1024) 81 | 82 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 83 | self.bn1, 84 | nn.LeakyReLU(negative_slope=0.2)) 85 | self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 86 | self.bn2, 87 | nn.LeakyReLU(negative_slope=0.2)) 88 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), 89 | self.bn3, 90 | nn.LeakyReLU(negative_slope=0.2)) 91 | self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), 92 | self.bn4, 93 | nn.LeakyReLU(negative_slope=0.2)) 94 | self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 95 | self.bn5, 96 | nn.LeakyReLU(negative_slope=0.2)) 97 | 98 | self.linear1 = nn.Linear(1024, 512, bias=False) 99 | self.bn6 = nn.BatchNorm1d(512) 100 | self.dp1 = nn.Dropout(p=0.5) 101 | self.linear2 = nn.Linear(512, 256) 102 | self.bn7 = nn.BatchNorm1d(256) 103 | self.dp2 = nn.Dropout(p=0.5) 104 | self.linear3 = nn.Linear(256, output_channels) 105 | 106 | self.head=nn.Sequential(self.linear1, 107 | self.bn6, 108 | nn.LeakyReLU(0.2), 109 | self.dp1, 110 | self.linear2, 111 | self.bn7, 112 | nn.LeakyReLU(0.2), 113 | self.dp2, 114 | self.linear3) 115 | 116 | 117 | self.refine_time=2 118 | self.NUM=alpha 119 | self.lamda=lamda 120 | 121 | def get_legal_id(self,used_index_list,data_index,num_point): 122 | mask=torch.zeros(num_point) 123 | used_index=used_index_list[data_index] 124 | mask[used_index]=1 125 | legal_index=torch.where(mask==0)[0] 126 | return legal_index 127 | 128 | def feature_refinement(self,point_feat): 129 | device=point_feat.device 130 | num_point=point_feat.shape[2] 131 | batch_size=point_feat.shape[0] 132 | 133 | feat_list=[] 134 | used_index_list=[torch.LongTensor([]).to(device) for _ in range(batch_size)] 135 | 136 | for i in range(self.refine_time): 137 | hie_feat_list=[] 138 | for data_index,single_data in enumerate(point_feat): 139 | legal_index=self.get_legal_id(used_index_list,data_index,num_point) 140 | legal_feat=single_data[:,legal_index] 141 | 142 | max_feat,max_index=torch.max(legal_feat,-1) 143 | max_index=torch.unique(max_index).detach() 144 | hie_feat_list.append(max_feat) 145 | used_index_list[data_index]=torch.cat((used_index_list[data_index],max_index)) 146 | 147 | hie_feat_list=torch.stack(hie_feat_list,0) 148 | feat_list.append(hie_feat_list) 149 | 150 | feat_list=torch.stack(feat_list,0) 151 | # feat_list=feat_list.permute(1,0,2) 152 | 153 | return feat_list 154 | 155 | 156 | def get_aug_loss(self,inv_feat,y): 157 | 158 | device=inv_feat.device 159 | 160 | pred_1=self.head(inv_feat[:,:,0]) 161 | pred_2=self.head(inv_feat[:,:,1]) 162 | 163 | pred1_loss,pred1_row_loss=cal_loss_raw(pred_1,y) 164 | pred2_loss,pred2_row_loss=cal_loss_raw(pred_2,y) 165 | 166 | pc_con = F.softmax(pred_1, dim=-1)#.max(dim=1)[0] 167 | one_hot = F.one_hot(y, pred_1.shape[1]).float() 168 | pc_con = (pc_con*one_hot).max(dim=1)[0] 169 | 170 | parameters = torch.max(torch.tensor(self.NUM).to(device), torch.exp(pc_con) * self.NUM).to(device) 171 | aug_diff = torch.abs(1.0 - torch.exp(pred2_row_loss - pred1_row_loss * parameters)).mean() 172 | 173 | if self.lamda==None: 174 | loss=pred1_loss+pred2_loss+aug_diff 175 | else: 176 | loss=(1-self.lamda)*(pred1_loss+pred2_loss)+self.lamda*aug_diff 177 | 178 | return pred_1,loss 179 | 180 | 181 | 182 | 183 | 184 | 185 | def forward(self,x,y): 186 | # y=torch.LongTensor([0,0]).cuda() 187 | 188 | batch_size = x.size(0) 189 | x = get_graph_feature(x, k=self.k) 190 | x = self.conv1(x) 191 | x1 = x.max(dim=-1, keepdim=False)[0] 192 | 193 | 194 | x = get_graph_feature(x1, k=self.k) 195 | x = self.conv2(x) 196 | x2 = x.max(dim=-1, keepdim=False)[0] 197 | 198 | 199 | 200 | x = get_graph_feature(x2, k=self.k) 201 | x = self.conv3(x) 202 | x3 = x.max(dim=-1, keepdim=False)[0] 203 | 204 | x = get_graph_feature(x3, k=self.k) 205 | x = self.conv4(x) 206 | x4 = x.max(dim=-1, keepdim=False)[0] 207 | 208 | x = torch.cat((x1, x2, x3, x4), dim=1) 209 | 210 | x = self.conv5(x) 211 | inv_feat=self.feature_refinement(x).permute(1,2,0) 212 | pred1,loss = self.get_aug_loss(inv_feat,y) 213 | return pred1,loss 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | class DGCNN_hie_ref(nn.Module): 223 | def __init__(self, output_channels=40): 224 | super(DGCNN_hie_ref, self).__init__() 225 | self.k = 20 226 | 227 | self.bn1 = nn.BatchNorm2d(64) 228 | self.bn2 = nn.BatchNorm2d(64) 229 | self.bn3 = nn.BatchNorm2d(128) 230 | self.bn4 = nn.BatchNorm2d(256) 231 | self.bn5 = nn.BatchNorm1d(1024) 232 | 233 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 234 | self.bn1, 235 | nn.LeakyReLU(negative_slope=0.2)) 236 | self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 237 | self.bn2, 238 | nn.LeakyReLU(negative_slope=0.2)) 239 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), 240 | self.bn3, 241 | nn.LeakyReLU(negative_slope=0.2)) 242 | self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), 243 | self.bn4, 244 | nn.LeakyReLU(negative_slope=0.2)) 245 | self.conv5 = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 246 | self.bn5, 247 | nn.LeakyReLU(negative_slope=0.2)) 248 | 249 | self.linear1 = nn.Linear(1024, 512, bias=False) 250 | self.bn6 = nn.BatchNorm1d(512) 251 | self.dp1 = nn.Dropout(p=0.5) 252 | self.linear2 = nn.Linear(512, 256) 253 | self.bn7 = nn.BatchNorm1d(256) 254 | self.dp2 = nn.Dropout(p=0.5) 255 | self.linear3 = nn.Linear(256, output_channels) 256 | 257 | self.head=nn.Sequential(self.linear1, 258 | self.bn6, 259 | nn.LeakyReLU(0.2), 260 | self.dp1, 261 | self.linear2, 262 | self.bn7, 263 | nn.LeakyReLU(0.2), 264 | self.dp2, 265 | self.linear3) 266 | 267 | 268 | self.refine_time=4 269 | self.NUM=1.5 270 | 271 | ####### if hie, the previous level works for next level, 272 | ####### if not hie, all level serves for the first level 273 | self.hie=True 274 | 275 | 276 | def get_legal_id(self,used_index_list,data_index,num_point): 277 | mask=torch.zeros(num_point) 278 | used_index=used_index_list[data_index] 279 | mask[used_index]=1 280 | legal_index=torch.where(mask==0)[0] 281 | return legal_index 282 | 283 | def feature_refinement(self,point_feat): 284 | device=point_feat.device 285 | num_point=point_feat.shape[2] 286 | batch_size=point_feat.shape[0] 287 | 288 | feat_list=[] 289 | used_index_list=[torch.LongTensor([]).to(device) for _ in range(batch_size)] 290 | 291 | for i in range(self.refine_time): 292 | hie_feat_list=[] 293 | for data_index,single_data in enumerate(point_feat): 294 | legal_index=self.get_legal_id(used_index_list,data_index,num_point) 295 | legal_feat=single_data[:,legal_index] 296 | 297 | max_feat,max_index=torch.max(legal_feat,-1) 298 | max_index=torch.unique(max_index).detach() 299 | hie_feat_list.append(max_feat) 300 | used_index_list[data_index]=torch.cat((used_index_list[data_index],max_index)) 301 | 302 | hie_feat_list=torch.stack(hie_feat_list,0) 303 | feat_list.append(hie_feat_list) 304 | 305 | feat_list=torch.stack(feat_list,0) 306 | # feat_list=feat_list.permute(1,0,2) 307 | 308 | return feat_list 309 | 310 | 311 | def get_aug_loss(self,inv_feat,y): 312 | if self.hie: 313 | level_num=-2 314 | else: 315 | level_num=0 316 | 317 | 318 | device=inv_feat.device 319 | iter_time=inv_feat.shape[-1] 320 | 321 | pred_list=[] 322 | # pred_loss_list=[] 323 | pred_row_loss_list=[] 324 | 325 | self.num_list=[self.NUM,self.NUM] 326 | 327 | pred_loss_total=0 328 | aug_diff=0 329 | for i in range(iter_time): 330 | pred=self.head(inv_feat[:,:,i]) 331 | pred_list.append(pred) 332 | 333 | pred_loss,pred_row_loss=cal_loss_raw(pred,y) 334 | # pred_loss_list.append(pred_loss) 335 | pred_loss_total+=pred_loss 336 | pred_row_loss_list.append(pred_row_loss) 337 | 338 | if i!=0: 339 | pc_con=F.softmax(pred_list[level_num],dim=-1) 340 | one_hot=F.one_hot(y,pred_list[level_num].shape[1]).float() 341 | pc_con = (pc_con*one_hot).max(dim=1)[0] 342 | parameters = torch.max(torch.tensor(self.num_list[i-1]).to(device), torch.exp(pc_con) * self.num_list[i-1]).to(device) 343 | aug_diff += torch.abs(1.0 - torch.exp(pred_row_loss_list[-1] - pred_row_loss_list[level_num] * parameters)).mean() 344 | 345 | total_loss=aug_diff+pred_loss_total 346 | return pred_list[0],total_loss 347 | 348 | 349 | 350 | 351 | 352 | def forward(self, x,y): 353 | batch_size = x.size(0) 354 | x = get_graph_feature(x, k=self.k) 355 | x = self.conv1(x) 356 | x1 = x.max(dim=-1, keepdim=False)[0] 357 | 358 | 359 | x = get_graph_feature(x1, k=self.k) 360 | x = self.conv2(x) 361 | x2 = x.max(dim=-1, keepdim=False)[0] 362 | 363 | 364 | 365 | x = get_graph_feature(x2, k=self.k) 366 | x = self.conv3(x) 367 | x3 = x.max(dim=-1, keepdim=False)[0] 368 | 369 | x = get_graph_feature(x3, k=self.k) 370 | x = self.conv4(x) 371 | x4 = x.max(dim=-1, keepdim=False)[0] 372 | 373 | x = torch.cat((x1, x2, x3, x4), dim=1) 374 | 375 | x = self.conv5(x) 376 | inv_feat=self.feature_refinement(x).permute(1,2,0) 377 | pred1,loss = self.get_aug_loss(inv_feat,y) 378 | return pred1,loss 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | if __name__=='__main__': 389 | inpt=torch.rand((5,3,1024)) 390 | network=DGCNN_ref(output_channels=40) 391 | label=torch.LongTensor([0,0,0,0,0]) 392 | 393 | out=network(inpt,label) 394 | -------------------------------------------------------------------------------- /model/seg/DGCNN_seg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | 12 | def knn(x, k): 13 | inner = -2*torch.matmul(x.transpose(2, 1), x) 14 | xx = torch.sum(x**2, dim=1, keepdim=True) 15 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 16 | 17 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 18 | return idx 19 | 20 | 21 | def get_graph_feature(x, k=20, idx=None, dim9=False): 22 | batch_size = x.size(0) 23 | num_points = x.size(2) 24 | x = x.view(batch_size, -1, num_points) 25 | if idx is None: 26 | if dim9 == False: 27 | idx = knn(x, k=k) # (batch_size, num_points, k) 28 | else: 29 | idx = knn(x[:, 6:], k=k) 30 | device = x.device 31 | 32 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 33 | 34 | idx = idx + idx_base 35 | 36 | idx = idx.view(-1) 37 | 38 | _, num_dims, _ = x.size() 39 | 40 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 41 | feature = x.view(batch_size*num_points, -1)[idx, :] 42 | feature = feature.view(batch_size, num_points, k, num_dims) 43 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 44 | 45 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 46 | 47 | return feature # (batch_size, 2*num_dims, num_points, k) 48 | 49 | 50 | 51 | 52 | class DGCNN_semseg(nn.Module): 53 | def __init__(self,num_cls,inpt_length): 54 | super(DGCNN_semseg, self).__init__() 55 | # self.args = args 56 | self.k = 20 57 | self.num_cls=num_cls 58 | self.inpt_length=inpt_length 59 | self.bn1 = nn.BatchNorm2d(64) 60 | self.bn2 = nn.BatchNorm2d(64) 61 | self.bn3 = nn.BatchNorm2d(64) 62 | self.bn4 = nn.BatchNorm2d(64) 63 | self.bn5 = nn.BatchNorm2d(64) 64 | # self.bn6 = nn.BatchNorm1d(args.emb_dims) 65 | self.bn6 = nn.BatchNorm1d(1024) 66 | self.bn7 = nn.BatchNorm1d(512) 67 | self.bn8 = nn.BatchNorm1d(256) 68 | 69 | self.conv1 = nn.Sequential(nn.Conv2d(2*self.inpt_length, 64, kernel_size=1, bias=False), 70 | self.bn1, 71 | nn.LeakyReLU(negative_slope=0.2)) 72 | self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 73 | self.bn2, 74 | nn.LeakyReLU(negative_slope=0.2)) 75 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 76 | self.bn3, 77 | nn.LeakyReLU(negative_slope=0.2)) 78 | self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 79 | self.bn4, 80 | nn.LeakyReLU(negative_slope=0.2)) 81 | self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 82 | self.bn5, 83 | nn.LeakyReLU(negative_slope=0.2)) 84 | # self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False), 85 | # self.bn6, 86 | # nn.LeakyReLU(negative_slope=0.2)) 87 | self.conv6 = nn.Sequential(nn.Conv1d(192, 1024, kernel_size=1, bias=False), 88 | self.bn6, 89 | nn.LeakyReLU(negative_slope=0.2)) 90 | self.conv7 = nn.Sequential(nn.Conv1d(1216, 512, kernel_size=1, bias=False), 91 | self.bn7, 92 | nn.LeakyReLU(negative_slope=0.2)) 93 | self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False), 94 | self.bn8, 95 | nn.LeakyReLU(negative_slope=0.2)) 96 | self.dp1 = nn.Dropout(p=0.5) 97 | self.conv9 = nn.Conv1d(256, self.num_cls, kernel_size=1, bias=False) 98 | 99 | 100 | def forward(self, x): 101 | x=x.permute(0,2,1) 102 | 103 | batch_size = x.size(0) 104 | num_points = x.size(2) 105 | 106 | x = get_graph_feature(x, k=self.k, dim9=True) # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k) 107 | x = self.conv1(x) # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k) 108 | x = self.conv2(x) # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k) 109 | x1 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 110 | 111 | x = get_graph_feature(x1, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k) 112 | x = self.conv3(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k) 113 | x = self.conv4(x) # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k) 114 | x2 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 115 | 116 | x = get_graph_feature(x2, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k) 117 | x = self.conv5(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k) 118 | x3 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 119 | 120 | x = torch.cat((x1, x2, x3), dim=1) # (batch_size, 64*3, num_points) 121 | 122 | x = self.conv6(x) # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points) 123 | x = x.max(dim=-1, keepdim=True)[0] # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1) 124 | 125 | x = x.repeat(1, 1, num_points) # (batch_size, 1024, num_points) 126 | x = torch.cat((x, x1, x2, x3), dim=1) # (batch_size, 1024+64*3, num_points) 127 | 128 | x = self.conv7(x) # (batch_size, 1024+64*3, num_points) -> (batch_size, 512, num_points) 129 | x = self.conv8(x) # (batch_size, 512, num_points) -> (batch_size, 256, num_points) 130 | x = self.dp1(x) 131 | x = self.conv9(x) # (batch_size, 256, num_points) -> (batch_size, 13, num_points) 132 | 133 | x=x.permute(0,2,1) 134 | 135 | 136 | return x 137 | 138 | 139 | if __name__=='__main__': 140 | inpt=torch.randn((5,4096,9)) 141 | label=torch.randint(low=0,high=13,size=(5,4096)) 142 | 143 | net=DGCNN_semseg(num_cls=13,inpt_length=9) 144 | out=net(inpt,label) 145 | -------------------------------------------------------------------------------- /model/seg/DGCNN_seg_repmax.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | import torch.nn.functional as F 10 | 11 | 12 | def knn(x, k): 13 | inner = -2*torch.matmul(x.transpose(2, 1), x) 14 | xx = torch.sum(x**2, dim=1, keepdim=True) 15 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 16 | 17 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 18 | return idx 19 | 20 | 21 | 22 | 23 | def cal_loss_raw(pred, gold): 24 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 25 | 26 | gold = gold.contiguous().view(-1) 27 | 28 | eps = 0.2 29 | n_class = pred.size(1) 30 | 31 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 32 | #one_hot = F.one_hot(gold, pred.shape[1]).float() 33 | 34 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 35 | log_prb = F.log_softmax(pred, dim=1) 36 | 37 | loss_raw = -(one_hot * log_prb).sum(dim=1) 38 | 39 | 40 | loss = loss_raw.mean() 41 | 42 | return loss,loss_raw 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | def get_graph_feature(x, k=20, idx=None, dim9=False): 55 | batch_size = x.size(0) 56 | num_points = x.size(2) 57 | x = x.view(batch_size, -1, num_points) 58 | if idx is None: 59 | if dim9 == False: 60 | idx = knn(x, k=k) # (batch_size, num_points, k) 61 | else: 62 | idx = knn(x[:, 6:], k=k) 63 | device = x.device 64 | 65 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 66 | 67 | idx = idx + idx_base 68 | 69 | idx = idx.view(-1) 70 | 71 | _, num_dims, _ = x.size() 72 | 73 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 74 | feature = x.view(batch_size*num_points, -1)[idx, :] 75 | feature = feature.view(batch_size, num_points, k, num_dims) 76 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 77 | 78 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 79 | 80 | return feature # (batch_size, 2*num_dims, num_points, k) 81 | 82 | 83 | 84 | 85 | class DGCNN_semseg_ref(nn.Module): 86 | def __init__(self,num_cls,inpt_length): 87 | super(DGCNN_semseg_ref, self).__init__() 88 | # self.args = args 89 | self.k = 20 90 | self.num_cls=num_cls 91 | self.inpt_length=inpt_length 92 | self.bn1 = nn.BatchNorm2d(64) 93 | self.bn2 = nn.BatchNorm2d(64) 94 | self.bn3 = nn.BatchNorm2d(64) 95 | self.bn4 = nn.BatchNorm2d(64) 96 | self.bn5 = nn.BatchNorm2d(64) 97 | # self.bn6 = nn.BatchNorm1d(args.emb_dims) 98 | self.bn6 = nn.BatchNorm1d(1024) 99 | self.bn7 = nn.BatchNorm1d(512) 100 | self.bn8 = nn.BatchNorm1d(256) 101 | 102 | self.conv1 = nn.Sequential(nn.Conv2d(2*self.inpt_length, 64, kernel_size=1, bias=False), 103 | self.bn1, 104 | nn.LeakyReLU(negative_slope=0.2)) 105 | self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 106 | self.bn2, 107 | nn.LeakyReLU(negative_slope=0.2)) 108 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 109 | self.bn3, 110 | nn.LeakyReLU(negative_slope=0.2)) 111 | self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False), 112 | self.bn4, 113 | nn.LeakyReLU(negative_slope=0.2)) 114 | self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 115 | self.bn5, 116 | nn.LeakyReLU(negative_slope=0.2)) 117 | # self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False), 118 | # self.bn6, 119 | # nn.LeakyReLU(negative_slope=0.2)) 120 | self.conv6 = nn.Sequential(nn.Conv1d(192, 1024, kernel_size=1, bias=False), 121 | self.bn6, 122 | nn.LeakyReLU(negative_slope=0.2)) 123 | self.conv7 = nn.Sequential(nn.Conv1d(1216, 512, kernel_size=1, bias=False), 124 | self.bn7, 125 | nn.LeakyReLU(negative_slope=0.2)) 126 | self.conv8 = nn.Sequential(nn.Conv1d(512, 256, kernel_size=1, bias=False), 127 | self.bn8, 128 | nn.LeakyReLU(negative_slope=0.2)) 129 | self.dp1 = nn.Dropout(p=0.5) 130 | self.conv9 = nn.Conv1d(256, self.num_cls, kernel_size=1, bias=False) 131 | 132 | 133 | self.refine_time=2 134 | self.NUM=1.2 135 | 136 | self.head=nn.Sequential(self.conv7, 137 | self.conv8, 138 | self.dp1, 139 | self.conv9) 140 | 141 | 142 | def get_legal_id(self,used_index_list,data_index,num_point): 143 | mask=torch.zeros(num_point) 144 | used_index=used_index_list[data_index] 145 | mask[used_index]=1 146 | legal_index=torch.where(mask==0)[0] 147 | return legal_index 148 | 149 | def feature_refinement(self,point_feat): 150 | device=point_feat.device 151 | num_point=point_feat.shape[2] 152 | batch_size=point_feat.shape[0] 153 | 154 | feat_list=[] 155 | used_index_list=[torch.LongTensor([]).to(device) for _ in range(batch_size)] 156 | 157 | for i in range(self.refine_time): 158 | hie_feat_list=[] 159 | for data_index,single_data in enumerate(point_feat): 160 | legal_index=self.get_legal_id(used_index_list,data_index,num_point) 161 | legal_feat=single_data[:,legal_index] 162 | 163 | max_feat,max_index=torch.max(legal_feat,-1) 164 | max_index=torch.unique(max_index).detach() 165 | hie_feat_list.append(max_feat) 166 | used_index_list[data_index]=torch.cat((used_index_list[data_index],max_index)) 167 | 168 | hie_feat_list=torch.stack(hie_feat_list,0) 169 | feat_list.append(hie_feat_list) 170 | 171 | feat_list=torch.stack(feat_list,0) 172 | # feat_list=feat_list.permute(1,0,2) 173 | 174 | return feat_list 175 | 176 | 177 | def get_aug_loss(self,inv_feat,point_feat,y): 178 | 179 | _,_,num_point=point_feat.shape 180 | device=inv_feat.device 181 | 182 | 183 | pred_1=self.head(torch.cat([inv_feat[:,:,0].unsqueeze(-1).repeat(1,1,num_point),point_feat],1)).permute(0,2,1) 184 | pred_2=self.head(torch.cat([inv_feat[:,:,1].unsqueeze(-1).repeat(1,1,num_point),point_feat],1)).permute(0,2,1) 185 | 186 | pred1_loss,pred1_row_loss=cal_loss_raw(pred_1.reshape(-1,pred_1.shape[-1]),y.reshape(-1)) 187 | pred2_loss,pred2_row_loss=cal_loss_raw(pred_2.reshape(-1,pred_2.shape[-1]),y.reshape(-1)) 188 | 189 | pc_con = F.softmax(pred_1, dim=-1)#.max(dim=1)[0] 190 | one_hot = F.one_hot(y, pred_1.shape[-1]).float() 191 | pc_con = (pc_con*one_hot).max(dim=-1)[0] 192 | 193 | parameters = torch.max(torch.tensor(self.NUM).to(device), torch.exp(pc_con) * self.NUM).reshape(-1).to(device) 194 | aug_diff = torch.abs(1.0 - torch.exp(pred2_row_loss - pred1_row_loss * parameters)).mean() 195 | 196 | loss=pred1_loss+pred2_loss+aug_diff 197 | return pred_1,loss 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | def forward(self, x,y): 208 | x=x.permute(0,2,1) 209 | 210 | batch_size = x.size(0) 211 | num_points = x.size(2) 212 | 213 | x = get_graph_feature(x, k=self.k, dim9=True) # (batch_size, 9, num_points) -> (batch_size, 9*2, num_points, k) 214 | x = self.conv1(x) # (batch_size, 9*2, num_points, k) -> (batch_size, 64, num_points, k) 215 | x = self.conv2(x) # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k) 216 | x1 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 217 | 218 | x = get_graph_feature(x1, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k) 219 | x = self.conv3(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k) 220 | x = self.conv4(x) # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k) 221 | x2 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 222 | 223 | x = get_graph_feature(x2, k=self.k) # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k) 224 | x = self.conv5(x) # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k) 225 | x3 = x.max(dim=-1, keepdim=False)[0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 226 | 227 | point_feat = torch.cat((x1, x2, x3), dim=1) # (batch_size, 64*3, num_points) 228 | 229 | x = self.conv6(point_feat) # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points) 230 | 231 | 232 | inv_feat=self.feature_refinement(x).permute(1,2,0) 233 | pred1,loss = self.get_aug_loss(inv_feat,point_feat,y) 234 | 235 | 236 | 237 | return pred1,loss 238 | 239 | 240 | if __name__=='__main__': 241 | inpt=torch.randn((5,4096,9)) 242 | label=torch.randint(low=0,high=13,size=(5,4096)) 243 | 244 | net=DGCNN_semseg_ref(num_cls=13,inpt_length=9) 245 | out=net(inpt,label) 246 | s -------------------------------------------------------------------------------- /utils/GDAutil.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def knn(x, k): 6 | inner = -2*torch.matmul(x.transpose(2, 1), x) 7 | xx = torch.sum(x**2, dim=1, keepdim=True) 8 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 9 | 10 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 11 | return idx, pairwise_distance 12 | 13 | 14 | def local_operator(x, k): 15 | batch_size = x.size(0) 16 | num_points = x.size(2) 17 | x = x.view(batch_size, -1, num_points) 18 | idx, _ = knn(x, k=k) 19 | # device = torch.device('cuda') 20 | device=x.device 21 | 22 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 23 | 24 | idx = idx + idx_base 25 | 26 | idx = idx.view(-1) 27 | 28 | _, num_dims, _ = x.size() 29 | 30 | x = x.transpose(2, 1).contiguous() 31 | 32 | neighbor = x.view(batch_size * num_points, -1)[idx, :] 33 | 34 | neighbor = neighbor.view(batch_size, num_points, k, num_dims) 35 | 36 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 37 | 38 | ##### concatenate neighbor feature and the feature difference of query points and neighbor points 39 | feature = torch.cat((neighbor-x, neighbor), dim=3).permute(0, 3, 1, 2) # local and global all in 40 | 41 | return feature 42 | 43 | 44 | def local_operator_withnorm(x, norm_plt, k): 45 | batch_size = x.size(0) 46 | num_points = x.size(2) 47 | x = x.view(batch_size, -1, num_points) 48 | norm_plt = norm_plt.view(batch_size, -1, num_points) 49 | idx, _ = knn(x, k=k) # (batch_size, num_points, k) 50 | device = torch.device('cuda') 51 | 52 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 53 | 54 | idx = idx + idx_base 55 | 56 | idx = idx.view(-1) 57 | 58 | _, num_dims, _ = x.size() 59 | 60 | x = x.transpose(2, 1).contiguous() 61 | norm_plt = norm_plt.transpose(2, 1).contiguous() 62 | 63 | neighbor = x.view(batch_size * num_points, -1)[idx, :] 64 | neighbor_norm = norm_plt.view(batch_size * num_points, -1)[idx, :] 65 | 66 | neighbor = neighbor.view(batch_size, num_points, k, num_dims) 67 | neighbor_norm = neighbor_norm.view(batch_size, num_points, k, num_dims) 68 | 69 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 70 | 71 | feature = torch.cat((neighbor-x, neighbor, neighbor_norm), dim=3).permute(0, 3, 1, 2) # 3c 72 | 73 | return feature 74 | 75 | 76 | def GDM(x, M): 77 | """ 78 | Geometry-Disentangle Module 79 | M: number of disentangled points in both sharp and gentle variation components 80 | """ 81 | k = 64 # number of neighbors to decide the range of j in Eq.(5) 82 | tau = 0.2 # threshold in Eq.(2) 83 | sigma = 2 # parameters of f (Gaussian function in Eq.(2)) 84 | ############### 85 | """Graph Construction:""" 86 | # device = torch.device('cuda') 87 | device=x.device 88 | 89 | batch_size = x.size(0) 90 | num_points = x.size(2) 91 | x = x.view(batch_size, -1, num_points) 92 | 93 | idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...] 94 | 95 | # here we add a tau 96 | p1 = torch.abs(p) 97 | p1 = torch.sqrt(p1) 98 | mask = p1 < tau 99 | 100 | # here we add a sigma 101 | p = p / (sigma * sigma) 102 | w = torch.exp(p) # b,n,n 103 | w = torch.mul(mask.float(), w) 104 | 105 | b = 1/torch.sum(w, dim=1) 106 | b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points) 107 | c = torch.eye(num_points, num_points, device=device) 108 | c = c.expand(batch_size, num_points, num_points) 109 | D = b * c # b,n,n 110 | 111 | A = torch.matmul(D, w) # normalized adjacency matrix A_hat 112 | 113 | # Get Aij in a local area: 114 | idx2 = idx.view(batch_size * num_points, -1) 115 | idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points 116 | idx2 = idx2 + idx_base2 117 | 118 | idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k] 119 | idx2 = idx2.reshape(batch_size * num_points * (k - 1)) 120 | idx2 = idx2.view(-1) 121 | 122 | A = A.view(-1) 123 | A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k 124 | ############### 125 | """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:""" 126 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 127 | idx = idx + idx_base 128 | idx = idx.reshape(batch_size * num_points, k)[:, 1:k] 129 | idx = idx.reshape(batch_size * num_points * (k - 1)) 130 | 131 | _, num_dims, _ = x.size() 132 | 133 | x = x.transpose(2, 1).contiguous() # b,n,c 134 | neighbor = x.view(batch_size * num_points, -1)[idx, :] 135 | neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c 136 | A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1 137 | n = A.mul(neighbor) # b,n,k,c 138 | n = torch.sum(n, dim=2) # b,n,c 139 | 140 | pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5) 141 | pais = pai.topk(k=M, dim=-1)[1] # first M points as the sharp variation component 142 | paig = (-pai).topk(k=M, dim=-1)[1] # last M points as the gentle variation component 143 | 144 | pai_base = torch.arange(0, batch_size, device=device).view(-1, 1) * num_points 145 | indices = (pais + pai_base).view(-1) 146 | indiceg = (paig + pai_base).view(-1) 147 | 148 | xs = x.view(batch_size * num_points, -1)[indices, :] 149 | xg = x.view(batch_size * num_points, -1)[indiceg, :] 150 | 151 | xs = xs.view(batch_size, M, -1) # b,M,c 152 | xg = xg.view(batch_size, M, -1) # b,M,c 153 | 154 | return xs, xg 155 | 156 | 157 | class SGCAM(nn.Module): 158 | """Sharp-Gentle Complementary Attention Module:""" 159 | def __init__(self, in_channels, inter_channels=None, bn_layer=True): 160 | super(SGCAM, self).__init__() 161 | 162 | self.in_channels = in_channels 163 | self.inter_channels = inter_channels 164 | 165 | if self.inter_channels is None: 166 | self.inter_channels = in_channels // 2 167 | if self.inter_channels == 0: 168 | self.inter_channels = 1 169 | 170 | conv_nd = nn.Conv1d 171 | bn = nn.BatchNorm1d 172 | 173 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 174 | kernel_size=1, stride=1, padding=0) 175 | 176 | if bn_layer: 177 | self.W = nn.Sequential( 178 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 179 | kernel_size=1, stride=1, padding=0), 180 | bn(self.in_channels) 181 | ) 182 | nn.init.constant(self.W[1].weight, 0) 183 | nn.init.constant(self.W[1].bias, 0) 184 | else: 185 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 186 | kernel_size=1, stride=1, padding=0) 187 | nn.init.constant(self.W.weight, 0) 188 | nn.init.constant(self.W.bias, 0) 189 | 190 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 191 | kernel_size=1, stride=1, padding=0) 192 | 193 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 194 | kernel_size=1, stride=1, padding=0) 195 | 196 | def forward(self, x, x_2): 197 | batch_size = x.size(0) 198 | 199 | g_x = self.g(x_2).view(batch_size, self.inter_channels, -1) 200 | g_x = g_x.permute(0, 2, 1) 201 | 202 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 203 | theta_x = theta_x.permute(0, 2, 1) 204 | phi_x = self.phi(x_2).view(batch_size, self.inter_channels, -1) 205 | W = torch.matmul(theta_x, phi_x) # Attention Matrix 206 | N = W.size(-1) 207 | W_div_C = W / N 208 | 209 | y = torch.matmul(W_div_C, g_x) 210 | y = y.permute(0, 2, 1).contiguous() 211 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 212 | W_y = self.W(y) 213 | y = W_y + x 214 | 215 | return y 216 | 217 | -------------------------------------------------------------------------------- /utils/GDM_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | 6 | 7 | def knn(x, k): 8 | inner = -2*torch.matmul(x.transpose(2, 1), x) 9 | xx = torch.sum(x**2, dim=1, keepdim=True) 10 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 11 | 12 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 13 | return idx, pairwise_distance 14 | 15 | 16 | 17 | 18 | def GDM_repli(x): 19 | k = 64 # number of neighbors to decide the range of j in Eq.(5) 20 | tau = 0.2 # threshold in Eq.(2) 21 | sigma = 2 # parameters of f (Gaussian function in Eq.(2)) 22 | ############### 23 | """Graph Construction:""" 24 | device = x.device 25 | batch_size = x.size(0) 26 | num_points = x.size(2) 27 | x = x.view(batch_size, -1, num_points) 28 | 29 | idx, p = knn(x, k=k) # p: -[(x1-x2)^2+...] 30 | 31 | # here we add a tau 32 | p1 = torch.abs(p) 33 | p1 = torch.sqrt(p1) 34 | mask = p1 < tau 35 | 36 | # here we add a sigma 37 | p = p / (sigma * sigma) 38 | w = torch.exp(p) # b,n,n 39 | w = torch.mul(mask.float(), w) 40 | 41 | b = 1/torch.sum(w, dim=1) 42 | b = b.reshape(batch_size, num_points, 1).repeat(1, 1, num_points) 43 | c = torch.eye(num_points, num_points, device=device) 44 | c = c.expand(batch_size, num_points, num_points) 45 | D = b * c # b,n,n 46 | 47 | A = torch.matmul(D, w) # normalized adjacency matrix A_hat 48 | 49 | # Get Aij in a local area: 50 | idx2 = idx.view(batch_size * num_points, -1) 51 | idx_base2 = torch.arange(0, batch_size * num_points, device=device).view(-1, 1) * num_points 52 | idx2 = idx2 + idx_base2 53 | 54 | idx2 = idx2.reshape(batch_size * num_points, k)[:, 1:k] 55 | idx2 = idx2.reshape(batch_size * num_points * (k - 1)) 56 | idx2 = idx2.view(-1) 57 | 58 | A = A.view(-1) 59 | A = A[idx2].reshape(batch_size, num_points, k - 1) # Aij: b,n,k 60 | ############### 61 | """Disentangling Point Clouds into Sharp(xs) and Gentle(xg) Variation Components:""" 62 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 63 | idx = idx + idx_base 64 | idx = idx.reshape(batch_size * num_points, k)[:, 1:k] 65 | idx = idx.reshape(batch_size * num_points * (k - 1)) 66 | 67 | _, num_dims, _ = x.size() 68 | 69 | x = x.transpose(2, 1).contiguous() # b,n,c 70 | neighbor = x.view(batch_size * num_points, -1)[idx, :] 71 | neighbor = neighbor.view(batch_size, num_points, k - 1, num_dims) # b,n,k,c 72 | A = A.reshape(batch_size, num_points, k - 1, 1) # b,n,k,1 73 | n = A.mul(neighbor) # b,n,k,c 74 | n = torch.sum(n, dim=2) # b,n,c 75 | 76 | pai = torch.norm(x - n, dim=-1).pow(2) # Eq.(5) 77 | return pai 78 | 79 | 80 | 81 | def create_cluster(xyz,point_feat,num_hie): 82 | batch_size,feat_dim,num_point=point_feat.shape 83 | 84 | # pow_list=np.arange(1,num_hie+1) 85 | num_point_list=np.array([64,128]) 86 | feat_list=[] 87 | 88 | 89 | pi=GDM_repli(xyz) 90 | sort_index=torch.argsort(pi,-1) 91 | repeat_index=sort_index.unsqueeze(1).repeat(1,feat_dim,1) 92 | gather_feat=torch.gather(point_feat,-1,repeat_index) 93 | 94 | gather_feat=gather_feat.reshape(batch_size,feat_dim,num_hie,-1) 95 | gather_feat=torch.max(gather_feat,-1,keepdim=False)[0] 96 | 97 | 98 | 99 | # for num_point in num_point_list: 100 | # sf=gather_feat[:,:,:num_point] 101 | # sf_max_feat=torch.max(sf,-1,keepdim=True)[0] 102 | 103 | # gf=gather_feat[:,:,-num_point:] 104 | # gf_max_feat=torch.max(gf,-1,keepdim=True)[0] 105 | 106 | # feat_list.append(sf_max_feat) 107 | # feat_list.append(gf_max_feat) 108 | 109 | # cluster_max_feat=torch.cat(feat_list,-1) 110 | 111 | return gather_feat 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /utils/all_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | def smooth_loss(pred, gold): 9 | eps = 0.2 10 | 11 | n_class = pred.size(1) 12 | 13 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 14 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 15 | log_prb = F.log_softmax(pred, dim=1) 16 | 17 | loss = -(one_hot * log_prb).sum(dim=1).mean() 18 | 19 | return loss 20 | 21 | 22 | 23 | def to_categorical(y, num_classes): 24 | """ 1-hot encodes a tensor """ 25 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 26 | if (y.is_cuda): 27 | return new_y.cuda(non_blocking=True) 28 | return new_y 29 | -------------------------------------------------------------------------------- /utils/aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @author:liruihui 5 | @file: augmentor.py 6 | @time: 2019/09/16 7 | @contact: ruihuili.lee@gmail.com 8 | @github: https://liruihui.github.io/ 9 | @description: 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.utils.data 16 | from torch.autograd import Variable 17 | import numpy as np 18 | import torch.nn.functional as F 19 | import random 20 | 21 | def batch_quat_to_rotmat(q, out=None): 22 | 23 | B = q.size(0) 24 | 25 | if out is None: 26 | out = q.new_empty(B, 3, 3) 27 | 28 | # 2 / squared quaternion 2-norm 29 | len = torch.sum(q.pow(2), 1) 30 | s = 2/len 31 | 32 | s_ = torch.clamp(len,2.0/3.0,3.0/2.0) 33 | 34 | # coefficients of the Hamilton product of the quaternion with itself 35 | h = torch.bmm(q.unsqueeze(2), q.unsqueeze(1)) 36 | 37 | out[:, 0, 0] = (1 - (h[:, 2, 2] + h[:, 3, 3]).mul(s))#.mul(s_) 38 | out[:, 0, 1] = (h[:, 1, 2] - h[:, 3, 0]).mul(s) 39 | out[:, 0, 2] = (h[:, 1, 3] + h[:, 2, 0]).mul(s) 40 | 41 | out[:, 1, 0] = (h[:, 1, 2] + h[:, 3, 0]).mul(s) 42 | out[:, 1, 1] = (1 - (h[:, 1, 1] + h[:, 3, 3]).mul(s))#.mul(s_) 43 | out[:, 1, 2] = (h[:, 2, 3] - h[:, 1, 0]).mul(s) 44 | 45 | out[:, 2, 0] = (h[:, 1, 3] - h[:, 2, 0]).mul(s) 46 | out[:, 2, 1] = (h[:, 2, 3] + h[:, 1, 0]).mul(s) 47 | out[:, 2, 2] = (1 - (h[:, 1, 1] + h[:, 2, 2]).mul(s))#.mul(s_) 48 | 49 | return out, s_ 50 | 51 | class Augmentor_Rotation(nn.Module): 52 | def __init__(self,dim): 53 | super(Augmentor_Rotation, self).__init__() 54 | self.fc1 = nn.Linear(dim + 1024, 512) 55 | self.fc2 = nn.Linear(512, 256) 56 | self.fc3 = nn.Linear(256, 4) 57 | self.bn1 = nn.BatchNorm1d(512) 58 | self.bn2 = nn.BatchNorm1d(256) 59 | 60 | def forward(self, x): 61 | B = x.size()[0] 62 | x = F.relu(self.bn1(self.fc1(x))) 63 | x = F.relu(self.bn2(self.fc2(x))) 64 | x = self.fc3(x) 65 | 66 | # iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(B, 1) 67 | # if x.is_cuda: 68 | # iden = iden.cuda() 69 | # x = x + iden 70 | # x = x.view(-1, 3, 3) 71 | 72 | iden = x.new_tensor([1, 0, 0, 0]) 73 | x = x + iden 74 | 75 | # convert quaternion to rotation matrix 76 | x, s = batch_quat_to_rotmat(x) 77 | x = x.view(-1, 3, 3) 78 | s = s.view(B, 1, 1) 79 | return x, None 80 | 81 | 82 | class Augmentor_Displacement(nn.Module): 83 | def __init__(self, dim): 84 | super(Augmentor_Displacement, self).__init__() 85 | 86 | self.conv1 = torch.nn.Conv1d(dim+1024+64, 1024, 1) 87 | 88 | self.conv2 = torch.nn.Conv1d(1024, 512, 1) 89 | self.conv3 = torch.nn.Conv1d(512, 64, 1) 90 | self.conv4 = torch.nn.Conv1d(64, 3, 1) 91 | 92 | self.bn1 = nn.BatchNorm1d(1024) 93 | self.bn2 = nn.BatchNorm1d(512) 94 | self.bn3 = nn.BatchNorm1d(64) 95 | 96 | def forward(self, x): 97 | batchsize = x.size()[0] 98 | x = F.relu(self.bn1(self.conv1(x))) 99 | x = F.relu(self.bn2(self.conv2(x))) 100 | x = F.relu(self.bn3(self.conv3(x))) 101 | x = self.conv4(x) 102 | 103 | return x 104 | 105 | 106 | class Augmentor(nn.Module): 107 | def __init__(self,dim=1024,in_dim=3): 108 | super(Augmentor, self).__init__() 109 | self.dim = dim 110 | self.conv1 = torch.nn.Conv1d(in_dim, 64, 1) 111 | self.conv2 = torch.nn.Conv1d(64, 64, 1) 112 | self.conv3 = torch.nn.Conv1d(64, 128, 1) 113 | self.conv4 = torch.nn.Conv1d(128, 1024, 1) 114 | self.bn1 = nn.BatchNorm1d(64) 115 | self.bn2 = nn.BatchNorm1d(64) 116 | self.bn3 = nn.BatchNorm1d(128) 117 | self.bn4 = nn.BatchNorm1d(1024) 118 | 119 | self.rot = Augmentor_Rotation(self.dim) 120 | self.dis = Augmentor_Displacement(self.dim) 121 | 122 | def forward(self, pt, noise): 123 | 124 | 125 | B, C, N = pt.size() 126 | raw_pt = pt[:,:3,:].contiguous() 127 | normal = pt[:,3:,:].transpose(1, 2).contiguous() if C > 3 else None 128 | 129 | 130 | x = F.relu(self.bn1(self.conv1(raw_pt))) 131 | x = F.relu(self.bn2(self.conv2(x))) 132 | pointfeat = x.clone() 133 | x = F.relu(self.bn3(self.conv3(x))) 134 | x = F.relu(self.bn4(self.conv4(x))) 135 | x = torch.max(x, 2, keepdim=True)[0] 136 | 137 | feat_r = x.view(-1, 1024) 138 | feat_r = torch.cat([feat_r,noise],1) 139 | rotation, scale = self.rot(feat_r) 140 | 141 | feat_d = x.view(-1, 1024, 1).repeat(1, 1, N) 142 | noise_d = noise.view(B, -1, 1).repeat(1, 1, N) 143 | 144 | feat_d = torch.cat([pointfeat, feat_d,noise_d],1) 145 | displacement = self.dis(feat_d) 146 | 147 | pt = raw_pt.transpose(2, 1).contiguous() 148 | 149 | p1 = random.uniform(0, 1) 150 | possi = 0.5#0.0 151 | if p1 > possi: 152 | pt = torch.bmm(pt, rotation).transpose(1, 2).contiguous() 153 | else: 154 | pt = pt.transpose(1, 2).contiguous() 155 | p2 = random.uniform(0, 1) 156 | if p2 > possi: 157 | pt = pt + displacement 158 | 159 | if normal is not None: 160 | normal = (torch.bmm(normal, rotation)).transpose(1, 2).contiguous() 161 | pt = torch.cat([pt,normal],1) 162 | 163 | return pt 164 | 165 | 166 | if __name__=='__main__': 167 | point = torch.randn(8,3,1024) 168 | noise = 0.02 * torch.randn(8, 1024) 169 | 170 | aug=Augmentor() 171 | aug(point, noise) -------------------------------------------------------------------------------- /utils/aug_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #-*- coding:utf-8 _*- 3 | """ 4 | @author:liruihui 5 | @file: loss_utils.py 6 | @time: 2019/09/23 7 | @contact: ruihuili.lee@gmail.com 8 | @github: https://liruihui.github.io/ 9 | @description: 10 | """ 11 | 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | NUM = 1.2#2.0 18 | W = 1.0#10.0 19 | 20 | 21 | def cal_loss_raw(pred, gold): 22 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 23 | 24 | gold = gold.contiguous().view(-1) 25 | 26 | eps = 0.2 27 | n_class = pred.size(1) 28 | 29 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 30 | #one_hot = F.one_hot(gold, pred.shape[1]).float() 31 | 32 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 33 | log_prb = F.log_softmax(pred, dim=1) 34 | 35 | loss_raw = -(one_hot * log_prb).sum(dim=1) 36 | 37 | 38 | loss = loss_raw.mean() 39 | 40 | return loss,loss_raw 41 | 42 | def mat_loss(trans): 43 | d = trans.size()[1] 44 | I = torch.eye(d)[None, :, :] 45 | if trans.is_cuda: 46 | I = I.cuda() 47 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1)) - I, dim=(1, 2))) 48 | return loss 49 | 50 | 51 | 52 | def cls_loss(pred, pred_aug, gold, pc_feat, aug_feat): 53 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 54 | mse_fn = torch.nn.MSELoss(reduce=True, reduction='mean') 55 | 56 | cls_pc, _ = cal_loss_raw(pred, gold) 57 | cls_aug, _ = cal_loss_raw(pred_aug, gold) 58 | # if ispn: 59 | # cls_pc = cls_pc + 0.001*mat_loss(pc_tran) 60 | # cls_aug = cls_aug + 0.001*mat_loss(aug_tran) 61 | 62 | feat_diff = 10.0*mse_fn(pc_feat,aug_feat) 63 | # parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(1.0-cls_pc_raw)**2).cuda() 64 | # cls_diff = (torch.abs(cls_pc_raw - cls_aug_raw) * (parameters*2)).mean() 65 | cls_loss = cls_pc + cls_aug + feat_diff# + cls_diff 66 | 67 | return cls_loss 68 | 69 | def aug_loss(pred, pred_aug, gold): 70 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 71 | # mse_fn = torch.nn.MSELoss(reduce=True, size_average=True) 72 | 73 | cls_pc, cls_pc_raw = cal_loss_raw(pred, gold) 74 | cls_aug, cls_aug_raw = cal_loss_raw(pred_aug, gold) 75 | # if ispn: 76 | # cls_pc = cls_pc + 0.001*mat_loss(pc_tran) 77 | # cls_aug = cls_aug + 0.001*mat_loss(aug_tran) 78 | pc_con = F.softmax(pred, dim=-1)#.max(dim=1)[0] 79 | one_hot = F.one_hot(gold, pred.shape[1]).float() 80 | pc_con = (pc_con*one_hot).max(dim=1)[0] 81 | 82 | 83 | parameters = torch.max(torch.tensor(NUM).cuda(), torch.exp(pc_con) * NUM).cuda() 84 | 85 | # both losses are usable 86 | aug_diff = W * torch.abs(1.0 - torch.exp(cls_aug_raw - cls_pc_raw * parameters)).mean() 87 | #aug_diff = W*torch.abs(cls_aug_raw - cls_pc_raw*parameters).mean() 88 | aug_loss = cls_aug + aug_diff 89 | 90 | return aug_loss 91 | 92 | 93 | 94 | 95 | 96 | if __name__=='__main__': 97 | batch_size=3 98 | num_point=1024 99 | num_cls=40 100 | 101 | target=torch.LongTensor([0,1,2]).cuda() 102 | 103 | pred_pc=torch.randn((batch_size,num_cls)).cuda() 104 | pc_tran=torch.randn((batch_size,64,64)).cuda() 105 | 106 | pred_aug=torch.rand_like(pred_pc).cuda() 107 | aug_tran=torch.rand_like(pred_aug).cuda() 108 | 109 | pc_feat=torch.randn((batch_size,1024)).cuda() 110 | aug_feat=torch.rand_like(pc_feat).cuda() 111 | 112 | augLoss = aug_loss(pred_pc, pred_aug, target) 113 | clsLoss = cls_loss(pred_pc, pred_aug, target, pc_feat, 114 | aug_feat) 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /utils/cal_final_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class accuracy_calculation(): 5 | def __init__(self,confusion_matrix): 6 | self.confusion_matrix=confusion_matrix 7 | self.number_of_labels=self.confusion_matrix.shape[0] 8 | 9 | def get_over_all_accuracy(self): 10 | matrix_diagonal = 0 11 | all_values = 0 12 | for row in range(self.number_of_labels): 13 | for column in range(self.number_of_labels): 14 | all_values += self.confusion_matrix[row][column] 15 | if row == column: 16 | matrix_diagonal += self.confusion_matrix[row][column] 17 | if all_values == 0: 18 | all_values = 1 19 | return float(matrix_diagonal) / all_values 20 | 21 | def get_mean_class_accuracy(self): # added 22 | re = 0 23 | for i in range(self.number_of_labels): 24 | re = re + self.confusion_matrix[i][i] / max(1,np.sum(self.confusion_matrix[i,:])) 25 | return re/self.number_of_labels 26 | 27 | def get_intersection_union_per_class(self): 28 | matrix_diagonal = [self.confusion_matrix[i][i] for i in range(self.number_of_labels)] 29 | errors_summed_by_row = [0] * self.number_of_labels 30 | for row in range(self.number_of_labels): 31 | for column in range(self.number_of_labels): 32 | if row != column: 33 | errors_summed_by_row[row] += self.confusion_matrix[row][column] 34 | errors_summed_by_column = [0] * self.number_of_labels 35 | for column in range(self.number_of_labels): 36 | for row in range(self.number_of_labels): 37 | if row != column: 38 | errors_summed_by_column[column] += self.confusion_matrix[row][column] 39 | 40 | divisor = [0] * self.number_of_labels 41 | for i in range(self.number_of_labels): 42 | divisor[i] = matrix_diagonal[i] + errors_summed_by_row[i] + errors_summed_by_column[i] 43 | if matrix_diagonal[i] == 0: 44 | divisor[i] = 1 45 | 46 | return [float(matrix_diagonal[i]) / divisor[i] for i in range(self.number_of_labels)] 47 | 48 | def get_average_intersection_union(self): 49 | values = self.get_intersection_union_per_class() 50 | class_seen = ((self.confusion_matrix.sum(1)+self.confusion_matrix.sum(0))!=0).sum() 51 | return sum(values) / class_seen 52 | 53 | 54 | # if __name__=='__main__': 55 | # confusion_array=np.load('confusion_m.npy') 56 | # calcula_class=accuracy_calculation(confusion_array) 57 | # print(calcula_class.get_over_all_accuracy()) 58 | # print(calcula_class.get_intersection_union_per_class()) 59 | # print(calcula_class.get_average_intersection_union()) 60 | -------------------------------------------------------------------------------- /utils/centerloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | 6 | 7 | 8 | class CenterLoss(nn.Module): 9 | """Center loss. 10 | 11 | Reference: 12 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 13 | 14 | Args: 15 | num_classes (int): number of classes. 16 | feat_dim (int): feature dimension. 17 | """ 18 | def __init__(self, num_classes=40, feat_dim=2, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (batch_size). 34 | """ 35 | batch_size = x.size(0) 36 | 37 | ### distmat shape is (5,40) #### 38 | # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | # torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | # distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | distmat=torch.cdist(x,self.centers) 43 | 44 | classes = torch.arange(self.num_classes).long() 45 | if self.use_gpu: classes = classes.cuda() 46 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 47 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 48 | 49 | dist = distmat * mask.float() 50 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 51 | 52 | return loss 53 | 54 | 55 | if __name__=='__main__': 56 | cl=CenterLoss().cuda() 57 | feat=torch.randn((5,1024)).cuda() 58 | label=torch.LongTensor([1,2,3,0,1]).cuda() 59 | cl(feat,label) -------------------------------------------------------------------------------- /utils/create_cluster_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import torch.utils.data as data 5 | import open3d as o3d 6 | 7 | import sys 8 | import glob 9 | import h5py 10 | import scipy.spatial as spa 11 | 12 | from sklearn.cluster import KMeans 13 | from tqdm import tqdm 14 | 15 | 16 | np.random.seed(0) 17 | class ModuleNet40(data.Dataset): 18 | def __init__(self,root,split): 19 | if split=='train': 20 | self.split=split 21 | else: 22 | self.split='test' 23 | self.root=root 24 | self.data,self.label=self.get_datalabel() 25 | self.num_points=1024 26 | 27 | def get_datalabel(self): 28 | all_data = [] 29 | all_label = [] 30 | for h5_name in glob.glob(os.path.join(self.root, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%self.split)): 31 | f = h5py.File(h5_name) 32 | data = f['data'][:].astype('float32') 33 | label = f['label'][:].astype('int64') 34 | f.close() 35 | all_data.append(data) 36 | all_label.append(label) 37 | all_data = np.concatenate(all_data, axis=0) 38 | all_label = np.concatenate(all_label, axis=0) 39 | return all_data, all_label 40 | 41 | def translate_pointcloud(self,pointcloud): 42 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 43 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 44 | 45 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 46 | return translated_pointcloud 47 | 48 | 49 | def get_point_cluster(self,point): 50 | #Create Point Cloud 51 | num_point=point.shape[0] 52 | point_cloud=o3d.geometry.PointCloud() 53 | point_cloud.points=o3d.utility.Vector3dVector(point) 54 | 55 | #Downsample PointCloud into Voxel 56 | voxel=o3d.geometry.voxel_down_sample(point_cloud,0.1) 57 | voxel_coor=np.array(voxel.points) 58 | 59 | # Get Each Voxel's coordinate covairance matrix 60 | neighbor_size=20 61 | num_voxel=voxel_coor.shape[0] 62 | dist_matrix=spa.distance_matrix(voxel_coor,voxel_coor) 63 | index=np.argsort(dist_matrix,1) 64 | index=index[:,:neighbor_size] 65 | index=index.reshape(-1) 66 | normalized_coor=voxel_coor[index].reshape(num_voxel,neighbor_size,-1) 67 | normalized_coor=normalized_coor-np.mean(normalized_coor,1,keepdims=True) 68 | neighbor_coor=torch.FloatTensor(normalized_coor) 69 | coor_mat=torch.bmm(neighbor_coor.permute(0,2,1),neighbor_coor) 70 | 71 | ## get each feature cluster's coordinate ### 72 | e,v = torch.symeig(coor_mat, eigenvectors=True) 73 | labels = self.get_cluster_result(e,v) 74 | # coor_list=[] 75 | coor_list=np.zeros((0,50,3)) 76 | for l in np.unique(labels): 77 | l_index=np.where(labels==l)[0] 78 | picked_ind=np.random.permutation(l_index)[:50] 79 | while len(picked_ind)<50: 80 | make_up=50-len(picked_ind) 81 | make_up_index=np.random.permutation(l_index)[:make_up] 82 | picked_ind=np.append(picked_ind,make_up_index) 83 | 84 | vo_co=np.expand_dims(voxel_coor[picked_ind],0) 85 | 86 | coor_list=np.append(coor_list,vo_co,0) 87 | 88 | # coor_list.append(voxel_coor[picked_ind]) 89 | # coor_list=np.array(coor_list) 90 | return coor_list 91 | 92 | 93 | 94 | 95 | 96 | def get_cluster_result(self,value,vector): 97 | eig_value=np.sort(value,1)[:,::-1] 98 | 99 | linearity=(eig_value[:,0]-eig_value[:,1])/eig_value[:,0] 100 | planarity=(eig_value[:,1]-eig_value[:,2])/eig_value[:,1] 101 | scaterring=eig_value[:,2]/eig_value[:,1] 102 | 103 | neighbor_feat=np.stack((linearity,planarity,scaterring),1) 104 | 105 | kmeans = KMeans(n_clusters=3, random_state=0).fit(neighbor_feat) 106 | labels=kmeans.labels_ 107 | 108 | 109 | 110 | return labels 111 | 112 | 113 | 114 | 115 | def __getitem__(self, item): 116 | pointcloud = self.data[item][:self.num_points] 117 | label = self.label[item] 118 | if self.split == 'train': 119 | pointcloud = self.translate_pointcloud(pointcloud) 120 | np.random.shuffle(pointcloud) 121 | 122 | 123 | ### get feature cluster ### 124 | cluster_coord=self.get_point_cluster(pointcloud) 125 | 126 | ### visulization ### 127 | # sampled_voxel=cluster_coord.reshape(150,-1) 128 | # pointcloud_o3d=o3d.geometry.PointCloud() 129 | # pointcloud_o3d.points=o3d.utility.Vector3dVector(cluster_coord[2].reshape(-1,3)) 130 | 131 | # origi_point=o3d.geometry.PointCloud() 132 | # origi_point.points=o3d.utility.Vector3dVector(pointcloud+2) 133 | 134 | # o3d.visualization.draw_geometries([pointcloud_o3d,origi_point]) 135 | #### visulization ### 136 | 137 | pointcloud=torch.FloatTensor(pointcloud) 138 | label=torch.LongTensor(label) 139 | 140 | pointcloud=pointcloud.permute(1,0) 141 | 142 | return pointcloud, torch.FloatTensor(cluster_coord), label 143 | 144 | def __len__(self): 145 | return self.data.shape[0] 146 | 147 | 148 | 149 | 150 | 151 | def get_sets(data_path,batch_size): 152 | train_data=ModuleNet40(data_path,split='train') 153 | train_loader=data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True,num_workers=0) 154 | 155 | test_data=ModuleNet40(data_path,split='test') 156 | test_loader=data.DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True,num_workers=0) 157 | 158 | valid_dataset=ModuleNet40(data_path,split='valid') 159 | valid_loader=data.DataLoader(dataset=valid_dataset,batch_size=batch_size,shuffle=True,num_workers=0) 160 | 161 | return train_loader,test_loader,valid_loader 162 | 163 | 164 | 165 | 166 | 167 | 168 | if __name__=='__main__': 169 | data_path='D:\Computer_vision\Dataset\Modulenet40\ModelNet40\data' 170 | dataset=ModuleNet40(data_path,'test') 171 | 172 | save_path='D:/Computer_vision/Dataset/Modulenet40/ModelNet40_cluster/Test' 173 | # point_cloud,point_cluster,label=dataset[0] 174 | 175 | for i in tqdm(range(len(dataset))): 176 | point_cloud,point_cluster,label=dataset[i] 177 | 178 | point_cloud=point_cloud.numpy() 179 | point_cluster=point_cluster.numpy() 180 | label=label.numpy() 181 | 182 | file_name='data_{}'.format(i) 183 | file_save_path=os.path.join(save_path,file_name) 184 | np.savez(file_save_path,pc=point_cloud,p_clster=point_cluster,label=label) 185 | 186 | # target_cls=4 187 | # label=dataset.label 188 | # index=np.where(label==target_cls)[0] 189 | # # for i in index: 190 | # # inpt,label=dataset[i] 191 | # a,b,c=get_sets(data_path,10) 192 | # for (point,cluster_point,label) in tqdm(a): 193 | # aaa=1 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /utils/create_heatmap.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from numpy import random 4 | 5 | from numpy.core.defchararray import mod 6 | sys.path.append(os.getcwd()) 7 | 8 | 9 | import numpy as np 10 | import torch 11 | from model.PointNet import PointNet_cls_CAM 12 | from Dataloader.ModelNet40 import ModuleNet40 13 | 14 | import open3d as o3d 15 | 16 | def vis_point_imp(datset,label_index): 17 | model=PointNet_cls_CAM(inpt_dim=3,num_cls=40) 18 | model.eval() 19 | 20 | model_param_path='./Exp/PointNet_CAM/pth/epoch_190' 21 | dic=torch.load(model_param_path) 22 | model.load_state_dict(dic['model_state']) 23 | params=list(model.parameters()) 24 | softmaxs_weight=params[-2].detach().cpu().numpy() 25 | 26 | 27 | 28 | point_feat = 'mlp1' 29 | features_blobs = [] 30 | def hook_feature(module, input, output): 31 | outdata=output.squeeze(0).permute(1,0) 32 | features_blobs.append(outdata.data.cpu().numpy()) 33 | 34 | model._modules.get(point_feat).register_forward_hook(hook_feature) 35 | 36 | # label_index=2 37 | all_data=dataset.data 38 | all_label=dataset.label.squeeze() 39 | target_index=np.where(all_label==label_index)[0] 40 | for index in target_index: 41 | inpt=torch.FloatTensor(all_data[index]).permute(1,0).unsqueeze(0) 42 | out=model(inpt) 43 | point_importance=get_point_importance(features_blobs[0],softmaxs_weight[label_index]) 44 | vis_point(all_data[index],point_importance) 45 | 46 | 47 | 48 | def vis_point(point,importance): 49 | num_point=point.shape[0] 50 | 51 | pointcloud=o3d.geometry.PointCloud() 52 | pointcloud.points=o3d.utility.Vector3dVector(point) 53 | 54 | point_color=np.array([[0,0,255]]) 55 | point_color=np.repeat(point_color,num_point,0) 56 | 57 | for index in range(point_color.shape[0]): 58 | imp_factor=importance[index] 59 | point_color[index,0]=255*imp_factor 60 | point_color[index,-1]=255*(1-imp_factor) 61 | 62 | 63 | 64 | affect_index=np.where(importance!=0)[0] 65 | print(affect_index.shape[0]) 66 | # point_color[affect_index]=np.array([255,0,0]) 67 | pointcloud.colors=o3d.utility.Vector3dVector(point_color) 68 | 69 | 70 | ske_point=o3d.geometry.PointCloud() 71 | ske_point.points=o3d.utility.Vector3dVector(point[affect_index]+np.array([2,0,0])) 72 | 73 | 74 | random_point_index=np.random.permutation(np.arange(num_point))[:affect_index.shape[0]] 75 | random_point=o3d.geometry.PointCloud() 76 | random_point.points=o3d.utility.Vector3dVector(point[random_point_index]+np.array([-2,0,0])) 77 | random_color=np.repeat(np.array([[0,255,0]]),random_point_index.shape[0],0) 78 | random_point.colors=o3d.utility.Vector3dVector(random_color) 79 | 80 | o3d.visualization.draw_geometries([pointcloud,ske_point,random_point]) 81 | 82 | 83 | 84 | 85 | 86 | 87 | def get_point_importance(feature,weight): 88 | num_point=feature.shape[0] 89 | point_imp=np.zeros(num_point) 90 | 91 | point_source=np.argmax(feature,0) 92 | pinv_feat=np.max(feature,0) 93 | matter_point=np.unique(point_source) 94 | for mater_index in matter_point: 95 | source=np.where(point_source==mater_index)[0] 96 | 97 | picked_feat=pinv_feat[source] 98 | picked_weight=weight[source] 99 | 100 | point_imp[mater_index]=np.sum(picked_feat*picked_weight) 101 | # cam = cam - np.min(cam) 102 | # cam_img = cam / np.max(cam) 103 | point_imp=np.clip(point_imp,0,None) 104 | point_imp=(point_imp-np.min(point_imp))/(np.max(point_imp)-np.min(point_imp)) 105 | 106 | return point_imp 107 | 108 | 109 | 110 | 111 | if __name__=='__main__': 112 | target_cls=6 113 | datapath='D:/Computer_vision/Dataset/Modulenet40/ModelNet40/data' 114 | dataset=ModuleNet40(datapath,'test') 115 | # train_loader,test_loader,valid_loader=get_sets(datapath,batch_size=10) 116 | vis_point_imp(dataset,label_index=target_cls) 117 | -------------------------------------------------------------------------------- /utils/curvenet_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Yue Wang 3 | @Contact: yuewangx@mit.edu 4 | @File: pointnet_util.py 5 | @Time: 2018/10/13 10:39 PM 6 | 7 | Modified by 8 | @Author: Tiange Xiang 9 | @Contact: txia7609@uni.sydney.edu.au 10 | @Time: 2021/01/21 3:10 PM 11 | """ 12 | 13 | import torch 14 | from torch._C import device 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from time import time 18 | import numpy as np 19 | 20 | from .walk import Walk 21 | 22 | 23 | def knn(x, k): 24 | k = k + 1 25 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 26 | xx = torch.sum(x**2, dim=1, keepdim=True) 27 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 28 | 29 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 30 | return idx 31 | 32 | def normal_knn(x, k): 33 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 34 | xx = torch.sum(x**2, dim=1, keepdim=True) 35 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 36 | 37 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 38 | return idx 39 | 40 | def pc_normalize(pc): 41 | l = pc.shape[0] 42 | centroid = np.mean(pc, axis=0) 43 | pc = pc - centroid 44 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 45 | pc = pc / m 46 | return pc 47 | 48 | def square_distance(src, dst): 49 | """ 50 | Calculate Euclid distance between each two points. 51 | """ 52 | B, N, _ = src.shape 53 | _, M, _ = dst.shape 54 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 55 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 56 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 57 | return dist 58 | 59 | def index_points(points, idx): 60 | """ 61 | 62 | Input: 63 | points: input points data, [B, N, C] 64 | idx: sample index data, [B, S] 65 | Return: 66 | new_points:, indexed points data, [B, S, C] 67 | """ 68 | device = points.device 69 | B = points.shape[0] 70 | view_shape = list(idx.shape) 71 | view_shape[1:] = [1] * (len(view_shape) - 1) 72 | repeat_shape = list(idx.shape) 73 | repeat_shape[0] = 1 74 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 75 | new_points = points[batch_indices, idx, :] 76 | return new_points 77 | 78 | 79 | def farthest_point_sample(xyz, npoint): 80 | """ 81 | Input: 82 | xyz: pointcloud data, [B, N, 3] 83 | npoint: number of samples 84 | Return: 85 | centroids: sampled pointcloud index, [B, npoint] 86 | """ 87 | device = xyz.device 88 | B, N, C = xyz.shape 89 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 90 | distance = torch.ones(B, N).to(device) * 1e10 91 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) * 0 92 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 93 | for i in range(npoint): 94 | centroids[:, i] = farthest 95 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 96 | dist = torch.sum((xyz - centroid) ** 2, -1) 97 | mask = dist < distance 98 | distance[mask] = dist[mask] 99 | farthest = torch.max(distance, -1)[1] 100 | return centroids 101 | 102 | def query_ball_point(radius, nsample, xyz, new_xyz): 103 | """ 104 | Input: 105 | radius: local region radius 106 | nsample: max sample number in local region 107 | xyz: all points, [B, N, 3] 108 | new_xyz: query points, [B, S, 3] 109 | Return: 110 | group_idx: grouped points index, [B, S, nsample] 111 | """ 112 | device = xyz.device 113 | B, N, C = xyz.shape 114 | _, S, _ = new_xyz.shape 115 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 116 | sqrdists = square_distance(new_xyz, xyz) 117 | group_idx[sqrdists > radius ** 2] = N 118 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 119 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 120 | mask = group_idx == N 121 | group_idx[mask] = group_first[mask] 122 | return group_idx 123 | 124 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 125 | """ 126 | Input: 127 | npoint: 128 | radius: 129 | nsample: 130 | xyz: input points position data, [B, N, 3] 131 | points: input points data, [B, N, D] 132 | Return: 133 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 134 | new_points: sampled points data, [B, npoint, nsample, 3+D] 135 | """ 136 | new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint)) 137 | torch.cuda.empty_cache() 138 | 139 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 140 | torch.cuda.empty_cache() 141 | 142 | new_points = index_points(points, idx) 143 | torch.cuda.empty_cache() 144 | 145 | if returnfps: 146 | return new_xyz, new_points, idx 147 | else: 148 | return new_xyz, new_points 149 | 150 | class Attention_block(nn.Module): 151 | ''' 152 | Used in attention U-Net. 153 | ''' 154 | def __init__(self,F_g,F_l,F_int): 155 | super(Attention_block,self).__init__() 156 | self.W_g = nn.Sequential( 157 | nn.Conv1d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 158 | nn.BatchNorm1d(F_int) 159 | ) 160 | 161 | self.W_x = nn.Sequential( 162 | nn.Conv1d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 163 | nn.BatchNorm1d(F_int) 164 | ) 165 | 166 | self.psi = nn.Sequential( 167 | nn.Conv1d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 168 | nn.BatchNorm1d(1), 169 | nn.Sigmoid() 170 | ) 171 | 172 | def forward(self,g,x): 173 | g1 = self.W_g(g) 174 | x1 = self.W_x(x) 175 | psi = F.leaky_relu(g1+x1, negative_slope=0.2) 176 | psi = self.psi(psi) 177 | 178 | return psi, 1. - psi 179 | 180 | 181 | class LPFA(nn.Module): 182 | def __init__(self, in_channel, out_channel, k, mlp_num=2, initial=False): 183 | super(LPFA, self).__init__() 184 | self.k = k 185 | # self.device = 'cuda' 186 | self.initial = initial 187 | 188 | if not initial: 189 | self.xyz2feature = nn.Sequential( 190 | nn.Conv2d(9, in_channel, kernel_size=1, bias=False), 191 | nn.BatchNorm2d(in_channel)) 192 | 193 | self.mlp = [] 194 | for _ in range(mlp_num): 195 | self.mlp.append(nn.Sequential(nn.Conv2d(in_channel, out_channel, 1, bias=False), 196 | nn.BatchNorm2d(out_channel), 197 | nn.LeakyReLU(0.2))) 198 | in_channel = out_channel 199 | self.mlp = nn.Sequential(*self.mlp) 200 | 201 | def forward(self, x, xyz, idx=None): 202 | x = self.group_feature(x, xyz, idx) 203 | x = self.mlp(x) 204 | 205 | if self.initial: 206 | x = x.max(dim=-1, keepdim=False)[0] 207 | else: 208 | x = x.mean(dim=-1, keepdim=False) 209 | 210 | return x 211 | 212 | def group_feature(self, x, xyz, idx): 213 | batch_size, num_dims, num_points = x.size() 214 | device=x.device 215 | if idx is None: 216 | idx = knn(xyz, k=self.k)[:,:,:self.k] # (batch_size, num_points, k) 217 | 218 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 219 | idx = idx + idx_base 220 | idx = idx.view(-1) 221 | 222 | xyz = xyz.transpose(2, 1).contiguous() # bs, n, 3 223 | point_feature = xyz.view(batch_size * num_points, -1)[idx, :] 224 | point_feature = point_feature.view(batch_size, num_points, self.k, -1) # bs, n, k, 3 225 | points = xyz.view(batch_size, num_points, 1, 3).expand(-1, -1, self.k, -1) # bs, n, k, 3 226 | 227 | 228 | 229 | 230 | #### it's made up by original points xyz, neighbor's xyz, the xyz's difference of neighbor point and original point 231 | point_feature = torch.cat((points, point_feature, point_feature - points), 232 | dim=3).permute(0, 3, 1, 2).contiguous() 233 | 234 | if self.initial: 235 | return point_feature 236 | 237 | x = x.transpose(2, 1).contiguous() # bs, n, c 238 | feature = x.view(batch_size * num_points, -1)[idx, :] 239 | feature = feature.view(batch_size, num_points, self.k, num_dims) #bs, n, k, c 240 | x = x.view(batch_size, num_points, 1, num_dims) 241 | feature = feature - x 242 | 243 | feature = feature.permute(0, 3, 1, 2).contiguous() 244 | point_feature = self.xyz2feature(point_feature) #bs, c, n, k 245 | feature = F.leaky_relu(feature + point_feature, 0.2) 246 | return feature #bs, c, n, k 247 | 248 | 249 | class PointNetFeaturePropagation(nn.Module): 250 | def __init__(self, in_channel, mlp, att=None): 251 | super(PointNetFeaturePropagation, self).__init__() 252 | self.mlp_convs = nn.ModuleList() 253 | self.mlp_bns = nn.ModuleList() 254 | last_channel = in_channel 255 | self.att = None 256 | if att is not None: 257 | self.att = Attention_block(F_g=att[0],F_l=att[1],F_int=att[2]) 258 | 259 | for out_channel in mlp: 260 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 261 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 262 | last_channel = out_channel 263 | 264 | def forward(self, xyz1, xyz2, points1, points2): 265 | """ 266 | Input: 267 | xyz1: input points position data, [B, C, N] 268 | xyz2: sampled input points position data, [B, C, S], skipped xyz 269 | points1: input points data, [B, D, N] 270 | points2: input points data, [B, D, S], skipped features 271 | Return: 272 | new_points: upsampled points data, [B, D', N] 273 | """ 274 | xyz1 = xyz1.permute(0, 2, 1) 275 | xyz2 = xyz2.permute(0, 2, 1) 276 | 277 | points2 = points2.permute(0, 2, 1) 278 | B, N, C = xyz1.shape 279 | _, S, _ = xyz2.shape 280 | 281 | if S == 1: 282 | interpolated_points = points2.repeat(1, N, 1) 283 | else: 284 | dists = square_distance(xyz1, xyz2) 285 | dists, idx = dists.sort(dim=-1) 286 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 287 | 288 | dist_recip = 1.0 / (dists + 1e-8) 289 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 290 | weight = dist_recip / norm 291 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 292 | 293 | # skip attention 294 | if self.att is not None: 295 | psix, psig = self.att(interpolated_points.permute(0, 2, 1), points1) 296 | points1 = points1 * psix 297 | 298 | if points1 is not None: 299 | points1 = points1.permute(0, 2, 1) 300 | new_points = torch.cat([points1, interpolated_points], dim=-1) 301 | else: 302 | new_points = interpolated_points 303 | 304 | new_points = new_points.permute(0, 2, 1) 305 | 306 | for i, conv in enumerate(self.mlp_convs): 307 | bn = self.mlp_bns[i] 308 | new_points = F.leaky_relu(bn(conv(new_points)), 0.2) 309 | 310 | return new_points 311 | 312 | 313 | class CIC(nn.Module): 314 | def __init__(self, npoint, radius, k, in_channels, output_channels, bottleneck_ratio=2, mlp_num=2, curve_config=None): 315 | super(CIC, self).__init__() 316 | self.in_channels = in_channels 317 | self.output_channels = output_channels 318 | self.bottleneck_ratio = bottleneck_ratio 319 | self.radius = radius 320 | self.k = k 321 | self.npoint = npoint 322 | 323 | planes = in_channels // bottleneck_ratio 324 | 325 | self.use_curve = curve_config is not None 326 | if self.use_curve: 327 | self.curveaggregation = CurveAggregation(planes) 328 | self.curvegrouping = CurveGrouping(planes, k, curve_config[0], curve_config[1]) 329 | 330 | self.conv1 = nn.Sequential( 331 | nn.Conv1d(in_channels, 332 | planes, 333 | kernel_size=1, 334 | bias=False), 335 | nn.BatchNorm1d(in_channels // bottleneck_ratio), 336 | nn.LeakyReLU(negative_slope=0.2, inplace=True)) 337 | 338 | self.conv2 = nn.Sequential( 339 | nn.Conv1d(planes, output_channels, kernel_size=1, bias=False), 340 | nn.BatchNorm1d(output_channels)) 341 | 342 | if in_channels != output_channels: 343 | self.shortcut = nn.Sequential( 344 | nn.Conv1d(in_channels, 345 | output_channels, 346 | kernel_size=1, 347 | bias=False), 348 | nn.BatchNorm1d(output_channels)) 349 | 350 | self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 351 | 352 | self.maxpool = MaskedMaxPool(npoint, radius, k) 353 | 354 | self.lpfa = LPFA(planes, planes, k, mlp_num=mlp_num, initial=False) 355 | 356 | def forward(self, xyz, x): 357 | 358 | # max pool 359 | if xyz.size(-1) != self.npoint: 360 | xyz, x = self.maxpool( 361 | xyz.transpose(1, 2).contiguous(), x) 362 | xyz = xyz.transpose(1, 2) 363 | 364 | shortcut = x 365 | x = self.conv1(x) # bs, c', n 366 | 367 | idx = knn(xyz, self.k) 368 | 369 | if self.use_curve: 370 | # curve grouping 371 | curves = self.curvegrouping(x, xyz, idx[:,:,1:]) # avoid self-loop 372 | 373 | # curve aggregation 374 | x = self.curveaggregation(x, curves) 375 | 376 | x = self.lpfa(x, xyz, idx=idx[:,:,:self.k]) #bs, c', n, k 377 | 378 | x = self.conv2(x) # bs, c, n 379 | 380 | if self.in_channels != self.output_channels: 381 | shortcut = self.shortcut(shortcut) 382 | 383 | x = self.relu(x + shortcut) 384 | 385 | return xyz, x 386 | 387 | 388 | class CurveAggregation(nn.Module): 389 | def __init__(self, in_channel): 390 | super(CurveAggregation, self).__init__() 391 | self.in_channel = in_channel 392 | mid_feature = in_channel // 2 393 | self.conva = nn.Conv1d(in_channel, 394 | mid_feature, 395 | kernel_size=1, 396 | bias=False) 397 | self.convb = nn.Conv1d(in_channel, 398 | mid_feature, 399 | kernel_size=1, 400 | bias=False) 401 | self.convc = nn.Conv1d(in_channel, 402 | mid_feature, 403 | kernel_size=1, 404 | bias=False) 405 | self.convn = nn.Conv1d(mid_feature, 406 | mid_feature, 407 | kernel_size=1, 408 | bias=False) 409 | self.convl = nn.Conv1d(mid_feature, 410 | mid_feature, 411 | kernel_size=1, 412 | bias=False) 413 | self.convd = nn.Sequential( 414 | nn.Conv1d(mid_feature * 2, 415 | in_channel, 416 | kernel_size=1, 417 | bias=False), 418 | nn.BatchNorm1d(in_channel)) 419 | self.line_conv_att = nn.Conv2d(in_channel, 420 | 1, 421 | kernel_size=1, 422 | bias=False) 423 | 424 | def forward(self, x, curves): 425 | curves_att = self.line_conv_att(curves) # bs, 1, c_n, c_l 426 | 427 | curver_inter = torch.sum(curves * F.softmax(curves_att, dim=-1), dim=-1) #bs, c, c_n 428 | curves_intra = torch.sum(curves * F.softmax(curves_att, dim=-2), dim=-2) #bs, c, c_l 429 | 430 | curver_inter = self.conva(curver_inter) # bs, mid, n 431 | curves_intra = self.convb(curves_intra) # bs, mid ,n 432 | 433 | x_logits = self.convc(x).transpose(1, 2).contiguous() 434 | x_inter = F.softmax(torch.bmm(x_logits, curver_inter), dim=-1) # bs, n, c_n 435 | x_intra = F.softmax(torch.bmm(x_logits, curves_intra), dim=-1) # bs, l, c_l 436 | 437 | 438 | curver_inter = self.convn(curver_inter).transpose(1, 2).contiguous() 439 | curves_intra = self.convl(curves_intra).transpose(1, 2).contiguous() 440 | 441 | x_inter = torch.bmm(x_inter, curver_inter) 442 | x_intra = torch.bmm(x_intra, curves_intra) 443 | 444 | curve_features = torch.cat((x_inter, x_intra),dim=-1).transpose(1, 2).contiguous() 445 | x = x + self.convd(curve_features) 446 | 447 | return F.leaky_relu(x, negative_slope=0.2) 448 | 449 | 450 | class CurveGrouping(nn.Module): 451 | def __init__(self, in_channel, k, curve_num, curve_length): 452 | super(CurveGrouping, self).__init__() 453 | self.curve_num = curve_num 454 | self.curve_length = curve_length 455 | self.in_channel = in_channel 456 | self.k = k 457 | 458 | self.att = nn.Conv1d(in_channel, 1, kernel_size=1, bias=False) 459 | 460 | self.walk = Walk(in_channel, k, curve_num, curve_length) 461 | 462 | def forward(self, x, xyz, idx): 463 | # starting point selection in self attention style 464 | x_att = torch.sigmoid(self.att(x)) 465 | x = x * x_att 466 | 467 | _, start_index = torch.topk(x_att, 468 | self.curve_num, 469 | dim=2, 470 | sorted=False) 471 | start_index = start_index.squeeze().unsqueeze(2) 472 | 473 | curves = self.walk(xyz, x, idx, start_index) #bs, c, c_n, c_l 474 | 475 | return curves 476 | 477 | 478 | class MaskedMaxPool(nn.Module): 479 | def __init__(self, npoint, radius, k): 480 | super(MaskedMaxPool, self).__init__() 481 | self.npoint = npoint 482 | self.radius = radius 483 | self.k = k 484 | 485 | def forward(self, xyz, features): 486 | sub_xyz, neighborhood_features = sample_and_group(self.npoint, self.radius, self.k, xyz, features.transpose(1,2)) 487 | 488 | neighborhood_features = neighborhood_features.permute(0, 3, 1, 2).contiguous() 489 | sub_features = F.max_pool2d( 490 | neighborhood_features, kernel_size=[1, neighborhood_features.shape[3]] 491 | ) # bs, c, n, 1 492 | sub_features = torch.squeeze(sub_features, -1) # bs, c, n 493 | return sub_xyz, sub_features 494 | -------------------------------------------------------------------------------- /utils/part_segmentation_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 5 | seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} 6 | for cat in seg_classes.keys(): 7 | for label in seg_classes[cat]: 8 | seg_label_to_cat[label] = cat 9 | class_labels = {c: l for l, c in enumerate(sorted(seg_classes.keys()))} 10 | part_codes = [] 11 | for k in sorted(seg_classes.keys()): part_codes += [seg_classes[k]] 12 | 13 | def get_evaluation_metrics(logits, labels): 14 | 15 | seg = np.ones_like(labels)*(-1) 16 | shape_IoUs = {c: [] for c in seg_classes.keys()} 17 | for i, (l, y) in enumerate(zip(logits, labels)): 18 | y = y.reshape(-1) 19 | cls_parts = seg_classes[seg_label_to_cat[y[0]]] 20 | category = cls_parts[0] 21 | 22 | # Point predictions 23 | s = l[:, cls_parts].argmax(-1) + category 24 | 25 | # Find IoU for each part in the point cloud 26 | part_IoUs = [] 27 | for p in cls_parts: 28 | s_p, y_p = (s == p), (y == p) 29 | iou = (s_p & y_p).sum() / float((s_p | y_p).sum()) if np.any(s_p | s_p) else 1.0 30 | part_IoUs += [iou] 31 | 32 | seg[i] = s 33 | shape_IoUs[seg_label_to_cat[category]] += [np.mean(part_IoUs)] 34 | 35 | # Overall point accuracy 36 | acc = (seg == labels).sum() / np.prod(labels.shape) 37 | 38 | class_accs = [] 39 | for i in range(len(np.unique(labels))): 40 | labels_i = (labels == i) 41 | seg_i = (seg == i) 42 | class_accs.append((labels_i & seg_i).sum() / labels_i.sum()) 43 | 44 | # Mean class accuracy (point-wise) 45 | mean_class_accuracy = np.mean(class_accs) 46 | 47 | mean_shape_IoUs = [] 48 | instance_IoUs = [] 49 | for c in shape_IoUs.keys(): 50 | instance_IoUs += shape_IoUs[c] 51 | mean_shape_IoUs += [np.mean(shape_IoUs[c])] 52 | 53 | # Overall IoU on all samples 54 | average_instance_IoUs = np.mean(instance_IoUs) 55 | 56 | # Mean class IoU: average IoUs of (Airplane, bag, cap, ..., table) 57 | average_shape_IoUs = np.mean(mean_shape_IoUs) 58 | 59 | summary = {} 60 | summary["acc"] = acc 61 | summary["mean_class_accuracy"] = mean_class_accuracy 62 | summary["average_instance_IoUs"] = average_instance_IoUs 63 | summary["average_shape_IoUs"] = average_shape_IoUs 64 | summary["shape_IoUs"] = {k: v for k, v in zip(seg_classes.keys(), mean_shape_IoUs)} 65 | 66 | return summary 67 | 68 | 69 | def compute_overall_iou(pred, target, num_classes): 70 | shape_ious = [] 71 | pred = pred.max(dim=2)[1] # (batch_size, num_points) the pred_class_idx of each point in each sample 72 | pred_np = pred.cpu().data.numpy() 73 | 74 | target_np = target.cpu().data.numpy() 75 | for shape_idx in range(pred.size(0)): # sample_idx 76 | part_ious = [] 77 | for part in range(num_classes): # class_idx! no matter which category, only consider all part_classes of all categories, check all 50 classes 78 | # for target, each point has a class no matter which category owns this point! also 50 classes!!! 79 | # only return 1 when both belongs to this class, which means correct: 80 | I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 81 | # always return 1 when either is belongs to this class: 82 | U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part)) 83 | 84 | F = np.sum(target_np[shape_idx] == part) 85 | 86 | if F != 0: 87 | iou = I / float(U) # iou across all points for this class 88 | part_ious.append(iou) # append the iou of this class 89 | shape_ious.append(np.mean(part_ious)) # each time append an average iou across all classes of this sample (sample_level!) 90 | return shape_ious # [batch_size] 91 | -------------------------------------------------------------------------------- /utils/pointnet2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | 29 | Input: 30 | src: source points, [B, N, C] 31 | dst: target points, [B, M, C] 32 | Output: 33 | dist: per-point square distance, [B, N, M] 34 | """ 35 | B, N, _ = src.shape 36 | _, M, _ = dst.shape 37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 38 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 40 | return dist 41 | 42 | 43 | def index_points(points, idx): 44 | """ 45 | 46 | Input: 47 | points: input points data, [B, N, C] 48 | idx: sample index data, [B, S] 49 | Return: 50 | new_points:, indexed points data, [B, S, C] 51 | """ 52 | device = points.device 53 | B = points.shape[0] 54 | view_shape = list(idx.shape) 55 | view_shape[1:] = [1] * (len(view_shape) - 1) 56 | repeat_shape = list(idx.shape) 57 | repeat_shape[0] = 1 58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, 3] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, 3] 93 | new_xyz: query points, [B, S, 3] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, 3] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 120 | new_points: sampled points data, [B, npoint, nsample, 3+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 125 | new_xyz = index_points(xyz, fps_idx) 126 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 127 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 128 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 129 | 130 | if points is not None: 131 | grouped_points = index_points(points, idx) 132 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 133 | else: 134 | new_points = grouped_xyz_norm 135 | if returnfps: 136 | return new_xyz, new_points, grouped_xyz, fps_idx 137 | else: 138 | return new_xyz, new_points 139 | 140 | 141 | def sample_and_group_all(xyz, points): 142 | """ 143 | Input: 144 | xyz: input points position data, [B, N, 3] 145 | points: input points data, [B, N, D] 146 | Return: 147 | new_xyz: sampled points position data, [B, 1, 3] 148 | new_points: sampled points data, [B, 1, N, 3+D] 149 | """ 150 | device = xyz.device 151 | B, N, C = xyz.shape 152 | new_xyz = torch.zeros(B, 1, C).to(device) 153 | grouped_xyz = xyz.view(B, 1, N, C) 154 | if points is not None: 155 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 156 | else: 157 | new_points = grouped_xyz 158 | return new_xyz, new_points 159 | 160 | 161 | class PointNetSetAbstraction(nn.Module): 162 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 163 | super(PointNetSetAbstraction, self).__init__() 164 | self.npoint = npoint 165 | self.radius = radius 166 | self.nsample = nsample 167 | self.mlp_convs = nn.ModuleList() 168 | self.mlp_bns = nn.ModuleList() 169 | last_channel = in_channel 170 | for out_channel in mlp: 171 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 172 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 173 | last_channel = out_channel 174 | self.group_all = group_all 175 | 176 | def forward(self, xyz, points): 177 | """ 178 | Input: 179 | xyz: input points position data, [B, C, N] 180 | points: input points data, [B, D, N] 181 | Return: 182 | new_xyz: sampled points position data, [B, C, S] 183 | new_points_concat: sample points feature data, [B, D', S] 184 | """ 185 | xyz = xyz.permute(0, 2, 1) 186 | if points is not None: 187 | points = points.permute(0, 2, 1) 188 | 189 | if self.group_all: 190 | new_xyz, new_points = sample_and_group_all(xyz, points) 191 | else: 192 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 193 | # new_xyz: sampled points position data, [B, npoint, C] 194 | # new_points: sampled points data, [B, npoint, nsample, C+D] 195 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 196 | for i, conv in enumerate(self.mlp_convs): 197 | bn = self.mlp_bns[i] 198 | new_points = F.relu(bn(conv(new_points))) 199 | 200 | 201 | 202 | new_points=new_points.squeeze(-1) 203 | # new_points = torch.max(new_points, 2)[0] 204 | # new_xyz = new_xyz.permute(0, 2, 1) 205 | return new_xyz, new_points 206 | 207 | 208 | class PointNetSetAbstractionMsg(nn.Module): 209 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 210 | super(PointNetSetAbstractionMsg, self).__init__() 211 | self.npoint = npoint 212 | self.radius_list = radius_list 213 | self.nsample_list = nsample_list 214 | self.conv_blocks = nn.ModuleList() 215 | self.bn_blocks = nn.ModuleList() 216 | for i in range(len(mlp_list)): 217 | convs = nn.ModuleList() 218 | bns = nn.ModuleList() 219 | last_channel = in_channel + 3 220 | for out_channel in mlp_list[i]: 221 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 222 | bns.append(nn.BatchNorm2d(out_channel)) 223 | last_channel = out_channel 224 | self.conv_blocks.append(convs) 225 | self.bn_blocks.append(bns) 226 | 227 | def forward(self, xyz, points): 228 | """ 229 | Input: 230 | xyz: input points position data, [B, C, N] 231 | points: input points data, [B, D, N] 232 | Return: 233 | new_xyz: sampled points position data, [B, C, S] 234 | new_points_concat: sample points feature data, [B, D', S] 235 | """ 236 | xyz = xyz.permute(0, 2, 1) 237 | if points is not None: 238 | points = points.permute(0, 2, 1) 239 | 240 | B, N, C = xyz.shape 241 | S = self.npoint 242 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 243 | new_points_list = [] 244 | for i, radius in enumerate(self.radius_list): 245 | K = self.nsample_list[i] 246 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 247 | grouped_xyz = index_points(xyz, group_idx) 248 | grouped_xyz -= new_xyz.view(B, S, 1, C) 249 | if points is not None: 250 | grouped_points = index_points(points, group_idx) 251 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 252 | else: 253 | grouped_points = grouped_xyz 254 | 255 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 256 | for j in range(len(self.conv_blocks[i])): 257 | conv = self.conv_blocks[i][j] 258 | bn = self.bn_blocks[i][j] 259 | grouped_points = F.relu(bn(conv(grouped_points))) 260 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 261 | new_points_list.append(new_points) 262 | 263 | new_xyz = new_xyz.permute(0, 2, 1) 264 | new_points_concat = torch.cat(new_points_list, dim=1) 265 | return new_xyz, new_points_concat 266 | 267 | 268 | class PointNetFeaturePropagation(nn.Module): 269 | def __init__(self, in_channel, mlp): 270 | super(PointNetFeaturePropagation, self).__init__() 271 | self.mlp_convs = nn.ModuleList() 272 | self.mlp_bns = nn.ModuleList() 273 | last_channel = in_channel 274 | for out_channel in mlp: 275 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 276 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 277 | last_channel = out_channel 278 | 279 | def forward(self, xyz1, xyz2, points1, points2): 280 | """ 281 | Input: 282 | xyz1: input points position data, [B, C, N] 283 | xyz2: sampled input points position data, [B, C, S] 284 | points1: input points data, [B, D, N] 285 | points2: input points data, [B, D, S] 286 | Return: 287 | new_points: upsampled points data, [B, D', N] 288 | """ 289 | xyz1 = xyz1.permute(0, 2, 1) 290 | xyz2 = xyz2.permute(0, 2, 1) 291 | 292 | points2 = points2.permute(0, 2, 1) 293 | B, N, C = xyz1.shape 294 | _, S, _ = xyz2.shape 295 | 296 | if S == 1: 297 | interpolated_points = points2.repeat(1, N, 1) 298 | else: 299 | dists = square_distance(xyz1, xyz2) 300 | dists, idx = dists.sort(dim=-1) 301 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 302 | 303 | dist_recip = 1.0 / (dists + 1e-8) 304 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 305 | weight = dist_recip / norm 306 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 307 | 308 | if points1 is not None: 309 | points1 = points1.permute(0, 2, 1) 310 | new_points = torch.cat([points1, interpolated_points], dim=-1) 311 | else: 312 | new_points = interpolated_points 313 | 314 | new_points = new_points.permute(0, 2, 1) 315 | for i, conv in enumerate(self.mlp_convs): 316 | bn = self.mlp_bns[i] 317 | new_points = F.relu(bn(conv(new_points))) 318 | return new_points 319 | 320 | -------------------------------------------------------------------------------- /utils/test_perform_cal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | 6 | 7 | def get_mean_accuracy(prediction,labels,num_cls): 8 | labels=labels.reshape(-1) 9 | prediction=prediction.reshape(-1,num_cls) 10 | prediction=np.argmax(prediction,1) 11 | 12 | similar=(labels==prediction).astype(int) 13 | mean_acc=np.sum(similar)/len(similar) 14 | 15 | return mean_acc 16 | 17 | 18 | 19 | def get_cls_accuracy(prediction,labels): 20 | prediction=np.argmax(prediction,-1) 21 | label=labels.reshape(-1) 22 | 23 | right=np.sum(prediction==label) 24 | accuracy=right/prediction.shape[0] 25 | return accuracy 26 | 27 | 28 | if __name__=='__main__': 29 | predi=np.ones((4,40)) 30 | label=np.ones((4,1)) 31 | get_cls_accuracy(predi,label) 32 | 33 | 34 | 35 | # if __name__=='__main__': 36 | # data=np.load('result.npy',allow_pickle=True).item() 37 | # get_mean_accuracy(data) 38 | 39 | -------------------------------------------------------------------------------- /utils/vis_feature_cluster.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from numpy import random 4 | import numpy 5 | import torch 6 | from numpy.core.defchararray import array, mod 7 | from torch.autograd.grad_mode import enable_grad 8 | sys.path.append(os.getcwd()) 9 | 10 | import scipy.spatial as spa 11 | 12 | import numpy as np 13 | import torch 14 | from model.PointNet import PointNet_cls_CAM 15 | from Dataloader.ModelNet40 import ModuleNet40 16 | 17 | import open3d as o3d 18 | from sklearn.cluster import KMeans 19 | np.random.seed(1) 20 | 21 | 22 | 23 | 24 | def get_point_cluster(point): 25 | num_point=point.shape[0] 26 | 27 | color=np.tile(np.array([[0,0,255]]),(num_point,1)) 28 | 29 | point_cloud=o3d.geometry.PointCloud() 30 | point_cloud.points=o3d.utility.Vector3dVector(point) 31 | point_cloud.colors=o3d.utility.Vector3dVector(color) 32 | # o3d.visualization.draw_geometries([point_cloud]) 33 | 34 | voxel=o3d.geometry.voxel_down_sample(point_cloud,0.05) 35 | voxel_coor=np.array(voxel.points) 36 | voxel_color=np.array(voxel.colors) 37 | 38 | 39 | neighbor_size=20 40 | num_voxel=voxel_coor.shape[0] 41 | dist_matrix=spa.distance_matrix(voxel_coor,voxel_coor) 42 | index=np.argsort(dist_matrix,1) 43 | index=index[:,:neighbor_size] 44 | 45 | index=index.reshape(-1) 46 | normalized_coor=voxel_coor[index].reshape(num_voxel,neighbor_size,-1) 47 | normalized_coor=normalized_coor-np.mean(normalized_coor,1,keepdims=True) 48 | neighbor_coor=torch.FloatTensor(normalized_coor) 49 | coor_mat=torch.bmm(neighbor_coor.permute(0,2,1),neighbor_coor) 50 | 51 | e,v = torch.symeig(coor_mat, eigenvectors=True) 52 | labels = get_cluster_result(e,v) 53 | # print(np.unique(labels,return_counts=True)) 54 | 55 | sample_index=[] 56 | 57 | for l in np.unique(labels): 58 | ind=np.where(labels==l)[0] 59 | picked_ind=np.random.permutation(ind)[:50] 60 | voxel_color[ind]=np.random.randint(low=0,high=255,size=(1,3)) 61 | sample_index.append(picked_ind) 62 | 63 | sample_index=np.concatenate(sample_index) 64 | sample_coor=voxel_coor[sample_index] 65 | sample_color=voxel_color[sample_index] 66 | 67 | sample_point=o3d.geometry.PointCloud() 68 | sample_point.points=o3d.utility.Vector3dVector(sample_coor) 69 | sample_point.colors=o3d.utility.Vector3dVector(sample_color/255) 70 | 71 | 72 | # voxel.colors=o3d.utility.Vector3dVector(voxel_color.astype(np.int)/255) 73 | o3d.visualization.draw_geometries([sample_point]) 74 | 75 | 76 | # voxel_color[index[10]]=np.array([255,0,0]) 77 | # voxel.colors=o3d.utility.Vector3dVector(voxel_color) 78 | 79 | # o3d.visualization.draw_geometries([voxel]) 80 | 81 | 82 | def get_cluster_result(value,vector): 83 | eig_value=np.sort(value,1)[:,::-1] 84 | 85 | linearity=(eig_value[:,0]-eig_value[:,1])/eig_value[:,0] 86 | planarity=(eig_value[:,1]-eig_value[:,2])/eig_value[:,1] 87 | scaterring=eig_value[:,2]/eig_value[:,1] 88 | 89 | neighbor_feat=np.stack((linearity,planarity,scaterring),1) 90 | 91 | kmeans = KMeans(n_clusters=3, random_state=0).fit(neighbor_feat) 92 | labels=kmeans.labels_ 93 | 94 | 95 | 96 | return labels 97 | 98 | # nei_point=o3d.geometry.PointCloud() 99 | # nei_point.points=o3d.utility.Vector3dVector(neighbor_feat) 100 | # o3d.visualization.draw_geometries([nei_point]) 101 | 102 | 103 | # voxel=point_cloud.voxel_down_sample(voxel_size=0.05) 104 | # o3d.visualization.draw_geometries([voxel]) 105 | 106 | 107 | 108 | if __name__=='__main__': 109 | target_cls=1 110 | datapath='D:/Computer_vision/Dataset/Modulenet40/ModelNet40/data' 111 | dataset=ModuleNet40(datapath,'test') 112 | 113 | data,label=dataset.data,dataset.label 114 | index=np.where(label==target_cls)[0] 115 | target_data=data[index] 116 | for i in range(len(data)): 117 | get_point_cluster(target_data[i]) 118 | # train_loader,test_loader,valid_loader=get_sets(datapath,batch_size=10) 119 | 120 | -------------------------------------------------------------------------------- /utils/voting_eval_cls.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | # from util.data_util import ModelNet40 8 | from Dataloader.ModelNet40 import ModuleNet40 9 | # from model.GDANet_cls import GDANET 10 | # from model.cls.DPFA_refine2 import DPFA_refine 11 | from model.cls.GD_DPFA_refine_cls import GD_DPFA 12 | import numpy as np 13 | from torch.utils.data import DataLoader, dataloader 14 | # from util.util import cal_loss, IOStream 15 | import sklearn.metrics as metrics 16 | 17 | from tqdm import tqdm 18 | 19 | class PointcloudScale(object): 20 | def __init__(self, scale_low=2. / 3., scale_high=3. / 2., trans_low=-0.2, trans_high=0.2, trans_open=True): 21 | self.scale_low = scale_low 22 | self.scale_high = scale_high 23 | self.trans_low = trans_low 24 | self.trans_high = trans_high 25 | self.trans_open = trans_open # whether add translation during voting or not 26 | 27 | def __call__(self, pc): 28 | pc=pc.permute(0,2,1) 29 | 30 | bsize = pc.size()[0] 31 | for i in range(bsize): 32 | xyz1 = np.random.uniform(low=self.scale_low, high=self.scale_high, size=[3]) 33 | xyz2 = np.random.uniform(low=self.trans_low, high=self.trans_high, size=[3]) 34 | scales = torch.from_numpy(xyz1).float().cuda() 35 | trans = torch.from_numpy(xyz2).float().cuda() if self.trans_open else 0 36 | pc[i, :, 0:3] = torch.mul(pc[i, :, 0:3], scales)+trans 37 | return pc.permute(0,2,1) 38 | 39 | 40 | def test(): 41 | # test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=5, 42 | # batch_size=args.test_batch_size, shuffle=False, drop_last=False) 43 | 44 | data_path='/data1/jiajing/dataset/ModelNet40/data' 45 | test_loader=DataLoader(ModuleNet40(data_path,'test'),batch_size=16,shuffle=False,drop_last=False) 46 | 47 | 48 | # device = torch.device("cuda" if args.cuda else "cpu") 49 | device = torch.device("cuda") 50 | NUM_PEPEAT = 300 51 | NUM_VOTE = 10 52 | # Try to load models 53 | model = GD_DPFA(num_cls=40,inpt_length=3).to(device) 54 | # model = nn.DataParallel(model) 55 | 56 | 57 | model_path='/data1/jiajing/worksapce/My_Research/bottom-up-segmentation/Exp/GD_DPFA/pth_file/epoch_260' 58 | # model_file=torch.load(model_path)[''] 59 | model.load_state_dict(torch.load(model_path)['model_state']) 60 | model = model.eval() 61 | best_acc = 0 62 | 63 | pointscale=PointcloudScale(scale_low=2. / 3., scale_high=3. / 2., trans_low=-0.2, trans_high=0.2, trans_open=True) 64 | for i in tqdm(range(NUM_PEPEAT)): 65 | test_true = [] 66 | test_pred = [] 67 | 68 | for data, label in tqdm(test_loader,leave=False): 69 | data, label = data.to(device), label.to(device).squeeze() 70 | pred = 0 71 | for v in tqdm(range(NUM_VOTE),leave=False): 72 | new_data = data 73 | batch_size = data.size()[0] 74 | if v > 0: 75 | new_data.data = pointscale(new_data.data) 76 | with torch.no_grad(): 77 | pred += F.softmax(model(new_data), dim=1) 78 | pred /= NUM_VOTE 79 | label = label.view(-1) 80 | pred_choice = pred.max(dim=1)[1] 81 | test_true.append(label.cpu().numpy()) 82 | test_pred.append(pred_choice.detach().cpu().numpy()) 83 | test_true = np.concatenate(test_true) 84 | test_pred = np.concatenate(test_pred) 85 | test_acc = metrics.accuracy_score(test_true, test_pred) 86 | if test_acc > best_acc: 87 | best_acc = test_acc 88 | outstr = 'Voting %d, test acc: %.6f,' % (i, test_acc*100) 89 | # io.cprint(outstr) 90 | print(outstr) 91 | 92 | final_outstr = 'Final voting result test acc: %.6f,' % (best_acc * 100) 93 | # io.cprint(final_outstr) 94 | print(final_outstr) 95 | 96 | 97 | def _init_(): 98 | if not os.path.exists('checkpoints'): 99 | os.makedirs('checkpoints') 100 | if not os.path.exists('checkpoints/'+args.exp_name): 101 | os.makedirs('checkpoints/'+args.exp_name) 102 | 103 | os.system('cp voting_eval_modelnet.py checkpoints'+'/'+args.exp_name+'/'+'voting_eval_modelnet.py.backup') 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser(description='3D Object Classification') 108 | parser.add_argument('--exp_name', type=str, default='GDANet', metavar='N', 109 | help='Name of the experiment') 110 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 111 | help='Size of batch)') 112 | parser.add_argument('--no_cuda', type=bool, default=False, 113 | help='enables CUDA training') 114 | parser.add_argument('--seed', type=int, default=1, metavar='S', 115 | help='random seed (default: 1)') 116 | parser.add_argument('--num_points', type=int, default=1024, 117 | help='num of points to use') 118 | parser.add_argument('--model_path', type=str, default='pretrained/GDANet_ModelNet40_93.4.t7', metavar='N', 119 | help='Pretrained model path') 120 | parser.add_argument('--trans_open', type=bool, default=True, metavar='N', 121 | help='enables input translation during voting') 122 | args = parser.parse_args() 123 | 124 | # _init_() 125 | 126 | # io = IOStream('checkpoints/' + args.exp_name + '/%s_voting.log' % (args.exp_name)) 127 | 128 | # io.cprint(str(args)) 129 | 130 | # args.cuda = not args.no_cuda and torch.cuda.is_available() 131 | # torch.manual_seed(args.seed) 132 | # # if args.cuda: 133 | # # io.cprint('Using GPU') 134 | # torch.cuda.manual_seed(args.seed) 135 | # # else: 136 | # io.cprint('Using CPU') 137 | 138 | test() 139 | -------------------------------------------------------------------------------- /utils/walk.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Tiange Xiang 3 | @Contact: txia7609@uni.sydney.edu.au 4 | @File: walk.py 5 | @Time: 2021/01/21 3:10 PM 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | def batched_index_select(input, dim, index): 15 | views = [input.shape[0]] + \ 16 | [1 if i != dim else -1 for i in range(1, len(input.shape))] 17 | expanse = list(input.shape) 18 | expanse[0] = -1 19 | expanse[dim] = -1 20 | index = index.view(views).expand(expanse) 21 | return torch.gather(input, dim, index) 22 | 23 | def gumbel_softmax(logits, dim, temperature=1): 24 | """ 25 | ST-gumple-softmax w/o random gumbel samplings 26 | input: [*, n_class] 27 | return: flatten --> [*, n_class] an one-hot vector 28 | """ 29 | y = F.softmax(logits / temperature, dim=dim) 30 | 31 | shape = y.size() 32 | _, ind = y.max(dim=-1) 33 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 34 | y_hard.scatter_(1, ind.view(-1, 1), 1) 35 | y_hard = y_hard.view(*shape) 36 | 37 | y_hard = (y_hard - y).detach() + y 38 | return y_hard 39 | 40 | class Walk(nn.Module): 41 | ''' 42 | Walk in the cloud 43 | ''' 44 | def __init__(self, in_channel, k, curve_num, curve_length): 45 | super(Walk, self).__init__() 46 | self.curve_num = curve_num 47 | self.curve_length = curve_length 48 | self.k = k 49 | 50 | self.agent_mlp = nn.Sequential( 51 | nn.Conv2d(in_channel * 2, 52 | 1, 53 | kernel_size=1, 54 | bias=False), nn.BatchNorm2d(1)) 55 | self.momentum_mlp = nn.Sequential( 56 | nn.Conv1d(in_channel * 2, 57 | 2, 58 | kernel_size=1, 59 | bias=False), nn.BatchNorm1d(2)) 60 | 61 | def crossover_suppression(self, cur, neighbor, bn, n, k): 62 | # cur: bs*n, 3 63 | # neighbor: bs*n, 3, k 64 | neighbor = neighbor.detach() 65 | cur = cur.unsqueeze(-1).detach() 66 | dot = torch.bmm(cur.transpose(1,2), neighbor) # bs*n, 1, k 67 | norm1 = torch.norm(cur, dim=1, keepdim=True) 68 | norm2 = torch.norm(neighbor, dim=1, keepdim=True) 69 | divider = torch.clamp(norm1 * norm2, min=1e-8) 70 | ans = torch.div(dot, divider).squeeze() # bs*n, k 71 | 72 | # normalize to [0, 1] 73 | ans = 1. + ans 74 | ans = torch.clamp(ans, 0., 1.0) 75 | 76 | return ans.detach() 77 | 78 | def forward(self, xyz, x, adj, cur): 79 | bn, c, tot_points = x.size() 80 | device=xyz.device 81 | # raw point coordinates 82 | xyz = xyz.transpose(1,2).contiguous # bs, n, 3 83 | 84 | # point features 85 | x = x.transpose(1,2).contiguous() # bs, n, c 86 | 87 | flatten_x = x.view(bn * tot_points, -1) 88 | batch_offset = torch.arange(0, bn, device=device).detach() * tot_points 89 | 90 | # indices of neighbors for the starting points 91 | tmp_adj = (adj + batch_offset.view(-1,1,1)).view(adj.size(0)*adj.size(1),-1) #bs, n, k 92 | 93 | # batch flattened indices for teh starting points 94 | flatten_cur = (cur + batch_offset.view(-1,1,1)).view(-1) 95 | 96 | curves = [] 97 | 98 | # one step at a time 99 | for step in range(self.curve_length): 100 | 101 | if step == 0: 102 | # get starting point features using flattend indices 103 | starting_points = flatten_x[flatten_cur, :].contiguous() 104 | pre_feature = starting_points.view(bn, self.curve_num, -1, 1).transpose(1,2) # bs * n, c 105 | else: 106 | # dynamic momentum 107 | cat_feature = torch.cat((cur_feature.squeeze(), pre_feature.squeeze()),dim=1) 108 | att_feature = F.softmax(self.momentum_mlp(cat_feature),dim=1).view(bn, 1, self.curve_num, 2) # bs, 1, n, 2 109 | cat_feature = torch.cat((cur_feature, pre_feature),dim=-1) # bs, c, n, 2 110 | 111 | # update curve descriptor 112 | pre_feature = torch.sum(cat_feature * att_feature, dim=-1, keepdim=True) # bs, c, n 113 | pre_feature_cos = pre_feature.transpose(1,2).contiguous().view(bn * self.curve_num, -1) 114 | 115 | pick_idx = tmp_adj[flatten_cur] # bs*n, k 116 | 117 | # get the neighbors of current points 118 | pick_values = flatten_x[pick_idx.view(-1),:] 119 | 120 | # reshape to fit crossover suppresion below 121 | pick_values_cos = pick_values.view(bn * self.curve_num, self.k, c) 122 | pick_values = pick_values_cos.view(bn, self.curve_num, self.k, c) 123 | pick_values_cos = pick_values_cos.transpose(1,2).contiguous() 124 | 125 | pick_values = pick_values.permute(0,3,1,2) # bs, c, n, k 126 | 127 | pre_feature_expand = pre_feature.expand_as(pick_values) 128 | 129 | # concat current point features with curve descriptors 130 | pre_feature_expand = torch.cat((pick_values, pre_feature_expand),dim=1) 131 | 132 | # which node to pick next? 133 | pre_feature_expand = self.agent_mlp(pre_feature_expand) # bs, 1, n, k 134 | 135 | if step !=0: 136 | # cross over supression 137 | d = self.crossover_suppression(cur_feature_cos - pre_feature_cos, 138 | pick_values_cos - cur_feature_cos.unsqueeze(-1), 139 | bn, self.curve_num, self.k) 140 | d = d.view(bn, self.curve_num, self.k).unsqueeze(1) # bs, 1, n, k 141 | pre_feature_expand = torch.mul(pre_feature_expand, d) 142 | 143 | pre_feature_expand = gumbel_softmax(pre_feature_expand, -1) #bs, 1, n, k 144 | 145 | cur_feature = torch.sum(pick_values * pre_feature_expand, dim=-1, keepdim=True) # bs, c, n, 1 146 | 147 | cur_feature_cos = cur_feature.transpose(1,2).contiguous().view(bn * self.curve_num, c) 148 | 149 | cur = torch.argmax(pre_feature_expand, dim=-1).view(-1, 1) # bs * n, 1 150 | 151 | flatten_cur = batched_index_select(pick_idx, 1, cur).squeeze() # bs * n 152 | 153 | # collect curve progress 154 | curves.append(cur_feature) 155 | 156 | return torch.cat(curves,dim=-1) 157 | --------------------------------------------------------------------------------