├── __init__.py ├── utils ├── __init__.py ├── pcp_name_filter.py ├── training_utils.py ├── o3d_draw.py └── preprocess_pcp_knn_patches.py ├── transplant_attn ├── .__init__.py ├── transformer_from_torch.py └── MultiheadAttention_from_torch.py ├── .gitignore ├── model.py ├── metrics.py ├── PcpKnnPatchesDataset.py ├── test_vis_attn_map_3d.py ├── README.md ├── test.py └── train.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transplant_attn/.__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea/ 3 | 4 | __pycache__/ 5 | 6 | dataset_dir/ 7 | 8 | debug_output/ 9 | 10 | logs/ 11 | 12 | paper_ckpts 13 | -------------------------------------------------------------------------------- /utils/pcp_name_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def get_pt_clouds_path(datafolder_read, obj_names, noise_type, noise_intensity='none'): 5 | if noise_type is 'none': 6 | result = [os.path.join(datafolder_read, name) for name in obj_names] 7 | elif noise_type is 'gradient': 8 | result = [os.path.join(datafolder_read, name + '_ddist_minmax') for name in obj_names] 9 | elif noise_type is 'striped': 10 | result = [os.path.join(datafolder_read, name + '_ddist_minmax_layers') for name in obj_names] 11 | elif noise_type is 'white': 12 | result = [os.path.join(datafolder_read, name + '_noise_white_' + '{0:.2e}'.format(float(noise_intensity))) for name in obj_names] 13 | elif noise_type is 'brown': 14 | result = [os.path.join(datafolder_read, name + '_noise_brown_' + '{0:.2e}'.format(float(noise_intensity))) for name in obj_names] 15 | else: 16 | print('noise type not implemented. exit now') 17 | exit() 18 | return result 19 | -------------------------------------------------------------------------------- /utils/training_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | 8 | def set_randomness(args): 9 | if args.truerand is False: 10 | random.seed(args.randseed) 11 | np.random.seed(args.randseed) 12 | torch.manual_seed(args.randseed) 13 | torch.backends.cudnn.deterministic = True 14 | torch.backends.cudnn.benchmark = False 15 | 16 | 17 | def save_checkpoint(epoch, train_loss, test_loss, model, optimizer, path, modelnet='checkpoint'): 18 | savepath = path + '/%s.pth' % modelnet 19 | state = { 20 | 'epoch': epoch, 21 | 'train_loss': train_loss, 22 | 'eval_loss': test_loss, 23 | 'model_state_dict': model.state_dict(), 24 | 'optimizer_state_dict': optimizer.state_dict(), 25 | } 26 | torch.save(state, savepath) 27 | 28 | 29 | def load_ckpt_to_net(ckpt_path, net): 30 | # ckpt = torch.load(ckpt_path) 31 | ckpt = torch.load(ckpt_path, map_location='cuda:0') 32 | weights = ckpt['model_state_dict'] 33 | net.load_state_dict(weights) 34 | return net 35 | -------------------------------------------------------------------------------- /utils/o3d_draw.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import numpy as np 3 | 4 | 5 | def draw_object(pts): 6 | """ 7 | :param pts: (N, 3) 8 | """ 9 | pcd = o3d.geometry.PointCloud() 10 | pcd.points = o3d.utility.Vector3dVector(pts) 11 | o3d.visualization.draw_geometries([pcd]) 12 | 13 | 14 | def draw_object_and_color(pts, colors): 15 | """ 16 | :param pts: (N, 3) 17 | :param colors: (N, 3) 18 | """ 19 | pcd = o3d.geometry.PointCloud() 20 | pcd.points = o3d.utility.Vector3dVector(pts) 21 | pcd.colors = o3d.utility.Vector3dVector(colors) 22 | o3d.visualization.draw_geometries([pcd]) 23 | 24 | 25 | def draw_object_and_normal(pts, normals): 26 | """ 27 | :param pts: (N, 3) 28 | :param normals: (N, 3) 29 | """ 30 | pcd = o3d.geometry.PointCloud() 31 | pcd.points = o3d.utility.Vector3dVector(pts) 32 | pcd.normals = o3d.utility.Vector3dVector(normals) 33 | o3d.visualization.draw_geometries([pcd]) 34 | 35 | 36 | def draw_object_and_color_and_normal(pts, colors, normals): 37 | """ 38 | :param pts: (N, 3) 39 | :param colors: (N, 3) 40 | :param normals: (N, 3) 41 | """ 42 | pcd = o3d.geometry.PointCloud() 43 | pcd.points = o3d.utility.Vector3dVector(pts) 44 | pcd.colors = o3d.utility.Vector3dVector(colors) 45 | pcd.normals = o3d.utility.Vector3dVector(normals) 46 | o3d.visualization.draw_geometries([pcd]) 47 | 48 | 49 | def draw_two_object(pts0, pts1, offset=np.array([2.5, 0, 0])): 50 | """ 51 | :param pts0/1: (N, 3) 52 | :param colors0/1: (N, 3) 53 | :param offset: (1, 3) display the second point cloud by the side. Default to the right. 54 | """ 55 | pcd0 = o3d.geometry.PointCloud() 56 | pcd0.points = o3d.utility.Vector3dVector(pts0) 57 | 58 | pcd1 = o3d.geometry.PointCloud() 59 | pcd1.points = o3d.utility.Vector3dVector(pts1 + offset) 60 | o3d.visualization.draw_geometries([pcd0, pcd1]) 61 | 62 | 63 | def draw_two_object_and_color(pts0, pts1, colors0, colors1, offset=np.array([2.5, 0, 0])): 64 | """ 65 | :param pts0/1: (N, 3) 66 | :param colors0/1: (N, 3) 67 | :param offset: (1, 3) display the second point cloud by the side. Default to the right. 68 | """ 69 | pcd0 = o3d.geometry.PointCloud() 70 | pcd0.points = o3d.utility.Vector3dVector(pts0) 71 | pcd0.colors = o3d.utility.Vector3dVector(colors0) 72 | 73 | pcd1 = o3d.geometry.PointCloud() 74 | pcd1.points = o3d.utility.Vector3dVector(pts1 + offset) 75 | pcd1.colors = o3d.utility.Vector3dVector(colors1) 76 | o3d.visualization.draw_geometries([pcd0, pcd1]) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | # We copied the TransformerEncoder, TransformerEncoderLayer and MultiheadAttention code from pytorch 1.3 code base 9 | # so we can run with PyTorch >= 1.1. 10 | from transplant_attn.transformer_from_torch import TransformerEncoder, TransformerEncoderLayer 11 | from transplant_attn.MultiheadAttention_from_torch import MultiheadAttention 12 | 13 | 14 | def run_attn(attn, x, use_ffn): 15 | """ 16 | :param attn: Attention functions, currently support nn.Multihead and TransformerEncoder 17 | :param x: Input embeddings in shape (B, K, N, C), C denotes the No. of channels for each point, e.g. 512. 18 | :param use_ffn: if choose use_ffn, the input is only x 19 | :return: Soft attn output () and weights. The returned weights is None if use TransformerEncoder 20 | """ 21 | attn_out_list = [] 22 | weights = None 23 | B = x.shape[0] 24 | 25 | if use_ffn: 26 | for b in range(B): 27 | attn_out, weights = attn(x[b]) # x: (K, N, 512), attn_out: (K, N, 512), weights: (N, K, K) 28 | attn_out_list.append(attn_out) 29 | else: 30 | for b in range(B): 31 | attn_out, weights = attn(x[b], x[b], x[b]) # x: (K, N, 512), attn_out: (K, N, 512), weights: (N, K, K) 32 | attn_out_list.append(attn_out) 33 | 34 | # The weights are only for debug and visualisation. 35 | # We just return the (N, K) matrix, not the full (N, K, K) tensor. 36 | if weights is not None: 37 | weights = weights[:, 0, :] 38 | 39 | x = torch.stack(attn_out_list) 40 | return x, weights 41 | 42 | 43 | class NINormalNet(nn.Module): 44 | def __init__(self): 45 | super(NINormalNet, self).__init__() 46 | 47 | self.conv0 = nn.Conv2d(3, 64, kernel_size=(1, 1)) 48 | self.conv1 = nn.Conv2d(64, 256, kernel_size=(1, 1)) 49 | self.conv2 = nn.Conv2d(256, 512, kernel_size=(1, 1)) 50 | 51 | self.bn0 = nn.BatchNorm2d(64) 52 | self.bn1 = nn.BatchNorm2d(256) 53 | self.bn2 = nn.BatchNorm2d(512) 54 | 55 | encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8, dim_feedforward=2048, dropout=0.0) 56 | self.attn = TransformerEncoder(encoder_layer, num_layers=1) 57 | 58 | self.fc0 = nn.Conv1d(512, 256, kernel_size=1) 59 | self.fc1 = nn.Conv1d(256, 64, kernel_size=1) 60 | self.fc2 = nn.Conv1d(64, 3, kernel_size=1) 61 | 62 | self.bn_fc0 = nn.BatchNorm1d(256) 63 | self.bn_fc1 = nn.BatchNorm1d(64) 64 | 65 | # the temperature is just a scalar learnable that controls the softmax strength 66 | self.temp = nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=True) # (1, ) 67 | 68 | def forward(self, pts): 69 | """ 70 | :param pts: (B, K, N, 3) input points 71 | :return: (B, N, 3) normals 72 | """ 73 | x = pts.permute(0, 3, 1, 2) # (B, 3, K, N) 74 | x = F.relu(self.bn0(self.conv0(x))) # (B, C, K, N) 75 | x = F.relu(self.bn1(self.conv1(x))) # (B, C, K, N) 76 | x = F.relu(self.bn2(self.conv2(x))) # (B, C, K, N) 77 | 78 | # learn a temperature 79 | x = x / self.temp 80 | x, weights = run_attn(attn=self.attn, x=x.permute(0, 2, 3, 1), use_ffn=True) 81 | 82 | x, _ = torch.max(x, dim=1) # (B, K, N, 512) -> (B, N, 512) 83 | x = x.transpose(1, 2) # (B, C, N) 84 | x = F.relu(self.bn_fc0(self.fc0(x))) # (B, C, N) 85 | x = F.relu(self.bn_fc1(self.fc1(x))) # (B, C, N) 86 | x = self.fc2(x) # (B, 3, N) 87 | x = x.transpose(1, 2) # (B, N, 3) 88 | 89 | # normalise all normal predictions to unit length, we only care about angle in normal estimation task. 90 | x = F.normalize(x, dim=2) 91 | 92 | return x, weights 93 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | 6 | def comp_ang(pred_n, gt_n): 7 | """ 8 | :param pred_n: (N, 3) 9 | :param gt_n: (N, 3) 10 | :return: a scalar, average angle between predicted normals and gt normals 11 | """ 12 | # for un-orient normal vector, it's fine if it's flipped, cos(theta) = -1 means correct normal. 13 | # clamp() because sometime cosine_similarity generate value slightly larger than 1. 14 | cos_dists = torch.abs(torch.nn.functional.cosine_similarity(pred_n, gt_n, dim=1, eps=1e-8)).clamp(-1.0, +1.0) 15 | angles_rad = torch.acos(cos_dists) 16 | angles_deg = 180 * angles_rad / math.pi 17 | avg_angles = torch.mean(angles_deg) 18 | return avg_angles 19 | 20 | 21 | def comp_ang_batch(pred_n, gt_n): 22 | """ 23 | :param pred_n: (B, N, 3) 24 | :param gt_n: (B, N, 3) 25 | :return: a scalar, average angle between predicted normals and gt normals 26 | """ 27 | # for un-orient normal vector, it's fine if it's flipped, cos(theta) = -1 means correct normal. 28 | # clamp() because sometime cosine_similarity generate value slightly larger than 1. 29 | cos_dists = torch.abs(torch.nn.functional.cosine_similarity(pred_n, gt_n, dim=2, eps=1e-8)).clamp(-1.0, +1.0) 30 | angles_rad = torch.acos(cos_dists) 31 | angles_deg = 180 * angles_rad / float(math.pi) 32 | avg_angles = torch.mean(angles_deg) # this has already divided by B and N 33 | 34 | return avg_angles 35 | 36 | 37 | def comp_pgp(pred_n, gt_n): 38 | """ 39 | :param pred_n: (N, 3) 40 | :param gt_n: (N, 3) 41 | :return: number of normals within certain degree thresholds 42 | """ 43 | cos_dists = torch.abs(torch.nn.functional.cosine_similarity(pred_n, gt_n, dim=1, eps=1e-8)).clamp(-1.0, +1.0) 44 | angles_rad = torch.acos(cos_dists) # always produce angle between (0, pi) 45 | angles_deg = 180 * angles_rad / math.pi 46 | angles_deg = torch.where(angles_deg > 90, 180 - angles_deg, angles_deg) # convert 150 degree -> 30 degree 47 | 48 | pgp003 = torch.sum(angles_deg <= 3) 49 | pgp005 = torch.sum(angles_deg <= 5) 50 | pgp010 = torch.sum(angles_deg <= 10) 51 | pgp030 = torch.sum(angles_deg <= 30) 52 | pgp060 = torch.sum(angles_deg <= 60) 53 | pgp090 = torch.sum(angles_deg <= 90) 54 | 55 | return { 56 | 'pgp003': pgp003, 57 | 'pgp005': pgp005, 58 | 'pgp010': pgp010, 59 | 'pgp030': pgp030, 60 | 'pgp060': pgp060, 61 | 'pgp090': pgp090, 62 | } 63 | 64 | 65 | def comp_pgp_batch_unori(pred_n, gt_n): 66 | """ 67 | :param pred_n: (B, N, 3) 68 | :param gt_n: (B, N, 3) 69 | :return: 70 | """ 71 | cos_dists = torch.abs(torch.nn.functional.cosine_similarity(pred_n, gt_n, dim=2, eps=1e-8)).clamp(-1.0, +1.0) 72 | angles_rad = torch.acos(cos_dists) # always produce angle between (0, pi) 73 | angles_deg = 180 * angles_rad / math.pi 74 | angles_deg = torch.where(angles_deg > 90, 180 - angles_deg, angles_deg) # convert 150 degree -> 30 degree 75 | 76 | pgp003 = torch.sum(angles_deg <= 3) 77 | pgp005 = torch.sum(angles_deg <= 5) 78 | pgp010 = torch.sum(angles_deg <= 10) 79 | pgp030 = torch.sum(angles_deg <= 30) 80 | pgp060 = torch.sum(angles_deg <= 60) 81 | pgp080 = torch.sum(angles_deg <= 80) 82 | pgp090 = torch.sum(angles_deg <= 90) 83 | 84 | return { 85 | 'pgp003': pgp003, 86 | 'pgp005': pgp005, 87 | 'pgp010': pgp010, 88 | 'pgp030': pgp030, 89 | 'pgp060': pgp060, 90 | 'pgp080': pgp080, 91 | 'pgp090': pgp090, 92 | } 93 | 94 | 95 | def comp_rms_angle_batch(pred_n, gt_n): 96 | """ 97 | :param pred_n: (B, N, 3) 98 | :param gt_n: (B, N, 3) 99 | :return: 100 | """ 101 | cos_dists = torch.abs(torch.nn.functional.cosine_similarity(pred_n, gt_n, dim=2, eps=1e-8)).clamp(-1.0, +1.0) 102 | angles_rad = torch.acos(cos_dists) # always produce angle between (0, pi) 103 | angles_deg = 180 * angles_rad / math.pi 104 | angles_deg = torch.where(angles_deg > 90, 180 - angles_deg, angles_deg) # convert 150 degree -> 30 degree 105 | 106 | rms_angle_error = torch.sqrt(angles_deg.pow(2).sum() / (angles_deg.view(-1).shape[0])) 107 | return rms_angle_error.item() 108 | -------------------------------------------------------------------------------- /PcpKnnPatchesDataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | 5 | sys.path.append(os.path.join(sys.path[0], '..')) 6 | # from utils import o3d_draw 7 | import torch 8 | import torch.utils.data 9 | import numpy as np 10 | import h5py 11 | from tqdm import tqdm 12 | 13 | 14 | class PcpKnnPatchesDataset(torch.utils.data.Dataset): 15 | def __init__(self, datafolder, dataset_name, dataset_type, fastdebug, noise_level): 16 | """ 17 | :param datafolder: path contains h5 file 18 | :param dataset_name: xxx_patch_k_xx.h5 19 | :param dataset_type: 'train', 'test', 'eval' 20 | :param fastdebug: True/False to load only a small portion for fast debugging 21 | :param noise_level: 'none', '0.01', '0.05', '0.1', 'all' 22 | 23 | The h5 file has many groups index with a 10 digit patch_id, each group (patch) has: 24 | - 'pts': (N, 3) np.float64 25 | - 'normals': (N, 3) np.float64 26 | - 'knn_pt_list': (N, K, 3) np.float32 27 | 28 | Group name should be object names like in '*_set.txt' in the official PCPNet dataset files. 29 | """ 30 | self.datafolder = datafolder 31 | self.dataset_name = dataset_name 32 | self.dataset_type = dataset_type 33 | self.fastdebug = fastdebug 34 | self.noise_level = noise_level 35 | 36 | if dataset_type == 'train': 37 | self.obj_names = np.genfromtxt(os.path.join(datafolder, 'trainingset_no_noise.txt'), dtype='str') 38 | elif dataset_type == 'test': 39 | self.obj_names = np.genfromtxt(os.path.join(datafolder, 'testset_no_noise.txt'), dtype='str') 40 | elif dataset_type == 'eval': 41 | self.obj_names = np.genfromtxt(os.path.join(datafolder, 'validationset_no_noise.txt'), dtype='str') 42 | 43 | # We are slight abusing the 'patch' word here, this patch only denotes 2k points and their neighbours in 44 | # pre-processing context, this is not the local knn patch that we use to estimate a normal. 45 | 46 | # 100k each, 2k pts per patch, we have 100k/2k = 50 patches per objects and we have 4 noise levels. 47 | if dataset_type == 'train': 48 | # 8 obj * 50 patches (* 4 levels) 49 | self.patches_per_noise_level = 400 50 | elif dataset_type == 'eval': 51 | # 3 obj * 50 patches (* 4 levels) 52 | self.patches_per_noise_level = 150 53 | elif dataset_type == 'test': 54 | # 19 obj * 50 patches (* 4 levels) 55 | self.patches_per_noise_level = 950 56 | 57 | self.h5f_path = os.path.join(self.datafolder, self.dataset_name) 58 | 59 | if dataset_type not in self.dataset_name: 60 | print('dataset name does not match dataset type. exit.') 61 | exit() 62 | 63 | def __len__(self): 64 | if self.fastdebug: 65 | return 16 # 16 patches 66 | else: 67 | return self.patches_per_noise_level if self.noise_level != 'all' else self.patches_per_noise_level*4 68 | 69 | def remap_index_for_noise_level(self, index, noise_level): 70 | if noise_level == 'none' or noise_level == 'all': 71 | return index 72 | elif noise_level == '0.01': 73 | return index + self.patches_per_noise_level 74 | elif noise_level == '0.05': 75 | return index + self.patches_per_noise_level*2 76 | elif noise_level == '0.1': 77 | return index + self.patches_per_noise_level*3 78 | 79 | def __getitem__(self, item): 80 | item = self.remap_index_for_noise_level(item, noise_level=self.noise_level) # picking noise level in h5py file. 81 | with h5py.File(self.h5f_path, 'r', libver='latest') as h5f: 82 | patch = h5f[str(item).zfill(10)] 83 | pts = np.array(patch['pts'], dtype=np.float32) 84 | gt_normals = np.array(patch['normals'], dtype=np.float32) 85 | knn_pt_list = np.array(patch['knn_pt_list'], dtype=np.float32) 86 | 87 | # Slicing first few neighbours and shift them to the origin. 88 | # This is useful when use a dataset has more neighbours. 89 | # knn_pt_list = knn_pt_list[:, :25] # (N, 25, 3) 90 | # knn_pt_list = knn_pt_list - np.mean(knn_pt_list, axis=1, keepdims=True) 91 | 92 | data_input = { 93 | 'pts': pts, 94 | 'knn_pt_list': knn_pt_list, 95 | 'gt_normals': gt_normals 96 | } 97 | 98 | return data_input 99 | 100 | 101 | if __name__ == '__main__': 102 | dataset = PcpKnnPatchesDataset(datafolder='./dataset_dir/pcp_knn_patch_h5', 103 | dataset_type='train', 104 | dataset_name='train_patchsize_2000_k_25.h5', 105 | fastdebug=False, 106 | noise_level='none') 107 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) 108 | 109 | 110 | for batch_id, data in enumerate(dataloader): 111 | print(batch_id) 112 | # if batch_id % 50 == 0: 113 | # pts = data['pts'] 114 | # gt_normals = data['gt_normals'] 115 | # o3d_draw.draw_object_and_normal(pts.squeeze(), gt_normals.squeeze()) 116 | knn_pt_list = data['knn_pt_list'] 117 | print(knn_pt_list) 118 | exit() -------------------------------------------------------------------------------- /test_vis_attn_map_3d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import open3d as o3d 6 | import numpy as np 7 | import sklearn.neighbors 8 | import torch 9 | import torch.nn.parallel 10 | from tqdm import tqdm 11 | 12 | from utils.pcp_name_filter import get_pt_clouds_path 13 | from utils.training_utils import set_randomness, load_ckpt_to_net 14 | from metrics import comp_rms_angle_batch 15 | from model import NINormalNet 16 | 17 | 18 | def norm_pts_to_unit_sphere(pts): 19 | """ 20 | :param pts: N, 3 21 | :return: 22 | pts: N, 3 23 | radius: scalar 24 | """ 25 | pts_range = np.max(pts, axis=0) - np.min(pts, axis=0) 26 | max_range = np.max(pts_range) # this is the max diameter 27 | radius = max_range / 2.0 28 | pts /= radius 29 | return pts, radius 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser('NINormalNetTestVis3DPatch') 34 | parser.add_argument('--gpu_id', default=0, type=int) 35 | parser.add_argument('--multi_gpu', default=True, type=bool) 36 | 37 | parser.add_argument('--datafolder', type=str, default='./dataset_dir/PCPNet_official_dataset', 38 | help='folder contains h5 dataset') 39 | parser.add_argument('--num_neighbours', type=int, default=50) 40 | parser.add_argument('--datatype', type=str, default='test') 41 | 42 | parser.add_argument('--truerand', type=bool, default=False, help="whether we want true randomness") 43 | parser.add_argument('--randseed', default=20, help="set random seed for np, python, and torch") 44 | parser.add_argument('--fastdebug', default=False, action='store_true', help="debug with very small portion data") 45 | parser.add_argument('--ckpt_path', type=str, help="checkpoint_folder", 46 | default='./paper_ckpts/nb50') 47 | return parser.parse_args() 48 | 49 | 50 | def test_one_epoch(model): 51 | model.eval() 52 | 53 | K = args.num_neighbours 54 | noise_type = 'none' 55 | noise_intensity = '0.0' 56 | 57 | obj_names = None 58 | if args.datatype is 'train': 59 | obj_names = np.genfromtxt(os.path.join(args.datafolder, 'trainingset_no_noise.txt'), dtype='str') 60 | elif args.datatype is 'test': 61 | obj_names = np.genfromtxt(os.path.join(args.datafolder, 'testset_no_noise.txt'), dtype='str') 62 | elif args.datatype is 'eval': 63 | obj_names = np.genfromtxt(os.path.join(args.datafolder, 'validationset_no_noise.txt'), dtype='str') 64 | 65 | obj_paths = get_pt_clouds_path(args.datafolder, obj_names, noise_type=noise_type, noise_intensity=noise_intensity) 66 | obj_pts_files = [p + '.xyz' for p in obj_paths] 67 | obj_normals_files = [p + '.normals' for p in obj_paths] 68 | 69 | for i in range(len(obj_pts_files)): 70 | obj_name = obj_names[i] 71 | 72 | # Uncomment this to visualise a specific object. 73 | # if obj_name != 'netsuke100k': 74 | # continue 75 | 76 | print(noise_type, obj_name, obj_pts_files[i]) 77 | 78 | # shift the centre of the point cloud to origin, and normalise the entire point cloud to unit sphere 79 | pts = np.genfromtxt(obj_pts_files[i]) 80 | pts = pts - np.mean(pts, axis=0) 81 | pts, _ = norm_pts_to_unit_sphere(pts) 82 | normals = np.genfromtxt(obj_normals_files[i]) 83 | 84 | N = pts.shape[0] 85 | tree = sklearn.neighbors.KDTree(pts[:N], leaf_size=50) 86 | 87 | # this controls how many patches we would like to vis for each object. 88 | counter = 0 89 | max_count = 7 90 | 91 | for r in tqdm(range(N)): 92 | '''Estimate normal and attn weights for a patch''' 93 | # get neighbours and shift to the origin 94 | _, idx = tree.query(pts[r].reshape(1, 3), k=K) # the first one is itself 95 | knn = pts[idx.squeeze()] 96 | normal_gt = normals[idx.squeeze()][0] # (3, ) 97 | centroid = np.mean(knn, axis=0) 98 | knn_centred = knn - centroid 99 | 100 | knn_model_input = torch.from_numpy(knn_centred) # (K, 3) 101 | knn_model_input = knn_model_input.view(1, K, 1, 3) # (B, K, N, 3) 102 | knn_model_input = knn_model_input.cuda().float() 103 | 104 | # pred_normals: (B, N, 3), weights: (N, K) 105 | pred_normals, weights = model(knn_model_input) # (1, 1, 3), (1, 50) 106 | normal_gt = torch.from_numpy(normal_gt).float().view(1, 1, 3).cuda() 107 | rms_angle = comp_rms_angle_batch(pred_normals, normal_gt) # (1, 1) 108 | 109 | tqdm.write('Angle err {0:.4f}'.format(rms_angle)) 110 | 111 | '''This is vis the patch in the entire point cloud''' 112 | pcd_entire = o3d.geometry.PointCloud() 113 | color = np.zeros_like(pts) 114 | color[:, 1] = 0.3 # all points are green 115 | color[idx.squeeze(), 1] = 0 116 | color[idx.squeeze(), 0] = 1 # selected points are red 117 | color[r] = 0 118 | color[r, 2] = 1 # the point is blue 119 | 120 | pcd_entire.points = o3d.utility.Vector3dVector(pts) 121 | pcd_entire.colors = o3d.utility.Vector3dVector(color) 122 | pcd_entire.normals = o3d.utility.Vector3dVector(normals) 123 | 124 | o3d.visualization.draw_geometries([pcd_entire]) 125 | 126 | '''This is just vis the patch''' 127 | pcd_knn = o3d.geometry.PointCloud() 128 | color = np.zeros_like(knn_centred) 129 | 130 | # we need to amplify all attention weights, otherwise the colour-coding looks dull. 131 | max_weight = weights.max().item() 132 | enlarge_ratio = 1.0 / max_weight 133 | color[:, 0] = weights.squeeze().cpu().numpy() * enlarge_ratio * 7 # selected points are red 134 | color[0] = 0 135 | color[0, 2] = 1 # the point is blue 136 | 137 | pcd_knn.points = o3d.utility.Vector3dVector(knn_centred) 138 | pcd_knn.colors = o3d.utility.Vector3dVector(color) 139 | 140 | o3d.visualization.draw_geometries([pcd_knn]) 141 | 142 | if counter >= max_count: 143 | break 144 | counter += 1 145 | 146 | 147 | def main(args): 148 | '''Model Loading''' 149 | ckpt_file = os.path.join(args.ckpt_path, 'ni_normal_net.pth') 150 | model = NINormalNet() 151 | if args.multi_gpu: 152 | model = torch.nn.DataParallel(model).to(device='cuda:' + str(args.gpu_id)) 153 | else: 154 | model = model.to(device='cuda:'+str(args.gpu_id)) 155 | load_ckpt_to_net(ckpt_file, model) 156 | 157 | '''Testing''' 158 | test_one_epoch(model) 159 | 160 | 161 | if __name__ == '__main__': 162 | args = parse_args() 163 | set_randomness(args) 164 | with torch.no_grad(): 165 | main(args) 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neighbourhood-Insensitive Point Cloud Normal Estimation Network 2 | 3 | **[Project Page](http://ninormal.active.vision/) | 4 | [Paper](https://arxiv.org/abs/2008.09965) | 5 | [Video](https://youtu.be/gxBeR2LBB0k) | 6 | [Supp](http://www.robots.ox.ac.uk/~ryan/bmvc2020/0028_supp.pdf) | 7 | [Data](https://huggingface.co/datasets/active-vision-lab/NINormal/tree/main) | 8 | [Pretrained Models](https://huggingface.co/datasets/active-vision-lab/NINormal/tree/main)** 9 | 10 | Zirui Wang and [Victor Adrian Prisacariu](http://www.robots.ox.ac.uk/~victor/). Active Vision Lab, University of Oxford. 11 | BMVC 2020 (Oral Presentation). 12 | 13 | **Update 30 June 2025**: data and pretrained checkpoints are now available at HuggingFace ([link](https://huggingface.co/datasets/active-vision-lab/NINormal/tree/main)). 14 | 15 | ~**Update**: We use our university's OneDrive to store our pretrained models and the preprocessed dataset. The university just changed the access policy and this stops us sharing our data through a public link so the data and pretrained links above are broken. The easiest fix for now is **you can send your email address to ryan[AT]robots.ox.ac.uk and I'll share it through email**. We will try to find out a way to share it with a link properly later.~ 16 | 17 | ## Environment: 18 | ``` 19 | Python == 3.7 20 | PyTorch >= 1.1.0 21 | CUDA >= 9.0 22 | h5py == 2.10 23 | ``` 24 | 25 | We tried PyTorch 1.1/1.3/1.4/1.5 and CUDA 9.0/9.2/10.0/10.1/10.2. So the code should be able to run as long as you have a modern PyTorch and CUDA installed. 26 | 27 | Setup env: 28 | ``` 29 | conda create -n ninormal python=3.7 30 | conda activate ninormal 31 | conda install pytorch=1.1.0 torchvision cudatoolkit=9.0 -c pytorch # you can change the pytorch and cuda version here. 32 | conda install tqdm future 33 | conda install -c anaconda scikit-learn 34 | pip install h5py==2.10 35 | pip install tensorflow # this is the cpu version, we just need the tensorboard. 36 | ``` 37 | 38 | (Optional) Install `open3d` for visualisation. You might need a physical monitor to install this lib. 39 | ``` 40 | conda install -c open3d-admin open3d 41 | ``` 42 | 43 | Clone this repo: 44 | ``` 45 | git clone https://github.com/ActiveVisionLab/NINormal.git 46 | ``` 47 | 48 | ## Dataset: 49 | We use the PCPNet dataset in our paper. The official PCPNet dataset is available at [here](https://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/). We pre-processed the official PCPNet dataset using scikit-learn's KDTree and wrapped the processed knn point patches in h5 files. For the best training efficiency, we produce an h5 file for each K and each train/test/eval split. 50 | 51 | To simply reproduce our paper results with k=20 or k=50, we provide a subset of our full pre-processed dataset [here](https://unioxfordnexus-my.sharepoint.com/:u:/g/personal/lina3315_ox_ac_uk/Ech7GImZcnhLvawKcARkwCoBbMEa5_I6qLZRsQvkRYCztQ?e=T1sJP1). 52 | 53 | 54 | To fully reproduce our results from k=3 to k=50 (paper Fig. 2), the full pre-processed dataset is available [here](https://unioxfordnexus-my.sharepoint.com/:u:/g/personal/lina3315_ox_ac_uk/EQzIvFRy1PNOnB_aFo6qLQYBT7cr7hygZsom2a87wfukuQ?e=kMj1lK). 55 | 56 | ## Training 57 | Untar the dataset. 58 | ``` 59 | tar -xvf path/to/the/tar.gz 60 | ``` 61 | 62 | We train all our models (except k=40 and k=50) using 3 Nvidia 1080Ti GPUs. For k=40 and k=50, we use 3 Nvidia Titan-RTX GPUs. All models are trained with batch size 6. To reproduce our paper results, set the `--batchsize_train=6` and `--batchsize_eval=6`. Reduce the batch size when out of memory. 63 | 64 | Train with 20 neighbours: 65 | ``` 66 | python train.py \ 67 | --datafolder='path/to/the/folder/contains/h5/files' \ 68 | --batchsize_train=6 \ 69 | --batchsize_eval=6 70 | ``` 71 | 72 | Train with 50 neighbours: 73 | ``` 74 | python train.py \ 75 | --datafolder='path/to/the/folder/contains/h5/files' \ 76 | --train_dataset_name='train_patchsize_2000_k_50.h5' \ 77 | --eval_dataset_name='eval_patchsize_2000_k_50.h5' \ 78 | --batchsize_train=6 \ 79 | --batchsize_eval=6 80 | ``` 81 | 82 | #### Optional: use a symlink 83 | Alternatively, you can create a symlink that points to the downloaded dataset: 84 | ``` 85 | cd NINormal # our repo 86 | mkdir dataset_dir 87 | cd dataset_dir 88 | ln -s path/to/the/folder/contains/h5/files ./pcp_knn_patch_h5_files 89 | ``` 90 | 91 | and train with: 92 | ``` 93 | python train.py 94 | ``` 95 | 96 | ## Note on Batch Size 97 | 98 | The batch size 6 is the batch size that the `Conv2D()` function processes. Our network can be implemented using the `Conv1D()` or `Linear()` but we use the 1x1 `Conv2D()` along with our pre-processed dataset to achieve the best balance between data loading and training. When setting the batch size to 6, the actual batch size our network processes is 6 x 2000 = 12000, as mentioned at the end of Sec. 3 in our paper. The number 2000 is the number of knn patches we packed in a subgroup in an h5 file. See the pre-processing script in `./utils` and the `PcpKnnPatchesDataset` for more details. 99 | 100 | 101 | ## Pretrained Models 102 | Similar to the dataset, we provide a tar file that contains models trained with k=20 and k=50 [here](https://unioxfordnexus-my.sharepoint.com/:u:/g/personal/lina3315_ox_ac_uk/ETepIC914XVPnAbUm1BESTABZb3pOOeOU2JYLlnAWjxeeg?e=1hsGOe). 103 | 104 | To evaluate all models that we present in Fig.2 (k=3 to k=50), download all models [here](https://unioxfordnexus-my.sharepoint.com/:u:/g/personal/lina3315_ox_ac_uk/EZCD7wK19bVMvXK6NGxbPMoBlvaJ_GzOq1szOF4ay7PcDg?e=5Q7ngq). 105 | 106 | ## Testing 107 | Untar the downloaded checkpoints file. 108 | ``` 109 | tar -xvf path/to/the/ckpts/tar.gz 110 | ``` 111 | 112 | **IMPORTANT NOTE**: 113 | The k for trained checkpoints and the k for a dataset must match. E.g. Use the nb20 ckpt with the nb20 dataset: 114 | ``` 115 | python test.py \ 116 | --ckpt_path='/path/to/the/ckpts/nb_20' \ 117 | --test_dataset_name='test_patchsize_2000_k_20.h5' 118 | ``` 119 | 120 | #### Optional: use a symlink 121 | Like the dataset, you can also do a symlink that points to the downloaded checkpoint folder: 122 | ``` 123 | cd NINormal # our repo 124 | ln -s path/to/the/folder/just/extracted ./paper_ckpts 125 | ``` 126 | 127 | and run test with just: 128 | ``` 129 | python test.py 130 | ``` 131 | 132 | ## Attention Weights Visualisation in 3D (paper Fig. 5) 133 | We recommend visualising attention weights using k=50 (with the model trained with k=50 of course...) to see how our network pays extra attention to the boundary of a patch. 134 | 135 | Install the opend3d lib. You might need a PC with a physical monitor to install this library... 136 | ``` 137 | conda install -c open3d-admin open3d 138 | ``` 139 | 140 | Similar to the testing procedure, after got datasets, run: 141 | ``` 142 | python test_vis_attn_map_3d.py --ckpt_path='/path/to/the/ckpts/nb_50' 143 | ``` 144 | 145 | 146 | ## ICP Iteration Experiment (paper Sec. 4.5) 147 | We aim to release it soon. 148 | 149 | ## Acknowledgement 150 | The authors would like to thank 151 | [Min Chen](https://sites.google.com/site/drminchen/home), 152 | [Tengda Han](https://tengdahan.github.io/), 153 | [Shuda Li](https://lishuda.wordpress.com/), 154 | [Tim Yuqing Tang](https://scholar.google.co.uk/citations?user=kQB_dOoAAAAJ&hl=en) and 155 | [Shangzhe Wu](https://elliottwu.com/) 156 | for insightful discussions and proofreading. 157 | 158 | ## Citation 159 | ``` 160 | @inproceedings{wang2020ninormal, 161 | title={Neighbourhood-Insensitive Point Cloud Normal Estimation Network}, 162 | author={Wang, Zirui and Prisacariu, Victor Adrian}, 163 | booktitle={BMVC}, 164 | year={2020} 165 | } 166 | ``` 167 | -------------------------------------------------------------------------------- /utils/preprocess_pcp_knn_patches.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | 5 | import numpy as np 6 | import sklearn.neighbors 7 | import h5py 8 | 9 | sys.path.append(os.path.join(sys.path[0], '..')) 10 | 11 | from utils.pcp_name_filter import get_pt_clouds_path 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser('PCPNet Dataset KNN Preprocessing For Normal') 16 | parser.add_argument('--datafolder_read', type=str, default='./dataset_dir/PCPNet_official_dataset') 17 | parser.add_argument('--datafolder_write', type=str, default='./') 18 | parser.add_argument('--K', default=25, type=int) 19 | 20 | return parser.parse_args() 21 | 22 | 23 | def norm_pts_to_unit_sphere(pts): 24 | """ 25 | :param pts: (N, 3) 26 | :return: 27 | pts: (N, 3) 28 | radius: scalar 29 | """ 30 | pts_range = np.max(pts, axis=0) - np.min(pts, axis=0) 31 | max_range = np.max(pts_range) # this is the max diameter 32 | radius = max_range / 2.0 33 | pts /= radius 34 | return pts, radius 35 | 36 | 37 | def comp_knn_pts(pts, K): 38 | """ 39 | :param pts: (N, 3) np array 40 | :param K: python int 41 | :return: (N, K, 3) np array 42 | """ 43 | N = pts.shape[0] 44 | 45 | tree = sklearn.neighbors.KDTree(pts[:N], leaf_size=50) 46 | 47 | knn_pts_list = np.zeros((N, K, 3), dtype=np.float32) 48 | _, idx = tree.query(pts, k=K) # idx is (N, K) 49 | 50 | for r in range(N): 51 | knn_pts_list[r] = pts[idx[r]] 52 | 53 | centroids = np.mean(knn_pts_list, axis=1, keepdims=True) # (N, 1, 3) 54 | knn_pts_list = knn_pts_list - centroids # (N, K, 3) 55 | 56 | return knn_pts_list # (N, K, 3) 57 | 58 | 59 | def store_obj_pt_cloud(h5f, pts, normals, knn_pt_list, global_patch_counter, patch_size): 60 | N = pts.shape[0] # 100K points 61 | 62 | num_patches = int(N/patch_size) 63 | for i in range(num_patches): 64 | grp = h5f.create_group(str(global_patch_counter).zfill(10)) 65 | grp.create_dataset('pts', data=pts[i*patch_size:(i+1)*patch_size]) 66 | grp.create_dataset('normals', data=normals[i*patch_size:(i+1)*patch_size]) 67 | grp.create_dataset('knn_pt_list', data=knn_pt_list[i*patch_size:(i+1)*patch_size]) 68 | global_patch_counter += 1 69 | return global_patch_counter 70 | 71 | 72 | if __name__ == '__main__': 73 | args = parse_args() 74 | 75 | dataset_types = ['train', 'test', 'eval'] 76 | # noise_types = ['none', 'white', 'gradient', 'striped'] 77 | noise_types = ['none', 'white'] 78 | # noise_types = ['gradient', 'striped'] 79 | noise_levels = [0.01, 0.05, 0.1] 80 | 81 | # # Use this setting to debug the preprocessing code faster. 82 | # dataset_types = ['train'] 83 | # noise_types = ['none'] 84 | 85 | # We have 2000 points and their knn neighbours in a h5 sub-group. 86 | patch_size = 2000 87 | 88 | for dataset_type in dataset_types: 89 | if dataset_type is 'train': 90 | obj_names = np.genfromtxt(os.path.join(args.datafolder_read, 'trainingset_no_noise.txt'), dtype='str') 91 | elif dataset_type is 'test': 92 | obj_names = np.genfromtxt(os.path.join(args.datafolder_read, 'testset_no_noise.txt'), dtype='str') 93 | elif dataset_type is 'eval': 94 | obj_names = np.genfromtxt(os.path.join(args.datafolder_read, 'validationset_no_noise.txt'), dtype='str') 95 | 96 | dataset_name = dataset_type + '_patchsize_' + str(patch_size) + '_k_' + str(args.K) + '.h5' 97 | if os.path.exists(os.path.join(args.datafolder_write, dataset_name)): 98 | print("have dataset file, exit to avoid overwriting.") 99 | exit() 100 | else: 101 | h5f = h5py.File(os.path.join(args.datafolder_write, dataset_name), 'w', libver='latest') 102 | 103 | global_patch_counter = 0 104 | 105 | for noise_type in noise_types: 106 | if noise_type == 'none': 107 | obj_paths = get_pt_clouds_path(args.datafolder_read, obj_names, noise_type=noise_type) 108 | 109 | obj_pts_files = [p + '.xyz' for p in obj_paths] 110 | obj_normals_files = [p + '.normals' for p in obj_paths] 111 | 112 | for i in range(len(obj_pts_files)): 113 | obj_name = obj_names[i] 114 | print(noise_type, obj_names[i], obj_pts_files[i]) 115 | 116 | # shift the centre of the point cloud to origin 117 | pts = np.genfromtxt(obj_pts_files[i]) 118 | pts = pts - np.mean(pts, axis=0) 119 | pts, _ = norm_pts_to_unit_sphere(pts) 120 | knn_pts_list = comp_knn_pts(pts, args.K) 121 | normals = np.genfromtxt(obj_normals_files[i]) 122 | 123 | global_patch_counter = store_obj_pt_cloud(h5f, pts, normals, knn_pts_list, global_patch_counter, patch_size) 124 | print("global counter: ", global_patch_counter) 125 | 126 | if noise_type == 'white' or noise_type == 'brown': 127 | for noise_level in noise_levels: 128 | obj_paths = get_pt_clouds_path(args.datafolder_read, obj_names, noise_type=noise_type, noise_intensity=noise_level) 129 | 130 | obj_pts_files = [p + '.xyz' for p in obj_paths] 131 | obj_normals_files = [p + '.normals' for p in obj_paths] 132 | 133 | for i in range(len(obj_pts_files)): 134 | obj_name = obj_names[i] 135 | print(noise_type, noise_level, obj_names[i], obj_pts_files[i]) 136 | 137 | # shift the centre of the point cloud to origin 138 | pts = np.genfromtxt(obj_pts_files[i]) 139 | pts = pts - np.mean(pts, axis=0) 140 | pts, _ = norm_pts_to_unit_sphere(pts) 141 | knn_pts_list = comp_knn_pts(pts, args.K) 142 | normals = np.genfromtxt(obj_normals_files[i]) 143 | 144 | global_patch_counter = store_obj_pt_cloud(h5f, pts, normals, knn_pts_list, global_patch_counter, patch_size) 145 | print("global counter: ", global_patch_counter) 146 | 147 | if noise_type == 'gradient' or noise_type == 'striped': 148 | obj_paths = get_pt_clouds_path(args.datafolder_read, obj_names, noise_type=noise_type) 149 | 150 | obj_pts_files = [p + '.xyz' for p in obj_paths] 151 | obj_normals_files = [p + '.normals' for p in obj_paths] 152 | 153 | # the analytic ones does not have gradient and striped version 154 | obj_pts_files = [f for f in obj_pts_files if 'analytic' not in f] 155 | obj_normals_files = [f for f in obj_normals_files if 'analytic' not in f] 156 | 157 | for i in range(len(obj_pts_files)): 158 | obj_name = obj_names[i] 159 | print(noise_type, obj_names[i], obj_pts_files[i]) 160 | 161 | # shift the centre of the point cloud to origin 162 | pts = np.genfromtxt(obj_pts_files[i]) 163 | pts = pts - np.mean(pts, axis=0) 164 | pts, _ = norm_pts_to_unit_sphere(pts) 165 | knn_pts_list = comp_knn_pts(pts, args.K) 166 | normals = np.genfromtxt(obj_normals_files[i]) 167 | 168 | global_patch_counter = store_obj_pt_cloud(h5f, pts, normals, knn_pts_list, global_patch_counter, patch_size) 169 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import math 5 | from timeit import default_timer as timer 6 | from pathlib import Path 7 | import shutil 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.parallel 13 | from torch.utils.data import DataLoader 14 | 15 | from PcpKnnPatchesDataset import PcpKnnPatchesDataset 16 | from utils.training_utils import set_randomness, load_ckpt_to_net 17 | import metrics 18 | from model import NINormalNet 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser('NINormalNetTest') 23 | parser.add_argument('--batchsize_test', type=int, default=10, 24 | help='batch size in testing, this number should be dividable by 50, e.g, 1, 2, 5, 10, 25') 25 | parser.add_argument('--gpu_id', default=0, type=int) 26 | parser.add_argument('--multi_gpu', default=True, type=bool) 27 | 28 | parser.add_argument('--datafolder', type=str, default='./dataset_dir/pcp_knn_patch_h5_files', 29 | help='folder contains h5 dataset') 30 | parser.add_argument('--datatype', type=str, default='test') 31 | parser.add_argument('--test_dataset_name', type=str, default='test_patchsize_2000_k_20.h5') 32 | parser.add_argument('--test_noise_level', type=str, default='none', choices=['none', '0.01', '0.05', '0.1', 'all']) 33 | 34 | parser.add_argument('--truerand', type=bool, default=False, help="whether we want true randomness") 35 | parser.add_argument('--randseed', default=20, help="set random seed for np, python, and torch") 36 | parser.add_argument('--fastdebug', default=False, action='store_true', help="debug with very small portion data") 37 | parser.add_argument('--ckpt_path', type=str, help="checkpoint folder", 38 | default='./paper_ckpts/nb20') 39 | 40 | parser.add_argument('--normal_out', type=bool, default=False, help="whether output estimated normals to files") 41 | return parser.parse_args() 42 | 43 | 44 | def test_one_epoch(test_dataloader, model, normal_out_path): 45 | model.eval() 46 | test_pgp003_epoch, test_pgp005_epoch, test_pgp010_epoch, test_pgp030_epoch, test_pgp060_epoch, test_pgp080_epoch, test_pgp090_epoch = 0, 0, 0, 0, 0, 0, 0 47 | test_pgp003_obj, test_pgp005_obj, test_pgp010_obj, test_pgp030_obj, test_pgp060_obj, test_pgp080_obj, test_pgp090_obj = 0, 0, 0, 0, 0, 0, 0 48 | obj_names = test_dataloader.dataset.obj_names 49 | obj_count = 0 50 | rms_angle_list_obj = [] 51 | rms_angle_list_patch = [] 52 | 53 | patch_count = 0 54 | 55 | # accumulate normal predictions for one object and write it to a txt file. 56 | pred_normals_one_obj = [] 57 | 58 | # profile inference time for each object, and compute an average time. 59 | object_time = 0 60 | time_list = [] 61 | 62 | print('Obj, pgp003, pgp005, pgp010, pgp030, pgp060, pgp080, pgp090, rms_angle, time') 63 | for batch_id, data in enumerate(test_dataloader): 64 | gt_normals = data['gt_normals'] # (B, N, 3) 65 | knn_pts = data['knn_pt_list'] # (B, N, K, 3) 66 | 67 | knn_pts = knn_pts.transpose(1, 2).to(device='cuda:' + str(args.gpu_id)) # (B, K, N, 3) 68 | gt_normals = gt_normals.to(device='cuda:' + str(args.gpu_id)) 69 | 70 | start = timer() 71 | pred_normals, weights = model(knn_pts) # weights: (N, K) 72 | end = timer() 73 | duration = end - start 74 | object_time += duration 75 | 76 | # for txt normal file writing 77 | pred_normals_one_obj.append(pred_normals.reshape(-1, 3)) 78 | 79 | pgp_ang_dic = metrics.comp_pgp_batch_unori(pred_normals, gt_normals) 80 | rms_angle = metrics.comp_rms_angle_batch(pred_normals, gt_normals) 81 | rms_angle_list_patch.append(rms_angle) 82 | 83 | test_pgp003_obj += pgp_ang_dic['pgp003'].item() 84 | test_pgp005_obj += pgp_ang_dic['pgp005'].item() 85 | test_pgp010_obj += pgp_ang_dic['pgp010'].item() 86 | test_pgp030_obj += pgp_ang_dic['pgp030'].item() 87 | test_pgp060_obj += pgp_ang_dic['pgp060'].item() 88 | test_pgp080_obj += pgp_ang_dic['pgp080'].item() 89 | test_pgp090_obj += pgp_ang_dic['pgp090'].item() 90 | 91 | patch_count += test_dataloader.batch_size 92 | if patch_count == 50: 93 | 94 | if args.normal_out: 95 | pred_normals_one_obj = torch.cat(pred_normals_one_obj, axis=0).cpu().numpy() 96 | np.savetxt(os.path.join(normal_out_path, obj_names[obj_count]+'.normals'), pred_normals_one_obj) 97 | pred_normals_one_obj = [] # clean it and prepare for the next object. 98 | 99 | rms_angle_obj = np.mean(rms_angle_list_patch) 100 | rms_angle_list_obj.append(rms_angle_obj) 101 | print('{0:25s}, {1:.2%}, {2:.2%}, {3:.2%}, {4:.2%}, {5:.2%}, {6:.2%}, {7:.2%}, {8:.2f}, {9:.2f}sec'.format(obj_names[obj_count], 102 | test_pgp003_obj / float(100000), 103 | test_pgp005_obj / float(100000), 104 | test_pgp010_obj / float(100000), 105 | test_pgp030_obj / float(100000), 106 | test_pgp060_obj / float(100000), 107 | test_pgp080_obj / float(100000), 108 | test_pgp090_obj / float(100000), 109 | rms_angle_obj, 110 | object_time)) 111 | test_pgp003_epoch += test_pgp003_obj 112 | test_pgp005_epoch += test_pgp005_obj 113 | test_pgp010_epoch += test_pgp010_obj 114 | test_pgp030_epoch += test_pgp030_obj 115 | test_pgp060_epoch += test_pgp060_obj 116 | test_pgp080_epoch += test_pgp080_obj 117 | test_pgp090_epoch += test_pgp090_obj 118 | 119 | test_pgp003_obj = 0 120 | test_pgp005_obj = 0 121 | test_pgp010_obj = 0 122 | test_pgp030_obj = 0 123 | test_pgp060_obj = 0 124 | test_pgp080_obj = 0 125 | test_pgp090_obj = 0 126 | 127 | obj_count += 1 128 | patch_count = 0 129 | rms_angle_list_patch = [] 130 | 131 | time_list.append(object_time) 132 | object_time = 0.0 133 | 134 | test_pgp003_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 135 | test_pgp005_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 136 | test_pgp010_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 137 | test_pgp030_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 138 | test_pgp060_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 139 | test_pgp080_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 140 | test_pgp090_epoch /= float(len(test_dataloader.dataset) * 2000) # 100k is the number of points in each object 141 | 142 | # print("------------------------") 143 | print('Total, {0:.2%}, {1:.2%}, {2:.2%}, {3:.2%}, {4:.2%}, {5:.2%}, {6:.2%}'.format(test_pgp003_epoch, 144 | test_pgp005_epoch, 145 | test_pgp010_epoch, 146 | test_pgp030_epoch, 147 | test_pgp060_epoch, 148 | test_pgp080_epoch, 149 | test_pgp090_epoch)) 150 | 151 | print('Mean shape RMS angle error: {0:.2f}'.format(np.mean(rms_angle_list_obj))) 152 | print('Mean time for all objects: {0:.2f}'.format(np.mean(time_list))) 153 | 154 | if obj_count != len(obj_names) or len(obj_names) != len(rms_angle_list_obj): 155 | print('Warning: number of object is not correct. Need to double check. Exit now.') 156 | exit() 157 | 158 | def main(args): 159 | normal_out_path = None 160 | if args.normal_out: 161 | # Prepare the path to write normal predictions to txt files. 162 | # If the normal txt folder exists, we remove all '.normal' files and 'test.py' and 'testlog.txt' file inside it, 163 | # otherwise we create the folder straightway. 164 | # Finally we copy this test file to the folder for record. 165 | normal_out_path = Path(os.path.join(args.ckpt_path, 'normal_txts')) 166 | if os.path.exists(normal_out_path): 167 | ans = input('Already have normal estimation txt folder, OVERWRITE? (yes/n) ') 168 | if ans == 'yes': 169 | filelist = [f for f in os.listdir(normal_out_path) if f.endswith('.normals') or f =='test.py' or f == 'testlog.txt'] 170 | for f in filelist: 171 | print('removed file: ', os.path.join(normal_out_path, f)) 172 | os.remove(os.path.join(normal_out_path, f)) 173 | else: 174 | print("have normal estimation txt folder, exit to prevent overwrite.") 175 | exit() 176 | else: 177 | normal_out_path.mkdir(parents=False, exist_ok=False) 178 | shutil.copy('./test.py', normal_out_path) 179 | 180 | '''LOG''' 181 | logger = logging.getLogger("NINormalNetTest") 182 | logger.setLevel(logging.INFO) 183 | file_handler = logging.FileHandler(os.path.join(normal_out_path, 'testlog.txt')) 184 | logger.addHandler(file_handler) 185 | logger.info(args) 186 | 187 | '''Data Loading''' 188 | test_dataset = PcpKnnPatchesDataset(datafolder=args.datafolder, 189 | dataset_type=args.datatype, 190 | dataset_name=args.test_dataset_name, 191 | fastdebug=args.fastdebug, 192 | noise_level=args.test_noise_level) 193 | test_dataloader = DataLoader(test_dataset, batch_size=args.batchsize_test, shuffle=False, num_workers=8) 194 | 195 | '''Model Loading''' 196 | ckpt_file = os.path.join(args.ckpt_path, 'ni_normal_net.pth') 197 | model = NINormalNet() 198 | 199 | if args.multi_gpu: 200 | model = torch.nn.DataParallel(model).to(device='cuda:' + str(args.gpu_id)) 201 | else: 202 | model = model.to(device='cuda:'+str(args.gpu_id)) 203 | 204 | load_ckpt_to_net(ckpt_file, model) 205 | 206 | '''Testing''' 207 | test_one_epoch(test_dataloader, model, normal_out_path) 208 | 209 | 210 | if __name__ == '__main__': 211 | args = parse_args() 212 | set_randomness(args) 213 | with torch.no_grad(): 214 | main(args) 215 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import logging 5 | from pathlib import Path 6 | import datetime 7 | import shutil 8 | 9 | import torch 10 | import torch.nn.parallel 11 | from torch.utils.data import DataLoader 12 | from torch.utils.tensorboard import SummaryWriter 13 | from tqdm import tqdm 14 | 15 | from PcpKnnPatchesDataset import PcpKnnPatchesDataset 16 | from utils.training_utils import set_randomness, save_checkpoint 17 | import metrics 18 | from model import NINormalNet 19 | 20 | 21 | def gen_detail_name(args): 22 | num_nb = int(os.path.splitext(args.train_dataset_name)[0].split('_')[-1]) 23 | outstr = 'lr_' + str(args.learning_rate) + \ 24 | '_train_batch_' + str(args.batchsize_train) + \ 25 | '_train_noise_' + args.train_noise_level + \ 26 | '_eval_noise_' + args.eval_noise_level + \ 27 | '_L2_reg_' + str(args.L2_reg) + \ 28 | '_gpu' + str(args.gpu_id) + \ 29 | '_randseed_' + str(args.randseed) + \ 30 | '_num_nb_' + str(num_nb) + \ 31 | '_' + str(args.alias) + \ 32 | '_' + str(datetime.datetime.now().strftime('%Y%m%d_%H%M')) 33 | return outstr 34 | 35 | 36 | def comp_loss_terms(pred, gt): 37 | """ 38 | :param pred: (B, N, 3) 39 | :param gt: (B, N, 3) 40 | :return: (1, ) 41 | """ 42 | sin_loss = torch.mean(torch.norm(torch.cross(pred, gt, dim=2), dim=2) / (torch.norm(pred, dim=2) * torch.norm(gt, dim=2))) 43 | return sin_loss 44 | 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser('NINormalNetTrain') 48 | parser.add_argument('--batchsize_train', type=int, default=6, 49 | help='set to 6 (if you have 3x1080Ti) to reproduce paper results, reduce if OOM') 50 | parser.add_argument('--batchsize_eval', type=int, default=6, help='default to 6, reduce if OOM') 51 | parser.add_argument('--epoch', default=900, type=int) 52 | parser.add_argument('--learning_rate', default=0.0005, type=float) 53 | parser.add_argument('--milestones', default=[400, 800], type=int, nargs='+', 54 | help='lr schedule milestones in training') 55 | parser.add_argument('--lr_gamma', type=float, default=0.1, help="lr milestones gamma") 56 | parser.add_argument('--fastdebug', default=False, action='store_true', help="debug with very small portion data") 57 | parser.add_argument('--gpu_id', default=0, type=int) 58 | parser.add_argument('--multi_gpu', default=True, type=bool) 59 | 60 | parser.add_argument('--datafolder', type=str, default='./dataset_dir/pcp_knn_patch_h5_files', 61 | help='folder contains h5 dataset') 62 | parser.add_argument('--train_dataset_name', type=str, default='train_patchsize_2000_k_20.h5') 63 | parser.add_argument('--eval_dataset_name', type=str, default='eval_patchsize_2000_k_20.h5') 64 | parser.add_argument('--train_noise_level', type=str, default='none', choices=['none', '0.01', '0.05', '0.1', 'all']) 65 | parser.add_argument('--eval_noise_level', type=str, default='none', choices=['none', '0.01', '0.05', '0.1', 'all']) 66 | 67 | parser.add_argument('--L2_reg', type=float, default=1e-4, help='L2 regularisation for network weights') 68 | parser.add_argument('--model_name', default='ni_normal_net', help='for checkpoint saving') 69 | 70 | parser.add_argument('--truerand', default=False, type=bool, help="whether we want true randomness") 71 | parser.add_argument('--randseed', default=20, help="set random seed for np, python, and torch") 72 | 73 | parser.add_argument('--alias', type=str, default='', help="specify experiments") 74 | return parser.parse_args() 75 | 76 | 77 | def train_one_epoch(train_dataloader, optimizer, model, args): 78 | model.train() 79 | 80 | train_loss_epoch = 0 81 | train_pgp003_epoch, train_pgp005_epoch, train_pgp010_epoch, train_pgp030_epoch, train_pgp060_epoch, train_pgp080_epoch, train_pgp090_epoch = 0, 0, 0, 0, 0, 0, 0 82 | 83 | for batch_id, data in enumerate(tqdm(train_dataloader, desc='train_batch')): 84 | knn_pts = data['knn_pt_list'] # (B, N, K, 3) 85 | knn_pts = knn_pts.transpose(1, 2).contiguous() # (B, K, N, 3) 86 | knn_pts = knn_pts.to(device='cuda:' + str(args.gpu_id), non_blocking=True) # (B, K, N, 3) 87 | 88 | gt_normals = data['gt_normals'] # (B, N, 3) 89 | gt_normals = gt_normals.to(device='cuda:'+str(args.gpu_id), non_blocking=True) # (B, N, 3) 90 | 91 | pred_normals, weights = model(knn_pts) 92 | 93 | loss = comp_loss_terms(pred_normals, gt_normals) 94 | loss.backward() 95 | optimizer.step() 96 | optimizer.zero_grad() 97 | 98 | '''Error analysis''' 99 | pgp_ang_dic = metrics.comp_pgp_batch_unori(pred_normals, gt_normals) 100 | train_loss_epoch += loss.item() 101 | 102 | train_pgp003_epoch += pgp_ang_dic['pgp003'].item() 103 | train_pgp005_epoch += pgp_ang_dic['pgp005'].item() 104 | train_pgp010_epoch += pgp_ang_dic['pgp010'].item() 105 | train_pgp030_epoch += pgp_ang_dic['pgp030'].item() 106 | train_pgp060_epoch += pgp_ang_dic['pgp060'].item() 107 | train_pgp080_epoch += pgp_ang_dic['pgp080'].item() 108 | train_pgp090_epoch += pgp_ang_dic['pgp090'].item() 109 | 110 | train_loss_epoch /= len(train_dataloader) 111 | 112 | train_pgp003_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 113 | train_pgp005_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 114 | train_pgp010_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 115 | train_pgp030_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 116 | train_pgp060_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 117 | train_pgp080_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 118 | train_pgp090_epoch /= float(len(train_dataloader.dataset) * 2000) # 2k is the number of points in each patch 119 | 120 | return { 121 | 'loss': train_loss_epoch, 122 | 'pgp003': train_pgp003_epoch, 123 | 'pgp005': train_pgp005_epoch, 124 | 'pgp010': train_pgp010_epoch, 125 | 'pgp030': train_pgp030_epoch, 126 | 'pgp060': train_pgp060_epoch, 127 | 'pgp080': train_pgp080_epoch, 128 | 'pgp090': train_pgp090_epoch, 129 | } 130 | 131 | 132 | def eval_one_epoch(eval_dataloader, model, args): 133 | model.eval() 134 | eval_loss_epoch = 0 135 | eval_pgp003_epoch, eval_pgp005_epoch, eval_pgp010_epoch, eval_pgp030_epoch, eval_pgp060_epoch, eval_pgp080_epoch, eval_pgp090_epoch = 0, 0, 0, 0, 0, 0, 0 136 | 137 | for batch_id, data in enumerate(eval_dataloader): 138 | knn_pts = data['knn_pt_list'] # (B, N, K, 3) 139 | knn_pts = knn_pts.transpose(1, 2).contiguous() # (B, K, N, 3) 140 | knn_pts = knn_pts.to(device='cuda:' + str(args.gpu_id), non_blocking=True) # (B, K, N, 3) 141 | 142 | gt_normals = data['gt_normals'] # (B, N, 3) 143 | gt_normals = gt_normals.to(device='cuda:' + str(args.gpu_id), non_blocking=True) 144 | 145 | pred_normals, weights = model(knn_pts) 146 | 147 | loss = comp_loss_terms(pred_normals, gt_normals) 148 | 149 | '''Error analysis''' 150 | pgp_ang_dic = metrics.comp_pgp_batch_unori(pred_normals, gt_normals) 151 | eval_loss_epoch += loss.item() 152 | 153 | eval_pgp003_epoch += pgp_ang_dic['pgp003'].item() 154 | eval_pgp005_epoch += pgp_ang_dic['pgp005'].item() 155 | eval_pgp010_epoch += pgp_ang_dic['pgp010'].item() 156 | eval_pgp030_epoch += pgp_ang_dic['pgp030'].item() 157 | eval_pgp060_epoch += pgp_ang_dic['pgp060'].item() 158 | eval_pgp080_epoch += pgp_ang_dic['pgp080'].item() 159 | eval_pgp090_epoch += pgp_ang_dic['pgp090'].item() 160 | 161 | eval_loss_epoch /= len(eval_dataloader) 162 | 163 | eval_pgp003_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 164 | eval_pgp005_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 165 | eval_pgp010_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 166 | eval_pgp030_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 167 | eval_pgp060_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 168 | eval_pgp080_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 169 | eval_pgp090_epoch /= float(len(eval_dataloader.dataset) * 2000) # 100k is the number of points in each object 170 | 171 | return { 172 | 'loss': eval_loss_epoch, 173 | 'pgp003': eval_pgp003_epoch, 174 | 'pgp005': eval_pgp005_epoch, 175 | 'pgp010': eval_pgp010_epoch, 176 | 'pgp030': eval_pgp030_epoch, 177 | 'pgp060': eval_pgp060_epoch, 178 | 'pgp080': eval_pgp080_epoch, 179 | 'pgp090': eval_pgp090_epoch, 180 | } 181 | 182 | 183 | def main(args): 184 | '''Create the log dir for this run''' 185 | exp_root_dir = Path('./logs/') 186 | exp_root_dir.mkdir(parents=True, exist_ok=True) 187 | experiment_dir = Path(os.path.join(exp_root_dir, gen_detail_name(args))) 188 | experiment_dir.mkdir(parents=True, exist_ok=True) 189 | 190 | # copy train and model file to ensure reproducibility. 191 | shutil.copy('./model.py', experiment_dir) 192 | shutil.copy('./train.py', experiment_dir) 193 | 194 | '''Logger''' 195 | logger = logging.getLogger("NINormalNetTrain") 196 | logger.setLevel(logging.INFO) 197 | file_handler = logging.FileHandler(os.path.join(experiment_dir, 'log.txt')) 198 | file_handler.setLevel(logging.INFO) 199 | logger.addHandler(file_handler) 200 | logger.info(args) 201 | 202 | '''Summary Writer''' 203 | writer = SummaryWriter(log_dir=experiment_dir) 204 | 205 | '''Data Loading''' 206 | logger.info('Load dataset ...') 207 | train_dataset = PcpKnnPatchesDataset(datafolder=args.datafolder, 208 | dataset_type='train', 209 | dataset_name=args.train_dataset_name, 210 | fastdebug=args.fastdebug, 211 | noise_level=args.train_noise_level) 212 | 213 | eval_dataset = PcpKnnPatchesDataset(datafolder=args.datafolder, 214 | dataset_type='eval', 215 | dataset_name=args.eval_dataset_name, 216 | fastdebug=args.fastdebug, 217 | noise_level=args.eval_noise_level) 218 | 219 | train_dataloader = DataLoader(train_dataset, batch_size=args.batchsize_train, 220 | shuffle=True, num_workers=10, pin_memory=True) 221 | eval_dataloader = DataLoader(eval_dataset, batch_size=args.batchsize_eval, 222 | shuffle=False, num_workers=10, pin_memory=True) 223 | 224 | '''Model Loading''' 225 | model = NINormalNet() 226 | if args.multi_gpu: 227 | model = torch.nn.DataParallel(model).to(device='cuda:' + str(args.gpu_id)) 228 | else: 229 | model = model.to(device='cuda:'+str(args.gpu_id)) 230 | 231 | '''Set Optimiser''' 232 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.L2_reg) 233 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_gamma) 234 | 235 | '''Training''' 236 | best_temp = 0.0 237 | best_eval_pgp010 = 0.0 238 | sys.stdout.flush() 239 | logger.info('Start training...') 240 | for epoch in tqdm(range(args.epoch), desc='epochs'): 241 | logger.info('Epoch (%d/%s):', epoch + 1, args.epoch) 242 | 243 | train_epoch_metric = train_one_epoch(train_dataloader, optimizer, model, args) 244 | with torch.no_grad(): 245 | eval_epoch_metric = eval_one_epoch(eval_dataloader, model, args) 246 | scheduler.step() 247 | 248 | tqdm.write('pgp010: train: {0:.4f}, eval: {1:.4f}'.format(train_epoch_metric['pgp010'], eval_epoch_metric['pgp010'])) 249 | logger.info('pgp010: train: {0:.4f}, eval: {1:.4f}'.format(train_epoch_metric['pgp010'], eval_epoch_metric['pgp010'])) 250 | 251 | writer.add_scalar('train/loss', train_epoch_metric['loss'], epoch) 252 | writer.add_scalar('train/pgp003', train_epoch_metric['pgp003'], epoch) 253 | writer.add_scalar('train/pgp005', train_epoch_metric['pgp005'], epoch) 254 | writer.add_scalar('train/pgp010', train_epoch_metric['pgp010'], epoch) 255 | writer.add_scalar('train/pgp030', train_epoch_metric['pgp030'], epoch) 256 | writer.add_scalar('train/pgp060', train_epoch_metric['pgp060'], epoch) 257 | writer.add_scalar('train/pgp080', train_epoch_metric['pgp080'], epoch) 258 | writer.add_scalar('train/pgp090', train_epoch_metric['pgp090'], epoch) # sanity check, should be always 1.0 in un-oriented normal estimation 259 | writer.add_scalar('train/lr', scheduler.get_lr()[0], epoch) 260 | writer.add_scalar('temp', model.module.temp if args.multi_gpu else model.temp, epoch) 261 | 262 | writer.add_scalar('eval/loss', eval_epoch_metric['loss'], epoch) 263 | writer.add_scalar('eval/pgp003', eval_epoch_metric['pgp003'], epoch) 264 | writer.add_scalar('eval/pgp005', eval_epoch_metric['pgp005'], epoch) 265 | writer.add_scalar('eval/pgp010', eval_epoch_metric['pgp010'], epoch) 266 | writer.add_scalar('eval/pgp030', eval_epoch_metric['pgp030'], epoch) 267 | writer.add_scalar('eval/pgp060', eval_epoch_metric['pgp060'], epoch) 268 | writer.add_scalar('eval/pgp080', eval_epoch_metric['pgp080'], epoch) 269 | writer.add_scalar('eval/pgp090', eval_epoch_metric['pgp090'], epoch) # sanity check, should be always 1.0 in un-oriented normal estimation 270 | 271 | if eval_epoch_metric['pgp010'] >= best_eval_pgp010: 272 | best_eval_pgp010 = eval_epoch_metric['pgp010'] 273 | best_temp = model.module.temp if args.multi_gpu else model.temp 274 | 275 | logger.info('Saving model with the best pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp)) 276 | tqdm.write('Saving model with the best pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp)) 277 | save_checkpoint(epoch, train_epoch_metric['pgp010'], eval_epoch_metric['pgp010'], model, 278 | optimizer, str(experiment_dir), args.model_name) 279 | 280 | print('Best eval pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp)) 281 | logger.info('Best eval pgp010: {0:.4%} at temp {1:.4f}'.format(best_eval_pgp010, best_temp)) 282 | 283 | print('Final temp: {0:.4f}'.format(model.module.temp if args.multi_gpu else model.temp)) 284 | logger.info('Final temp: {0:.4f}'.format(model.module.temp if args.multi_gpu else model.temp)) 285 | 286 | return 287 | 288 | 289 | if __name__ == '__main__': 290 | torch.backends.cudnn.enabled=False 291 | args = parse_args() 292 | set_randomness(args) 293 | main(args) 294 | 295 | -------------------------------------------------------------------------------- /transplant_attn/transformer_from_torch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From Zirui Wang: 3 | This file is copied from PyTorch 1.3 torch.nn.modules.transformer.py. 4 | We just changed the import part and using torch.nn.MultiheadAttention() 5 | ''' 6 | 7 | # import torch 8 | # import copy 9 | # from .. import functional as F 10 | # from .module import Module 11 | # from .activation import MultiheadAttention 12 | # from .container import ModuleList 13 | # from ..init import xavier_uniform_ 14 | # from .dropout import Dropout 15 | # from .linear import Linear 16 | # from .normalization import LayerNorm 17 | 18 | import torch 19 | import copy 20 | import torch.nn.functional as F 21 | from torch.nn.modules.module import Module 22 | from torch.nn.modules.container import ModuleList 23 | from torch.nn.init import xavier_uniform_ 24 | from torch.nn.modules.dropout import Dropout 25 | from torch.nn.modules.linear import Linear 26 | from torch.nn.modules.normalization import LayerNorm 27 | 28 | if torch.__version__ == '1.3.0': 29 | MultiheadAttention = torch.nn.MultiheadAttention 30 | else: 31 | from .MultiheadAttention_from_torch import MultiheadAttention 32 | 33 | 34 | class Transformer(Module): 35 | r"""A transformer model. User is able to modify the attributes as needed. The architecture 36 | is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer, 37 | Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and 38 | Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information 39 | Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805) 40 | model with corresponding parameters. 41 | 42 | Args: 43 | d_model: the number of expected features in the encoder/decoder inputs (default=512). 44 | nhead: the number of heads in the multiheadattention models (default=8). 45 | num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6). 46 | num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6). 47 | dim_feedforward: the dimension of the feedforward network model (default=2048). 48 | dropout: the dropout value (default=0.1). 49 | activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu). 50 | custom_encoder: custom encoder (default=None). 51 | custom_decoder: custom decoder (default=None). 52 | 53 | Examples:: 54 | >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12) 55 | >>> src = torch.rand((10, 32, 512)) 56 | >>> tgt = torch.rand((20, 32, 512)) 57 | >>> out = transformer_model(src, tgt) 58 | 59 | Note: A full example to apply nn.Transformer module for the word language model is available in 60 | https://github.com/pytorch/examples/tree/master/word_language_model 61 | """ 62 | 63 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 64 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 65 | activation="relu", custom_encoder=None, custom_decoder=None): 66 | super(Transformer, self).__init__() 67 | 68 | if custom_encoder is not None: 69 | self.encoder = custom_encoder 70 | else: 71 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 72 | encoder_norm = LayerNorm(d_model) 73 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 74 | 75 | if custom_decoder is not None: 76 | self.decoder = custom_decoder 77 | else: 78 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation) 79 | decoder_norm = LayerNorm(d_model) 80 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 81 | 82 | self._reset_parameters() 83 | 84 | self.d_model = d_model 85 | self.nhead = nhead 86 | 87 | def forward(self, src, tgt, src_mask=None, tgt_mask=None, 88 | memory_mask=None, src_key_padding_mask=None, 89 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 90 | r"""Take in and process masked source/target sequences. 91 | 92 | Args: 93 | src: the sequence to the encoder (required). 94 | tgt: the sequence to the decoder (required). 95 | src_mask: the additive mask for the src sequence (optional). 96 | tgt_mask: the additive mask for the tgt sequence (optional). 97 | memory_mask: the additive mask for the encoder output (optional). 98 | src_key_padding_mask: the ByteTensor mask for src keys per batch (optional). 99 | tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional). 100 | memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional). 101 | 102 | Shape: 103 | - src: :math:`(S, N, E)`. 104 | - tgt: :math:`(T, N, E)`. 105 | - src_mask: :math:`(S, S)`. 106 | - tgt_mask: :math:`(T, T)`. 107 | - memory_mask: :math:`(T, S)`. 108 | - src_key_padding_mask: :math:`(N, S)`. 109 | - tgt_key_padding_mask: :math:`(N, T)`. 110 | - memory_key_padding_mask: :math:`(N, S)`. 111 | 112 | Note: [src/tgt/memory]_mask should be filled with 113 | float('-inf') for the masked positions and float(0.0) else. These masks 114 | ensure that predictions for position i depend only on the unmasked positions 115 | j and are applied identically for each sequence in a batch. 116 | [src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions 117 | that should be masked with float('-inf') and False values will be unchanged. 118 | This mask ensures that no information will be taken from position i if 119 | it is masked, and has a separate mask for each sequence in a batch. 120 | 121 | - output: :math:`(T, N, E)`. 122 | 123 | Note: Due to the multi-head attention architecture in the transformer model, 124 | the output sequence length of a transformer is same as the input sequence 125 | (i.e. target) length of the decode. 126 | 127 | where S is the source sequence length, T is the target sequence length, N is the 128 | batch size, E is the feature number 129 | 130 | Examples: 131 | >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask) 132 | """ 133 | 134 | if src.size(1) != tgt.size(1): 135 | raise RuntimeError("the batch number of src and tgt must be equal") 136 | 137 | if src.size(2) != self.d_model or tgt.size(2) != self.d_model: 138 | raise RuntimeError("the feature number of src and tgt must be equal to d_model") 139 | 140 | memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) 141 | output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, 142 | tgt_key_padding_mask=tgt_key_padding_mask, 143 | memory_key_padding_mask=memory_key_padding_mask) 144 | return output 145 | 146 | def generate_square_subsequent_mask(self, sz): 147 | r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). 148 | Unmasked positions are filled with float(0.0). 149 | """ 150 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 151 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 152 | return mask 153 | 154 | def _reset_parameters(self): 155 | r"""Initiate parameters in the transformer model.""" 156 | 157 | for p in self.parameters(): 158 | if p.dim() > 1: 159 | xavier_uniform_(p) 160 | 161 | 162 | class TransformerEncoder(Module): 163 | r"""TransformerEncoder is a stack of N encoder layers 164 | 165 | Args: 166 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 167 | num_layers: the number of sub-encoder-layers in the encoder (required). 168 | norm: the layer normalization component (optional). 169 | 170 | Examples:: 171 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 172 | >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 173 | >>> src = torch.rand(10, 32, 512) 174 | >>> out = transformer_encoder(src) 175 | """ 176 | 177 | def __init__(self, encoder_layer, num_layers, norm=None): 178 | super(TransformerEncoder, self).__init__() 179 | self.layers = _get_clones(encoder_layer, num_layers) 180 | self.num_layers = num_layers 181 | self.norm = norm 182 | print('==============') 183 | print('Using the Transformer file copied from PyTorch 1.3.0') 184 | print('==============') 185 | 186 | def forward(self, src, mask=None, src_key_padding_mask=None): 187 | r"""Pass the input through the endocder layers in turn. 188 | 189 | Args: 190 | src: the sequnce to the encoder (required). 191 | mask: the mask for the src sequence (optional). 192 | src_key_padding_mask: the mask for the src keys per batch (optional). 193 | 194 | Shape: 195 | see the docs in Transformer class. 196 | """ 197 | output = src 198 | 199 | for i in range(self.num_layers): 200 | output, my_weight = self.layers[i](output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 201 | 202 | if self.norm: 203 | output = self.norm(output) 204 | 205 | return output, my_weight 206 | 207 | 208 | class TransformerDecoder(Module): 209 | r"""TransformerDecoder is a stack of N decoder layers 210 | 211 | Args: 212 | decoder_layer: an instance of the TransformerDecoderLayer() class (required). 213 | num_layers: the number of sub-decoder-layers in the decoder (required). 214 | norm: the layer normalization component (optional). 215 | 216 | Examples:: 217 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 218 | >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) 219 | >>> memory = torch.rand(10, 32, 512) 220 | >>> tgt = torch.rand(20, 32, 512) 221 | >>> out = transformer_decoder(tgt, memory) 222 | """ 223 | 224 | def __init__(self, decoder_layer, num_layers, norm=None): 225 | super(TransformerDecoder, self).__init__() 226 | self.layers = _get_clones(decoder_layer, num_layers) 227 | self.num_layers = num_layers 228 | self.norm = norm 229 | 230 | def forward(self, tgt, memory, tgt_mask=None, 231 | memory_mask=None, tgt_key_padding_mask=None, 232 | memory_key_padding_mask=None): 233 | r"""Pass the inputs (and mask) through the decoder layer in turn. 234 | 235 | Args: 236 | tgt: the sequence to the decoder (required). 237 | memory: the sequnce from the last layer of the encoder (required). 238 | tgt_mask: the mask for the tgt sequence (optional). 239 | memory_mask: the mask for the memory sequence (optional). 240 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 241 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 242 | 243 | Shape: 244 | see the docs in Transformer class. 245 | """ 246 | output = tgt 247 | 248 | for i in range(self.num_layers): 249 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 250 | memory_mask=memory_mask, 251 | tgt_key_padding_mask=tgt_key_padding_mask, 252 | memory_key_padding_mask=memory_key_padding_mask) 253 | 254 | if self.norm: 255 | output = self.norm(output) 256 | 257 | return output 258 | 259 | class TransformerEncoderLayer(Module): 260 | r"""TransformerEncoderLayer is made up of self-attn and feedforward network. 261 | This standard encoder layer is based on the paper "Attention Is All You Need". 262 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 263 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 264 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 265 | in a different way during application. 266 | 267 | Args: 268 | d_model: the number of expected features in the input (required). 269 | nhead: the number of heads in the multiheadattention models (required). 270 | dim_feedforward: the dimension of the feedforward network model (default=2048). 271 | dropout: the dropout value (default=0.1). 272 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 273 | 274 | Examples:: 275 | >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) 276 | >>> src = torch.rand(10, 32, 512) 277 | >>> out = encoder_layer(src) 278 | """ 279 | 280 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 281 | super(TransformerEncoderLayer, self).__init__() 282 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 283 | # Implementation of Feedforward model 284 | self.linear1 = Linear(d_model, dim_feedforward) 285 | self.dropout = Dropout(dropout) 286 | self.linear2 = Linear(dim_feedforward, d_model) 287 | 288 | self.norm1 = LayerNorm(d_model) 289 | self.norm2 = LayerNorm(d_model) 290 | self.dropout1 = Dropout(dropout) 291 | self.dropout2 = Dropout(dropout) 292 | 293 | self.activation = _get_activation_fn(activation) 294 | 295 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 296 | r"""Pass the input through the endocder layer. 297 | 298 | Args: 299 | src: the sequnce to the encoder layer (required). 300 | src_mask: the mask for the src sequence (optional). 301 | src_key_padding_mask: the mask for the src keys per batch (optional). 302 | 303 | Shape: 304 | see the docs in Transformer class. 305 | """ 306 | src2, my_weight = self.self_attn(src, src, src, attn_mask=src_mask, 307 | key_padding_mask=src_key_padding_mask) 308 | src = src + self.dropout1(src2) 309 | src = self.norm1(src) 310 | if hasattr(self, "activation"): 311 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 312 | else: # for backward compatibility 313 | src2 = self.linear2(self.dropout(F.relu(self.linear1(src)))) 314 | src = src + self.dropout2(src2) 315 | src = self.norm2(src) 316 | return src, my_weight 317 | 318 | 319 | class TransformerDecoderLayer(Module): 320 | r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. 321 | This standard decoder layer is based on the paper "Attention Is All You Need". 322 | Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, 323 | Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in 324 | Neural Information Processing Systems, pages 6000-6010. Users may modify or implement 325 | in a different way during application. 326 | 327 | Args: 328 | d_model: the number of expected features in the input (required). 329 | nhead: the number of heads in the multiheadattention models (required). 330 | dim_feedforward: the dimension of the feedforward network model (default=2048). 331 | dropout: the dropout value (default=0.1). 332 | activation: the activation function of intermediate layer, relu or gelu (default=relu). 333 | 334 | Examples:: 335 | >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) 336 | >>> memory = torch.rand(10, 32, 512) 337 | >>> tgt = torch.rand(20, 32, 512) 338 | >>> out = decoder_layer(tgt, memory) 339 | """ 340 | 341 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 342 | super(TransformerDecoderLayer, self).__init__() 343 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 344 | self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 345 | # Implementation of Feedforward model 346 | self.linear1 = Linear(d_model, dim_feedforward) 347 | self.dropout = Dropout(dropout) 348 | self.linear2 = Linear(dim_feedforward, d_model) 349 | 350 | self.norm1 = LayerNorm(d_model) 351 | self.norm2 = LayerNorm(d_model) 352 | self.norm3 = LayerNorm(d_model) 353 | self.dropout1 = Dropout(dropout) 354 | self.dropout2 = Dropout(dropout) 355 | self.dropout3 = Dropout(dropout) 356 | 357 | self.activation = _get_activation_fn(activation) 358 | 359 | def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, 360 | tgt_key_padding_mask=None, memory_key_padding_mask=None): 361 | r"""Pass the inputs (and mask) through the decoder layer. 362 | 363 | Args: 364 | tgt: the sequence to the decoder layer (required). 365 | memory: the sequnce from the last layer of the encoder (required). 366 | tgt_mask: the mask for the tgt sequence (optional). 367 | memory_mask: the mask for the memory sequence (optional). 368 | tgt_key_padding_mask: the mask for the tgt keys per batch (optional). 369 | memory_key_padding_mask: the mask for the memory keys per batch (optional). 370 | 371 | Shape: 372 | see the docs in Transformer class. 373 | """ 374 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 375 | key_padding_mask=tgt_key_padding_mask)[0] 376 | tgt = tgt + self.dropout1(tgt2) 377 | tgt = self.norm1(tgt) 378 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 379 | key_padding_mask=memory_key_padding_mask)[0] 380 | tgt = tgt + self.dropout2(tgt2) 381 | tgt = self.norm2(tgt) 382 | if hasattr(self, "activation"): 383 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 384 | else: # for backward compatibility 385 | tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) 386 | tgt = tgt + self.dropout3(tgt2) 387 | tgt = self.norm3(tgt) 388 | return tgt 389 | 390 | 391 | def _get_clones(module, N): 392 | return ModuleList([copy.deepcopy(module) for i in range(N)]) 393 | 394 | 395 | def _get_activation_fn(activation): 396 | if activation == "relu": 397 | return F.relu 398 | elif activation == "gelu": 399 | return F.gelu 400 | else: 401 | raise RuntimeError("activation should be relu/gelu, not %s." % activation) 402 | -------------------------------------------------------------------------------- /transplant_attn/MultiheadAttention_from_torch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From Zirui Wang: 3 | The MultiheadAttention class is copied from PyTorch1.3 torch.nn.modules.activation.py 4 | The multi_head_attention_forward is copied from PyTorch1.3 torch.nn.functional.py 5 | ''' 6 | 7 | # import warnings 8 | # import torch 9 | # from . import Linear 10 | # from torch.nn.init import xavier_uniform_ 11 | # from torch.nn.init import constant_ 12 | # from torch.nn.init import xavier_normal_ 13 | # from torch.nn.parameter import Parameter 14 | # from .module import Module 15 | # from .. import functional as F 16 | 17 | import warnings 18 | import torch 19 | from torch.nn.modules import Linear 20 | from torch.nn.init import xavier_uniform_ 21 | from torch.nn.init import constant_ 22 | from torch.nn.init import xavier_normal_ 23 | from torch.nn.parameter import Parameter 24 | from torch.nn.modules.module import Module 25 | import torch.nn.functional as F 26 | 27 | 28 | def multi_head_attention_forward(query, # type: Tensor 29 | key, # type: Tensor 30 | value, # type: Tensor 31 | embed_dim_to_check, # type: int 32 | num_heads, # type: int 33 | in_proj_weight, # type: Tensor 34 | in_proj_bias, # type: Tensor 35 | bias_k, # type: Optional[Tensor] 36 | bias_v, # type: Optional[Tensor] 37 | add_zero_attn, # type: bool 38 | dropout_p, # type: float 39 | out_proj_weight, # type: Tensor 40 | out_proj_bias, # type: Tensor 41 | training=True, # type: bool 42 | key_padding_mask=None, # type: Optional[Tensor] 43 | need_weights=True, # type: bool 44 | attn_mask=None, # type: Optional[Tensor] 45 | use_separate_proj_weight=False, # type: bool 46 | q_proj_weight=None, # type: Optional[Tensor] 47 | k_proj_weight=None, # type: Optional[Tensor] 48 | v_proj_weight=None, # type: Optional[Tensor] 49 | static_k=None, # type: Optional[Tensor] 50 | static_v=None # type: Optional[Tensor] 51 | ): 52 | # type: (...) -> Tuple[Tensor, Optional[Tensor]] 53 | r""" 54 | Args: 55 | query, key, value: map a query and a set of key-value pairs to an output. 56 | See "Attention Is All You Need" for more details. 57 | embed_dim_to_check: total dimension of the model. 58 | num_heads: parallel attention heads. 59 | in_proj_weight, in_proj_bias: input projection weight and bias. 60 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 61 | add_zero_attn: add a new batch of zeros to the key and 62 | value sequences at dim=1. 63 | dropout_p: probability of an element to be zeroed. 64 | out_proj_weight, out_proj_bias: the output projection weight and bias. 65 | training: apply dropout if is ``True``. 66 | key_padding_mask: if provided, specified padding elements in the key will 67 | be ignored by the attention. This is an binary mask. When the value is True, 68 | the corresponding value on the attention layer will be filled with -inf. 69 | need_weights: output attn_output_weights. 70 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 71 | (i.e. the values will be added to the attention layer). 72 | use_separate_proj_weight: the function accept the proj. weights for query, key, 73 | and value in differnt forms. If false, in_proj_weight will be used, which is 74 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 75 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 76 | static_k, static_v: static key and value used for attention operators. 77 | 78 | 79 | Shape: 80 | Inputs: 81 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 82 | the embedding dimension. 83 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 84 | the embedding dimension. 85 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 86 | the embedding dimension. 87 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 88 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 89 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 90 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 91 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 92 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 93 | 94 | Outputs: 95 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 96 | E is the embedding dimension. 97 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 98 | L is the target sequence length, S is the source sequence length. 99 | """ 100 | 101 | qkv_same = torch.equal(query, key) and torch.equal(key, value) 102 | kv_same = torch.equal(key, value) 103 | 104 | tgt_len, bsz, embed_dim = query.size() 105 | assert embed_dim == embed_dim_to_check 106 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 107 | assert key.size() == value.size() 108 | 109 | head_dim = embed_dim // num_heads 110 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 111 | scaling = float(head_dim) ** -0.5 112 | 113 | if use_separate_proj_weight is not True: 114 | if qkv_same: 115 | # self-attention 116 | q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 117 | 118 | elif kv_same: 119 | # encoder-decoder attention 120 | # This is inline in_proj function with in_proj_weight and in_proj_bias 121 | _b = in_proj_bias 122 | _start = 0 123 | _end = embed_dim 124 | _w = in_proj_weight[_start:_end, :] 125 | if _b is not None: 126 | _b = _b[_start:_end] 127 | q = F.linear(query, _w, _b) 128 | 129 | if key is None: 130 | assert value is None 131 | k = None 132 | v = None 133 | else: 134 | 135 | # This is inline in_proj function with in_proj_weight and in_proj_bias 136 | _b = in_proj_bias 137 | _start = embed_dim 138 | _end = None 139 | _w = in_proj_weight[_start:, :] 140 | if _b is not None: 141 | _b = _b[_start:] 142 | k, v = F.linear(key, _w, _b).chunk(2, dim=-1) 143 | 144 | else: 145 | # This is inline in_proj function with in_proj_weight and in_proj_bias 146 | _b = in_proj_bias 147 | _start = 0 148 | _end = embed_dim 149 | _w = in_proj_weight[_start:_end, :] 150 | if _b is not None: 151 | _b = _b[_start:_end] 152 | q = F.linear(query, _w, _b) 153 | 154 | # This is inline in_proj function with in_proj_weight and in_proj_bias 155 | _b = in_proj_bias 156 | _start = embed_dim 157 | _end = embed_dim * 2 158 | _w = in_proj_weight[_start:_end, :] 159 | if _b is not None: 160 | _b = _b[_start:_end] 161 | k = F.linear(key, _w, _b) 162 | 163 | # This is inline in_proj function with in_proj_weight and in_proj_bias 164 | _b = in_proj_bias 165 | _start = embed_dim * 2 166 | _end = None 167 | _w = in_proj_weight[_start:, :] 168 | if _b is not None: 169 | _b = _b[_start:] 170 | v = F.linear(value, _w, _b) 171 | else: 172 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 173 | len1, len2 = q_proj_weight_non_opt.size() 174 | assert len1 == embed_dim and len2 == query.size(-1) 175 | 176 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 177 | len1, len2 = k_proj_weight_non_opt.size() 178 | assert len1 == embed_dim and len2 == key.size(-1) 179 | 180 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 181 | len1, len2 = v_proj_weight_non_opt.size() 182 | assert len1 == embed_dim and len2 == value.size(-1) 183 | 184 | if in_proj_bias is not None: 185 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 186 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 187 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 188 | else: 189 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) 190 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) 191 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) 192 | q = q * scaling 193 | 194 | if bias_k is not None and bias_v is not None: 195 | if static_k is None and static_v is None: 196 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 197 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 198 | if attn_mask is not None: 199 | attn_mask = torch.cat([attn_mask, 200 | torch.zeros((attn_mask.size(0), 1), 201 | dtype=attn_mask.dtype, 202 | device=attn_mask.device)], dim=1) 203 | if key_padding_mask is not None: 204 | key_padding_mask = torch.cat( 205 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 206 | dtype=key_padding_mask.dtype, 207 | device=key_padding_mask.device)], dim=1) 208 | else: 209 | assert static_k is None, "bias cannot be added to static key." 210 | assert static_v is None, "bias cannot be added to static value." 211 | else: 212 | assert bias_k is None 213 | assert bias_v is None 214 | 215 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 216 | if k is not None: 217 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 218 | if v is not None: 219 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 220 | 221 | if static_k is not None: 222 | assert static_k.size(0) == bsz * num_heads 223 | assert static_k.size(2) == head_dim 224 | k = static_k 225 | 226 | if static_v is not None: 227 | assert static_v.size(0) == bsz * num_heads 228 | assert static_v.size(2) == head_dim 229 | v = static_v 230 | 231 | src_len = k.size(1) 232 | 233 | if key_padding_mask is not None: 234 | assert key_padding_mask.size(0) == bsz 235 | assert key_padding_mask.size(1) == src_len 236 | 237 | if add_zero_attn: 238 | src_len += 1 239 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 240 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 241 | if attn_mask is not None: 242 | attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), 243 | dtype=attn_mask.dtype, 244 | device=attn_mask.device)], dim=1) 245 | if key_padding_mask is not None: 246 | key_padding_mask = torch.cat( 247 | [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), 248 | dtype=key_padding_mask.dtype, 249 | device=key_padding_mask.device)], dim=1) 250 | 251 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 252 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 253 | 254 | if attn_mask is not None: 255 | attn_mask = attn_mask.unsqueeze(0) 256 | attn_output_weights += attn_mask 257 | 258 | if key_padding_mask is not None: 259 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 260 | attn_output_weights = attn_output_weights.masked_fill( 261 | key_padding_mask.unsqueeze(1).unsqueeze(2), 262 | float('-inf'), 263 | ) 264 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 265 | 266 | attn_output_weights = F.softmax( 267 | attn_output_weights, dim=-1) 268 | attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) 269 | 270 | attn_output = torch.bmm(attn_output_weights, v) 271 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 272 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 273 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 274 | 275 | if need_weights: 276 | # average attention weights over heads 277 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 278 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 279 | else: 280 | return attn_output, None 281 | 282 | 283 | class MultiheadAttention(Module): 284 | r"""Allows the model to jointly attend to information 285 | from different representation subspaces. 286 | See reference: Attention Is All You Need 287 | 288 | .. math:: 289 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 290 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 291 | 292 | Args: 293 | embed_dim: total dimension of the model. 294 | num_heads: parallel attention heads. 295 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 296 | bias: add bias as module parameter. Default: True. 297 | add_bias_kv: add bias to the key and value sequences at dim=0. 298 | add_zero_attn: add a new batch of zeros to the key and 299 | value sequences at dim=1. 300 | kdim: total number of features in key. Default: None. 301 | vdim: total number of features in key. Default: None. 302 | 303 | Note: if kdim and vdim are None, they will be set to embed_dim such that 304 | query, key, and value have the same number of features. 305 | 306 | Examples:: 307 | 308 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 309 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 310 | """ 311 | 312 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 313 | super(MultiheadAttention, self).__init__() 314 | self.embed_dim = embed_dim 315 | self.kdim = kdim if kdim is not None else embed_dim 316 | self.vdim = vdim if vdim is not None else embed_dim 317 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 318 | 319 | self.num_heads = num_heads 320 | self.dropout = dropout 321 | self.head_dim = embed_dim // num_heads 322 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 323 | 324 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 325 | 326 | if self._qkv_same_embed_dim is False: 327 | self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) 328 | self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) 329 | self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) 330 | 331 | if bias: 332 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 333 | else: 334 | self.register_parameter('in_proj_bias', None) 335 | self.out_proj = Linear(embed_dim, embed_dim, bias=bias) 336 | 337 | if add_bias_kv: 338 | self.bias_k = Parameter(torch.empty(1, 1, embed_dim)) 339 | self.bias_v = Parameter(torch.empty(1, 1, embed_dim)) 340 | else: 341 | self.bias_k = self.bias_v = None 342 | 343 | self.add_zero_attn = add_zero_attn 344 | 345 | self._reset_parameters() 346 | print('==============') 347 | print('Using the MultiheadAttention file copied from PyTorch 1.3.0') 348 | print('==============') 349 | 350 | def _reset_parameters(self): 351 | if self._qkv_same_embed_dim: 352 | xavier_uniform_(self.in_proj_weight) 353 | else: 354 | xavier_uniform_(self.q_proj_weight) 355 | xavier_uniform_(self.k_proj_weight) 356 | xavier_uniform_(self.v_proj_weight) 357 | 358 | if self.in_proj_bias is not None: 359 | constant_(self.in_proj_bias, 0.) 360 | constant_(self.out_proj.bias, 0.) 361 | if self.bias_k is not None: 362 | xavier_normal_(self.bias_k) 363 | if self.bias_v is not None: 364 | xavier_normal_(self.bias_v) 365 | 366 | def forward(self, query, key, value, key_padding_mask=None, 367 | need_weights=True, attn_mask=None): 368 | r""" 369 | Args: 370 | query, key, value: map a query and a set of key-value pairs to an output. 371 | See "Attention Is All You Need" for more details. 372 | key_padding_mask: if provided, specified padding elements in the key will 373 | be ignored by the attention. This is an binary mask. When the value is True, 374 | the corresponding value on the attention layer will be filled with -inf. 375 | need_weights: output attn_output_weights. 376 | attn_mask: mask that prevents attention to certain positions. This is an additive mask 377 | (i.e. the values will be added to the attention layer). 378 | 379 | Shape: 380 | - Inputs: 381 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 382 | the embedding dimension. 383 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 384 | the embedding dimension. 385 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 386 | the embedding dimension. 387 | - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. 388 | - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 389 | 390 | - Outputs: 391 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 392 | E is the embedding dimension. 393 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 394 | L is the target sequence length, S is the source sequence length. 395 | """ 396 | if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: 397 | return multi_head_attention_forward( 398 | query, key, value, self.embed_dim, self.num_heads, 399 | self.in_proj_weight, self.in_proj_bias, 400 | self.bias_k, self.bias_v, self.add_zero_attn, 401 | self.dropout, self.out_proj.weight, self.out_proj.bias, 402 | training=self.training, 403 | key_padding_mask=key_padding_mask, need_weights=need_weights, 404 | attn_mask=attn_mask, use_separate_proj_weight=True, 405 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 406 | v_proj_weight=self.v_proj_weight) 407 | else: 408 | if not hasattr(self, '_qkv_same_embed_dim'): 409 | warnings.warn('A new version of MultiheadAttention module has been implemented. \ 410 | Please re-train your model with the new module', 411 | UserWarning) 412 | 413 | return multi_head_attention_forward( 414 | query, key, value, self.embed_dim, self.num_heads, 415 | self.in_proj_weight, self.in_proj_bias, 416 | self.bias_k, self.bias_v, self.add_zero_attn, 417 | self.dropout, self.out_proj.weight, self.out_proj.bias, 418 | training=self.training, 419 | key_padding_mask=key_padding_mask, need_weights=need_weights, 420 | attn_mask=attn_mask) --------------------------------------------------------------------------------