├── Dataset.py ├── LICENSE ├── README.md ├── generate_voxel.py ├── images ├── gt.png ├── mean shift prediction.png ├── mean shift.png └── point cloud.png ├── model.py ├── show_result.py ├── train.py └── voxel └── generate_data_list.py /Dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | For training 3 | ''' 4 | 5 | import os 6 | import sys 7 | import glob 8 | import numpy as np 9 | import random 10 | import h5py 11 | import torch 12 | from PIL import Image 13 | from scipy.ndimage.interpolation import rotate 14 | 15 | class Data_Configs: 16 | sem_names = ['background', 'wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 17 | 'bookshelf', 'picture', 'counter', 'desk', 'curtain', 'refridgerator', 18 | 'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture' 19 | ] 20 | sem_ids = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] 21 | sem_num = len(sem_names) 22 | ins_max_num = 20 # for scannet maximum 23 | 24 | class Data_SCANNET: 25 | def __init__(self, dataset_path, train_scenes, test_scenes , mode): 26 | self.root_folder_4_traintest = dataset_path 27 | if mode == 'train': 28 | self.train_files = train_scenes 29 | print('train files:', len(self.train_files)) 30 | elif mode == 'val': 31 | self.test_files = test_scenes 32 | print('val files:', len(self.test_files)) 33 | 34 | self.mode = mode 35 | 36 | 37 | def load_data_file_voxel(self, file_path): 38 | scene = file_path 39 | fin = h5py.File(os.path.join(self.root_folder_4_traintest, scene + '.h5'), 'r') 40 | rgbs = fin['rgbs'][:] # [H, W, D , 3] 41 | sem = fin['sem_labels'][:] # [H, W, D, 1] 42 | ins = fin['ins_labels'][:] # [H, W, D, 1] 43 | return rgbs, sem, ins 44 | 45 | 46 | def load_voxel(self, file_path): 47 | rgbs, sem, ins = self.load_data_file_voxel(file_path) 48 | return rgbs, sem, ins 49 | 50 | 51 | def __getitem__(self , index): 52 | if self.mode == 'train': 53 | bat_files = self.train_files[index] 54 | elif self.mode == 'val': 55 | bat_files = self.test_files[index] 56 | rgbs, sem, ins = self.load_voxel(bat_files) 57 | 58 | # Data augmentation 59 | if self.mode == 'train': 60 | # rotate degree 61 | angle_list = [15 , 30 , 45 , 75 , 90 , 105 , 120 , 135 , 150 , 165 , 180] 62 | 63 | # which operation needed to do now 64 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 65 | do_rotate = np.random.uniform(0.0, 1.0) < 0.5 66 | # rotate which angle 67 | angles = random.randint(0, 10) 68 | 69 | # flip 70 | if (do_flip == 1): 71 | rgbs = np.ascontiguousarray(np.flip(rgbs , (0,1))) 72 | sem = np.ascontiguousarray(np.flip(sem , (0,1))) 73 | ins = np.ascontiguousarray(np.flip(ins , (0,1))) 74 | # rotate 75 | if (do_rotate == 1): 76 | rgbs = rotate(rgbs , angle = angle_list[angles], axes = (0,1) , mode = 'mirror', order=0 , reshape=False) 77 | sem = rotate(sem , angle = angle_list[angles], axes = (0,1) , mode = 'mirror', order=0 , reshape=False) 78 | ins = rotate(ins , angle = angle_list[angles], axes = (0,1) , mode = 'mirror', order=0 , reshape=False) 79 | 80 | rgbs = np.asarray(rgbs, dtype=np.float32) 81 | sem = np.asarray(sem, dtype=np.int32) 82 | ins = np.asarray(ins, dtype=np.int32) 83 | 84 | return rgbs, sem , ins 85 | 86 | def __len__(self): 87 | if self.mode == 'train': 88 | return len(self.train_files) 89 | elif self.mode == 'val': 90 | return len(self.test_files) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shang Yi Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unofficial implementation of 3D Instance Segmentation via Multi-Task Metric Learning (MTML) 2 | This is unofficial implementation of MTML written in Pytorch. 3 | Notice : We just build the network and see its loss decrease untill epoch 100 while training, and we have not implemented any post-process yet (we just apply simple mean-shift to see its visualization); therefore, **we have no any experiments and performance report right now**. 4 | 5 | ### (1) Setup 6 | * Ubuntu 16.04 + cuda 9.0 7 | * Python 3.6 + Pytorch 1.2 8 | * pyntcloud library 9 | 10 | ### (2) Data Download 11 | We use ScanNet dataset to implement. 12 | ScanNet official website : http://www.scan-net.org/ (for data download) 13 | 14 | ### (3) Data Preprocess (point cloud -> voxel) 15 | ``` 16 | # Generate voxel from point cloud 17 | python generate_voxel.py 18 | # Generate train / test data list 19 | cd voxel 20 | python generate_data_list.py 21 | cd .. 22 | ``` 23 | 24 | ### (3) Train/test 25 | ``` 26 | python train.py 27 | ``` 28 | 29 | ### (4) Visualization 30 | ``` 31 | python show_result.py 32 | ``` 33 | 34 | ### (5) Qualitative Results on ScanNet 35 | 1. Point cloud 36 | ![Arch Image](https://github.com/FishWantToFly/MTML_pytorch_implementation/blob/master/images/point%20cloud.png) 37 | 2. Ground truth segmentation 38 | ![Arch Image](https://github.com/FishWantToFly/MTML_pytorch_implementation/blob/master/images/gt.png) 39 | 3. Mean shift results of feature embedding 40 | ![Arch Image](https://github.com/FishWantToFly/MTML_pytorch_implementation/blob/master/images/mean%20shift.png) 41 | 4. Instance segmentation results from mean shift of feature embedding 42 | ![Arch Image](https://github.com/FishWantToFly/MTML_pytorch_implementation/blob/master/images/mean%20shift%20prediction.png) -------------------------------------------------------------------------------- /generate_voxel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob, os, pickle, json, copy, sys, h5py 3 | 4 | from statistics import mode 5 | import plyfile 6 | from pyntcloud import PyntCloud 7 | from plyfile import PlyData, PlyElement 8 | from collections import Counter 9 | 10 | MTML_VOXEL_SIZE = 0.1 # size for voxel 11 | 12 | def make_dir(dir): 13 | if not os.path.exists(dir): 14 | os.makedirs(dir) 15 | 16 | def read_label_ply(filename): 17 | plydata = PlyData.read(filename) 18 | x = np.asarray(plydata.elements[0].data['x']) 19 | y = np.asarray(plydata.elements[0].data['y']) 20 | z = np.asarray(plydata.elements[0].data['z']) 21 | label = np.asarray(plydata.elements[0].data['label']) 22 | return np.stack([x,y,z], axis=1), label 23 | 24 | def read_color_ply(filename): 25 | plydata = PlyData.read(filename) 26 | x = np.asarray(plydata.elements[0].data['x']) 27 | y = np.asarray(plydata.elements[0].data['y']) 28 | z = np.asarray(plydata.elements[0].data['z']) 29 | r = np.asarray(plydata.elements[0].data['red']) 30 | g = np.asarray(plydata.elements[0].data['green']) 31 | b = np.asarray(plydata.elements[0].data['blue']) 32 | return np.stack([x,y,z,r,g,b], axis=1) 33 | 34 | def collect_label(labelPath, scan): 35 | aggregation = os.path.join(labelPath, scan+'.aggregation.json') 36 | segs = os.path.join(labelPath, scan+'_vh_clean_2.0.010000.segs.json') 37 | sem = os.path.join(labelPath, scan+'_vh_clean_2.labels.ply') 38 | # Load all labels 39 | fid = open(aggregation,'r') 40 | aggreData = json.load(fid) 41 | fid = open(segs,'r') 42 | segsData = json.load(fid) 43 | _, semLabel = read_label_ply(sem) 44 | 45 | # Convert segments to normal labels 46 | segGroups = aggreData['segGroups'] 47 | segIndices = np.array(segsData['segIndices']) 48 | 49 | # outGroups is the output instance labels 50 | outGroups = np.zeros(np.shape(segIndices)) - 1 51 | 52 | for j in range(np.shape(segGroups)[0]): 53 | segGroup = segGroups[j]['segments'] 54 | objectId = segGroups[j]['objectId'] 55 | for k in range(np.shape(segGroup)[0]): 56 | segment = segGroup[k] 57 | ind = np.where(segIndices==segment) 58 | if all(outGroups[ind] == -1) != True: 59 | print('Error!') 60 | outGroups[ind] = int(objectId) 61 | 62 | outGroups = outGroups.astype(int) 63 | return semLabel, outGroups 64 | 65 | def save_h5(h5_filename, rgbs, sem_labels, ins_labels, data_dtype='float32', label_dtype='int32'): 66 | h5_fout = h5py.File(h5_filename) 67 | h5_fout.create_dataset( 68 | 'rgbs', data=rgbs, 69 | compression='gzip', compression_opts=4, 70 | dtype=data_dtype) 71 | h5_fout.create_dataset( 72 | 'sem_labels', data=sem_labels, 73 | compression='gzip', compression_opts=4, 74 | dtype=label_dtype) 75 | h5_fout.create_dataset( 76 | 'ins_labels', data=ins_labels, 77 | compression='gzip', compression_opts=1, 78 | dtype=label_dtype) 79 | h5_fout.close() 80 | 81 | def most_common(lst): 82 | data = Counter(lst) 83 | return max(lst, key=data.get) 84 | 85 | # take into account wall and floor 86 | VALID_CLASS_IDS = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39]) 87 | target_sem_idx = np.arange(40) 88 | count = 0 89 | for i in range(40): 90 | if i in VALID_CLASS_IDS: 91 | count += 1 92 | target_sem_idx[i] = count 93 | else: 94 | target_sem_idx[i] = 0 95 | 96 | def changem(input_array, source_idx, target_idx): 97 | mapping = {} 98 | for i, sidx in enumerate(source_idx): 99 | mapping[sidx] = target_idx[i] 100 | input_array = np.array([mapping[i] for i in input_array]) 101 | return input_array 102 | 103 | 104 | if __name__ == '__main__': 105 | dataset_path = sys.argv[1] 106 | save_path = './voxel' 107 | 108 | anky_cloud_list = sorted(glob.glob(os.path.join(dataset_path, '*/*_vh_clean_2.ply'))) 109 | 110 | save_map_dir = os.path.join(save_path, 'mapping') 111 | save_voxel_dir = os.path.join(save_path, 'voxel') 112 | make_dir(save_path) 113 | make_dir(save_map_dir) 114 | make_dir(save_voxel_dir) 115 | 116 | non_empty_voxel_num_list = [] 117 | index = -1 118 | 119 | for anky_cloud_path in anky_cloud_list: 120 | index += 1 121 | anky_cloud_label_path = anky_cloud_path[:-3]+'labels.ply' 122 | scnen_name = os.path.basename(anky_cloud_path)[:12] 123 | anky_cloud = PyntCloud.from_file(anky_cloud_path) 124 | voxelgrid_id = anky_cloud.add_structure("voxelgrid", size_x=MTML_VOXEL_SIZE, size_y=MTML_VOXEL_SIZE, size_z=MTML_VOXEL_SIZE, \ 125 | regular_bounding_box=False) # regular_bounding_box set false to allow different length of xyz 126 | voxelgrid = anky_cloud.structures[voxelgrid_id] 127 | pc_num = voxelgrid.voxel_x.shape[0] 128 | 129 | x_len, y_len, z_len = voxelgrid.x_y_z 130 | # create voxel->pc dict 131 | voxel_pc_mapping_dict = {} 132 | for i in range(pc_num): 133 | _x = voxelgrid.voxel_x[i] 134 | _y = voxelgrid.voxel_y[i] 135 | _z = voxelgrid.voxel_z[i] 136 | get_list = voxel_pc_mapping_dict.get((_x, _y, _z)) 137 | if get_list == None : 138 | voxel_pc_mapping_dict[(_x, _y, _z)] = [i] 139 | else : 140 | temp_list = [] 141 | for el in (get_list): 142 | temp_list.append(el) 143 | temp_list.append(i) 144 | voxel_pc_mapping_dict[(_x, _y, _z)] = temp_list 145 | 146 | # save voxel -> point cloud mapping 147 | with open(os.path.join(save_map_dir, '%s.pkl' % (scnen_name)), 'wb') as fp: 148 | pickle.dump(voxel_pc_mapping_dict, fp, protocol=pickle.HIGHEST_PROTOCOL) 149 | 150 | ''' 151 | create rgb / ins_label / sem_label 152 | ''' 153 | # pad zero if length is odd (for 3d conv/deconv) 154 | x_len_even, y_len_even, z_len_even = x_len, y_len, z_len 155 | if x_len_even % 2 == 1: 156 | x_len_even += 1 157 | if y_len_even % 2 == 1: 158 | y_len_even += 1 159 | if z_len_even % 2 == 1: 160 | z_len_even += 1 161 | 162 | voxel_sem_label = -1 * np.ones((x_len_even, y_len_even, z_len_even, 1), dtype = np.int32) 163 | voxel_ins_label = -1 * np.ones((x_len_even, y_len_even, z_len_even, 1), dtype = np.int32) 164 | voxel_rgb = np.zeros((x_len_even, y_len_even, z_len_even, 3)) 165 | 166 | sem_label_gt, ins_label_gt = collect_label(os.path.dirname(anky_cloud_label_path), scnen_name) 167 | sem_label_gt[sem_label_gt>=40] = 0 168 | sem_label_gt[sem_label_gt<0] = 0 169 | sem_label_gt = changem(sem_label_gt, np.arange(40), target_sem_idx) 170 | rgb_label_gt = read_color_ply(anky_cloud_path)[:, 3:6] 171 | 172 | for i in range(x_len): 173 | for j in range(y_len): 174 | for k in range(z_len): 175 | pc_list = voxel_pc_mapping_dict.get((i, j, k)) 176 | if pc_list != None : 177 | this_voxel_sem_list = sem_label_gt[pc_list] 178 | this_voxel_ins_list = ins_label_gt[pc_list] 179 | this_voxel_rgb_list = rgb_label_gt[pc_list] 180 | 181 | # sem + ins 182 | _sem = most_common(this_voxel_sem_list) 183 | _ins = most_common(this_voxel_ins_list) 184 | voxel_sem_label[i][j][k] = _sem 185 | voxel_ins_label[i][j][k] = _ins 186 | 187 | # rgb 188 | r_sum, g_sum, b_sum = 0, 0, 0 189 | for l in this_voxel_rgb_list: 190 | r_sum += l[0] 191 | g_sum += l[1] 192 | b_sum += l[2] 193 | r_final = r_sum / len(this_voxel_rgb_list) / 255 194 | g_final = g_sum / len(this_voxel_rgb_list) / 255 195 | b_final = b_sum / len(this_voxel_rgb_list) / 255 196 | voxel_rgb[i][j][k] = (r_final, g_final, b_final) 197 | 198 | # store as .h5 file 199 | rgbs = copy.deepcopy(voxel_rgb) 200 | sem_labels = copy.deepcopy(voxel_sem_label) 201 | ins_labels = copy.deepcopy(voxel_ins_label) 202 | 203 | h5_filename = os.path.join(save_voxel_dir, '%s.h5' % scnen_name) 204 | print(index) 205 | print('{0}'.format(h5_filename)) 206 | print() 207 | if not os.path.isfile(h5_filename): 208 | save_h5(h5_filename, 209 | rgbs, 210 | sem_labels, 211 | ins_labels 212 | ) 213 | -------------------------------------------------------------------------------- /images/gt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FishWantToFly/MTML_pytorch_implementation/05589fb632383e95bfb763f027be2fbf111660a8/images/gt.png -------------------------------------------------------------------------------- /images/mean shift prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FishWantToFly/MTML_pytorch_implementation/05589fb632383e95bfb763f027be2fbf111660a8/images/mean shift prediction.png -------------------------------------------------------------------------------- /images/mean shift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FishWantToFly/MTML_pytorch_implementation/05589fb632383e95bfb763f027be2fbf111660a8/images/mean shift.png -------------------------------------------------------------------------------- /images/point cloud.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FishWantToFly/MTML_pytorch_implementation/05589fb632383e95bfb763f027be2fbf111660a8/images/point cloud.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import sys 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | class MTML(nn.Module): 10 | def __init__(self): 11 | super(MTML, self).__init__() 12 | 13 | self.conv1 = nn.Conv3d(3, 8, kernel_size=7, padding = 3) 14 | self.act = nn.ReLU() 15 | self.conv2_1 = nn.Conv3d(8, 16, kernel_size=1) 16 | self.conv2_2 = nn.Conv3d(8, 16, kernel_size=3, padding = 1) 17 | self.conv3 = nn.Conv3d(16, 16, kernel_size=3, padding = 1) 18 | self.maxpool = nn.MaxPool3d(2) 19 | self.conv4 = nn.Conv3d(16, 32, kernel_size=3, padding = 1) 20 | self.conv5_1 = nn.Conv3d(32, 32, kernel_size=1) 21 | self.conv5_2 = nn.Conv3d(32, 32, kernel_size=3, padding = 1) 22 | self.conv6 = nn.Conv3d(32, 32, kernel_size=3, padding = 1) 23 | self.conv7 = nn.Conv3d(32, 32, kernel_size=3, padding = 1) 24 | self.conv8 = nn.Conv3d(32, 32, kernel_size=3, dilation=2, padding =2) # dilated conv. 25 | self.conv9 = nn.Conv3d(32, 32, kernel_size=3, dilation=2, padding =2) 26 | self.conv10 = nn.Conv3d(32, 32, kernel_size=3, dilation=2, padding =2) 27 | self.conv11 = nn.Conv3d(32, 32, kernel_size=3, dilation=2, padding =2) 28 | self.deconv1 = nn.ConvTranspose3d(96, 128, kernel_size=4, stride= 2 , padding = 1) 29 | self.conv12 = nn.Conv3d(128, 64, kernel_size=1) 30 | self.conv13_1 = nn.Conv3d(64, 3, kernel_size=1) 31 | self.conv13_2 = nn.Conv3d(64, 3, kernel_size=1) 32 | 33 | def forward(self, x): 34 | x = x.permute(0,4,1,2,3) 35 | 36 | out = self.conv1(x) 37 | out = self.act(out) 38 | out1 = self.conv2_1(out) 39 | out1 = self.act(out1) 40 | out = self.conv2_2(out) 41 | out = self.act(out) 42 | out = self.conv3(out) 43 | out = self.act(out) 44 | out = out1 + out 45 | 46 | out = self.maxpool(out) 47 | out = self.conv4(out) 48 | out = self.act(out) 49 | out1 = self.conv5_1(out) 50 | out1 = self.act(out1) 51 | out = self.conv5_2(out) 52 | out = self.act(out) 53 | out = out1 + out 54 | out2 = out 55 | 56 | out = self.conv6(out) 57 | out = self.act(out) 58 | out = self.conv7(out) 59 | out = self.act(out) 60 | out = self.conv8(out) 61 | out = self.act(out) 62 | out = self.conv9(out) 63 | out = self.act(out) 64 | out3 = out 65 | 66 | out = self.conv10(out) 67 | out = self.act(out) 68 | out = self.conv11(out) 69 | out = self.act(out) 70 | 71 | out = torch.cat((out, out2, out3), dim =1) 72 | out = self.deconv1(out) 73 | out = self.conv12(out) 74 | out = self.act(out) 75 | 76 | dir_embedding = self.conv13_1(out) 77 | feature_embedding = self.conv13_2(out) 78 | 79 | dir_embedding = dir_embedding.permute(0, 2, 3, 4, 1) 80 | feature_embedding = feature_embedding.permute(0, 2, 3, 4, 1) 81 | return dir_embedding, feature_embedding -------------------------------------------------------------------------------- /show_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import torch.optim as optim 6 | import os, sys, glob 7 | import numpy as np 8 | from torchvision import transforms 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath('__file__')) 12 | sys.path.append(os.path.join(BASE_DIR, '../3D-BoNet/data_scannet/utils')) 13 | from Dataset import Data_Configs as Data_Configs 14 | from Dataset import Data_SCANNET as Data 15 | import io_util 16 | from tqdm import tqdm 17 | import math 18 | import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multiprocessing as mp 19 | import matplotlib.pyplot as plt 20 | from model import MTML 21 | from sklearn.cluster import MeanShift 22 | from sklearn.datasets.samples_generator import make_blobs 23 | from matplotlib import pyplot as plt 24 | from mpl_toolkits.mplot3d import Axes3D 25 | import matplotlib.cm as cm 26 | from pyntcloud import PyntCloud 27 | 28 | def read_txt(filename): 29 | res= [] 30 | with open(filename) as f: 31 | for line in f: 32 | res.append(line.strip()) 33 | return res 34 | 35 | if __name__ == '__main__': 36 | os.environ["CUDA_VISIBLE_DEVICES"] ='0' 37 | 38 | epoch_num = '080' # which epoch you want to use 39 | MODEL_PATH = os.path.join('checkpoint') 40 | dataset_path ='voxel' 41 | 42 | train_scene_txt = os.path.join(dataset_path,'train.txt') 43 | val_scene_txt = os.path.join(dataset_path ,'val.txt') 44 | train_scenes = read_txt(train_scene_txt) 45 | val_scenes = read_txt(val_scene_txt) 46 | 47 | _dataset_path = os.path.join(dataset_path, 'voxel') 48 | val_data = Data(_dataset_path, train_scenes, val_scenes , mode = 'val') 49 | val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=1, shuffle=False, 50 | num_workers=10) 51 | 52 | mtml = MTML().cuda().eval() 53 | mtml.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'mtml_%s.pth' % (epoch_num)))) 54 | 55 | for i, data in enumerate(val_dataloader): 56 | rgb , sem , ins = data 57 | rgb , sem , ins = rgb.cuda() , sem.cuda() , ins.cuda() 58 | sem = sem.squeeze(0).squeeze(-1).view(-1) 59 | ins = ins.squeeze(0).squeeze(-1).view(-1) 60 | dir_embedding, feature_embedding = mtml(rgb) 61 | rgb = rgb.view(-1,3) 62 | loc = feature_embedding.view(-1,3) 63 | loc = loc[sem > 0] 64 | ins = ins[sem > 0] 65 | rgb = rgb[sem > 0] 66 | 67 | # choose first scene as visualization target 68 | if (i == 0): 69 | break 70 | 71 | # Prepare for plotting 72 | x = np.arange(10) 73 | ys = [i+x+(i*x)**2 for i in range(8)] 74 | colors = cm.rainbow(np.linspace(0, 1, len(ys))) 75 | fig = plt.figure() 76 | ax = fig.add_subplot(111, projection='3d') 77 | 78 | # Do mean shift on feature embedding 79 | ms = MeanShift() 80 | ms.fit(loc.detach().cpu().numpy()) 81 | cluster_centers = ms.cluster_centers_ 82 | labels = ms.labels_ 83 | ax.scatter(cluster_centers[:,0], cluster_centers[:,1], cluster_centers[:,2], marker='x', color='red', s=300, linewidth=5, zorder=10) 84 | 85 | # visualization 86 | for i in np.unique(labels): 87 | ss = loc[labels == i] 88 | ss = ss.detach().cpu().numpy() 89 | ax.scatter(ss[:,0], ss[:,1], ss[:,2], marker='o' , color=colors[i]) 90 | plt.show() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import torch.optim as optim 6 | import os 7 | import sys 8 | import glob 9 | import numpy as np 10 | import random 11 | import copy 12 | from random import shuffle 13 | import argparse 14 | from torchvision import transforms 15 | import torch.optim.lr_scheduler as lr_scheduler 16 | import h5py 17 | 18 | from Dataset import Data_Configs as Data_Configs 19 | from Dataset import Data_SCANNET as Data 20 | from tqdm import tqdm 21 | import importlib 22 | import datetime, time 23 | import logging 24 | import lera 25 | import math 26 | import torch, numpy as np, glob, math, torch.utils.data, scipy.ndimage, multiprocessing as mp 27 | import matplotlib.pyplot as plt 28 | from model import MTML 29 | 30 | def mkdir_p(dir_path): 31 | try: 32 | if not os.path.exists(dir_path): 33 | os.mkdir(dir_path) 34 | except OSError as e: 35 | if e.errno != errno.EEXIST: 36 | raise 37 | 38 | def read_txt(filename): 39 | res= [] 40 | with open(filename) as f: 41 | for line in f: 42 | res.append(line.strip()) 43 | return res 44 | 45 | parser = argparse.ArgumentParser('HTML') 46 | parser.add_argument('--batchsize', type=int, default=1, help='input batch size') 47 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 48 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training') 49 | parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model') 50 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 51 | parser.add_argument('--learning_rate', type=float, default= 5e-4, help='learning rate for training') 52 | parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer') 53 | parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training') 54 | parser.add_argument('--model_name', type=str, default='Bo-net-revise', help='Name of model') 55 | FLAGS = parser.parse_args() 56 | 57 | checkpoint_dir = 'checkpoint' 58 | mkdir_p(checkpoint_dir) 59 | 60 | Weight_Decay = 1e-4 61 | learning_rate = FLAGS.learning_rate 62 | 63 | LOG_FOUT_train = open(os.path.join(checkpoint_dir, 'log_train.txt'), 'w') 64 | LOG_FOUT_train.write(str(FLAGS)+'\n') 65 | LOG_FOUT_test = open(os.path.join(checkpoint_dir, 'log_test.txt'), 'w') 66 | LOG_FOUT_test.write(str(FLAGS)+'\n') 67 | 68 | def log_string_train(out_str): 69 | LOG_FOUT_train.write(out_str+'\n') 70 | LOG_FOUT_train.flush() 71 | print(out_str) 72 | 73 | def log_string_test(out_str): 74 | LOG_FOUT_test.write(out_str+'\n') 75 | LOG_FOUT_test.flush() 76 | print(out_str) 77 | 78 | def get_fea_loss(feature_embedding, batch_group, batch_sem): 79 | ''' 80 | Input : 81 | feature_embedding [B, X, Y, Z, 3] 82 | batch_group [B, X, Y, Z, 1] 83 | ''' 84 | batch_size = batch_group.shape[0] 85 | 86 | delta_var = 0.1 87 | delta_dist = 1.5 88 | cn = 0 # instace number of this scene 89 | 90 | # MTML setting 91 | gamma_var, gamma_dist, gamma_reg = 1, 1, 0.001 92 | loss_var_batch, loss_dist_batch, loss_reg_batch = 0, 0, 0 93 | 94 | for i in range(batch_size): 95 | pc_group = batch_group[i].squeeze(-1) 96 | pc_sem = batch_sem[i].squeeze(-1) 97 | pc_feature_embedding = feature_embedding[i] # [X, Y, Z, 3] 98 | pc_group_unique = torch.unique(pc_group) 99 | 100 | # decide size of average_feaures_embedding 101 | total_ins = 0 102 | for ins in pc_group_unique: 103 | if ins == -1: continue # invalid instance 104 | pos = (pc_group == ins).nonzero() 105 | if (pc_sem[pos[0][0]][pos[0][1]][pos[0][2]] <= 0): continue 106 | total_ins += 1 107 | average_feature_embeddings = torch.zeros((total_ins, 3)).cuda() 108 | 109 | ################## 110 | # compute loss_var 111 | _id = 0 112 | loss_var_sum = 0 113 | for ins in pc_group_unique: 114 | pos = (pc_group == ins).nonzero() 115 | 116 | if ins == -1: continue 117 | if (pc_sem[pos[0][0]][pos[0][1]][pos[0][2]] <= 0): continue 118 | 119 | # feature embeddings of voxels belonged to this instance now 120 | _pc_embedding_feature = pc_feature_embedding[pos[:,0], pos[:, 1], pos[:, 2]].squeeze(1) # [num_voxel, 3] 121 | num_voxel = _pc_embedding_feature.shape[0] 122 | 123 | # compute averaege 124 | average_feature_embeddings_now = torch.mean(_pc_embedding_feature, dim = 0) # [3] 125 | average_feature_embeddings[_id] = average_feature_embeddings_now 126 | _id += 1 127 | 128 | # compute loss_var 129 | diff_embedding_feaures = average_feature_embeddings_now.repeat(num_voxel, 1) - _pc_embedding_feature # [num_voxel, 3] 130 | diff_embedding_feaures = torch.norm(diff_embedding_feaures, p = 2, dim = 1, keepdim = True) 131 | diff_embedding_feaures = torch.clamp(diff_embedding_feaures - delta_var, min = 0) ** 2 132 | loss_var_sum += torch.sum(diff_embedding_feaures) / num_voxel 133 | 134 | if (total_ins == 0) : 135 | cn = 1 136 | loss_var_batch = torch.tensor(0) 137 | else : loss_var_batch += (loss_var_sum / total_ins) 138 | 139 | ################## 140 | # compute loss_dist 141 | C = total_ins 142 | loss_dist_sum = 0 143 | for i in range(C): 144 | for j in range(i + 1, C): # for non-repeated calculation 145 | diff_average_feaures = average_feature_embeddings[i] - average_feature_embeddings[j] 146 | diff_average_feaures = torch.norm(diff_average_feaures, p = 2) 147 | diff_average_feaures = torch.clamp(2 * delta_dist - diff_average_feaures, min = 0) ** 2 148 | loss_dist_sum += diff_average_feaures 149 | if (C == 0 or C == 1): 150 | loss_dist_batch = torch.tensor(0) 151 | cn = 1 152 | else : loss_dist_batch += (loss_dist_sum / (C * (C - 1))) 153 | 154 | ################## 155 | # compute loss_reg 156 | loss_reg_sum = 0 157 | for i in range(C): 158 | diff_average_feaures = torch.norm(average_feature_embeddings[i], p = 2) 159 | loss_reg_sum += diff_average_feaures 160 | if (C == 0): loss_reg_batch = torch.tensor(0) 161 | else : loss_reg_batch += (loss_reg_sum / C) 162 | 163 | loss_var = loss_var_batch / batch_size 164 | loss_dist = loss_dist_batch / batch_size 165 | loss_reg = loss_reg_batch / batch_size 166 | 167 | return loss_var, loss_dist, loss_reg, cn 168 | 169 | def get_dir_loss(dir_embedding, batch_group, batch_sem): 170 | ''' 171 | Input : 172 | dir_embedding [B, X, Y, Z, 3] 173 | batch_group [B, X, Y, Z, 1] 174 | ''' 175 | batch_size = batch_group.shape[0] 176 | loss_dir_batch = 0 177 | 178 | for i in range(batch_size): 179 | pc_group = batch_group[i].squeeze(-1) 180 | pc_sem = batch_sem[i].squeeze(-1) 181 | pc_dir_embedding = dir_embedding[i] # [X, Y, Z, 3] 182 | pc_group_unique = torch.unique(pc_group) 183 | 184 | ################## 185 | # compute loss_dir 186 | total_ins = 0 187 | loss_dir_sum = 0 188 | for ins in pc_group_unique: 189 | pos = (pc_group == ins).nonzero() 190 | if ins == -1: continue 191 | if(pc_sem[pos[0][0]][pos[0][1]][pos[0][2]] <= 0): continue 192 | 193 | _pc_dir_bedding = pc_dir_embedding[pos[:,0], pos[:, 1], pos[:, 2]].squeeze(1) # [num_voxel, 3] 194 | num_voxel = _pc_dir_bedding.shape[0] 195 | 196 | ## Exception : num_voxel only 1 197 | ## Solution now : skip it 198 | if num_voxel <= 1 : continue 199 | 200 | # compute v_i and v_i_GT 201 | v_i = _pc_dir_bedding / torch.norm(_pc_dir_bedding, p = 2, dim = 1, keepdim = True) # normalized dir_embedding # [num_voxel, 3] 202 | x_i = pos.float() # [num_voxel, 3] 203 | x_center = torch.mean(pos.float(), dim = 0).unsqueeze(0) # center of this instance # [1, 3] 204 | x_center = x_center.repeat(num_voxel, 1) # [num_voxel, 3] 205 | 206 | ## Exception : x_i equals to x_center -> cause Denominator to be 0 207 | ## Solution : ignore that center-like voxel 208 | # find which pixel is at the position of instance center 209 | check_center = torch.sum(torch.abs((x_i - x_center)), axis = 1) == 0 210 | 211 | v_i_GT = (x_i - x_center) / torch.norm((x_i - x_center), p = 2, dim = 1, keepdim = True) # [num_voxel, 3] 212 | v_i_GT[check_center, :] = 0 213 | loss_dir_sum += torch.sum(torch.mul(v_i, v_i_GT)) / num_voxel 214 | total_ins += 1 215 | 216 | if (total_ins == 0) : loss_dir_batch += torch.tensor(0) 217 | else : loss_dir_batch += loss_dir_sum / total_ins 218 | 219 | loss_dir = -1 * loss_dir_batch / batch_size 220 | return loss_dir 221 | 222 | if __name__ == '__main__': 223 | os.environ["CUDA_VISIBLE_DEVICES"] ='0' 224 | 225 | dataset_path = './voxel' 226 | train_scene_txt = os.path.join(dataset_path ,'train.txt') 227 | val_scene_txt = os.path.join(dataset_path ,'val.txt') 228 | 229 | train_scenes = read_txt(train_scene_txt) 230 | val_scenes = read_txt(val_scene_txt) 231 | 232 | _dataset_path = os.path.join(dataset_path, 'voxel') 233 | train_data = Data(_dataset_path, train_scenes, val_scenes , mode = 'train') 234 | val_data = Data(_dataset_path, train_scenes, val_scenes , mode = 'val') 235 | train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=FLAGS.batchsize, shuffle=True, 236 | num_workers=10) 237 | val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=FLAGS.batchsize, shuffle=False, 238 | num_workers=10) 239 | 240 | mtml = MTML().cuda() 241 | mtml = torch.nn.DataParallel(mtml, device_ids = [0]) 242 | optim_params = [ 243 | {'params' : mtml.parameters() , 'lr' : FLAGS.learning_rate , 'betas' : (0.9, 0.999) , 'eps' : 1e-08 }, 244 | ] 245 | optimizer = optim.Adam(optim_params , lr=learning_rate ,weight_decay=Weight_Decay) 246 | scheduler = lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5) 247 | 248 | # Ratio of loss function 249 | fea_ratio, dir_ratio, sem_ratio = 1, 0.5, 1 250 | 251 | print("Start training.") 252 | for epoch in range(FLAGS.epoch): 253 | loss_list = [] 254 | for i, data in enumerate(train_dataloader): 255 | optimizer.zero_grad() 256 | rgb , sem , ins = data 257 | rgb , sem , ins = rgb.cuda() , sem.cuda() , ins.cuda() 258 | dir_embedding, feature_embedding = mtml(rgb) 259 | 260 | loss_var, loss_dist, loss_reg , cn = get_fea_loss(feature_embedding, ins, sem) 261 | dir_loss = get_dir_loss(dir_embedding, ins, sem) * dir_ratio 262 | if (cn == 1): 263 | continue 264 | fea_loss = (loss_var + loss_dist + 0.001 * loss_reg) * fea_ratio 265 | total_loss = fea_loss + dir_loss 266 | 267 | total_loss.backward() 268 | optimizer.step() 269 | 270 | loss_list.append([total_loss.item(), loss_var.item(), loss_dist.item(), loss_reg.item(), fea_loss.item(), dir_loss.item()]) 271 | 272 | if i % 20 == 0: 273 | print("Epoch %3d Iteration %3d (train)" % (epoch, i)) 274 | print("%.3f %.3f %.3f %.3f %.3f %.3f" % (total_loss.item(), loss_var.item(), loss_dist.item(), loss_reg.item(), fea_loss.item(), dir_loss.item())) 275 | print('') 276 | 277 | loss_list_final = np.mean(loss_list, axis=0) 278 | log_string_train(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 279 | log_string_train("Epoch %3d (train)" % (epoch)) 280 | log_string_train("%.3f %.3f %.3f %.3f %.3f %.3f" % (loss_list_final[0], loss_list_final[1], loss_list_final[2], loss_list_final[3], loss_list_final[4], loss_list_final[5])) 281 | log_string_train('') 282 | scheduler.step(epoch) 283 | 284 | # Save model 285 | if epoch % 5 == 0: 286 | torch.save(mtml.module.state_dict(), '%s/%s_%.3d.pth' % (checkpoint_dir, 'mtml', epoch)) 287 | 288 | # Testing 289 | if epoch % 5 == 0: 290 | loss_list = [] 291 | with torch.no_grad(): 292 | for i, data in enumerate(val_dataloader): 293 | rgb , sem , ins = data 294 | rgb , sem , ins = rgb.cuda() , sem.cuda() , ins.cuda() 295 | dir_embedding, feature_embedding = mtml(rgb) 296 | 297 | loss_var, loss_dist, loss_reg, cn = get_fea_loss(feature_embedding, ins, sem) 298 | dir_loss = get_dir_loss(dir_embedding, ins, sem) * dir_ratio 299 | if (cn == 1): 300 | continue 301 | fea_loss = (loss_var + loss_dist + 0.001 * loss_reg) * fea_ratio 302 | total_loss = fea_loss + dir_loss 303 | 304 | loss_list.append([total_loss.item(), loss_var.item(), loss_dist.item(), loss_reg.item(), fea_loss.item(), dir_loss.item()]) 305 | 306 | loss_list_final = np.mean(loss_list, axis=0) 307 | log_string_test(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 308 | log_string_test("Epoch %3d (test)" % (epoch)) 309 | log_string_test("%.3f %.3f %.3f %.3f %.3f %.3f" % (loss_list_final[0], loss_list_final[1], loss_list_final[2], loss_list_final[3], loss_list_final[4], loss_list_final[5])) 310 | log_string_test('') 311 | -------------------------------------------------------------------------------- /voxel/generate_data_list.py: -------------------------------------------------------------------------------- 1 | import glob, os, copy, random 2 | import numpy as np 3 | from os import walk 4 | from sklearn.model_selection import train_test_split 5 | 6 | data_list_dir = "./voxel" 7 | data_list = [] 8 | for scene in glob.glob("./voxel/*.h5"): 9 | _, _scene = os.path.split(scene) 10 | data_list.append(_scene[:-3]) 11 | 12 | random.seed() 13 | train_list, val_list = train_test_split(data_list, test_size=0.2) 14 | print("Train data len : %d" % (len(train_list))) 15 | print("Val data len : %d" % (len(val_list))) 16 | 17 | train_save_dir = os.path.join('train.txt') 18 | val_save_dir = os.path.join('val.txt') 19 | with open(train_save_dir, 'w') as f: 20 | for scene in train_list: 21 | f.write("%s\n" % scene) 22 | with open(val_save_dir, 'w') as f: 23 | for scene in val_list: 24 | f.write("%s\n" % scene) 25 | --------------------------------------------------------------------------------