├── LICENSE ├── README.md ├── configs ├── datasets │ ├── BaselineEuroc │ │ └── Euroc_1000.conf │ ├── KITTIDataset │ │ └── KITTI26.conf │ └── TUMVI │ │ └── room_1000.conf ├── exp │ ├── EuRoC │ │ ├── baseline.conf │ │ └── codenet.conf │ ├── KITTI │ │ └── codenet.conf │ └── TUMVI │ │ └── codenet.conf └── train │ └── train.conf ├── datasets ├── EuRoCdataset.py ├── KITTIdataset.py ├── TUMdataset.py ├── __init__.py ├── dataset.py └── dataset_utils.py ├── doc ├── alto.gif └── model.png ├── eval.py ├── evaluation └── evaluate_state.py ├── inference.py ├── model ├── __init__.py ├── cnn.py ├── code.py ├── loss_func.py ├── losses.py ├── net.py └── others.py ├── train.py └── utils ├── __init__.py ├── deferentiate_vel.py ├── integrate.py ├── utils.py ├── visualize.py └── visualize_state.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Carnegie Mellon University 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AirIMU : Learning Uncertainty Propagation for Inertial Odometry 2 | [![License: BSD 3-Clause](https://img.shields.io/badge/License-BSD%203--Clause-yellow.svg)](./LICENSE) 3 | [![YouTube](https://img.shields.io/badge/YouTube-b31b1b?style=flat&logo=youtube&logoColor=white)](https://www.youtube.com/watch?v=fTX1u-e7wtU) 4 | [![arXiv](https://img.shields.io/badge/arXiv-AirIMU-orange.svg)](https://arxiv.org/abs/2310.04874) 5 | [![githubio](https://img.shields.io/badge/-homepage-blue?logo=Github&color=FF0000)](https://airimu.github.io/) 6 | 7 | 8 | ![AirIMU](./doc/model.png) 9 | ## 📢 Latest News 10 | - [2025-02-01] Introducing Our New Project!
11 | 🚀 [**AirIO : Learning Inertial Odometry with Enhanced IMU Feature Observability**](https://github.com/Air-IO/Air-IO)
12 | ``` 13 | AirIO achieves up to 86.6% performance boost over SOTA methods: 14 | 15 | - ✅ Tailored specifically for drones 16 | - ✅ No external sensors or control inputs required 17 | - ✅ Generalizes to unseen trajectories 18 | - ✅ Explicitly encodes UAV attitude and predicts velocity in body-frame representation 19 | ``` 20 | 21 | ## Installation 22 | 23 | This work is based on pypose. Follow the instruction and install the newest realase of pypose: 24 | https://github.com/pypose/pypose 25 | 26 | 27 | ## Dataset 28 | > **Note**: Remember to reset the `data_root` in `configs/datasets/${DATASET}/${DATASET}.conf`. 29 | 30 | Download the Euroc dataset from: 31 | https://projects.asl.ethz.ch/datasets/doku.php?id=kmavvisualinertialdatasets 32 | 33 | Download the TUM VI dataset from: 34 | https://cvg.cit.tum.de/data/datasets/visual-inertial-dataset 35 | 36 | Download the KITTI dataset from: 37 | https://www.cvlibs.net/datasets/kitti/ 38 | 39 | ## Pretrained Model 40 | > **Note**: You can download our trained ckpt here. 41 | 42 | 43 | [KITTI](https://github.com/sleepycan/AirIMU/releases/download/pretrained_model/KITTI_odom_model.zip) 44 | 45 | [EuRoC](https://github.com/sleepycan/AirIMU/releases/download/pretrained_model_euroc/EuRoCWholeaug.zip) 46 | ## Train 47 | 48 | Easy way to start the training using the exisiting configuration. 49 | > **Note**:You can also create your own configuration file for different datasets and set the parameters accordingly. 50 | 51 | ``` 52 | python train.py --config configs/exp/EuRoC/codenet.conf 53 | ``` 54 | 55 | More specific option: 56 | 57 | ``` 58 | usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--load_ckpt] [--log] 59 | 60 | optional arguments: 61 | -h, --help show this help message and exit 62 | --config CONFIG config file path 63 | --device DEVICE cuda or cpu, Default is cuda:0 64 | --load_ckpt If True, try to load the newest.ckpt in the exp_dir specificed in our config file. 65 | --log if True, save the meta data with wandb, Default is True 66 | ``` 67 | 68 | ## Evaluation 69 | 70 | To evaluate the model and generate network inference file net_output.pickle, run the following command: 71 | ``` 72 | python inference.py --config configs/exp/EuRoC/codenet.conf 73 | ``` 74 | 75 |
76 | 77 | You can use the evaluation tool to assess your model performance with net_output.pickle, run the following command. 78 | > **Note**: Make sure to replace path/to/net_output_directory with the directory path where your network output pickle file is stored. 79 | 80 | ``` 81 | python evaluation/evaluate_state.py --dataconf configs/datasets/${DATASET}/${DATASET}.conf --exp path/to/net_output_directory 82 | ``` 83 | 84 |
85 | More specific option for the evaluation tool: 86 | 87 | ``` 88 | usage: evaluation/evaluate_state.py [-h] [--dataconf] [--device] [--exp] [--seqlen] [--savedir] [--usegtrot] [--mask] 89 | 90 | optional arguments: 91 | -h, --help show this help message and exit 92 | --config config file path 93 | --device cuda or cpu, Default is cuda:0 94 | --exp the directory path where your network output pickle file is stored 95 | --seqlen the length of the integration sequence 96 | --savedir the save diretory for the evaluation results, default path is "./result/loss_result" 97 | --usegtrot use ground truth rotation for gravity compensation, default is true 98 | --mask mask the segments if needed. 99 | ``` 100 | 101 | 102 | 103 | 104 | 105 | ### Cite Our Work 106 | 107 | Thanks for using our work. You can cite it as: 108 | 109 | ```bib 110 | @article{qiu2023airimu, 111 | title={AirIMU: Learning Uncertainty Propagation for Inertial Odometry}, 112 | author={Yuheng Qiu and Chen Wang and Can Xu and Yutian Chen and Xunfei Zhou and Youjie Xia and Sebastian Scherer}, 113 | year={2023}, 114 | eprint={2310.04874}, 115 | archivePrefix={arXiv}, 116 | primaryClass={cs.RO} 117 | } 118 | ``` 119 | -------------------------------------------------------------------------------- /configs/datasets/BaselineEuroc/Euroc_1000.conf: -------------------------------------------------------------------------------- 1 | train: 2 | { 3 | mode: train 4 | calib: False 5 | data_list: 6 | [ 7 | {name: Euroc 8 | window_size: 1000 9 | step_size: 10 10 | data_root: PATH_TO_EuRoC 11 | data_drive: [V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy] 12 | }, 13 | ] 14 | } 15 | 16 | test: 17 | { 18 | mode: test 19 | calib: False 20 | data_list: 21 | [ 22 | {name: Euroc 23 | window_size: 1000 24 | step_size: 200 25 | data_root: PATH_TO_EuRoC 26 | data_drive: [V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy] 27 | }, 28 | ] 29 | } 30 | 31 | eval: 32 | { 33 | mode: evaluate 34 | calib: False 35 | data_list: 36 | [{ 37 | name: Euroc 38 | window_size: 1000 39 | step_size: 200 40 | data_root: PATH_TO_EuRoC 41 | data_drive: [MH_02_easy, MH_04_difficult, V1_03_difficult, V2_02_medium, V1_01_easy] 42 | }, 43 | ] 44 | 45 | } 46 | 47 | 48 | inference: 49 | { 50 | mode: evaluate 51 | calib: False 52 | data_list: 53 | [{ 54 | name: Euroc 55 | window_size: 1000 56 | step_size: 1000 57 | data_root: PATH_TO_EuRoC 58 | data_drive: [MH_02_easy, MH_04_difficult, V1_03_difficult, V2_02_medium, V1_01_easy, V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy] 59 | }, 60 | ] 61 | } 62 | -------------------------------------------------------------------------------- /configs/datasets/KITTIDataset/KITTI26.conf: -------------------------------------------------------------------------------- 1 | train: 2 | { 3 | mode: train 4 | calib: False 5 | data_list: 6 | [ 7 | { 8 | name: KITTI, 9 | data_root: PATH_TO_KITTI_2011_09_26, 10 | data_drive: ["0027", "0059", "0023", "0070", "0095", "0064", "0091", "0051", "0104", "0036", "0101", "0032", "0061", "0084", "0013", "0096", "0005", "0056", "0001", "0019", "0028", "0086", "0015", "0046", "0011", "0009", "0117", "0087"] 11 | window_size: 50, 12 | step_size: 25 13 | } 14 | ] 15 | } 16 | 17 | test: 18 | { 19 | mode: evaluate 20 | calib: False 21 | data_list: 22 | [ 23 | { 24 | name: KITTI, 25 | data_root: PATH_TO_KITTI_2011_09_26, 26 | data_drive: ["0018", "0106", "0029", "0014", "0039", "0035", "0022"], 27 | window_size: 50, 28 | step_size: 25 29 | } 30 | ] 31 | } 32 | 33 | eval: 34 | { 35 | mode: evaluate 36 | calib: False 37 | data_list: 38 | [ 39 | { 40 | name: KITTI, 41 | data_root: PATH_TO_KITTI_2011_09_26, 42 | data_drive: ["0018", "0106", "0029", "0014", "0039", "0035", "0022"], 43 | window_size: 50, 44 | step_size: 25 45 | } 46 | ] 47 | } 48 | -------------------------------------------------------------------------------- /configs/datasets/TUMVI/room_1000.conf: -------------------------------------------------------------------------------- 1 | train: 2 | { 3 | mode: train 4 | calib: False 5 | data_list: 6 | [ 7 | {name: TUMVI 8 | window_size: 1000 9 | step_size: 10 10 | data_root: PATH_TO_TUM 11 | data_drive: [dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16] 12 | }, 13 | ] 14 | } 15 | 16 | test: 17 | { 18 | mode: test 19 | calib: False 20 | data_list: 21 | [ 22 | {name: TUMVI 23 | window_size: 1000 24 | step_size: 100 25 | data_root: PATH_TO_TUM 26 | data_drive: [dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16] 27 | }, 28 | ] 29 | } 30 | 31 | eval: 32 | { 33 | mode: evaluate 34 | calib: False 35 | data_list: 36 | [ 37 | {name: TUMVI 38 | window_size: 1000 39 | step_size: 100 40 | data_root: PATH_TO_TUM 41 | data_drive: [dataset-room2_512_16, dataset-room4_512_16, dataset-room6_512_16] 42 | }, 43 | ] 44 | } 45 | 46 | inference: 47 | { 48 | mode: evaluate 49 | calib: False 50 | data_list: 51 | [ 52 | {name: TUMVI 53 | window_size: 1000 54 | step_size: 100 55 | data_root: PATH_TO_TUM 56 | data_drive: [dataset-room2_512_16, dataset-room4_512_16, dataset-room6_512_16,dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16, dataset-room1_512_16] 57 | }, 58 | ] 59 | } -------------------------------------------------------------------------------- /configs/exp/EuRoC/baseline.conf: -------------------------------------------------------------------------------- 1 | general: 2 | { 3 | exp_dir: experiments/EuRoC 4 | } 5 | 6 | dataset: 7 | { 8 | include "../../datasets/BaselineEuroc/Euroc_1000_half.conf" 9 | } 10 | 11 | train: 12 | { 13 | include "../../train/train.conf" 14 | batch_size: 128 15 | rot_weight: 1e2 16 | pos_weight: 1e2 17 | vel_weight: 1e1 18 | propcov: False 19 | 20 | network: iden 21 | sampling: 50 22 | gtrot:True 23 | 24 | sampling: 50 25 | loss: Huber_loss0005 26 | rotloss: Huber_loss005 27 | lr: 1e-2 28 | } -------------------------------------------------------------------------------- /configs/exp/EuRoC/codenet.conf: -------------------------------------------------------------------------------- 1 | general: 2 | { 3 | exp_dir: experiments/EuRoC 4 | } 5 | 6 | dataset: 7 | { 8 | include "../../datasets/BaselineEuroc/Euroc_1000.conf" 9 | collate: padding9 10 | } 11 | 12 | train: 13 | { 14 | include "../../train/train.conf" 15 | lr: 1e-3 16 | batch_size: 128 17 | rot_weight: 1e3 18 | pos_weight: 1e2 19 | vel_weight: 1e1 20 | cov_weight: 1e-4 21 | 22 | network: codenet 23 | gtrot:True 24 | propcov:True 25 | covaug:True 26 | 27 | sampling: 50 28 | loss: Huber_loss005 29 | rotloss: Huber_loss005 30 | } -------------------------------------------------------------------------------- /configs/exp/KITTI/codenet.conf: -------------------------------------------------------------------------------- 1 | general: 2 | { 3 | exp_dir: experiments/KITTI 4 | } 5 | 6 | dataset: 7 | { 8 | include "../../datasets/KITTIDataset/KITTI26.conf" 9 | } 10 | 11 | train: 12 | { 13 | network: codenetkitti 14 | lr: 1e-3 15 | weight_decay: 1e-3, 16 | batch_size: 128 17 | min_lr: 1e-7 18 | max_epoches: 500 19 | patience: 50 20 | factor: 0.5 21 | 22 | save_freq: 5 23 | eval_freq: 1 24 | 25 | gtinit: True 26 | propcov: True 27 | gtrot:True 28 | 29 | rot_weight: 1e1 30 | pos_weight: 1e1 31 | vel_weight: 1e1 32 | cov_weight: 1e1 33 | 34 | sampling: False 35 | loss: Huber_loss0005 36 | rotloss: Huber_loss005 37 | } 38 | -------------------------------------------------------------------------------- /configs/exp/TUMVI/codenet.conf: -------------------------------------------------------------------------------- 1 | general: 2 | { 3 | exp_dir: /user/yqiu/data/yuhengq/AirIMU/TUMVI 4 | } 5 | 6 | dataset: 7 | { 8 | 9 | include "../../datasets/TUMVI/room_1000.conf" 10 | collate: padding9 11 | } 12 | 13 | train: 14 | { 15 | include "../../train/train.conf" 16 | lr: 1e-3 17 | batch_size: 128 18 | rot_weight: 1e2 19 | pos_weight: 1e2 20 | vel_weight: 1e2 21 | cov_weight: 1e-4 22 | 23 | network: codenet 24 | covaug:True 25 | gtrot:True 26 | propcov:True 27 | 28 | sampling: 50 29 | loss: Huber_loss005 30 | rotloss: Huber_loss005 31 | } -------------------------------------------------------------------------------- /configs/train/train.conf: -------------------------------------------------------------------------------- 1 | lr: 1e-3 2 | weight_decay: 1e-4 3 | 4 | batch_size: 128 5 | min_lr: 1e-5 6 | max_epoches: 100 7 | patience: 10 8 | factor: 0.1 9 | 10 | save_freq: 5 11 | eval_freq: 1 12 | 13 | cov_weight: None 14 | sampling: 50 15 | gtinit: True 16 | -------------------------------------------------------------------------------- /datasets/EuRoCdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pypose as pp 5 | from utils import qinterp, lookAt 6 | from .dataset import Sequence 7 | 8 | class Euroc(Sequence): 9 | """ 10 | Output: 11 | acce: the accelaration in **world frame** 12 | """ 13 | def __init__(self, data_root, data_name, intepolate = True, calib = False, load_vicon = False, glob_coord=False, **kwargs): 14 | super(Euroc, self).__init__() 15 | ( 16 | self.data_root, self.data_name, 17 | self.data, 18 | self.ts, 19 | self.targets, 20 | self.orientations, 21 | self.gt_pos, 22 | self.gt_ori, 23 | ) = (data_root, data_name, dict(), None, None, None, None, None) 24 | 25 | self.camera_ext_R = pp.mat2SO3(np.array([[0.0148655429818, -0.999880929698, 0.00414029679422,], 26 | [0.999557249008, 0.0149672133247, 0.025715529948, ], 27 | [-0.0257744366974, 0.00375618835797, 0.999660727178,],]), check= False) 28 | self.camera_ext_t = torch.tensor(np.array([-0.0216401454975, -0.064676986768, 0.00981073058949,])) 29 | self.vicon_ext_R = pp.mat2SO3(np.array([[0.33638, -0.01749, 0.94156],[-0.02078, -0.99972, -0.01114],[0.94150, -0.01582, -0.33665]]), check= False) 30 | self.vicon_ext_t = torch.tensor(np.array([0.06901, -0.02781,-0.12395])) 31 | self.ext_T = pp.SE3(torch.cat((self.camera_ext_t, self.camera_ext_R))) 32 | self.gravity = torch.tensor([0., 0., 9.81007], dtype=torch.float64) 33 | 34 | data_path = os.path.join(data_root, data_name) 35 | self.load_imu(data_path) 36 | self.load_gt(data_path) 37 | if load_vicon: 38 | self.load_vicon(data_path) 39 | 40 | # EUROC require an interpolation 41 | if intepolate: 42 | t_start = np.max([self.data['gt_time'][0], self.data['time'][0]]) 43 | t_end = np.min([self.data['gt_time'][-1], self.data['time'][-1]]) 44 | 45 | # find the index of the start and end 46 | idx_start_imu = np.searchsorted(self.data['time'], t_start) 47 | idx_start_gt = np.searchsorted(self.data['gt_time'], t_start) 48 | 49 | idx_end_imu = np.searchsorted(self.data['time'], t_end, 'right') 50 | idx_end_gt = np.searchsorted(self.data['gt_time'], t_end, 'right') 51 | 52 | for k in ['gt_time', 'pos', 'quat', 'velocity', 'b_acc', 'b_gyro']: 53 | self.data[k] = self.data[k][idx_start_gt:idx_end_gt] 54 | 55 | for k in ['time', 'acc', 'gyro']: 56 | self.data[k] = self.data[k][idx_start_imu:idx_end_imu] 57 | 58 | ## start interpotation 59 | self.data["gt_orientation"] = self.interp_rot(self.data['time'], self.data['gt_time'], self.data['quat']) 60 | self.data["gt_translation"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data['pos']) 61 | 62 | self.data["b_acc"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["b_acc"]) 63 | self.data["b_gyro"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["b_gyro"]) 64 | self.data["velocity"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["velocity"]) 65 | 66 | else: 67 | self.data["gt_orientation"] = pp.SO3(torch.tensor(self.data['pose'][:,3:])) 68 | self.data['gt_translation'] = torch.tensor(self.data['pose'][:,:3]) 69 | 70 | # move the time to torch 71 | self.data["time"] = torch.tensor(self.data["time"]) 72 | self.data["gt_time"] = torch.tensor(self.data["gt_time"]) 73 | self.data['dt'] = (self.data["time"][1:] - self.data["time"][:-1])[:,None] 74 | self.data["mask"] = torch.ones(self.data["time"].shape[0], dtype=torch.bool) 75 | 76 | # Calibration for evaluation 77 | if calib == "head": 78 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - self.data["b_gyro"][0] 79 | self.data["acc"] = torch.tensor(self.data["acc"]) - self.data["b_acc"][0] 80 | elif calib == "full": 81 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - self.data["b_gyro"] 82 | self.data["acc"] = torch.tensor(self.data["acc"]) - self.data["b_acc"] 83 | elif calib == "aligngravity": 84 | ## Find the nearest static point 85 | nl_point = np.where(self.data['velocity'].norm(dim=-1) < 0.001)[0][0] 86 | avg_acc = self.data['acc'][nl_point+10:nl_point+100].mean(axis=-2) 87 | avg_gyro = self.data['gyro'][nl_point+10:nl_point+100].mean(axis=-2) 88 | 89 | gr = lookAt(avg_acc) 90 | g_IMU = gr.T @ self.gravity 91 | gl_acc_b = avg_acc - g_IMU.numpy() 92 | gl_gyro_b = avg_gyro 93 | 94 | self.data["acc"] = torch.tensor(self.data["acc"]) - gl_acc_b 95 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - gl_gyro_b 96 | else: 97 | self.data["gyro"] = torch.tensor(self.data["gyro"]) 98 | self.data["acc"] = torch.tensor(self.data["acc"]) 99 | 100 | # change the acc and gyro scope into the global coordinate. 101 | if glob_coord: 102 | self.data['gyro'] = self.data["gt_orientation"] * self.data['gyro'] 103 | self.data['acc'] = self.data["gt_orientation"] * self.data['acc'] 104 | 105 | print("loaded: ", data_path, "calib: ", calib, "interpolate: ", intepolate) 106 | 107 | def get_length(self): 108 | return self.data['time'].shape[0] 109 | 110 | def load_imu(self, folder): 111 | imu_data = np.loadtxt(os.path.join(folder, "mav0/imu0/data.csv"), dtype=float, delimiter=',') 112 | self.data["time"] = imu_data[:,0] / 1e9 113 | self.data["gyro"] = imu_data[:,1:4] # w_RS_S_x [rad s^-1],w_RS_S_y [rad s^-1],w_RS_S_z [rad s^-1] 114 | self.data["acc"] = imu_data[:,4:]# acc a_RS_S_x [m s^-2],a_RS_S_y [m s^-2],a_RS_S_z [m s^-2] 115 | 116 | def load_gt(self, folder): 117 | gt_data = np.loadtxt(os.path.join(folder, "mav0/state_groundtruth_estimate0/data.csv"), dtype=float, delimiter=',') 118 | self.data["gt_time"] = gt_data[:,0] / 1e9 119 | self.data["pos"] = gt_data[:,1:4] 120 | self.data['quat'] = gt_data[:,4:8] # w, x, y, z 121 | self.data["b_acc"] = gt_data[:,-3:] 122 | self.data["b_gyro"] = gt_data[:,-6:-3] 123 | self.data["velocity"] = gt_data[:,-9:-6] 124 | 125 | def interp_rot(self, time, opt_time, quat): 126 | # interpolation in the log space 127 | imu_dt = torch.Tensor(time - opt_time[0]) 128 | gt_dt = torch.Tensor(opt_time - opt_time[0]) 129 | 130 | quat = torch.tensor(quat) 131 | quat = qinterp(quat, gt_dt, imu_dt).double() 132 | self.data['rot_wxyz'] = quat 133 | rot = torch.zeros_like(quat) 134 | rot[:,3] = quat[:,0] 135 | rot[:,:3] = quat[:,1:] 136 | 137 | return pp.SO3(rot) 138 | 139 | def interp_xyz(self, time, opt_time, xyz): 140 | 141 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0]) 142 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1]) 143 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2]) 144 | 145 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose() 146 | return torch.tensor(inte_xyz) 147 | 148 | -------------------------------------------------------------------------------- /datasets/KITTIdataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pypose as pp 4 | 5 | import pykitti 6 | from datetime import datetime 7 | 8 | from .dataset import Sequence 9 | class KITTI(Sequence): 10 | def __init__(self, data_root, data_drive, **kwargs) -> None: 11 | ( 12 | self.data_root, self.data_date, self.drive, 13 | ## self.data 14 | # "time" - (nx1) 15 | # "acc" - (nx3) 16 | # "gyro" - (nx3) 17 | # "dt" - (nx1) 18 | # "gt_translation" - (nx3) 19 | # "gt_orientation" - (nx3) 20 | # "velocity" - (nx3) 21 | # "mask" - 22 | ## 23 | self.data, 24 | self.ts, # Not used 25 | self.targets, # Not used 26 | self.orientations, # Not used 27 | self.gt_pos, # Tensor (nx3) 28 | self.gt_ori, # SO3 29 | ) = "/".join(data_root.split("/")[:-1]), data_root.split("/")[-1], data_drive, dict(), None, None, None, None, None 30 | print(f"Loading KITTI {data_root} @ {data_drive}") 31 | self.load_data() 32 | self.data_name = self.data_date + "_" + self.drive 33 | print(f"KITTI Sequence {data_drive} - length: {self.get_length()}") 34 | 35 | def get_length(self): 36 | return self.data["time"].size(0) - 1 37 | 38 | def load_data(self): 39 | raw_data = pykitti.raw(self.data_root, self.data_date, self.drive) 40 | raw_len = len(raw_data.timestamps) - 1 41 | 42 | self.data["time"] = torch.tensor( 43 | [datetime.timestamp(raw_data.timestamps[i]) for i in range(raw_len + 1)], 44 | dtype=torch.double 45 | ).unsqueeze(-1).double() 46 | self.data["acc"] = torch.tensor( 47 | [ 48 | [raw_data.oxts[i].packet.ax, 49 | raw_data.oxts[i].packet.ay, 50 | raw_data.oxts[i].packet.az] 51 | for i in range(raw_len) 52 | ] 53 | ).double() 54 | self.data["gyro"] = torch.tensor( 55 | [ 56 | [raw_data.oxts[i].packet.wx, 57 | raw_data.oxts[i].packet.wy, 58 | raw_data.oxts[i].packet.wz] 59 | for i in range(raw_len) 60 | ] 61 | ).double() 62 | self.data["dt"] = self.data["time"][1:] - self.data["time"][:-1] 63 | 64 | self.data["gt_translation"] = torch.tensor( 65 | np.array([raw_data.oxts[i].T_w_imu[0:3, 3] 66 | for i in range(raw_len)]) 67 | ).double() 68 | 69 | self.data["gt_orientation"] = pp.euler2SO3(torch.tensor( 70 | [ 71 | [raw_data.oxts[i].packet.roll, 72 | raw_data.oxts[i].packet.pitch, 73 | raw_data.oxts[i].packet.yaw] 74 | for i in range(raw_len) 75 | ] 76 | )).double() 77 | self.data["velocity"] = self.data["gt_orientation"] @ torch.tensor( 78 | [[raw_data.oxts[i].packet.vf, 79 | raw_data.oxts[i].packet.vl, 80 | raw_data.oxts[i].packet.vu] 81 | for i in range(raw_len)] 82 | ).double() 83 | self.data["mask"] = torch.ones(self.data["time"].shape[0], dtype=torch.bool) 84 | self.gt_pos = self.data["gt_translation"].clone() 85 | self.gt_ori = self.data["gt_orientation"].clone() 86 | 87 | -------------------------------------------------------------------------------- /datasets/TUMdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pypose as pp 5 | from utils import qinterp 6 | from .dataset import Sequence 7 | 8 | class TUMVI(Sequence): 9 | """ 10 | Output: 11 | acce: the accelaration in **world frame** 12 | """ 13 | def __init__(self, data_root, data_name, intepolate = True, calib = False, glob_coord=False, **kwargs): 14 | super(TUMVI, self).__init__() 15 | ( 16 | self.data_root, self.data_name, 17 | self.data, 18 | self.ts, 19 | self.targets, 20 | self.orientations, 21 | self.gt_pos, 22 | self.gt_ori, 23 | ) = (data_root, data_name, dict(), None, None, None, None, None) 24 | data_path = os.path.join(self.data_root, self.data_name) 25 | self.load_imu(data_path) 26 | self.load_gt(data_path) 27 | 28 | # inteporlate the ground truth pose 29 | if intepolate: 30 | t_start = np.max([self.data['gt_time'][0], self.data['time'][0]]) 31 | t_end = np.min([self.data['gt_time'][-1], self.data['time'][-1]]) 32 | 33 | idx_start_imu = np.searchsorted(self.data['time'], t_start) 34 | idx_start_gt = np.searchsorted(self.data['gt_time'], t_start) 35 | 36 | idx_end_imu = np.searchsorted(self.data['time'], t_end, 'right') 37 | idx_end_gt = np.searchsorted(self.data['gt_time'], t_end, 'right') 38 | 39 | ## GT data 40 | for k in ['gt_time', 'pos', 'quat']: 41 | self.data[k] = self.data[k][idx_start_gt:idx_end_gt] 42 | 43 | # ## imu data 44 | for k in ['time', 'acc', 'gyro']: 45 | self.data[k] = self.data[k][idx_start_imu:idx_end_imu] 46 | 47 | ## start interpotation 48 | self.data["gt_orientation"] = self.interp_rot(self.data['time'], self.data['gt_time'], self.data['quat']) 49 | self.data["gt_translation"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data['pos']) 50 | else: 51 | self.data["gt_orientation"] = pp.SO3(torch.tensor(self.data['pose'][:,3:])) 52 | self.data['gt_translation'] = torch.tensor(self.data['pose'][:,:3]) 53 | 54 | # move the time to torch 55 | self.data["time"] = torch.tensor(self.data["time"]) 56 | self.data["gt_time"] = torch.tensor(self.data["gt_time"]) 57 | self.data['dt'] = (self.data["time"][1:] - self.data["time"][:-1])[:,None] 58 | 59 | ## TUM dataset has some mistracked area 60 | gt_indexing = torch.searchsorted(self.data['gt_time'], self.data['time']) # indexing the imu with the nearest gt. 61 | time_dist = (self.data['time'] - self.data['gt_time'][gt_indexing]).abs() 62 | self.data["mask"] = time_dist < 0.01 63 | 64 | # Calibration for evaluation 65 | self.data["gyro"] = torch.tensor(self.data["gyro"]) 66 | self.data["acc"] = torch.tensor(self.data["acc"]) 67 | 68 | # change the acc and gyro scope into the global coordinate. 69 | if glob_coord: # For the other methods 70 | self.data['gyro'] = self.data["gt_orientation"] * self.data['gyro'] 71 | self.data['acc'] = self.data["gt_orientation"] * self.data['acc'] 72 | 73 | print("loaded: ", data_path, "calib: ", calib, "interpolate: ", intepolate) 74 | # self.save_hdf5(data_path) 75 | 76 | def get_length(self): 77 | return self.data['time'].shape[0] 78 | 79 | def load_imu(self, folder): 80 | imu_data = np.loadtxt(os.path.join(folder, "mav0/imu0/data.csv"), dtype=float, delimiter=',') 81 | self.data["time"] = imu_data[:,0] / 1e9 82 | self.data["gyro"] = imu_data[:,1:4] # w_RS_S_x [rad s^-1],w_RS_S_y [rad s^-1],w_RS_S_z [rad s^-1] 83 | self.data["acc"] = imu_data[:,4:]# acc a_RS_S_x [m s^-2],a_RS_S_y [m s^-2],a_RS_S_z [m s^-2] 84 | 85 | def load_gt(self, folder): 86 | gt_data = np.loadtxt(os.path.join(folder, "mav0/mocap0/data.csv"), dtype=float, delimiter=',') 87 | self.data["gt_time"] = gt_data[:,0] / 1e9 88 | self.data["pos"] = gt_data[:,1:4] 89 | self.data['quat'] = gt_data[:,4:8] # w, x, y, z 90 | velo_data = np.loadtxt(os.path.join(folder, "mav0/mocap0/grad_velo.txt"), dtype=float) 91 | self.data["velocity"] = torch.tensor(velo_data[:,1:]) 92 | 93 | def interp_rot(self, time, opt_time, quat): 94 | # interpolation in the log space 95 | imu_dt = torch.Tensor(time - opt_time[0]) 96 | gt_dt = torch.Tensor(opt_time - opt_time[0]) 97 | 98 | quat = torch.tensor(quat) 99 | quat = qinterp(quat, gt_dt, imu_dt).double() 100 | self.data['rot_wxyz'] = quat 101 | rot = torch.zeros_like(quat) 102 | rot[:,3] = quat[:,0] 103 | rot[:,:3] = quat[:,1:] 104 | 105 | return pp.SO3(rot) 106 | 107 | def interp_xyz(self, time, opt_time, xyz): 108 | 109 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0]) 110 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1]) 111 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2]) 112 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose() 113 | 114 | return torch.tensor(inte_xyz) 115 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .dataset_utils import * 3 | from .EuRoCdataset import * 4 | from .KITTIdataset import * 5 | from .TUMdataset import * -------------------------------------------------------------------------------- /datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from abc import ABC, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as Data 7 | from pyhocon import ConfigFactory 8 | 9 | 10 | class Sequence(ABC): 11 | # Dictionary to keep track of subclasses 12 | subclasses = {} 13 | 14 | def __init_subclass__(cls, **kwargs): 15 | super().__init_subclass__(**kwargs) 16 | cls.subclasses[cls.__name__] = cls 17 | 18 | class SeqDataset(Data.Dataset): 19 | def __init__(self, root, dataname, devive = 'cpu', name='Nav', duration=200, step_size=200, mode='inference', 20 | drop_last = True, conf = {}): 21 | super().__init__() 22 | 23 | self.DataClass = Sequence.subclasses 24 | 25 | self.conf = conf 26 | self.seq = self.DataClass[name](root, dataname, **self.conf) 27 | self.data = self.seq.data 28 | self.seqlen = self.seq.get_length()-1 29 | self.gravity = conf.gravity if "gravity" in conf.keys() else 9.81007 30 | if duration is None: self.duration = self.seqlen 31 | else: self.duration = duration 32 | 33 | if step_size is None: self.step_size = self.seqlen 34 | else: self.step_size = step_size 35 | 36 | self.data['acc_cov'] = 0.08 * torch.ones_like(self.data['acc']) 37 | self.data['gyro_cov'] = 0.006 * torch.ones_like(self.data['gyro']) 38 | 39 | start_frame = 0 40 | end_frame = self.seqlen 41 | 42 | self.index_map = [[i, i + self.duration] for i in range( 43 | 0, end_frame - start_frame - self.duration, self.step_size)] 44 | if (self.index_map[-1][-1] < end_frame) and (not drop_last): 45 | self.index_map.append([self.index_map[-1][-1], end_frame]) 46 | 47 | self.index_map = np.array(self.index_map) 48 | 49 | def __len__(self): 50 | return len(self.index_map) 51 | 52 | def __getitem__(self, i): 53 | frame_id, end_frame_id = self.index_map[i] 54 | return { 55 | 'dt': self.data['dt'][frame_id: end_frame_id], 56 | 'acc': self.data['acc'][frame_id: end_frame_id], 57 | 'gyro': self.data['gyro'][frame_id: end_frame_id], 58 | 'rot': self.data['gt_orientation'][frame_id: end_frame_id], 59 | 'gt_pos': self.data['gt_translation'][frame_id+1: end_frame_id+1], 60 | 'gt_rot': self.data['gt_orientation'][frame_id+1: end_frame_id+1], 61 | 'gt_vel': self.data['velocity'][frame_id+1: end_frame_id+1], 62 | 'init_pos': self.data['gt_translation'][frame_id][None, ...], 63 | 'init_rot': self.data['gt_orientation'][frame_id: end_frame_id], 64 | 'init_vel': self.data['velocity'][frame_id][None, ...], 65 | } 66 | 67 | def get_init_value(self): 68 | return {'pos': self.data['gt_translation'][:1], 69 | 'rot': self.data['gt_orientation'][:1], 70 | 'vel': self.data['velocity'][:1]} 71 | 72 | def get_mask(self): 73 | return self.data['mask'] 74 | 75 | def get_gravity(self): 76 | return self.gravity 77 | 78 | 79 | class SeqInfDataset(SeqDataset): 80 | def __init__(self, root, dataname, inference_state, device = 'cpu', name='Nav', duration=200, step_size=200, 81 | drop_last = True, mode='inference', usecov = True, useraw = False,conf={}): 82 | super().__init__(root, dataname, device, name, duration, step_size, mode, drop_last, conf) 83 | self.data['acc'][:-1] += inference_state['correction_acc'].cpu()[0] 84 | self.data['gyro'][:-1] += inference_state['correction_gyro'].cpu()[0] 85 | 86 | if 'acc_cov' in inference_state.keys() and usecov: 87 | self.data['acc_cov'] = inference_state['acc_cov'][0] 88 | 89 | if 'gyro_cov' in inference_state.keys() and usecov: 90 | self.data['gyro_cov'] = inference_state['gyro_cov'][0] 91 | 92 | 93 | class SeqeuncesDataset(Data.Dataset): 94 | """ 95 | For the purpose of training and inferering 96 | 1. Abandon the features of the last time frame, since there are no ground truth pose and dt 97 | to integrate the imu data of the last frame. So the length of the dataset is seq.get_length() - 1 98 | """ 99 | def __init__(self, data_set_config, mode = None, data_path = None, data_root = None, device= "cuda:0"): 100 | super(SeqeuncesDataset, self).__init__() 101 | ( 102 | self.ts, 103 | self.dt, 104 | self.acc, 105 | self.gyro, 106 | self.gt_pos, 107 | self.gt_ori, 108 | self.gt_velo, 109 | self.index_map, 110 | self.seq_idx, 111 | ) = ([], [], [], [], [], [], [], [], 0) 112 | self.uni = torch.distributions.uniform.Uniform(-torch.ones(1), torch.ones(1)) 113 | self.device = device 114 | self.conf = data_set_config 115 | self.gravity = conf.gravity if "gravity" in conf.keys() else 9.81007 116 | if mode is None: 117 | self.mode = data_set_config.mode 118 | else: 119 | self.mode = mode 120 | 121 | self.DataClass = Sequence.subclasses 122 | 123 | ## the design of datapath provide a quick way to revisit a specific sequence, but introduce some inconsistency 124 | if data_path is None: 125 | for conf in data_set_config.data_list: 126 | for path in conf.data_drive: 127 | self.construct_index_map(conf, conf["data_root"], path, self.seq_idx) 128 | self.seq_idx += 1 129 | ## the design of dataroot provide a quick way to introduce multiple sequences in eval set, but introduce some inconsistency 130 | elif data_root is None: 131 | conf = data_set_config.data_list[0] 132 | self.construct_index_map(conf, conf["data_root"], data_path, self.seq_idx) 133 | self.seq_idx += 1 134 | else: 135 | conf = data_set_config.data_list[0] 136 | self.construct_index_map(conf, data_root, data_path, self.seq_idx) 137 | self.seq_idx += 1 138 | 139 | def load_data(self, seq, start_frame, end_frame): 140 | if "time" in seq.data.keys(): 141 | self.ts.append(seq.data["time"][start_frame:end_frame]) 142 | self.acc.append(seq.data["acc"][start_frame:end_frame]) 143 | self.gyro.append(seq.data["gyro"][start_frame:end_frame]) 144 | # the groud truth state should include the init state and integrated state, thus has one more frame than imu data 145 | self.dt.append(seq.data["dt"][start_frame:end_frame+1]) 146 | self.gt_pos.append(seq.data["gt_translation"][start_frame:end_frame+1]) 147 | self.gt_ori.append(seq.data["gt_orientation"][start_frame:end_frame+1]) 148 | self.gt_velo.append(seq.data["velocity"][start_frame:end_frame+1]) 149 | 150 | def construct_index_map(self, conf, data_root, data_name, seq_id): 151 | seq = self.DataClass[conf.name](data_root, data_name, intepolate = True, **self.conf) 152 | seq_len = seq.get_length() -1 # abandon the last imu features 153 | window_size, step_size = conf.window_size, conf.step_size 154 | ## seting the starting and ending duration with different trianing mode 155 | start_frame, end_frame = 0, seq_len 156 | 157 | if self.mode == 'train_half': 158 | end_frame = np.floor(seq_len * 0.5).astype(int) 159 | elif self.mode == 'test_half': 160 | start_frame = np.floor(seq_len * 0.5).astype(int) 161 | elif self.mode == 'train_1m': 162 | end_frame = 12000 163 | elif self.mode == 'test_1m': 164 | start_frame = 12000 165 | elif self.mode == 'mini':# For the purpse of debug 166 | end_frame = 1000 167 | 168 | _duration = end_frame - start_frame 169 | if self.mode == "inference": 170 | window_size = seq_len 171 | step_size = seq_len 172 | self.index_map = [[seq_id, 0, seq_len]] 173 | elif self.mode == "infevaluate": 174 | self.index_map +=[ 175 | [seq_id, j, j+window_size] for j in range( 176 | 0, _duration - window_size, step_size) 177 | ] 178 | if self.index_map[-1][2] < _duration: 179 | print(self.index_map[-1][2]) 180 | self.index_map += [[seq_id, self.index_map[-1][2], seq_len]] 181 | elif self.mode == 'evaluate': 182 | # adding the last piece for evaluation 183 | self.index_map +=[ 184 | [seq_id, j, j+window_size] for j in range( 185 | 0, _duration - window_size, step_size) 186 | ] 187 | elif self.mode == 'train_half_random': 188 | np.random.seed(1) 189 | window_group_size = 3000 190 | selected_indices = [j for j in range(0, _duration-window_group_size, window_group_size)] 191 | np.random.shuffle(selected_indices) 192 | indices_num = len(selected_indices) 193 | for w in selected_indices[:np.floor(indices_num * 0.5).astype(int)]: 194 | self.index_map +=[[seq_id, j, j + window_size] for j in range(w, w+window_group_size-window_size,step_size)] 195 | elif self.mode == 'test_half_random': 196 | np.random.seed(1) 197 | window_group_size = 3000 198 | selected_indices = [j for j in range(0, _duration-window_group_size, window_group_size)] 199 | np.random.shuffle(selected_indices) 200 | indices_num = len(selected_indices) 201 | for w in selected_indices[np.floor(indices_num * 0.5).astype(int):]: 202 | self.index_map +=[[seq_id, j, j + window_size] for j in range(w, w+window_group_size-window_size,step_size)] 203 | else: 204 | ## applied the mask if we need the training. 205 | self.index_map +=[ 206 | [seq_id, j, j+window_size] for j in range( 207 | 0, _duration - window_size, step_size) 208 | if torch.all(seq.data["mask"][j: j+window_size]) 209 | ] 210 | 211 | ## Loading the data from each sequence into 212 | self.load_data(seq, start_frame, end_frame) 213 | 214 | def __len__(self): 215 | return len(self.index_map) 216 | 217 | def __getitem__(self, item): 218 | seq_id, frame_id, end_frame_id = self.index_map[item][0], self.index_map[item][1], self.index_map[item][2] 219 | data = { 220 | 'dt': self.dt[seq_id][frame_id: end_frame_id], 221 | 'acc': self.acc[seq_id][frame_id: end_frame_id], 222 | 'gyro': self.gyro[seq_id][frame_id: end_frame_id], 223 | 'rot': self.gt_ori[seq_id][frame_id: end_frame_id] 224 | } 225 | init_state = { 226 | 'init_rot': self.gt_ori[seq_id][frame_id][None, ...], 227 | 'init_pos': self.gt_pos[seq_id][frame_id][None, ...], 228 | 'init_vel': self.gt_velo[seq_id][frame_id][None, ...], 229 | } 230 | label = { 231 | 'gt_pos': self.gt_pos[seq_id][frame_id+1 : end_frame_id+1], 232 | 'gt_rot': self.gt_ori[seq_id][frame_id+1 : end_frame_id+1], 233 | 'gt_vel': self.gt_velo[seq_id][frame_id+1 : end_frame_id+1], 234 | } 235 | 236 | return {**data, **init_state, **label} 237 | 238 | def get_dtype(self): 239 | return self.acc[0].dtype 240 | 241 | 242 | 243 | if __name__ == '__main__': 244 | from datasets.dataset_utils import custom_collate 245 | parser = argparse.ArgumentParser() 246 | parser.add_argument('--config', type=str, default='configs/datasets/BaselineEuRoC.conf', help='config file path, i.e., configs/Euroc.conf') 247 | parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu") 248 | 249 | args = parser.parse_args(); print(args) 250 | conf = ConfigFactory.parse_file(args.config) 251 | 252 | dataset = SeqeuncesDataset(data_set_config=conf.train) 253 | loader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, collate_fn=custom_collate) 254 | 255 | for i, (data, init, label) in enumerate(loader): 256 | for k in data: print(k, ":", data[k].shape) 257 | for k in init: print(k, ":", init[k].shape) 258 | for k in label: print(k, ":", label[k].shape) 259 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def imu_seq_collate(data): 4 | acc = torch.stack([d['acc'] for d in data]) 5 | gyro = torch.stack([d['gyro'] for d in data]) 6 | 7 | gt_pos = torch.stack([d['gt_pos'] for d in data]) 8 | gt_rot = torch.stack([d['gt_rot'] for d in data]) 9 | gt_vel = torch.stack([d['gt_vel'] for d in data]) 10 | 11 | init_pos = torch.stack([d['init_pos'] for d in data]) 12 | init_rot = torch.stack([d['init_rot'] for d in data]) 13 | init_vel = torch.stack([d['init_vel'] for d in data]) 14 | 15 | dt = torch.stack([d['dt'] for d in data]) 16 | 17 | return { 18 | 'dt': dt, 19 | 'acc': acc, 20 | 'gyro': gyro, 21 | 22 | 'gt_pos': gt_pos, 23 | 'gt_vel': gt_vel, 24 | 'gt_rot': gt_rot, 25 | 26 | 'init_pos': init_pos, 27 | 'init_vel': init_vel, 28 | 'init_rot': init_rot, 29 | } 30 | 31 | def custom_collate(data): 32 | dt = torch.stack([d['dt'] for d in data]) 33 | acc = torch.stack([d['acc'] for d in data]) 34 | gyro = torch.stack([d['gyro'] for d in data]) 35 | rot = torch.stack([d['rot'] for d in data]) 36 | 37 | gt_pos = torch.stack([d['gt_pos'] for d in data]) 38 | gt_rot = torch.stack([d['gt_rot'] for d in data]) 39 | gt_vel = torch.stack([d['gt_vel'] for d in data]) 40 | 41 | init_pos = torch.stack([d['init_pos'] for d in data]) 42 | init_rot = torch.stack([d['init_rot'] for d in data]) 43 | init_vel = torch.stack([d['init_vel'] for d in data]) 44 | 45 | return {'dt': dt, 'acc': acc, 'gyro': gyro, 'rot':rot,}, \ 46 | {'pos': init_pos, 'vel': init_vel, 'rot': init_rot,}, \ 47 | {'gt_pos': gt_pos, 'gt_vel': gt_vel, 'gt_rot': gt_rot, } 48 | 49 | def padding_collate(data, pad_len = 1, use_gravity = True): 50 | B = len(data) 51 | input_data, init_state, label = custom_collate(data) 52 | 53 | if use_gravity: 54 | iden_acc_vector = torch.tensor([0.,0.,9.81007], dtype=input_data['dt'].dtype).repeat(B,pad_len,1) 55 | else: 56 | iden_acc_vector = torch.zeros(B, pad_len, 3, dtype=input_data['dt'].dtype) 57 | 58 | pad_acc = init_state['rot'].Inv() * iden_acc_vector 59 | pad_gyro = torch.zeros(B, pad_len, 3, dtype=input_data['dt'].dtype) 60 | 61 | input_data["acc"] = torch.cat([pad_acc, input_data['acc']], dim =1) 62 | input_data["gyro"] = torch.cat([pad_gyro, input_data['gyro']], dim =1) 63 | 64 | return input_data, init_state, label 65 | 66 | collate_fcs ={ 67 | "base": custom_collate, 68 | "padding": padding_collate, 69 | "padding9": lambda data: padding_collate(data, pad_len = 9), 70 | "padding1": lambda data: padding_collate(data, pad_len = 1), 71 | "Gpadding": lambda data: padding_collate(data, pad_len = 9, use_gravity = False), 72 | } -------------------------------------------------------------------------------- /doc/alto.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haleqiu/AirIMU/c69afb9d1dfa5acf13a5cc1c15dca370f7440636/doc/alto.gif -------------------------------------------------------------------------------- /doc/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haleqiu/AirIMU/c69afb9d1dfa5acf13a5cc1c15dca370f7440636/doc/model.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | import torch.utils.data as Data 6 | import argparse 7 | import pickle 8 | 9 | import tqdm, yaml 10 | from utils import move_to, save_state, cat_state, vis_corrections 11 | from model import net_dict 12 | from pyhocon import ConfigFactory 13 | 14 | from datasets import SeqeuncesDataset, collate_fcs 15 | from model.losses import get_RMSE 16 | 17 | 18 | def get_metrics(eval_state): 19 | metrics = {} 20 | net_dist = (eval_state['evaluate']["rot"].Inv() * eval_state['labels']["gt_rot"]).Log() 21 | metrics['pos'], metrics['rot'], metrics['vel'] = eval_state['loss']['pos'].mean().item(), eval_state['loss']['rot'].mean().item(), eval_state['loss']['vel'].mean().item() 22 | metrics['rot_deg'] = 180./np.pi * metrics['rot'] 23 | 24 | return metrics 25 | 26 | 27 | def evaluate(network, loader, confs, silent_tqdm=False): 28 | network.eval() 29 | evaluate_cov_states, evaluate_states, loss_states, labels = {}, {}, {}, {} 30 | pred_rot_covs, pred_vel_covs, pred_pos_covs = [], [], [] 31 | 32 | with torch.no_grad(): 33 | inte_state = None 34 | for i, (data, init_state, label) in enumerate(tqdm.tqdm(loader, disable=silent_tqdm)): 35 | data, init_state, label = move_to([data, init_state, label], confs.device) 36 | # Use the gt init state while there is no integration. 37 | if inte_state is not None and confs.gtinit is False: 38 | init_state ={ 39 | "pos": inte_state['pos'][:,-1], 40 | "rot": inte_state['rot'][:,-1], 41 | "vel": inte_state['vel'][:,-1], 42 | } 43 | inte_state = network(data, init_state) 44 | loss_state = get_RMSE(inte_state, label) 45 | 46 | save_state(loss_states, loss_state) 47 | save_state(evaluate_states, inte_state) 48 | save_state(labels, label) 49 | 50 | if 'cov' in inte_state and inte_state['cov'] is not None: 51 | cov_diag = torch.diagonal(inte_state['cov'], dim1=-2, dim2=-1) # Shape: (B, 9) 52 | 53 | pred_rot_covs.append(cov_diag[..., :3]) 54 | pred_pos_covs.append(cov_diag[...,-3:]) 55 | pred_vel_covs.append(cov_diag[...,3:6]) 56 | 57 | if 'cov' in inte_state and inte_state['cov'] is not None: 58 | evaluate_cov_states["pred_rot_covs"] = torch.cat(pred_rot_covs, dim = -2) 59 | evaluate_cov_states["pred_vel_covs"] = torch.cat(pred_vel_covs, dim = -2) 60 | evaluate_cov_states["pred_pos_covs"] = torch.cat(pred_pos_covs, dim = -2) 61 | 62 | for k, v in loss_states.items(): 63 | loss_states[k] = torch.stack(v, dim=0) 64 | cat_state(evaluate_states) 65 | cat_state(labels) 66 | 67 | print("evaluating: position loss %f, rotation loss %f, vel losses %f"\ 68 | %(loss_states['pos'].mean(), loss_states['rot'].mean(), loss_states['vel'].mean())) 69 | 70 | return {'evaluate': evaluate_states, 'evaluate_cov': evaluate_cov_states, 'loss': loss_states, 'labels': labels} 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path') 76 | parser.add_argument('--load', type=str, default=None, help='path for model check point') 77 | parser.add_argument("--device", type=str, default="cuda:0", help="cuda or cpu") 78 | parser.add_argument('--seqlen', type=int, default=None, help='the length of the integration sequence.') 79 | parser.add_argument('--batch_size', type=int, default=1, help='batch size.') 80 | parser.add_argument('--steplen', type=int, default=None, help='the length of the step we take.') 81 | parser.add_argument('--gtrot', default=True, action="store_false", help='if set False, we will not use ground truth orientation to compensate the gravity') 82 | parser.add_argument('--gtinit', default=True, action="store_false", help='if set False, we will use the integrated pose as the intial pose for the next integral') 83 | parser.add_argument('--posonly', default=False, action="store_true", help='if True, ground truth rotation will be applied in the integration.') 84 | parser.add_argument('--train', default=False, action="store_true", help='if True, We will evaluate the training set (may be removed in the future).') 85 | parser.add_argument('--whole', default=False, action="store_true", help='(may be removed in the future).') 86 | 87 | args = parser.parse_args(); print(args) 88 | conf = ConfigFactory.parse_file(args.config) 89 | conf.train.device = args.device 90 | conf_name = os.path.split(args.config)[-1].split(".")[0] 91 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name) 92 | 93 | if args.posonly: 94 | conf.train["posonly"] = True 95 | 96 | conf.train["gtrot"] = args.gtrot 97 | conf.train["gtinit"] = args.gtinit 98 | conf.train['sampling'] = False 99 | network = net_dict[conf.train.network](conf.train).to(args.device).double() 100 | 101 | save_folder = os.path.join(conf.general.exp_dir, "evaluate") 102 | os.makedirs(save_folder, exist_ok=True) 103 | 104 | if args.load is None: 105 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt") 106 | else: 107 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt", args.load) 108 | 109 | if os.path.exists(ckpt_path): 110 | checkpoint = torch.load(ckpt_path, map_location=torch.device(args.device)) 111 | print("loaded state dict %s in epoch %i"%(ckpt_path, checkpoint["epoch"])) 112 | network.load_state_dict(checkpoint["model_state_dict"]) 113 | else: 114 | print("no model loaded", ckpt_path) 115 | 116 | if 'collate' in conf.dataset.keys(): 117 | collate_fn = collate_fcs[conf.dataset.collate] 118 | else: 119 | collate_fn = collate_fcs['base'] 120 | 121 | if args.train: 122 | dataset_conf = conf.dataset.train 123 | else: 124 | dataset_conf = conf.dataset.eval 125 | 126 | if args.posonly: 127 | dataset_conf['calib'] = "posonly" 128 | 129 | for data_conf in dataset_conf.data_list: 130 | if args.seqlen is not None: 131 | data_conf["window_size"] = args.seqlen 132 | data_conf["step_size"] = args.seqlen if args.steplen is None else args.steplen 133 | if args.whole: 134 | data_conf["mode"] = "inference" 135 | 136 | pos_loss_xyzs = [] 137 | pred_pos_cov = [] 138 | 139 | all_metrics = {} 140 | for path in data_conf.data_drive: 141 | eval_dataset = SeqeuncesDataset(data_set_config=dataset_conf, data_path=path) 142 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, drop_last = False) 143 | eval_state = evaluate(network=network, loader = eval_loader, confs=conf.train) 144 | ## Save the state in 145 | net_result_path = os.path.join(conf.general.exp_dir, path + '_eval_state.pickle') 146 | with open(net_result_path, 'wb') as handle: 147 | pickle.dump(eval_state, handle, protocol=pickle.HIGHEST_PROTOCOL) 148 | 149 | if "pred_pos_covs" in eval_state['evaluate_cov'].keys(): 150 | pred_pos_cov.append(eval_state['evaluate_cov']['pred_pos_covs']) 151 | 152 | title = "$SO(3)$ orientation error" 153 | if 'correction_acc' in eval_state['evaluate'].keys(): 154 | correction = torch.cat((eval_state['evaluate']['correction_acc'][0], eval_state['evaluate']['correction_gyro'][0]), dim=-1) 155 | vis_corrections(correction.cpu(), save_folder=os.path.join(save_folder, path)) 156 | 157 | all_metrics[path] = get_metrics(eval_state) 158 | 159 | with open(os.path.join(conf.general.exp_dir, 'result_%s_init%s_gt%s.yaml'%(str(args.seqlen),str(args.gtinit),str(args.gtrot))), 'w') as file: 160 | yaml.dump(all_metrics, file, default_flow_style=False) -------------------------------------------------------------------------------- /evaluation/evaluate_state.py: -------------------------------------------------------------------------------- 1 | # output the trajctory in the world frame for visualization and evaluation 2 | import os, sys 3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))) 4 | 5 | import os 6 | import json 7 | import argparse 8 | import numpy as np 9 | import pypose as pp 10 | 11 | import torch 12 | import torch.utils.data as Data 13 | 14 | from pyhocon import ConfigFactory 15 | from datasets import SeqInfDataset, SeqDataset, imu_seq_collate 16 | 17 | from utils import CPU_Unpickler, integrate 18 | from utils.visualize_state import visualize_rotations, visualize_state_error 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--device", type=str, default="cpu", help="cuda or cpu, Default is cuda:0") 23 | parser.add_argument("--exp", type=str, default=None, help="the directory path where your network output pickle file is stored") 24 | parser.add_argument("--seqlen", type=int, default="200", help="the length of the integration sequence") 25 | parser.add_argument("--dataconf", type=str, default="configs/datasets/BaselineEuroc/Euroc_1000.conf", help="the configuration of the dataset") 26 | parser.add_argument("--savedir",type=str,default = "./result/loss_result",help = "the save diretory for the evaluation results") 27 | parser.add_argument("--usegtrot", action="store_true", help="Use ground truth rotation for gravity compensation") 28 | parser.add_argument("--mask", action="store_true", help="Mask the segments if needed") 29 | 30 | args = parser.parse_args(); 31 | print(("\n"*3) + str(args) + ("\n"*3)) 32 | config = ConfigFactory.parse_file(args.dataconf) 33 | dataset_conf = config.inference 34 | 35 | if args.exp is not None: 36 | net_result_path = os.path.join(args.exp, 'net_output.pickle') 37 | if os.path.isfile(net_result_path): 38 | with open(net_result_path, 'rb') as handle: 39 | inference_state_load = CPU_Unpickler(handle).load() 40 | else: 41 | raise Exception(f"Unable to load the network result: {net_result_path}") 42 | 43 | folder = args.savedir 44 | os.makedirs(folder, exist_ok=True) 45 | 46 | AllResults = [] 47 | 48 | for data_conf in dataset_conf.data_list: 49 | print(data_conf) 50 | for data_name in data_conf.data_drive: 51 | print(data_conf.data_root, data_name) 52 | print("data_conf.dataroot", data_conf.data_root) 53 | print("data_name", data_name) 54 | print("data_conf.name", data_conf.name) 55 | 56 | dataset = SeqDataset(data_conf.data_root, data_name, args.device, name = data_conf.name, duration=args.seqlen, step_size=args.seqlen, drop_last=False, conf = dataset_conf) 57 | loader = Data.DataLoader(dataset=dataset, batch_size=1, collate_fn=imu_seq_collate, shuffle=False, drop_last=False) 58 | 59 | init = dataset.get_init_value() 60 | gravity = dataset.get_gravity() 61 | integrator_outstate = pp.module.IMUPreintegrator( 62 | init['pos'], init['rot'], init['vel'],gravity=gravity, 63 | reset=False 64 | ).to(args.device).double() 65 | 66 | integrator_reset = pp.module.IMUPreintegrator( 67 | init['pos'], init['rot'], init['vel'],gravity = gravity, 68 | reset=True 69 | ).to(args.device).double() 70 | 71 | outstate = integrate( 72 | integrator_outstate, loader, init, 73 | device=args.device, gtinit=False, save_full_traj=True, 74 | use_gt_rot=args.usegtrot 75 | ) 76 | relative_outstate = integrate( 77 | integrator_reset, loader, init, 78 | device=args.device, gtinit=True, 79 | use_gt_rot=args.usegtrot 80 | ) 81 | 82 | if args.exp is not None: 83 | inference_state = inference_state_load[data_name] 84 | dataset_inf = SeqInfDataset(data_conf.data_root, data_name, inference_state, device = args.device, name = data_conf.name,duration=args.seqlen, step_size=args.seqlen, drop_last=False, conf = dataset_conf) 85 | infloader = Data.DataLoader(dataset=dataset_inf, batch_size=1, 86 | collate_fn=imu_seq_collate, 87 | shuffle=False, drop_last=True) 88 | 89 | integrator_infstate = pp.module.IMUPreintegrator( 90 | init['pos'], init['rot'], init['vel'], gravity = gravity, 91 | reset=False 92 | ).to(args.device).double() 93 | 94 | infstate = integrate( 95 | integrator_infstate, infloader, init, 96 | device=args.device, gtinit=False, save_full_traj=True, 97 | use_gt_rot=args.usegtrot 98 | ) 99 | relative_infstate = integrate( 100 | integrator_reset, infloader, init, 101 | device=args.device, gtinit=True, 102 | use_gt_rot=args.usegtrot 103 | ) 104 | 105 | index_id = dataset.index_map[:, -1] 106 | mask = torch.ones(dataset.seqlen, dtype = torch.bool) 107 | select_mask = torch.ones_like(dataset.get_mask()[index_id], dtype = torch.bool) 108 | 109 | ### For the datasets with mask like TUMVI 110 | if args.mask: 111 | mask = dataset.get_mask()[:dataset.seqlen] 112 | select_mask = dataset.get_mask()[index_id] 113 | select_mask[-1] = False #3 drop last 114 | 115 | #save loss result 116 | result_dic = { 117 | 'name': data_name, 118 | 'use_gt_rot': args.usegtrot, 119 | 'AOE(raw)':180./np.pi * outstate['rot_dist'][0, mask].mean().numpy(), 120 | 'ATE(raw)':outstate['pos_dist'][0, mask].mean().item(), 121 | 'AVE(raw)':outstate['vel_dist'][0, mask].mean().item(), 122 | 123 | 'ROE(raw)':180./np.pi *relative_outstate['rot_dist'][0, select_mask].mean().numpy(), 124 | 'RTE(raw)':relative_outstate['pos_dist'][0, select_mask].mean().item(), 125 | 'RVE(raw)':relative_outstate['vel_dist'][0, select_mask].mean().item(), 126 | 127 | 'RP_RMSE(raw)': np.sqrt((relative_outstate['pos_dist'][0, select_mask]**2).mean()).numpy().item(), 128 | 'RV_RMSE(raw)': np.sqrt((relative_outstate['vel_dist'][0, select_mask]**2).mean()).numpy().item(), 129 | 'RO_RMSE(raw)':180./np.pi * torch.sqrt((relative_outstate['rot_dist'][0, select_mask]**2).mean()).item(), 130 | 'O_RMSE(raw)':180./np.pi * torch.sqrt((outstate['rot_dist'][0, mask]**2).mean()).item(), 131 | 132 | 133 | 'AOE(AirIMU)':180./np.pi * infstate['rot_dist'][0, mask].mean().numpy(), 134 | 'ATE(AirIMU)':infstate['pos_dist'][0, mask].mean().item(), 135 | 'AVE(AirIMU)':infstate['vel_dist'][0, mask].mean().item(), 136 | 137 | 'ROE(AirIMU)':180./np.pi * relative_infstate['rot_dist'][0, select_mask].mean().numpy(), 138 | 'RTE(AirIMU)':relative_infstate['pos_dist'][0, select_mask].mean().item(), 139 | 'RVE(AirIMU)':relative_infstate['vel_dist'][0, select_mask].mean().item(), 140 | 141 | 'RP_RMSE(AirIMU)': np.sqrt((relative_infstate['pos_dist'][0, select_mask]**2).mean()).item(), 142 | 'RV_RMSE(AirIMU)': np.sqrt((relative_infstate['vel_dist'][0, select_mask]**2).mean()).item(), 143 | 'RO_RMSE(AirIMU)':180./np.pi * torch.sqrt((relative_infstate['rot_dist'][0, select_mask]**2).mean()).numpy(), 144 | 'O_RMSE(AirIMU)': 180./np.pi * torch.sqrt((infstate['rot_dist'][0, mask]**2).mean()).numpy(), 145 | } 146 | 147 | AllResults.append(result_dic) 148 | 149 | print("==============Integration==============") 150 | print("outstate:") 151 | print("pos_err: ", outstate['pos_dist'].mean()) 152 | print("rot_err: ", outstate['rot_dist'].mean()) 153 | print("vel_err: ", outstate['vel_dist'].mean()) 154 | 155 | print("relative_outstate") 156 | print("pos_err: ", relative_outstate['pos_dist'].mean()) 157 | print("rot_err: ", relative_outstate['rot_dist'].mean()) 158 | print("vel_err: ", relative_outstate['vel_dist'].mean()) 159 | 160 | print("==============AirIMU==============") 161 | print("infstate:") 162 | print("pos_err: ", infstate['pos_dist'].mean()) 163 | print("rot_err: ", infstate['rot_dist'].mean()) 164 | print("vel_err: ", infstate['vel_dist'].mean()) 165 | 166 | print("relatvie_infstate") 167 | print("pos_err: ", relative_infstate['pos_dist'].mean()) 168 | print("rot_err: ", relative_infstate['rot_dist'].mean()) 169 | print("vel_err: ", relative_infstate['vel_dist'].mean()) 170 | 171 | visualize_state_error(data_name,outstate,infstate,save_folder=folder,mask=mask,file_name="inte_error_compare.png") 172 | visualize_state_error(data_name,relative_outstate,relative_infstate,mask=select_mask,save_folder=folder) 173 | visualize_rotations(data_name,outstate['orientations_gt'][0],outstate['orientations'][0],infstate['orientations'][0],save_folder=folder) 174 | 175 | file_path = os.path.join(folder, "loss_result.json") 176 | with open(file_path, 'w') as f: 177 | json.dump(AllResults, f, indent=4) 178 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import torch.utils.data as Data 5 | import argparse 6 | import pickle 7 | 8 | import tqdm 9 | from utils import move_to, save_state 10 | from pyhocon import ConfigFactory 11 | 12 | from datasets import collate_fcs, SeqeuncesDataset 13 | from model import net_dict 14 | from utils import * 15 | 16 | 17 | 18 | def inference(network, loader, confs): 19 | ''' 20 | Correction inference 21 | save the corrections generated from the network. 22 | ''' 23 | network.eval() 24 | evaluate_states = {} 25 | 26 | with torch.no_grad(): 27 | inte_state = None 28 | for data, _, _ in tqdm.tqdm(loader): 29 | data = move_to(data, confs.device) 30 | # Use the gt init state while there is no integration. 31 | inte_state = network.inference(data) 32 | # update the corected acc and gyro 33 | save_state(evaluate_states, inte_state) 34 | 35 | for k, v in evaluate_states.items(): 36 | evaluate_states[k] = torch.cat(v, dim=-2) 37 | 38 | return evaluate_states 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path') 44 | parser.add_argument('--load', type=str, default=None, help='path for model check point') 45 | parser.add_argument("--device", type=str, default="cuda:0", help="cuda or cpu") 46 | parser.add_argument('--batch_size', type=int, default=1, help='batch size.') 47 | parser.add_argument('--seqlen', type=int, default=1000, help='the length of the segment') 48 | parser.add_argument('--train', default=False, action="store_true", help='if True, We will evaluate the training set (may be removed in the future).') 49 | parser.add_argument('--gtinit', default=True, action="store_false", help='if set False, we will use the integrated pose as the intial pose for the next integral') 50 | parser.add_argument('--whole', default=False, action="store_true", help='(may be removed in the future).') 51 | 52 | 53 | args = parser.parse_args(); print(args) 54 | conf = ConfigFactory.parse_file(args.config) 55 | conf.train.device = args.device 56 | conf_name = os.path.split(args.config)[-1].split(".")[0] 57 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name) 58 | conf.train['sampling'] = False 59 | conf["gtinit"] = args.gtinit 60 | conf['device'] = args.device 61 | 62 | ''' 63 | Load the pretrained model 64 | ''' 65 | network = net_dict[conf.train.network](conf.train).to(args.device).double() 66 | save_folder = os.path.join(conf.general.exp_dir, "evaluate") 67 | os.makedirs(save_folder, exist_ok=True) 68 | 69 | if args.load is None: 70 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt") 71 | else: 72 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt", args.load) 73 | 74 | if os.path.exists(ckpt_path): 75 | checkpoint = torch.load(ckpt_path, map_location=torch.device(args.device)) 76 | print("loaded state dict %s in epoch %i"%(ckpt_path, checkpoint["epoch"])) 77 | network.load_state_dict(checkpoint["model_state_dict"]) 78 | else: 79 | raise Exception(f"No model loaded {ckpt_path}") 80 | 81 | if 'collate' in conf.dataset.keys(): 82 | collate_fn = collate_fcs[conf.dataset.collate] 83 | else: 84 | collate_fn = collate_fcs['base'] 85 | 86 | print(conf.dataset) 87 | dataset_conf = conf.dataset.inference 88 | 89 | ''' 90 | Run and save the IMU correction 91 | ''' 92 | cov_result, rmse = [], [] 93 | net_out_result = {} 94 | evals = {} 95 | dataset_conf.data_list[0]["window_size"] = args.seqlen 96 | dataset_conf.data_list[0]["step_size"] = args.seqlen 97 | for data_conf in dataset_conf.data_list: 98 | for path in data_conf.data_drive: 99 | if args.whole: 100 | dataset_conf["mode"] = "inference" 101 | else: 102 | dataset_conf["mode"] = "infevaluate" 103 | dataset_conf["exp_dir"] = conf.general.exp_dir 104 | print("\n"*3 + str(dataset_conf)) 105 | eval_dataset = SeqeuncesDataset(data_set_config=dataset_conf, data_path=path, data_root=data_conf["data_root"]) 106 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size, 107 | shuffle=False, collate_fn=collate_fn, drop_last = False) 108 | 109 | inference_state = inference(network=network, loader = eval_loader, confs=conf.train) 110 | if not "acc_cov" in inference_state.keys(): 111 | inference_state["acc_cov"] = torch.zeros_like(inference_state["correction_acc"]) 112 | if not "gyro_cov" in inference_state.keys(): 113 | inference_state["gyro_cov"] = torch.zeros_like(inference_state["correction_gyro"]) 114 | 115 | inference_state['corrected_acc'] = eval_dataset.acc[0] + inference_state['correction_acc'].squeeze(0).cpu() 116 | inference_state['corrected_gyro'] = eval_dataset.gyro[0] + inference_state['correction_gyro'].squeeze(0).cpu() 117 | inference_state['rot'] = eval_dataset.gt_ori[0] 118 | inference_state['dt'] = eval_dataset.dt[0] 119 | 120 | net_out_result[path] = inference_state 121 | 122 | #### RPE and Cov analysis 123 | rpe_pos, rpe_rot, mse_pos = [], [], [] 124 | relative_cov, relative_sigma_x, relative_sigma_y, relative_sigma_z = [], [], [], [] 125 | dataset_conf["mode"] = "evaluate" 126 | 127 | net_result_path = os.path.join(conf.general.exp_dir, 'net_output.pickle') 128 | print("save netout, ", net_result_path) 129 | with open(net_result_path, 'wb') as handle: 130 | pickle.dump(net_out_result, handle, protocol=pickle.HIGHEST_PROTOCOL) 131 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import ModelBase 2 | from .cnn import CNNPOS 3 | from .others import Identity, ParamNet 4 | from .code import * 5 | 6 | net_dict = { 7 | 'codeposenet': CodePoseNet, 8 | 'codenetkitti': CodeNetKITTI, 9 | 'iden': Identity, 10 | 'cnnpos': CNNPOS, 11 | 'codenet': CodeNet, 12 | 'param': ParamNet, 13 | } 14 | -------------------------------------------------------------------------------- /model/cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pypose as pp 4 | import torch.nn as nn 5 | 6 | from model.net import ModelBase 7 | 8 | class CNNEncoder(nn.Module): 9 | def __init__(self, duration = 1, k_list = [7, 7, 7, 7], c_list = [6, 16, 32, 64, 128], 10 | s_list = [1, 1, 1, 1], p_list = [3, 3, 3, 3]): 11 | super(CNNEncoder, self).__init__() 12 | self.duration = duration 13 | self.k_list, self.c_list, self.s_list, self.p_list = k_list, c_list, s_list, p_list 14 | layers = [] 15 | 16 | for i in range(len(self.c_list) - 1): 17 | layers.append(torch.nn.Conv1d(self.c_list[i], self.c_list[i+1], self.k_list[i], \ 18 | stride=self.s_list[i], padding=self.p_list[i])) 19 | layers.append(torch.nn.BatchNorm1d(self.c_list[i+1])) 20 | layers.append(torch.nn.GELU()) 21 | layers.append(torch.nn.Dropout(0.1)) 22 | 23 | self.net = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.net(x) 27 | 28 | 29 | class CNNcorrection(ModelBase): 30 | ''' 31 | The input feature shape [B, F, Duration, 6] 32 | ''' 33 | def __init__(self, conf): 34 | super(CNNcorrection, self).__init__(conf) 35 | self.k_list = [7, 7, 7, 7] 36 | self.c_list = [6, 32, 64, 128, 256] 37 | 38 | self.cnn = CNNEncoder(c_list=self.c_list, k_list=self.k_list) 39 | 40 | self.accdecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 41 | self.acccov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 42 | 43 | self.gyrodecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 44 | self.gyrocov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 45 | 46 | gyro_std = np.pi/180 47 | if "gyro_std" in conf: 48 | print(" The gyro std is set to ", conf.gyro_std, " rad/s") 49 | gyro_std = conf.gyro_std 50 | self.register_buffer('gyro_std', torch.tensor(gyro_std)) 51 | 52 | acc_std = 0.1 53 | if "acc_std" in conf: 54 | print(" The acc std is set to ", conf.acc_std, " m/s^2") 55 | acc_std = conf.acc_std 56 | self.register_buffer('acc_std', torch.tensor(acc_std)) 57 | 58 | def encoder(self, x): 59 | return self.cnn(x.transpose(-1,-2)).transpose(-1,-2) 60 | 61 | def decoder(self, x): 62 | acc = self.accdecoder(x) * self.acc_std 63 | gyro = self.gyrodecoder(x) * self.gyro_std 64 | coorections = torch.cat([acc, gyro], dim = -1) 65 | 66 | return coorections 67 | 68 | def cov_decoder(self, x): 69 | return self.cov_head(x).transpose(-1,-2) 70 | 71 | 72 | class CNNPOS(CNNcorrection): 73 | """ 74 | Only correct the accelerometer 75 | """ 76 | def __init__(self, conf): 77 | super(CNNPOS, self).__init__(conf) 78 | 79 | def decoder(self, x): 80 | acc = self.accdecoder(x) * self.acc_std 81 | gyro = torch.zeros_like(acc) 82 | coorections = torch.cat([acc, gyro], dim = -1) 83 | 84 | return coorections -------------------------------------------------------------------------------- /model/code.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pypose as pp 4 | 5 | import torch.nn as nn 6 | from model.net import ModelBase 7 | from model.cnn import CNNEncoder 8 | 9 | 10 | class CodeNet(ModelBase): 11 | def __init__(self, conf): 12 | super().__init__(conf) 13 | self.conf = conf 14 | 15 | gyro_std = np.pi/180 16 | if "gyro_std" in conf: 17 | print(" The gyro std is set to ", conf.gyro_std, " rad/s") 18 | gyro_std = conf.gyro_std 19 | self.register_buffer('gyro_std', torch.tensor(gyro_std)) 20 | 21 | acc_std = 0.1 22 | if "acc_std" in conf: 23 | print(" The acc std is set to ", conf.acc_std, " m/s^2") 24 | acc_std = conf.acc_std 25 | self.register_buffer('acc_std', torch.tensor(acc_std)) 26 | 27 | ## the encoder have the same correction in one interval 28 | self.interval = 9 29 | self.inter_head = np.floor(self.interval/2.).astype(int) 30 | self.inter_tail = self.interval - self.inter_head 31 | 32 | self.cnn = CNNEncoder(c_list=[6, 32, 64], k_list=[7, 7], s_list=[3, 3])# (N,F/8,64) 33 | 34 | self.gru1 = nn.GRU(input_size = 64, hidden_size = 128, num_layers = 1, batch_first = True) 35 | self.gru2 = nn.GRU(input_size = 128, hidden_size = 256, num_layers = 1, batch_first = True) 36 | 37 | self.accdecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 38 | self.acccov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 39 | 40 | self.gyrodecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 41 | self.gyrocov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3)) 42 | 43 | def encoder(self, x): 44 | x = self.cnn(x.transpose(-1,-2)).transpose(-1,-2) 45 | x, _ = self.gru1(x) 46 | x, _ = self.gru2(x) 47 | 48 | return x 49 | 50 | def cov_decoder(self, x): 51 | acc = torch.exp(self.acccov_decoder(x) - 5.) 52 | gyro = torch.exp(self.gyrocov_decoder(x) - 5.) 53 | 54 | return torch.cat([acc, gyro], dim = -1) 55 | 56 | def decoder(self, x): 57 | acc = self.accdecoder(x) * self.acc_std 58 | gyro = self.gyrodecoder(x) * self.gyro_std 59 | 60 | return torch.cat([acc, gyro], dim = -1) 61 | 62 | def _update(self, to_update, feat, frame_len): 63 | ### Note: This will change the data in the to_update !!!!!! 64 | def _clip(x,l): 65 | if x > l: 66 | return l 67 | elif x < 0: 68 | return 0 69 | else: 70 | return x 71 | 72 | _feat_range = np.ceil((frame_len-self.inter_head)/self.interval).astype(int) + 1 ## not equivalent to features shape 73 | 74 | for i in range(_feat_range): 75 | s_p = _clip(i*self.interval-self.inter_head, frame_len) 76 | e_p = _clip(i*self.interval+self.inter_tail, frame_len) 77 | idx = _clip(i, feat.shape[1]-1) 78 | 79 | # skip the first padded input 80 | to_update[:,s_p:e_p,:] += feat[:,idx:idx+1,:] 81 | 82 | return to_update 83 | 84 | def inference(self, data): 85 | frame_len = data["acc"].shape[1] - self.interval 86 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1) 87 | feature = self.encoder(feature)[:,1:,:] 88 | correction = self.decoder(feature) 89 | zero_signal = torch.zeros_like(data['acc'][:,self.interval:,:]) 90 | 91 | # a referenced size 1000 92 | correction_acc = self._update(zero_signal.clone(), correction[...,:3], frame_len) 93 | correction_gyro = self._update(zero_signal.clone(), correction[...,3:], frame_len) 94 | 95 | # covariance propagation 96 | cov_state = {'acc_cov':None, 'gyro_cov': None,} 97 | if self.conf.propcov: 98 | cov = self.cov_decoder(feature) 99 | cov_state['acc_cov'] = self._update(torch.zeros_like(correction_acc, device=correction_acc.device), 100 | cov[...,:3], frame_len) 101 | cov_state['gyro_cov'] = self._update(torch.zeros_like(correction_gyro, device=correction_gyro.device), 102 | cov[...,3:], frame_len) 103 | 104 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro} 105 | 106 | def forward(self, data, init_state): 107 | inference_state = self.inference(data) 108 | 109 | data['corrected_acc'] = data['acc'][:,self.interval:,:] + inference_state['correction_acc'] 110 | data['corrected_gyro'] = data['gyro'][:,self.interval:,:] + inference_state['correction_gyro'] 111 | 112 | out_state = self.integrate(init_state = init_state, data = data, cov_state = inference_state['cov_state']) 113 | 114 | return {**out_state, 'correction_acc': inference_state['correction_acc'], 'correction_gyro': inference_state['correction_gyro'], 115 | 'corrected_acc': data['corrected_acc'], 'corrected_gyro': data['corrected_gyro']} 116 | 117 | 118 | class CodePoseNet(CodeNet): 119 | def __init__(self, conf): 120 | super().__init__(conf) 121 | 122 | def inference(self, data): 123 | frame_len = data["acc"].shape[1] - self.interval 124 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1) 125 | feature = self.encoder(feature)[:,1:,:] 126 | correction = self.decoder(feature) 127 | zero_signal = torch.zeros_like(data['acc'][:,self.interval:,:]) 128 | 129 | # a referenced size 1000 130 | correction_acc = self._update(zero_signal.clone(), correction[...,:3], frame_len) 131 | correction_gyro = zero_signal.clone() 132 | 133 | # covariance propagation 134 | cov_state = {'acc_cov':None, 'gyro_cov': None,} 135 | if self.conf.propcov: 136 | cov = self.cov_decoder(feature) 137 | cov_state['acc_cov'] = self._update(torch.zeros_like(correction_acc, device=correction_acc.device), 138 | cov[...,:3], frame_len) 139 | cov_state['gyro_cov'] = self._update(torch.zeros_like(correction_gyro, device=correction_gyro.device), 140 | cov[...,3:], frame_len) 141 | 142 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro} 143 | 144 | 145 | class CodeNetKITTI(torch.nn.Module): 146 | def __init__(self, conf): 147 | super().__init__() 148 | self.conf = conf 149 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True).double() 150 | 151 | self.accEncoder = CNNEncoder(k_list=[7, 3, 3], p_list=[3, 1, 1], c_list=[3, 32, 64, 128]) 152 | self.gyroEncoder = CNNEncoder(k_list=[7, 3, 3], p_list=[3, 1, 1], c_list=[3, 32, 64, 128]) 153 | 154 | self.accDecoder = nn.Sequential( 155 | nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 32), nn.GELU(), nn.Linear(32, 3) 156 | ) 157 | self.gyroDecoder = nn.Sequential( 158 | nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 32), nn.GELU(), nn.Linear(32, 3) 159 | ) 160 | self.accCovDecoder = nn.Sequential( 161 | nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 32), nn.GELU(), nn.Linear(32, 3) 162 | ) 163 | self.gyroCovDecoder = nn.Sequential( 164 | nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 32), nn.GELU(), nn.Linear(32, 3) 165 | ) 166 | 167 | gyro_std = np.pi/180 168 | self.register_buffer('gyro_std', torch.tensor(gyro_std)) 169 | 170 | acc_std = 0.1 171 | self.register_buffer('acc_std', torch.tensor(acc_std)) 172 | 173 | def integrate(self, init_state, data, cov_state, use_gtrot): 174 | gt_rot = None 175 | if self.conf.gtrot: gt_rot = data['rot'].double() 176 | if not use_gtrot: gt_rot = None 177 | 178 | if self.conf.propcov: 179 | out_state = self.integrator( 180 | init_state = init_state, 181 | dt = data['dt'].double(), 182 | gyro = data['corrected_gyro'].double(), 183 | acc = data['corrected_acc'].double(), 184 | rot = gt_rot, 185 | acc_cov = cov_state['acc_cov'].double(), 186 | gyro_cov = cov_state['gyro_cov'].double() 187 | ) 188 | else: 189 | out_state = self.integrator( 190 | init_state = init_state, 191 | dt = data['dt'].double(), 192 | gyro = data['corrected_gyro'].double(), 193 | acc = data['corrected_acc'].double(), 194 | rot = gt_rot, 195 | ) 196 | 197 | return {**out_state, **cov_state} 198 | 199 | def inference(self, data): 200 | feature_acc = self.accEncoder(data["acc"].transpose(-1,-2)).transpose(-1,-2) 201 | feature_gyro = self.gyroEncoder(data["gyro"].transpose(-1,-2)).transpose(-1,-2) 202 | 203 | correction_acc = self.accDecoder(feature_acc) 204 | correction_gyro = self.gyroDecoder(feature_gyro) 205 | 206 | cov_state = {'acc_cov':None, 'gyro_cov': None} 207 | if self.conf.propcov: 208 | feature = torch.cat([feature_acc, feature_gyro], dim = -1) 209 | cov_state['acc_cov'] = self.accCovDecoder(feature).exp() 210 | cov_state['gyro_cov'] = self.gyroCovDecoder(feature).exp() 211 | 212 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro} 213 | 214 | def forward(self, data, init_state, use_gtrot=True): 215 | init_state_ = { 216 | "pos": init_state["pos"], 217 | "rot": init_state["rot"][:,:1,:], 218 | "vel": init_state["vel"], 219 | } 220 | inference_state = self.inference(data) 221 | 222 | data['corrected_acc'] = data['acc'] + inference_state['correction_acc'] 223 | data['corrected_gyro'] = data['gyro'] + inference_state['correction_gyro'] 224 | 225 | out_state = self.integrate(init_state=init_state_, data = data, cov_state = inference_state['cov_state'], use_gtrot=use_gtrot) 226 | 227 | return { 228 | **out_state, 229 | 'correction_acc': inference_state['correction_acc'], 230 | 'correction_gyro': inference_state['correction_gyro'], 231 | 'corrected_acc': data['corrected_acc'], 232 | 'corrected_gyro': data['corrected_gyro'] 233 | } 234 | -------------------------------------------------------------------------------- /model/loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | EPSILON = 1e-7 4 | 5 | def diag_cov_loss(dist, pred_cov): 6 | error = (dist).pow(2) 7 | return torch.mean(error / 2*(torch.exp(2 * pred_cov)) + pred_cov) 8 | 9 | def diag_ln_cov_loss(dist, pred_cov, use_epsilon=False): 10 | error = (dist).pow(2) 11 | if use_epsilon: l = ((error / pred_cov) + torch.log(pred_cov + EPSILON)) 12 | else: l = ((error / pred_cov) + torch.log(pred_cov)) 13 | return l.mean() 14 | 15 | def L2(dist): 16 | error = dist.pow(2) 17 | return torch.mean(error) 18 | 19 | def L1(dist): 20 | error = (dist).abs().mean() 21 | return error 22 | 23 | def loss_weight_decay(error, decay_rate = 0.95): 24 | F = error.shape[-2] 25 | decay_list = decay_rate * torch.ones(F, device=error.device, dtype=error.dtype) 26 | decay_list[0] = 1. 27 | decay_list = torch.cumprod(decay_list, 0) 28 | error = torch.einsum('bfc, f -> bfc', error, decay_list) 29 | return error 30 | 31 | def loss_weight_decrease(error, decay_rate = 0.95): 32 | F = error.shape[-2] 33 | decay_list = torch.tensor([1./i for i in range(1, F+1)]) 34 | error = torch.einsum('bfc, f -> bfc', error, decay_list) 35 | return error 36 | 37 | def Huber(dist, delta=0.005): 38 | error = torch.nn.functional.huber_loss(dist, torch.zeros_like(dist, device=dist.device), delta=delta) 39 | return error 40 | 41 | 42 | loss_fc_list = { 43 | "L2": L2, 44 | "L1": L1, 45 | "diag_cov_ln": diag_ln_cov_loss, 46 | "Huber_loss005":lambda dist: Huber(dist, delta = 0.005), 47 | "Huber_loss05":lambda dist: Huber(dist, delta = 0.05), 48 | } -------------------------------------------------------------------------------- /model/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .loss_func import loss_fc_list, diag_ln_cov_loss 3 | from utils import report_hasNan 4 | 5 | 6 | def loss_(fc, pred, targ, sampling = None, dtype = 'trans'): 7 | ## reshape or sample the input targ and pred 8 | ## cov and error is for reference 9 | if sampling: 10 | pred = pred[:,sampling-1::sampling,:] 11 | targ = targ[:,sampling-1::sampling,:] 12 | else: 13 | pred = pred[:,-1:,:] 14 | targ = targ[:,-1:,:] 15 | 16 | if dtype == 'rot': 17 | dist = (pred * targ.Inv()).Log() 18 | else: 19 | dist = pred - targ 20 | loss = fc(dist) 21 | return loss, dist 22 | 23 | 24 | def get_loss(inte_state, data, confs): 25 | ## The state loss for evaluation 26 | loss, state_losses, cov_losses = 0, {}, {} 27 | loss_fc = loss_fc_list[confs.loss] 28 | rotloss_fc = loss_fc_list[confs.rotloss] 29 | 30 | rot_loss, rot_dist = loss_(rotloss_fc, inte_state['rot'], data['gt_rot'], sampling = confs.sampling, dtype='rot') 31 | vel_loss, vel_dist = loss_(loss_fc, inte_state['vel'], data['gt_vel'], sampling = confs.sampling) 32 | pos_loss, pos_dist = loss_(loss_fc, inte_state['pos'], data['gt_pos'], sampling = confs.sampling) 33 | 34 | state_losses['pos'] = pos_dist[:,-1,:].norm(dim=-1).mean() 35 | state_losses['rot'] = rot_dist[:,-1,:].norm(dim=-1).mean() 36 | state_losses['vel'] = vel_dist[:,-1,:].norm(dim=-1).mean() 37 | 38 | # Apply the covariance loss 39 | if confs.propcov: 40 | cov_diag = torch.diagonal(inte_state['cov'], dim1=-2, dim2=-1) 41 | cov_losses['pred_cov_rot'] = cov_diag[...,:3].mean() 42 | cov_losses['pred_cov_vel'] = cov_diag[...,3:6].mean() 43 | cov_losses['pred_cov_pos'] = cov_diag[...,-3:].mean() 44 | 45 | if "covaug" in confs and confs["covaug"] is True: 46 | rot_loss += confs.cov_weight * diag_ln_cov_loss(rot_dist, cov_diag[...,:3]) 47 | vel_loss += confs.cov_weight * diag_ln_cov_loss(vel_dist, cov_diag[...,3:6]) 48 | pos_loss += confs.cov_weight * diag_ln_cov_loss(pos_dist, cov_diag[...,-3:]) 49 | else: 50 | rot_loss += confs.cov_weight * diag_ln_cov_loss(rot_dist.detach(), cov_diag[...,:3]) 51 | vel_loss += confs.cov_weight * diag_ln_cov_loss(vel_dist.detach(), cov_diag[...,3:6]) 52 | pos_loss += confs.cov_weight * diag_ln_cov_loss(pos_dist.detach(), cov_diag[...,-3:]) 53 | 54 | loss += (confs.pos_weight * pos_loss + confs.rot_weight * rot_loss + confs.vel_weight * vel_loss) 55 | # report_hasNan(loss) 56 | 57 | return {'loss':loss, **state_losses, **cov_losses} 58 | 59 | 60 | def get_RMSE(inte_state, data): 61 | ''' 62 | get the RMSE of the last state in one segment 63 | ''' 64 | def _RMSE(x): 65 | return torch.sqrt((x.norm(dim=-1)**2).mean()) 66 | 67 | dist_pos = (inte_state['pos'][:,-1,:] - data['gt_pos'][:,-1,:]) 68 | dist_vel = (inte_state['vel'][:,-1,:] - data['gt_vel'][:,-1,:]) 69 | dist_rot = (data['gt_rot'][:,-1,:] * inte_state['rot'][:,-1,:].Inv()).Log() 70 | 71 | pos_loss = _RMSE(dist_pos)[None,...] 72 | vel_loss = _RMSE(dist_vel)[None,...] 73 | rot_loss = _RMSE(dist_rot)[None,...] 74 | 75 | ## Relative pos error 76 | return {'pos': pos_loss, 'rot': rot_loss, 'vel': vel_loss, 77 | 'pos_dist': dist_pos.norm(dim=-1).mean(), 78 | 'vel_dist': dist_vel.norm(dim=-1).mean(), 79 | 'rot_dist': dist_rot.norm(dim=-1).mean(),} 80 | -------------------------------------------------------------------------------- /model/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pypose as pp 3 | import torch.nn as nn 4 | 5 | class ModelBase(nn.Module): 6 | def __init__(self, conf): 7 | super().__init__() 8 | self.conf = conf 9 | if "ngravity" in conf.keys(): 10 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True, gravity = 0.0) 11 | print("conf.ngravity", conf.ngravity, self.integrator.gravity) 12 | else: 13 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True) 14 | print("network constructed: ", self.conf.network, "gtrot: ", self.conf.gtrot) 15 | 16 | def _select(self, data, start, end): 17 | 18 | select = {} 19 | for k in data.keys(): 20 | if data[k] is None: 21 | select[k] = None 22 | else: 23 | select[k] = data[k][:, start:end] 24 | return select 25 | 26 | def integrate(self, init_state, data, cov_state): 27 | B, F = data["corrected_acc"].shape[:2] 28 | inte_pos, inte_vel, inte_rot, inte_cov = [], [], [], [] 29 | gt_rot = None 30 | if self.conf.gtrot: 31 | gt_rot = data['rot'] 32 | if "posonly" in self.conf.keys(): 33 | data['corrected_gyro'] = data['gyro'] 34 | 35 | if self.conf.sampling: 36 | inte_state = None 37 | for iter in range(0, F, self.conf.sampling): 38 | if (F - iter) < self.conf.sampling: continue 39 | start, end = iter, iter + self.conf.sampling 40 | selected_data = self._select(data, start, end) 41 | selected_cov_state = self._select(cov_state, start, end) 42 | 43 | # take the init sate from last frame as the init state of the next frame 44 | if inte_state is not None: 45 | init_state = { 46 | "pos": inte_state["pos"][:,-1:,:], 47 | "vel": inte_state["vel"][:,-1:,:], 48 | "rot": inte_state["rot"][:,-1:,:], 49 | } 50 | if self.conf.propcov: 51 | init_state["Rij"] = inte_state["Rij"] 52 | init_state["cov"] = inte_state["cov"] 53 | 54 | if self.conf.gtrot: 55 | gt_rot = selected_data['rot'] 56 | 57 | ## starting point and ending point 58 | inte_state = self.integrator(init_state = init_state, dt = selected_data['dt'], gyro = selected_data['corrected_gyro'], 59 | acc = selected_data['corrected_acc'], rot = gt_rot, acc_cov = selected_cov_state['acc_cov'], gyro_cov = selected_cov_state['gyro_cov']) 60 | 61 | inte_pos.append(inte_state['pos']) 62 | inte_rot.append(inte_state['rot']) 63 | inte_vel.append(inte_state['vel']) 64 | inte_cov.append(inte_state['cov']) 65 | 66 | out_state ={ 67 | 'pos': torch.cat(inte_pos, dim =1), 68 | 'vel': torch.cat(inte_vel, dim =1), 69 | 'rot': torch.cat(inte_rot, dim =1), 70 | } 71 | if self.conf.propcov: 72 | out_state['cov'] = torch.stack(inte_cov, dim =1) 73 | else: 74 | 75 | out_state = self.integrator(init_state = init_state, dt = data['dt'], gyro = data['corrected_gyro'], 76 | acc = data['corrected_acc'], rot = gt_rot, acc_cov = cov_state['acc_cov'], gyro_cov = cov_state['gyro_cov']) 77 | 78 | return {**out_state, **cov_state} 79 | 80 | def inference(self, data): 81 | ''' 82 | Pure inference, generate the network output. 83 | ''' 84 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1) 85 | feature = self.encoder(feature) 86 | correction = self.decoder(feature) 87 | 88 | # Correction update 89 | data['corrected_acc'] = correction[...,:3] + data["acc"] 90 | data['corrected_gyro'] = correction[...,3:] + data["gyro"] 91 | 92 | # covariance propagation 93 | cov_state = {'acc_cov':None, 'gyro_cov': None,} 94 | if self.conf.propcov: 95 | cov = self.cov_decoder(feature) 96 | cov_state['acc_cov'] = cov[...,:3]; cov_state['gyro_cov'] = cov[...,3:] 97 | 98 | return {**cov_state, 'correction_acc': correction[...,:3], 'correction_gyro': correction[...,3:]} 99 | 100 | ## For reference 101 | def forward(self, data, init_state): 102 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1) 103 | feature = self.encoder(feature) 104 | correction = self.decoder(feature) 105 | 106 | # Correction update 107 | data['corrected_acc'] = correction[...,:3] + data["acc"] 108 | data['corrected_gyro'] = correction[...,3:] + data["gyro"] 109 | 110 | # covariance propagation 111 | cov_state = {'acc_cov':None, 'gyro_cov': None,} 112 | if self.conf.propcov: 113 | cov = self.cov_decoder(feature) 114 | cov_state['acc_cov'] = cov[...,:3]; cov_state['gyro_cov'] = cov[...,3:] 115 | 116 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state) 117 | return {**out_state, 'correction_acc': correction[...,:3], 'correction_gyro': correction[...,3:]} 118 | -------------------------------------------------------------------------------- /model/others.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from model.net import ModelBase 5 | 6 | ## For the Baseline 7 | class Identity(ModelBase): 8 | def __init__(self, conf, acc_cov = None, gyro_cov=None): 9 | super().__init__(conf) 10 | self.register_parameter('param', nn.Parameter(torch.zeros(1), requires_grad=True)) 11 | self.acc_cov, self.gyro_cov = acc_cov, gyro_cov 12 | 13 | def inference(self, data): 14 | return {'acc_cov':self.acc_cov, 'gyro_cov': self.gyro_cov, 15 | 'correction_acc': torch.zeros_like(data['acc']), 16 | 'correction_gyro': torch.zeros_like(data['gyro'])} 17 | 18 | def forward(self, data, init_state): 19 | data['corrected_acc'] = data["acc"] 20 | data['corrected_gyro'] = data["gyro"] 21 | cov_state = {'acc_cov':self.acc_cov, 'gyro_cov': self.gyro_cov} 22 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state) 23 | return out_state 24 | 25 | 26 | class ParamNet(ModelBase): 27 | def __init__(self, conf, acc_cov = None, gyro_cov=None): 28 | super().__init__(conf) 29 | self.acc_cov, self.gyro_cov = acc_cov, gyro_cov 30 | self.gyro_bias, self.acc_bias = torch.nn.Parameter(torch.zeros(3)), torch.nn.Parameter(torch.zeros(3)) 31 | self.gyro_cov, self.acc_cov = torch.nn.Parameter(torch.ones(3)), torch.nn.Parameter(torch.ones(3)) 32 | 33 | def forward(self, data, init_state): 34 | data['corrected_acc'] = data["acc"] + self.acc_bias 35 | data['corrected_gyro'] = data["gyro"] + self.gyro_bias 36 | 37 | # covariance propagation 38 | cov_state = {'acc_cov':None, 'gyro_cov': None,} 39 | if self.conf.propcov: 40 | cov_state['acc_cov'] = torch.zeros_like(data['corrected_acc']) + self.acc_cov**2 41 | cov_state['gyro_cov'] = torch.zeros_like(data['corrected_gyro']) + self.gyro_cov**2 42 | 43 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state) 44 | return {**out_state, 'correction_acc': self.acc_bias, 'correction_gyro': self.gyro_bias} 45 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | import torch.utils.data as Data 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | import argparse 8 | 9 | import tqdm, wandb 10 | from utils import move_to 11 | from model import net_dict 12 | from pyhocon import ConfigFactory 13 | from pyhocon import HOCONConverter as conf_convert 14 | 15 | from datasets import SeqeuncesDataset, collate_fcs 16 | from model.losses import get_loss 17 | from eval import evaluate 18 | 19 | torch.autograd.set_detect_anomaly(True) 20 | 21 | def train(network, loader, confs, epoch, optimizer): 22 | """ 23 | Train network for one epoch using a specified data loader 24 | Outputs all targets, predicts, predicted covariance params, and losses 25 | """ 26 | network.train() 27 | losses, pos_losses, rot_losses, vel_losses = 0, 0, 0, 0 28 | pred_cov_rot, pred_cov_vel, pred_cov_pos = 0, 0, 0 29 | acc_covs, gyro_covs = 0, 0 30 | 31 | t_range = tqdm.tqdm(loader) 32 | for i, (data, init_state, label) in enumerate(t_range): 33 | data, init_state, label = move_to([data, init_state, label], confs.device) 34 | inte_state = network(data, init_state) 35 | loss_state = get_loss(inte_state, label, confs) 36 | 37 | # statistics 38 | losses += loss_state['loss'].item() 39 | pos_losses += loss_state['pos'].item() 40 | rot_losses += loss_state['rot'].item() 41 | vel_losses += loss_state['vel'].item() 42 | 43 | if confs.propcov: 44 | acc_covs += inte_state["acc_cov"].mean().item() 45 | gyro_covs += inte_state["gyro_cov"].mean().item() 46 | pred_cov_pos += loss_state['pred_cov_pos'].mean().item() 47 | pred_cov_rot += loss_state['pred_cov_rot'].mean().item() 48 | pred_cov_vel += loss_state['pred_cov_vel'].mean().item() 49 | t_range.set_description(f'training epoch: %03d, losses: %.06f, position, %.06f rotation %.06f, pred_rot %.06f, pred_cov%.06f'%(epoch, \ 50 | loss_state['loss'], (pos_losses/(i+1)), (rot_losses/(i+1)), \ 51 | loss_state['pred_cov_rot'], loss_state['pred_cov_pos'])) 52 | 53 | else: 54 | t_range.set_description(f'training epoch: %03d, losses: %.06f, position, %.06f rotation %.06f, velocity %.06f'%(epoch, \ 55 | loss_state['loss'], (pos_losses/(i+1)), (rot_losses/(i+1)), loss_state['vel'])) 56 | 57 | t_range.refresh() 58 | optimizer.zero_grad() 59 | loss_state['loss'].backward() 60 | optimizer.step() 61 | 62 | return {"loss": (losses/(i+1)), "pos_loss": (pos_losses/(i+1)), "rot_loss": (rot_losses/(i+1)), "vel_loss":((vel_losses)/(i+1)), 63 | "pred_cov_rot": (pred_cov_rot/(i+1)), "pred_cov_vel": (pred_cov_vel/(i+1)), "pred_cov_pos": (pred_cov_pos/(i+1))} 64 | 65 | 66 | def test(network, loader, confs): 67 | network.eval() 68 | with torch.no_grad(): 69 | losses, pos_losses, rot_losses, vel_losses = 0, 0, 0, 0 70 | pred_cov_rot, pred_cov_vel, pred_cov_pos = 0, 0, 0 71 | acc_covs, gyro_covs = [], [] 72 | 73 | t_range = tqdm.tqdm(loader) 74 | for i, (data, init_state, label) in enumerate(t_range): 75 | 76 | data, init_state, label = move_to([data, init_state, label], confs.device) 77 | inte_state = network(data, init_state) 78 | 79 | loss_state = get_loss(inte_state, label, confs) 80 | # statistics 81 | losses += loss_state['loss'].item() 82 | pos_losses += loss_state["pos"].item() 83 | rot_losses += loss_state["rot"].item() 84 | vel_losses += loss_state['vel'].item() 85 | 86 | if confs.propcov: 87 | acc_covs.append(inte_state["acc_cov"].reshape(-1)) 88 | gyro_covs.append(inte_state["gyro_cov"].reshape(-1)) 89 | pred_cov_pos += loss_state['pred_cov_pos'].mean().item() 90 | pred_cov_rot += loss_state['pred_cov_rot'].mean().item() 91 | pred_cov_vel += loss_state['pred_cov_vel'].mean().item() 92 | 93 | t_range.set_description(f'testing losses: %.06f, position, %.06f rotation %.06f, vel %.06f'%(losses/(i+1), \ 94 | pos_losses/(i+1), rot_losses/(i+1), vel_losses/(i+1))) 95 | t_range.refresh() 96 | 97 | if acc_covs: 98 | acc_covs = torch.cat(acc_covs) 99 | if gyro_covs: 100 | gyro_covs = torch.cat(gyro_covs) 101 | 102 | return {"loss": (losses/(i+1)), "pos_loss":(pos_losses/(i+1)), "rot_loss":(rot_losses/(i+1)), "vel_loss":(vel_losses/(i+1)), 103 | "pred_cov_rot": (pred_cov_rot/(i+1)), "pred_cov_vel": (pred_cov_vel/(i+1)), "pred_cov_pos": (pred_cov_pos/(i+1)), "acc_covs": acc_covs, "gyro_covs": gyro_covs} 104 | 105 | 106 | def write_wandb(header, objs, epoch_i): 107 | if isinstance(objs, dict): 108 | for k, v in objs.items(): 109 | if isinstance(v, float): 110 | wandb.log({os.path.join(header, k): v}, epoch_i) 111 | else: 112 | wandb.log({header: objs}, step = epoch_i) 113 | 114 | 115 | def save_ckpt(network, optimizer, scheduler, epoch_i, test_loss, conf, save_best = False): 116 | if epoch_i%conf.train.save_freq==conf.train.save_freq-1: 117 | torch.save({ 118 | 'epoch': epoch_i, 119 | 'model_state_dict': network.state_dict(), 120 | 'optimizer_state_dict': optimizer.state_dict(), 121 | 'scheduler_state_dict': scheduler.state_dict(), 122 | 'best_loss': test_loss, 123 | }, os.path.join(conf.general.exp_dir, "ckpt/%04d.ckpt"%epoch_i)) 124 | 125 | if save_best: 126 | print("saving the best model", test_loss) 127 | torch.save({ 128 | 'epoch': epoch_i, 129 | 'model_state_dict': network.state_dict(), 130 | 'optimizer_state_dict': optimizer.state_dict(), 131 | 'scheduler_state_dict': scheduler.state_dict(), 132 | 'best_loss': test_loss, 133 | }, os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt")) 134 | 135 | torch.save({ 136 | 'epoch': epoch_i, 137 | 'model_state_dict': network.state_dict(), 138 | 'optimizer_state_dict': optimizer.state_dict(), 139 | 'scheduler_state_dict': scheduler.state_dict(), 140 | 'best_loss': test_loss, 141 | }, os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt")) 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path') 147 | parser.add_argument('--device', type=str, default="cuda:0", help="cuda or cpu, Default is cuda:0") 148 | parser.add_argument('--load_ckpt', default=False, action="store_true", help="If True, try to load the newest.ckpt in the \ 149 | exp_dir specificed in our config file.") 150 | parser.add_argument('--log', default=True, action="store_false", help="if True, save the meta data with wandb") 151 | args = parser.parse_args(); print(args) 152 | conf = ConfigFactory.parse_file(args.config) 153 | # torch.cuda.set_device(args.device) 154 | 155 | conf.train.device = args.device 156 | exp_folder = os.path.split(conf.general.exp_dir)[-1] 157 | conf_name = os.path.split(args.config)[-1].split(".")[0] 158 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name) 159 | 160 | if 'collate' in conf.dataset.keys(): 161 | collate_fn = collate_fcs[conf.dataset.collate] 162 | else: 163 | collate_fn = collate_fcs['base'] 164 | 165 | train_dataset = SeqeuncesDataset(data_set_config=conf.dataset.train) 166 | test_dataset = SeqeuncesDataset(data_set_config=conf.dataset.test) 167 | eval_dataset = SeqeuncesDataset(data_set_config=conf.dataset.eval) 168 | train_loader = Data.DataLoader(dataset=train_dataset, batch_size=conf.train.batch_size, shuffle=True, collate_fn=collate_fn) 169 | test_loader = Data.DataLoader(dataset=test_dataset, batch_size=conf.train.batch_size, shuffle=False, collate_fn=collate_fn) 170 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn) 171 | 172 | os.makedirs(os.path.join(conf.general.exp_dir, "ckpt"), exist_ok=True) 173 | with open(os.path.join(conf.general.exp_dir, "parameters.yaml"), "w") as f: 174 | f.write(conf_convert.to_yaml(conf)) 175 | 176 | if not args.log: 177 | wandb.disabled = True 178 | print("wandb is disabled") 179 | else: 180 | wandb.init(project= "AirIMU_" + exp_folder, 181 | config= conf.train, 182 | group = conf.train.network, 183 | name = conf_name,) 184 | 185 | ## optimizer 186 | network = net_dict[conf.train.network](conf.train).to(device = args.device, dtype = train_dataset.get_dtype()) 187 | optimizer = torch.optim.Adam(network.parameters(), lr = conf.train.lr, weight_decay=conf.train.weight_decay) # to use with ViTs 188 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor = conf.train.factor, patience = conf.train.patience, min_lr = conf.train.min_lr) 189 | best_loss = np.inf 190 | epoch = 0 191 | 192 | ## load the chkp if there exist 193 | if args.load_ckpt: 194 | if os.path.isfile(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt")): 195 | checkpoint = torch.load(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt"), map_location = args.device) 196 | network.load_state_dict(checkpoint["model_state_dict"]) 197 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 198 | scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) 199 | epoch = checkpoint['epoch'] 200 | best_loss = checkpoint['best_loss'] 201 | print("loaded state dict %s best_loss %f"%(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt"), best_loss)) 202 | else: 203 | print("Can't find the checkpoint") 204 | 205 | for epoch_i in range(epoch, conf.train.max_epoches): 206 | train_loss = train(network, train_loader, conf.train, epoch_i, optimizer) 207 | test_loss = test(network, test_loader, conf.train) 208 | print("train loss: %f test loss: %f"%(train_loss["loss"], test_loss["loss"])) 209 | 210 | # save the training meta information 211 | if args.log: 212 | write_wandb("train", train_loss, epoch_i) 213 | write_wandb("test", test_loss, epoch_i) 214 | write_wandb("lr", scheduler.optimizer.param_groups[0]['lr'], epoch_i) 215 | 216 | if epoch_i%conf.train.eval_freq == conf.train.eval_freq-1: 217 | eval_state = evaluate(network=network, loader = eval_loader, confs=conf.train) 218 | if args.log: 219 | write_wandb('eval/pos_loss', eval_state['loss']['pos'].mean(), epoch_i) 220 | write_wandb('eval/rot_loss', eval_state['loss']['rot'].mean(), epoch_i) 221 | write_wandb('eval/vel_loss', eval_state['loss']['vel'].mean(), epoch_i) 222 | write_wandb('eval/rot_dist', eval_state['loss']['rot_dist'].mean(), epoch_i) 223 | write_wandb('eval/vel_dist', eval_state['loss']['vel_dist'].mean(), epoch_i) 224 | write_wandb('eval/pos_dist', eval_state['loss']['pos_dist'].mean(), epoch_i) 225 | 226 | print("eval pos: %f eval rot: %f"%(eval_state['loss']['pos'].mean(), eval_state['loss']['rot'].mean())) 227 | 228 | scheduler.step(test_loss['loss']) 229 | if test_loss['loss'] < best_loss: 230 | best_loss = test_loss['loss'];save_best = True 231 | else: 232 | save_best = False 233 | 234 | save_ckpt(network, optimizer, scheduler, epoch_i, best_loss, conf, save_best=save_best,) 235 | 236 | wandb.finish() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .visualize import * 3 | from .integrate import * -------------------------------------------------------------------------------- /utils/deferentiate_vel.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | import numpy as np 3 | 4 | def interp_xyz(time, opt_time, xyz): 5 | 6 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0]) 7 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1]) 8 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2]) 9 | 10 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose() 11 | return inte_xyz 12 | 13 | def gradientvelo(xyz, imu_time, time): 14 | 15 | inte_xyz = interp_xyz(imu_time, time, xyz) 16 | time_interval = imu_time[1:] - imu_time[:-1] 17 | time_interval = np.append(time_interval, time_interval.mean()) 18 | velo_d = np.einsum('nd, n -> nd', np.gradient(inte_xyz, axis=0), 1/time_interval) 19 | 20 | return velo_d 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root', type=str, default='/data/datasets/yuhengq/tumvio') 26 | parser.add_argument('--seq', nargs='+', default=["dataset-room1_512_16", "dataset-room2_512_16", "dataset-room3_512_16", 27 | "dataset-room4_512_16", "dataset-room5_512_16", "dataset-room6_512_16"]) 28 | parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu") 29 | parser.add_argument('--load_ckpt', default=False, action="store_true") 30 | 31 | args = parser.parse_args(); print(args) 32 | 33 | for seq in args.seq: 34 | print(seq) 35 | gt_data = np.loadtxt(os.path.join(args.root, seq, "mav0/mocap0/data.csv"), dtype=float, delimiter=',') 36 | imu_data = np.loadtxt(os.path.join(args.root, seq, "mav0/imu0/data.csv"), dtype=float, delimiter=',') 37 | 38 | gt_time = gt_data[:,0]*1e-9 39 | xyz = gt_data[:,1:4] 40 | 41 | imu_time = imu_data[:, 0]*1e-9 42 | acc = imu_data[:, 4:] 43 | 44 | gt_velo = gradientvelo(xyz, imu_time, gt_time) 45 | to_save = np.concatenate([imu_time[:,None], gt_velo], 1) 46 | print("saving to ", os.path.join(args.root, seq, "mav0/mocap0/grad_velo.txt")) 47 | np.savetxt(os.path.join(args.root, seq, "mav0/mocap0/grad_velo.txt"), to_save) 48 | 49 | -------------------------------------------------------------------------------- /utils/integrate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from utils import move_to 4 | 5 | 6 | def integrate(integrator, loader, init, device="cpu", gtinit=False, save_full_traj=False, use_gt_rot=True): 7 | """ 8 | gtinit: 9 | If gtinit is True, use the ground truth initial state to integrate the small segments. 10 | This is used for the evaluation of the local pattern of the trajectory. 11 | If gtinit is False, use the predicted initial state to integrate the segments. 12 | Which is equivalent to integrate through the entire trajectory. 13 | save_full_traj: 14 | If save_full_traj is True, save the full trajectory. 15 | If save_full_traj is False, save the last frame of each segment. 16 | """ 17 | # states to ouput 18 | integrator.eval() 19 | out_state = dict() 20 | poses, poses_gt = [init['pos'][None,:]], [init['pos'][None,:]] 21 | orientations,orientations_gt = [init['rot'][None,:]], [init['rot'][None,:]] 22 | vel, vel_gt = [init['vel'][None,:]], [init['vel'][None,:]] 23 | covs = [torch.zeros(9, 9)] 24 | for idx, data in tqdm.tqdm(enumerate(loader)): 25 | data = move_to(data, device) 26 | if gtinit: 27 | init_state = { 28 | "pos": data["init_pos"][:,:1,:], 29 | "vel": data["init_vel"][:,:1,:], 30 | "rot": data["init_rot"][:,:1,:], 31 | } 32 | else: 33 | init_state = None 34 | 35 | init_rot = data['init_rot'] if use_gt_rot else None 36 | state = integrator( 37 | init_state = init_state, dt=data['dt'], 38 | gyro=data['gyro'], acc=data['acc'], 39 | rot=init_rot 40 | ) 41 | 42 | if save_full_traj: 43 | vel.append(state['vel'][..., :, :].cpu()) 44 | vel_gt.append(data['gt_vel'][..., :, :].cpu()) 45 | orientations.append(state['rot'][..., :, :].cpu()) 46 | orientations_gt.append(data['gt_rot'][..., :, :].cpu()) 47 | poses_gt.append(data['gt_pos'][..., :, :].cpu()) 48 | poses.append(state['pos'][..., :, :].cpu()) 49 | else: 50 | vel.append(state['vel'][..., -1:, :].cpu()) 51 | vel_gt.append(data['gt_vel'][..., -1:, :].cpu()) 52 | orientations.append(state['rot'][..., -1:, :].cpu()) 53 | orientations_gt.append(data['gt_rot'][..., -1:, :].cpu()) 54 | poses_gt.append(data['gt_pos'][..., -1:, :].cpu()) 55 | poses.append(state['pos'][..., -1:, :].cpu()) 56 | 57 | 58 | covs.append(state['cov'][..., -1, :, :].cpu()) 59 | out_state['vel'] = torch.cat(vel, dim=-2) 60 | out_state['vel_gt'] = torch.cat(vel_gt, dim=-2) 61 | 62 | out_state['orientations'] = torch.cat(orientations, dim=-2) 63 | out_state['orientations_gt'] = torch.cat(orientations_gt, dim=-2) 64 | 65 | out_state['poses'] = torch.cat(poses, dim=-2) 66 | out_state['poses_gt'] = torch.cat(poses_gt, dim=-2) 67 | 68 | out_state['covs'] = torch.stack(covs, dim=0) 69 | out_state['pos_dist'] = (out_state['poses'][:, 1:, :] - out_state['poses_gt'][:, 1:, :]).norm(dim=-1) 70 | out_state['vel_dist'] = (out_state['vel'][:, 1:, :] - out_state['vel_gt'][:, 1:, :]).norm(dim=-1) 71 | out_state['rot_dist'] = ((out_state['orientations_gt'][:, 1:, :].Inv() @ out_state['orientations'][:, 1:, :]).Log()).norm(dim=-1) 72 | return out_state -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os, warnings 2 | import torch 3 | import io, pickle 4 | import numpy as np 5 | from pyproj import Proj, transform 6 | from inspect import currentframe, getframeinfo 7 | 8 | from scipy.spatial.transform import Rotation as R 9 | from scipy.spatial.transform import Slerp 10 | 11 | def save_state(out_states:dict, in_state:dict): 12 | for k, v in in_state.items(): 13 | if v is None: 14 | continue 15 | elif isinstance(v, dict): 16 | save_state(out_states=out_states, in_state=v) 17 | elif k in out_states.keys(): 18 | out_states[k].append(v) 19 | else: 20 | out_states[k] = [v] 21 | 22 | def trans_ecef2wgs(traj): 23 | # Define the coordinate reference systems 24 | ecef = Proj(proj='geocent', ellps='WGS84', datum='WGS84') 25 | wgs84 = Proj(proj='latlong', datum='WGS84') 26 | 27 | # Convert ECEF coordinates to geographic coordinates 28 | trajectory_geo = [transform(ecef, wgs84, x, y, z, radians=False) for x, y, z in traj] 29 | 30 | return trajectory_geo 31 | 32 | def Gaussian_noise(num_nodes, sigma_x=0.05 ,sigma_y=0.05, sigma_z=0.05): 33 | std = torch.stack([torch.ones(num_nodes)*sigma_x, torch.ones(num_nodes)*sigma_y, torch.ones(num_nodes)*sigma_z], dim=-1) 34 | return torch.normal(mean = 0, std = std) 35 | 36 | def move_to(obj, device): 37 | if torch.is_tensor(obj):return obj.to(device) 38 | elif obj is None: 39 | return None 40 | elif isinstance(obj, dict): 41 | res = {} 42 | for k, v in obj.items(): 43 | res[k] = move_to(v, device) 44 | return res 45 | elif isinstance(obj, list): 46 | res = [] 47 | for v in obj: 48 | res.append(move_to(v, device)) 49 | return res 50 | elif isinstance(obj, np.ndarray): 51 | return torch.tensor(obj).to(device) 52 | else: 53 | raise TypeError("Invalid type for move_to", type(obj)) 54 | 55 | def qinterp(qs, t, t_int): 56 | qs = R.from_quat(qs.numpy()) 57 | slerp = Slerp(t, qs) 58 | interp_rot = slerp(t_int).as_quat() 59 | return torch.tensor(interp_rot) 60 | 61 | def lookAt(dir_vec, up = torch.tensor([0.,0.,1.], dtype=torch.float64), source = torch.tensor([0.,0.,0.], dtype=torch.float64)): 62 | ''' 63 | dir_vec: the tensor shall be (1) 64 | return the rotation matrix of the 65 | ''' 66 | if not isinstance(dir_vec, torch.Tensor): 67 | dir_vec = torch.tensor(dir_vec) 68 | def normalize(x): 69 | length = x.norm() 70 | if length< 0.005: 71 | length = 1 72 | warnings.warn('Normlization error that the norm is too small') 73 | return x/length 74 | 75 | zaxis = normalize(dir_vec - source) 76 | xaxis = normalize(torch.cross(zaxis, up)) 77 | yaxis = torch.cross(xaxis, zaxis) 78 | 79 | m = torch.tensor([ 80 | [xaxis[0], xaxis[1], xaxis[2]], 81 | [yaxis[0], yaxis[1], yaxis[2]], 82 | [zaxis[0], zaxis[1], zaxis[2]], 83 | ]) 84 | 85 | return m 86 | 87 | def cat_state(in_state:dict): 88 | pop_list = [] 89 | for k, v in in_state.items(): 90 | if len(v[0].shape) > 2: 91 | in_state[k] = torch.cat(v, dim=-2) 92 | else: 93 | pop_list.append(k) 94 | for k in pop_list: 95 | in_state.pop(k) 96 | 97 | class CPU_Unpickler(pickle.Unpickler): 98 | def find_class(self, module, name): 99 | if module == 'torch.storage' and name == '_load_from_bytes': 100 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 101 | else: 102 | return super().find_class(module, name) 103 | 104 | def write_board(writer, objs, epoch_i, header = ''): 105 | # writer = SummaryWriter(log_dir=conf.general.exp_dir) 106 | if isinstance(objs, dict): 107 | for k, v in objs.items(): 108 | if isinstance(v, float): 109 | writer.add_scalar(os.path.join(header, k), v, epoch_i) 110 | elif isinstance(objs, float): 111 | writer.add_scalar(header, v, epoch_i) 112 | 113 | def report_hasNan(x): 114 | cf = currentframe().f_back 115 | res = torch.any(torch.isnan(x)).cpu().item() 116 | if res: print(f"[hasnan!] {getframeinfo(cf).filename}:{cf.f_lineno}") 117 | 118 | def report_hasNeg(x): 119 | cf = currentframe().f_back 120 | res = torch.any(x < 0).cpu().item() 121 | if res: print(f"[hasneg!] {getframeinfo(cf).filename}:{cf.f_lineno}") 122 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import pypose as pp 5 | from scipy.spatial.transform import Rotation as T 6 | 7 | import matplotlib.pyplot as plt 8 | import os 9 | import argparse 10 | 11 | def visualize_rotations(rotations, rotations_gt, save_folder, save_prefix = ""): 12 | ## Visualize the euler angle 13 | gt_r_euler = rotations_gt.euler().cpu().numpy() 14 | r_euler = rotations.data.euler().cpu().numpy() 15 | 16 | fig, axs = plt.subplots(3,) 17 | 18 | fig.suptitle("integrated orientation v.s. gt orientation") 19 | axs[0].plot(r_euler[:,0]) 20 | axs[0].plot(gt_r_euler[:,0]) 21 | axs[0].legend(["euler_x", "euler_x_gt"]) 22 | 23 | axs[1].plot(r_euler[:,1]) 24 | axs[1].plot(gt_r_euler[:,1]) 25 | axs[1].legend(["euler_y", "euler_y_gt"]) 26 | 27 | axs[2].plot(r_euler[:,2]) 28 | axs[2].plot(gt_r_euler[:,2]) 29 | axs[2].legend(["euler_z", "euler_z_gt"]) 30 | plt.savefig(os.path.join(save_folder, save_prefix + "orientation.png")) 31 | 32 | r = rotations_gt[0] 33 | gt_r_euler = ((r.Inv()@rotations_gt).euler()).cpu().numpy() 34 | r_euler = ((r.Inv()@rotations).euler()).cpu().numpy() 35 | 36 | fig, axs = plt.subplots(3,) 37 | 38 | fig.suptitle("Incremental orientation v.s. gt orientation") 39 | axs[0].plot(r_euler[:,0]) 40 | axs[0].plot(gt_r_euler[:,0]) 41 | axs[0].legend(["euler_x", "euler_x_gt"]) 42 | 43 | axs[1].plot(r_euler[:,1]) 44 | axs[1].plot(gt_r_euler[:,1]) 45 | axs[1].legend(["euler_y", "euler_y_gt"]) 46 | 47 | axs[2].plot(r_euler[:,2]) 48 | axs[2].plot(gt_r_euler[:,2]) 49 | axs[2].legend(["euler_z", "euler_z_gt"]) 50 | plt.savefig(os.path.join(save_folder, save_prefix + "incremental_orientation.png")) 51 | 52 | 53 | def visualize_velocities(velocities, gt_velocities, save_folder, save_prefix = ""): 54 | fig, axs = plt.subplots(3,) 55 | velocities = velocities.detach().numpy() 56 | 57 | fig.suptitle("integrated velocity v.s. gt velocity") 58 | axs[0].plot(velocities[:,0]) 59 | axs[0].plot(gt_velocities[:,0]) 60 | axs[0].legend(["velocity", "gt velocity"]) 61 | 62 | axs[1].plot(velocities[:,1]) 63 | axs[1].plot(gt_velocities[:,1]) 64 | axs[1].legend(["velocity", "gt velocity"]) 65 | 66 | axs[2].plot(velocities[:,2]) 67 | axs[2].plot(gt_velocities[:,2]) 68 | axs[2].legend(["velocity", "gt velocity"]) 69 | plt.savefig(os.path.join(save_folder, save_prefix + "velocity.png")) 70 | 71 | 72 | def plot_2d_traj(trajectory, trajectory_gt, save_folder, vis_length = None, save_prefix = ""): 73 | 74 | if torch.is_tensor(trajectory): 75 | trajectory = trajectory.detach().cpu().numpy() 76 | trajectory_gt = trajectory_gt.detach().cpu().numpy() 77 | 78 | plt.clf() 79 | plt.figure(figsize=(3, 3),facecolor=(1, 1, 1)) 80 | 81 | ax = plt.axes() 82 | ax.plot(trajectory[:,0][:vis_length], trajectory[:,1][:vis_length], 'b') 83 | ax.plot(trajectory_gt[:,0][:vis_length], trajectory_gt[:,1][:vis_length], 'r') 84 | plt.title("PyPose IMU Integrator") 85 | plt.legend(["PyPose", "Ground Truth"],loc='right') 86 | 87 | plt.savefig(os.path.join(save_folder, save_prefix + "poses.png")) 88 | 89 | def plot_poses(points1, points2, title='', axlim=None, savefig = None): 90 | if torch.is_tensor(points1): 91 | points1 = points1.detach().cpu().numpy() 92 | if torch.is_tensor(points2): 93 | points2 = points2.detach().cpu().numpy() 94 | 95 | plt.figure(figsize=(7, 7)) 96 | ax = plt.axes(projection='3d') 97 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF") 98 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth") 99 | 100 | plt.title(title) 101 | ax.legend() 102 | if axlim is not None: 103 | ax.set_xlim(axlim[0]) 104 | ax.set_ylim(axlim[1]) 105 | ax.set_zlim(axlim[2]) 106 | if savefig is not None: 107 | plt.savefig(savefig) 108 | print('Saving to', savefig) 109 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 110 | 111 | 112 | def vis_cov_error_diag(pos_loss_xyzs, pred_pos_cov, save_folder, save_prefix = ""): 113 | """ 114 | pred_pos_covs 115 | """ 116 | 117 | for i, axis in enumerate(["x", "y", "z"]): 118 | if torch.is_tensor(pos_loss_xyzs): 119 | pos_loss_xyzs = pos_loss_xyzs.detach().cpu().numpy() 120 | pred_pos_cov = pred_pos_cov.detach().cpu().numpy() 121 | 122 | plt.clf() 123 | plt.grid() 124 | plt.ylim(0.0, 1) 125 | plt.scatter(pos_loss_xyzs[:,i].cpu().numpy(), torch.sqrt(pred_pos_cov)[:,i].cpu().numpy(), marker="o", s = 2) 126 | plt.legend(["covariance","error"], loc='left') 127 | 128 | plt.savefig(os.path.join(save_folder, save_prefix + "cov-error_%s.png"%axis)) 129 | 130 | 131 | def vis_rotation_error(ts, error, save_folder): 132 | title = "$SO(3)$ orientation error" 133 | ts = ts -ts[0] 134 | 135 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=(20, 12)) 136 | axs[0].set(ylabel='roll (deg)', title=title) 137 | axs[1].set(ylabel='pitch (deg)') 138 | axs[2].set(xlabel='$t$ (s)', ylabel='yaw (deg)') 139 | 140 | for i in range(3): 141 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU') 142 | axs[i].plot(ts, 180./np.pi * error[:, i].detach().cpu().numpy(), color='blue', label=r'net IMU') 143 | axs[i].set_ylim(-10, 10) 144 | axs[i].set_xlim(ts[0], ts[-1]) 145 | 146 | for i in range(len(axs)): 147 | axs[i].grid() 148 | axs[i].legend() 149 | fig.tight_layout() 150 | 151 | plt.savefig(save_folder + '_orientation_error.png') 152 | 153 | 154 | def vis_corrections( error, save_folder): 155 | title = "$acc & gyro correction" 156 | 157 | fig, axs = plt.subplots(6, 1, sharex=True, figsize=(20, 24)) 158 | axs[0].set(ylabel='x (m)', title=title) 159 | axs[1].set(ylabel='y (m)') 160 | axs[2].set(ylabel='z (m)') 161 | axs[3].set(ylabel='roll (deg)') 162 | axs[4].set(ylabel='pitch (deg)') 163 | axs[5].set(ylabel='yaw (deg)') 164 | 165 | for i in range(3): 166 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU') 167 | axs[i].plot(error[:,i], color='blue', label=r'net IMU') 168 | # axs[i].set_ylim(-0.2, 0.2) 169 | 170 | for i in range(3,6): 171 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU') 172 | axs[i].plot(180./np.pi * error[:, i].detach().cpu().numpy(), color='blue', label=r'net IMU') 173 | # axs[i].set_ylim(-2, 2) 174 | 175 | for i in range(len(axs)): 176 | axs[i].grid() 177 | axs[i].legend() 178 | fig.tight_layout() 179 | 180 | plt.savefig(save_folder + 'corrections.png') 181 | 182 | 183 | 184 | def plot_and_save(points, title='', axlim=None, savefig = None): 185 | points = points.detach().cpu().numpy() 186 | plt.figure(figsize=(7, 7)) 187 | ax = plt.axes(projection='3d') 188 | ax.plot3D(points[:,0], points[:,1], points[:,2], 'b') 189 | plt.title(title) 190 | if axlim is not None: 191 | ax.set_xlim(axlim[0]) 192 | ax.set_ylim(axlim[1]) 193 | ax.set_zlim(axlim[2]) 194 | if savefig is not None: 195 | plt.savefig(savefig) 196 | print('Saving to', savefig) 197 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 198 | 199 | def plot2_and_save(points1, points2, title='', axlim=None, savefig = None): 200 | points1 = points1.detach().cpu().numpy() 201 | points2 = points2.detach().cpu().numpy() 202 | plt.figure(figsize=(7, 7)) 203 | ax = plt.axes(projection='3d') 204 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b') 205 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r') 206 | plt.title(title) 207 | if axlim is not None: 208 | ax.set_xlim(axlim[0]) 209 | ax.set_ylim(axlim[1]) 210 | ax.set_zlim(axlim[2]) 211 | if savefig is not None: 212 | plt.savefig(savefig) 213 | print('Saving to', savefig) 214 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 215 | 216 | # .plot(x, y, marker="o", markersize=20, markeredgecolor="red", markerfacecolor="green") 217 | 218 | def plot_nodes(points1, points2, point_nodes, title='', axlim=None, savefig = None): 219 | points1 = points1.detach().cpu().numpy() 220 | points2 = points2.detach().cpu().numpy() 221 | point_nodes = point_nodes.detach().cpu().numpy() 222 | plt.figure(figsize=(7, 7)) 223 | ax = plt.axes(projection='3d') 224 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF") 225 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth") 226 | ax.scatter(point_nodes[:,0], point_nodes[:,1], point_nodes[:,2], marker="o", 227 | facecolor = "yellow", edgecolor="green", label = "GPS signal") 228 | 229 | plt.title(title) 230 | ax.legend() 231 | if axlim is not None: 232 | ax.set_xlim(axlim[0]) 233 | ax.set_ylim(axlim[1]) 234 | ax.set_zlim(axlim[2]) 235 | if savefig is not None: 236 | plt.savefig(savefig) 237 | print('Saving to', savefig) 238 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 239 | 240 | def plot_trajs(points, labels, title='', axlim=None, savefig = None): 241 | for i, p in enumerate(points): 242 | if torch.is_tensor(p): 243 | points[i] = points[i].detach().cpu().numpy() 244 | 245 | plt.figure(figsize=(7, 7)) 246 | ax = plt.axes(projection='3d') 247 | for i, p in enumerate(points): 248 | ax.plot3D(p[:,0], p[:,1], p[:,2], label = labels[i]) 249 | 250 | plt.title(title) 251 | ax.legend() 252 | if axlim is not None: 253 | ax.set_xlim(axlim[0]) 254 | ax.set_ylim(axlim[1]) 255 | ax.set_zlim(axlim[2]) 256 | if savefig is not None: 257 | plt.savefig(savefig) 258 | print('Saving to', savefig) 259 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 260 | 261 | def plot_nodes_2d(points1, points2, point_nodes, title='', axlim=None, savefig = None): 262 | points1 = points1.detach().cpu().numpy() 263 | points2 = points2.detach().cpu().numpy() 264 | point_nodes = point_nodes.detach().cpu().numpy() 265 | plt.figure(figsize=(7, 7)) 266 | ax = plt.axes(projection='3d') 267 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF") 268 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth") 269 | ax.scatter(point_nodes[:,0], point_nodes[:,1], point_nodes[:,2], marker="o", 270 | facecolor = "yellow", edgecolor="green", label = "GPS signal") 271 | 272 | plt.title(title) 273 | ax.legend() 274 | if axlim is not None: 275 | ax.set_xlim(axlim[0]) 276 | ax.set_ylim(axlim[1]) 277 | ax.set_zlim(axlim[2]) 278 | if savefig is not None: 279 | plt.savefig(savefig) 280 | print('Saving to', savefig) 281 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() 282 | 283 | 284 | 285 | def plot_trajs(points, labels, title='', axlim=None, savefig = None): 286 | for i, p in enumerate(points): 287 | if torch.is_tensor(p): 288 | points[i] = points[i].detach().cpu().numpy() 289 | 290 | plt.figure(figsize=(7, 7)) 291 | ax = plt.axes(projection='3d') 292 | for i, p in enumerate(points): 293 | ax.plot3D(p[:,0], p[:,1], p[:,2], label = labels[i]) 294 | 295 | plt.title(title) 296 | ax.legend() 297 | 298 | ## Take the largest range 299 | x_len = ax.get_xlim()[1] - ax.get_xlim()[0] 300 | y_len = ax.get_ylim()[1] - ax.get_ylim()[0] 301 | z_len = ax.get_zlim()[1] - ax.get_zlim()[0] 302 | x_mean = np.mean(ax.get_xlim()) 303 | y_mean = np.mean(ax.get_ylim()) 304 | z_mean = np.mean(ax.get_zlim()) 305 | _len = np.max([x_len, y_len, z_len]) 306 | 307 | ax.set_xlim(x_mean - _len / 2, x_mean + _len / 2) 308 | ax.set_ylim(y_mean - _len / 2, y_mean + _len / 2) 309 | ax.set_zlim(z_mean - _len / 2, z_mean + _len / 2) 310 | 311 | if axlim is not None: 312 | ax.set_xlim(axlim[0]) 313 | ax.set_ylim(axlim[1]) 314 | ax.set_zlim(axlim[2]) 315 | if savefig is not None: 316 | plt.savefig(savefig) 317 | print('Saving to', savefig) 318 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim() -------------------------------------------------------------------------------- /utils/visualize_state.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import pypose as pp 5 | import matplotlib.pyplot as plt 6 | import matplotlib.patches as mpatches 7 | 8 | 9 | def visualize_state_error(save_prefix, relative_outstate, relative_infstate, \ 10 | save_folder=None, mask=None, file_name="state_error_compare.png"): 11 | if mask is None: 12 | outstate_pos_err = relative_outstate['pos_dist'][0] 13 | outstate_vel_err = relative_outstate['vel_dist'][0] 14 | outstate_rot_err = relative_outstate['rot_dist'][0] 15 | 16 | infstate_pos_err = relative_infstate['pos_dist'][0] 17 | infstate_vel_err = relative_infstate['vel_dist'][0] 18 | infstate_rot_err = relative_infstate['rot_dist'][0] 19 | else: 20 | outstate_pos_err = relative_outstate['pos_dist'][0, mask] 21 | outstate_vel_err = relative_outstate['vel_dist'][0, mask] 22 | outstate_rot_err = relative_outstate['rot_dist'][0, mask] 23 | 24 | infstate_pos_err = relative_infstate['pos_dist'][0, mask] 25 | infstate_vel_err = relative_infstate['vel_dist'][0, mask] 26 | infstate_rot_err = relative_infstate['rot_dist'][0, mask] 27 | 28 | fig, axs = plt.subplots(3,) 29 | fig.suptitle("Integration error vs AirIMU Integration error") 30 | 31 | axs[0].plot(outstate_pos_err,color = 'b',linewidth=1) 32 | axs[0].plot(infstate_pos_err,color = 'red',linewidth=1) 33 | axs[0].legend(["integration_pos_error", "AirIMU_pos_error"]) 34 | axs[0].grid(True) 35 | 36 | axs[1].plot(outstate_vel_err,color = 'b',linewidth=1) 37 | axs[1].plot(infstate_vel_err,color = 'red',linewidth=1) 38 | axs[1].legend(["integration_vel_error", "AirIMU_vel_error"]) 39 | axs[1].grid(True) 40 | 41 | axs[2].plot(outstate_rot_err,color = 'b',linewidth=1) 42 | axs[2].plot(infstate_rot_err,color = 'red',linewidth=1) 43 | axs[2].legend(["integration_rot_error", "AirIMU_rot_error"]) 44 | axs[2].grid(True) 45 | 46 | plt.tight_layout() 47 | if save_folder is not None: 48 | plt.savefig(os.path.join(save_folder, save_prefix + file_name), dpi = 300) 49 | plt.show() 50 | 51 | 52 | def visualize_rotations(save_prefix, gt_rot, out_rot, inf_rot = None,save_folder=None): 53 | 54 | gt_euler = 180./np.pi* pp.SO3(gt_rot).euler() 55 | outstate_euler = 180./np.pi* pp.SO3(out_rot).euler() 56 | 57 | legend_list = ["roll","pitch", "yaw"] 58 | fig, axs = plt.subplots(3,) 59 | fig.suptitle("integrated orientation") 60 | for i in range(3): 61 | axs[i].plot(outstate_euler[:,i],color = 'b',linewidth=0.9) 62 | axs[i].plot(gt_euler[:,i],color = 'mediumseagreen',linewidth=0.9) 63 | axs[i].legend(["Integrated_"+legend_list[i],"gt_"+legend_list[i]]) 64 | axs[i].grid(True) 65 | 66 | if inf_rot is not None: 67 | infstate_euler = 180./np.pi* pp.SO3(inf_rot).euler() 68 | print(infstate_euler.shape) 69 | for i in range(3): 70 | axs[i].plot(infstate_euler[:,i],color = 'red',linewidth=0.9) 71 | axs[i].legend(["Integrated_"+legend_list[i],"gt_"+legend_list[i],"AirIMU_"+legend_list[i]]) 72 | plt.tight_layout() 73 | if save_folder is not None: 74 | plt.savefig(os.path.join(save_folder, save_prefix+ "_orientation_compare.png"), dpi = 300) 75 | plt.show() 76 | 77 | 78 | def visualize_trajectory(save_prefix, save_folder, outstate, infstate): 79 | gt_x, gt_y, gt_z = torch.split(outstate["poses_gt"][0].cpu(), 1, dim=1) 80 | rawTraj_x, rawTraj_y, rawTraj_z = torch.split(outstate["poses"][0].cpu(), 1, dim=1) 81 | airTraj_x, airTraj_y, airTraj_z = torch.split(infstate["poses"][0].cpu(), 1, dim=1) 82 | 83 | fig, ax = plt.subplots() 84 | ax.plot(rawTraj_x, rawTraj_y, label="Raw") 85 | ax.plot(airTraj_x, airTraj_y, label="AirIMU") 86 | ax.plot(gt_x , gt_y , label="Ground Truth") 87 | 88 | ax.set_xlabel('X axis') 89 | ax.set_ylabel('Y axis') 90 | ax.legend() 91 | ax.set_aspect('equal', adjustable='box') 92 | 93 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_xy.png"), dpi = 300) 94 | plt.close() 95 | 96 | ########################################################### 97 | 98 | fig, ax = plt.subplots() 99 | ax.plot(rawTraj_x, rawTraj_z, label="Raw") 100 | ax.plot(airTraj_x, airTraj_z, label="AirIMU") 101 | ax.plot(gt_x , gt_z , label="Ground Truth") 102 | 103 | ax.set_xlabel('X axis') 104 | ax.set_ylabel('Z axis') 105 | ax.legend() 106 | ax.set_aspect('equal', adjustable='box') 107 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_xz.png"), dpi = 300) 108 | plt.close() 109 | 110 | ########################################################### 111 | 112 | fig, ax = plt.subplots() 113 | ax.plot(rawTraj_y, rawTraj_z, label="Raw") 114 | ax.plot(airTraj_y, airTraj_z, label="AirIMU") 115 | ax.plot(gt_y , gt_z , label="Ground Truth") 116 | 117 | ax.set_xlabel('Y axis') 118 | ax.set_ylabel('Z axis') 119 | ax.legend() 120 | ax.set_aspect('equal', adjustable='box') 121 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_yz.png"), dpi = 300) 122 | plt.close() 123 | 124 | ########################################################### 125 | 126 | fig = plt.figure() 127 | ax = fig.add_subplot(111, projection='3d') 128 | 129 | elevation_angle = 20 # Change the elevation angle (view from above/below) 130 | azimuthal_angle = 30 # Change the azimuthal angle (rotate around z-axis) 131 | 132 | ax.view_init(elevation_angle, azimuthal_angle) # Set the view 133 | 134 | # Plotting the ground truth and inferred poses 135 | ax.plot(rawTraj_x, rawTraj_y, rawTraj_z, label="Raw") 136 | ax.plot(airTraj_x, airTraj_y, airTraj_z, label="AirIMU") 137 | ax.plot(gt_x , gt_y , gt_z , label="Ground Truth") 138 | 139 | # Adding labels 140 | ax.set_xlabel('X axis') 141 | ax.set_ylabel('Y axis') 142 | ax.set_zlabel('Z axis') 143 | ax.legend() 144 | 145 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_3d.png"), dpi = 300) 146 | plt.close() 147 | 148 | 149 | def box_plot_wrapper(ax, data, edge_color, fill_color, **kwargs): 150 | bp = ax.boxplot(data, **kwargs) 151 | 152 | for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']: 153 | plt.setp(bp[element], color=edge_color) 154 | 155 | for patch in bp['boxes']: 156 | patch.set(facecolor=fill_color) 157 | 158 | return bp 159 | 160 | 161 | def plot_boxes(folder, input_data, metrics, show_metrics): 162 | fig, ax = plt.subplots(dpi=300) 163 | raw_ticks = [_-0.12 for _ in range(1, len(metrics) + 1)] 164 | air_ticks = [_+0.12 for _ in range(1, len(metrics) + 1)] 165 | label_ticks = [_ for _ in range(1, len(metrics) + 1)] 166 | 167 | raw_data = [input_data[metric + "(raw)" ] for metric in metrics] 168 | air_data = [input_data[metric + "(AirIMU)"] for metric in metrics] 169 | 170 | # ax.boxplot(data, patch_artist=True, positions=ticks, widths=.2) 171 | box_plot_wrapper(ax, raw_data, edge_color="black", fill_color="royalblue", positions=raw_ticks, patch_artist=True, widths=.2) 172 | box_plot_wrapper(ax, air_data, edge_color="black", fill_color="gold", positions=air_ticks, patch_artist=True, widths=.2) 173 | ax.set_xticks(label_ticks) 174 | ax.set_xticklabels(show_metrics) 175 | 176 | # Create color patches for legend 177 | gold_patch = mpatches.Patch(color='gold', label='AirIMU') 178 | royalblue_patch = mpatches.Patch(color='royalblue', label='Raw') 179 | ax.legend(handles=[gold_patch, royalblue_patch]) 180 | 181 | plt.savefig(os.path.join(folder, "Metrics.png"), dpi = 300) 182 | plt.close() 183 | 184 | --------------------------------------------------------------------------------