├── Figure └── FIG1.png ├── core ├── __pycache__ │ ├── oan.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── loss.cpython-36.pyc │ ├── test.cpython-36.pyc │ ├── train.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── logger.cpython-36.pyc │ ├── evaluation.cpython-36.pyc │ └── transformations.cpython-36.pyc ├── utils.py ├── main.py ├── train.py ├── loss.py ├── logger.py ├── evaluation.py ├── config.py ├── data.py ├── test.py ├── pgf.py └── transformations.py └── README.md /Figure/FIG1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/Figure/FIG1.png -------------------------------------------------------------------------------- /core/__pycache__/oan.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/oan.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/evaluation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/evaluation.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/transformations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guobaoxiao/PGFNet/HEAD/core/__pycache__/transformations.cpython-36.pyc -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from multiprocessing import Pool as ThreadPool 4 | 5 | def tocuda(data): 6 | # convert tensor data in dictionary to cuda when it is a tensor 7 | for key in data.keys(): 8 | if type(data[key]) == torch.Tensor: 9 | data[key] = data[key].cuda() 10 | return data 11 | 12 | def get_pool_result(num_processor, fun, args): 13 | pool = ThreadPool(num_processor) 14 | pool_res = pool.map(fun, args) 15 | pool.close() 16 | pool.join() 17 | return pool_res 18 | 19 | def np_skew_symmetric(v): 20 | 21 | zero = np.zeros_like(v[:, 0]) 22 | 23 | M = np.stack([ 24 | zero, -v[:, 2], v[:, 1], 25 | v[:, 2], zero, -v[:, 0], 26 | -v[:, 1], v[:, 0], zero, 27 | ], axis=1) 28 | 29 | return M 30 | 31 | def torch_skew_symmetric(v): 32 | 33 | zero = torch.zeros_like(v[:, 0]) 34 | 35 | M = torch.stack([ 36 | zero, -v[:, 2], v[:, 1], 37 | v[:, 2], zero, -v[:, 0], 38 | -v[:, 1], v[:, 0], zero, 39 | ], dim=1) 40 | 41 | return M 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PGFNet 2 | PGFNet: Preference-Guided Filtering Network for Two-View Correspondence Learning 3 | 4 | # PGFNet implementation 5 | Pytorch implementation of PGFNet 6 | 7 | ## Abstract 8 | 9 | 10 | Accurate correspondence selection between two images is of great importance for numerous feature matching based vision tasks. The initial correspondences established by off-the-shelf feature extraction methods usually contain a large number of outliers, and this often leads to the difficulty in accurately and sufficiently capturing contextual information for the correspondence learning task. In this paper, we propose a Preference-Guided Filtering Network (PGFNet) to address this problem. The proposed PGFNet is able to effectively select correct correspondences and simultaneously recover the accurate camera pose of matching images. Specifically, we first design a novel iterative filtering structure to learn the preference scores of correspondences for guiding the correspondence filtering strategy. This structure explicitly alleviates the negative effects of outliers so that our network is able to capture more reliable contextual information encoded by the inliers for network learning. Then, to enhance the reliability of preference scores, we present a simple yet effective Grouped Residual Attention block as our network backbone, by designing a feature grouping strategy, a feature grouping manner, a hierarchical residual-like manner and two grouped attention operations. We evaluate PGFNet by extensive ablation studies and comparative experiments on the tasks of outlier removal and camera pose estimation. The results demonstrate outstanding performance gains over the existing state-of-the-art methods on different challenging scenes. 11 | 12 | ## Requirements 13 | 14 | Please use Python 3.6, opencv-contrib-python (3.4.0.12) and Pytorch (>= 1.1.0). Other dependencies should be easily installed through pip or conda. 15 | 16 | ## Explanation 17 | 18 | If you need YFCC100M and SUN3D datasets, You can visit the code at https://github.com/zjhthu/OANet.git. We have uploaded the main code on 'core' folder. 19 | 20 | # Citing PGFNet 21 | If you find the PGFNet code useful, please consider citing: 22 | 23 | ```bibtex 24 | @article{liu2023pgfnet, 25 | title={Pgfnet: Preference-guided filtering network for two-view correspondence learning}, 26 | author={Liu, Xin and Xiao, Guobao and Chen, Riqing and Ma, Jiayi}, 27 | journal={IEEE Transactions on Image Processing}, 28 | volume={32}, 29 | pages={1367--1378}, 30 | year={2023}, 31 | publisher={IEEE} 32 | } 33 | ``` 34 | -------------------------------------------------------------------------------- /core/main.py: -------------------------------------------------------------------------------- 1 | from config import get_config, print_usage 2 | config, unparsed = get_config() 3 | import os 4 | os.environ['CUDA_VISIBLE_DEVICES'] = config.gpu_id 5 | import torch.utils.data 6 | import sys 7 | from data import collate_fn, CorrespondencesDataset 8 | from pgf import PGFNet as Model 9 | from train import train 10 | from test import test 11 | 12 | 13 | print("-------------------------Deep Essential-------------------------") 14 | print("Note: To combine datasets, use .") 15 | 16 | def create_log_dir(config): 17 | if not os.path.isdir(config.log_base): 18 | os.makedirs(config.log_base) 19 | if config.log_suffix == "": 20 | suffix = "-".join(sys.argv) 21 | #result_path = config.log_base+'/'+suffix 22 | result_path = './log' 23 | if not os.path.isdir(result_path): 24 | os.makedirs(result_path) 25 | if not os.path.isdir(result_path+'/train'): 26 | os.makedirs(result_path+'/train') 27 | if not os.path.isdir(result_path+'/valid'): 28 | os.makedirs(result_path+'/valid') 29 | if not os.path.isdir(result_path+'/test'): 30 | os.makedirs(result_path+'/test') 31 | if os.path.exists(result_path+'/config.th'): 32 | print('warning: will overwrite config file') 33 | torch.save(config, result_path+'/config.th') 34 | 35 | # path for saving traning logs 36 | config.log_path = result_path+'/train' 37 | 38 | def main(config): 39 | """The main function.""" 40 | 41 | # Initialize network 42 | model = Model(config) 43 | 44 | # Run propper mode 45 | if config.run_mode == "train": 46 | create_log_dir(config) 47 | 48 | train_dataset = CorrespondencesDataset(config.data_tr, config) 49 | 50 | train_loader = torch.utils.data.DataLoader( 51 | train_dataset, batch_size=config.train_batch_size, shuffle=True, 52 | num_workers=16, pin_memory=False, collate_fn=collate_fn) 53 | 54 | valid_dataset = CorrespondencesDataset(config.data_va, config) 55 | valid_loader = torch.utils.data.DataLoader( 56 | valid_dataset, batch_size=config.train_batch_size, shuffle=False, 57 | num_workers=8, pin_memory=False, collate_fn=collate_fn) 58 | #valid_loader = None 59 | print('start training .....') 60 | train(model, train_loader, valid_loader, config) 61 | 62 | elif config.run_mode == "test": 63 | test_dataset = CorrespondencesDataset(config.data_te, config) 64 | test_loader = torch.utils.data.DataLoader( 65 | test_dataset, batch_size=1, shuffle=False, 66 | num_workers=8, pin_memory=False, collate_fn=collate_fn) 67 | 68 | test(test_loader, model, config) 69 | 70 | 71 | 72 | 73 | if __name__ == "__main__": 74 | 75 | # ---------------------------------------- 76 | # Parse configuration 77 | config, unparsed = get_config() 78 | # If we have unparsed arguments, print usage and exit 79 | if len(unparsed) > 0: 80 | print_usage() 81 | exit(1) 82 | 83 | main(config) 84 | 85 | # 86 | # main.py ends here 87 | -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | import sys 5 | from tqdm import trange 6 | import os 7 | from logger import Logger 8 | from test import valid 9 | from loss import MatchLoss 10 | from utils import tocuda 11 | 12 | 13 | def train_step(step, optimizer, model, match_loss, data): 14 | model.train() 15 | 16 | res_logits, res_e_hat = model(data) 17 | loss = 0 18 | loss_val = [] 19 | for i in range(len(res_logits)): 20 | loss_i, geo_loss, cla_loss, l2_loss, _, _ = match_loss.run(step, data, res_logits[i], res_e_hat[i]) 21 | loss += loss_i 22 | loss_val += [geo_loss, cla_loss, l2_loss] 23 | optimizer.zero_grad() 24 | loss.backward() 25 | for name, param in model.named_parameters(): 26 | if torch.any(torch.isnan(param.grad)): 27 | print('skip because nan') 28 | return loss_val 29 | 30 | optimizer.step() 31 | return loss_val 32 | 33 | 34 | def train(model, train_loader, valid_loader, config): 35 | model.cuda() 36 | optimizer = optim.Adam(model.parameters(), lr=config.train_lr, weight_decay = config.weight_decay) 37 | match_loss = MatchLoss(config) 38 | 39 | checkpoint_path = os.path.join(config.log_path, 'checkpoint.pth') 40 | config.resume = os.path.isfile(checkpoint_path) 41 | if config.resume: 42 | print('==> Resuming from checkpoint..') 43 | checkpoint = torch.load(checkpoint_path) 44 | best_acc = checkpoint['best_acc'] 45 | start_epoch = checkpoint['epoch'] 46 | model.load_state_dict(checkpoint['state_dict']) 47 | optimizer.load_state_dict(checkpoint['optimizer']) 48 | logger_train = Logger(os.path.join(config.log_path, 'log_train.txt'), title='oan', resume=True) 49 | logger_valid = Logger(os.path.join(config.log_path, 'log_valid.txt'), title='oan', resume=True) 50 | else: 51 | best_acc = -1 52 | start_epoch = 0 53 | logger_train = Logger(os.path.join(config.log_path, 'log_train.txt'), title='oan') 54 | logger_train.set_names(['Learning Rate'] + ['Geo Loss', 'Classfi Loss', 'L2 Loss']*(config.iter_num+1)) 55 | logger_valid = Logger(os.path.join(config.log_path, 'log_valid.txt'), title='oan') 56 | logger_valid.set_names(['Valid Acc'] + ['Geo Loss', 'Clasfi Loss', 'L2 Loss']) 57 | train_loader_iter = iter(train_loader) 58 | for step in trange(start_epoch, config.train_iter, ncols=config.tqdm_width): 59 | try: 60 | train_data = next(train_loader_iter) 61 | except StopIteration: 62 | train_loader_iter = iter(train_loader) 63 | train_data = next(train_loader_iter) 64 | train_data = tocuda(train_data) 65 | 66 | # run training 67 | cur_lr = optimizer.param_groups[0]['lr'] 68 | loss_vals = train_step(step, optimizer, model, match_loss, train_data) #训练 69 | logger_train.append([cur_lr] + loss_vals) 70 | 71 | # Check if we want to write validation 72 | b_save = ((step + 1) % config.save_intv) == 0 73 | b_validate = ((step + 1) % config.val_intv) == 0 74 | if b_validate: 75 | va_res, geo_loss, cla_loss, l2_loss, _, _, _ = valid(valid_loader, model, step, config) # 验证 76 | logger_valid.append([va_res, geo_loss, cla_loss, l2_loss]) 77 | if va_res > best_acc: 78 | print("Saving best model with va_res = {}".format(va_res)) 79 | best_acc = va_res 80 | torch.save({ 81 | 'epoch': step + 1, 82 | 'state_dict': model.state_dict(), 83 | 'best_acc': best_acc, 84 | 'optimizer' : optimizer.state_dict(), 85 | }, os.path.join(config.log_path, 'model_best.pth')) 86 | 87 | if b_save: 88 | torch.save({ 89 | 'epoch': step + 1, 90 | 'state_dict': model.state_dict(), 91 | 'best_acc': best_acc, 92 | 'optimizer' : optimizer.state_dict(), 93 | }, checkpoint_path) 94 | 95 | -------------------------------------------------------------------------------- /core/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import torch_skew_symmetric 3 | import numpy as np 4 | 5 | def batch_episym(x1, x2, F): 6 | batch_size, num_pts = x1.shape[0], x1.shape[1] 7 | x1 = torch.cat([x1, x1.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) 8 | x2 = torch.cat([x2, x2.new_ones(batch_size, num_pts,1)], dim=-1).reshape(batch_size, num_pts,3,1) 9 | F = F.reshape(-1,1,3,3).repeat(1,num_pts,1,1) 10 | x2Fx1 = torch.matmul(x2.transpose(2,3), torch.matmul(F, x1)).reshape(batch_size,num_pts) 11 | Fx1 = torch.matmul(F,x1).reshape(batch_size,num_pts,3) 12 | Ftx2 = torch.matmul(F.transpose(2,3),x2).reshape(batch_size,num_pts,3) 13 | ys = x2Fx1**2 * ( 14 | 1.0 / (Fx1[:, :, 0]**2 + Fx1[:, :, 1]**2 + 1e-15) + 15 | 1.0 / (Ftx2[:, :, 0]**2 + Ftx2[:, :, 1]**2 + 1e-15)) 16 | return ys 17 | 18 | class MatchLoss(object): 19 | def __init__(self, config): 20 | self.loss_essential = config.loss_essential 21 | self.loss_classif = config.loss_classif 22 | self.use_fundamental = config.use_fundamental 23 | self.obj_geod_th = config.obj_geod_th 24 | self.geo_loss_margin = config.geo_loss_margin 25 | self.loss_essential_init_iter = config.loss_essential_init_iter 26 | 27 | def run(self, global_step, data, logits, e_hat): 28 | R_in, t_in, y_in, pts_virt = data['Rs'], data['ts'], data['ys'], data['virtPts'] 29 | 30 | # Get groundtruth Essential matrix 31 | e_gt_unnorm = torch.reshape(torch.matmul( 32 | torch.reshape(torch_skew_symmetric(t_in), (-1, 3, 3)), 33 | torch.reshape(R_in, (-1, 3, 3)) 34 | ), (-1, 9)) 35 | 36 | e_gt = e_gt_unnorm / torch.norm(e_gt_unnorm, dim=1, keepdim=True) 37 | 38 | ess_hat = e_hat 39 | if self.use_fundamental: 40 | ess_hat = torch.matmul(torch.matmul(data['T2s'].transpose(1,2), ess_hat.reshape(-1,3,3)),data['T1s']) 41 | # get essential matrix from fundamental matrix 42 | ess_hat = torch.matmul(torch.matmul(data['K2s'].transpose(1,2), ess_hat.reshape(-1,3,3)),data['K1s']).reshape(-1,9) 43 | ess_hat = ess_hat / torch.norm(ess_hat, dim=1, keepdim=True) 44 | 45 | 46 | # Essential/Fundamental matrix loss 47 | pts1_virts, pts2_virts = pts_virt[:, :, :2], pts_virt[:,:,2:] 48 | geod = batch_episym(pts1_virts, pts2_virts, e_hat) 49 | essential_loss = torch.min(geod, self.geo_loss_margin*geod.new_ones(geod.shape)) 50 | essential_loss = essential_loss.mean() 51 | # we do not use the l2 loss, just save the value for convenience 52 | L2_loss = torch.mean(torch.min( 53 | torch.sum(torch.pow(ess_hat - e_gt, 2), dim=1), 54 | torch.sum(torch.pow(ess_hat + e_gt, 2), dim=1) 55 | )) 56 | 57 | 58 | # Classification loss 59 | # The groundtruth epi sqr 60 | gt_geod_d = y_in[:, :, 0] 61 | is_pos = (gt_geod_d < self.obj_geod_th).type(logits.type()) 62 | is_neg = (gt_geod_d >= self.obj_geod_th).type(logits.type()) 63 | c = is_pos - is_neg 64 | classif_losses = -torch.log(torch.sigmoid(c * logits) + np.finfo(float).eps.item()) 65 | # balance 66 | num_pos = torch.relu(torch.sum(is_pos, dim=1) - 1.0) + 1.0 67 | num_neg = torch.relu(torch.sum(is_neg, dim=1) - 1.0) + 1.0 68 | classif_loss_p = torch.sum(classif_losses * is_pos, dim=1) 69 | classif_loss_n = torch.sum(classif_losses * is_neg, dim=1) 70 | classif_loss = torch.mean(classif_loss_p * 0.5 / num_pos + classif_loss_n * 0.5 / num_neg) 71 | 72 | 73 | precision = torch.mean( 74 | torch.sum((logits > 0).type(is_pos.type()) * is_pos, dim=1) / 75 | torch.sum((logits > 0).type(is_pos.type()) * (is_pos + is_neg), dim=1) 76 | ) 77 | recall = torch.mean( 78 | torch.sum((logits > 0).type(is_pos.type()) * is_pos, dim=1) / 79 | torch.sum(is_pos, dim=1) 80 | ) 81 | 82 | loss = 0 83 | # Check global_step and add essential loss 84 | if self.loss_essential > 0 and global_step >= self.loss_essential_init_iter: 85 | loss += self.loss_essential * essential_loss 86 | if self.loss_classif > 0: 87 | loss += self.loss_classif * classif_loss 88 | 89 | return [loss, (self.loss_essential * essential_loss).item(), (self.loss_classif * classif_loss).item(), L2_loss.item(), precision.item(), recall.item()] 90 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') 128 | -------------------------------------------------------------------------------- /core/evaluation.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from transformations import quaternion_from_matrix 3 | import numpy as np 4 | 5 | 6 | def evaluate_R_t(R_gt, t_gt, R, t, q_gt=None): 7 | t = t.flatten() 8 | t_gt = t_gt.flatten() 9 | 10 | eps = 1e-15 11 | 12 | if q_gt is None: 13 | q_gt = quaternion_from_matrix(R_gt) 14 | q = quaternion_from_matrix(R) 15 | q = q / (np.linalg.norm(q) + eps) 16 | q_gt = q_gt / (np.linalg.norm(q_gt) + eps) 17 | loss_q = np.maximum(eps, (1.0 - np.sum(q * q_gt)**2)) 18 | err_q = np.arccos(1 - 2 * loss_q) 19 | 20 | # dR = np.dot(R, R_gt.T) 21 | # dt = t - np.dot(dR, t_gt) 22 | # dR = np.dot(R, R_gt.T) 23 | # dt = t - t_gt 24 | t = t / (np.linalg.norm(t) + eps) 25 | t_gt = t_gt / (np.linalg.norm(t_gt) + eps) 26 | loss_t = np.maximum(eps, (1.0 - np.sum(t * t_gt)**2)) 27 | err_t = np.arccos(np.sqrt(1 - loss_t)) 28 | 29 | if np.sum(np.isnan(err_q)) or np.sum(np.isnan(err_t)): 30 | # This should never happen! Debug here 31 | import IPython 32 | IPython.embed() 33 | 34 | return err_q, err_t 35 | 36 | 37 | def eval_nondecompose(p1s, p2s, E_hat, dR, dt, scores): 38 | 39 | # Use only the top 10% in terms of score to decompose, we can probably 40 | # implement a better way of doing this, but this should be just fine. 41 | num_top = len(scores) // 10 42 | num_top = max(1, num_top) 43 | th = np.sort(scores)[::-1][num_top] 44 | mask = scores >= th 45 | 46 | p1s_good = p1s[mask] 47 | p2s_good = p2s[mask] 48 | 49 | # Match types 50 | E_hat = E_hat.reshape(3, 3).astype(p1s.dtype) 51 | R, t = None, None 52 | if p1s_good.shape[0] >= 5: 53 | # Get the best E just in case we get multipl E from findEssentialMat 54 | num_inlier, R, t, mask_new = cv2.recoverPose( 55 | E_hat, p1s_good, p2s_good) 56 | try: 57 | err_q, err_t = evaluate_R_t(dR, dt, R, t) 58 | except: 59 | print("Failed in evaluation") 60 | print(E_hat) 61 | print(R) 62 | print(t) 63 | err_q = np.pi 64 | err_t = np.pi / 2 65 | #import pdb;pdb.set_trace() 66 | else: 67 | err_q = np.pi 68 | err_t = np.pi / 2 69 | 70 | loss_q = np.sqrt(0.5 * (1 - np.cos(err_q))) 71 | loss_t = np.sqrt(1.0 - np.cos(err_t)**2) 72 | 73 | # Change mask type 74 | mask = mask.flatten().astype(bool) 75 | 76 | mask_updated = mask.copy() 77 | if mask_new is not None: 78 | # Change mask type 79 | mask_new = mask_new.flatten().astype(bool) 80 | mask_updated[mask] = mask_new 81 | 82 | return err_q, err_t, loss_q, loss_t, np.sum(num_inlier), mask_updated, R, t 83 | 84 | 85 | def eval_decompose(p1s, p2s, dR, dt, mask=None, method=cv2.LMEDS, probs=None, 86 | weighted=False, use_prob=True): 87 | if mask is None: 88 | mask = np.ones((len(p1s),), dtype=bool) 89 | # Change mask type 90 | mask = mask.flatten().astype(bool) 91 | 92 | # Mask the ones that will not be used 93 | p1s_good = p1s[mask] 94 | p2s_good = p2s[mask] 95 | probs_good = None 96 | if probs is not None: 97 | probs_good = probs[mask] 98 | 99 | num_inlier = 0 100 | mask_new2 = None 101 | R, t = None, None 102 | if p1s_good.shape[0] >= 5: 103 | if probs is None and method != "MLESAC": 104 | # Change the threshold from 0.01 to 0.001 can largely imporve the results 105 | # For fundamental matrix estimation evaluation, we also transform the matrix to essential matrix. 106 | # This gives better results than using findFundamentalMat 107 | E, mask_new = cv2.findEssentialMat(p1s_good, p2s_good, method=method, threshold=0.001) 108 | 109 | else: 110 | pass 111 | if E is not None: 112 | new_RT = False 113 | # Get the best E just in case we get multipl E from 114 | # findEssentialMat 115 | for _E in np.split(E, len(E) / 3): 116 | _num_inlier, _R, _t, _mask_new2 = cv2.recoverPose( 117 | _E, p1s_good, p2s_good, mask=mask_new) 118 | if _num_inlier > num_inlier: 119 | num_inlier = _num_inlier 120 | R = _R 121 | t = _t 122 | mask_new2 = _mask_new2 123 | new_RT = True 124 | if new_RT: 125 | err_q, err_t = evaluate_R_t(dR, dt, R, t) 126 | else: 127 | err_q = np.pi 128 | err_t = np.pi / 2 129 | 130 | else: 131 | err_q = np.pi 132 | err_t = np.pi / 2 133 | else: 134 | err_q = np.pi 135 | err_t = np.pi / 2 136 | 137 | loss_q = np.sqrt(0.5 * (1 - np.cos(err_q))) 138 | loss_t = np.sqrt(1.0 - np.cos(err_t)**2) 139 | 140 | mask_updated = mask.copy() 141 | if mask_new2 is not None: 142 | # Change mask type 143 | mask_new2 = mask_new2.flatten().astype(bool) 144 | mask_updated[mask] = mask_new2 145 | 146 | return err_q, err_t, loss_q, loss_t, np.sum(num_inlier), mask_updated, R, t 147 | -------------------------------------------------------------------------------- /core/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | 5 | def str2bool(v): 6 | return v.lower() in ("true", "1") 7 | 8 | 9 | arg_lists = [] 10 | parser = argparse.ArgumentParser() 11 | 12 | 13 | def add_argument_group(name): 14 | arg = parser.add_argument_group(name) 15 | arg_lists.append(arg) 16 | return arg 17 | 18 | 19 | # ----------------------------------------------------------------------------- 20 | # Network 21 | net_arg = add_argument_group("Network") 22 | net_arg.add_argument( 23 | "--net_depth", type=int, default=12, help="" 24 | "number of layers. Default: 12") 25 | net_arg.add_argument( 26 | "--clusters", type=int, default=500, help="" 27 | "cluster number in OANet. Default: 500") 28 | net_arg.add_argument( 29 | "--iter_num", type=int, default=1, help="" 30 | "iteration number in the iterative network. Default: 1") 31 | net_arg.add_argument( 32 | "--net_channels", type=int, default=256, help="" 33 | "number of channels in a layer. Default: 256") 34 | net_arg.add_argument( 35 | "--use_fundamental", type=str2bool, default=False, help="" 36 | "train fundamental matrix estimation. Default: False") 37 | net_arg.add_argument( 38 | "--share", type=str2bool, default=False, help="" 39 | "share the parameter in iterative network. Default: False") 40 | net_arg.add_argument( 41 | "--use_ratio", type=int, default=0, help="" 42 | "use ratio test. 0: not use, 1: use before network, 2: use as side information. Default: 0") 43 | net_arg.add_argument( 44 | "--use_mutual", type=int, default=0, help="" 45 | "use matual nearest neighbor check. 0: not use, 1: use before network, 2: use as side information. Default: 0") 46 | net_arg.add_argument( 47 | "--ratio_test_th", type=float, default=0.8, help="" 48 | "ratio test threshold. Default: 0.8") 49 | 50 | # ----------------------------------------------------------------------------- 51 | # Data 52 | data_arg = add_argument_group("Data") 53 | data_arg.add_argument( 54 | "--data_tr", type=str, default='', help="" 55 | "name of the dataset for train") 56 | data_arg.add_argument( 57 | "--data_va", type=str, default='', help="" 58 | "name of the dataset for valid") 59 | data_arg.add_argument( 60 | "--data_te", type=str, default='', help="" 61 | "name of the unseen dataset for test") 62 | 63 | 64 | # ----------------------------------------------------------------------------- 65 | # Objective 66 | obj_arg = add_argument_group("obj") 67 | obj_arg.add_argument( 68 | "--obj_num_kp", type=int, default=2000, help="" 69 | "number of keypoints per image") 70 | obj_arg.add_argument( 71 | "--obj_top_k", type=int, default=-1, help="" 72 | "number of keypoints above the threshold to use for " 73 | "essential matrix estimation. put -1 to use all. ") 74 | obj_arg.add_argument( 75 | "--obj_geod_type", type=str, default="episym", 76 | choices=["sampson", "episqr", "episym"], help="" 77 | "type of geodesic distance") 78 | obj_arg.add_argument( 79 | "--obj_geod_th", type=float, default=1e-4, help="" 80 | "theshold for the good geodesic distance") 81 | 82 | 83 | # ----------------------------------------------------------------------------- 84 | # Loss 85 | loss_arg = add_argument_group("loss") 86 | loss_arg.add_argument( 87 | "--weight_decay", type=float, default=0, help="" 88 | "l2 decay") 89 | loss_arg.add_argument( 90 | "--momentum", type=float, default=0.9, help="" 91 | "momentum") 92 | loss_arg.add_argument( 93 | "--loss_classif", type=float, default=1.0, help="" 94 | "weight of the classification loss") 95 | loss_arg.add_argument( 96 | "--loss_essential", type=float, default=0.5, help="" 97 | "weight of the essential loss") 98 | loss_arg.add_argument( 99 | "--loss_essential_init_iter", type=int, default=20000, help="" 100 | "initial iterations to run only the classification loss") 101 | loss_arg.add_argument( 102 | "--geo_loss_margin", type=float, default=0.1, help="" 103 | "clamping margin in geometry loss") 104 | 105 | # ----------------------------------------------------------------------------- 106 | # Training 107 | train_arg = add_argument_group("Train") 108 | train_arg.add_argument( 109 | "--run_mode", type=str, default="train", help="" 110 | "run_mode") 111 | train_arg.add_argument( 112 | "--train_lr", type=float, default=1e-3, help="" 113 | "learning rate") 114 | train_arg.add_argument( 115 | "--train_batch_size", type=int, default=32, help="" 116 | "batch size") 117 | train_arg.add_argument( 118 | "--gpu_id", type=str, default='0', help='id(s) for CUDA_VISIBLE_DEVICES') 119 | train_arg.add_argument( 120 | "--num_processor", type=int, default=8, help='numbers of used cpu') 121 | train_arg.add_argument( 122 | "--train_iter", type=int, default=500000, help="" 123 | "training iterations to perform") 124 | train_arg.add_argument( 125 | "--log_base", type=str, default="./log/", help="" 126 | "save directory name inside results") 127 | train_arg.add_argument( 128 | "--log_suffix", type=str, default="", help="" 129 | "suffix of log dir") 130 | train_arg.add_argument( 131 | "--val_intv", type=int, default=10000, help="" 132 | "validation interval") 133 | train_arg.add_argument( 134 | "--save_intv", type=int, default=1000, help="" 135 | "summary interval") 136 | 137 | # ----------------------------------------------------------------------------- 138 | # Testing 139 | test_arg = add_argument_group("Test") 140 | test_arg.add_argument( 141 | "--use_ransac", type=str2bool, default=True, help="" 142 | "use ransac when testing?") 143 | test_arg.add_argument( 144 | "--model_path", type=str, default="", help="" 145 | "saved best model path for test") 146 | test_arg.add_argument( 147 | "--res_path", type=str, default="", help="" 148 | "path for saving results") 149 | 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Visualization 153 | vis_arg = add_argument_group('Visualization') 154 | vis_arg.add_argument( 155 | "--tqdm_width", type=int, default=79, help="" 156 | "width of the tqdm bar" 157 | ) 158 | 159 | 160 | def get_config(): 161 | config, unparsed = parser.parse_known_args() 162 | return config, unparsed 163 | 164 | 165 | def print_usage(): 166 | parser.print_usage() 167 | 168 | # 169 | # config.py ends here 170 | -------------------------------------------------------------------------------- /core/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import h5py 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import torch.utils.data as data 8 | from utils import np_skew_symmetric 9 | 10 | 11 | def collate_fn(batch): 12 | batch_size = len(batch) 13 | numkps = np.array([sample['xs'].shape[1] for sample in batch]) 14 | cur_num_kp = int(numkps.min()) 15 | 16 | data = {} 17 | data['K1s'], data['K2s'], data['Rs'], \ 18 | data['ts'], data['xs'], data['ys'], data['T1s'], data['T2s'], data['virtPts'], data['sides'] = [], [], [], [], [], [], [], [], [], [] 19 | for sample in batch: 20 | data['K1s'].append(sample['K1']) 21 | data['K2s'].append(sample['K2']) 22 | data['T1s'].append(sample['T1']) 23 | data['T2s'].append(sample['T2']) 24 | data['Rs'].append(sample['R']) 25 | data['ts'].append(sample['t']) 26 | data['virtPts'].append(sample['virtPt']) 27 | if sample['xs'].shape[1] > cur_num_kp: 28 | sub_idx = np.random.choice(sample['xs'].shape[1], cur_num_kp) 29 | data['xs'].append(sample['xs'][:,sub_idx,:]) 30 | data['ys'].append(sample['ys'][sub_idx,:]) 31 | if sample['side'] != []: 32 | data['sides'].append(sample['side'][sub_idx,:]) 33 | else: 34 | data['xs'].append(sample['xs']) 35 | data['ys'].append(sample['ys']) 36 | if sample['side'] != []: 37 | data['sides'].append(sample['side']) 38 | 39 | 40 | for key in ['K1s', 'K2s', 'Rs', 'ts', 'xs', 'ys', 'T1s', 'T2s','virtPts']: 41 | data[key] = torch.from_numpy(np.stack(data[key])).float() 42 | if data['sides'] != []: 43 | data['sides'] = torch.from_numpy(np.stack(data['sides'])).float() 44 | return data 45 | 46 | 47 | 48 | class CorrespondencesDataset(data.Dataset): 49 | def __init__(self, filename, config): 50 | self.config = config 51 | self.filename = filename 52 | self.data = None 53 | 54 | def correctMatches(self, e_gt): 55 | step = 0.1 56 | xx,yy = np.meshgrid(np.arange(-1, 1, step), np.arange(-1, 1, step)) 57 | # Points in first image before projection 58 | pts1_virt_b = np.float32(np.vstack((xx.flatten(), yy.flatten())).T) 59 | # Points in second image before projection 60 | pts2_virt_b = np.float32(pts1_virt_b) 61 | pts1_virt_b, pts2_virt_b = pts1_virt_b.reshape(1,-1,2), pts2_virt_b.reshape(1,-1,2) 62 | 63 | pts1_virt_b, pts2_virt_b = cv2.correctMatches(e_gt.reshape(3,3), pts1_virt_b, pts2_virt_b) 64 | 65 | return pts1_virt_b.squeeze(), pts2_virt_b.squeeze() 66 | 67 | def norm_input(self, x): 68 | x_mean = np.mean(x, axis=0) 69 | dist = x - x_mean 70 | meandist = np.sqrt((dist**2).sum(axis=1)).mean() 71 | scale = np.sqrt(2) / meandist 72 | T = np.zeros([3,3]) 73 | T[0,0], T[1,1], T[2,2] = scale, scale, 1 74 | T[0,2], T[1,2] = -scale*x_mean[0], -scale*x_mean[1] 75 | x = x * np.asarray([T[0,0], T[1,1]]) + np.array([T[0,2], T[1,2]]) 76 | return x, T 77 | 78 | def __getitem__(self, index): 79 | if self.data is None: 80 | self.data = h5py.File(self.filename,'r') 81 | 82 | xs = np.asarray(self.data['xs'][str(index)]) 83 | ys = np.asarray(self.data['ys'][str(index)]) 84 | R = np.asarray(self.data['Rs'][str(index)]) 85 | t = np.asarray(self.data['ts'][str(index)]) 86 | side = [] 87 | if self.config.use_ratio == 0 and self.config.use_mutual == 0: 88 | pass 89 | elif self.config.use_ratio == 1 and self.config.use_mutual == 0: 90 | mask = np.asarray(self.data['ratios'][str(index)]).reshape(-1) < config.ratio_test_th 91 | xs = xs[:,mask,:] 92 | ys = ys[:,mask] 93 | elif self.config.use_ratio == 0 and self.config.use_mutual == 1: 94 | mask = np.asarray(self.data['mutuals'][str(index)]).reshape(-1).astype(bool) 95 | xs = xs[:,mask,:] 96 | ys = ys[:,mask] 97 | elif self.config.use_ratio == 2 and self.config.use_mutual == 2: 98 | side.append(np.asarray(self.data['ratios'][str(index)]).reshape(-1,1)) 99 | side.append(np.asarray(self.data['mutuals'][str(index)]).reshape(-1,1)) 100 | side = np.concatenate(side,axis=-1) 101 | else: 102 | raise NotImplementedError 103 | 104 | 105 | e_gt_unnorm = np.reshape(np.matmul( 106 | np.reshape(np_skew_symmetric(t.astype('float64').reshape(1,3)), (3, 3)), np.reshape(R.astype('float64'), (3, 3))), (3, 3)) 107 | e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) 108 | 109 | if self.config.use_fundamental: 110 | cx1 = np.asarray(self.data['cx1s'][str(index)]) 111 | cy1 = np.asarray(self.data['cy1s'][str(index)]) 112 | cx2 = np.asarray(self.data['cx2s'][str(index)]) 113 | cy2 = np.asarray(self.data['cy2s'][str(index)]) 114 | f1 = np.asarray(self.data['f1s'][str(index)]) 115 | f2 = np.asarray(self.data['f2s'][str(index)]) 116 | K1 = np.asarray([ 117 | [f1[0], 0, cx1[0]], 118 | [0, f1[1], cy1[0]], 119 | [0, 0, 1] 120 | ]) 121 | K2 = np.asarray([ 122 | [f2[0], 0, cx2[0]], 123 | [0, f2[1], cy2[0]], 124 | [0, 0, 1] 125 | ]) 126 | x1, x2 = xs[0,:,:2], xs[0,:,2:4] 127 | x1 = x1 * np.asarray([K1[0,0], K1[1,1]]) + np.array([K1[0,2], K1[1,2]]) 128 | x2 = x2 * np.asarray([K2[0,0], K2[1,1]]) + np.array([K2[0,2], K2[1,2]]) 129 | # norm input 130 | x1, T1 = self.norm_input(x1) 131 | x2, T2 = self.norm_input(x2) 132 | 133 | xs = np.concatenate([x1,x2],axis=-1).reshape(1,-1,4) 134 | # get F 135 | e_gt = np.matmul(np.matmul(np.linalg.inv(K2).T, e_gt), np.linalg.inv(K1)) 136 | # get F after norm 137 | e_gt_unnorm = np.matmul(np.matmul(np.linalg.inv(T2).T, e_gt), np.linalg.inv(T1)) 138 | e_gt = e_gt_unnorm / np.linalg.norm(e_gt_unnorm) 139 | else: 140 | K1, K2 = np.zeros(1), np.zeros(1) 141 | T1, T2 = np.zeros(1), np.zeros(1) 142 | 143 | pts1_virt, pts2_virt = self.correctMatches(e_gt) 144 | 145 | pts_virt = np.concatenate([pts1_virt, pts2_virt], axis=1).astype('float64') 146 | return {'K1':K1, 'K2':K2, 'R':R, 't':t, \ 147 | 'xs':xs, 'ys':ys, 'T1':T1, 'T2':T2, 'virtPt':pts_virt, 'side':side} 148 | 149 | def reset(self): 150 | if self.data is not None: 151 | self.data.close() 152 | self.data = None 153 | 154 | def __len__(self): 155 | if self.data is None: 156 | self.data = h5py.File(self.filename,'r') 157 | _len = len(self.data['xs']) 158 | self.data.close() 159 | self.data = None 160 | else: 161 | _len = len(self.data['xs']) 162 | return _len 163 | 164 | def __del__(self): 165 | if self.data is not None: 166 | self.data.close() 167 | 168 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import cv2 5 | from six.moves import xrange 6 | from loss import MatchLoss 7 | from evaluation import eval_nondecompose, eval_decompose 8 | from utils import tocuda, get_pool_result 9 | 10 | 11 | def test_sample(args): 12 | _xs, _dR, _dt, _e_hat, _y_hat, _y_gt, config, = args 13 | _xs = _xs.reshape(-1, 4).astype('float64') 14 | _dR, _dt = _dR.astype('float64').reshape(3,3), _dt.astype('float64') 15 | _y_hat_out = _y_hat.flatten().astype('float64') 16 | e_hat_out = _e_hat.flatten().astype('float64') 17 | 18 | _x1 = _xs[:, :2] 19 | _x2 = _xs[:, 2:] 20 | # current validity from network 21 | _valid = _y_hat_out 22 | # choose top ones (get validity threshold) 23 | _valid_th = np.sort(_valid)[::-1][config.obj_top_k] 24 | _mask_before = _valid >= max(0, _valid_th) 25 | 26 | if not config.use_ransac: 27 | _err_q, _err_t, _, _, _num_inlier, _mask_updated, _R_hat, _t_hat = \ 28 | eval_nondecompose(_x1, _x2, e_hat_out, _dR, _dt, _y_hat_out) 29 | else: 30 | # actually not use prob here since probs is None 31 | _err_q, _err_t, _, _, _num_inlier, _mask_updated, _R_hat, _t_hat = \ 32 | eval_decompose(_x1, _x2, _dR, _dt, mask=_mask_before, method=cv2.RANSAC, \ 33 | probs=None, weighted=False, use_prob=True) 34 | if _R_hat is None: 35 | _R_hat = np.random.randn(3,3) 36 | _t_hat = np.random.randn(3,1) 37 | return [float(_err_q), float(_err_t), float(_num_inlier), _R_hat.reshape(1,-1), _t_hat.reshape(1,-1)] 38 | 39 | def dump_res(measure_list, res_path, eval_res, tag): 40 | # dump test results 41 | for sub_tag in measure_list: 42 | # For median error 43 | ofn = os.path.join(res_path, "median_{}_{}.txt".format(sub_tag, tag)) 44 | with open(ofn, "w") as ofp: 45 | ofp.write("{}\n".format(np.median(eval_res[sub_tag]))) 46 | 47 | ths = np.arange(7) * 5 48 | cur_err_q = np.array(eval_res["err_q"]) * 180.0 / np.pi 49 | cur_err_t = np.array(eval_res["err_t"]) * 180.0 / np.pi 50 | # Get histogram 51 | q_acc_hist, _ = np.histogram(cur_err_q, ths) 52 | t_acc_hist, _ = np.histogram(cur_err_t, ths) 53 | qt_acc_hist, _ = np.histogram(np.maximum(cur_err_q, cur_err_t), ths) 54 | num_pair = float(len(cur_err_q)) 55 | q_acc_hist = q_acc_hist.astype(float) / num_pair 56 | t_acc_hist = t_acc_hist.astype(float) / num_pair 57 | qt_acc_hist = qt_acc_hist.astype(float) / num_pair 58 | q_acc = np.cumsum(q_acc_hist) 59 | t_acc = np.cumsum(t_acc_hist) 60 | qt_acc = np.cumsum(qt_acc_hist) 61 | # Store return val 62 | for _idx_th in xrange(1, len(ths)): 63 | ofn = os.path.join(res_path, "acc_q_auc{}_{}.txt".format(ths[_idx_th], tag)) 64 | with open(ofn, "w") as ofp: 65 | ofp.write("{}\n".format(np.mean(q_acc[:_idx_th]))) 66 | ofn = os.path.join(res_path, "acc_t_auc{}_{}.txt".format(ths[_idx_th], tag)) 67 | with open(ofn, "w") as ofp: 68 | ofp.write("{}\n".format(np.mean(t_acc[:_idx_th]))) 69 | ofn = os.path.join(res_path, "acc_qt_auc{}_{}.txt".format(ths[_idx_th], tag)) 70 | with open(ofn, "w") as ofp: 71 | ofp.write("{}\n".format(np.mean(qt_acc[:_idx_th]))) 72 | 73 | ofn = os.path.join(res_path, "all_acc_qt_auc20_{}.txt".format(tag)) 74 | np.savetxt(ofn, np.maximum(cur_err_q, cur_err_t)) 75 | ofn = os.path.join(res_path, "all_acc_q_auc20_{}.txt".format(tag)) 76 | np.savetxt(ofn, cur_err_q) 77 | ofn = os.path.join(res_path, "all_acc_t_auc20_{}.txt".format(tag)) 78 | np.savetxt(ofn, cur_err_t) 79 | 80 | # Return qt_auc20 81 | ret_val = np.mean(qt_acc[:4]) # 1 == 5 82 | return ret_val 83 | 84 | def denorm(x, T): 85 | x = (x - np.array([T[0,2], T[1,2]])) / np.asarray([T[0,0], T[1,1]]) 86 | return x 87 | 88 | def test_process(mode, model, cur_global_step, data_loader, config): 89 | model.eval() 90 | match_loss = MatchLoss(config) 91 | loader_iter = iter(data_loader) 92 | 93 | # save info given by the network 94 | network_infor_list = ["geo_losses", "cla_losses", "l2_losses", 'precisions', 'recalls', 'f_scores'] 95 | network_info = {info:[] for info in network_infor_list} 96 | 97 | results, pool_arg = [], [] 98 | eval_step, eval_step_i, num_processor = 100, 0, 8 99 | with torch.no_grad(): 100 | for test_data in loader_iter: 101 | test_data = tocuda(test_data) 102 | res_logits, res_e_hat = model(test_data) 103 | y_hat, e_hat = res_logits[-1], res_e_hat[-1] 104 | loss, geo_loss, cla_loss, l2_loss, prec, rec = match_loss.run(cur_global_step, test_data, y_hat, e_hat) 105 | info = [geo_loss, cla_loss, l2_loss, prec, rec, 2*prec*rec/(prec+rec+1e-15)] 106 | for info_idx, value in enumerate(info): 107 | network_info[network_infor_list[info_idx]].append(value) 108 | 109 | if config.use_fundamental: 110 | # unnorm F 111 | e_hat = torch.matmul(torch.matmul(test_data['T2s'].transpose(1,2), e_hat.reshape(-1,3,3)),test_data['T1s']) 112 | # get essential matrix from fundamental matrix 113 | e_hat = torch.matmul(torch.matmul(test_data['K2s'].transpose(1,2), e_hat.reshape(-1,3,3)),test_data['K1s']).reshape(-1,9) 114 | e_hat = e_hat / torch.norm(e_hat, dim=1, keepdim=True) 115 | 116 | for batch_idx in range(e_hat.shape[0]): 117 | test_xs = test_data['xs'][batch_idx].detach().cpu().numpy() 118 | if config.use_fundamental: # back to original 119 | x1, x2 = test_xs[0,:,:2], test_xs[0,:,2:4] 120 | T1, T2 = test_data['T1s'][batch_idx].cpu().numpy(), test_data['T2s'][batch_idx].cpu().numpy() 121 | x1, x2 = denorm(x1, T1), denorm(x2, T2) # denormalize coordinate 122 | K1, K2 = test_data['K1s'][batch_idx].cpu().numpy(), test_data['K2s'][batch_idx].cpu().numpy() 123 | x1, x2 = denorm(x1, K1), denorm(x2, K2) # normalize coordiante with intrinsic 124 | test_xs = np.concatenate([x1,x2],axis=-1).reshape(1,-1,4) 125 | 126 | pool_arg += [(test_xs, test_data['Rs'][batch_idx].detach().cpu().numpy(), \ 127 | test_data['ts'][batch_idx].detach().cpu().numpy(), e_hat[batch_idx].detach().cpu().numpy(), \ 128 | y_hat[batch_idx].detach().cpu().numpy(), \ 129 | test_data['ys'][batch_idx,:,0].detach().cpu().numpy(), config)] 130 | 131 | eval_step_i += 1 132 | if eval_step_i % eval_step == 0: 133 | results += get_pool_result(num_processor, test_sample, pool_arg) 134 | pool_arg = [] 135 | if len(pool_arg) > 0: 136 | results += get_pool_result(num_processor, test_sample, pool_arg) 137 | 138 | measure_list = ["err_q", "err_t", "num", 'R_hat', 't_hat'] 139 | eval_res = {} 140 | for measure_idx, measure in enumerate(measure_list): 141 | eval_res[measure] = np.asarray([result[measure_idx] for result in results]) 142 | 143 | if config.res_path == '': 144 | config.res_path = os.path.join(config.log_path[:-5], mode) 145 | tag = "ours" if not config.use_ransac else "ours_ransac" 146 | ret_val = dump_res(measure_list, config.res_path, eval_res, tag) 147 | return [ret_val, np.mean(np.asarray(network_info['geo_losses'])), np.mean(np.asarray(network_info['cla_losses'])), \ 148 | np.mean(np.asarray(network_info['l2_losses'])), np.mean(np.asarray(network_info['precisions'])), \ 149 | np.mean(np.asarray(network_info['recalls'])), np.mean(np.asarray(network_info['f_scores']))] 150 | 151 | 152 | def test(data_loader, model, config): 153 | save_file_best = os.path.join(config.model_path, 'model_best.pth') 154 | if not os.path.exists(save_file_best): 155 | print("Model File {} does not exist! Quiting".format(save_file_best)) 156 | exit(1) 157 | # Restore model 158 | checkpoint = torch.load(save_file_best) 159 | start_epoch = checkpoint['epoch'] 160 | model.load_state_dict(checkpoint['state_dict']) 161 | model.cuda() 162 | print("Restoring from " + str(save_file_best) + ', ' + str(start_epoch) + "epoch...\n") 163 | if config.res_path == '': 164 | config.res_path = config.model_path[:-5]+'test' 165 | print('save result to '+config.res_path) 166 | va_res = test_process("test", model, 0, data_loader, config) 167 | print('test result '+str(va_res)) 168 | def valid(data_loader, model, step, config): 169 | config.use_ransac = False 170 | return test_process("valid", model, step, data_loader, config) 171 | 172 | -------------------------------------------------------------------------------- /core/pgf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from loss import batch_episym 4 | from torch.nn import functional as F 5 | import numpy as np 6 | 7 | 8 | class GRA_block(nn.Module): 9 | def __init__(self, channels, out_channels=None): 10 | nn.Module.__init__(self) 11 | sub_channels = channels // 4 12 | 13 | 14 | self.spatial_att = nn.Sequential( 15 | nn.Conv2d(2, 1, kernel_size=1), 16 | nn.BatchNorm2d(1) 17 | ) 18 | 19 | self.channel_att = nn.Sequential( 20 | nn.Conv2d(channels, sub_channels, kernel_size=1), 21 | nn.BatchNorm2d(sub_channels), 22 | nn.ReLU(), 23 | nn.Conv2d(sub_channels, channels, kernel_size=1), 24 | nn.BatchNorm2d(channels), 25 | ) 26 | 27 | self.conv1 = nn.Sequential( 28 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 29 | nn.BatchNorm2d(sub_channels), 30 | nn.ReLU(), 31 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1), 32 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 33 | nn.BatchNorm2d(sub_channels), 34 | nn.ReLU(), 35 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1) 36 | ) 37 | self.conv2 = nn.Sequential( 38 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 39 | nn.BatchNorm2d(sub_channels), 40 | nn.ReLU(), 41 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1), 42 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 43 | nn.BatchNorm2d(sub_channels), 44 | nn.ReLU(), 45 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1) 46 | ) 47 | self.conv3 = nn.Sequential( 48 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 49 | nn.BatchNorm2d(sub_channels), 50 | nn.ReLU(), 51 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1), 52 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 53 | nn.BatchNorm2d(sub_channels), 54 | nn.ReLU(), 55 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1) 56 | ) 57 | self.conv4 = nn.Sequential( 58 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 59 | nn.BatchNorm2d(sub_channels), 60 | nn.ReLU(), 61 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1), 62 | nn.InstanceNorm2d(sub_channels, eps=1e-3), 63 | nn.BatchNorm2d(sub_channels), 64 | nn.ReLU(), 65 | nn.Conv2d(sub_channels, sub_channels, kernel_size=1) 66 | ) 67 | 68 | def forward(self, x): 69 | spx = torch.split(x, 64, 1) 70 | #x_spatial = self.spatial_attpre(spx[0]) 71 | 72 | x_spatial = spx[0] 73 | x_spatial = torch.cat((torch.max(x_spatial, 1)[0].unsqueeze(1), torch.mean(x_spatial, 1).unsqueeze(1)), dim=1) 74 | x_spatial = self.spatial_att(x_spatial) 75 | scale_sa = torch.sigmoid(x_spatial) 76 | 77 | sp1 = spx[0] * scale_sa 78 | sp1 = self.conv1(sp1) 79 | 80 | sp2 = sp1 + spx[1] * scale_sa 81 | sp2 = self.conv2(sp2) 82 | 83 | sp3 = sp2 + spx[2] * scale_sa 84 | sp3 = self.conv3(sp3) 85 | 86 | sp4 = sp3 + spx[3] * scale_sa 87 | sp4 = self.conv4(sp4) 88 | 89 | cat = torch.cat((sp1, sp2, sp3, sp4), 1) 90 | out = cat 91 | 92 | xag = F.avg_pool2d(out, (out.size(2), x.size(3)), stride=(out.size(2), x.size(3))) 93 | xmg = F.max_pool2d(out, (out.size(2), x.size(3)), stride=(out.size(2), x.size(3))) 94 | 95 | xam = self.channel_att(xag + xmg) 96 | 97 | scale_ca = torch.sigmoid(xam) 98 | out = out * scale_ca 99 | 100 | out = out + x 101 | out = shuffle_chnls(out, 4) 102 | return out 103 | 104 | class PointCN_down128(nn.Module): 105 | def __init__(self, channels, out_channels=None): 106 | nn.Module.__init__(self) 107 | self.conv = nn.Sequential( 108 | nn.InstanceNorm2d(channels , eps=1e-3), 109 | nn.BatchNorm2d(channels), 110 | nn.ReLU(), 111 | nn.Conv2d(channels, channels // 2, kernel_size=1) 112 | ) 113 | 114 | def forward(self, x): 115 | out = self.conv(x) 116 | 117 | return out 118 | 119 | 120 | class PointCN_256(nn.Module): 121 | def __init__(self, channels, out_channels=None): 122 | nn.Module.__init__(self) 123 | self.conv = nn.Sequential( 124 | nn.InstanceNorm2d(channels, eps=1e-3), 125 | nn.BatchNorm2d(channels), 126 | nn.ReLU(), 127 | nn.Conv2d(channels, channels, kernel_size=1) 128 | ) 129 | 130 | def forward(self, x): 131 | out = self.conv(x) 132 | 133 | return out 134 | 135 | 136 | class diff_pool(nn.Module): 137 | def __init__(self, in_channel, output_points): 138 | nn.Module.__init__(self) 139 | self.output_points = output_points 140 | self.conv = nn.Sequential( 141 | nn.InstanceNorm2d(in_channel, eps=1e-3), 142 | nn.BatchNorm2d(in_channel), 143 | nn.ReLU(), 144 | nn.Conv2d(in_channel, output_points, kernel_size=1)) 145 | 146 | def forward(self, x): 147 | embed = self.conv(x) # b*k*n*1 148 | S = torch.softmax(embed, dim=2).squeeze(3) 149 | out = torch.matmul(x.squeeze(3), S.transpose(1, 2)).unsqueeze(3) 150 | return out 151 | 152 | 153 | class diff_unpool(nn.Module): 154 | def __init__(self, in_channel, output_points): 155 | nn.Module.__init__(self) 156 | self.output_points = output_points 157 | self.conv = nn.Sequential( 158 | nn.InstanceNorm2d(in_channel, eps=1e-3), 159 | nn.BatchNorm2d(in_channel), 160 | nn.ReLU(), 161 | nn.Conv2d(in_channel, output_points, kernel_size=1)) 162 | 163 | def forward(self, x_up, x_down): 164 | # x_up: b*c*n*1 165 | # x_down: b*c*k*1 166 | embed = self.conv(x_up) # b*k*n*1 167 | S = torch.softmax(embed, dim=1).squeeze(3) # b*k*n 168 | out = torch.matmul(x_down.squeeze(3), S).unsqueeze(3) 169 | return out 170 | 171 | 172 | class trans(nn.Module): 173 | def __init__(self, dim1, dim2): 174 | nn.Module.__init__(self) 175 | self.dim1 = dim1 176 | self.dim2 = dim2 177 | 178 | def forward(self, x): 179 | return x.transpose(self.dim1, self.dim2) 180 | 181 | 182 | class OAFilter(nn.Module): 183 | def __init__(self, channels, points, out_channels=None): 184 | nn.Module.__init__(self) 185 | if not out_channels: 186 | out_channels = channels 187 | self.shot_cut = None 188 | if out_channels != channels: 189 | self.shot_cut = nn.Conv2d(channels, out_channels, kernel_size=1) 190 | self.conv1 = nn.Sequential( 191 | nn.InstanceNorm2d(channels, eps=1e-3), 192 | nn.BatchNorm2d(channels), 193 | nn.ReLU(), 194 | nn.Conv2d(channels, out_channels, kernel_size=1), # b*c*n*1 195 | trans(1, 2)) 196 | # Spatial Correlation Layer 197 | self.conv2 = nn.Sequential( 198 | nn.BatchNorm2d(points), 199 | nn.ReLU(), 200 | nn.Conv2d(points, points, kernel_size=1) 201 | ) 202 | self.conv3 = nn.Sequential( 203 | trans(1, 2), 204 | nn.InstanceNorm2d(out_channels, eps=1e-3), 205 | nn.BatchNorm2d(out_channels), 206 | nn.ReLU(), 207 | nn.Conv2d(out_channels, out_channels, kernel_size=1) 208 | ) 209 | 210 | def forward(self, x): 211 | out = self.conv1(x) 212 | out = out + self.conv2(out) 213 | out = self.conv3(out) 214 | if self.shot_cut: 215 | out = out + self.shot_cut(x) 216 | else: 217 | out = out + x 218 | return out 219 | 220 | 221 | class GRAModule(nn.Module): 222 | def __init__(self, net_channels, input_channel, depth, clusters): 223 | nn.Module.__init__(self) 224 | channels = net_channels 225 | self.layer_num = depth 226 | print('channels:' + str(channels) + ', layer_num:' + str(self.layer_num)) 227 | self.conv1 = nn.Conv2d(input_channel, channels, kernel_size=1) 228 | 229 | l2_nums = clusters 230 | 231 | self.l1_1 = [] 232 | for _ in range(self.layer_num // 2): 233 | self.l1_1.append(GRA_block(channels)) 234 | 235 | self.l1_down128 = PointCN_down128(channels) 236 | self.down1 = diff_pool(channels//2, l2_nums) 237 | 238 | self.l2 = [] 239 | for _ in range(self.layer_num // 2): 240 | self.l2.append(OAFilter(channels //2, l2_nums)) 241 | 242 | self.up1 = diff_unpool(channels//2, l2_nums) 243 | 244 | self.l1_2 = [] 245 | self.l1_exchange = PointCN_256(channels) 246 | 247 | for _ in range(self.layer_num // 2): 248 | self.l1_2.append(GRA_block(channels)) 249 | 250 | self.l1_1 = nn.Sequential(*self.l1_1) 251 | self.l1_2 = nn.Sequential(*self.l1_2) 252 | self.l2 = nn.Sequential(*self.l2) 253 | 254 | self.output = nn.Conv2d(channels, 1, kernel_size=1) 255 | 256 | def forward(self, data, xs): 257 | # data: b*c*n*1 258 | batch_size, num_pts = data.shape[0], data.shape[2] 259 | x1_1 = self.conv1(data) 260 | x1_1 = self.l1_1(x1_1) 261 | 262 | x1_1 = self.l1_down128(x1_1) 263 | 264 | x_down = self.down1(x1_1) 265 | x2 = self.l2(x_down) 266 | x_up = self.up1(x1_1, x2) 267 | x1_2 = self.l1_exchange(torch.cat([x1_1, x_up], dim=1)) 268 | out = self.l1_2(x1_2) 269 | 270 | logits = torch.squeeze(torch.squeeze(self.output(out), 3), 1) 271 | e_hat = weighted_8points(xs, logits) 272 | 273 | x1, x2 = xs[:, 0, :, :2], xs[:, 0, :, 2:4] 274 | e_hat_norm = e_hat 275 | residual = batch_episym(x1, x2, e_hat_norm).reshape(batch_size, 1, num_pts, 1) 276 | 277 | return logits, e_hat, residual 278 | 279 | 280 | class PGFNet(nn.Module): 281 | def __init__(self, config): 282 | nn.Module.__init__(self) 283 | self.iter_num = config.iter_num 284 | depth_each_stage = config.net_depth // (config.iter_num + 1) 285 | self.side_channel = (config.use_ratio == 2) + (config.use_mutual == 2) 286 | self.weights_init = GRAModule(config.net_channels, 4 + self.side_channel, depth_each_stage, config.clusters) 287 | self.weights_iter = [GRAModule(config.net_channels, 6 + self.side_channel, depth_each_stage, config.clusters) for 288 | _ in range(config.iter_num)] 289 | self.weights_iter = nn.Sequential(*self.weights_iter) 290 | self.gamma = nn.Parameter(torch.zeros(1)) 291 | 292 | def forward(self, data): 293 | assert data['xs'].dim() == 4 and data['xs'].shape[1] == 1 294 | batch_size, num_pts = data['xs'].shape[0], data['xs'].shape[2] 295 | # data: b*1*n*c 296 | # x_weight = self.positon(data['xs']) 297 | input = data['xs'].transpose(1, 3) 298 | if self.side_channel > 0: 299 | sides = data['sides'].transpose(1, 2).unsqueeze(3) 300 | input = torch.cat([input, sides], dim=1) 301 | 302 | res_logits, res_e_hat = [], [] 303 | logits, e_hat, residual = self.weights_init(input, data['xs']) 304 | res_logits.append(logits), res_e_hat.append(e_hat) 305 | logits_1 = logits 306 | logits_2 = torch.zeros(logits.shape).cuda() 307 | 308 | index_k = logits.topk(k=num_pts//2, dim=-1)[1] 309 | input_new = torch.stack( 310 | [ input[i].squeeze().transpose(0, 1)[index_k[i]] for i in range(input.size(0))]).unsqueeze(-1).transpose(1, 2) 311 | 312 | residual_new = torch.stack( 313 | [ residual[i].squeeze(0)[index_k[i]] for i in range(residual.size(0))]).unsqueeze(1) 314 | 315 | logits_new = logits.reshape(residual.shape) 316 | logits_new = torch.stack( 317 | [ logits_new[i].squeeze(0)[index_k[i]] for i in range(logits_new.size(0))]).unsqueeze(1) 318 | 319 | data_new = torch.stack( 320 | [ data['xs'][i].squeeze(0)[index_k[i]] for i in range(input.size(0))]).unsqueeze(1) 321 | 322 | 323 | 324 | for i in range(self.iter_num): 325 | logits, e_hat, residual = self.weights_iter[i]( 326 | torch.cat([input_new, residual_new.detach(), torch.relu(torch.tanh(logits_new)).detach()], 327 | dim=1), 328 | data_new) 329 | '''for i in range(logits_2.size(0)): 330 | for j in range(logits.size(-1)): 331 | logits_2[i][index_k[i][j]] = logits[i][j]''' 332 | logits_2.scatter_(1, index_k, logits) 333 | 334 | logits_2 = logits_2 + self.gamma*logits_1 335 | e_hat = weighted_8points(data['xs'], logits_2) 336 | 337 | 338 | res_logits.append(logits_2), res_e_hat.append(e_hat) 339 | #print(self.gamma) 340 | return res_logits, res_e_hat 341 | 342 | def batch_symeig(X): 343 | # it is much faster to run symeig on CPU 344 | X = X.cpu() 345 | b, d, _ = X.size() 346 | bv = X.new(b, d, d) 347 | for batch_idx in range(X.shape[0]): 348 | e, v = torch.symeig(X[batch_idx, :, :].squeeze(), True) 349 | bv[batch_idx, :, :] = v 350 | bv = bv.cuda() 351 | return bv 352 | 353 | 354 | def weighted_8points(x_in, logits): 355 | # x_in: batch * 1 * N * 4 356 | x_shp = x_in.shape 357 | # Turn into weights for each sample 358 | weights = torch.relu(torch.tanh(logits)) 359 | x_in = x_in.squeeze(1) 360 | 361 | # Make input data (num_img_pair x num_corr x 4) 362 | xx = torch.reshape(x_in, (x_shp[0], x_shp[2], 4)).permute(0, 2, 1) 363 | 364 | # Create the matrix to be used for the eight-point algorithm 365 | X = torch.stack([ 366 | xx[:, 2] * xx[:, 0], xx[:, 2] * xx[:, 1], xx[:, 2], 367 | xx[:, 3] * xx[:, 0], xx[:, 3] * xx[:, 1], xx[:, 3], 368 | xx[:, 0], xx[:, 1], torch.ones_like(xx[:, 0]) 369 | ], dim=1).permute(0, 2, 1) 370 | wX = torch.reshape(weights, (x_shp[0], x_shp[2], 1)) * X 371 | XwX = torch.matmul(X.permute(0, 2, 1), wX) 372 | 373 | # Recover essential matrix from self-adjoing eigen 374 | v = batch_symeig(XwX) 375 | e_hat = torch.reshape(v[:, :, 0], (x_shp[0], 9)) 376 | 377 | # Make unit norm just in case 378 | e_hat = e_hat / torch.norm(e_hat, dim=1, keepdim=True) 379 | return e_hat 380 | 381 | 382 | def shuffle_chnls(x, groups=4): 383 | """Channel Shuffle""" 384 | 385 | bs, chnls, h, w = x.data.size() 386 | if chnls % groups: 387 | return x 388 | chnls_per_group = chnls // groups 389 | x = x.view(bs, groups, chnls_per_group, h, w) 390 | x = torch.transpose(x, 1, 2).contiguous() 391 | x = x.view(bs, -1, h, w) 392 | 393 | return x 394 | -------------------------------------------------------------------------------- /core/transformations.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # transformations.py 3 | 4 | # Copyright (c) 2006-2015, Christoph Gohlke 5 | # Copyright (c) 2006-2015, The Regents of the University of California 6 | # Produced at the Laboratory for Fluorescence Dynamics 7 | # All rights reserved. 8 | # 9 | # Redistribution and use in source and binary forms, with or without 10 | # modification, are permitted provided that the following conditions are met: 11 | # 12 | # * Redistributions of source code must retain the above copyright 13 | # notice, this list of conditions and the following disclaimer. 14 | # * Redistributions in binary form must reproduce the above copyright 15 | # notice, this list of conditions and the following disclaimer in the 16 | # documentation and/or other materials provided with the distribution. 17 | # * Neither the name of the copyright holders nor the names of any 18 | # contributors may be used to endorse or promote products derived 19 | # from this software without specific prior written permission. 20 | # 21 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 25 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | # POSSIBILITY OF SUCH DAMAGE. 32 | 33 | """Homogeneous Transformation Matrices and Quaternions. 34 | 35 | A library for calculating 4x4 matrices for translating, rotating, reflecting, 36 | scaling, shearing, projecting, orthogonalizing, and superimposing arrays of 37 | 3D homogeneous coordinates as well as for converting between rotation matrices, 38 | Euler angles, and quaternions. Also includes an Arcball control object and 39 | functions to decompose transformation matrices. 40 | 41 | :Author: 42 | `Christoph Gohlke `_ 43 | 44 | :Organization: 45 | Laboratory for Fluorescence Dynamics, University of California, Irvine 46 | 47 | :Version: 2015.07.18 48 | 49 | Requirements 50 | ------------ 51 | * `CPython 2.7 or 3.4 `_ 52 | * `Numpy 1.9 `_ 53 | * `Transformations.c 2015.07.18 `_ 54 | (recommended for speedup of some functions) 55 | 56 | Notes 57 | ----- 58 | The API is not stable yet and is expected to change between revisions. 59 | 60 | This Python code is not optimized for speed. Refer to the transformations.c 61 | module for a faster implementation of some functions. 62 | 63 | Documentation in HTML format can be generated with epydoc. 64 | 65 | Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using 66 | numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using 67 | numpy.dot(M, v) for shape (4, \*) column vectors, respectively 68 | numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points"). 69 | 70 | This module follows the "column vectors on the right" and "row major storage" 71 | (C contiguous) conventions. The translation components are in the right column 72 | of the transformation matrix, i.e. M[:3, 3]. 73 | The transpose of the transformation matrices may have to be used to interface 74 | with other graphics systems, e.g. with OpenGL's glMultMatrixd(). See also [16]. 75 | 76 | Calculations are carried out with numpy.float64 precision. 77 | 78 | Vector, point, quaternion, and matrix function arguments are expected to be 79 | "array like", i.e. tuple, list, or numpy arrays. 80 | 81 | Return types are numpy arrays unless specified otherwise. 82 | 83 | Angles are in radians unless specified otherwise. 84 | 85 | Quaternions w+ix+jy+kz are represented as [w, x, y, z]. 86 | 87 | A triple of Euler angles can be applied/interpreted in 24 ways, which can 88 | be specified using a 4 character string or encoded 4-tuple: 89 | 90 | *Axes 4-string*: e.g. 'sxyz' or 'ryxy' 91 | 92 | - first character : rotations are applied to 's'tatic or 'r'otating frame 93 | - remaining characters : successive rotation axis 'x', 'y', or 'z' 94 | 95 | *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1) 96 | 97 | - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix. 98 | - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed 99 | by 'z', or 'z' is followed by 'x'. Otherwise odd (1). 100 | - repetition : first and last axis are same (1) or different (0). 101 | - frame : rotations are applied to static (0) or rotating (1) frame. 102 | 103 | Other Python packages and modules for 3D transformations and quaternions: 104 | 105 | * `Transforms3d `_ 106 | includes most code of this module. 107 | * `Blender.mathutils `_ 108 | * `numpy-dtypes `_ 109 | 110 | References 111 | ---------- 112 | (1) Matrices and transformations. Ronald Goldman. 113 | In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990. 114 | (2) More matrices and transformations: shear and pseudo-perspective. 115 | Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. 116 | (3) Decomposing a matrix into simple transformations. Spencer Thomas. 117 | In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991. 118 | (4) Recovering the data from the transformation matrix. Ronald Goldman. 119 | In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991. 120 | (5) Euler angle conversion. Ken Shoemake. 121 | In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994. 122 | (6) Arcball rotation control. Ken Shoemake. 123 | In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994. 124 | (7) Representing attitude: Euler angles, unit quaternions, and rotation 125 | vectors. James Diebel. 2006. 126 | (8) A discussion of the solution for the best rotation to relate two sets 127 | of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828. 128 | (9) Closed-form solution of absolute orientation using unit quaternions. 129 | BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642. 130 | (10) Quaternions. Ken Shoemake. 131 | http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf 132 | (11) From quaternion to matrix and back. JMP van Waveren. 2005. 133 | http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm 134 | (12) Uniform random rotations. Ken Shoemake. 135 | In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992. 136 | (13) Quaternion in molecular modeling. CFF Karney. 137 | J Mol Graph Mod, 25(5):595-604 138 | (14) New method for extracting the quaternion from a rotation matrix. 139 | Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087. 140 | (15) Multiple View Geometry in Computer Vision. Hartley and Zissermann. 141 | Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130. 142 | (16) Column Vectors vs. Row Vectors. 143 | http://steve.hollasch.net/cgindex/math/matrix/column-vec.html 144 | 145 | Examples 146 | -------- 147 | >>> alpha, beta, gamma = 0.123, -1.234, 2.345 148 | >>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1] 149 | >>> I = identity_matrix() 150 | >>> Rx = rotation_matrix(alpha, xaxis) 151 | >>> Ry = rotation_matrix(beta, yaxis) 152 | >>> Rz = rotation_matrix(gamma, zaxis) 153 | >>> R = concatenate_matrices(Rx, Ry, Rz) 154 | >>> euler = euler_from_matrix(R, 'rxyz') 155 | >>> numpy.allclose([alpha, beta, gamma], euler) 156 | True 157 | >>> Re = euler_matrix(alpha, beta, gamma, 'rxyz') 158 | >>> is_same_transform(R, Re) 159 | True 160 | >>> al, be, ga = euler_from_matrix(Re, 'rxyz') 161 | >>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz')) 162 | True 163 | >>> qx = quaternion_about_axis(alpha, xaxis) 164 | >>> qy = quaternion_about_axis(beta, yaxis) 165 | >>> qz = quaternion_about_axis(gamma, zaxis) 166 | >>> q = quaternion_multiply(qx, qy) 167 | >>> q = quaternion_multiply(q, qz) 168 | >>> Rq = quaternion_matrix(q) 169 | >>> is_same_transform(R, Rq) 170 | True 171 | >>> S = scale_matrix(1.23, origin) 172 | >>> T = translation_matrix([1, 2, 3]) 173 | >>> Z = shear_matrix(beta, xaxis, origin, zaxis) 174 | >>> R = random_rotation_matrix(numpy.random.rand(3)) 175 | >>> M = concatenate_matrices(T, R, Z, S) 176 | >>> scale, shear, angles, trans, persp = decompose_matrix(M) 177 | >>> numpy.allclose(scale, 1.23) 178 | True 179 | >>> numpy.allclose(trans, [1, 2, 3]) 180 | True 181 | >>> numpy.allclose(shear, [0, math.tan(beta), 0]) 182 | True 183 | >>> is_same_transform(R, euler_matrix(axes='sxyz', *angles)) 184 | True 185 | >>> M1 = compose_matrix(scale, shear, angles, trans, persp) 186 | >>> is_same_transform(M, M1) 187 | True 188 | >>> v0, v1 = random_vector(3), random_vector(3) 189 | >>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1)) 190 | >>> v2 = numpy.dot(v0, M[:3,:3].T) 191 | >>> numpy.allclose(unit_vector(v1), unit_vector(v2)) 192 | True 193 | 194 | """ 195 | 196 | from __future__ import division, print_function 197 | 198 | import math 199 | 200 | import numpy 201 | 202 | __version__ = '2015.07.18' 203 | __docformat__ = 'restructuredtext en' 204 | __all__ = () 205 | 206 | 207 | def identity_matrix(): 208 | """Return 4x4 identity/unit matrix. 209 | 210 | >>> I = identity_matrix() 211 | >>> numpy.allclose(I, numpy.dot(I, I)) 212 | True 213 | >>> numpy.sum(I), numpy.trace(I) 214 | (4.0, 4.0) 215 | >>> numpy.allclose(I, numpy.identity(4)) 216 | True 217 | 218 | """ 219 | return numpy.identity(4) 220 | 221 | 222 | def translation_matrix(direction): 223 | """Return matrix to translate by direction vector. 224 | 225 | >>> v = numpy.random.random(3) - 0.5 226 | >>> numpy.allclose(v, translation_matrix(v)[:3, 3]) 227 | True 228 | 229 | """ 230 | M = numpy.identity(4) 231 | M[:3, 3] = direction[:3] 232 | return M 233 | 234 | 235 | def translation_from_matrix(matrix): 236 | """Return translation vector from translation matrix. 237 | 238 | >>> v0 = numpy.random.random(3) - 0.5 239 | >>> v1 = translation_from_matrix(translation_matrix(v0)) 240 | >>> numpy.allclose(v0, v1) 241 | True 242 | 243 | """ 244 | return numpy.array(matrix, copy=False)[:3, 3].copy() 245 | 246 | 247 | def reflection_matrix(point, normal): 248 | """Return matrix to mirror at plane defined by point and normal vector. 249 | 250 | >>> v0 = numpy.random.random(4) - 0.5 251 | >>> v0[3] = 1. 252 | >>> v1 = numpy.random.random(3) - 0.5 253 | >>> R = reflection_matrix(v0, v1) 254 | >>> numpy.allclose(2, numpy.trace(R)) 255 | True 256 | >>> numpy.allclose(v0, numpy.dot(R, v0)) 257 | True 258 | >>> v2 = v0.copy() 259 | >>> v2[:3] += v1 260 | >>> v3 = v0.copy() 261 | >>> v2[:3] -= v1 262 | >>> numpy.allclose(v2, numpy.dot(R, v3)) 263 | True 264 | 265 | """ 266 | normal = unit_vector(normal[:3]) 267 | M = numpy.identity(4) 268 | M[:3, :3] -= 2.0 * numpy.outer(normal, normal) 269 | M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal 270 | return M 271 | 272 | 273 | def reflection_from_matrix(matrix): 274 | """Return mirror plane point and normal vector from reflection matrix. 275 | 276 | >>> v0 = numpy.random.random(3) - 0.5 277 | >>> v1 = numpy.random.random(3) - 0.5 278 | >>> M0 = reflection_matrix(v0, v1) 279 | >>> point, normal = reflection_from_matrix(M0) 280 | >>> M1 = reflection_matrix(point, normal) 281 | >>> is_same_transform(M0, M1) 282 | True 283 | 284 | """ 285 | M = numpy.array(matrix, dtype=numpy.float64, copy=False) 286 | # normal: unit eigenvector corresponding to eigenvalue -1 287 | w, V = numpy.linalg.eig(M[:3, :3]) 288 | i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0] 289 | if not len(i): 290 | raise ValueError("no unit eigenvector corresponding to eigenvalue -1") 291 | normal = numpy.real(V[:, i[0]]).squeeze() 292 | # point: any unit eigenvector corresponding to eigenvalue 1 293 | w, V = numpy.linalg.eig(M) 294 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 295 | if not len(i): 296 | raise ValueError("no unit eigenvector corresponding to eigenvalue 1") 297 | point = numpy.real(V[:, i[-1]]).squeeze() 298 | point /= point[3] 299 | return point, normal 300 | 301 | 302 | def rotation_matrix(angle, direction, point=None): 303 | """Return matrix to rotate about axis defined by point and direction. 304 | 305 | >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0]) 306 | >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1]) 307 | True 308 | >>> angle = (random.random() - 0.5) * (2*math.pi) 309 | >>> direc = numpy.random.random(3) - 0.5 310 | >>> point = numpy.random.random(3) - 0.5 311 | >>> R0 = rotation_matrix(angle, direc, point) 312 | >>> R1 = rotation_matrix(angle-2*math.pi, direc, point) 313 | >>> is_same_transform(R0, R1) 314 | True 315 | >>> R0 = rotation_matrix(angle, direc, point) 316 | >>> R1 = rotation_matrix(-angle, -direc, point) 317 | >>> is_same_transform(R0, R1) 318 | True 319 | >>> I = numpy.identity(4, numpy.float64) 320 | >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc)) 321 | True 322 | >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2, 323 | ... direc, point))) 324 | True 325 | 326 | """ 327 | sina = math.sin(angle) 328 | cosa = math.cos(angle) 329 | direction = unit_vector(direction[:3]) 330 | # rotation matrix around unit vector 331 | R = numpy.diag([cosa, cosa, cosa]) 332 | R += numpy.outer(direction, direction) * (1.0 - cosa) 333 | direction *= sina 334 | R += numpy.array([[ 0.0, -direction[2], direction[1]], 335 | [ direction[2], 0.0, -direction[0]], 336 | [-direction[1], direction[0], 0.0]]) 337 | M = numpy.identity(4) 338 | M[:3, :3] = R 339 | if point is not None: 340 | # rotation not around origin 341 | point = numpy.array(point[:3], dtype=numpy.float64, copy=False) 342 | M[:3, 3] = point - numpy.dot(R, point) 343 | return M 344 | 345 | 346 | def rotation_from_matrix(matrix): 347 | """Return rotation angle and axis from rotation matrix. 348 | 349 | >>> angle = (random.random() - 0.5) * (2*math.pi) 350 | >>> direc = numpy.random.random(3) - 0.5 351 | >>> point = numpy.random.random(3) - 0.5 352 | >>> R0 = rotation_matrix(angle, direc, point) 353 | >>> angle, direc, point = rotation_from_matrix(R0) 354 | >>> R1 = rotation_matrix(angle, direc, point) 355 | >>> is_same_transform(R0, R1) 356 | True 357 | 358 | """ 359 | R = numpy.array(matrix, dtype=numpy.float64, copy=False) 360 | R33 = R[:3, :3] 361 | # direction: unit eigenvector of R33 corresponding to eigenvalue of 1 362 | w, W = numpy.linalg.eig(R33.T) 363 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 364 | if not len(i): 365 | raise ValueError("no unit eigenvector corresponding to eigenvalue 1") 366 | direction = numpy.real(W[:, i[-1]]).squeeze() 367 | # point: unit eigenvector of R33 corresponding to eigenvalue of 1 368 | w, Q = numpy.linalg.eig(R) 369 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 370 | if not len(i): 371 | raise ValueError("no unit eigenvector corresponding to eigenvalue 1") 372 | point = numpy.real(Q[:, i[-1]]).squeeze() 373 | point /= point[3] 374 | # rotation angle depending on direction 375 | cosa = (numpy.trace(R33) - 1.0) / 2.0 376 | if abs(direction[2]) > 1e-8: 377 | sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2] 378 | elif abs(direction[1]) > 1e-8: 379 | sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1] 380 | else: 381 | sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0] 382 | angle = math.atan2(sina, cosa) 383 | return angle, direction, point 384 | 385 | 386 | def scale_matrix(factor, origin=None, direction=None): 387 | """Return matrix to scale by factor around origin in direction. 388 | 389 | Use factor -1 for point symmetry. 390 | 391 | >>> v = (numpy.random.rand(4, 5) - 0.5) * 20 392 | >>> v[3] = 1 393 | >>> S = scale_matrix(-1.234) 394 | >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3]) 395 | True 396 | >>> factor = random.random() * 10 - 5 397 | >>> origin = numpy.random.random(3) - 0.5 398 | >>> direct = numpy.random.random(3) - 0.5 399 | >>> S = scale_matrix(factor, origin) 400 | >>> S = scale_matrix(factor, origin, direct) 401 | 402 | """ 403 | if direction is None: 404 | # uniform scaling 405 | M = numpy.diag([factor, factor, factor, 1.0]) 406 | if origin is not None: 407 | M[:3, 3] = origin[:3] 408 | M[:3, 3] *= 1.0 - factor 409 | else: 410 | # nonuniform scaling 411 | direction = unit_vector(direction[:3]) 412 | factor = 1.0 - factor 413 | M = numpy.identity(4) 414 | M[:3, :3] -= factor * numpy.outer(direction, direction) 415 | if origin is not None: 416 | M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction 417 | return M 418 | 419 | 420 | def scale_from_matrix(matrix): 421 | """Return scaling factor, origin and direction from scaling matrix. 422 | 423 | >>> factor = random.random() * 10 - 5 424 | >>> origin = numpy.random.random(3) - 0.5 425 | >>> direct = numpy.random.random(3) - 0.5 426 | >>> S0 = scale_matrix(factor, origin) 427 | >>> factor, origin, direction = scale_from_matrix(S0) 428 | >>> S1 = scale_matrix(factor, origin, direction) 429 | >>> is_same_transform(S0, S1) 430 | True 431 | >>> S0 = scale_matrix(factor, origin, direct) 432 | >>> factor, origin, direction = scale_from_matrix(S0) 433 | >>> S1 = scale_matrix(factor, origin, direction) 434 | >>> is_same_transform(S0, S1) 435 | True 436 | 437 | """ 438 | M = numpy.array(matrix, dtype=numpy.float64, copy=False) 439 | M33 = M[:3, :3] 440 | factor = numpy.trace(M33) - 2.0 441 | try: 442 | # direction: unit eigenvector corresponding to eigenvalue factor 443 | w, V = numpy.linalg.eig(M33) 444 | i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0] 445 | direction = numpy.real(V[:, i]).squeeze() 446 | direction /= vector_norm(direction) 447 | except IndexError: 448 | # uniform scaling 449 | factor = (factor + 2.0) / 3.0 450 | direction = None 451 | # origin: any eigenvector corresponding to eigenvalue 1 452 | w, V = numpy.linalg.eig(M) 453 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 454 | if not len(i): 455 | raise ValueError("no eigenvector corresponding to eigenvalue 1") 456 | origin = numpy.real(V[:, i[-1]]).squeeze() 457 | origin /= origin[3] 458 | return factor, origin, direction 459 | 460 | 461 | def projection_matrix(point, normal, direction=None, 462 | perspective=None, pseudo=False): 463 | """Return matrix to project onto plane defined by point and normal. 464 | 465 | Using either perspective point, projection direction, or none of both. 466 | 467 | If pseudo is True, perspective projections will preserve relative depth 468 | such that Perspective = dot(Orthogonal, PseudoPerspective). 469 | 470 | >>> P = projection_matrix([0, 0, 0], [1, 0, 0]) 471 | >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:]) 472 | True 473 | >>> point = numpy.random.random(3) - 0.5 474 | >>> normal = numpy.random.random(3) - 0.5 475 | >>> direct = numpy.random.random(3) - 0.5 476 | >>> persp = numpy.random.random(3) - 0.5 477 | >>> P0 = projection_matrix(point, normal) 478 | >>> P1 = projection_matrix(point, normal, direction=direct) 479 | >>> P2 = projection_matrix(point, normal, perspective=persp) 480 | >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True) 481 | >>> is_same_transform(P2, numpy.dot(P0, P3)) 482 | True 483 | >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0]) 484 | >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20 485 | >>> v0[3] = 1 486 | >>> v1 = numpy.dot(P, v0) 487 | >>> numpy.allclose(v1[1], v0[1]) 488 | True 489 | >>> numpy.allclose(v1[0], 3-v1[1]) 490 | True 491 | 492 | """ 493 | M = numpy.identity(4) 494 | point = numpy.array(point[:3], dtype=numpy.float64, copy=False) 495 | normal = unit_vector(normal[:3]) 496 | if perspective is not None: 497 | # perspective projection 498 | perspective = numpy.array(perspective[:3], dtype=numpy.float64, 499 | copy=False) 500 | M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal) 501 | M[:3, :3] -= numpy.outer(perspective, normal) 502 | if pseudo: 503 | # preserve relative depth 504 | M[:3, :3] -= numpy.outer(normal, normal) 505 | M[:3, 3] = numpy.dot(point, normal) * (perspective+normal) 506 | else: 507 | M[:3, 3] = numpy.dot(point, normal) * perspective 508 | M[3, :3] = -normal 509 | M[3, 3] = numpy.dot(perspective, normal) 510 | elif direction is not None: 511 | # parallel projection 512 | direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False) 513 | scale = numpy.dot(direction, normal) 514 | M[:3, :3] -= numpy.outer(direction, normal) / scale 515 | M[:3, 3] = direction * (numpy.dot(point, normal) / scale) 516 | else: 517 | # orthogonal projection 518 | M[:3, :3] -= numpy.outer(normal, normal) 519 | M[:3, 3] = numpy.dot(point, normal) * normal 520 | return M 521 | 522 | 523 | def projection_from_matrix(matrix, pseudo=False): 524 | """Return projection plane and perspective point from projection matrix. 525 | 526 | Return values are same as arguments for projection_matrix function: 527 | point, normal, direction, perspective, and pseudo. 528 | 529 | >>> point = numpy.random.random(3) - 0.5 530 | >>> normal = numpy.random.random(3) - 0.5 531 | >>> direct = numpy.random.random(3) - 0.5 532 | >>> persp = numpy.random.random(3) - 0.5 533 | >>> P0 = projection_matrix(point, normal) 534 | >>> result = projection_from_matrix(P0) 535 | >>> P1 = projection_matrix(*result) 536 | >>> is_same_transform(P0, P1) 537 | True 538 | >>> P0 = projection_matrix(point, normal, direct) 539 | >>> result = projection_from_matrix(P0) 540 | >>> P1 = projection_matrix(*result) 541 | >>> is_same_transform(P0, P1) 542 | True 543 | >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False) 544 | >>> result = projection_from_matrix(P0, pseudo=False) 545 | >>> P1 = projection_matrix(*result) 546 | >>> is_same_transform(P0, P1) 547 | True 548 | >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True) 549 | >>> result = projection_from_matrix(P0, pseudo=True) 550 | >>> P1 = projection_matrix(*result) 551 | >>> is_same_transform(P0, P1) 552 | True 553 | 554 | """ 555 | M = numpy.array(matrix, dtype=numpy.float64, copy=False) 556 | M33 = M[:3, :3] 557 | w, V = numpy.linalg.eig(M) 558 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 559 | if not pseudo and len(i): 560 | # point: any eigenvector corresponding to eigenvalue 1 561 | point = numpy.real(V[:, i[-1]]).squeeze() 562 | point /= point[3] 563 | # direction: unit eigenvector corresponding to eigenvalue 0 564 | w, V = numpy.linalg.eig(M33) 565 | i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] 566 | if not len(i): 567 | raise ValueError("no eigenvector corresponding to eigenvalue 0") 568 | direction = numpy.real(V[:, i[0]]).squeeze() 569 | direction /= vector_norm(direction) 570 | # normal: unit eigenvector of M33.T corresponding to eigenvalue 0 571 | w, V = numpy.linalg.eig(M33.T) 572 | i = numpy.where(abs(numpy.real(w)) < 1e-8)[0] 573 | if len(i): 574 | # parallel projection 575 | normal = numpy.real(V[:, i[0]]).squeeze() 576 | normal /= vector_norm(normal) 577 | return point, normal, direction, None, False 578 | else: 579 | # orthogonal projection, where normal equals direction vector 580 | return point, direction, None, None, False 581 | else: 582 | # perspective projection 583 | i = numpy.where(abs(numpy.real(w)) > 1e-8)[0] 584 | if not len(i): 585 | raise ValueError( 586 | "no eigenvector not corresponding to eigenvalue 0") 587 | point = numpy.real(V[:, i[-1]]).squeeze() 588 | point /= point[3] 589 | normal = - M[3, :3] 590 | perspective = M[:3, 3] / numpy.dot(point[:3], normal) 591 | if pseudo: 592 | perspective -= normal 593 | return point, normal, None, perspective, pseudo 594 | 595 | 596 | def clip_matrix(left, right, bottom, top, near, far, perspective=False): 597 | """Return matrix to obtain normalized device coordinates from frustum. 598 | 599 | The frustum bounds are axis-aligned along x (left, right), 600 | y (bottom, top) and z (near, far). 601 | 602 | Normalized device coordinates are in range [-1, 1] if coordinates are 603 | inside the frustum. 604 | 605 | If perspective is True the frustum is a truncated pyramid with the 606 | perspective point at origin and direction along z axis, otherwise an 607 | orthographic canonical view volume (a box). 608 | 609 | Homogeneous coordinates transformed by the perspective clip matrix 610 | need to be dehomogenized (divided by w coordinate). 611 | 612 | >>> frustum = numpy.random.rand(6) 613 | >>> frustum[1] += frustum[0] 614 | >>> frustum[3] += frustum[2] 615 | >>> frustum[5] += frustum[4] 616 | >>> M = clip_matrix(perspective=False, *frustum) 617 | >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) 618 | array([-1., -1., -1., 1.]) 619 | >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1]) 620 | array([ 1., 1., 1., 1.]) 621 | >>> M = clip_matrix(perspective=True, *frustum) 622 | >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1]) 623 | >>> v / v[3] 624 | array([-1., -1., -1., 1.]) 625 | >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1]) 626 | >>> v / v[3] 627 | array([ 1., 1., -1., 1.]) 628 | 629 | """ 630 | if left >= right or bottom >= top or near >= far: 631 | raise ValueError("invalid frustum") 632 | if perspective: 633 | if near <= _EPS: 634 | raise ValueError("invalid frustum: near <= 0") 635 | t = 2.0 * near 636 | M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0], 637 | [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0], 638 | [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)], 639 | [0.0, 0.0, -1.0, 0.0]] 640 | else: 641 | M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)], 642 | [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)], 643 | [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)], 644 | [0.0, 0.0, 0.0, 1.0]] 645 | return numpy.array(M) 646 | 647 | 648 | def shear_matrix(angle, direction, point, normal): 649 | """Return matrix to shear by angle along direction vector on shear plane. 650 | 651 | The shear plane is defined by a point and normal vector. The direction 652 | vector must be orthogonal to the plane's normal vector. 653 | 654 | A point P is transformed by the shear matrix into P" such that 655 | the vector P-P" is parallel to the direction vector and its extent is 656 | given by the angle of P-P'-P", where P' is the orthogonal projection 657 | of P onto the shear plane. 658 | 659 | >>> angle = (random.random() - 0.5) * 4*math.pi 660 | >>> direct = numpy.random.random(3) - 0.5 661 | >>> point = numpy.random.random(3) - 0.5 662 | >>> normal = numpy.cross(direct, numpy.random.random(3)) 663 | >>> S = shear_matrix(angle, direct, point, normal) 664 | >>> numpy.allclose(1, numpy.linalg.det(S)) 665 | True 666 | 667 | """ 668 | normal = unit_vector(normal[:3]) 669 | direction = unit_vector(direction[:3]) 670 | if abs(numpy.dot(normal, direction)) > 1e-6: 671 | raise ValueError("direction and normal vectors are not orthogonal") 672 | angle = math.tan(angle) 673 | M = numpy.identity(4) 674 | M[:3, :3] += angle * numpy.outer(direction, normal) 675 | M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction 676 | return M 677 | 678 | 679 | def shear_from_matrix(matrix): 680 | """Return shear angle, direction and plane from shear matrix. 681 | 682 | >>> angle = (random.random() - 0.5) * 4*math.pi 683 | >>> direct = numpy.random.random(3) - 0.5 684 | >>> point = numpy.random.random(3) - 0.5 685 | >>> normal = numpy.cross(direct, numpy.random.random(3)) 686 | >>> S0 = shear_matrix(angle, direct, point, normal) 687 | >>> angle, direct, point, normal = shear_from_matrix(S0) 688 | >>> S1 = shear_matrix(angle, direct, point, normal) 689 | >>> is_same_transform(S0, S1) 690 | True 691 | 692 | """ 693 | M = numpy.array(matrix, dtype=numpy.float64, copy=False) 694 | M33 = M[:3, :3] 695 | # normal: cross independent eigenvectors corresponding to the eigenvalue 1 696 | w, V = numpy.linalg.eig(M33) 697 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0] 698 | if len(i) < 2: 699 | raise ValueError("no two linear independent eigenvectors found %s" % w) 700 | V = numpy.real(V[:, i]).squeeze().T 701 | lenorm = -1.0 702 | for i0, i1 in ((0, 1), (0, 2), (1, 2)): 703 | n = numpy.cross(V[i0], V[i1]) 704 | w = vector_norm(n) 705 | if w > lenorm: 706 | lenorm = w 707 | normal = n 708 | normal /= lenorm 709 | # direction and angle 710 | direction = numpy.dot(M33 - numpy.identity(3), normal) 711 | angle = vector_norm(direction) 712 | direction /= angle 713 | angle = math.atan(angle) 714 | # point: eigenvector corresponding to eigenvalue 1 715 | w, V = numpy.linalg.eig(M) 716 | i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0] 717 | if not len(i): 718 | raise ValueError("no eigenvector corresponding to eigenvalue 1") 719 | point = numpy.real(V[:, i[-1]]).squeeze() 720 | point /= point[3] 721 | return angle, direction, point, normal 722 | 723 | 724 | def decompose_matrix(matrix): 725 | """Return sequence of transformations from transformation matrix. 726 | 727 | matrix : array_like 728 | Non-degenerative homogeneous transformation matrix 729 | 730 | Return tuple of: 731 | scale : vector of 3 scaling factors 732 | shear : list of shear factors for x-y, x-z, y-z axes 733 | angles : list of Euler angles about static x, y, z axes 734 | translate : translation vector along x, y, z axes 735 | perspective : perspective partition of matrix 736 | 737 | Raise ValueError if matrix is of wrong type or degenerative. 738 | 739 | >>> T0 = translation_matrix([1, 2, 3]) 740 | >>> scale, shear, angles, trans, persp = decompose_matrix(T0) 741 | >>> T1 = translation_matrix(trans) 742 | >>> numpy.allclose(T0, T1) 743 | True 744 | >>> S = scale_matrix(0.123) 745 | >>> scale, shear, angles, trans, persp = decompose_matrix(S) 746 | >>> scale[0] 747 | 0.123 748 | >>> R0 = euler_matrix(1, 2, 3) 749 | >>> scale, shear, angles, trans, persp = decompose_matrix(R0) 750 | >>> R1 = euler_matrix(*angles) 751 | >>> numpy.allclose(R0, R1) 752 | True 753 | 754 | """ 755 | M = numpy.array(matrix, dtype=numpy.float64, copy=True).T 756 | if abs(M[3, 3]) < _EPS: 757 | raise ValueError("M[3, 3] is zero") 758 | M /= M[3, 3] 759 | P = M.copy() 760 | P[:, 3] = 0.0, 0.0, 0.0, 1.0 761 | if not numpy.linalg.det(P): 762 | raise ValueError("matrix is singular") 763 | 764 | scale = numpy.zeros((3, )) 765 | shear = [0.0, 0.0, 0.0] 766 | angles = [0.0, 0.0, 0.0] 767 | 768 | if any(abs(M[:3, 3]) > _EPS): 769 | perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T)) 770 | M[:, 3] = 0.0, 0.0, 0.0, 1.0 771 | else: 772 | perspective = numpy.array([0.0, 0.0, 0.0, 1.0]) 773 | 774 | translate = M[3, :3].copy() 775 | M[3, :3] = 0.0 776 | 777 | row = M[:3, :3].copy() 778 | scale[0] = vector_norm(row[0]) 779 | row[0] /= scale[0] 780 | shear[0] = numpy.dot(row[0], row[1]) 781 | row[1] -= row[0] * shear[0] 782 | scale[1] = vector_norm(row[1]) 783 | row[1] /= scale[1] 784 | shear[0] /= scale[1] 785 | shear[1] = numpy.dot(row[0], row[2]) 786 | row[2] -= row[0] * shear[1] 787 | shear[2] = numpy.dot(row[1], row[2]) 788 | row[2] -= row[1] * shear[2] 789 | scale[2] = vector_norm(row[2]) 790 | row[2] /= scale[2] 791 | shear[1:] /= scale[2] 792 | 793 | if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0: 794 | numpy.negative(scale, scale) 795 | numpy.negative(row, row) 796 | 797 | angles[1] = math.asin(-row[0, 2]) 798 | if math.cos(angles[1]): 799 | angles[0] = math.atan2(row[1, 2], row[2, 2]) 800 | angles[2] = math.atan2(row[0, 1], row[0, 0]) 801 | else: 802 | #angles[0] = math.atan2(row[1, 0], row[1, 1]) 803 | angles[0] = math.atan2(-row[2, 1], row[1, 1]) 804 | angles[2] = 0.0 805 | 806 | return scale, shear, angles, translate, perspective 807 | 808 | 809 | def compose_matrix(scale=None, shear=None, angles=None, translate=None, 810 | perspective=None): 811 | """Return transformation matrix from sequence of transformations. 812 | 813 | This is the inverse of the decompose_matrix function. 814 | 815 | Sequence of transformations: 816 | scale : vector of 3 scaling factors 817 | shear : list of shear factors for x-y, x-z, y-z axes 818 | angles : list of Euler angles about static x, y, z axes 819 | translate : translation vector along x, y, z axes 820 | perspective : perspective partition of matrix 821 | 822 | >>> scale = numpy.random.random(3) - 0.5 823 | >>> shear = numpy.random.random(3) - 0.5 824 | >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi) 825 | >>> trans = numpy.random.random(3) - 0.5 826 | >>> persp = numpy.random.random(4) - 0.5 827 | >>> M0 = compose_matrix(scale, shear, angles, trans, persp) 828 | >>> result = decompose_matrix(M0) 829 | >>> M1 = compose_matrix(*result) 830 | >>> is_same_transform(M0, M1) 831 | True 832 | 833 | """ 834 | M = numpy.identity(4) 835 | if perspective is not None: 836 | P = numpy.identity(4) 837 | P[3, :] = perspective[:4] 838 | M = numpy.dot(M, P) 839 | if translate is not None: 840 | T = numpy.identity(4) 841 | T[:3, 3] = translate[:3] 842 | M = numpy.dot(M, T) 843 | if angles is not None: 844 | R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz') 845 | M = numpy.dot(M, R) 846 | if shear is not None: 847 | Z = numpy.identity(4) 848 | Z[1, 2] = shear[2] 849 | Z[0, 2] = shear[1] 850 | Z[0, 1] = shear[0] 851 | M = numpy.dot(M, Z) 852 | if scale is not None: 853 | S = numpy.identity(4) 854 | S[0, 0] = scale[0] 855 | S[1, 1] = scale[1] 856 | S[2, 2] = scale[2] 857 | M = numpy.dot(M, S) 858 | M /= M[3, 3] 859 | return M 860 | 861 | 862 | def orthogonalization_matrix(lengths, angles): 863 | """Return orthogonalization matrix for crystallographic cell coordinates. 864 | 865 | Angles are expected in degrees. 866 | 867 | The de-orthogonalization matrix is the inverse. 868 | 869 | >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90]) 870 | >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10) 871 | True 872 | >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7]) 873 | >>> numpy.allclose(numpy.sum(O), 43.063229) 874 | True 875 | 876 | """ 877 | a, b, c = lengths 878 | angles = numpy.radians(angles) 879 | sina, sinb, _ = numpy.sin(angles) 880 | cosa, cosb, cosg = numpy.cos(angles) 881 | co = (cosa * cosb - cosg) / (sina * sinb) 882 | return numpy.array([ 883 | [ a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0], 884 | [-a*sinb*co, b*sina, 0.0, 0.0], 885 | [ a*cosb, b*cosa, c, 0.0], 886 | [ 0.0, 0.0, 0.0, 1.0]]) 887 | 888 | 889 | def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True): 890 | """Return affine transform matrix to register two point sets. 891 | 892 | v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous 893 | coordinates, where ndims is the dimensionality of the coordinate space. 894 | 895 | If shear is False, a similarity transformation matrix is returned. 896 | If also scale is False, a rigid/Euclidean transformation matrix 897 | is returned. 898 | 899 | By default the algorithm by Hartley and Zissermann [15] is used. 900 | If usesvd is True, similarity and Euclidean transformation matrices 901 | are calculated by minimizing the weighted sum of squared deviations 902 | (RMSD) according to the algorithm by Kabsch [8]. 903 | Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9] 904 | is used, which is slower when using this Python implementation. 905 | 906 | The returned matrix performs rotation, translation and uniform scaling 907 | (if specified). 908 | 909 | >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]] 910 | >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]] 911 | >>> affine_matrix_from_points(v0, v1) 912 | array([[ 0.14549, 0.00062, 675.50008], 913 | [ 0.00048, 0.14094, 53.24971], 914 | [ 0. , 0. , 1. ]]) 915 | >>> T = translation_matrix(numpy.random.random(3)-0.5) 916 | >>> R = random_rotation_matrix(numpy.random.random(3)) 917 | >>> S = scale_matrix(random.random()) 918 | >>> M = concatenate_matrices(T, R, S) 919 | >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 920 | >>> v0[3] = 1 921 | >>> v1 = numpy.dot(M, v0) 922 | >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1) 923 | >>> M = affine_matrix_from_points(v0[:3], v1[:3]) 924 | >>> numpy.allclose(v1, numpy.dot(M, v0)) 925 | True 926 | 927 | More examples in superimposition_matrix() 928 | 929 | """ 930 | v0 = numpy.array(v0, dtype=numpy.float64, copy=True) 931 | v1 = numpy.array(v1, dtype=numpy.float64, copy=True) 932 | 933 | ndims = v0.shape[0] 934 | if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape: 935 | raise ValueError("input arrays are of wrong shape or type") 936 | 937 | # move centroids to origin 938 | t0 = -numpy.mean(v0, axis=1) 939 | M0 = numpy.identity(ndims+1) 940 | M0[:ndims, ndims] = t0 941 | v0 += t0.reshape(ndims, 1) 942 | t1 = -numpy.mean(v1, axis=1) 943 | M1 = numpy.identity(ndims+1) 944 | M1[:ndims, ndims] = t1 945 | v1 += t1.reshape(ndims, 1) 946 | 947 | if shear: 948 | # Affine transformation 949 | A = numpy.concatenate((v0, v1), axis=0) 950 | u, s, vh = numpy.linalg.svd(A.T) 951 | vh = vh[:ndims].T 952 | B = vh[:ndims] 953 | C = vh[ndims:2*ndims] 954 | t = numpy.dot(C, numpy.linalg.pinv(B)) 955 | t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1) 956 | M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,))) 957 | elif usesvd or ndims != 3: 958 | # Rigid transformation via SVD of covariance matrix 959 | u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T)) 960 | # rotation matrix from SVD orthonormal bases 961 | R = numpy.dot(u, vh) 962 | if numpy.linalg.det(R) < 0.0: 963 | # R does not constitute right handed system 964 | R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0) 965 | s[-1] *= -1.0 966 | # homogeneous transformation matrix 967 | M = numpy.identity(ndims+1) 968 | M[:ndims, :ndims] = R 969 | else: 970 | # Rigid transformation matrix via quaternion 971 | # compute symmetric matrix N 972 | xx, yy, zz = numpy.sum(v0 * v1, axis=1) 973 | xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1) 974 | xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1) 975 | N = [[xx+yy+zz, 0.0, 0.0, 0.0], 976 | [yz-zy, xx-yy-zz, 0.0, 0.0], 977 | [zx-xz, xy+yx, yy-xx-zz, 0.0], 978 | [xy-yx, zx+xz, yz+zy, zz-xx-yy]] 979 | # quaternion: eigenvector corresponding to most positive eigenvalue 980 | w, V = numpy.linalg.eigh(N) 981 | q = V[:, numpy.argmax(w)] 982 | q /= vector_norm(q) # unit quaternion 983 | # homogeneous transformation matrix 984 | M = quaternion_matrix(q) 985 | 986 | if scale and not shear: 987 | # Affine transformation; scale is ratio of RMS deviations from centroid 988 | v0 *= v0 989 | v1 *= v1 990 | M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0)) 991 | 992 | # move centroids back 993 | M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0)) 994 | M /= M[ndims, ndims] 995 | return M 996 | 997 | 998 | def superimposition_matrix(v0, v1, scale=False, usesvd=True): 999 | """Return matrix to transform given 3D point set into second point set. 1000 | 1001 | v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points. 1002 | 1003 | The parameters scale and usesvd are explained in the more general 1004 | affine_matrix_from_points function. 1005 | 1006 | The returned matrix is a similarity or Euclidean transformation matrix. 1007 | This function has a fast C implementation in transformations.c. 1008 | 1009 | >>> v0 = numpy.random.rand(3, 10) 1010 | >>> M = superimposition_matrix(v0, v0) 1011 | >>> numpy.allclose(M, numpy.identity(4)) 1012 | True 1013 | >>> R = random_rotation_matrix(numpy.random.random(3)) 1014 | >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]] 1015 | >>> v1 = numpy.dot(R, v0) 1016 | >>> M = superimposition_matrix(v0, v1) 1017 | >>> numpy.allclose(v1, numpy.dot(M, v0)) 1018 | True 1019 | >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20 1020 | >>> v0[3] = 1 1021 | >>> v1 = numpy.dot(R, v0) 1022 | >>> M = superimposition_matrix(v0, v1) 1023 | >>> numpy.allclose(v1, numpy.dot(M, v0)) 1024 | True 1025 | >>> S = scale_matrix(random.random()) 1026 | >>> T = translation_matrix(numpy.random.random(3)-0.5) 1027 | >>> M = concatenate_matrices(T, R, S) 1028 | >>> v1 = numpy.dot(M, v0) 1029 | >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1) 1030 | >>> M = superimposition_matrix(v0, v1, scale=True) 1031 | >>> numpy.allclose(v1, numpy.dot(M, v0)) 1032 | True 1033 | >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) 1034 | >>> numpy.allclose(v1, numpy.dot(M, v0)) 1035 | True 1036 | >>> v = numpy.empty((4, 100, 3)) 1037 | >>> v[:, :, 0] = v0 1038 | >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False) 1039 | >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0])) 1040 | True 1041 | 1042 | """ 1043 | v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3] 1044 | v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3] 1045 | return affine_matrix_from_points(v0, v1, shear=False, 1046 | scale=scale, usesvd=usesvd) 1047 | 1048 | 1049 | def euler_matrix(ai, aj, ak, axes='sxyz'): 1050 | """Return homogeneous rotation matrix from Euler angles and axis sequence. 1051 | 1052 | ai, aj, ak : Euler's roll, pitch and yaw angles 1053 | axes : One of 24 axis sequences as string or encoded tuple 1054 | 1055 | >>> R = euler_matrix(1, 2, 3, 'syxz') 1056 | >>> numpy.allclose(numpy.sum(R[0]), -1.34786452) 1057 | True 1058 | >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1)) 1059 | >>> numpy.allclose(numpy.sum(R[0]), -0.383436184) 1060 | True 1061 | >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5) 1062 | >>> for axes in _AXES2TUPLE.keys(): 1063 | ... R = euler_matrix(ai, aj, ak, axes) 1064 | >>> for axes in _TUPLE2AXES.keys(): 1065 | ... R = euler_matrix(ai, aj, ak, axes) 1066 | 1067 | """ 1068 | try: 1069 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes] 1070 | except (AttributeError, KeyError): 1071 | _TUPLE2AXES[axes] # validation 1072 | firstaxis, parity, repetition, frame = axes 1073 | 1074 | i = firstaxis 1075 | j = _NEXT_AXIS[i+parity] 1076 | k = _NEXT_AXIS[i-parity+1] 1077 | 1078 | if frame: 1079 | ai, ak = ak, ai 1080 | if parity: 1081 | ai, aj, ak = -ai, -aj, -ak 1082 | 1083 | si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak) 1084 | ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak) 1085 | cc, cs = ci*ck, ci*sk 1086 | sc, ss = si*ck, si*sk 1087 | 1088 | M = numpy.identity(4) 1089 | if repetition: 1090 | M[i, i] = cj 1091 | M[i, j] = sj*si 1092 | M[i, k] = sj*ci 1093 | M[j, i] = sj*sk 1094 | M[j, j] = -cj*ss+cc 1095 | M[j, k] = -cj*cs-sc 1096 | M[k, i] = -sj*ck 1097 | M[k, j] = cj*sc+cs 1098 | M[k, k] = cj*cc-ss 1099 | else: 1100 | M[i, i] = cj*ck 1101 | M[i, j] = sj*sc-cs 1102 | M[i, k] = sj*cc+ss 1103 | M[j, i] = cj*sk 1104 | M[j, j] = sj*ss+cc 1105 | M[j, k] = sj*cs-sc 1106 | M[k, i] = -sj 1107 | M[k, j] = cj*si 1108 | M[k, k] = cj*ci 1109 | return M 1110 | 1111 | 1112 | def euler_from_matrix(matrix, axes='sxyz'): 1113 | """Return Euler angles from rotation matrix for specified axis sequence. 1114 | 1115 | axes : One of 24 axis sequences as string or encoded tuple 1116 | 1117 | Note that many Euler angle triplets can describe one matrix. 1118 | 1119 | >>> R0 = euler_matrix(1, 2, 3, 'syxz') 1120 | >>> al, be, ga = euler_from_matrix(R0, 'syxz') 1121 | >>> R1 = euler_matrix(al, be, ga, 'syxz') 1122 | >>> numpy.allclose(R0, R1) 1123 | True 1124 | >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5) 1125 | >>> for axes in _AXES2TUPLE.keys(): 1126 | ... R0 = euler_matrix(axes=axes, *angles) 1127 | ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes)) 1128 | ... if not numpy.allclose(R0, R1): print(axes, "failed") 1129 | 1130 | """ 1131 | try: 1132 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] 1133 | except (AttributeError, KeyError): 1134 | _TUPLE2AXES[axes] # validation 1135 | firstaxis, parity, repetition, frame = axes 1136 | 1137 | i = firstaxis 1138 | j = _NEXT_AXIS[i+parity] 1139 | k = _NEXT_AXIS[i-parity+1] 1140 | 1141 | M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3] 1142 | if repetition: 1143 | sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k]) 1144 | if sy > _EPS: 1145 | ax = math.atan2( M[i, j], M[i, k]) 1146 | ay = math.atan2( sy, M[i, i]) 1147 | az = math.atan2( M[j, i], -M[k, i]) 1148 | else: 1149 | ax = math.atan2(-M[j, k], M[j, j]) 1150 | ay = math.atan2( sy, M[i, i]) 1151 | az = 0.0 1152 | else: 1153 | cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i]) 1154 | if cy > _EPS: 1155 | ax = math.atan2( M[k, j], M[k, k]) 1156 | ay = math.atan2(-M[k, i], cy) 1157 | az = math.atan2( M[j, i], M[i, i]) 1158 | else: 1159 | ax = math.atan2(-M[j, k], M[j, j]) 1160 | ay = math.atan2(-M[k, i], cy) 1161 | az = 0.0 1162 | 1163 | if parity: 1164 | ax, ay, az = -ax, -ay, -az 1165 | if frame: 1166 | ax, az = az, ax 1167 | return ax, ay, az 1168 | 1169 | 1170 | def euler_from_quaternion(quaternion, axes='sxyz'): 1171 | """Return Euler angles from quaternion for specified axis sequence. 1172 | 1173 | >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0]) 1174 | >>> numpy.allclose(angles, [0.123, 0, 0]) 1175 | True 1176 | 1177 | """ 1178 | return euler_from_matrix(quaternion_matrix(quaternion), axes) 1179 | 1180 | 1181 | def quaternion_from_euler(ai, aj, ak, axes='sxyz'): 1182 | """Return quaternion from Euler angles and axis sequence. 1183 | 1184 | ai, aj, ak : Euler's roll, pitch and yaw angles 1185 | axes : One of 24 axis sequences as string or encoded tuple 1186 | 1187 | >>> q = quaternion_from_euler(1, 2, 3, 'ryxz') 1188 | >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435]) 1189 | True 1190 | 1191 | """ 1192 | try: 1193 | firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()] 1194 | except (AttributeError, KeyError): 1195 | _TUPLE2AXES[axes] # validation 1196 | firstaxis, parity, repetition, frame = axes 1197 | 1198 | i = firstaxis + 1 1199 | j = _NEXT_AXIS[i+parity-1] + 1 1200 | k = _NEXT_AXIS[i-parity] + 1 1201 | 1202 | if frame: 1203 | ai, ak = ak, ai 1204 | if parity: 1205 | aj = -aj 1206 | 1207 | ai /= 2.0 1208 | aj /= 2.0 1209 | ak /= 2.0 1210 | ci = math.cos(ai) 1211 | si = math.sin(ai) 1212 | cj = math.cos(aj) 1213 | sj = math.sin(aj) 1214 | ck = math.cos(ak) 1215 | sk = math.sin(ak) 1216 | cc = ci*ck 1217 | cs = ci*sk 1218 | sc = si*ck 1219 | ss = si*sk 1220 | 1221 | q = numpy.empty((4, )) 1222 | if repetition: 1223 | q[0] = cj*(cc - ss) 1224 | q[i] = cj*(cs + sc) 1225 | q[j] = sj*(cc + ss) 1226 | q[k] = sj*(cs - sc) 1227 | else: 1228 | q[0] = cj*cc + sj*ss 1229 | q[i] = cj*sc - sj*cs 1230 | q[j] = cj*ss + sj*cc 1231 | q[k] = cj*cs - sj*sc 1232 | if parity: 1233 | q[j] *= -1.0 1234 | 1235 | return q 1236 | 1237 | 1238 | def quaternion_about_axis(angle, axis): 1239 | """Return quaternion for rotation about axis. 1240 | 1241 | >>> q = quaternion_about_axis(0.123, [1, 0, 0]) 1242 | >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0]) 1243 | True 1244 | 1245 | """ 1246 | q = numpy.array([0.0, axis[0], axis[1], axis[2]]) 1247 | qlen = vector_norm(q) 1248 | if qlen > _EPS: 1249 | q *= math.sin(angle/2.0) / qlen 1250 | q[0] = math.cos(angle/2.0) 1251 | return q 1252 | 1253 | 1254 | def quaternion_matrix(quaternion): 1255 | """Return homogeneous rotation matrix from quaternion. 1256 | 1257 | >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0]) 1258 | >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0])) 1259 | True 1260 | >>> M = quaternion_matrix([1, 0, 0, 0]) 1261 | >>> numpy.allclose(M, numpy.identity(4)) 1262 | True 1263 | >>> M = quaternion_matrix([0, 1, 0, 0]) 1264 | >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1])) 1265 | True 1266 | 1267 | """ 1268 | q = numpy.array(quaternion, dtype=numpy.float64, copy=True) 1269 | n = numpy.dot(q, q) 1270 | if n < _EPS: 1271 | return numpy.identity(4) 1272 | q *= math.sqrt(2.0 / n) 1273 | q = numpy.outer(q, q) 1274 | return numpy.array([ 1275 | [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0], 1276 | [ q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0], 1277 | [ q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0], 1278 | [ 0.0, 0.0, 0.0, 1.0]]) 1279 | 1280 | 1281 | def quaternion_from_matrix(matrix, isprecise=False): 1282 | """Return quaternion from rotation matrix. 1283 | 1284 | If isprecise is True, the input matrix is assumed to be a precise rotation 1285 | matrix and a faster algorithm is used. 1286 | 1287 | >>> q = quaternion_from_matrix(numpy.identity(4), True) 1288 | >>> numpy.allclose(q, [1, 0, 0, 0]) 1289 | True 1290 | >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1])) 1291 | >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0]) 1292 | True 1293 | >>> R = rotation_matrix(0.123, (1, 2, 3)) 1294 | >>> q = quaternion_from_matrix(R, True) 1295 | >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786]) 1296 | True 1297 | >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0], 1298 | ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]] 1299 | >>> q = quaternion_from_matrix(R) 1300 | >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611]) 1301 | True 1302 | >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0], 1303 | ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]] 1304 | >>> q = quaternion_from_matrix(R) 1305 | >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603]) 1306 | True 1307 | >>> R = random_rotation_matrix() 1308 | >>> q = quaternion_from_matrix(R) 1309 | >>> is_same_transform(R, quaternion_matrix(q)) 1310 | True 1311 | >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0) 1312 | >>> numpy.allclose(quaternion_from_matrix(R, isprecise=False), 1313 | ... quaternion_from_matrix(R, isprecise=True)) 1314 | True 1315 | 1316 | """ 1317 | M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4] 1318 | if isprecise: 1319 | q = numpy.empty((4, )) 1320 | t = numpy.trace(M) 1321 | if t > M[3, 3]: 1322 | q[0] = t 1323 | q[3] = M[1, 0] - M[0, 1] 1324 | q[2] = M[0, 2] - M[2, 0] 1325 | q[1] = M[2, 1] - M[1, 2] 1326 | else: 1327 | i, j, k = 1, 2, 3 1328 | if M[1, 1] > M[0, 0]: 1329 | i, j, k = 2, 3, 1 1330 | if M[2, 2] > M[i, i]: 1331 | i, j, k = 3, 1, 2 1332 | t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] 1333 | q[i] = t 1334 | q[j] = M[i, j] + M[j, i] 1335 | q[k] = M[k, i] + M[i, k] 1336 | q[3] = M[k, j] - M[j, k] 1337 | q *= 0.5 / math.sqrt(t * M[3, 3]) 1338 | else: 1339 | m00 = M[0, 0] 1340 | m01 = M[0, 1] 1341 | m02 = M[0, 2] 1342 | m10 = M[1, 0] 1343 | m11 = M[1, 1] 1344 | m12 = M[1, 2] 1345 | m20 = M[2, 0] 1346 | m21 = M[2, 1] 1347 | m22 = M[2, 2] 1348 | # symmetric matrix K 1349 | K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0], 1350 | [m01+m10, m11-m00-m22, 0.0, 0.0], 1351 | [m02+m20, m12+m21, m22-m00-m11, 0.0], 1352 | [m21-m12, m02-m20, m10-m01, m00+m11+m22]]) 1353 | K /= 3.0 1354 | # quaternion is eigenvector of K that corresponds to largest eigenvalue 1355 | w, V = numpy.linalg.eigh(K) 1356 | q = V[[3, 0, 1, 2], numpy.argmax(w)] 1357 | if q[0] < 0.0: 1358 | numpy.negative(q, q) 1359 | return q 1360 | 1361 | 1362 | def quaternion_multiply(quaternion1, quaternion0): 1363 | """Return multiplication of two quaternions. 1364 | 1365 | >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7]) 1366 | >>> numpy.allclose(q, [28, -44, -14, 48]) 1367 | True 1368 | 1369 | """ 1370 | w0, x0, y0, z0 = quaternion0 1371 | w1, x1, y1, z1 = quaternion1 1372 | return numpy.array([-x1*x0 - y1*y0 - z1*z0 + w1*w0, 1373 | x1*w0 + y1*z0 - z1*y0 + w1*x0, 1374 | -x1*z0 + y1*w0 + z1*x0 + w1*y0, 1375 | x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64) 1376 | 1377 | 1378 | def quaternion_conjugate(quaternion): 1379 | """Return conjugate of quaternion. 1380 | 1381 | >>> q0 = random_quaternion() 1382 | >>> q1 = quaternion_conjugate(q0) 1383 | >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:]) 1384 | True 1385 | 1386 | """ 1387 | q = numpy.array(quaternion, dtype=numpy.float64, copy=True) 1388 | numpy.negative(q[1:], q[1:]) 1389 | return q 1390 | 1391 | 1392 | def quaternion_inverse(quaternion): 1393 | """Return inverse of quaternion. 1394 | 1395 | >>> q0 = random_quaternion() 1396 | >>> q1 = quaternion_inverse(q0) 1397 | >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0]) 1398 | True 1399 | 1400 | """ 1401 | q = numpy.array(quaternion, dtype=numpy.float64, copy=True) 1402 | numpy.negative(q[1:], q[1:]) 1403 | return q / numpy.dot(q, q) 1404 | 1405 | 1406 | def quaternion_real(quaternion): 1407 | """Return real part of quaternion. 1408 | 1409 | >>> quaternion_real([3, 0, 1, 2]) 1410 | 3.0 1411 | 1412 | """ 1413 | return float(quaternion[0]) 1414 | 1415 | 1416 | def quaternion_imag(quaternion): 1417 | """Return imaginary part of quaternion. 1418 | 1419 | >>> quaternion_imag([3, 0, 1, 2]) 1420 | array([ 0., 1., 2.]) 1421 | 1422 | """ 1423 | return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True) 1424 | 1425 | 1426 | def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True): 1427 | """Return spherical linear interpolation between two quaternions. 1428 | 1429 | >>> q0 = random_quaternion() 1430 | >>> q1 = random_quaternion() 1431 | >>> q = quaternion_slerp(q0, q1, 0) 1432 | >>> numpy.allclose(q, q0) 1433 | True 1434 | >>> q = quaternion_slerp(q0, q1, 1, 1) 1435 | >>> numpy.allclose(q, q1) 1436 | True 1437 | >>> q = quaternion_slerp(q0, q1, 0.5) 1438 | >>> angle = math.acos(numpy.dot(q0, q)) 1439 | >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \ 1440 | numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle) 1441 | True 1442 | 1443 | """ 1444 | q0 = unit_vector(quat0[:4]) 1445 | q1 = unit_vector(quat1[:4]) 1446 | if fraction == 0.0: 1447 | return q0 1448 | elif fraction == 1.0: 1449 | return q1 1450 | d = numpy.dot(q0, q1) 1451 | if abs(abs(d) - 1.0) < _EPS: 1452 | return q0 1453 | if shortestpath and d < 0.0: 1454 | # invert rotation 1455 | d = -d 1456 | numpy.negative(q1, q1) 1457 | angle = math.acos(d) + spin * math.pi 1458 | if abs(angle) < _EPS: 1459 | return q0 1460 | isin = 1.0 / math.sin(angle) 1461 | q0 *= math.sin((1.0 - fraction) * angle) * isin 1462 | q1 *= math.sin(fraction * angle) * isin 1463 | q0 += q1 1464 | return q0 1465 | 1466 | 1467 | def random_quaternion(rand=None): 1468 | """Return uniform random unit quaternion. 1469 | 1470 | rand: array like or None 1471 | Three independent random variables that are uniformly distributed 1472 | between 0 and 1. 1473 | 1474 | >>> q = random_quaternion() 1475 | >>> numpy.allclose(1, vector_norm(q)) 1476 | True 1477 | >>> q = random_quaternion(numpy.random.random(3)) 1478 | >>> len(q.shape), q.shape[0]==4 1479 | (1, True) 1480 | 1481 | """ 1482 | if rand is None: 1483 | rand = numpy.random.rand(3) 1484 | else: 1485 | assert len(rand) == 3 1486 | r1 = numpy.sqrt(1.0 - rand[0]) 1487 | r2 = numpy.sqrt(rand[0]) 1488 | pi2 = math.pi * 2.0 1489 | t1 = pi2 * rand[1] 1490 | t2 = pi2 * rand[2] 1491 | return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1, 1492 | numpy.cos(t1)*r1, numpy.sin(t2)*r2]) 1493 | 1494 | 1495 | def random_rotation_matrix(rand=None): 1496 | """Return uniform random rotation matrix. 1497 | 1498 | rand: array like 1499 | Three independent random variables that are uniformly distributed 1500 | between 0 and 1 for each returned quaternion. 1501 | 1502 | >>> R = random_rotation_matrix() 1503 | >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4)) 1504 | True 1505 | 1506 | """ 1507 | return quaternion_matrix(random_quaternion(rand)) 1508 | 1509 | 1510 | class Arcball(object): 1511 | """Virtual Trackball Control. 1512 | 1513 | >>> ball = Arcball() 1514 | >>> ball = Arcball(initial=numpy.identity(4)) 1515 | >>> ball.place([320, 320], 320) 1516 | >>> ball.down([500, 250]) 1517 | >>> ball.drag([475, 275]) 1518 | >>> R = ball.matrix() 1519 | >>> numpy.allclose(numpy.sum(R), 3.90583455) 1520 | True 1521 | >>> ball = Arcball(initial=[1, 0, 0, 0]) 1522 | >>> ball.place([320, 320], 320) 1523 | >>> ball.setaxes([1, 1, 0], [-1, 1, 0]) 1524 | >>> ball.constrain = True 1525 | >>> ball.down([400, 200]) 1526 | >>> ball.drag([200, 400]) 1527 | >>> R = ball.matrix() 1528 | >>> numpy.allclose(numpy.sum(R), 0.2055924) 1529 | True 1530 | >>> ball.next() 1531 | 1532 | """ 1533 | def __init__(self, initial=None): 1534 | """Initialize virtual trackball control. 1535 | 1536 | initial : quaternion or rotation matrix 1537 | 1538 | """ 1539 | self._axis = None 1540 | self._axes = None 1541 | self._radius = 1.0 1542 | self._center = [0.0, 0.0] 1543 | self._vdown = numpy.array([0.0, 0.0, 1.0]) 1544 | self._constrain = False 1545 | if initial is None: 1546 | self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0]) 1547 | else: 1548 | initial = numpy.array(initial, dtype=numpy.float64) 1549 | if initial.shape == (4, 4): 1550 | self._qdown = quaternion_from_matrix(initial) 1551 | elif initial.shape == (4, ): 1552 | initial /= vector_norm(initial) 1553 | self._qdown = initial 1554 | else: 1555 | raise ValueError("initial not a quaternion or matrix") 1556 | self._qnow = self._qpre = self._qdown 1557 | 1558 | def place(self, center, radius): 1559 | """Place Arcball, e.g. when window size changes. 1560 | 1561 | center : sequence[2] 1562 | Window coordinates of trackball center. 1563 | radius : float 1564 | Radius of trackball in window coordinates. 1565 | 1566 | """ 1567 | self._radius = float(radius) 1568 | self._center[0] = center[0] 1569 | self._center[1] = center[1] 1570 | 1571 | def setaxes(self, *axes): 1572 | """Set axes to constrain rotations.""" 1573 | if axes is None: 1574 | self._axes = None 1575 | else: 1576 | self._axes = [unit_vector(axis) for axis in axes] 1577 | 1578 | @property 1579 | def constrain(self): 1580 | """Return state of constrain to axis mode.""" 1581 | return self._constrain 1582 | 1583 | @constrain.setter 1584 | def constrain(self, value): 1585 | """Set state of constrain to axis mode.""" 1586 | self._constrain = bool(value) 1587 | 1588 | def down(self, point): 1589 | """Set initial cursor window coordinates and pick constrain-axis.""" 1590 | self._vdown = arcball_map_to_sphere(point, self._center, self._radius) 1591 | self._qdown = self._qpre = self._qnow 1592 | if self._constrain and self._axes is not None: 1593 | self._axis = arcball_nearest_axis(self._vdown, self._axes) 1594 | self._vdown = arcball_constrain_to_axis(self._vdown, self._axis) 1595 | else: 1596 | self._axis = None 1597 | 1598 | def drag(self, point): 1599 | """Update current cursor window coordinates.""" 1600 | vnow = arcball_map_to_sphere(point, self._center, self._radius) 1601 | if self._axis is not None: 1602 | vnow = arcball_constrain_to_axis(vnow, self._axis) 1603 | self._qpre = self._qnow 1604 | t = numpy.cross(self._vdown, vnow) 1605 | if numpy.dot(t, t) < _EPS: 1606 | self._qnow = self._qdown 1607 | else: 1608 | q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]] 1609 | self._qnow = quaternion_multiply(q, self._qdown) 1610 | 1611 | def next(self, acceleration=0.0): 1612 | """Continue rotation in direction of last drag.""" 1613 | q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False) 1614 | self._qpre, self._qnow = self._qnow, q 1615 | 1616 | def matrix(self): 1617 | """Return homogeneous rotation matrix.""" 1618 | return quaternion_matrix(self._qnow) 1619 | 1620 | 1621 | def arcball_map_to_sphere(point, center, radius): 1622 | """Return unit sphere coordinates from window coordinates.""" 1623 | v0 = (point[0] - center[0]) / radius 1624 | v1 = (center[1] - point[1]) / radius 1625 | n = v0*v0 + v1*v1 1626 | if n > 1.0: 1627 | # position outside of sphere 1628 | n = math.sqrt(n) 1629 | return numpy.array([v0/n, v1/n, 0.0]) 1630 | else: 1631 | return numpy.array([v0, v1, math.sqrt(1.0 - n)]) 1632 | 1633 | 1634 | def arcball_constrain_to_axis(point, axis): 1635 | """Return sphere point perpendicular to axis.""" 1636 | v = numpy.array(point, dtype=numpy.float64, copy=True) 1637 | a = numpy.array(axis, dtype=numpy.float64, copy=True) 1638 | v -= a * numpy.dot(a, v) # on plane 1639 | n = vector_norm(v) 1640 | if n > _EPS: 1641 | if v[2] < 0.0: 1642 | numpy.negative(v, v) 1643 | v /= n 1644 | return v 1645 | if a[2] == 1.0: 1646 | return numpy.array([1.0, 0.0, 0.0]) 1647 | return unit_vector([-a[1], a[0], 0.0]) 1648 | 1649 | 1650 | def arcball_nearest_axis(point, axes): 1651 | """Return axis, which arc is nearest to point.""" 1652 | point = numpy.array(point, dtype=numpy.float64, copy=False) 1653 | nearest = None 1654 | mx = -1.0 1655 | for axis in axes: 1656 | t = numpy.dot(arcball_constrain_to_axis(point, axis), point) 1657 | if t > mx: 1658 | nearest = axis 1659 | mx = t 1660 | return nearest 1661 | 1662 | 1663 | # epsilon for testing whether a number is close to zero 1664 | _EPS = numpy.finfo(float).eps * 4.0 1665 | 1666 | # axis sequences for Euler angles 1667 | _NEXT_AXIS = [1, 2, 0, 1] 1668 | 1669 | # map axes strings to/from tuples of inner axis, parity, repetition, frame 1670 | _AXES2TUPLE = { 1671 | 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0), 1672 | 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0), 1673 | 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0), 1674 | 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0), 1675 | 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1), 1676 | 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1), 1677 | 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1), 1678 | 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)} 1679 | 1680 | _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items()) 1681 | 1682 | 1683 | def vector_norm(data, axis=None, out=None): 1684 | """Return length, i.e. Euclidean norm, of ndarray along axis. 1685 | 1686 | >>> v = numpy.random.random(3) 1687 | >>> n = vector_norm(v) 1688 | >>> numpy.allclose(n, numpy.linalg.norm(v)) 1689 | True 1690 | >>> v = numpy.random.rand(6, 5, 3) 1691 | >>> n = vector_norm(v, axis=-1) 1692 | >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2))) 1693 | True 1694 | >>> n = vector_norm(v, axis=1) 1695 | >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) 1696 | True 1697 | >>> v = numpy.random.rand(5, 4, 3) 1698 | >>> n = numpy.empty((5, 3)) 1699 | >>> vector_norm(v, axis=1, out=n) 1700 | >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1))) 1701 | True 1702 | >>> vector_norm([]) 1703 | 0.0 1704 | >>> vector_norm([1]) 1705 | 1.0 1706 | 1707 | """ 1708 | data = numpy.array(data, dtype=numpy.float64, copy=True) 1709 | if out is None: 1710 | if data.ndim == 1: 1711 | return math.sqrt(numpy.dot(data, data)) 1712 | data *= data 1713 | out = numpy.atleast_1d(numpy.sum(data, axis=axis)) 1714 | numpy.sqrt(out, out) 1715 | return out 1716 | else: 1717 | data *= data 1718 | numpy.sum(data, axis=axis, out=out) 1719 | numpy.sqrt(out, out) 1720 | 1721 | 1722 | def unit_vector(data, axis=None, out=None): 1723 | """Return ndarray normalized by length, i.e. Euclidean norm, along axis. 1724 | 1725 | >>> v0 = numpy.random.random(3) 1726 | >>> v1 = unit_vector(v0) 1727 | >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0)) 1728 | True 1729 | >>> v0 = numpy.random.rand(5, 4, 3) 1730 | >>> v1 = unit_vector(v0, axis=-1) 1731 | >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2) 1732 | >>> numpy.allclose(v1, v2) 1733 | True 1734 | >>> v1 = unit_vector(v0, axis=1) 1735 | >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1) 1736 | >>> numpy.allclose(v1, v2) 1737 | True 1738 | >>> v1 = numpy.empty((5, 4, 3)) 1739 | >>> unit_vector(v0, axis=1, out=v1) 1740 | >>> numpy.allclose(v1, v2) 1741 | True 1742 | >>> list(unit_vector([])) 1743 | [] 1744 | >>> list(unit_vector([1])) 1745 | [1.0] 1746 | 1747 | """ 1748 | if out is None: 1749 | data = numpy.array(data, dtype=numpy.float64, copy=True) 1750 | if data.ndim == 1: 1751 | data /= math.sqrt(numpy.dot(data, data)) 1752 | return data 1753 | else: 1754 | if out is not data: 1755 | out[:] = numpy.array(data, copy=False) 1756 | data = out 1757 | length = numpy.atleast_1d(numpy.sum(data*data, axis)) 1758 | numpy.sqrt(length, length) 1759 | if axis is not None: 1760 | length = numpy.expand_dims(length, axis) 1761 | data /= length 1762 | if out is None: 1763 | return data 1764 | 1765 | 1766 | def random_vector(size): 1767 | """Return array of random doubles in the half-open interval [0.0, 1.0). 1768 | 1769 | >>> v = random_vector(10000) 1770 | >>> numpy.all(v >= 0) and numpy.all(v < 1) 1771 | True 1772 | >>> v0 = random_vector(10) 1773 | >>> v1 = random_vector(10) 1774 | >>> numpy.any(v0 == v1) 1775 | False 1776 | 1777 | """ 1778 | return numpy.random.random(size) 1779 | 1780 | 1781 | def vector_product(v0, v1, axis=0): 1782 | """Return vector perpendicular to vectors. 1783 | 1784 | >>> v = vector_product([2, 0, 0], [0, 3, 0]) 1785 | >>> numpy.allclose(v, [0, 0, 6]) 1786 | True 1787 | >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] 1788 | >>> v1 = [[3], [0], [0]] 1789 | >>> v = vector_product(v0, v1) 1790 | >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]]) 1791 | True 1792 | >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] 1793 | >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] 1794 | >>> v = vector_product(v0, v1, axis=1) 1795 | >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]]) 1796 | True 1797 | 1798 | """ 1799 | return numpy.cross(v0, v1, axis=axis) 1800 | 1801 | 1802 | def angle_between_vectors(v0, v1, directed=True, axis=0): 1803 | """Return angle between vectors. 1804 | 1805 | If directed is False, the input vectors are interpreted as undirected axes, 1806 | i.e. the maximum angle is pi/2. 1807 | 1808 | >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3]) 1809 | >>> numpy.allclose(a, math.pi) 1810 | True 1811 | >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False) 1812 | >>> numpy.allclose(a, 0) 1813 | True 1814 | >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]] 1815 | >>> v1 = [[3], [0], [0]] 1816 | >>> a = angle_between_vectors(v0, v1) 1817 | >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532]) 1818 | True 1819 | >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]] 1820 | >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]] 1821 | >>> a = angle_between_vectors(v0, v1, axis=1) 1822 | >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532]) 1823 | True 1824 | 1825 | """ 1826 | v0 = numpy.array(v0, dtype=numpy.float64, copy=False) 1827 | v1 = numpy.array(v1, dtype=numpy.float64, copy=False) 1828 | dot = numpy.sum(v0 * v1, axis=axis) 1829 | dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis) 1830 | return numpy.arccos(dot if directed else numpy.fabs(dot)) 1831 | 1832 | 1833 | def inverse_matrix(matrix): 1834 | """Return inverse of square transformation matrix. 1835 | 1836 | >>> M0 = random_rotation_matrix() 1837 | >>> M1 = inverse_matrix(M0.T) 1838 | >>> numpy.allclose(M1, numpy.linalg.inv(M0.T)) 1839 | True 1840 | >>> for size in range(1, 7): 1841 | ... M0 = numpy.random.rand(size, size) 1842 | ... M1 = inverse_matrix(M0) 1843 | ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size) 1844 | 1845 | """ 1846 | return numpy.linalg.inv(matrix) 1847 | 1848 | 1849 | def concatenate_matrices(*matrices): 1850 | """Return concatenation of series of transformation matrices. 1851 | 1852 | >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5 1853 | >>> numpy.allclose(M, concatenate_matrices(M)) 1854 | True 1855 | >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T)) 1856 | True 1857 | 1858 | """ 1859 | M = numpy.identity(4) 1860 | for i in matrices: 1861 | M = numpy.dot(M, i) 1862 | return M 1863 | 1864 | 1865 | def is_same_transform(matrix0, matrix1): 1866 | """Return True if two matrices perform same transformation. 1867 | 1868 | >>> is_same_transform(numpy.identity(4), numpy.identity(4)) 1869 | True 1870 | >>> is_same_transform(numpy.identity(4), random_rotation_matrix()) 1871 | False 1872 | 1873 | """ 1874 | matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True) 1875 | matrix0 /= matrix0[3, 3] 1876 | matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True) 1877 | matrix1 /= matrix1[3, 3] 1878 | return numpy.allclose(matrix0, matrix1) 1879 | 1880 | 1881 | def _import_module(name, package=None, warn=False, prefix='_py_', ignore='_'): 1882 | """Try import all public attributes from module into global namespace. 1883 | 1884 | Existing attributes with name clashes are renamed with prefix. 1885 | Attributes starting with underscore are ignored by default. 1886 | 1887 | Return True on successful import. 1888 | 1889 | """ 1890 | import warnings 1891 | from importlib import import_module 1892 | try: 1893 | if not package: 1894 | module = import_module(name) 1895 | else: 1896 | module = import_module('.' + name, package=package) 1897 | except ImportError: 1898 | if warn: 1899 | warnings.warn("failed to import module %s" % name) 1900 | else: 1901 | for attr in dir(module): 1902 | if ignore and attr.startswith(ignore): 1903 | continue 1904 | if prefix: 1905 | if attr in globals(): 1906 | globals()[prefix + attr] = globals()[attr] 1907 | elif warn: 1908 | warnings.warn("no Python implementation of " + attr) 1909 | globals()[attr] = getattr(module, attr) 1910 | return True 1911 | 1912 | 1913 | _import_module('_transformations') 1914 | 1915 | if __name__ == "__main__": 1916 | import doctest 1917 | import random # used in doctests 1918 | numpy.set_printoptions(suppress=True, precision=5) 1919 | doctest.testmod() 1920 | 1921 | --------------------------------------------------------------------------------