├── requirements.txt ├── pipeline.png ├── .gitignore ├── models ├── __init__.py ├── dpa.py ├── adapter.py └── clip2point.py ├── render ├── __init__.py ├── selector.py ├── render.py └── blocks.py ├── datasets ├── __init__.py ├── scanobjectnn.py ├── modelnet10.py ├── modelnet40_align.py ├── utils.py └── shapenet.py ├── utils.py ├── zeroshot.py ├── README.md ├── fewshot.py └── pretraining.py /requirements.txt: -------------------------------------------------------------------------------- 1 | plyfile 2 | h5py 3 | lightly 4 | torch_optimizer 5 | open3d -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tyhuang0428/CLIP2Point/HEAD/pipeline.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | 4 | data/ 5 | exp_results/ 6 | pre_results/ 7 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip2point import CLIP2Point 2 | from .dpa import DPA 3 | 4 | 5 | __all__ = ['CLIP2Point', 'DPA'] 6 | -------------------------------------------------------------------------------- /render/__init__.py: -------------------------------------------------------------------------------- 1 | from .render import Renderer 2 | from .selector import Selector 3 | 4 | __all__ = ['Renderer', 'Selector'] 5 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelnet10 import ModelNet10 2 | from .modelnet40_align import ModelNet40Align, ModelNet40Ply 3 | from .scanobjectnn import ScanObjectNN 4 | from .shapenet import ShapeNetRender 5 | 6 | 7 | __all__ = ['ModelNet10', 'ModelNet40Align', 'ModelNet40Ply', 'ScanObjectNN', 'ShapeNetRender'] 8 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from plyfile import PlyData 4 | 5 | 6 | class IOStream(): 7 | def __init__(self, path): 8 | self.f = open(path, 'a') 9 | 10 | def cprint(self, text): 11 | print(text) 12 | self.f.write(text+'\n') 13 | self.f.flush() 14 | 15 | def close(self): 16 | self.f.close() 17 | 18 | 19 | def read_state_dict(path): 20 | ckpt = torch.load(path) 21 | base_ckpt = {k.replace("module.", ""): v for k, v in ckpt.items()} 22 | for key in list(base_ckpt.keys()): 23 | if key.startswith('point_model.'): 24 | base_ckpt[key[len('point_model.'):]] = base_ckpt[key] 25 | del base_ckpt[key] 26 | return base_ckpt 27 | 28 | 29 | def read_ply(filename): 30 | """ read XYZ point cloud from filename PLY file """ 31 | plydata = PlyData.read(filename) 32 | pc = plydata['vertex'].data 33 | pc_array = np.array([[x, y, z] for x,y,z in pc]) 34 | return pc_array 35 | -------------------------------------------------------------------------------- /models/dpa.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch.nn as nn 3 | import clip 4 | 5 | from .adapter import SimplifiedAdapter 6 | from render import Renderer, Selector 7 | from utils import read_state_dict 8 | 9 | clip_model, _ = clip.load("ViT-B/32", device='cpu') 10 | 11 | 12 | class DPA(nn.Module): 13 | def __init__(self, args, eval=False): 14 | super().__init__() 15 | self.views = args.views 16 | self.selector = Selector(self.views, args.dim, args.model) 17 | self.renderer = Renderer(points_radius=0.02) 18 | self.pre_model = deepcopy(clip_model.visual) 19 | self.ori_model = deepcopy(clip_model.visual) 20 | if not eval and args.ckpt is not None: 21 | print('loading from %s' % args.ckpt) 22 | self.pre_model.load_state_dict(read_state_dict(args.ckpt)) 23 | self.adapter1 = SimplifiedAdapter(num_views=args.views, in_features=512) 24 | self.adapter2 = SimplifiedAdapter(num_views=args.views, in_features=512) 25 | 26 | def forward(self, points): 27 | azim, elev, dist = self.selector(points) 28 | imgs = self.renderer(points, azim, elev, dist, self.views, rot=True) 29 | b, n, c, h, w = imgs.size() 30 | imgs = imgs.reshape(b * n, c, h, w) 31 | img_feat1 = self.adapter1(self.pre_model(imgs)) 32 | img_feat2 = self.adapter2(self.ori_model(imgs)) 33 | img_feats = (img_feat1 + img_feat2) * 0.5 34 | img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True) 35 | return img_feats 36 | -------------------------------------------------------------------------------- /models/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BatchNormPoint(nn.Module): 6 | def __init__(self, feat_size): 7 | super().__init__() 8 | self.feat_size = feat_size 9 | self.bn = nn.BatchNorm1d(feat_size) 10 | 11 | def forward(self, x): 12 | assert len(x.shape) == 3 13 | s1, s2, s3 = x.shape[0], x.shape[1], x.shape[2] 14 | assert s3 == self.feat_size 15 | x = x.reshape(s1 * s2, self.feat_size) 16 | x = self.bn(x) 17 | return x.reshape(s1, s2, s3) 18 | 19 | 20 | class SimplifiedAdapter(nn.Module): 21 | def __init__(self, num_views=10, in_features=512): 22 | super().__init__() 23 | 24 | self.num_views = num_views 25 | self.in_features = in_features 26 | self.adapter_ratio = 0.6 27 | self.fusion_init = 0.5 28 | self.dropout = 0.075 29 | 30 | self.fusion_ratio = nn.Parameter(torch.tensor([self.fusion_init] * self.num_views), requires_grad=True) 31 | 32 | self.global_f = nn.Sequential( 33 | BatchNormPoint(self.in_features), 34 | nn.Dropout(self.dropout), 35 | nn.Flatten(), 36 | nn.Linear(in_features=self.in_features * self.num_views, 37 | out_features=self.in_features), 38 | nn.BatchNorm1d(self.in_features), 39 | nn.ReLU(), 40 | nn.Dropout(self.dropout), 41 | nn.Linear(in_features=self.in_features, out_features=self.in_features)) 42 | 43 | def forward(self, feat): 44 | img_feat = feat.reshape(-1, self.num_views, self.in_features) 45 | 46 | # Global feature 47 | return self.global_f(img_feat * self.fusion_ratio.reshape(1, -1, 1)) 48 | -------------------------------------------------------------------------------- /models/clip2point.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import torch.nn as nn 4 | import clip 5 | from lightly.loss.ntx_ent_loss import NTXentLoss 6 | 7 | from render import Renderer, Selector 8 | 9 | clip_model, _ = clip.load("ViT-B/32", device='cpu') 10 | 11 | 12 | class CLIP2Point(nn.Module): 13 | def __init__(self, args): 14 | super().__init__() 15 | self.views = args.views 16 | self.selector = Selector(self.views, args.dim, args.model) 17 | self.renderer = Renderer(points_radius=0.02) 18 | self.point_model = deepcopy(clip_model.visual) 19 | self.image_model = deepcopy(clip_model.visual) 20 | self.weights = nn.Parameter(torch.ones([])) 21 | self.criterion = NTXentLoss(temperature = 0.07) 22 | 23 | def infer(self, points, rot=False): 24 | azim, elev, dist = self.selector(points) 25 | imgs = self.renderer(points, azim, elev, dist, self.views, rot=rot) 26 | b, n, c, h, w = imgs.size() 27 | imgs = imgs.reshape(b * n, c, h, w) 28 | imgs = self.point_model(imgs) 29 | img_feats = imgs / imgs.norm(dim=-1, keepdim=True) 30 | return img_feats 31 | 32 | def forward(self, points, images, a, e, d): 33 | batch_size = points.shape[0] 34 | depths = self.renderer(points, a, e, d, 1, aug=True, rot=False) 35 | image_feat = self.image_model(images.squeeze(1)).detach() 36 | depth1 = depths[:, 0] 37 | depth2 = depths[:, 1] 38 | depths = torch.cat([depth1, depth2], dim=0) 39 | depths = self.point_model(depths) 40 | depth1_feat = depths[: batch_size] 41 | depth2_feat = depths[batch_size: ] 42 | depth_feat = (depth1_feat + depth2_feat) * 0.5 43 | depth_loss = self.criterion(depth1_feat, depth2_feat) 44 | image_loss = self.criterion(depth_feat, image_feat) 45 | return image_loss + depth_loss / (self.weights ** 2) + torch.log(self.weights + 1), image_loss, depth_loss 46 | -------------------------------------------------------------------------------- /datasets/scanobjectnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | from torch.utils.data import Dataset 4 | import random 5 | import torch 6 | import numpy as np 7 | 8 | from datasets.utils import pc_normalize 9 | 10 | 11 | class ScanObjectNN(Dataset): 12 | def __init__(self, partition='test', few_num=0, num_points=1024): 13 | assert partition in ('test', 'training') 14 | self._load_ScanObjectNN(partition) 15 | self.num_points = num_points 16 | self.partition = partition 17 | self.few_num = few_num 18 | self._preprocess() 19 | 20 | def __getitem__(self, index): 21 | point, label = self.points[index], self.labels[index] 22 | point = pc_normalize(point) 23 | if self.partition == 'train': 24 | pt_idxs = np.arange(point.shape[0]) 25 | np.random.shuffle(pt_idxs) 26 | point = point[pt_idxs] 27 | return point[: self.num_points], label 28 | return point, label 29 | 30 | def _load_ScanObjectNN(self, partition): 31 | BASE_DIR = '/data1/hty/h5_files/' 32 | DATA_DIR = os.path.join(BASE_DIR, 'main_split') 33 | h5_name = os.path.join(DATA_DIR, f'{partition}_objectdataset.h5') 34 | f = h5py.File(h5_name) 35 | self.points = torch.from_numpy(f['data'][:].astype('float32')) 36 | self.labels = torch.from_numpy(f['label'][:].astype('int64')) 37 | 38 | def _preprocess(self): 39 | if self.partition == 'training' and self.few_num > 0: 40 | num_dict = {i: 0 for i in range(15)} 41 | self.few_points = [] 42 | self.few_labels = [] 43 | random_list = [k for k in range(len(self.labels))] 44 | random.shuffle(random_list) 45 | for i in random_list: 46 | label = self.labels[i].item() 47 | if num_dict[label] == self.few_num: 48 | continue 49 | self.few_points.append(self.points[i]) 50 | self.few_labels.append(self.labels[i]) 51 | num_dict[label] += 1 52 | else: 53 | self.few_points = self.points 54 | self.few_labels = self.labels 55 | 56 | def __len__(self): 57 | return len(self.few_labels) 58 | -------------------------------------------------------------------------------- /datasets/modelnet10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from datasets.utils import pc_normalize, offread_uniformed 8 | 9 | cats = {'bathtub': 0, 'bed': 1, 'chair': 2, 'desk': 3, 'dresser': 4, 'monitor': 5, 'night_stand': 6, 'sofa': 7, 'table': 8, 'toilet': 9} 10 | 11 | 12 | class ModelNet10(Dataset): 13 | def __init__(self, partition='test', few_num=0, num_points=1024): 14 | assert partition in ('test', 'train') 15 | super().__init__() 16 | self.partition = partition 17 | self.few_num = few_num 18 | self.num_points = num_points 19 | self._load_data() 20 | if self.partition == 'train' and self.few_num > 0: 21 | self.paths, self.labels = self._few() 22 | 23 | def _load_data(self): 24 | DATA_DIR = '/data/ModelNet10' 25 | self.paths = [] 26 | self.labels = [] 27 | for cat in os.listdir(DATA_DIR): 28 | cat_path = os.path.join(DATA_DIR, cat, self.partition) 29 | for case in os.listdir(cat_path): 30 | if case.endswith('.off'): 31 | self.paths.append(os.path.join(cat_path, case)) 32 | self.labels.append(cats[cat]) 33 | 34 | def _few(self): 35 | num_dict = {i: 0 for i in range(10)} 36 | few_paths = [] 37 | few_labels = [] 38 | random_list = [k for k in range(len(self.labels))] 39 | random.shuffle(random_list) 40 | for i in random_list: 41 | label = self.labels[i].item() 42 | if num_dict[label] == self.few_num: 43 | continue 44 | few_paths.append(self.paths[i]) 45 | few_labels.append(self.labels[i]) 46 | num_dict[label] += 1 47 | return few_paths, few_labels 48 | 49 | def __getitem__(self, index): 50 | point = torch.from_numpy(offread_uniformed(self.paths[index], 1024)).to(torch.float32) 51 | label = self.labels[index] 52 | point = pc_normalize(point) 53 | if self.partition == 'train': 54 | pt_idxs = np.arange(point.shape[0]) 55 | np.random.shuffle(pt_idxs) 56 | point = point[pt_idxs] 57 | return point[: self.num_points], label 58 | return point, label 59 | 60 | def __len__(self): 61 | return len(self.labels) 62 | -------------------------------------------------------------------------------- /datasets/modelnet40_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | 7 | from utils import read_ply 8 | from datasets.utils import pc_normalize, offread_uniformed 9 | 10 | cats = {'airplane': 0, 'bathtub': 1, 'bed': 2, 'bench': 3, 'bookshelf': 4, 'bottle': 5, 'bowl': 6, 'car': 7, 'chair': 8, 'cone': 9, 'cup': 10, 'curtain': 11, 'desk': 12, 'door': 13, 'dresser': 14, 'flower_pot': 15, 'glass_box': 16, 'guitar': 17, 'keyboard': 18, 'lamp': 19, 'laptop': 20, 'mantel': 21, 'monitor': 22, 'night_stand': 23, 'person': 24, 'piano': 25, 'plant': 26, 'radio': 27, 'range_hood': 28, 'sink': 29, 'sofa': 30, 'stairs': 31, 'stool': 32, 'table': 33, 'tent': 34, 'toilet': 35, 'tv_stand': 36, 'vase': 37, 'wardrobe': 38, 'xbox': 39} 11 | 12 | 13 | class ModelNet40Align(Dataset): 14 | ''' 15 | points are randomly sampled from .off file, so the results of this dataset may be better or wrose than our claim results 16 | ''' 17 | def __init__(self, partition='test', few_num=0, num_points=1024): 18 | assert partition in ('test', 'train') 19 | super().__init__() 20 | self.partition = partition 21 | self.few_num = few_num 22 | self.num_points = num_points 23 | self._load_data() 24 | if self.partition == 'train' and self.few_num > 0: 25 | self.paths, self.labels = self._few() 26 | 27 | def _load_data(self): 28 | DATA_DIR = './data/ModelNet40_manually_aligned' 29 | self.paths = [] 30 | self.labels = [] 31 | for cat in os.listdir(DATA_DIR): 32 | cat_path = os.path.join(DATA_DIR, cat, self.partition) 33 | for case in os.listdir(cat_path): 34 | if case.endswith('.off'): 35 | self.paths.append(os.path.join(cat_path, case)) 36 | self.labels.append(cats[cat]) 37 | 38 | def _few(self): 39 | num_dict = {i: 0 for i in range(40)} 40 | few_paths = [] 41 | few_labels = [] 42 | random_list = [k for k in range(len(self.labels))] 43 | random.shuffle(random_list) 44 | for i in random_list: 45 | label = self.labels[i] 46 | if num_dict[label] == self.few_num: 47 | continue 48 | few_paths.append(self.paths[i]) 49 | few_labels.append(self.labels[i]) 50 | num_dict[label] += 1 51 | return few_paths, few_labels 52 | 53 | def __getitem__(self, index): 54 | point = torch.from_numpy(offread_uniformed(self.paths[index], self.num_points)).to(torch.float32) 55 | label = self.labels[index] 56 | point = pc_normalize(point) 57 | if self.partition == 'train': 58 | pt_idxs = np.arange(point.shape[0]) 59 | np.random.shuffle(pt_idxs) 60 | point = point[pt_idxs] 61 | return point, label 62 | return point, label 63 | 64 | def __len__(self): 65 | return len(self.labels) 66 | 67 | 68 | class ModelNet40Ply(Dataset): 69 | ''' 70 | we save the random points in our few-shot learning, so the results are confirmed 71 | ''' 72 | def __init__(self, partition='test', few_num=0, num_points=1024): 73 | assert partition in ('test', 'train') 74 | super().__init__() 75 | self.partition = partition 76 | self.few_num = few_num 77 | self.num_points = num_points 78 | self._load_data() 79 | if self.partition == 'train' and self.few_num > 0: 80 | self.paths, self.labels = self._few() 81 | 82 | def _load_data(self): 83 | DATA_DIR = './data/ModelNet40_Ply' 84 | self.paths = [] 85 | self.labels = [] 86 | for case in os.listdir(DATA_DIR): 87 | self.paths.append(os.path.join(DATA_DIR, case)) 88 | self.labels.append(int(case.split('_')[0])) 89 | 90 | def __getitem__(self, index): 91 | label = self.labels[index] 92 | point = torch.from_numpy(read_ply(self.paths[index])) 93 | point = pc_normalize(point) 94 | return point, label 95 | 96 | def __len__(self): 97 | return len(self.labels) 98 | -------------------------------------------------------------------------------- /zeroshot.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | import argparse 4 | from torch.utils.data import DataLoader 5 | import clip 6 | from tqdm import tqdm 7 | from pointnet2_ops import pointnet2_utils 8 | 9 | from datasets import ModelNet10, ModelNet40Align, ModelNet40Ply, ScanObjectNN 10 | from render.selector import Selector 11 | from render.render import Renderer 12 | from utils import read_state_dict 13 | 14 | clip_model, _ = clip.load('ViT-B/32', device='cpu') 15 | 16 | 17 | def inference(args): 18 | if args.dataset == 'ModelNet10': 19 | dataset = ModelNet10() 20 | prompts = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] 21 | elif args.dataset == 'ModelNet40': 22 | dataset = ModelNet40Ply() 23 | prompts = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower pot', 'glass box', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night stand', 'person', 'piano', 'plant', 'radio', 'range hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv stand', 'vase', 'wardrobe', 'xbox'] 24 | else: 25 | dataset = ScanObjectNN() 26 | prompts = ['bag', 'bin', 'box', 'cabinet', 'chair', 'desk', 'display', 'door', 'shelf', 'table', 'bed', 'pillow', 'sink', 'sofa', 'toilet'] 27 | 28 | 29 | dataloader = DataLoader(dataset, batch_size=args.test_batch_size, num_workers=4, shuffle=True) 30 | prompts = ['image of a ' + prompts[i] for i in range(len(prompts))] 31 | prompts = clip.tokenize(prompts) 32 | prompts = clip_model.encode_text(prompts) 33 | prompts_feats = prompts / prompts.norm(dim=-1, keepdim=True) 34 | 35 | # =================================== INIT MODEL =========================================================== 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | model = deepcopy(clip_model.visual).to(device) 38 | if args.ckpt is not None: 39 | model.load_state_dict(read_state_dict(args.ckpt)) 40 | selector = Selector(args.views, 0).to(device) 41 | render = Renderer(points_per_pixel=1, points_radius=0.02).to(device) 42 | prompt_feats = prompts_feats.to(device) 43 | # ==================================== TESTING LOOP ===================================================== 44 | model.eval() 45 | with torch.no_grad(): 46 | correct_num = 0 47 | total = 0 48 | for (points, label) in tqdm(dataloader): 49 | points = points.to(device) 50 | if args.dataset == 'ScanObjectNN': 51 | fps_idx = pointnet2_utils.furthest_point_sample(points, 1024) 52 | points = pointnet2_utils.gather_operation(points.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() 53 | c_views_azim, c_views_elev, c_views_dist = selector(points) 54 | if args.dataset == 'ScanObjectNN': 55 | images = render(points, c_views_azim, c_views_elev, c_views_dist, args.views, rot=False) 56 | else: 57 | images = render(points, c_views_azim, c_views_elev, c_views_dist, args.views, rot=True) 58 | b, n, c, h, w = images.shape 59 | images = images.reshape(-1, c, h, w) 60 | image_feats = model(images) 61 | image_feats = image_feats / image_feats.norm(dim=-1, keepdim=True) 62 | logits = image_feats @ prompt_feats.t() 63 | logits = logits.reshape(b, n, -1) 64 | logits = torch.sum(logits, dim=1) 65 | probs = logits.softmax(dim=-1) 66 | index = torch.max(probs, dim=1).indices 67 | correct_num += torch.sum(torch.eq(index.detach().cpu(), label)).item() 68 | total += len(label) 69 | test_acc = correct_num / total 70 | print(test_acc) 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser(description='Zero-shot Point Cloud Classification') 75 | parser.add_argument('--dataset', type=str, choices=['ModelNet10', 'ModelNet40', 'ScanObjectNN']) 76 | parser.add_argument('--views', type=int, default=6) 77 | parser.add_argument('--ckpt', type=str, default=None) 78 | parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size', 79 | help='Size of batch)') 80 | args = parser.parse_args() 81 | 82 | inference(args) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP2Point 2 | This repo is the official implementation of ICCV 2023 paper ["CLIP2Point: Transfer CLIP to Point Cloud Classification with Image-Depth Pre-training"](https://arxiv.org/abs/2210.01055) 3 | 4 | ## Introduction 5 | CLIP2Point is an image-depth pre-training method by contrastive learning to transfer CLIP to the 3D domain, and is then adapted to point cloud classification. We introduce a new depth rendering setting that forms a better visual effect, and then render 52,460 pairs of images and depth maps from ShapeNet for pre-training. The pre-training scheme of CLIP2Point combines cross-modality learning to enforce the depth features for capturing expressive visual and textual features and intra-modality learning to enhance the invariance of depth aggregation. Additionally, we propose a novel Dual-Path Adapter (DPA) module, i.e., a dual-path structure with simplified adapters for few-shot learning. The dual-path structure allows the joint use of CLIP and CLIP2Point, and the simplified adapter can well fit few-shot tasks without post-search. Experimental results show that CLIP2Point is effective in transferring CLIP knowledge to 3D vision. Our CLIP2Point outperforms PointCLIP and other self-supervised 3D networks, achieving state-of-the-art results on zero-shot and few-shot classification. 6 | 7 | ![test](./pipeline.png) 8 | 9 | ## Requirements 10 | ### Installation 11 | PyTorch, PyTorch3d, CLIP, pointnet2_ops, etc., are required. We recommend to create a conda environment and install dependencies in Linux as follows: 12 | ```shell 13 | # create a conda environment 14 | conda create -n clip2point python=3.7 -y 15 | conda activate clip2point 16 | 17 | # install pytorch & pytorch3d 18 | conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch 19 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 20 | conda install -c bottler nvidiacub 21 | conda install pytorch3d -c pytorch3d 22 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" 23 | 24 | # install CLIP 25 | pip install ftfy regex tqdm 26 | pip install git+https://github.com/openai/CLIP.git 27 | 28 | # install pointnet2 & other packages 29 | pip install "git+https://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Data preparation 34 | The overall directory structure should be: 35 | 36 | ``` 37 | │CLIP2Point/ 38 | ├──datasets/ 39 | ├──data/ 40 | │ ├──ModelNet40_Align/ 41 | │ ├──ModelNet40_Ply/ 42 | │ ├──Rendering/ 43 | │ ├──ShapeNet55/ 44 | │ ...... 45 | ├──....... 46 | ``` 47 | For pre-training, we use **Rendering** and **ShapeNet55**, which respectively provide rendered images and point cloud data. We provide [**Rendering**](https://drive.google.com/file/d/1jMuYi4IoM6A80uPCohjGkNP6pLP1-0Gm/view?usp=sharing), and you can refer to [Point-BERT](https://github.com/lulutang0608/Point-BERT/blob/master/DATASET.md) for **ShapeNet55**. 48 | 49 | For downstream classification, we use [**ModelNet40_Align**](https://github.com/lmb-freiburg/orion). As we randomly sample points from .off file, results of this dataset can be better or worse than the results we claimed. Thus, we further save and provide the sampled point cloud data in our training process, namely [**ModelNet40_Ply**](https://drive.google.com/file/d/1nEJYZ9QPBgYPMiVCJGKWY6pVaJmPxVRq/view?usp=sharing). Note that our claimed results are possibly not the best results, but a best result in our training process. You can use **ModelNet40_Align** to find a better result, or simply infernce our pre-trained weights with **ModelNet40_Ply**. 50 | 51 | 52 | ## Get stared 53 | We have several config options in training and evaluation, you can follow our settings, or modify some of them to satisfy your requirements. We provide the pre-trained checkpoint [best_eval.pth](https://drive.google.com/file/d/1ZAnIANNMqRRRmaVtk8Kp93s_NkGU51zv/view?usp=sharing) and few-shot checkpoint [best_test.pth](https://drive.google.com/file/d/1Jr1yXOu1yKmMs8K7XD8FnttPRHnZOZHx/view?usp=sharing) 54 | ### Pre-training 55 | If you want to pre-train a depth encoder, and then save the logs and checkpoints at ./pre_results/vit32/, 56 | ``` 57 | python pretraining.py --exp_name vit32 58 | ``` 59 | ### Zero-shot Classification 60 | If you want to evaluate zero-shot CLIP visual encoder, 61 | ``` 62 | python zeroshot.py 63 | ``` 64 | or you can use the checkpoint, 65 | ``` 66 | python zeroshot.py --ckpt [pre-trained_ckpt_path] 67 | ``` 68 | ### Few-shot Classification 69 | If you want to train a Dual-Path Adapter (DPA) for few-shot classification and save the logs and checkpoints at ./exp_results/dpa/, 70 | ``` 71 | python fewshot.py --exp_name dpa --ckpt [pre-trained_ckpt_path] 72 | ``` 73 | or simply evaluate it with the few-shot training checkpoint, 74 | ``` 75 | python fewshot.py --eval --ckpt [trained_ckpt_path] 76 | ``` 77 | 78 | 79 | ## Acknowledgement 80 | Our codes are built on [CLIP](https://github.com/openai/CLIP), [MVTN](https://github.com/ajhamdi/MVTN), and [CrossPoint](https://github.com/MohamedAfham/CrossPoint). 81 | 82 | ## Citation 83 | ``` 84 | @article{huang2022clip2point, 85 | title={CLIP2Point: Transfer CLIP to Point Cloud Classification with Image-Depth Pre-training}, 86 | author={Huang, Tianyu and Dong, Bowen and Yang, Yunhan and Huang, Xiaoshui and Lau, Rynson WH and Ouyang, Wanli and Zuo, Wangmeng}, 87 | journal={arXiv preprint arXiv:2210.01055}, 88 | year={2022} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /fewshot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | import clip 7 | from tqdm import tqdm 8 | from torch.utils.tensorboard import SummaryWriter 9 | from datasets.modelnet40_align import ModelNet40Ply 10 | 11 | from models import DPA 12 | from datasets import ModelNet40Align 13 | from utils import IOStream 14 | 15 | 16 | clip_model, _ = clip.load('ViT-B/32', device='cpu') 17 | prompts = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower pot', 'glass box', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night stand', 'person', 'piano', 'plant', 'radio', 'range hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv stand', 'vase', 'wardrobe', 'xbox'] 18 | prompts = ['image of a ' + prompts[i] for i in range(len(prompts))] 19 | prompts = clip.tokenize(prompts) 20 | prompts = clip_model.encode_text(prompts) 21 | prompts_feats = prompts / prompts.norm(dim=-1, keepdim=True) 22 | 23 | 24 | def _init_(): 25 | if not os.path.exists('exp_results'): 26 | os.makedirs('exp_results') 27 | if not os.path.exists('exp_results/'+args.exp_name): 28 | os.makedirs('exp_results/'+args.exp_name) 29 | 30 | 31 | def train(args, io): 32 | train_dataloader = DataLoader(ModelNet40Align('train', 16), batch_size=args.batch_size, num_workers=4, shuffle=True) 33 | test_dataloader = DataLoader(ModelNet40Align('test'), batch_size=args.test_batch_size, num_workers=4, shuffle=True) 34 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 35 | # =================================== INIT MODEL =========================================================== 36 | model = DPA(args).to(device) 37 | for name, param in model.named_parameters(): 38 | if 'adapter' not in name and 'selector' not in name and 'renderer' not in name: 39 | param.requires_grad_(False) 40 | prompt_feats = prompts_feats.to(device).detach() 41 | # ==================================== TRAINING LOOP ===================================================== 42 | optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) 43 | n_epochs = args.epoch 44 | max_test_acc = 0 45 | summary_writer = SummaryWriter("exp_results/%s/tensorboard" % args.exp_name) 46 | for epoch in range(n_epochs): 47 | model.train() 48 | loss_sum = 0 49 | correct_num = 0 50 | total = 0 51 | for (points, label) in tqdm(train_dataloader): 52 | points = points.to(device) 53 | label = label.to(device) 54 | optimizer.zero_grad() 55 | img_feats = model(points) 56 | logits = img_feats @ prompt_feats.t() 57 | 58 | loss = F.cross_entropy(logits, label) 59 | loss_sum += loss.item() 60 | loss.backward() 61 | optimizer.step() 62 | probs = logits.softmax(dim=-1) 63 | index = torch.max(probs, dim=1).indices 64 | correct_num += torch.sum(torch.eq(index, label)).item() 65 | total += len(label) 66 | train_acc = correct_num / total 67 | 68 | model.eval() 69 | with torch.no_grad(): 70 | correct_num = 0 71 | total = 0 72 | for (points, label) in tqdm(test_dataloader): 73 | points = points.to(device) 74 | img_feats = model(points) 75 | logits = img_feats @ prompt_feats.t() 76 | probs = logits.softmax(dim=-1) 77 | index = torch.max(probs, dim=1).indices 78 | correct_num += torch.sum(torch.eq(index.detach().cpu(), label)).item() 79 | total += len(label) 80 | test_acc = correct_num / total 81 | 82 | mean_loss = loss_sum / len(train_dataloader) 83 | io.cprint('epoch%d total_loss: %.4f, train_acc: %.4f, test_acc: %.4f' % (epoch + 1, mean_loss, train_acc, test_acc)) 84 | summary_writer.add_scalar('train/loss', mean_loss, epoch + 1) 85 | summary_writer.add_scalar("train/acc", train_acc, epoch + 1) 86 | summary_writer.add_scalar("test/acc", test_acc, epoch + 1) 87 | if test_acc > max_test_acc: 88 | max_test_acc = test_acc 89 | torch.save(model.state_dict(), 'exp_results/%s/best.pth' % (args.exp_name)) 90 | io.cprint('save the best test acc at %d' % (epoch + 1)) 91 | 92 | 93 | def eval(args): 94 | assert args.ckpt is not None, 'load a checkpoint for evaluation' 95 | test_dataloader = DataLoader(ModelNet40Ply('test'), batch_size=args.test_batch_size, num_workers=4, shuffle=True) 96 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 97 | model = DPA(args, True).to(device) 98 | model.load_state_dict(torch.load(args.ckpt)) 99 | prompt_feats = prompts_feats.to(device).detach() 100 | 101 | model.eval() 102 | with torch.no_grad(): 103 | correct_num = 0 104 | total = 0 105 | for (points, label) in tqdm(test_dataloader): 106 | points = points.to(device) 107 | img_feats = model(points) 108 | logits = img_feats @ prompt_feats.t() 109 | probs = logits.softmax(dim=-1) 110 | index = torch.max(probs, dim=1).indices 111 | correct_num += torch.sum(torch.eq(index.detach().cpu(), label)).item() 112 | total += len(label) 113 | test_acc = correct_num / total 114 | print(test_acc) 115 | 116 | 117 | if __name__ == "__main__": 118 | # Training settings 119 | parser = argparse.ArgumentParser(description='Few-shot Point Cloud Classification') 120 | parser.add_argument('--exp_name', type=str, default='test', metavar='N', 121 | help='Name of the experiment') 122 | parser.add_argument('--views', type=int, default=10) 123 | parser.add_argument('--ckpt', type=str, default=None) 124 | parser.add_argument('--dim', type=int, default=0, choices=[0, 512], help='0 if the view angle is not learnable') 125 | parser.add_argument('--model', type=str, default='PointNet', metavar='N', 126 | choices=['DGCNN', 'PointNet'], 127 | help='Model to use, [pointnet, dgcnn]') 128 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 129 | help='Size of batch)') 130 | parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size', 131 | help='Size of batch)') 132 | parser.add_argument('--epoch', type=int, default=100, metavar='N', 133 | help='number of episode to train ') 134 | parser.add_argument('--eval', action='store_true') 135 | args = parser.parse_args() 136 | 137 | if not args.eval: 138 | _init_() 139 | io = IOStream('exp_results/' + args.exp_name + '/run.log') 140 | io.cprint(str(args)) 141 | train(args, io) 142 | else: 143 | eval(args) 144 | -------------------------------------------------------------------------------- /render/selector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from render.blocks import MLP, PointNet, SimpleDGCNN, load_point_ckpt 5 | 6 | 7 | class ViewSelector(nn.Module): 8 | def __init__(self, nb_views, canonical_distance=1., transform_distance=False, input_view_noise=0.0): 9 | super().__init__() 10 | self.nb_views = nb_views 11 | self.transform_distance = transform_distance 12 | self.canonical_distance = canonical_distance 13 | self.input_view_noise = input_view_noise 14 | views_dist = torch.ones((self.nb_views), dtype=torch.float, requires_grad=False) * canonical_distance 15 | if self.nb_views == 10: 16 | views_elev = torch.asarray((0, 90, 180, 270, 225, 225, 315, 315, 0, 0), dtype=torch.float, requires_grad=False) 17 | views_azim = torch.asarray((0, 0, 0, 0, -45, 45, -45, 45, -90, 90), dtype=torch.float, requires_grad=False) 18 | elif self.nb_views == 6: 19 | views_elev = torch.asarray((0, 0, 0, 0, 90, -90), dtype=torch.float, requires_grad=False) 20 | views_azim = torch.asarray((0, 90, 180, 270, 0, 180), dtype=torch.float, requires_grad=False) 21 | 22 | self.register_buffer('views_azim', views_azim) 23 | self.register_buffer('views_elev', views_elev) 24 | self.register_buffer('views_dist', views_dist) 25 | 26 | def forward(self, c_batch_size): 27 | c_views_azim = self.views_azim.expand(c_batch_size, self.nb_views) 28 | c_views_elev = self.views_elev.expand(c_batch_size, self.nb_views) 29 | c_views_dist = self.views_dist.expand(c_batch_size, self.nb_views) 30 | c_views_dist = c_views_dist + float(self.transform_distance) * 1.0 * c_views_dist * ( 31 | torch.rand((c_batch_size, self.nb_views), device=c_views_dist.device) - 0.5) 32 | if self.input_view_noise > 0.0 and self.training: 33 | c_views_azim = c_views_azim + \ 34 | torch.normal(0.0, 180.0 * self.input_view_noise, 35 | c_views_azim.size(), device=c_views_azim.device) 36 | c_views_elev = c_views_elev + \ 37 | torch.normal(0.0, 90.0 * self.input_view_noise, 38 | c_views_elev.size(), device=c_views_elev.device) 39 | c_views_dist = c_views_dist + \ 40 | torch.normal(0.0, self.canonical_distance * self.input_view_noise, 41 | c_views_dist.size(), device=c_views_dist.device) 42 | return c_views_azim, c_views_elev, c_views_dist 43 | 44 | 45 | class LearnedViewSelector(ViewSelector): 46 | def __init__(self, nb_views, shape_features_size=512, canonical_distance=1., transform_distance=False, input_view_noise=0.0): 47 | ViewSelector.__init__(self, nb_views, canonical_distance, transform_distance, input_view_noise) 48 | self.view_transformer = nn.Sequential( 49 | MLP([shape_features_size+3*self.nb_views, shape_features_size, shape_features_size, 5 * self.nb_views, 3*self.nb_views], dropout=0.5, norm=True), 50 | MLP([3*self.nb_views, 3*self.nb_views], act=None, dropout=0, norm=False), 51 | nn.Tanh()) if self.transform_distance \ 52 | else nn.Sequential( 53 | MLP([shape_features_size+2*self.nb_views, shape_features_size, shape_features_size, 5 * self.nb_views, 2*self.nb_views], dropout=0.5, norm=True), 54 | MLP([2*self.nb_views, 2*self.nb_views], act=None, dropout=0, norm=False), 55 | nn.Tanh()) 56 | 57 | def forward(self, shape_features): 58 | c_batch_size = shape_features.shape[0] 59 | c_views_azim = self.views_azim.expand(c_batch_size, self.nb_views) 60 | c_views_elev = self.views_elev.expand(c_batch_size, self.nb_views) 61 | c_views_dist = self.views_dist.expand(c_batch_size, self.nb_views) 62 | c_views_dist = c_views_dist + float(self.transform_distance) * 1.0 * c_views_dist * ( 63 | torch.rand((c_batch_size, self.nb_views), device=c_views_dist.device) - 0.5) 64 | if self.input_view_noise > 0.0 and self.training: 65 | c_views_azim = c_views_azim + \ 66 | torch.normal(0.0, 180.0 * self.input_view_noise, 67 | c_views_azim.size(), device=c_views_azim.device) 68 | c_views_elev = c_views_elev + \ 69 | torch.normal(0.0, 90.0 * self.input_view_noise, 70 | c_views_elev.size(), device=c_views_elev.device) 71 | c_views_dist = c_views_dist + \ 72 | torch.normal(0.0, self.canonical_distance * self.input_view_noise, 73 | c_views_dist.size(), device=c_views_dist.device) 74 | if not self.transform_distance: 75 | adjutment_vector = self.view_transformer( 76 | torch.cat([shape_features, c_views_azim, c_views_elev], dim=1)) 77 | adjutment_vector = torch.chunk(adjutment_vector, 2, dim=1) 78 | return c_views_azim + adjutment_vector[0] * 180.0/self.nb_views, c_views_elev + adjutment_vector[1] * 90.0, c_views_dist 79 | else: 80 | adjutment_vector = self.view_transformer( 81 | torch.cat([shape_features, c_views_azim, c_views_elev, c_views_dist], dim=1)) 82 | adjutment_vector = torch.chunk(adjutment_vector, 3, dim=1) 83 | return c_views_azim + adjutment_vector[0] * 180.0/self.nb_views, c_views_elev + adjutment_vector[1] * 90.0, c_views_dist + adjutment_vector[2] * self.canonical_distance + 0.1 84 | 85 | 86 | class FeatureExtractor(nn.Module): 87 | def __init__(self, shape_features_size, shape_extractor, screatch_feature_extractor): 88 | super().__init__() 89 | if shape_extractor == "PointNet": 90 | print('build PointNet selector') 91 | self.fe_model = PointNet(shape_features_size, alignment=True) 92 | elif shape_extractor == "DGCNN": 93 | print('build DGCNN selector') 94 | self.fe_model = SimpleDGCNN(shape_features_size) 95 | if screatch_feature_extractor: 96 | load_point_ckpt(self.fe_model, shape_extractor, 97 | ckpt_dir='./checkpoint') 98 | self.features_order = {"logits": 0, 99 | "post_max": 1, "transform_matrix": 2} 100 | 101 | def forward(self, points): 102 | batch_size, _, _ = points.shape 103 | points = points.transpose(1, 2) 104 | features = self.fe_model(points) 105 | return features[0].view(batch_size, -1) 106 | 107 | 108 | class Selector(nn.Module): 109 | def __init__(self, nb_views, shape_features_size=512, shape_extractor="PointNet", canonical_distance=1., transform_distance=False, input_view_noise=0.0, screatch_feature_extractor=False): 110 | super().__init__() 111 | self.learned = True if shape_features_size > 0 else False 112 | self.view_selector = LearnedViewSelector(nb_views, shape_features_size, canonical_distance, transform_distance, input_view_noise) if self.learned else ViewSelector(nb_views, canonical_distance, transform_distance, input_view_noise) 113 | if self.learned: 114 | self.feature_extractor = FeatureExtractor(shape_features_size=shape_features_size, shape_extractor=shape_extractor, screatch_feature_extractor=screatch_feature_extractor) 115 | 116 | def forward(self, points): 117 | if self.learned: 118 | shape_features = self.feature_extractor(points) 119 | return self.view_selector(shape_features) 120 | return self.view_selector(points.shape[0]) 121 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import copy 4 | import numpy 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from mpl_toolkits.mplot3d.art3d import Poly3DCollection 7 | import matplotlib.pyplot 8 | 9 | # used to read ply files 10 | from plyfile import PlyData 11 | import open3d as o3d 12 | import numpy as np 13 | 14 | 15 | class Mesh: 16 | def __init__(self): 17 | self._vertices = [] # array-like (N, D) 18 | self._faces = [] # array-like (M, K) 19 | self._edges = [] # array-like (L, 2) 20 | 21 | def clone(self): 22 | other = copy.deepcopy(self) 23 | return other 24 | 25 | def clear(self): 26 | for key in self.__dict__: 27 | self.__dict__[key] = [] 28 | 29 | def add_attr(self, name): 30 | self.__dict__[name] = [] 31 | 32 | @property 33 | def vertex_array(self): 34 | return numpy.array(self._vertices) 35 | 36 | @property 37 | def vertex_list(self): 38 | return list(map(tuple, self._vertices)) 39 | 40 | @staticmethod 41 | def faces2polygons(faces, vertices): 42 | p = list(map(lambda face: \ 43 | list(map(lambda vidx: vertices[vidx], face)), faces)) 44 | return p 45 | 46 | @property 47 | def polygon_list(self): 48 | p = Mesh.faces2polygons(self._faces, self._vertices) 49 | return p 50 | 51 | def plot(self, fig=None, ax=None, *args, **kwargs): 52 | p = self.polygon_list 53 | v = self.vertex_array 54 | if fig is None: 55 | fig = matplotlib.pyplot.gcf() 56 | if ax is None: 57 | ax = Axes3D(fig) 58 | if p: 59 | ax.add_collection3d(Poly3DCollection(p)) 60 | if v.shape: 61 | ax.scatter(v[:, 0], v[:, 1], v[:, 2], *args, **kwargs) 62 | ax.set_xlabel('X') 63 | ax.set_ylabel('Y') 64 | ax.set_zlabel('Z') 65 | return fig, ax 66 | 67 | def on_unit_sphere(self, zero_mean=False): 68 | # radius == 1 69 | v = self.vertex_array # (N, D) 70 | if zero_mean: 71 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 72 | v[:, 0:3] = v[:, 0:3] - a 73 | n = numpy.linalg.norm(v[:, 0:3], axis=1) # (N,) 74 | m = numpy.max(n) # scalar 75 | v[:, 0:3] = v[:, 0:3] / m 76 | self._vertices = v 77 | return self 78 | 79 | def on_unit_cube(self, zero_mean=False): 80 | # volume == 1 81 | v = self.vertex_array # (N, D) 82 | if zero_mean: 83 | a = numpy.mean(v[:, 0:3], axis=0, keepdims=True) # (1, 3) 84 | v[:, 0:3] = v[:, 0:3] - a 85 | m = numpy.max(numpy.abs(v)) # scalar 86 | v[:, 0:3] = v[:, 0:3] / (m * 2) 87 | self._vertices = v 88 | return self 89 | 90 | def rot_x(self): 91 | # camera local (up: +Y, front: -Z) -> model local (up: +Z, front: +Y). 92 | v = self.vertex_array 93 | t = numpy.copy(v[:, 1]) 94 | v[:, 1] = -numpy.copy(v[:, 2]) 95 | v[:, 2] = t 96 | self._vertices = list(map(tuple, v)) 97 | return self 98 | 99 | def rot_zc(self): 100 | # R = [0, -1; 101 | # 1, 0] 102 | v = self.vertex_array 103 | x = numpy.copy(v[:, 0]) 104 | y = numpy.copy(v[:, 1]) 105 | v[:, 0] = -y 106 | v[:, 1] = x 107 | self._vertices = list(map(tuple, v)) 108 | return self 109 | 110 | def offread_uniformed(filepath, sampled_pt_num=1024): 111 | """ read OFF mesh file and uniformly sample points on the mesh. """ 112 | mesh = Mesh() 113 | input = o3d.io.read_triangle_mesh(filepath) 114 | pointCloud = input.sample_points_uniformly(sampled_pt_num) 115 | points = np.asarray(pointCloud.points) 116 | return points 117 | pts = tuple(map(tuple, points)) 118 | mesh._vertices = pts 119 | 120 | return mesh 121 | 122 | def offread(filepath, points_only=True): 123 | """ read Geomview OFF file. """ 124 | with open(filepath, 'r') as fin: 125 | mesh, fixme = _load_off(fin, points_only) 126 | if fixme: 127 | _fix_modelnet_broken_off(filepath) 128 | return mesh 129 | 130 | 131 | def _load_off(fin, points_only): 132 | """ read Geomview OFF file. """ 133 | mesh = Mesh() 134 | 135 | fixme = False 136 | sig = fin.readline().strip() 137 | if sig == 'OFF': 138 | line = fin.readline().strip() 139 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 140 | elif sig[0:3] == 'OFF': # ...broken data in ModelNet (missing '\n')... 141 | line = sig[3:] 142 | num_verts, num_faces, num_edges = tuple([int(s) for s in line.split(' ')]) 143 | fixme = True 144 | else: 145 | raise RuntimeError('unknown format') 146 | 147 | for v in range(num_verts): 148 | vp = tuple(float(s) for s in fin.readline().strip().split(' ')) 149 | mesh._vertices.append(vp) 150 | 151 | if points_only: 152 | return mesh, fixme 153 | 154 | for f in range(num_faces): 155 | fc = tuple([int(s) for s in fin.readline().strip().split(' ')][1:]) 156 | mesh._faces.append(fc) 157 | 158 | return mesh, fixme 159 | 160 | 161 | def _fix_modelnet_broken_off(filepath): 162 | oldfile = '{}.orig'.format(filepath) 163 | os.rename(filepath, oldfile) 164 | with open(oldfile, 'r') as fin: 165 | with open(filepath, 'w') as fout: 166 | sig = fin.readline().strip() 167 | line = sig[3:] 168 | print('OFF', file=fout) 169 | print(line, file=fout) 170 | for line in fin: 171 | print(line.strip(), file=fout) 172 | 173 | 174 | def objread(filepath, points_only=True): 175 | """Loads a Wavefront OBJ file. """ 176 | _vertices = [] 177 | _normals = [] 178 | _texcoords = [] 179 | _faces = [] 180 | _mtl_name = None 181 | 182 | material = None 183 | for line in open(filepath, "r"): 184 | if line.startswith('#'): continue 185 | values = line.split() 186 | if not values: continue 187 | if values[0] == 'v': 188 | v = tuple(map(float, values[1:4])) 189 | _vertices.append(v) 190 | elif values[0] == 'vn': 191 | v = tuple(map(float, values[1:4])) 192 | _normals.append(v) 193 | elif values[0] == 'vt': 194 | _texcoords.append(tuple(map(float, values[1:3]))) 195 | elif values[0] in ('usemtl', 'usemat'): 196 | material = values[1] 197 | elif values[0] == 'mtllib': 198 | _mtl_name = values[1] 199 | elif values[0] == 'f': 200 | face_ = [] 201 | texcoords_ = [] 202 | norms_ = [] 203 | for v in values[1:]: 204 | w = v.split('/') 205 | face_.append(int(w[0]) - 1) 206 | if len(w) >= 2 and len(w[1]) > 0: 207 | texcoords_.append(int(w[1]) - 1) 208 | else: 209 | texcoords_.append(-1) 210 | if len(w) >= 3 and len(w[2]) > 0: 211 | norms_.append(int(w[2]) - 1) 212 | else: 213 | norms_.append(-1) 214 | # _faces.append((face_, norms_, texcoords_, material)) 215 | _faces.append(face_) 216 | 217 | mesh = Mesh() 218 | mesh._vertices = _vertices 219 | if points_only: 220 | return mesh 221 | 222 | mesh._faces = _faces 223 | 224 | return mesh 225 | 226 | 227 | def plyread(filepath, points_only=True): 228 | # read binary ply file and return [x, y, z] array 229 | data = PlyData.read(filepath) 230 | vertex = data['vertex'] 231 | 232 | (x, y, z) = (vertex[t] for t in ('x', 'y', 'z')) 233 | num_verts = len(x) 234 | 235 | mesh = Mesh() 236 | 237 | for v in range(num_verts): 238 | vp = tuple(float(s) for s in [x[v], y[v], z[v]]) 239 | mesh._vertices.append(vp) 240 | 241 | return mesh 242 | 243 | 244 | def pc_normalize(pc): 245 | centroid = torch.mean(pc, dim=0) 246 | pc = pc - centroid 247 | m = torch.max(torch.sqrt(torch.sum(pc**2, dim=1))) 248 | pc = pc / m 249 | return pc -------------------------------------------------------------------------------- /pretraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from pointnet2_ops import pointnet2_utils 7 | from tqdm import tqdm 8 | import clip 9 | import torch_optimizer as optim 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from models import CLIP2Point 13 | from datasets import ModelNet40Align, ShapeNetRender 14 | from utils import IOStream 15 | 16 | clip_model, _ = clip.load("ViT-B/32", device='cpu') 17 | 18 | 19 | def _init_(path): 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | if not os.path.exists(path + '/' + args.exp_name): 23 | os.makedirs(path + '/' + args.exp_name) 24 | 25 | 26 | def train(args, io): 27 | test_prompts = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower pot', 'glass box', 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night stand', 'person', 'piano', 'plant', 'radio', 'range hood', 'sink', 'sofa', 'stairs', 'stool', 'table', 'tent', 'toilet', 'tv stand', 'vase', 'wardrobe', 'xbox'] 28 | val_prompts = ['airplane', 'ashcan', 'bag', 'basket', 'bathtub', 'bed', 'bench', 'birdhouse', 'bookshelf', 'bottle', 'bowl', 'bus', 'cabinet', 'camera', 'can', 'cap', 'car', 'cellular telephone', 'chair', 'clock', 'computer keyboard', 'dishwasher', 'display', 'earphone', 'faucet', 'file', 'guitar', 'helmet', 'jar', 'knife', 'lamp', 'laptop', 'loudspeaker', 'mailbox', 'microphone', 'microwave', 'motorcycle', 'mug', 'piano', 'pillow', 'pistol', 'pot', 'printer', 'remote control', 'rifle', 'rocket', 'skateboard', 'sofa', 'stove', 'table', 'telephone', 'tower', 'train', 'vessel', 'washer'] 29 | test_prompts = ['image of a ' + test_prompts[i] for i in range(len(test_prompts))] 30 | val_prompts = ['image of a ' + val_prompts[i] for i in range(len(val_prompts))] 31 | test_prompts_ = clip.tokenize(test_prompts) 32 | test_prompt_feats = clip_model.encode_text(test_prompts_) 33 | test_prompt_feats = test_prompt_feats / test_prompt_feats.norm(dim=-1, keepdim=True) 34 | test_prompt_feats = test_prompt_feats 35 | val_prompts_ = clip.tokenize(val_prompts) 36 | val_prompt_feats = clip_model.encode_text(val_prompts_) 37 | val_prompt_feats = val_prompt_feats / val_prompt_feats.norm(dim=-1, keepdim=True) 38 | val_prompt_feats = val_prompt_feats 39 | 40 | train_dataloader = DataLoader(ShapeNetRender(partition='train', num_points=args.num_points), batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) 41 | val_loader = DataLoader(ShapeNetRender(partition='test', num_points=args.num_points), batch_size=args.test_batch_size, shuffle=True, num_workers=4) 42 | test_loader = DataLoader(ModelNet40Align(num_points=args.num_points), batch_size=args.test_batch_size, num_workers=4) 43 | gpu_num = torch.cuda.device_count() 44 | gpus = [i for i in range(gpu_num)] 45 | device = torch.device(f'cuda:{gpus[0]}' if torch.cuda.is_available() else 'cpu') 46 | # =================================== INIT MODEL ========================================================== 47 | summary_writer = SummaryWriter("pre_results/%s/tensorboard" % (args.exp_name)) 48 | model = CLIP2Point(args) 49 | model = nn.DataParallel(model, device_ids=gpus, output_device=gpus[0]) # 多卡训练修改 50 | model = model.to(device) 51 | for name, param in model.named_parameters(): 52 | if 'image_model' in name: 53 | param.requires_grad_(False) 54 | val_prompt_feats = val_prompt_feats.to(device) 55 | test_prompt_feats = test_prompt_feats.to(device) 56 | # ==================================== TRAINING LOOP ====================================================== 57 | optimizer = optim.Lamb(model.parameters(), lr=0.006, weight_decay=1e-4) 58 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( 59 | optimizer, 60 | T_0=2 * len(train_dataloader), 61 | T_mult=1, 62 | eta_min=max(1e-2 * 1e-3, 1e-6), 63 | last_epoch=-1, 64 | ) 65 | 66 | n_epochs = args.epoch 67 | max_val_acc = 0 68 | max_test_acc = 0 69 | for epoch in range(n_epochs): 70 | model.train() 71 | loss_sum = 0 72 | depth_sum = 0 73 | image_sum = 0 74 | 75 | for (image, points, a, e, d) in tqdm(train_dataloader): 76 | optimizer.zero_grad() 77 | image = image.to(device) 78 | points = points.to(device) 79 | a = a.unsqueeze(-1).to(device) 80 | e = e.unsqueeze(-1).to(device) 81 | d = d.unsqueeze(-1).to(device) 82 | loss, image_loss, depth_loss = model(points, image, a, e, d) 83 | loss = torch.mean(loss) 84 | image_sum += torch.mean(image_loss).item() 85 | depth_sum += torch.mean(depth_loss).item() 86 | loss_sum += loss.item() 87 | loss.backward() 88 | optimizer.step() 89 | scheduler.step() 90 | 91 | # Validation and Testing 92 | model.eval() 93 | with torch.no_grad(): 94 | correct_num = 0 95 | total = 0 96 | for (points, label) in tqdm(val_loader): 97 | b = points.shape[0] 98 | points = points.to(device) 99 | img_feats = model.module.infer(points) 100 | 101 | logits = img_feats @ val_prompt_feats.t() 102 | logits = logits.reshape(b, args.views, -1) 103 | logits = torch.sum(logits, dim=1) 104 | probs = logits.softmax(dim=-1) 105 | index = torch.max(probs, dim=1).indices 106 | correct_num += torch.sum(torch.eq(index.detach().cpu(), label)).item() 107 | total += len(label) 108 | val_acc = correct_num / total 109 | 110 | with torch.no_grad(): 111 | correct_num = 0 112 | total = 0 113 | for (points, label) in tqdm(test_loader): 114 | b = points.shape[0] 115 | points = points.to(device) 116 | img_feats = model.module.infer(points, True) 117 | logits = img_feats @ test_prompt_feats.t() 118 | logits = logits.reshape(b, args.views, -1) 119 | logits = torch.sum(logits, dim=1) 120 | probs = logits.softmax(dim=-1) 121 | index = torch.max(probs, dim=1).indices 122 | correct_num += torch.sum(torch.eq(index.detach().cpu(), label)).item() 123 | total += len(label) 124 | test_acc = correct_num / total 125 | 126 | depth_loss = depth_sum / len(train_dataloader) 127 | image_loss = image_sum / len(train_dataloader) 128 | mean_loss = loss_sum / len(train_dataloader) 129 | io.cprint('epoch%d total_loss: %.4f, image_loss: %.4f, depth_loss: %.4f, balance_weights: %.4f, val_acc: %.4f, test_acc: %.4f' % (epoch + 1, mean_loss, image_loss, depth_loss, model.module.weights, val_acc, test_acc)) 130 | summary_writer.add_scalar('train/loss', mean_loss, epoch + 1) 131 | summary_writer.add_scalar('train/depth_loss', depth_loss, epoch + 1) 132 | summary_writer.add_scalar('train/image_loss', image_loss, epoch + 1) 133 | summary_writer.add_scalar('train/weights', model.module.weights, epoch + 1) 134 | summary_writer.add_scalar("val/acc", val_acc, epoch + 1) 135 | summary_writer.add_scalar("test/acc", test_acc, epoch + 1) 136 | if val_acc > max_val_acc: 137 | max_val_acc = val_acc 138 | torch.save(model.state_dict(), '%s/%s/best_val.pth' % ('pre_results', args.exp_name)) 139 | io.cprint('save the best val acc at %d' % (epoch + 1)) 140 | if test_acc > max_test_acc: 141 | max_test_acc = test_acc 142 | torch.save(model.state_dict(), '%s/%s/best_test.pth' % ('pre_results', args.exp_name)) 143 | io.cprint('save the best test acc at %d' % (epoch + 1)) 144 | 145 | 146 | if __name__ == "__main__": 147 | # Training settings 148 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 149 | parser.add_argument('--exp_name', type=str, default='test', metavar='N', 150 | help='Name of the experiment') 151 | parser.add_argument('--views', type=int, default=10) 152 | parser.add_argument('--num_points', type=int, default=1024) 153 | parser.add_argument('--ckpt', type=str, default=None) 154 | parser.add_argument('--dim', type=int, default=0, choices=[0, 512], help='0 if the view angle is not learnable') 155 | parser.add_argument('--model', type=str, default='PointNet', metavar='N', 156 | choices=['DGCNN', 'PointNet'], 157 | help='Model to use, [pointnet, dgcnn]') 158 | parser.add_argument('--batch_size', type=int, default=256, metavar='batch_size', 159 | help='Size of batch)') 160 | parser.add_argument('--test_batch_size', type=int, default=32, metavar='batch_size', 161 | help='Size of batch)') 162 | parser.add_argument('--epoch', type=int, default=100, metavar='N', 163 | help='number of episode to train ') 164 | args = parser.parse_args() 165 | 166 | _init_('pre_results') 167 | io = IOStream('pre_results' + '/' + args.exp_name + '/run.log') 168 | io.cprint(str(args)) 169 | train(args, io) 170 | -------------------------------------------------------------------------------- /datasets/shapenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.utils.data as data 5 | import h5py 6 | from typing import Tuple 7 | import collections 8 | from pytorch3d.io import load_obj 9 | import random 10 | from torchvision.transforms import Normalize, ToTensor 11 | from PIL import Image 12 | 13 | from render.render import Renderer 14 | 15 | 16 | cat_labels = {'02691156': 0, '02747177': 1, '02773838': 2, '02801938': 3, '02808440': 4, '02818832': 5, '02828884': 6, '02843684': 7, '02871439': 8, '02876657': 9, '02880940': 10, '02924116': 11, '02933112': 12, '02942699': 13, '02946921': 14, '02954340': 15, '02958343': 16, 17 | '02992529': 17, '03001627': 18, '03046257': 19, '03085013': 20, '03207941': 21, '03211117': 22, '03261776': 23, '03325088': 24, '03337140': 25, '03467517': 26, '03513137': 27, '03593526': 28, '03624134': 29, '03636649': 30, '03642806': 31, '03691459': 32, '03710193': 33, 18 | '03759954': 34, '03761084': 35, '03790512': 36, '03797390': 37, '03928116': 38, '03938244': 39, '03948459': 40, '03991062': 41, '04004475': 42, '04074963': 43, '04090263': 44, '04099429': 45, '04225987': 46, '04256520': 47, '04330267': 48, '04379243': 49, '04401088': 50, 19 | '04460130': 51, '04468005': 52, '04530566': 53, '04554684': 54} 20 | 21 | 22 | class IO: 23 | @classmethod 24 | def get(cls, file_path): 25 | _, file_extension = os.path.splitext(file_path) 26 | 27 | if file_extension in ['.npy']: 28 | return cls._read_npy(file_path) 29 | # elif file_extension in ['.pcd']: 30 | # return cls._read_pcd(file_path) 31 | elif file_extension in ['.h5']: 32 | return cls._read_h5(file_path) 33 | elif file_extension in ['.txt']: 34 | return cls._read_txt(file_path) 35 | else: 36 | raise Exception('Unsupported file extension: %s' % file_extension) 37 | 38 | @classmethod 39 | def _read_npy(cls, file_path): 40 | return np.load(file_path) 41 | 42 | @classmethod 43 | def _read_txt(cls, file_path): 44 | return np.loadtxt(file_path) 45 | 46 | @classmethod 47 | def _read_h5(cls, file_path): 48 | f = h5py.File(file_path, 'r') 49 | return f['data'][()] 50 | 51 | 52 | def torch_center_and_normalize(points,p="inf"): 53 | """ 54 | a helper pytorch function that normalize and center 3D points clouds 55 | """ 56 | N = points.shape[0] 57 | center = points.mean(0) 58 | if p != "fro" and p!= "no": 59 | scale = torch.max(torch.norm(points - center, p=float(p),dim=1)) 60 | elif p=="fro" : 61 | scale = torch.norm(points - center, p=p ) 62 | elif p=="no": 63 | scale = 1.0 64 | points = points - center.expand(N, 3) 65 | points = points * (1.0 / float(scale)) 66 | return points 67 | 68 | 69 | class ShapeNet(data.Dataset): 70 | def __init__(self, partition='train', whole=False, num_points=1024): 71 | assert partition in ['train', 'test'] 72 | self.data_root = './data/ShapeNet55/ShapeNet-55' 73 | self.pc_path = './data/ShapeNet55/shapenet_pc' 74 | self.subset = partition 75 | self.npoints = 8192 76 | 77 | self.data_list_file = os.path.join(self.data_root, f'{self.subset}.txt') 78 | test_data_list_file = os.path.join(self.data_root, 'test.txt') 79 | 80 | self.sample_points_num = num_points 81 | self.whole = whole 82 | 83 | with open(self.data_list_file, 'r') as f: 84 | lines = f.readlines() 85 | if self.whole: 86 | with open(test_data_list_file, 'r') as f: 87 | test_lines = f.readlines() 88 | lines = test_lines + lines 89 | self.file_list = [] 90 | check_list = ['03001627-udf068a6b', '03001627-u6028f63e', '03001627-uca24feec', '04379243-', '02747177-', '03001627-u481ebf18', '03001627-u45c7b89f', '03001627-ub5d972a1', '03001627-u1e22cc04', '03001627-ue639c33f'] 91 | 92 | # flag = False 93 | for line in lines: 94 | line = line.strip() 95 | taxonomy_id = line.split('-')[0] 96 | model_id = line.split('-')[1].split('.')[0] 97 | 98 | if taxonomy_id + '-' + model_id not in check_list: 99 | self.file_list.append({ 100 | 'taxonomy_id': taxonomy_id, 101 | 'model_id': model_id, 102 | 'file_path': line 103 | }) 104 | 105 | self.permutation = np.arange(self.npoints) 106 | 107 | def _load_mesh(self, model_path) -> Tuple: 108 | verts, faces, aux = load_obj(model_path, create_texture_atlas=True, texture_wrap='clamp') 109 | textures = aux.texture_atlas 110 | return verts, faces.verts_idx, textures 111 | 112 | def pc_norm(self, pc): 113 | """ pc: NxC, return NxC """ 114 | centroid = np.mean(pc, axis=0) 115 | pc = pc - centroid 116 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 117 | pc = pc / m 118 | return pc 119 | 120 | def random_sample(self, pc, num): 121 | np.random.shuffle(self.permutation) 122 | pc = pc[self.permutation[:num]] 123 | return pc 124 | 125 | def __getitem__(self, idx): 126 | sample = self.file_list[idx] 127 | 128 | points = IO.get(os.path.join(self.pc_path, sample['file_path'])).astype(np.float32) 129 | 130 | # points = self.random_sample(points, self.sample_points_num) 131 | points = self.pc_norm(points) 132 | points = torch.from_numpy(points).float() 133 | 134 | verts, faces, textures = self._load_mesh(os.path.join('/data/ShapeNetCore.v2', sample['taxonomy_id'], sample['model_id'], 'models', 'model_normalized.obj')) 135 | verts = torch_center_and_normalize(verts.to(torch.float), '2.0') 136 | mesh = dict() 137 | mesh["verts"] = verts 138 | mesh["faces"] = faces 139 | mesh["textures"] = textures 140 | # label = cat_labels[sample['taxonomy_id']] 141 | label = sample['taxonomy_id'] + '_' + sample['model_id'] 142 | # return points, mesh, sample['taxonomy_id'], sample['model_id'] 143 | return points, mesh, label 144 | 145 | def __len__(self): 146 | return len(self.file_list) 147 | 148 | 149 | class ShapeNetDebug(ShapeNet): 150 | def __init__(self, partition='train', whole=False): 151 | super().__init__(partition, whole) 152 | 153 | def __getitem__(self, idx): 154 | sample = self.file_list[idx] 155 | return sample['taxonomy_id'] + '_' + sample['model_id'] 156 | 157 | 158 | class ShapeNetRender(ShapeNet): 159 | def __init__(self, partition='train', whole=False, num_points=1024): 160 | super().__init__(partition, whole, num_points) 161 | self.partition = partition 162 | self.views_dist = torch.ones((10), dtype=torch.float, requires_grad=False) 163 | self.views_elev = torch.asarray((0, 90, 180, 270, 225, 225, 315, 315, 0, 0), dtype=torch.float, requires_grad=False) 164 | self.views_azim = torch.asarray((0, 0, 0, 0, -45, 45, -45, 45, -90, 90), dtype=torch.float, requires_grad=False) 165 | self.render = Renderer() 166 | self.totensor = ToTensor() 167 | self.norm = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 168 | 169 | def __getitem__(self, idx): 170 | sample = self.file_list[idx] 171 | points = IO.get(os.path.join(self.pc_path, sample['file_path'])).astype(np.float32) 172 | points = self.random_sample(points, self.sample_points_num) 173 | points = self.pc_norm(points) 174 | points = torch.from_numpy(points).float() 175 | 176 | if self.partition == 'test': 177 | return points, cat_labels[sample['taxonomy_id']] 178 | 179 | name = sample['taxonomy_id'] + '_' + sample['model_id'] 180 | rand_idx = random.randint(0, 9) 181 | image = Image.open('./data/rendering/%s/%d.png' % (name, rand_idx)) 182 | image = self.norm(self.totensor(image)) 183 | return image, points, self.views_azim[rand_idx], self.views_elev[rand_idx], self.views_dist[rand_idx] 184 | 185 | 186 | def collate_fn(batch): 187 | r"""Puts each data field into a tensor with outer dimension batch size""" 188 | 189 | elem = batch[0] 190 | elem_type = type(elem) 191 | if isinstance(elem, torch.Tensor): 192 | return torch.stack(batch, 0) 193 | elif elem_type.__module__ == 'pytorch3d.structures.meshes': 194 | return batch 195 | elif isinstance(elem, dict): 196 | return batch 197 | elif isinstance(elem, float): 198 | return torch.tensor(batch, dtype=torch.float64) 199 | elif isinstance(elem, (int)): 200 | return torch.tensor(batch) 201 | elif isinstance(elem, (str, bytes)): 202 | return batch 203 | elif isinstance(elem, collections.abc.Mapping): 204 | return {key: collate_fn([d[key] for d in batch]) for key in elem} 205 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): 206 | return elem_type(*(collate_fn(samples) for samples in zip(*batch))) 207 | elif isinstance(elem, collections.abc.Sequence): 208 | 209 | it = iter(batch) 210 | elem_size = len(next(it)) 211 | if not all(len(elem) == elem_size for elem in it): 212 | raise RuntimeError( 213 | 'each element in list of batch should be of equal size') 214 | transposed = zip(*batch) 215 | return [collate_fn(samples) for samples in transposed] 216 | -------------------------------------------------------------------------------- /render/render.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.renderer.cameras import camera_position_from_spherical_angles 2 | from pytorch3d.renderer import ( 3 | OpenGLPerspectiveCameras, look_at_view_transform, OpenGLOrthographicCameras, 4 | RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams, HardPhongShader, PointsRasterizationSettings, PointsRasterizer, DirectionalLights) 5 | from pytorch3d.transforms import axis_angle_to_matrix 6 | from pytorch3d.renderer.mesh import TexturesAtlas 7 | from pytorch3d.structures import Meshes, Pointclouds 8 | from torch import nn 9 | import numpy as np 10 | from torch.autograd import Variable 11 | import torch 12 | from torchvision.transforms import Normalize 13 | import sys 14 | import os 15 | 16 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 17 | 18 | 19 | ORTHOGONAL_THRESHOLD = 1e-6 20 | EXAHSTION_LIMIT = 20 21 | 22 | 23 | def batch_tensor(tensor, dim=1, squeeze=False): 24 | """ 25 | a function to reshape pytorch tensor `tensor` along some dimension `dim` to the batch dimension 0 such that the tensor can be processed in parallel. 26 | if `sqeeze`=True , the diension `dim` will be removed completelelky, otherwize it will be of size=1. cehck `unbatch_tensor()` for the reverese function 27 | """ 28 | batch_size, dim_size = tensor.shape[0], tensor.shape[dim] 29 | returned_size = list(tensor.shape) 30 | returned_size[0] = batch_size*dim_size 31 | returned_size[dim] = 1 32 | if squeeze: 33 | return tensor.transpose(0, dim).reshape(returned_size).squeeze_(dim) 34 | else: 35 | return tensor.transpose(0, dim).reshape(returned_size) 36 | 37 | 38 | def unbatch_tensor(tensor, batch_size, dim=1, unsqueeze=False): 39 | """ 40 | a function to chunk pytorch tensor `tensor` along the batch dimension 0 and cincatenate the chuncks on dimension `dim` to recover from `batch_tensor()` function. 41 | if `unsqueee`=True , it will add a dimension `dim` before the unbatching 42 | """ 43 | fake_batch_size = tensor.shape[0] 44 | nb_chunks = int(fake_batch_size / batch_size) 45 | if unsqueeze: 46 | return torch.cat(torch.chunk(tensor.unsqueeze_(dim), nb_chunks, dim=0), dim=dim).contiguous() 47 | else: 48 | return torch.cat(torch.chunk(tensor, nb_chunks, dim=0), dim=dim).contiguous() 49 | 50 | 51 | def check_valid_rotation_matrix(R, tol: float = 1e-6): 52 | """ 53 | Determine if R is a valid rotation matrix by checking it satisfies the 54 | following conditions: 55 | ``RR^T = I and det(R) = 1`` 56 | Args: 57 | R: an (N, 3, 3) matrix 58 | Returns: 59 | None 60 | Emits a warning if R is an invalid rotation matrix. 61 | """ 62 | N = R.shape[0] 63 | eye = torch.eye(3, dtype=R.dtype, device=R.device) 64 | eye = eye.view(1, 3, 3).expand(N, -1, -1) 65 | orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol) 66 | det_R = torch.det(R) 67 | no_distortion = torch.allclose(det_R, torch.ones_like(det_R)) 68 | return orthogonal and no_distortion 69 | 70 | 71 | def check_and_correct_rotation_matrix(R, T, nb_trials, azim, elev, dist): 72 | exhastion = 0 73 | while not check_valid_rotation_matrix(R): 74 | exhastion += 1 75 | R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor(elev.T + 90.0 * torch.rand_like(elev.T, device=elev.device), 76 | dim=1, squeeze=True), azim=batch_tensor(azim.T + 180.0 * torch.rand_like(azim.T, device=elev.device), dim=1, squeeze=True)) 77 | 78 | if not check_valid_rotation_matrix(R) and exhastion > nb_trials: 79 | sys.exit("Remedy did not work") 80 | return R, T 81 | 82 | 83 | class Renderer(nn.Module): 84 | """ 85 | The Multi-view differntiable renderer main class that render multiple views differntiably from some given viewpoints. It can render meshes and point clouds as well 86 | Args: 87 | `nb_views` int , The number of views used in the multi-view setup 88 | `image_size` int , The image sizes of the rendered views. 89 | `pc_rendering` : bool , flag to use point cloud rendering instead of mesh rendering 90 | `object_color` : str , The color setup of the objects/points rendered. Choices: ["white", "random","black","red","green","blue", "custom"] 91 | `background_color` : str , The color setup of the rendering background. Choices: ["white", "random","black","red","green","blue", "custom"] 92 | `faces_per_pixel` int , The number of faces rendered per pixel when mesh rendering is used (`pc_rendering` == `False`) . 93 | `points_radius`: float , the radius of the points rendered. The more points in a specific `image_size`, the less radius required for proper rendering. 94 | `points_per_pixel` int , The number of points rendered per pixel when point cloud rendering is used (`pc_rendering` == `True`) . 95 | `light_direction` : str , The setup of the light used in rendering when mesh rendering is available. Choices: ["fixed", "random", "relative"] 96 | `cull_backfaces` : bool , Allow backface-culling when rendering meshes (`pc_rendering` == `False`). 97 | 98 | Returns: 99 | an MVTN object that can render multiple views according to predefined setup 100 | """ 101 | 102 | def __init__(self, image_size=224, points_radius=0.02, points_per_pixel=1): 103 | super().__init__() 104 | self.image_size = image_size 105 | self.points_radius = points_radius 106 | self.points_per_pixel = points_per_pixel 107 | self.normalize = Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 108 | self.light_direction_type = 'random' 109 | 110 | def norm(self, img): # [B, H, W] 111 | detached_img = img.detach() 112 | B, H, W = detached_img.shape 113 | 114 | mask = detached_img > 0 115 | batch_points = detached_img.reshape(B, -1) 116 | batch_max, _ = torch.max(batch_points, dim=1, keepdim=True) 117 | batch_max = batch_max.unsqueeze(-1).repeat(1, H, W) 118 | detached_img[~mask] = 1. 119 | batch_points = detached_img.reshape(B, -1) 120 | batch_min, _ = torch.min(batch_points, dim=1, keepdim=True) 121 | batch_min = batch_min.unsqueeze(-1).repeat(1, H, W) 122 | img = img.sub_(batch_min).div_(batch_max) * 200. / 255. 123 | img[~mask] = 1. 124 | return self.normalize(img.unsqueeze(1).repeat(1, 3, 1, 1)) 125 | 126 | def render_meshes(self, meshes, azim, elev, dist, view, lights, background_color=(1.0, 1.0, 1.0)): 127 | collated_dict = {} 128 | for k in meshes[0].keys(): 129 | collated_dict[k] = [d[k] for d in meshes] 130 | textures = TexturesAtlas(atlas=collated_dict["textures"]) 131 | 132 | new_meshes = Meshes( 133 | verts=collated_dict["verts"], 134 | faces=collated_dict["faces"], 135 | textures=textures, 136 | ).to(lights.device) 137 | 138 | R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor( 139 | elev.T, dim=1, squeeze=True), azim=batch_tensor(azim.T, dim=1, squeeze=True)) 140 | 141 | cameras = OpenGLPerspectiveCameras( 142 | device=lights.device, R=R, T=T) 143 | camera = OpenGLPerspectiveCameras(device=lights.device, R=R[None, 0, ...], 144 | T=T[None, 0, ...]) 145 | 146 | raster_settings = RasterizationSettings( 147 | image_size=self.image_size, 148 | blur_radius=0.0, 149 | faces_per_pixel=1, 150 | cull_backfaces=False, 151 | bin_size=0 152 | ) 153 | renderer = MeshRenderer( 154 | rasterizer=MeshRasterizer( 155 | cameras=camera, raster_settings=raster_settings), 156 | shader=HardPhongShader(blend_params=BlendParams(background_color=background_color), device=lights.device, cameras=camera, lights=lights) 157 | ) 158 | new_meshes = new_meshes.extend(view) 159 | rendered_images = renderer(new_meshes, cameras=cameras, lights=lights) 160 | 161 | rendered_images = unbatch_tensor( 162 | rendered_images, batch_size=view, dim=1, unsqueeze=True).transpose(0, 1) 163 | 164 | rendered_images = rendered_images[..., 165 | 0:3].transpose(2, 4).transpose(3, 4) 166 | return self.normalize(rendered_images) 167 | 168 | def light_direction(self, azim, elev, dist): 169 | if self.light_direction_type == "fixed": 170 | return ((0, 1.0, 0),) 171 | elif self.light_direction_type == "random" and self.training: 172 | return (tuple(1.0 - 2 * np.random.rand(3)),) 173 | else: 174 | relative_view = Variable(camera_position_from_spherical_angles(distance=batch_tensor(dist.T, dim=1, squeeze=True), elevation=batch_tensor( 175 | elev.T, dim=1, squeeze=True), azimuth=batch_tensor(azim.T, dim=1, squeeze=True))).to(torch.float) 176 | 177 | return relative_view 178 | 179 | def render_points(self, points, azim, elev, dist, view, aug=False, rot=False): 180 | views = view * 2 if aug else view 181 | batch_size = points.shape[0] 182 | if aug: 183 | azim = azim.repeat(1, 2) 184 | elev = elev.repeat(1, 2) 185 | rand_dist1 = dist * (1 + (torch.rand((batch_size, 1), device=points.device) - 0.5) / 5) 186 | rand_dist2 = dist * (1 + (torch.rand((batch_size, 1), device=points.device) - 0.5) / 5) 187 | dist = torch.cat([rand_dist1, rand_dist2], dim=1) 188 | 189 | if rot: 190 | rota1 = axis_angle_to_matrix(torch.tensor([0.5 * np.pi, 0, 0])).to(points.device) 191 | rota2 = axis_angle_to_matrix(torch.tensor([0, -0.5 * np.pi, 0])).to(points.device) 192 | # rota1 = axis_angle_to_matrix(torch.tensor([0, - 0.5 * np.pi, 0])).to(points.device) 193 | # rota2 = axis_angle_to_matrix(torch.tensor([0, 0, -0.5 * np.pi])).to(points.device) 194 | points = points @ rota1 @ rota2 195 | 196 | point_cloud = Pointclouds(points=points.to(torch.float)) 197 | 198 | R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor( 199 | elev.T, dim=1, squeeze=True), azim=batch_tensor(azim.T, dim=1, squeeze=True)) 200 | 201 | cameras = OpenGLOrthographicCameras(device=points.device, R=R, T=T, znear=0.01) 202 | raster_settings = PointsRasterizationSettings( 203 | image_size=self.image_size, 204 | radius=self.points_radius, 205 | points_per_pixel=self.points_per_pixel, 206 | bin_size=0 207 | ) 208 | renderer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) 209 | point_cloud = point_cloud.extend(views) 210 | point_cloud.scale_(batch_tensor(1.0/dist.T, dim=1, 211 | squeeze=True)[..., None][..., None].to(points.device)) 212 | 213 | rendered_images = torch.mean(renderer(point_cloud).zbuf, dim=-1) 214 | rendered_images = self.norm(rendered_images) 215 | rendered_images = unbatch_tensor( 216 | rendered_images, batch_size=views, dim=1, unsqueeze=True).transpose(0, 1) 217 | 218 | return rendered_images 219 | 220 | def forward(self, points, azim, elev, dist, view, mesh=None, aug=False, rot=False): 221 | """ 222 | The main rendering function of the MVRenderer class. It can render meshes (if `self.pc_rendering` == `False`) or 3D point clouds(if `self.pc_rendering` == `True`). 223 | Arge: 224 | `meshes`: a list of B `Pytorch3D.Mesh` to be rendered , B batch size. In case not available, just pass `None`. 225 | `points`: B * N * 3 tensor, a batch of B point clouds to be rendered where each point cloud has N points and each point has X,Y,Z property. In case not available, just pass `None` . 226 | `azim`: B * M tensor, a B batch of M azimth angles that represent the azimth angles of the M view-points to render the points or meshes from. 227 | `elev`: B * M tensor, a B batch of M elevation angles that represent the elevation angles of the M view-points to render the points or meshes from. 228 | `dist`: B * M tensor, a B batch of M unit distances that represent the distances of the M view-points to render the points or meshes from. 229 | `color`: B * N * 3 tensor, The RGB colors of batch of point clouds/meshes with N is the number of points/vertices and B batch size. Only if `self.object_color` == `custom`, otherwise this option not used 230 | 231 | """ 232 | rendered_depthes = self.render_points(points=points, azim=azim, elev=elev, dist=dist, view=view, aug=aug, rot=rot) 233 | 234 | if mesh is not None: 235 | background_color = torch.tensor((1.0, 1.0, 1.0), device=points.device) 236 | lights = DirectionalLights(device=points.device, direction=self.light_direction(azim, elev, dist)) 237 | rendered_images = self.render_meshes(meshes=mesh, azim=azim, elev=elev, dist=dist * 2, view=view, lights=lights, background_color=background_color) 238 | return rendered_depthes, rendered_images 239 | 240 | return rendered_depthes 241 | -------------------------------------------------------------------------------- /render/blocks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | from torch.nn import Sequential as Seq, Linear as Lin, Conv1d 5 | 6 | 7 | def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): 8 | """ 9 | activation layer 10 | :param act: 11 | :param inplace: 12 | :param neg_slope: 13 | :param n_prelu: 14 | :return: 15 | """ 16 | 17 | act = act.lower() 18 | if act == 'relu': 19 | layer = nn.ReLU(inplace) 20 | elif act == 'leakyrelu': 21 | layer = nn.LeakyReLU(neg_slope, inplace) 22 | elif act == 'prelu': 23 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 24 | else: 25 | raise NotImplementedError('activation layer [%s] is not found' % act) 26 | return layer 27 | 28 | 29 | # Now, let's implement a sharedMLP layer. It is implmented by using Conv1d with kernel size equals to 1. 30 | class Conv1dLayer(Seq): 31 | def __init__(self, channels, act='relu', norm=True, bias=True): 32 | m = [] 33 | for i in range(1, len(channels)): 34 | m.append(Conv1d(channels[i - 1], channels[i], 1, bias=bias)) 35 | if norm: 36 | m.append(nn.BatchNorm1d(channels[i])) 37 | if act: 38 | m.append(act_layer(act)) 39 | super(Conv1dLayer, self).__init__(*m) 40 | 41 | 42 | class MLP(Seq): 43 | """ 44 | Given input with shape [B, C_in] 45 | return output with shape [B, C_out] 46 | """ 47 | 48 | def __init__(self, channels, act='relu', norm=True, bias=True, dropout=0.5): 49 | # todo: 50 | m = [] 51 | for i in range(1, len(channels)): 52 | m.append(Lin(channels[i - 1], channels[i], bias=bias)) 53 | if norm: 54 | m.append(nn.BatchNorm1d(channels[i])) 55 | if act: 56 | m.append(act_layer(act)) 57 | if dropout > 0: 58 | m.append(nn.Dropout(dropout)) 59 | super(MLP, self).__init__(*m) 60 | 61 | 62 | def knn(x, k): 63 | """ 64 | Given point features x [B, C, N, 1], and number of neighbors k (int) 65 | Return the idx for the k neighbors of each point. 66 | So, the shape of idx: [B, N, k] 67 | """ 68 | with torch.no_grad(): 69 | x = x.squeeze(-1) 70 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 71 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 72 | inner = -xx - inner - xx.transpose(2, 1) 73 | 74 | idx = inner.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 75 | return idx 76 | 77 | 78 | def batched_index_select(x, idx): 79 | """ 80 | This can be used for neighbors features fetching 81 | Given a pointcloud x, return its k neighbors features indicated by a tensor idx. 82 | :param x: torch.Size([batch_size, num_dims, num_vertices, 1]) 83 | :param index: torch.Size([batch_size, num_vertices, k]) 84 | :return: torch.Size([batch_size, num_dims, num_vertices, k]) 85 | """ 86 | 87 | batch_size, num_dims, num_vertices = x.shape[:3] 88 | k = idx.shape[-1] 89 | idx_base = torch.arange( 90 | 0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices 91 | idx = idx + idx_base 92 | idx = idx.view(-1) 93 | 94 | x = x.transpose(2, 1).contiguous() 95 | feature = x.view(batch_size * num_vertices, -1)[idx, :] 96 | feature = feature.view(batch_size, num_vertices, k, 97 | num_dims).permute(0, 3, 1, 2) 98 | return feature 99 | 100 | 101 | def get_center_feature(x, k): 102 | """ 103 | Given you a point cloud, and neighbors k, return the center features. 104 | :param x: torch.Size([batch_size, num_dims, num_vertices, 1]) 105 | :param k: int 106 | :return: torch.Size([batch_size, num_dims, num_vertices, k]) 107 | """ 108 | x = x.repeat(1, 1, 1, k) 109 | return x 110 | 111 | 112 | class Transformation(nn.Module): 113 | def __init__(self, k=3): 114 | super(Transformation, self).__init__() 115 | self.k = k 116 | # Task 2.2.1 T-Net architecture 117 | 118 | # self.convs consists of 3 convolution layer. 119 | # please look at the description above. 120 | 121 | self.convs = Seq(*[Conv1dLayer([self.k, 64], act='relu', norm=True, bias=True), Conv1dLayer( 122 | [64, 128], act='relu', norm=True, bias=True), Conv1dLayer([128, 1024], act=None, norm=False, bias=True)]) 123 | self.fcs = Seq(*[Conv1dLayer([1024, 512], act='relu', norm=True, bias=True), Conv1dLayer([512, 256], act='relu', norm=True, 124 | bias=True), Conv1dLayer([256, self.k*self.k], act=None, norm=False, bias=True)]) # no relu or BN at the last layer. 125 | 126 | def forward(self, x): 127 | # Forward of T-Net architecture 128 | 129 | B, K, N = x.shape # batch-size, dim, number of points 130 | ## forward of shared mlp 131 | # input - B x K x N 132 | # output - B x 1024 x N 133 | 134 | x = self.convs(x) 135 | 136 | ## global max pooling 137 | # input - B x 1024 x N 138 | # output - B x 1024 139 | 140 | x, _ = torch.max(x, 2, keepdim=True) 141 | # print(x.size()) 142 | 143 | ## mlp 144 | # input - B x 1024 145 | # output - B x (K*K) 146 | 147 | x = self.fcs(x) 148 | 149 | ## reshape the transformation matrix to B x K x K 150 | identity = torch.eye(self.k, device=x.device) 151 | x = x.view(B, self.k, self.k) + identity[None] 152 | return x 153 | 154 | 155 | def stn(x, transform_matrix=None): 156 | # spatial transformantion network. this is the matrix multiplication part inside the joint alignment network. 157 | x = x.transpose(2, 1) 158 | x = torch.bmm(x, transform_matrix) 159 | x = x.transpose(2, 1) 160 | return x 161 | 162 | 163 | class OrthoLoss(nn.Module): 164 | def __init__(self): 165 | super(OrthoLoss, self).__init__() 166 | 167 | def forward(self, x): 168 | ## hint: useful function `torch.bmm` or `torch.matmul` 169 | 170 | ## TASK 2.2.2 171 | ## compute the matrix product 172 | # print(x.size(),torch.transpose(x,1,2).size()) 173 | prod = torch.bmm(x, torch.transpose(x, 1, 2)) 174 | 175 | prod = torch.stack([torch.eye(prod.size()[1]) for ii in range( 176 | prod.size()[0])]).to(x.device) - prod # minus 177 | norm = torch.norm(prod, 'fro')**2 178 | return norm 179 | 180 | 181 | class PointNet(nn.Module): 182 | def __init__(self, num_classes=40, alignment=False): 183 | super(PointNet, self).__init__() 184 | # look at the description under 2.2 or refer to the paper if you need more details 185 | 186 | self.alignment = alignment 187 | 188 | ## `input_transform` calculates the input transform matrix of size `3 x 3` 189 | if self.alignment: 190 | self.input_transform = Transformation(3) 191 | 192 | ## TASK 2.3.1 193 | ## define your network layers here 194 | ## local feature 195 | ## one shared mlp layer (shared MLP is actually 1x1 convolution. You can use our conv1dLayer) 196 | ## input size: B x 3 x N 197 | ## output size: B x 64 x N 198 | 199 | self.conv1 = Conv1dLayer([3, 64], act='relu', norm=True, bias=True) 200 | 201 | ## `feature_transform` calculates the feature transform matrix of size `64 x 64` 202 | if self.alignment: 203 | ## TASK 2.3.2 transormation layer 204 | self.feature_transform = Transformation(64) 205 | 206 | ## TASK 2.3.3 207 | ## define your network layers here 208 | ## global feature 209 | ## 2 layers of shared mlp. 64 -> 128 -> 1024 210 | ## input size: B x 64 x N 211 | ## output size: B x 1024 x N 212 | 213 | self.conv2s = Conv1dLayer( 214 | [64, 128, 1024], act='relu', norm=True, bias=True) 215 | 216 | # Task 2.3.4 classification layer 217 | # 3 MLP layers. 1024 -> 512 -> 256 -> num_classes. 218 | # there is a dropout in the second layer. dropout ratio = 05 219 | # no relu or BN at the last layer. 220 | # self.classifier = MLP([1024, 512, 256, num_classes], 221 | # act='relu', norm=True, bias=True, dropout=0.5) 222 | self.classifier = MLP([1024, 512], 223 | act='relu', norm=True, bias=True, dropout=0.5) 224 | 225 | def forward(self, x): 226 | 227 | ## task 2.3.5 apply the input transform in the coordinate domain 228 | if self.alignment: 229 | # get transformation matrix then apply to x 230 | transform = self.input_transform(x) 231 | # apply transorm into the input feature x 232 | x = torch.bmm(transform, x) 233 | 234 | ## forward of shared mlp 235 | # input - B x K x N 236 | # output - B x 64 x N 237 | x = self.conv1(x) 238 | 239 | ## task 2.3.7 another transform in the feauture domain 240 | if self.alignment: 241 | transform = self.feature_transform(x) 242 | x = torch.bmm(transform, x) 243 | else: 244 | transform = None 245 | # local_feature = x # this can be used in segmentation task. we comment it out here. 246 | 247 | ## TASK 2.3.8 248 | ## forward of shared mlp 249 | # input - B x 64 x N 250 | # output - B x 1024 x N 251 | x = self.conv2s(x) 252 | 253 | ## global max pooling 254 | # input - B x 1024 x N 255 | # output - B x 1024 256 | x = torch.max(x, dim=2, keepdim=True)[0] 257 | global_feature = x.view(-1, 1024) 258 | 259 | ## summary: 260 | ## global_feature: B x 1024 261 | ## local_feature: B x 64 x N 262 | ## transform: B x K x K 263 | 264 | # 2.3.10 classification 265 | out = self.classifier(global_feature) 266 | return out, global_feature, transform 267 | 268 | 269 | class Conv2dLayer(Seq): 270 | def __init__(self, channels, act='relu', norm=True, bias=False, kernel_size=1, stride=1, dilation=1, drop=0., groups=1): 271 | m = [] 272 | for i in range(1, len(channels)): 273 | m.append(nn.Conv2d(channels[i - 1], channels[i], bias=bias, 274 | kernel_size=kernel_size, stride=stride, dilation=dilation, groups=groups)) 275 | if norm: 276 | m.append(nn.BatchNorm2d(channels[i])) 277 | if act: 278 | m.append(act_layer(act)) 279 | if drop > 0: 280 | m.append(nn.Dropout2d(drop)) 281 | super(Conv2dLayer, self).__init__(*m) 282 | 283 | 284 | class EdgeConv2d(nn.Module): 285 | """ 286 | Static EdgeConv graph convolution layer (with activation, batch normalization) for point cloud [B, C, N, 1]. 287 | This operation perform the EdgeConv given the knn idx. 288 | input: B, C, N, 1 289 | return: B, C, N, 1 290 | """ 291 | 292 | def __init__(self, in_channels, out_channels, act='leakyrelu', norm=True, bias=False, aggr='max', groups=1): 293 | super(EdgeConv2d, self).__init__() 294 | self.nn = Conv2dLayer([in_channels * 2, out_channels], 295 | act, norm, bias, groups=groups) 296 | if aggr == 'mean': 297 | self.aggr = torch.mean 298 | else: 299 | self.aggr = torch.max 300 | 301 | def forward(self, x, edge_index): 302 | # TASK3.3: Write the forwad pass of EdgeConv. 303 | # use x_j to indicate neighbor features. 304 | x_j = batched_index_select(x, edge_index) 305 | # use x_i to indicate center features. 306 | x_i = get_center_feature(x, edge_index.size()[-1]) 307 | x = self.aggr( 308 | self.nn(torch.cat([x_i, x_i-x_j], dim=1)), dim=3, keepdim=True)[0] 309 | return x 310 | 311 | 312 | class DynEdgeConv2d(EdgeConv2d): 313 | """ 314 | Dynamic EdgeConv graph convolution layer (with activation, batch normalization) for point cloud [B, C, N, 1] 315 | This operaiton will build the knn graph at first, then perform the static EdgeConv 316 | input: B, C, N, 1 317 | return: B, C, N, 1 318 | """ 319 | 320 | def __init__(self, in_channels, out_channels, k=9, act='relu', 321 | norm=True, bias=False, aggr='max'): 322 | super(DynEdgeConv2d, self).__init__(in_channels, 323 | out_channels, act=act, norm=norm, bias=bias, aggr=aggr) 324 | self.k = k 325 | 326 | def forward(self, x): 327 | idx = knn(x, self.k) 328 | x = super(DynEdgeConv2d, self).forward(x, idx) 329 | return x 330 | # 331 | 332 | 333 | class SimpleDGCNN(nn.Module): 334 | def __init__(self, num_classes=40, k=9): 335 | super(SimpleDGCNN, self).__init__() 336 | self.k = k 337 | 338 | # Look at PointNet backbone. 339 | # There are conv1d layer: 3 --> 64 --> 128 -->1024. 340 | # Then MLP classifier. 341 | 342 | # Here we keep the classifier part the same. But change the backbone into dynamic EdgeConv. 343 | # k=9, use relu and bachnormalization. Other parameters keep the default. 344 | self.convs = Seq(*[DynEdgeConv2d(3, 64, k=self.k), DynEdgeConv2d(64, 345 | 128, k=self.k), DynEdgeConv2d(128, 1024, k=self.k)]) 346 | # self.classifier = Seq(*[MLP([1024, 512, 256], act='relu', norm=True, bias=True, dropout=0.5), 347 | # MLP([256, num_classes], act=None, norm=False, bias=True, dropout=0)]) 348 | self.classifier = Seq(*MLP([1024, 512], act='relu', norm=True, bias=True, dropout=0)) 349 | 350 | def forward(self, x): 351 | # x should be [B, C, N, 1] 352 | if len(x.shape) < 4: 353 | x = x.unsqueeze(-1) 354 | 355 | # dynamic edgeConvolution layers 356 | x = self.convs(x) 357 | 358 | # max pooling layer 359 | x = torch.max(x, dim=2, keepdim=True)[0] 360 | global_feature = x.view(-1, 1024) 361 | out = self.classifier(global_feature) 362 | return out, global_feature , None 363 | 364 | 365 | def load_point_ckpt(model, network_name, ckpt_dir='./checkpoint', verbose=True): 366 | # ------------------ load ckpt 367 | filename = '{}/{}_model.pth'.format(ckpt_dir, network_name) 368 | if not os.path.exists(filename): 369 | print("No such checkpoint file as: {}".format(filename)) 370 | return None 371 | state = torch.load(filename) 372 | state['state_dict'] = {k: v.cuda() for k, v in state['state_dict'].items()} 373 | model.load_state_dict(state['state_dict'], strict=False) 374 | # optimizer.load_state_dict(state['optimizer_state_dict']) 375 | # scheduler.load_state_dict(state['scheduler_state_dict']) 376 | if verbose: 377 | print('Succeefullly loaded model from {}'.format(filename)) --------------------------------------------------------------------------------