├── LICENSE
├── README.md
├── configs
├── datasets
│ ├── BaselineEuroc
│ │ └── Euroc_1000.conf
│ ├── KITTIDataset
│ │ └── KITTI26.conf
│ └── TUMVI
│ │ └── room_1000.conf
├── exp
│ ├── EuRoC
│ │ ├── baseline.conf
│ │ └── codenet.conf
│ ├── KITTI
│ │ └── codenet.conf
│ └── TUMVI
│ │ └── codenet.conf
└── train
│ └── train.conf
├── datasets
├── EuRoCdataset.py
├── KITTIdataset.py
├── TUMdataset.py
├── __init__.py
├── dataset.py
└── dataset_utils.py
├── doc
├── alto.gif
└── model.png
├── eval.py
├── evaluation
└── evaluate_state.py
├── inference.py
├── model
├── __init__.py
├── cnn.py
├── code.py
├── loss_func.py
├── losses.py
├── net.py
└── others.py
├── train.py
└── utils
├── __init__.py
├── deferentiate_vel.py
├── integrate.py
├── utils.py
├── visualize.py
└── visualize_state.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, Carnegie Mellon University
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AirIMU : Learning Uncertainty Propagation for Inertial Odometry
2 | [](./LICENSE)
3 | [](https://www.youtube.com/watch?v=fTX1u-e7wtU)
4 | [](https://arxiv.org/abs/2310.04874)
5 | [](https://airimu.github.io/)
6 |
7 |
8 | 
9 | ## 📢 Latest News
10 | - [2025-02-01] Introducing Our New Project!
11 | 🚀 [**AirIO : Learning Inertial Odometry with Enhanced IMU Feature Observability**](https://github.com/Air-IO/Air-IO)
12 | ```
13 | AirIO achieves up to 86.6% performance boost over SOTA methods:
14 |
15 | - ✅ Tailored specifically for drones
16 | - ✅ No external sensors or control inputs required
17 | - ✅ Generalizes to unseen trajectories
18 | - ✅ Explicitly encodes UAV attitude and predicts velocity in body-frame representation
19 | ```
20 |
21 | ## Installation
22 |
23 | This work is based on pypose. Follow the instruction and install the newest realase of pypose:
24 | https://github.com/pypose/pypose
25 |
26 |
27 | ## Dataset
28 | > **Note**: Remember to reset the `data_root` in `configs/datasets/${DATASET}/${DATASET}.conf`.
29 |
30 | Download the Euroc dataset from:
31 | https://projects.asl.ethz.ch/datasets/doku.php?id=kmavvisualinertialdatasets
32 |
33 | Download the TUM VI dataset from:
34 | https://cvg.cit.tum.de/data/datasets/visual-inertial-dataset
35 |
36 | Download the KITTI dataset from:
37 | https://www.cvlibs.net/datasets/kitti/
38 |
39 | ## Pretrained Model
40 | > **Note**: You can download our trained ckpt here.
41 |
42 |
43 | [KITTI](https://github.com/sleepycan/AirIMU/releases/download/pretrained_model/KITTI_odom_model.zip)
44 |
45 | [EuRoC](https://github.com/sleepycan/AirIMU/releases/download/pretrained_model_euroc/EuRoCWholeaug.zip)
46 | ## Train
47 |
48 | Easy way to start the training using the exisiting configuration.
49 | > **Note**:You can also create your own configuration file for different datasets and set the parameters accordingly.
50 |
51 | ```
52 | python train.py --config configs/exp/EuRoC/codenet.conf
53 | ```
54 |
55 | More specific option:
56 |
57 | ```
58 | usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--load_ckpt] [--log]
59 |
60 | optional arguments:
61 | -h, --help show this help message and exit
62 | --config CONFIG config file path
63 | --device DEVICE cuda or cpu, Default is cuda:0
64 | --load_ckpt If True, try to load the newest.ckpt in the exp_dir specificed in our config file.
65 | --log if True, save the meta data with wandb, Default is True
66 | ```
67 |
68 | ## Evaluation
69 |
70 | To evaluate the model and generate network inference file net_output.pickle, run the following command:
71 | ```
72 | python inference.py --config configs/exp/EuRoC/codenet.conf
73 | ```
74 |
75 |
76 |
77 | You can use the evaluation tool to assess your model performance with net_output.pickle, run the following command.
78 | > **Note**: Make sure to replace path/to/net_output_directory with the directory path where your network output pickle file is stored.
79 |
80 | ```
81 | python evaluation/evaluate_state.py --dataconf configs/datasets/${DATASET}/${DATASET}.conf --exp path/to/net_output_directory
82 | ```
83 |
84 |
85 | More specific option for the evaluation tool:
86 |
87 | ```
88 | usage: evaluation/evaluate_state.py [-h] [--dataconf] [--device] [--exp] [--seqlen] [--savedir] [--usegtrot] [--mask]
89 |
90 | optional arguments:
91 | -h, --help show this help message and exit
92 | --config config file path
93 | --device cuda or cpu, Default is cuda:0
94 | --exp the directory path where your network output pickle file is stored
95 | --seqlen the length of the integration sequence
96 | --savedir the save diretory for the evaluation results, default path is "./result/loss_result"
97 | --usegtrot use ground truth rotation for gravity compensation, default is true
98 | --mask mask the segments if needed.
99 | ```
100 |
101 |
102 |
103 |
104 |
105 | ### Cite Our Work
106 |
107 | Thanks for using our work. You can cite it as:
108 |
109 | ```bib
110 | @article{qiu2023airimu,
111 | title={AirIMU: Learning Uncertainty Propagation for Inertial Odometry},
112 | author={Yuheng Qiu and Chen Wang and Can Xu and Yutian Chen and Xunfei Zhou and Youjie Xia and Sebastian Scherer},
113 | year={2023},
114 | eprint={2310.04874},
115 | archivePrefix={arXiv},
116 | primaryClass={cs.RO}
117 | }
118 | ```
119 |
--------------------------------------------------------------------------------
/configs/datasets/BaselineEuroc/Euroc_1000.conf:
--------------------------------------------------------------------------------
1 | train:
2 | {
3 | mode: train
4 | calib: False
5 | data_list:
6 | [
7 | {name: Euroc
8 | window_size: 1000
9 | step_size: 10
10 | data_root: PATH_TO_EuRoC
11 | data_drive: [V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy]
12 | },
13 | ]
14 | }
15 |
16 | test:
17 | {
18 | mode: test
19 | calib: False
20 | data_list:
21 | [
22 | {name: Euroc
23 | window_size: 1000
24 | step_size: 200
25 | data_root: PATH_TO_EuRoC
26 | data_drive: [V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy]
27 | },
28 | ]
29 | }
30 |
31 | eval:
32 | {
33 | mode: evaluate
34 | calib: False
35 | data_list:
36 | [{
37 | name: Euroc
38 | window_size: 1000
39 | step_size: 200
40 | data_root: PATH_TO_EuRoC
41 | data_drive: [MH_02_easy, MH_04_difficult, V1_03_difficult, V2_02_medium, V1_01_easy]
42 | },
43 | ]
44 |
45 | }
46 |
47 |
48 | inference:
49 | {
50 | mode: evaluate
51 | calib: False
52 | data_list:
53 | [{
54 | name: Euroc
55 | window_size: 1000
56 | step_size: 1000
57 | data_root: PATH_TO_EuRoC
58 | data_drive: [MH_02_easy, MH_04_difficult, V1_03_difficult, V2_02_medium, V1_01_easy, V1_02_medium, V2_01_easy, V2_03_difficult, MH_03_medium, MH_05_difficult, MH_01_easy]
59 | },
60 | ]
61 | }
62 |
--------------------------------------------------------------------------------
/configs/datasets/KITTIDataset/KITTI26.conf:
--------------------------------------------------------------------------------
1 | train:
2 | {
3 | mode: train
4 | calib: False
5 | data_list:
6 | [
7 | {
8 | name: KITTI,
9 | data_root: PATH_TO_KITTI_2011_09_26,
10 | data_drive: ["0027", "0059", "0023", "0070", "0095", "0064", "0091", "0051", "0104", "0036", "0101", "0032", "0061", "0084", "0013", "0096", "0005", "0056", "0001", "0019", "0028", "0086", "0015", "0046", "0011", "0009", "0117", "0087"]
11 | window_size: 50,
12 | step_size: 25
13 | }
14 | ]
15 | }
16 |
17 | test:
18 | {
19 | mode: evaluate
20 | calib: False
21 | data_list:
22 | [
23 | {
24 | name: KITTI,
25 | data_root: PATH_TO_KITTI_2011_09_26,
26 | data_drive: ["0018", "0106", "0029", "0014", "0039", "0035", "0022"],
27 | window_size: 50,
28 | step_size: 25
29 | }
30 | ]
31 | }
32 |
33 | eval:
34 | {
35 | mode: evaluate
36 | calib: False
37 | data_list:
38 | [
39 | {
40 | name: KITTI,
41 | data_root: PATH_TO_KITTI_2011_09_26,
42 | data_drive: ["0018", "0106", "0029", "0014", "0039", "0035", "0022"],
43 | window_size: 50,
44 | step_size: 25
45 | }
46 | ]
47 | }
48 |
--------------------------------------------------------------------------------
/configs/datasets/TUMVI/room_1000.conf:
--------------------------------------------------------------------------------
1 | train:
2 | {
3 | mode: train
4 | calib: False
5 | data_list:
6 | [
7 | {name: TUMVI
8 | window_size: 1000
9 | step_size: 10
10 | data_root: PATH_TO_TUM
11 | data_drive: [dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16]
12 | },
13 | ]
14 | }
15 |
16 | test:
17 | {
18 | mode: test
19 | calib: False
20 | data_list:
21 | [
22 | {name: TUMVI
23 | window_size: 1000
24 | step_size: 100
25 | data_root: PATH_TO_TUM
26 | data_drive: [dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16]
27 | },
28 | ]
29 | }
30 |
31 | eval:
32 | {
33 | mode: evaluate
34 | calib: False
35 | data_list:
36 | [
37 | {name: TUMVI
38 | window_size: 1000
39 | step_size: 100
40 | data_root: PATH_TO_TUM
41 | data_drive: [dataset-room2_512_16, dataset-room4_512_16, dataset-room6_512_16]
42 | },
43 | ]
44 | }
45 |
46 | inference:
47 | {
48 | mode: evaluate
49 | calib: False
50 | data_list:
51 | [
52 | {name: TUMVI
53 | window_size: 1000
54 | step_size: 100
55 | data_root: PATH_TO_TUM
56 | data_drive: [dataset-room2_512_16, dataset-room4_512_16, dataset-room6_512_16,dataset-room1_512_16, dataset-room3_512_16, dataset-room5_512_16, dataset-room1_512_16]
57 | },
58 | ]
59 | }
--------------------------------------------------------------------------------
/configs/exp/EuRoC/baseline.conf:
--------------------------------------------------------------------------------
1 | general:
2 | {
3 | exp_dir: experiments/EuRoC
4 | }
5 |
6 | dataset:
7 | {
8 | include "../../datasets/BaselineEuroc/Euroc_1000_half.conf"
9 | }
10 |
11 | train:
12 | {
13 | include "../../train/train.conf"
14 | batch_size: 128
15 | rot_weight: 1e2
16 | pos_weight: 1e2
17 | vel_weight: 1e1
18 | propcov: False
19 |
20 | network: iden
21 | sampling: 50
22 | gtrot:True
23 |
24 | sampling: 50
25 | loss: Huber_loss0005
26 | rotloss: Huber_loss005
27 | lr: 1e-2
28 | }
--------------------------------------------------------------------------------
/configs/exp/EuRoC/codenet.conf:
--------------------------------------------------------------------------------
1 | general:
2 | {
3 | exp_dir: experiments/EuRoC
4 | }
5 |
6 | dataset:
7 | {
8 | include "../../datasets/BaselineEuroc/Euroc_1000.conf"
9 | collate: padding9
10 | }
11 |
12 | train:
13 | {
14 | include "../../train/train.conf"
15 | lr: 1e-3
16 | batch_size: 128
17 | rot_weight: 1e3
18 | pos_weight: 1e2
19 | vel_weight: 1e1
20 | cov_weight: 1e-4
21 |
22 | network: codenet
23 | gtrot:True
24 | propcov:True
25 | covaug:True
26 |
27 | sampling: 50
28 | loss: Huber_loss005
29 | rotloss: Huber_loss005
30 | }
--------------------------------------------------------------------------------
/configs/exp/KITTI/codenet.conf:
--------------------------------------------------------------------------------
1 | general:
2 | {
3 | exp_dir: experiments/KITTI
4 | }
5 |
6 | dataset:
7 | {
8 | include "../../datasets/KITTIDataset/KITTI26.conf"
9 | }
10 |
11 | train:
12 | {
13 | network: codenetkitti
14 | lr: 1e-3
15 | weight_decay: 1e-3,
16 | batch_size: 128
17 | min_lr: 1e-7
18 | max_epoches: 500
19 | patience: 50
20 | factor: 0.5
21 |
22 | save_freq: 5
23 | eval_freq: 1
24 |
25 | gtinit: True
26 | propcov: True
27 | gtrot:True
28 |
29 | rot_weight: 1e1
30 | pos_weight: 1e1
31 | vel_weight: 1e1
32 | cov_weight: 1e1
33 |
34 | sampling: False
35 | loss: Huber_loss0005
36 | rotloss: Huber_loss005
37 | }
38 |
--------------------------------------------------------------------------------
/configs/exp/TUMVI/codenet.conf:
--------------------------------------------------------------------------------
1 | general:
2 | {
3 | exp_dir: /user/yqiu/data/yuhengq/AirIMU/TUMVI
4 | }
5 |
6 | dataset:
7 | {
8 |
9 | include "../../datasets/TUMVI/room_1000.conf"
10 | collate: padding9
11 | }
12 |
13 | train:
14 | {
15 | include "../../train/train.conf"
16 | lr: 1e-3
17 | batch_size: 128
18 | rot_weight: 1e2
19 | pos_weight: 1e2
20 | vel_weight: 1e2
21 | cov_weight: 1e-4
22 |
23 | network: codenet
24 | covaug:True
25 | gtrot:True
26 | propcov:True
27 |
28 | sampling: 50
29 | loss: Huber_loss005
30 | rotloss: Huber_loss005
31 | }
--------------------------------------------------------------------------------
/configs/train/train.conf:
--------------------------------------------------------------------------------
1 | lr: 1e-3
2 | weight_decay: 1e-4
3 |
4 | batch_size: 128
5 | min_lr: 1e-5
6 | max_epoches: 100
7 | patience: 10
8 | factor: 0.1
9 |
10 | save_freq: 5
11 | eval_freq: 1
12 |
13 | cov_weight: None
14 | sampling: 50
15 | gtinit: True
16 |
--------------------------------------------------------------------------------
/datasets/EuRoCdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pypose as pp
5 | from utils import qinterp, lookAt
6 | from .dataset import Sequence
7 |
8 | class Euroc(Sequence):
9 | """
10 | Output:
11 | acce: the accelaration in **world frame**
12 | """
13 | def __init__(self, data_root, data_name, intepolate = True, calib = False, load_vicon = False, glob_coord=False, **kwargs):
14 | super(Euroc, self).__init__()
15 | (
16 | self.data_root, self.data_name,
17 | self.data,
18 | self.ts,
19 | self.targets,
20 | self.orientations,
21 | self.gt_pos,
22 | self.gt_ori,
23 | ) = (data_root, data_name, dict(), None, None, None, None, None)
24 |
25 | self.camera_ext_R = pp.mat2SO3(np.array([[0.0148655429818, -0.999880929698, 0.00414029679422,],
26 | [0.999557249008, 0.0149672133247, 0.025715529948, ],
27 | [-0.0257744366974, 0.00375618835797, 0.999660727178,],]), check= False)
28 | self.camera_ext_t = torch.tensor(np.array([-0.0216401454975, -0.064676986768, 0.00981073058949,]))
29 | self.vicon_ext_R = pp.mat2SO3(np.array([[0.33638, -0.01749, 0.94156],[-0.02078, -0.99972, -0.01114],[0.94150, -0.01582, -0.33665]]), check= False)
30 | self.vicon_ext_t = torch.tensor(np.array([0.06901, -0.02781,-0.12395]))
31 | self.ext_T = pp.SE3(torch.cat((self.camera_ext_t, self.camera_ext_R)))
32 | self.gravity = torch.tensor([0., 0., 9.81007], dtype=torch.float64)
33 |
34 | data_path = os.path.join(data_root, data_name)
35 | self.load_imu(data_path)
36 | self.load_gt(data_path)
37 | if load_vicon:
38 | self.load_vicon(data_path)
39 |
40 | # EUROC require an interpolation
41 | if intepolate:
42 | t_start = np.max([self.data['gt_time'][0], self.data['time'][0]])
43 | t_end = np.min([self.data['gt_time'][-1], self.data['time'][-1]])
44 |
45 | # find the index of the start and end
46 | idx_start_imu = np.searchsorted(self.data['time'], t_start)
47 | idx_start_gt = np.searchsorted(self.data['gt_time'], t_start)
48 |
49 | idx_end_imu = np.searchsorted(self.data['time'], t_end, 'right')
50 | idx_end_gt = np.searchsorted(self.data['gt_time'], t_end, 'right')
51 |
52 | for k in ['gt_time', 'pos', 'quat', 'velocity', 'b_acc', 'b_gyro']:
53 | self.data[k] = self.data[k][idx_start_gt:idx_end_gt]
54 |
55 | for k in ['time', 'acc', 'gyro']:
56 | self.data[k] = self.data[k][idx_start_imu:idx_end_imu]
57 |
58 | ## start interpotation
59 | self.data["gt_orientation"] = self.interp_rot(self.data['time'], self.data['gt_time'], self.data['quat'])
60 | self.data["gt_translation"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data['pos'])
61 |
62 | self.data["b_acc"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["b_acc"])
63 | self.data["b_gyro"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["b_gyro"])
64 | self.data["velocity"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data["velocity"])
65 |
66 | else:
67 | self.data["gt_orientation"] = pp.SO3(torch.tensor(self.data['pose'][:,3:]))
68 | self.data['gt_translation'] = torch.tensor(self.data['pose'][:,:3])
69 |
70 | # move the time to torch
71 | self.data["time"] = torch.tensor(self.data["time"])
72 | self.data["gt_time"] = torch.tensor(self.data["gt_time"])
73 | self.data['dt'] = (self.data["time"][1:] - self.data["time"][:-1])[:,None]
74 | self.data["mask"] = torch.ones(self.data["time"].shape[0], dtype=torch.bool)
75 |
76 | # Calibration for evaluation
77 | if calib == "head":
78 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - self.data["b_gyro"][0]
79 | self.data["acc"] = torch.tensor(self.data["acc"]) - self.data["b_acc"][0]
80 | elif calib == "full":
81 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - self.data["b_gyro"]
82 | self.data["acc"] = torch.tensor(self.data["acc"]) - self.data["b_acc"]
83 | elif calib == "aligngravity":
84 | ## Find the nearest static point
85 | nl_point = np.where(self.data['velocity'].norm(dim=-1) < 0.001)[0][0]
86 | avg_acc = self.data['acc'][nl_point+10:nl_point+100].mean(axis=-2)
87 | avg_gyro = self.data['gyro'][nl_point+10:nl_point+100].mean(axis=-2)
88 |
89 | gr = lookAt(avg_acc)
90 | g_IMU = gr.T @ self.gravity
91 | gl_acc_b = avg_acc - g_IMU.numpy()
92 | gl_gyro_b = avg_gyro
93 |
94 | self.data["acc"] = torch.tensor(self.data["acc"]) - gl_acc_b
95 | self.data["gyro"] = torch.tensor(self.data["gyro"]) - gl_gyro_b
96 | else:
97 | self.data["gyro"] = torch.tensor(self.data["gyro"])
98 | self.data["acc"] = torch.tensor(self.data["acc"])
99 |
100 | # change the acc and gyro scope into the global coordinate.
101 | if glob_coord:
102 | self.data['gyro'] = self.data["gt_orientation"] * self.data['gyro']
103 | self.data['acc'] = self.data["gt_orientation"] * self.data['acc']
104 |
105 | print("loaded: ", data_path, "calib: ", calib, "interpolate: ", intepolate)
106 |
107 | def get_length(self):
108 | return self.data['time'].shape[0]
109 |
110 | def load_imu(self, folder):
111 | imu_data = np.loadtxt(os.path.join(folder, "mav0/imu0/data.csv"), dtype=float, delimiter=',')
112 | self.data["time"] = imu_data[:,0] / 1e9
113 | self.data["gyro"] = imu_data[:,1:4] # w_RS_S_x [rad s^-1],w_RS_S_y [rad s^-1],w_RS_S_z [rad s^-1]
114 | self.data["acc"] = imu_data[:,4:]# acc a_RS_S_x [m s^-2],a_RS_S_y [m s^-2],a_RS_S_z [m s^-2]
115 |
116 | def load_gt(self, folder):
117 | gt_data = np.loadtxt(os.path.join(folder, "mav0/state_groundtruth_estimate0/data.csv"), dtype=float, delimiter=',')
118 | self.data["gt_time"] = gt_data[:,0] / 1e9
119 | self.data["pos"] = gt_data[:,1:4]
120 | self.data['quat'] = gt_data[:,4:8] # w, x, y, z
121 | self.data["b_acc"] = gt_data[:,-3:]
122 | self.data["b_gyro"] = gt_data[:,-6:-3]
123 | self.data["velocity"] = gt_data[:,-9:-6]
124 |
125 | def interp_rot(self, time, opt_time, quat):
126 | # interpolation in the log space
127 | imu_dt = torch.Tensor(time - opt_time[0])
128 | gt_dt = torch.Tensor(opt_time - opt_time[0])
129 |
130 | quat = torch.tensor(quat)
131 | quat = qinterp(quat, gt_dt, imu_dt).double()
132 | self.data['rot_wxyz'] = quat
133 | rot = torch.zeros_like(quat)
134 | rot[:,3] = quat[:,0]
135 | rot[:,:3] = quat[:,1:]
136 |
137 | return pp.SO3(rot)
138 |
139 | def interp_xyz(self, time, opt_time, xyz):
140 |
141 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0])
142 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1])
143 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2])
144 |
145 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose()
146 | return torch.tensor(inte_xyz)
147 |
148 |
--------------------------------------------------------------------------------
/datasets/KITTIdataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pypose as pp
4 |
5 | import pykitti
6 | from datetime import datetime
7 |
8 | from .dataset import Sequence
9 | class KITTI(Sequence):
10 | def __init__(self, data_root, data_drive, **kwargs) -> None:
11 | (
12 | self.data_root, self.data_date, self.drive,
13 | ## self.data
14 | # "time" - (nx1)
15 | # "acc" - (nx3)
16 | # "gyro" - (nx3)
17 | # "dt" - (nx1)
18 | # "gt_translation" - (nx3)
19 | # "gt_orientation" - (nx3)
20 | # "velocity" - (nx3)
21 | # "mask" -
22 | ##
23 | self.data,
24 | self.ts, # Not used
25 | self.targets, # Not used
26 | self.orientations, # Not used
27 | self.gt_pos, # Tensor (nx3)
28 | self.gt_ori, # SO3
29 | ) = "/".join(data_root.split("/")[:-1]), data_root.split("/")[-1], data_drive, dict(), None, None, None, None, None
30 | print(f"Loading KITTI {data_root} @ {data_drive}")
31 | self.load_data()
32 | self.data_name = self.data_date + "_" + self.drive
33 | print(f"KITTI Sequence {data_drive} - length: {self.get_length()}")
34 |
35 | def get_length(self):
36 | return self.data["time"].size(0) - 1
37 |
38 | def load_data(self):
39 | raw_data = pykitti.raw(self.data_root, self.data_date, self.drive)
40 | raw_len = len(raw_data.timestamps) - 1
41 |
42 | self.data["time"] = torch.tensor(
43 | [datetime.timestamp(raw_data.timestamps[i]) for i in range(raw_len + 1)],
44 | dtype=torch.double
45 | ).unsqueeze(-1).double()
46 | self.data["acc"] = torch.tensor(
47 | [
48 | [raw_data.oxts[i].packet.ax,
49 | raw_data.oxts[i].packet.ay,
50 | raw_data.oxts[i].packet.az]
51 | for i in range(raw_len)
52 | ]
53 | ).double()
54 | self.data["gyro"] = torch.tensor(
55 | [
56 | [raw_data.oxts[i].packet.wx,
57 | raw_data.oxts[i].packet.wy,
58 | raw_data.oxts[i].packet.wz]
59 | for i in range(raw_len)
60 | ]
61 | ).double()
62 | self.data["dt"] = self.data["time"][1:] - self.data["time"][:-1]
63 |
64 | self.data["gt_translation"] = torch.tensor(
65 | np.array([raw_data.oxts[i].T_w_imu[0:3, 3]
66 | for i in range(raw_len)])
67 | ).double()
68 |
69 | self.data["gt_orientation"] = pp.euler2SO3(torch.tensor(
70 | [
71 | [raw_data.oxts[i].packet.roll,
72 | raw_data.oxts[i].packet.pitch,
73 | raw_data.oxts[i].packet.yaw]
74 | for i in range(raw_len)
75 | ]
76 | )).double()
77 | self.data["velocity"] = self.data["gt_orientation"] @ torch.tensor(
78 | [[raw_data.oxts[i].packet.vf,
79 | raw_data.oxts[i].packet.vl,
80 | raw_data.oxts[i].packet.vu]
81 | for i in range(raw_len)]
82 | ).double()
83 | self.data["mask"] = torch.ones(self.data["time"].shape[0], dtype=torch.bool)
84 | self.gt_pos = self.data["gt_translation"].clone()
85 | self.gt_ori = self.data["gt_orientation"].clone()
86 |
87 |
--------------------------------------------------------------------------------
/datasets/TUMdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pypose as pp
5 | from utils import qinterp
6 | from .dataset import Sequence
7 |
8 | class TUMVI(Sequence):
9 | """
10 | Output:
11 | acce: the accelaration in **world frame**
12 | """
13 | def __init__(self, data_root, data_name, intepolate = True, calib = False, glob_coord=False, **kwargs):
14 | super(TUMVI, self).__init__()
15 | (
16 | self.data_root, self.data_name,
17 | self.data,
18 | self.ts,
19 | self.targets,
20 | self.orientations,
21 | self.gt_pos,
22 | self.gt_ori,
23 | ) = (data_root, data_name, dict(), None, None, None, None, None)
24 | data_path = os.path.join(self.data_root, self.data_name)
25 | self.load_imu(data_path)
26 | self.load_gt(data_path)
27 |
28 | # inteporlate the ground truth pose
29 | if intepolate:
30 | t_start = np.max([self.data['gt_time'][0], self.data['time'][0]])
31 | t_end = np.min([self.data['gt_time'][-1], self.data['time'][-1]])
32 |
33 | idx_start_imu = np.searchsorted(self.data['time'], t_start)
34 | idx_start_gt = np.searchsorted(self.data['gt_time'], t_start)
35 |
36 | idx_end_imu = np.searchsorted(self.data['time'], t_end, 'right')
37 | idx_end_gt = np.searchsorted(self.data['gt_time'], t_end, 'right')
38 |
39 | ## GT data
40 | for k in ['gt_time', 'pos', 'quat']:
41 | self.data[k] = self.data[k][idx_start_gt:idx_end_gt]
42 |
43 | # ## imu data
44 | for k in ['time', 'acc', 'gyro']:
45 | self.data[k] = self.data[k][idx_start_imu:idx_end_imu]
46 |
47 | ## start interpotation
48 | self.data["gt_orientation"] = self.interp_rot(self.data['time'], self.data['gt_time'], self.data['quat'])
49 | self.data["gt_translation"] = self.interp_xyz(self.data['time'], self.data['gt_time'], self.data['pos'])
50 | else:
51 | self.data["gt_orientation"] = pp.SO3(torch.tensor(self.data['pose'][:,3:]))
52 | self.data['gt_translation'] = torch.tensor(self.data['pose'][:,:3])
53 |
54 | # move the time to torch
55 | self.data["time"] = torch.tensor(self.data["time"])
56 | self.data["gt_time"] = torch.tensor(self.data["gt_time"])
57 | self.data['dt'] = (self.data["time"][1:] - self.data["time"][:-1])[:,None]
58 |
59 | ## TUM dataset has some mistracked area
60 | gt_indexing = torch.searchsorted(self.data['gt_time'], self.data['time']) # indexing the imu with the nearest gt.
61 | time_dist = (self.data['time'] - self.data['gt_time'][gt_indexing]).abs()
62 | self.data["mask"] = time_dist < 0.01
63 |
64 | # Calibration for evaluation
65 | self.data["gyro"] = torch.tensor(self.data["gyro"])
66 | self.data["acc"] = torch.tensor(self.data["acc"])
67 |
68 | # change the acc and gyro scope into the global coordinate.
69 | if glob_coord: # For the other methods
70 | self.data['gyro'] = self.data["gt_orientation"] * self.data['gyro']
71 | self.data['acc'] = self.data["gt_orientation"] * self.data['acc']
72 |
73 | print("loaded: ", data_path, "calib: ", calib, "interpolate: ", intepolate)
74 | # self.save_hdf5(data_path)
75 |
76 | def get_length(self):
77 | return self.data['time'].shape[0]
78 |
79 | def load_imu(self, folder):
80 | imu_data = np.loadtxt(os.path.join(folder, "mav0/imu0/data.csv"), dtype=float, delimiter=',')
81 | self.data["time"] = imu_data[:,0] / 1e9
82 | self.data["gyro"] = imu_data[:,1:4] # w_RS_S_x [rad s^-1],w_RS_S_y [rad s^-1],w_RS_S_z [rad s^-1]
83 | self.data["acc"] = imu_data[:,4:]# acc a_RS_S_x [m s^-2],a_RS_S_y [m s^-2],a_RS_S_z [m s^-2]
84 |
85 | def load_gt(self, folder):
86 | gt_data = np.loadtxt(os.path.join(folder, "mav0/mocap0/data.csv"), dtype=float, delimiter=',')
87 | self.data["gt_time"] = gt_data[:,0] / 1e9
88 | self.data["pos"] = gt_data[:,1:4]
89 | self.data['quat'] = gt_data[:,4:8] # w, x, y, z
90 | velo_data = np.loadtxt(os.path.join(folder, "mav0/mocap0/grad_velo.txt"), dtype=float)
91 | self.data["velocity"] = torch.tensor(velo_data[:,1:])
92 |
93 | def interp_rot(self, time, opt_time, quat):
94 | # interpolation in the log space
95 | imu_dt = torch.Tensor(time - opt_time[0])
96 | gt_dt = torch.Tensor(opt_time - opt_time[0])
97 |
98 | quat = torch.tensor(quat)
99 | quat = qinterp(quat, gt_dt, imu_dt).double()
100 | self.data['rot_wxyz'] = quat
101 | rot = torch.zeros_like(quat)
102 | rot[:,3] = quat[:,0]
103 | rot[:,:3] = quat[:,1:]
104 |
105 | return pp.SO3(rot)
106 |
107 | def interp_xyz(self, time, opt_time, xyz):
108 |
109 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0])
110 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1])
111 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2])
112 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose()
113 |
114 | return torch.tensor(inte_xyz)
115 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import *
2 | from .dataset_utils import *
3 | from .EuRoCdataset import *
4 | from .KITTIdataset import *
5 | from .TUMdataset import *
--------------------------------------------------------------------------------
/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from abc import ABC, abstractmethod
3 |
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as Data
7 | from pyhocon import ConfigFactory
8 |
9 |
10 | class Sequence(ABC):
11 | # Dictionary to keep track of subclasses
12 | subclasses = {}
13 |
14 | def __init_subclass__(cls, **kwargs):
15 | super().__init_subclass__(**kwargs)
16 | cls.subclasses[cls.__name__] = cls
17 |
18 | class SeqDataset(Data.Dataset):
19 | def __init__(self, root, dataname, devive = 'cpu', name='Nav', duration=200, step_size=200, mode='inference',
20 | drop_last = True, conf = {}):
21 | super().__init__()
22 |
23 | self.DataClass = Sequence.subclasses
24 |
25 | self.conf = conf
26 | self.seq = self.DataClass[name](root, dataname, **self.conf)
27 | self.data = self.seq.data
28 | self.seqlen = self.seq.get_length()-1
29 | self.gravity = conf.gravity if "gravity" in conf.keys() else 9.81007
30 | if duration is None: self.duration = self.seqlen
31 | else: self.duration = duration
32 |
33 | if step_size is None: self.step_size = self.seqlen
34 | else: self.step_size = step_size
35 |
36 | self.data['acc_cov'] = 0.08 * torch.ones_like(self.data['acc'])
37 | self.data['gyro_cov'] = 0.006 * torch.ones_like(self.data['gyro'])
38 |
39 | start_frame = 0
40 | end_frame = self.seqlen
41 |
42 | self.index_map = [[i, i + self.duration] for i in range(
43 | 0, end_frame - start_frame - self.duration, self.step_size)]
44 | if (self.index_map[-1][-1] < end_frame) and (not drop_last):
45 | self.index_map.append([self.index_map[-1][-1], end_frame])
46 |
47 | self.index_map = np.array(self.index_map)
48 |
49 | def __len__(self):
50 | return len(self.index_map)
51 |
52 | def __getitem__(self, i):
53 | frame_id, end_frame_id = self.index_map[i]
54 | return {
55 | 'dt': self.data['dt'][frame_id: end_frame_id],
56 | 'acc': self.data['acc'][frame_id: end_frame_id],
57 | 'gyro': self.data['gyro'][frame_id: end_frame_id],
58 | 'rot': self.data['gt_orientation'][frame_id: end_frame_id],
59 | 'gt_pos': self.data['gt_translation'][frame_id+1: end_frame_id+1],
60 | 'gt_rot': self.data['gt_orientation'][frame_id+1: end_frame_id+1],
61 | 'gt_vel': self.data['velocity'][frame_id+1: end_frame_id+1],
62 | 'init_pos': self.data['gt_translation'][frame_id][None, ...],
63 | 'init_rot': self.data['gt_orientation'][frame_id: end_frame_id],
64 | 'init_vel': self.data['velocity'][frame_id][None, ...],
65 | }
66 |
67 | def get_init_value(self):
68 | return {'pos': self.data['gt_translation'][:1],
69 | 'rot': self.data['gt_orientation'][:1],
70 | 'vel': self.data['velocity'][:1]}
71 |
72 | def get_mask(self):
73 | return self.data['mask']
74 |
75 | def get_gravity(self):
76 | return self.gravity
77 |
78 |
79 | class SeqInfDataset(SeqDataset):
80 | def __init__(self, root, dataname, inference_state, device = 'cpu', name='Nav', duration=200, step_size=200,
81 | drop_last = True, mode='inference', usecov = True, useraw = False,conf={}):
82 | super().__init__(root, dataname, device, name, duration, step_size, mode, drop_last, conf)
83 | self.data['acc'][:-1] += inference_state['correction_acc'].cpu()[0]
84 | self.data['gyro'][:-1] += inference_state['correction_gyro'].cpu()[0]
85 |
86 | if 'acc_cov' in inference_state.keys() and usecov:
87 | self.data['acc_cov'] = inference_state['acc_cov'][0]
88 |
89 | if 'gyro_cov' in inference_state.keys() and usecov:
90 | self.data['gyro_cov'] = inference_state['gyro_cov'][0]
91 |
92 |
93 | class SeqeuncesDataset(Data.Dataset):
94 | """
95 | For the purpose of training and inferering
96 | 1. Abandon the features of the last time frame, since there are no ground truth pose and dt
97 | to integrate the imu data of the last frame. So the length of the dataset is seq.get_length() - 1
98 | """
99 | def __init__(self, data_set_config, mode = None, data_path = None, data_root = None, device= "cuda:0"):
100 | super(SeqeuncesDataset, self).__init__()
101 | (
102 | self.ts,
103 | self.dt,
104 | self.acc,
105 | self.gyro,
106 | self.gt_pos,
107 | self.gt_ori,
108 | self.gt_velo,
109 | self.index_map,
110 | self.seq_idx,
111 | ) = ([], [], [], [], [], [], [], [], 0)
112 | self.uni = torch.distributions.uniform.Uniform(-torch.ones(1), torch.ones(1))
113 | self.device = device
114 | self.conf = data_set_config
115 | self.gravity = conf.gravity if "gravity" in conf.keys() else 9.81007
116 | if mode is None:
117 | self.mode = data_set_config.mode
118 | else:
119 | self.mode = mode
120 |
121 | self.DataClass = Sequence.subclasses
122 |
123 | ## the design of datapath provide a quick way to revisit a specific sequence, but introduce some inconsistency
124 | if data_path is None:
125 | for conf in data_set_config.data_list:
126 | for path in conf.data_drive:
127 | self.construct_index_map(conf, conf["data_root"], path, self.seq_idx)
128 | self.seq_idx += 1
129 | ## the design of dataroot provide a quick way to introduce multiple sequences in eval set, but introduce some inconsistency
130 | elif data_root is None:
131 | conf = data_set_config.data_list[0]
132 | self.construct_index_map(conf, conf["data_root"], data_path, self.seq_idx)
133 | self.seq_idx += 1
134 | else:
135 | conf = data_set_config.data_list[0]
136 | self.construct_index_map(conf, data_root, data_path, self.seq_idx)
137 | self.seq_idx += 1
138 |
139 | def load_data(self, seq, start_frame, end_frame):
140 | if "time" in seq.data.keys():
141 | self.ts.append(seq.data["time"][start_frame:end_frame])
142 | self.acc.append(seq.data["acc"][start_frame:end_frame])
143 | self.gyro.append(seq.data["gyro"][start_frame:end_frame])
144 | # the groud truth state should include the init state and integrated state, thus has one more frame than imu data
145 | self.dt.append(seq.data["dt"][start_frame:end_frame+1])
146 | self.gt_pos.append(seq.data["gt_translation"][start_frame:end_frame+1])
147 | self.gt_ori.append(seq.data["gt_orientation"][start_frame:end_frame+1])
148 | self.gt_velo.append(seq.data["velocity"][start_frame:end_frame+1])
149 |
150 | def construct_index_map(self, conf, data_root, data_name, seq_id):
151 | seq = self.DataClass[conf.name](data_root, data_name, intepolate = True, **self.conf)
152 | seq_len = seq.get_length() -1 # abandon the last imu features
153 | window_size, step_size = conf.window_size, conf.step_size
154 | ## seting the starting and ending duration with different trianing mode
155 | start_frame, end_frame = 0, seq_len
156 |
157 | if self.mode == 'train_half':
158 | end_frame = np.floor(seq_len * 0.5).astype(int)
159 | elif self.mode == 'test_half':
160 | start_frame = np.floor(seq_len * 0.5).astype(int)
161 | elif self.mode == 'train_1m':
162 | end_frame = 12000
163 | elif self.mode == 'test_1m':
164 | start_frame = 12000
165 | elif self.mode == 'mini':# For the purpse of debug
166 | end_frame = 1000
167 |
168 | _duration = end_frame - start_frame
169 | if self.mode == "inference":
170 | window_size = seq_len
171 | step_size = seq_len
172 | self.index_map = [[seq_id, 0, seq_len]]
173 | elif self.mode == "infevaluate":
174 | self.index_map +=[
175 | [seq_id, j, j+window_size] for j in range(
176 | 0, _duration - window_size, step_size)
177 | ]
178 | if self.index_map[-1][2] < _duration:
179 | print(self.index_map[-1][2])
180 | self.index_map += [[seq_id, self.index_map[-1][2], seq_len]]
181 | elif self.mode == 'evaluate':
182 | # adding the last piece for evaluation
183 | self.index_map +=[
184 | [seq_id, j, j+window_size] for j in range(
185 | 0, _duration - window_size, step_size)
186 | ]
187 | elif self.mode == 'train_half_random':
188 | np.random.seed(1)
189 | window_group_size = 3000
190 | selected_indices = [j for j in range(0, _duration-window_group_size, window_group_size)]
191 | np.random.shuffle(selected_indices)
192 | indices_num = len(selected_indices)
193 | for w in selected_indices[:np.floor(indices_num * 0.5).astype(int)]:
194 | self.index_map +=[[seq_id, j, j + window_size] for j in range(w, w+window_group_size-window_size,step_size)]
195 | elif self.mode == 'test_half_random':
196 | np.random.seed(1)
197 | window_group_size = 3000
198 | selected_indices = [j for j in range(0, _duration-window_group_size, window_group_size)]
199 | np.random.shuffle(selected_indices)
200 | indices_num = len(selected_indices)
201 | for w in selected_indices[np.floor(indices_num * 0.5).astype(int):]:
202 | self.index_map +=[[seq_id, j, j + window_size] for j in range(w, w+window_group_size-window_size,step_size)]
203 | else:
204 | ## applied the mask if we need the training.
205 | self.index_map +=[
206 | [seq_id, j, j+window_size] for j in range(
207 | 0, _duration - window_size, step_size)
208 | if torch.all(seq.data["mask"][j: j+window_size])
209 | ]
210 |
211 | ## Loading the data from each sequence into
212 | self.load_data(seq, start_frame, end_frame)
213 |
214 | def __len__(self):
215 | return len(self.index_map)
216 |
217 | def __getitem__(self, item):
218 | seq_id, frame_id, end_frame_id = self.index_map[item][0], self.index_map[item][1], self.index_map[item][2]
219 | data = {
220 | 'dt': self.dt[seq_id][frame_id: end_frame_id],
221 | 'acc': self.acc[seq_id][frame_id: end_frame_id],
222 | 'gyro': self.gyro[seq_id][frame_id: end_frame_id],
223 | 'rot': self.gt_ori[seq_id][frame_id: end_frame_id]
224 | }
225 | init_state = {
226 | 'init_rot': self.gt_ori[seq_id][frame_id][None, ...],
227 | 'init_pos': self.gt_pos[seq_id][frame_id][None, ...],
228 | 'init_vel': self.gt_velo[seq_id][frame_id][None, ...],
229 | }
230 | label = {
231 | 'gt_pos': self.gt_pos[seq_id][frame_id+1 : end_frame_id+1],
232 | 'gt_rot': self.gt_ori[seq_id][frame_id+1 : end_frame_id+1],
233 | 'gt_vel': self.gt_velo[seq_id][frame_id+1 : end_frame_id+1],
234 | }
235 |
236 | return {**data, **init_state, **label}
237 |
238 | def get_dtype(self):
239 | return self.acc[0].dtype
240 |
241 |
242 |
243 | if __name__ == '__main__':
244 | from datasets.dataset_utils import custom_collate
245 | parser = argparse.ArgumentParser()
246 | parser.add_argument('--config', type=str, default='configs/datasets/BaselineEuRoC.conf', help='config file path, i.e., configs/Euroc.conf')
247 | parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu")
248 |
249 | args = parser.parse_args(); print(args)
250 | conf = ConfigFactory.parse_file(args.config)
251 |
252 | dataset = SeqeuncesDataset(data_set_config=conf.train)
253 | loader = Data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, collate_fn=custom_collate)
254 |
255 | for i, (data, init, label) in enumerate(loader):
256 | for k in data: print(k, ":", data[k].shape)
257 | for k in init: print(k, ":", init[k].shape)
258 | for k in label: print(k, ":", label[k].shape)
259 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def imu_seq_collate(data):
4 | acc = torch.stack([d['acc'] for d in data])
5 | gyro = torch.stack([d['gyro'] for d in data])
6 |
7 | gt_pos = torch.stack([d['gt_pos'] for d in data])
8 | gt_rot = torch.stack([d['gt_rot'] for d in data])
9 | gt_vel = torch.stack([d['gt_vel'] for d in data])
10 |
11 | init_pos = torch.stack([d['init_pos'] for d in data])
12 | init_rot = torch.stack([d['init_rot'] for d in data])
13 | init_vel = torch.stack([d['init_vel'] for d in data])
14 |
15 | dt = torch.stack([d['dt'] for d in data])
16 |
17 | return {
18 | 'dt': dt,
19 | 'acc': acc,
20 | 'gyro': gyro,
21 |
22 | 'gt_pos': gt_pos,
23 | 'gt_vel': gt_vel,
24 | 'gt_rot': gt_rot,
25 |
26 | 'init_pos': init_pos,
27 | 'init_vel': init_vel,
28 | 'init_rot': init_rot,
29 | }
30 |
31 | def custom_collate(data):
32 | dt = torch.stack([d['dt'] for d in data])
33 | acc = torch.stack([d['acc'] for d in data])
34 | gyro = torch.stack([d['gyro'] for d in data])
35 | rot = torch.stack([d['rot'] for d in data])
36 |
37 | gt_pos = torch.stack([d['gt_pos'] for d in data])
38 | gt_rot = torch.stack([d['gt_rot'] for d in data])
39 | gt_vel = torch.stack([d['gt_vel'] for d in data])
40 |
41 | init_pos = torch.stack([d['init_pos'] for d in data])
42 | init_rot = torch.stack([d['init_rot'] for d in data])
43 | init_vel = torch.stack([d['init_vel'] for d in data])
44 |
45 | return {'dt': dt, 'acc': acc, 'gyro': gyro, 'rot':rot,}, \
46 | {'pos': init_pos, 'vel': init_vel, 'rot': init_rot,}, \
47 | {'gt_pos': gt_pos, 'gt_vel': gt_vel, 'gt_rot': gt_rot, }
48 |
49 | def padding_collate(data, pad_len = 1, use_gravity = True):
50 | B = len(data)
51 | input_data, init_state, label = custom_collate(data)
52 |
53 | if use_gravity:
54 | iden_acc_vector = torch.tensor([0.,0.,9.81007], dtype=input_data['dt'].dtype).repeat(B,pad_len,1)
55 | else:
56 | iden_acc_vector = torch.zeros(B, pad_len, 3, dtype=input_data['dt'].dtype)
57 |
58 | pad_acc = init_state['rot'].Inv() * iden_acc_vector
59 | pad_gyro = torch.zeros(B, pad_len, 3, dtype=input_data['dt'].dtype)
60 |
61 | input_data["acc"] = torch.cat([pad_acc, input_data['acc']], dim =1)
62 | input_data["gyro"] = torch.cat([pad_gyro, input_data['gyro']], dim =1)
63 |
64 | return input_data, init_state, label
65 |
66 | collate_fcs ={
67 | "base": custom_collate,
68 | "padding": padding_collate,
69 | "padding9": lambda data: padding_collate(data, pad_len = 9),
70 | "padding1": lambda data: padding_collate(data, pad_len = 1),
71 | "Gpadding": lambda data: padding_collate(data, pad_len = 9, use_gravity = False),
72 | }
--------------------------------------------------------------------------------
/doc/alto.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haleqiu/AirIMU/c69afb9d1dfa5acf13a5cc1c15dca370f7440636/doc/alto.gif
--------------------------------------------------------------------------------
/doc/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/haleqiu/AirIMU/c69afb9d1dfa5acf13a5cc1c15dca370f7440636/doc/model.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | import torch.utils.data as Data
6 | import argparse
7 | import pickle
8 |
9 | import tqdm, yaml
10 | from utils import move_to, save_state, cat_state, vis_corrections
11 | from model import net_dict
12 | from pyhocon import ConfigFactory
13 |
14 | from datasets import SeqeuncesDataset, collate_fcs
15 | from model.losses import get_RMSE
16 |
17 |
18 | def get_metrics(eval_state):
19 | metrics = {}
20 | net_dist = (eval_state['evaluate']["rot"].Inv() * eval_state['labels']["gt_rot"]).Log()
21 | metrics['pos'], metrics['rot'], metrics['vel'] = eval_state['loss']['pos'].mean().item(), eval_state['loss']['rot'].mean().item(), eval_state['loss']['vel'].mean().item()
22 | metrics['rot_deg'] = 180./np.pi * metrics['rot']
23 |
24 | return metrics
25 |
26 |
27 | def evaluate(network, loader, confs, silent_tqdm=False):
28 | network.eval()
29 | evaluate_cov_states, evaluate_states, loss_states, labels = {}, {}, {}, {}
30 | pred_rot_covs, pred_vel_covs, pred_pos_covs = [], [], []
31 |
32 | with torch.no_grad():
33 | inte_state = None
34 | for i, (data, init_state, label) in enumerate(tqdm.tqdm(loader, disable=silent_tqdm)):
35 | data, init_state, label = move_to([data, init_state, label], confs.device)
36 | # Use the gt init state while there is no integration.
37 | if inte_state is not None and confs.gtinit is False:
38 | init_state ={
39 | "pos": inte_state['pos'][:,-1],
40 | "rot": inte_state['rot'][:,-1],
41 | "vel": inte_state['vel'][:,-1],
42 | }
43 | inte_state = network(data, init_state)
44 | loss_state = get_RMSE(inte_state, label)
45 |
46 | save_state(loss_states, loss_state)
47 | save_state(evaluate_states, inte_state)
48 | save_state(labels, label)
49 |
50 | if 'cov' in inte_state and inte_state['cov'] is not None:
51 | cov_diag = torch.diagonal(inte_state['cov'], dim1=-2, dim2=-1) # Shape: (B, 9)
52 |
53 | pred_rot_covs.append(cov_diag[..., :3])
54 | pred_pos_covs.append(cov_diag[...,-3:])
55 | pred_vel_covs.append(cov_diag[...,3:6])
56 |
57 | if 'cov' in inte_state and inte_state['cov'] is not None:
58 | evaluate_cov_states["pred_rot_covs"] = torch.cat(pred_rot_covs, dim = -2)
59 | evaluate_cov_states["pred_vel_covs"] = torch.cat(pred_vel_covs, dim = -2)
60 | evaluate_cov_states["pred_pos_covs"] = torch.cat(pred_pos_covs, dim = -2)
61 |
62 | for k, v in loss_states.items():
63 | loss_states[k] = torch.stack(v, dim=0)
64 | cat_state(evaluate_states)
65 | cat_state(labels)
66 |
67 | print("evaluating: position loss %f, rotation loss %f, vel losses %f"\
68 | %(loss_states['pos'].mean(), loss_states['rot'].mean(), loss_states['vel'].mean()))
69 |
70 | return {'evaluate': evaluate_states, 'evaluate_cov': evaluate_cov_states, 'loss': loss_states, 'labels': labels}
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path')
76 | parser.add_argument('--load', type=str, default=None, help='path for model check point')
77 | parser.add_argument("--device", type=str, default="cuda:0", help="cuda or cpu")
78 | parser.add_argument('--seqlen', type=int, default=None, help='the length of the integration sequence.')
79 | parser.add_argument('--batch_size', type=int, default=1, help='batch size.')
80 | parser.add_argument('--steplen', type=int, default=None, help='the length of the step we take.')
81 | parser.add_argument('--gtrot', default=True, action="store_false", help='if set False, we will not use ground truth orientation to compensate the gravity')
82 | parser.add_argument('--gtinit', default=True, action="store_false", help='if set False, we will use the integrated pose as the intial pose for the next integral')
83 | parser.add_argument('--posonly', default=False, action="store_true", help='if True, ground truth rotation will be applied in the integration.')
84 | parser.add_argument('--train', default=False, action="store_true", help='if True, We will evaluate the training set (may be removed in the future).')
85 | parser.add_argument('--whole', default=False, action="store_true", help='(may be removed in the future).')
86 |
87 | args = parser.parse_args(); print(args)
88 | conf = ConfigFactory.parse_file(args.config)
89 | conf.train.device = args.device
90 | conf_name = os.path.split(args.config)[-1].split(".")[0]
91 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name)
92 |
93 | if args.posonly:
94 | conf.train["posonly"] = True
95 |
96 | conf.train["gtrot"] = args.gtrot
97 | conf.train["gtinit"] = args.gtinit
98 | conf.train['sampling'] = False
99 | network = net_dict[conf.train.network](conf.train).to(args.device).double()
100 |
101 | save_folder = os.path.join(conf.general.exp_dir, "evaluate")
102 | os.makedirs(save_folder, exist_ok=True)
103 |
104 | if args.load is None:
105 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt")
106 | else:
107 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt", args.load)
108 |
109 | if os.path.exists(ckpt_path):
110 | checkpoint = torch.load(ckpt_path, map_location=torch.device(args.device))
111 | print("loaded state dict %s in epoch %i"%(ckpt_path, checkpoint["epoch"]))
112 | network.load_state_dict(checkpoint["model_state_dict"])
113 | else:
114 | print("no model loaded", ckpt_path)
115 |
116 | if 'collate' in conf.dataset.keys():
117 | collate_fn = collate_fcs[conf.dataset.collate]
118 | else:
119 | collate_fn = collate_fcs['base']
120 |
121 | if args.train:
122 | dataset_conf = conf.dataset.train
123 | else:
124 | dataset_conf = conf.dataset.eval
125 |
126 | if args.posonly:
127 | dataset_conf['calib'] = "posonly"
128 |
129 | for data_conf in dataset_conf.data_list:
130 | if args.seqlen is not None:
131 | data_conf["window_size"] = args.seqlen
132 | data_conf["step_size"] = args.seqlen if args.steplen is None else args.steplen
133 | if args.whole:
134 | data_conf["mode"] = "inference"
135 |
136 | pos_loss_xyzs = []
137 | pred_pos_cov = []
138 |
139 | all_metrics = {}
140 | for path in data_conf.data_drive:
141 | eval_dataset = SeqeuncesDataset(data_set_config=dataset_conf, data_path=path)
142 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn, drop_last = False)
143 | eval_state = evaluate(network=network, loader = eval_loader, confs=conf.train)
144 | ## Save the state in
145 | net_result_path = os.path.join(conf.general.exp_dir, path + '_eval_state.pickle')
146 | with open(net_result_path, 'wb') as handle:
147 | pickle.dump(eval_state, handle, protocol=pickle.HIGHEST_PROTOCOL)
148 |
149 | if "pred_pos_covs" in eval_state['evaluate_cov'].keys():
150 | pred_pos_cov.append(eval_state['evaluate_cov']['pred_pos_covs'])
151 |
152 | title = "$SO(3)$ orientation error"
153 | if 'correction_acc' in eval_state['evaluate'].keys():
154 | correction = torch.cat((eval_state['evaluate']['correction_acc'][0], eval_state['evaluate']['correction_gyro'][0]), dim=-1)
155 | vis_corrections(correction.cpu(), save_folder=os.path.join(save_folder, path))
156 |
157 | all_metrics[path] = get_metrics(eval_state)
158 |
159 | with open(os.path.join(conf.general.exp_dir, 'result_%s_init%s_gt%s.yaml'%(str(args.seqlen),str(args.gtinit),str(args.gtrot))), 'w') as file:
160 | yaml.dump(all_metrics, file, default_flow_style=False)
--------------------------------------------------------------------------------
/evaluation/evaluate_state.py:
--------------------------------------------------------------------------------
1 | # output the trajctory in the world frame for visualization and evaluation
2 | import os, sys
3 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)))
4 |
5 | import os
6 | import json
7 | import argparse
8 | import numpy as np
9 | import pypose as pp
10 |
11 | import torch
12 | import torch.utils.data as Data
13 |
14 | from pyhocon import ConfigFactory
15 | from datasets import SeqInfDataset, SeqDataset, imu_seq_collate
16 |
17 | from utils import CPU_Unpickler, integrate
18 | from utils.visualize_state import visualize_rotations, visualize_state_error
19 |
20 | if __name__ == '__main__':
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--device", type=str, default="cpu", help="cuda or cpu, Default is cuda:0")
23 | parser.add_argument("--exp", type=str, default=None, help="the directory path where your network output pickle file is stored")
24 | parser.add_argument("--seqlen", type=int, default="200", help="the length of the integration sequence")
25 | parser.add_argument("--dataconf", type=str, default="configs/datasets/BaselineEuroc/Euroc_1000.conf", help="the configuration of the dataset")
26 | parser.add_argument("--savedir",type=str,default = "./result/loss_result",help = "the save diretory for the evaluation results")
27 | parser.add_argument("--usegtrot", action="store_true", help="Use ground truth rotation for gravity compensation")
28 | parser.add_argument("--mask", action="store_true", help="Mask the segments if needed")
29 |
30 | args = parser.parse_args();
31 | print(("\n"*3) + str(args) + ("\n"*3))
32 | config = ConfigFactory.parse_file(args.dataconf)
33 | dataset_conf = config.inference
34 |
35 | if args.exp is not None:
36 | net_result_path = os.path.join(args.exp, 'net_output.pickle')
37 | if os.path.isfile(net_result_path):
38 | with open(net_result_path, 'rb') as handle:
39 | inference_state_load = CPU_Unpickler(handle).load()
40 | else:
41 | raise Exception(f"Unable to load the network result: {net_result_path}")
42 |
43 | folder = args.savedir
44 | os.makedirs(folder, exist_ok=True)
45 |
46 | AllResults = []
47 |
48 | for data_conf in dataset_conf.data_list:
49 | print(data_conf)
50 | for data_name in data_conf.data_drive:
51 | print(data_conf.data_root, data_name)
52 | print("data_conf.dataroot", data_conf.data_root)
53 | print("data_name", data_name)
54 | print("data_conf.name", data_conf.name)
55 |
56 | dataset = SeqDataset(data_conf.data_root, data_name, args.device, name = data_conf.name, duration=args.seqlen, step_size=args.seqlen, drop_last=False, conf = dataset_conf)
57 | loader = Data.DataLoader(dataset=dataset, batch_size=1, collate_fn=imu_seq_collate, shuffle=False, drop_last=False)
58 |
59 | init = dataset.get_init_value()
60 | gravity = dataset.get_gravity()
61 | integrator_outstate = pp.module.IMUPreintegrator(
62 | init['pos'], init['rot'], init['vel'],gravity=gravity,
63 | reset=False
64 | ).to(args.device).double()
65 |
66 | integrator_reset = pp.module.IMUPreintegrator(
67 | init['pos'], init['rot'], init['vel'],gravity = gravity,
68 | reset=True
69 | ).to(args.device).double()
70 |
71 | outstate = integrate(
72 | integrator_outstate, loader, init,
73 | device=args.device, gtinit=False, save_full_traj=True,
74 | use_gt_rot=args.usegtrot
75 | )
76 | relative_outstate = integrate(
77 | integrator_reset, loader, init,
78 | device=args.device, gtinit=True,
79 | use_gt_rot=args.usegtrot
80 | )
81 |
82 | if args.exp is not None:
83 | inference_state = inference_state_load[data_name]
84 | dataset_inf = SeqInfDataset(data_conf.data_root, data_name, inference_state, device = args.device, name = data_conf.name,duration=args.seqlen, step_size=args.seqlen, drop_last=False, conf = dataset_conf)
85 | infloader = Data.DataLoader(dataset=dataset_inf, batch_size=1,
86 | collate_fn=imu_seq_collate,
87 | shuffle=False, drop_last=True)
88 |
89 | integrator_infstate = pp.module.IMUPreintegrator(
90 | init['pos'], init['rot'], init['vel'], gravity = gravity,
91 | reset=False
92 | ).to(args.device).double()
93 |
94 | infstate = integrate(
95 | integrator_infstate, infloader, init,
96 | device=args.device, gtinit=False, save_full_traj=True,
97 | use_gt_rot=args.usegtrot
98 | )
99 | relative_infstate = integrate(
100 | integrator_reset, infloader, init,
101 | device=args.device, gtinit=True,
102 | use_gt_rot=args.usegtrot
103 | )
104 |
105 | index_id = dataset.index_map[:, -1]
106 | mask = torch.ones(dataset.seqlen, dtype = torch.bool)
107 | select_mask = torch.ones_like(dataset.get_mask()[index_id], dtype = torch.bool)
108 |
109 | ### For the datasets with mask like TUMVI
110 | if args.mask:
111 | mask = dataset.get_mask()[:dataset.seqlen]
112 | select_mask = dataset.get_mask()[index_id]
113 | select_mask[-1] = False #3 drop last
114 |
115 | #save loss result
116 | result_dic = {
117 | 'name': data_name,
118 | 'use_gt_rot': args.usegtrot,
119 | 'AOE(raw)':180./np.pi * outstate['rot_dist'][0, mask].mean().numpy(),
120 | 'ATE(raw)':outstate['pos_dist'][0, mask].mean().item(),
121 | 'AVE(raw)':outstate['vel_dist'][0, mask].mean().item(),
122 |
123 | 'ROE(raw)':180./np.pi *relative_outstate['rot_dist'][0, select_mask].mean().numpy(),
124 | 'RTE(raw)':relative_outstate['pos_dist'][0, select_mask].mean().item(),
125 | 'RVE(raw)':relative_outstate['vel_dist'][0, select_mask].mean().item(),
126 |
127 | 'RP_RMSE(raw)': np.sqrt((relative_outstate['pos_dist'][0, select_mask]**2).mean()).numpy().item(),
128 | 'RV_RMSE(raw)': np.sqrt((relative_outstate['vel_dist'][0, select_mask]**2).mean()).numpy().item(),
129 | 'RO_RMSE(raw)':180./np.pi * torch.sqrt((relative_outstate['rot_dist'][0, select_mask]**2).mean()).item(),
130 | 'O_RMSE(raw)':180./np.pi * torch.sqrt((outstate['rot_dist'][0, mask]**2).mean()).item(),
131 |
132 |
133 | 'AOE(AirIMU)':180./np.pi * infstate['rot_dist'][0, mask].mean().numpy(),
134 | 'ATE(AirIMU)':infstate['pos_dist'][0, mask].mean().item(),
135 | 'AVE(AirIMU)':infstate['vel_dist'][0, mask].mean().item(),
136 |
137 | 'ROE(AirIMU)':180./np.pi * relative_infstate['rot_dist'][0, select_mask].mean().numpy(),
138 | 'RTE(AirIMU)':relative_infstate['pos_dist'][0, select_mask].mean().item(),
139 | 'RVE(AirIMU)':relative_infstate['vel_dist'][0, select_mask].mean().item(),
140 |
141 | 'RP_RMSE(AirIMU)': np.sqrt((relative_infstate['pos_dist'][0, select_mask]**2).mean()).item(),
142 | 'RV_RMSE(AirIMU)': np.sqrt((relative_infstate['vel_dist'][0, select_mask]**2).mean()).item(),
143 | 'RO_RMSE(AirIMU)':180./np.pi * torch.sqrt((relative_infstate['rot_dist'][0, select_mask]**2).mean()).numpy(),
144 | 'O_RMSE(AirIMU)': 180./np.pi * torch.sqrt((infstate['rot_dist'][0, mask]**2).mean()).numpy(),
145 | }
146 |
147 | AllResults.append(result_dic)
148 |
149 | print("==============Integration==============")
150 | print("outstate:")
151 | print("pos_err: ", outstate['pos_dist'].mean())
152 | print("rot_err: ", outstate['rot_dist'].mean())
153 | print("vel_err: ", outstate['vel_dist'].mean())
154 |
155 | print("relative_outstate")
156 | print("pos_err: ", relative_outstate['pos_dist'].mean())
157 | print("rot_err: ", relative_outstate['rot_dist'].mean())
158 | print("vel_err: ", relative_outstate['vel_dist'].mean())
159 |
160 | print("==============AirIMU==============")
161 | print("infstate:")
162 | print("pos_err: ", infstate['pos_dist'].mean())
163 | print("rot_err: ", infstate['rot_dist'].mean())
164 | print("vel_err: ", infstate['vel_dist'].mean())
165 |
166 | print("relatvie_infstate")
167 | print("pos_err: ", relative_infstate['pos_dist'].mean())
168 | print("rot_err: ", relative_infstate['rot_dist'].mean())
169 | print("vel_err: ", relative_infstate['vel_dist'].mean())
170 |
171 | visualize_state_error(data_name,outstate,infstate,save_folder=folder,mask=mask,file_name="inte_error_compare.png")
172 | visualize_state_error(data_name,relative_outstate,relative_infstate,mask=select_mask,save_folder=folder)
173 | visualize_rotations(data_name,outstate['orientations_gt'][0],outstate['orientations'][0],infstate['orientations'][0],save_folder=folder)
174 |
175 | file_path = os.path.join(folder, "loss_result.json")
176 | with open(file_path, 'w') as f:
177 | json.dump(AllResults, f, indent=4)
178 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import torch.utils.data as Data
5 | import argparse
6 | import pickle
7 |
8 | import tqdm
9 | from utils import move_to, save_state
10 | from pyhocon import ConfigFactory
11 |
12 | from datasets import collate_fcs, SeqeuncesDataset
13 | from model import net_dict
14 | from utils import *
15 |
16 |
17 |
18 | def inference(network, loader, confs):
19 | '''
20 | Correction inference
21 | save the corrections generated from the network.
22 | '''
23 | network.eval()
24 | evaluate_states = {}
25 |
26 | with torch.no_grad():
27 | inte_state = None
28 | for data, _, _ in tqdm.tqdm(loader):
29 | data = move_to(data, confs.device)
30 | # Use the gt init state while there is no integration.
31 | inte_state = network.inference(data)
32 | # update the corected acc and gyro
33 | save_state(evaluate_states, inte_state)
34 |
35 | for k, v in evaluate_states.items():
36 | evaluate_states[k] = torch.cat(v, dim=-2)
37 |
38 | return evaluate_states
39 |
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path')
44 | parser.add_argument('--load', type=str, default=None, help='path for model check point')
45 | parser.add_argument("--device", type=str, default="cuda:0", help="cuda or cpu")
46 | parser.add_argument('--batch_size', type=int, default=1, help='batch size.')
47 | parser.add_argument('--seqlen', type=int, default=1000, help='the length of the segment')
48 | parser.add_argument('--train', default=False, action="store_true", help='if True, We will evaluate the training set (may be removed in the future).')
49 | parser.add_argument('--gtinit', default=True, action="store_false", help='if set False, we will use the integrated pose as the intial pose for the next integral')
50 | parser.add_argument('--whole', default=False, action="store_true", help='(may be removed in the future).')
51 |
52 |
53 | args = parser.parse_args(); print(args)
54 | conf = ConfigFactory.parse_file(args.config)
55 | conf.train.device = args.device
56 | conf_name = os.path.split(args.config)[-1].split(".")[0]
57 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name)
58 | conf.train['sampling'] = False
59 | conf["gtinit"] = args.gtinit
60 | conf['device'] = args.device
61 |
62 | '''
63 | Load the pretrained model
64 | '''
65 | network = net_dict[conf.train.network](conf.train).to(args.device).double()
66 | save_folder = os.path.join(conf.general.exp_dir, "evaluate")
67 | os.makedirs(save_folder, exist_ok=True)
68 |
69 | if args.load is None:
70 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt")
71 | else:
72 | ckpt_path = os.path.join(conf.general.exp_dir, "ckpt", args.load)
73 |
74 | if os.path.exists(ckpt_path):
75 | checkpoint = torch.load(ckpt_path, map_location=torch.device(args.device))
76 | print("loaded state dict %s in epoch %i"%(ckpt_path, checkpoint["epoch"]))
77 | network.load_state_dict(checkpoint["model_state_dict"])
78 | else:
79 | raise Exception(f"No model loaded {ckpt_path}")
80 |
81 | if 'collate' in conf.dataset.keys():
82 | collate_fn = collate_fcs[conf.dataset.collate]
83 | else:
84 | collate_fn = collate_fcs['base']
85 |
86 | print(conf.dataset)
87 | dataset_conf = conf.dataset.inference
88 |
89 | '''
90 | Run and save the IMU correction
91 | '''
92 | cov_result, rmse = [], []
93 | net_out_result = {}
94 | evals = {}
95 | dataset_conf.data_list[0]["window_size"] = args.seqlen
96 | dataset_conf.data_list[0]["step_size"] = args.seqlen
97 | for data_conf in dataset_conf.data_list:
98 | for path in data_conf.data_drive:
99 | if args.whole:
100 | dataset_conf["mode"] = "inference"
101 | else:
102 | dataset_conf["mode"] = "infevaluate"
103 | dataset_conf["exp_dir"] = conf.general.exp_dir
104 | print("\n"*3 + str(dataset_conf))
105 | eval_dataset = SeqeuncesDataset(data_set_config=dataset_conf, data_path=path, data_root=data_conf["data_root"])
106 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=args.batch_size,
107 | shuffle=False, collate_fn=collate_fn, drop_last = False)
108 |
109 | inference_state = inference(network=network, loader = eval_loader, confs=conf.train)
110 | if not "acc_cov" in inference_state.keys():
111 | inference_state["acc_cov"] = torch.zeros_like(inference_state["correction_acc"])
112 | if not "gyro_cov" in inference_state.keys():
113 | inference_state["gyro_cov"] = torch.zeros_like(inference_state["correction_gyro"])
114 |
115 | inference_state['corrected_acc'] = eval_dataset.acc[0] + inference_state['correction_acc'].squeeze(0).cpu()
116 | inference_state['corrected_gyro'] = eval_dataset.gyro[0] + inference_state['correction_gyro'].squeeze(0).cpu()
117 | inference_state['rot'] = eval_dataset.gt_ori[0]
118 | inference_state['dt'] = eval_dataset.dt[0]
119 |
120 | net_out_result[path] = inference_state
121 |
122 | #### RPE and Cov analysis
123 | rpe_pos, rpe_rot, mse_pos = [], [], []
124 | relative_cov, relative_sigma_x, relative_sigma_y, relative_sigma_z = [], [], [], []
125 | dataset_conf["mode"] = "evaluate"
126 |
127 | net_result_path = os.path.join(conf.general.exp_dir, 'net_output.pickle')
128 | print("save netout, ", net_result_path)
129 | with open(net_result_path, 'wb') as handle:
130 | pickle.dump(net_out_result, handle, protocol=pickle.HIGHEST_PROTOCOL)
131 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .net import ModelBase
2 | from .cnn import CNNPOS
3 | from .others import Identity, ParamNet
4 | from .code import *
5 |
6 | net_dict = {
7 | 'codeposenet': CodePoseNet,
8 | 'codenetkitti': CodeNetKITTI,
9 | 'iden': Identity,
10 | 'cnnpos': CNNPOS,
11 | 'codenet': CodeNet,
12 | 'param': ParamNet,
13 | }
14 |
--------------------------------------------------------------------------------
/model/cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pypose as pp
4 | import torch.nn as nn
5 |
6 | from model.net import ModelBase
7 |
8 | class CNNEncoder(nn.Module):
9 | def __init__(self, duration = 1, k_list = [7, 7, 7, 7], c_list = [6, 16, 32, 64, 128],
10 | s_list = [1, 1, 1, 1], p_list = [3, 3, 3, 3]):
11 | super(CNNEncoder, self).__init__()
12 | self.duration = duration
13 | self.k_list, self.c_list, self.s_list, self.p_list = k_list, c_list, s_list, p_list
14 | layers = []
15 |
16 | for i in range(len(self.c_list) - 1):
17 | layers.append(torch.nn.Conv1d(self.c_list[i], self.c_list[i+1], self.k_list[i], \
18 | stride=self.s_list[i], padding=self.p_list[i]))
19 | layers.append(torch.nn.BatchNorm1d(self.c_list[i+1]))
20 | layers.append(torch.nn.GELU())
21 | layers.append(torch.nn.Dropout(0.1))
22 |
23 | self.net = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.net(x)
27 |
28 |
29 | class CNNcorrection(ModelBase):
30 | '''
31 | The input feature shape [B, F, Duration, 6]
32 | '''
33 | def __init__(self, conf):
34 | super(CNNcorrection, self).__init__(conf)
35 | self.k_list = [7, 7, 7, 7]
36 | self.c_list = [6, 32, 64, 128, 256]
37 |
38 | self.cnn = CNNEncoder(c_list=self.c_list, k_list=self.k_list)
39 |
40 | self.accdecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
41 | self.acccov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
42 |
43 | self.gyrodecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
44 | self.gyrocov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
45 |
46 | gyro_std = np.pi/180
47 | if "gyro_std" in conf:
48 | print(" The gyro std is set to ", conf.gyro_std, " rad/s")
49 | gyro_std = conf.gyro_std
50 | self.register_buffer('gyro_std', torch.tensor(gyro_std))
51 |
52 | acc_std = 0.1
53 | if "acc_std" in conf:
54 | print(" The acc std is set to ", conf.acc_std, " m/s^2")
55 | acc_std = conf.acc_std
56 | self.register_buffer('acc_std', torch.tensor(acc_std))
57 |
58 | def encoder(self, x):
59 | return self.cnn(x.transpose(-1,-2)).transpose(-1,-2)
60 |
61 | def decoder(self, x):
62 | acc = self.accdecoder(x) * self.acc_std
63 | gyro = self.gyrodecoder(x) * self.gyro_std
64 | coorections = torch.cat([acc, gyro], dim = -1)
65 |
66 | return coorections
67 |
68 | def cov_decoder(self, x):
69 | return self.cov_head(x).transpose(-1,-2)
70 |
71 |
72 | class CNNPOS(CNNcorrection):
73 | """
74 | Only correct the accelerometer
75 | """
76 | def __init__(self, conf):
77 | super(CNNPOS, self).__init__(conf)
78 |
79 | def decoder(self, x):
80 | acc = self.accdecoder(x) * self.acc_std
81 | gyro = torch.zeros_like(acc)
82 | coorections = torch.cat([acc, gyro], dim = -1)
83 |
84 | return coorections
--------------------------------------------------------------------------------
/model/code.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pypose as pp
4 |
5 | import torch.nn as nn
6 | from model.net import ModelBase
7 | from model.cnn import CNNEncoder
8 |
9 |
10 | class CodeNet(ModelBase):
11 | def __init__(self, conf):
12 | super().__init__(conf)
13 | self.conf = conf
14 |
15 | gyro_std = np.pi/180
16 | if "gyro_std" in conf:
17 | print(" The gyro std is set to ", conf.gyro_std, " rad/s")
18 | gyro_std = conf.gyro_std
19 | self.register_buffer('gyro_std', torch.tensor(gyro_std))
20 |
21 | acc_std = 0.1
22 | if "acc_std" in conf:
23 | print(" The acc std is set to ", conf.acc_std, " m/s^2")
24 | acc_std = conf.acc_std
25 | self.register_buffer('acc_std', torch.tensor(acc_std))
26 |
27 | ## the encoder have the same correction in one interval
28 | self.interval = 9
29 | self.inter_head = np.floor(self.interval/2.).astype(int)
30 | self.inter_tail = self.interval - self.inter_head
31 |
32 | self.cnn = CNNEncoder(c_list=[6, 32, 64], k_list=[7, 7], s_list=[3, 3])# (N,F/8,64)
33 |
34 | self.gru1 = nn.GRU(input_size = 64, hidden_size = 128, num_layers = 1, batch_first = True)
35 | self.gru2 = nn.GRU(input_size = 128, hidden_size = 256, num_layers = 1, batch_first = True)
36 |
37 | self.accdecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
38 | self.acccov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
39 |
40 | self.gyrodecoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
41 | self.gyrocov_decoder = nn.Sequential(nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 3))
42 |
43 | def encoder(self, x):
44 | x = self.cnn(x.transpose(-1,-2)).transpose(-1,-2)
45 | x, _ = self.gru1(x)
46 | x, _ = self.gru2(x)
47 |
48 | return x
49 |
50 | def cov_decoder(self, x):
51 | acc = torch.exp(self.acccov_decoder(x) - 5.)
52 | gyro = torch.exp(self.gyrocov_decoder(x) - 5.)
53 |
54 | return torch.cat([acc, gyro], dim = -1)
55 |
56 | def decoder(self, x):
57 | acc = self.accdecoder(x) * self.acc_std
58 | gyro = self.gyrodecoder(x) * self.gyro_std
59 |
60 | return torch.cat([acc, gyro], dim = -1)
61 |
62 | def _update(self, to_update, feat, frame_len):
63 | ### Note: This will change the data in the to_update !!!!!!
64 | def _clip(x,l):
65 | if x > l:
66 | return l
67 | elif x < 0:
68 | return 0
69 | else:
70 | return x
71 |
72 | _feat_range = np.ceil((frame_len-self.inter_head)/self.interval).astype(int) + 1 ## not equivalent to features shape
73 |
74 | for i in range(_feat_range):
75 | s_p = _clip(i*self.interval-self.inter_head, frame_len)
76 | e_p = _clip(i*self.interval+self.inter_tail, frame_len)
77 | idx = _clip(i, feat.shape[1]-1)
78 |
79 | # skip the first padded input
80 | to_update[:,s_p:e_p,:] += feat[:,idx:idx+1,:]
81 |
82 | return to_update
83 |
84 | def inference(self, data):
85 | frame_len = data["acc"].shape[1] - self.interval
86 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1)
87 | feature = self.encoder(feature)[:,1:,:]
88 | correction = self.decoder(feature)
89 | zero_signal = torch.zeros_like(data['acc'][:,self.interval:,:])
90 |
91 | # a referenced size 1000
92 | correction_acc = self._update(zero_signal.clone(), correction[...,:3], frame_len)
93 | correction_gyro = self._update(zero_signal.clone(), correction[...,3:], frame_len)
94 |
95 | # covariance propagation
96 | cov_state = {'acc_cov':None, 'gyro_cov': None,}
97 | if self.conf.propcov:
98 | cov = self.cov_decoder(feature)
99 | cov_state['acc_cov'] = self._update(torch.zeros_like(correction_acc, device=correction_acc.device),
100 | cov[...,:3], frame_len)
101 | cov_state['gyro_cov'] = self._update(torch.zeros_like(correction_gyro, device=correction_gyro.device),
102 | cov[...,3:], frame_len)
103 |
104 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro}
105 |
106 | def forward(self, data, init_state):
107 | inference_state = self.inference(data)
108 |
109 | data['corrected_acc'] = data['acc'][:,self.interval:,:] + inference_state['correction_acc']
110 | data['corrected_gyro'] = data['gyro'][:,self.interval:,:] + inference_state['correction_gyro']
111 |
112 | out_state = self.integrate(init_state = init_state, data = data, cov_state = inference_state['cov_state'])
113 |
114 | return {**out_state, 'correction_acc': inference_state['correction_acc'], 'correction_gyro': inference_state['correction_gyro'],
115 | 'corrected_acc': data['corrected_acc'], 'corrected_gyro': data['corrected_gyro']}
116 |
117 |
118 | class CodePoseNet(CodeNet):
119 | def __init__(self, conf):
120 | super().__init__(conf)
121 |
122 | def inference(self, data):
123 | frame_len = data["acc"].shape[1] - self.interval
124 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1)
125 | feature = self.encoder(feature)[:,1:,:]
126 | correction = self.decoder(feature)
127 | zero_signal = torch.zeros_like(data['acc'][:,self.interval:,:])
128 |
129 | # a referenced size 1000
130 | correction_acc = self._update(zero_signal.clone(), correction[...,:3], frame_len)
131 | correction_gyro = zero_signal.clone()
132 |
133 | # covariance propagation
134 | cov_state = {'acc_cov':None, 'gyro_cov': None,}
135 | if self.conf.propcov:
136 | cov = self.cov_decoder(feature)
137 | cov_state['acc_cov'] = self._update(torch.zeros_like(correction_acc, device=correction_acc.device),
138 | cov[...,:3], frame_len)
139 | cov_state['gyro_cov'] = self._update(torch.zeros_like(correction_gyro, device=correction_gyro.device),
140 | cov[...,3:], frame_len)
141 |
142 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro}
143 |
144 |
145 | class CodeNetKITTI(torch.nn.Module):
146 | def __init__(self, conf):
147 | super().__init__()
148 | self.conf = conf
149 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True).double()
150 |
151 | self.accEncoder = CNNEncoder(k_list=[7, 3, 3], p_list=[3, 1, 1], c_list=[3, 32, 64, 128])
152 | self.gyroEncoder = CNNEncoder(k_list=[7, 3, 3], p_list=[3, 1, 1], c_list=[3, 32, 64, 128])
153 |
154 | self.accDecoder = nn.Sequential(
155 | nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 32), nn.GELU(), nn.Linear(32, 3)
156 | )
157 | self.gyroDecoder = nn.Sequential(
158 | nn.Linear(128, 64), nn.GELU(), nn.Linear(64, 32), nn.GELU(), nn.Linear(32, 3)
159 | )
160 | self.accCovDecoder = nn.Sequential(
161 | nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 32), nn.GELU(), nn.Linear(32, 3)
162 | )
163 | self.gyroCovDecoder = nn.Sequential(
164 | nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 32), nn.GELU(), nn.Linear(32, 3)
165 | )
166 |
167 | gyro_std = np.pi/180
168 | self.register_buffer('gyro_std', torch.tensor(gyro_std))
169 |
170 | acc_std = 0.1
171 | self.register_buffer('acc_std', torch.tensor(acc_std))
172 |
173 | def integrate(self, init_state, data, cov_state, use_gtrot):
174 | gt_rot = None
175 | if self.conf.gtrot: gt_rot = data['rot'].double()
176 | if not use_gtrot: gt_rot = None
177 |
178 | if self.conf.propcov:
179 | out_state = self.integrator(
180 | init_state = init_state,
181 | dt = data['dt'].double(),
182 | gyro = data['corrected_gyro'].double(),
183 | acc = data['corrected_acc'].double(),
184 | rot = gt_rot,
185 | acc_cov = cov_state['acc_cov'].double(),
186 | gyro_cov = cov_state['gyro_cov'].double()
187 | )
188 | else:
189 | out_state = self.integrator(
190 | init_state = init_state,
191 | dt = data['dt'].double(),
192 | gyro = data['corrected_gyro'].double(),
193 | acc = data['corrected_acc'].double(),
194 | rot = gt_rot,
195 | )
196 |
197 | return {**out_state, **cov_state}
198 |
199 | def inference(self, data):
200 | feature_acc = self.accEncoder(data["acc"].transpose(-1,-2)).transpose(-1,-2)
201 | feature_gyro = self.gyroEncoder(data["gyro"].transpose(-1,-2)).transpose(-1,-2)
202 |
203 | correction_acc = self.accDecoder(feature_acc)
204 | correction_gyro = self.gyroDecoder(feature_gyro)
205 |
206 | cov_state = {'acc_cov':None, 'gyro_cov': None}
207 | if self.conf.propcov:
208 | feature = torch.cat([feature_acc, feature_gyro], dim = -1)
209 | cov_state['acc_cov'] = self.accCovDecoder(feature).exp()
210 | cov_state['gyro_cov'] = self.gyroCovDecoder(feature).exp()
211 |
212 | return {"cov_state": cov_state, 'correction_acc': correction_acc, 'correction_gyro': correction_gyro}
213 |
214 | def forward(self, data, init_state, use_gtrot=True):
215 | init_state_ = {
216 | "pos": init_state["pos"],
217 | "rot": init_state["rot"][:,:1,:],
218 | "vel": init_state["vel"],
219 | }
220 | inference_state = self.inference(data)
221 |
222 | data['corrected_acc'] = data['acc'] + inference_state['correction_acc']
223 | data['corrected_gyro'] = data['gyro'] + inference_state['correction_gyro']
224 |
225 | out_state = self.integrate(init_state=init_state_, data = data, cov_state = inference_state['cov_state'], use_gtrot=use_gtrot)
226 |
227 | return {
228 | **out_state,
229 | 'correction_acc': inference_state['correction_acc'],
230 | 'correction_gyro': inference_state['correction_gyro'],
231 | 'corrected_acc': data['corrected_acc'],
232 | 'corrected_gyro': data['corrected_gyro']
233 | }
234 |
--------------------------------------------------------------------------------
/model/loss_func.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | EPSILON = 1e-7
4 |
5 | def diag_cov_loss(dist, pred_cov):
6 | error = (dist).pow(2)
7 | return torch.mean(error / 2*(torch.exp(2 * pred_cov)) + pred_cov)
8 |
9 | def diag_ln_cov_loss(dist, pred_cov, use_epsilon=False):
10 | error = (dist).pow(2)
11 | if use_epsilon: l = ((error / pred_cov) + torch.log(pred_cov + EPSILON))
12 | else: l = ((error / pred_cov) + torch.log(pred_cov))
13 | return l.mean()
14 |
15 | def L2(dist):
16 | error = dist.pow(2)
17 | return torch.mean(error)
18 |
19 | def L1(dist):
20 | error = (dist).abs().mean()
21 | return error
22 |
23 | def loss_weight_decay(error, decay_rate = 0.95):
24 | F = error.shape[-2]
25 | decay_list = decay_rate * torch.ones(F, device=error.device, dtype=error.dtype)
26 | decay_list[0] = 1.
27 | decay_list = torch.cumprod(decay_list, 0)
28 | error = torch.einsum('bfc, f -> bfc', error, decay_list)
29 | return error
30 |
31 | def loss_weight_decrease(error, decay_rate = 0.95):
32 | F = error.shape[-2]
33 | decay_list = torch.tensor([1./i for i in range(1, F+1)])
34 | error = torch.einsum('bfc, f -> bfc', error, decay_list)
35 | return error
36 |
37 | def Huber(dist, delta=0.005):
38 | error = torch.nn.functional.huber_loss(dist, torch.zeros_like(dist, device=dist.device), delta=delta)
39 | return error
40 |
41 |
42 | loss_fc_list = {
43 | "L2": L2,
44 | "L1": L1,
45 | "diag_cov_ln": diag_ln_cov_loss,
46 | "Huber_loss005":lambda dist: Huber(dist, delta = 0.005),
47 | "Huber_loss05":lambda dist: Huber(dist, delta = 0.05),
48 | }
--------------------------------------------------------------------------------
/model/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .loss_func import loss_fc_list, diag_ln_cov_loss
3 | from utils import report_hasNan
4 |
5 |
6 | def loss_(fc, pred, targ, sampling = None, dtype = 'trans'):
7 | ## reshape or sample the input targ and pred
8 | ## cov and error is for reference
9 | if sampling:
10 | pred = pred[:,sampling-1::sampling,:]
11 | targ = targ[:,sampling-1::sampling,:]
12 | else:
13 | pred = pred[:,-1:,:]
14 | targ = targ[:,-1:,:]
15 |
16 | if dtype == 'rot':
17 | dist = (pred * targ.Inv()).Log()
18 | else:
19 | dist = pred - targ
20 | loss = fc(dist)
21 | return loss, dist
22 |
23 |
24 | def get_loss(inte_state, data, confs):
25 | ## The state loss for evaluation
26 | loss, state_losses, cov_losses = 0, {}, {}
27 | loss_fc = loss_fc_list[confs.loss]
28 | rotloss_fc = loss_fc_list[confs.rotloss]
29 |
30 | rot_loss, rot_dist = loss_(rotloss_fc, inte_state['rot'], data['gt_rot'], sampling = confs.sampling, dtype='rot')
31 | vel_loss, vel_dist = loss_(loss_fc, inte_state['vel'], data['gt_vel'], sampling = confs.sampling)
32 | pos_loss, pos_dist = loss_(loss_fc, inte_state['pos'], data['gt_pos'], sampling = confs.sampling)
33 |
34 | state_losses['pos'] = pos_dist[:,-1,:].norm(dim=-1).mean()
35 | state_losses['rot'] = rot_dist[:,-1,:].norm(dim=-1).mean()
36 | state_losses['vel'] = vel_dist[:,-1,:].norm(dim=-1).mean()
37 |
38 | # Apply the covariance loss
39 | if confs.propcov:
40 | cov_diag = torch.diagonal(inte_state['cov'], dim1=-2, dim2=-1)
41 | cov_losses['pred_cov_rot'] = cov_diag[...,:3].mean()
42 | cov_losses['pred_cov_vel'] = cov_diag[...,3:6].mean()
43 | cov_losses['pred_cov_pos'] = cov_diag[...,-3:].mean()
44 |
45 | if "covaug" in confs and confs["covaug"] is True:
46 | rot_loss += confs.cov_weight * diag_ln_cov_loss(rot_dist, cov_diag[...,:3])
47 | vel_loss += confs.cov_weight * diag_ln_cov_loss(vel_dist, cov_diag[...,3:6])
48 | pos_loss += confs.cov_weight * diag_ln_cov_loss(pos_dist, cov_diag[...,-3:])
49 | else:
50 | rot_loss += confs.cov_weight * diag_ln_cov_loss(rot_dist.detach(), cov_diag[...,:3])
51 | vel_loss += confs.cov_weight * diag_ln_cov_loss(vel_dist.detach(), cov_diag[...,3:6])
52 | pos_loss += confs.cov_weight * diag_ln_cov_loss(pos_dist.detach(), cov_diag[...,-3:])
53 |
54 | loss += (confs.pos_weight * pos_loss + confs.rot_weight * rot_loss + confs.vel_weight * vel_loss)
55 | # report_hasNan(loss)
56 |
57 | return {'loss':loss, **state_losses, **cov_losses}
58 |
59 |
60 | def get_RMSE(inte_state, data):
61 | '''
62 | get the RMSE of the last state in one segment
63 | '''
64 | def _RMSE(x):
65 | return torch.sqrt((x.norm(dim=-1)**2).mean())
66 |
67 | dist_pos = (inte_state['pos'][:,-1,:] - data['gt_pos'][:,-1,:])
68 | dist_vel = (inte_state['vel'][:,-1,:] - data['gt_vel'][:,-1,:])
69 | dist_rot = (data['gt_rot'][:,-1,:] * inte_state['rot'][:,-1,:].Inv()).Log()
70 |
71 | pos_loss = _RMSE(dist_pos)[None,...]
72 | vel_loss = _RMSE(dist_vel)[None,...]
73 | rot_loss = _RMSE(dist_rot)[None,...]
74 |
75 | ## Relative pos error
76 | return {'pos': pos_loss, 'rot': rot_loss, 'vel': vel_loss,
77 | 'pos_dist': dist_pos.norm(dim=-1).mean(),
78 | 'vel_dist': dist_vel.norm(dim=-1).mean(),
79 | 'rot_dist': dist_rot.norm(dim=-1).mean(),}
80 |
--------------------------------------------------------------------------------
/model/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pypose as pp
3 | import torch.nn as nn
4 |
5 | class ModelBase(nn.Module):
6 | def __init__(self, conf):
7 | super().__init__()
8 | self.conf = conf
9 | if "ngravity" in conf.keys():
10 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True, gravity = 0.0)
11 | print("conf.ngravity", conf.ngravity, self.integrator.gravity)
12 | else:
13 | self.integrator = pp.module.IMUPreintegrator(prop_cov=conf.propcov, reset=True)
14 | print("network constructed: ", self.conf.network, "gtrot: ", self.conf.gtrot)
15 |
16 | def _select(self, data, start, end):
17 |
18 | select = {}
19 | for k in data.keys():
20 | if data[k] is None:
21 | select[k] = None
22 | else:
23 | select[k] = data[k][:, start:end]
24 | return select
25 |
26 | def integrate(self, init_state, data, cov_state):
27 | B, F = data["corrected_acc"].shape[:2]
28 | inte_pos, inte_vel, inte_rot, inte_cov = [], [], [], []
29 | gt_rot = None
30 | if self.conf.gtrot:
31 | gt_rot = data['rot']
32 | if "posonly" in self.conf.keys():
33 | data['corrected_gyro'] = data['gyro']
34 |
35 | if self.conf.sampling:
36 | inte_state = None
37 | for iter in range(0, F, self.conf.sampling):
38 | if (F - iter) < self.conf.sampling: continue
39 | start, end = iter, iter + self.conf.sampling
40 | selected_data = self._select(data, start, end)
41 | selected_cov_state = self._select(cov_state, start, end)
42 |
43 | # take the init sate from last frame as the init state of the next frame
44 | if inte_state is not None:
45 | init_state = {
46 | "pos": inte_state["pos"][:,-1:,:],
47 | "vel": inte_state["vel"][:,-1:,:],
48 | "rot": inte_state["rot"][:,-1:,:],
49 | }
50 | if self.conf.propcov:
51 | init_state["Rij"] = inte_state["Rij"]
52 | init_state["cov"] = inte_state["cov"]
53 |
54 | if self.conf.gtrot:
55 | gt_rot = selected_data['rot']
56 |
57 | ## starting point and ending point
58 | inte_state = self.integrator(init_state = init_state, dt = selected_data['dt'], gyro = selected_data['corrected_gyro'],
59 | acc = selected_data['corrected_acc'], rot = gt_rot, acc_cov = selected_cov_state['acc_cov'], gyro_cov = selected_cov_state['gyro_cov'])
60 |
61 | inte_pos.append(inte_state['pos'])
62 | inte_rot.append(inte_state['rot'])
63 | inte_vel.append(inte_state['vel'])
64 | inte_cov.append(inte_state['cov'])
65 |
66 | out_state ={
67 | 'pos': torch.cat(inte_pos, dim =1),
68 | 'vel': torch.cat(inte_vel, dim =1),
69 | 'rot': torch.cat(inte_rot, dim =1),
70 | }
71 | if self.conf.propcov:
72 | out_state['cov'] = torch.stack(inte_cov, dim =1)
73 | else:
74 |
75 | out_state = self.integrator(init_state = init_state, dt = data['dt'], gyro = data['corrected_gyro'],
76 | acc = data['corrected_acc'], rot = gt_rot, acc_cov = cov_state['acc_cov'], gyro_cov = cov_state['gyro_cov'])
77 |
78 | return {**out_state, **cov_state}
79 |
80 | def inference(self, data):
81 | '''
82 | Pure inference, generate the network output.
83 | '''
84 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1)
85 | feature = self.encoder(feature)
86 | correction = self.decoder(feature)
87 |
88 | # Correction update
89 | data['corrected_acc'] = correction[...,:3] + data["acc"]
90 | data['corrected_gyro'] = correction[...,3:] + data["gyro"]
91 |
92 | # covariance propagation
93 | cov_state = {'acc_cov':None, 'gyro_cov': None,}
94 | if self.conf.propcov:
95 | cov = self.cov_decoder(feature)
96 | cov_state['acc_cov'] = cov[...,:3]; cov_state['gyro_cov'] = cov[...,3:]
97 |
98 | return {**cov_state, 'correction_acc': correction[...,:3], 'correction_gyro': correction[...,3:]}
99 |
100 | ## For reference
101 | def forward(self, data, init_state):
102 | feature = torch.cat([data["acc"], data["gyro"]], dim = -1)
103 | feature = self.encoder(feature)
104 | correction = self.decoder(feature)
105 |
106 | # Correction update
107 | data['corrected_acc'] = correction[...,:3] + data["acc"]
108 | data['corrected_gyro'] = correction[...,3:] + data["gyro"]
109 |
110 | # covariance propagation
111 | cov_state = {'acc_cov':None, 'gyro_cov': None,}
112 | if self.conf.propcov:
113 | cov = self.cov_decoder(feature)
114 | cov_state['acc_cov'] = cov[...,:3]; cov_state['gyro_cov'] = cov[...,3:]
115 |
116 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state)
117 | return {**out_state, 'correction_acc': correction[...,:3], 'correction_gyro': correction[...,3:]}
118 |
--------------------------------------------------------------------------------
/model/others.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from model.net import ModelBase
5 |
6 | ## For the Baseline
7 | class Identity(ModelBase):
8 | def __init__(self, conf, acc_cov = None, gyro_cov=None):
9 | super().__init__(conf)
10 | self.register_parameter('param', nn.Parameter(torch.zeros(1), requires_grad=True))
11 | self.acc_cov, self.gyro_cov = acc_cov, gyro_cov
12 |
13 | def inference(self, data):
14 | return {'acc_cov':self.acc_cov, 'gyro_cov': self.gyro_cov,
15 | 'correction_acc': torch.zeros_like(data['acc']),
16 | 'correction_gyro': torch.zeros_like(data['gyro'])}
17 |
18 | def forward(self, data, init_state):
19 | data['corrected_acc'] = data["acc"]
20 | data['corrected_gyro'] = data["gyro"]
21 | cov_state = {'acc_cov':self.acc_cov, 'gyro_cov': self.gyro_cov}
22 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state)
23 | return out_state
24 |
25 |
26 | class ParamNet(ModelBase):
27 | def __init__(self, conf, acc_cov = None, gyro_cov=None):
28 | super().__init__(conf)
29 | self.acc_cov, self.gyro_cov = acc_cov, gyro_cov
30 | self.gyro_bias, self.acc_bias = torch.nn.Parameter(torch.zeros(3)), torch.nn.Parameter(torch.zeros(3))
31 | self.gyro_cov, self.acc_cov = torch.nn.Parameter(torch.ones(3)), torch.nn.Parameter(torch.ones(3))
32 |
33 | def forward(self, data, init_state):
34 | data['corrected_acc'] = data["acc"] + self.acc_bias
35 | data['corrected_gyro'] = data["gyro"] + self.gyro_bias
36 |
37 | # covariance propagation
38 | cov_state = {'acc_cov':None, 'gyro_cov': None,}
39 | if self.conf.propcov:
40 | cov_state['acc_cov'] = torch.zeros_like(data['corrected_acc']) + self.acc_cov**2
41 | cov_state['gyro_cov'] = torch.zeros_like(data['corrected_gyro']) + self.gyro_cov**2
42 |
43 | out_state = self.integrate(init_state = init_state, data = data, cov_state = cov_state)
44 | return {**out_state, 'correction_acc': self.acc_bias, 'correction_gyro': self.gyro_bias}
45 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | import torch.utils.data as Data
6 | from torch.optim.lr_scheduler import ReduceLROnPlateau
7 | import argparse
8 |
9 | import tqdm, wandb
10 | from utils import move_to
11 | from model import net_dict
12 | from pyhocon import ConfigFactory
13 | from pyhocon import HOCONConverter as conf_convert
14 |
15 | from datasets import SeqeuncesDataset, collate_fcs
16 | from model.losses import get_loss
17 | from eval import evaluate
18 |
19 | torch.autograd.set_detect_anomaly(True)
20 |
21 | def train(network, loader, confs, epoch, optimizer):
22 | """
23 | Train network for one epoch using a specified data loader
24 | Outputs all targets, predicts, predicted covariance params, and losses
25 | """
26 | network.train()
27 | losses, pos_losses, rot_losses, vel_losses = 0, 0, 0, 0
28 | pred_cov_rot, pred_cov_vel, pred_cov_pos = 0, 0, 0
29 | acc_covs, gyro_covs = 0, 0
30 |
31 | t_range = tqdm.tqdm(loader)
32 | for i, (data, init_state, label) in enumerate(t_range):
33 | data, init_state, label = move_to([data, init_state, label], confs.device)
34 | inte_state = network(data, init_state)
35 | loss_state = get_loss(inte_state, label, confs)
36 |
37 | # statistics
38 | losses += loss_state['loss'].item()
39 | pos_losses += loss_state['pos'].item()
40 | rot_losses += loss_state['rot'].item()
41 | vel_losses += loss_state['vel'].item()
42 |
43 | if confs.propcov:
44 | acc_covs += inte_state["acc_cov"].mean().item()
45 | gyro_covs += inte_state["gyro_cov"].mean().item()
46 | pred_cov_pos += loss_state['pred_cov_pos'].mean().item()
47 | pred_cov_rot += loss_state['pred_cov_rot'].mean().item()
48 | pred_cov_vel += loss_state['pred_cov_vel'].mean().item()
49 | t_range.set_description(f'training epoch: %03d, losses: %.06f, position, %.06f rotation %.06f, pred_rot %.06f, pred_cov%.06f'%(epoch, \
50 | loss_state['loss'], (pos_losses/(i+1)), (rot_losses/(i+1)), \
51 | loss_state['pred_cov_rot'], loss_state['pred_cov_pos']))
52 |
53 | else:
54 | t_range.set_description(f'training epoch: %03d, losses: %.06f, position, %.06f rotation %.06f, velocity %.06f'%(epoch, \
55 | loss_state['loss'], (pos_losses/(i+1)), (rot_losses/(i+1)), loss_state['vel']))
56 |
57 | t_range.refresh()
58 | optimizer.zero_grad()
59 | loss_state['loss'].backward()
60 | optimizer.step()
61 |
62 | return {"loss": (losses/(i+1)), "pos_loss": (pos_losses/(i+1)), "rot_loss": (rot_losses/(i+1)), "vel_loss":((vel_losses)/(i+1)),
63 | "pred_cov_rot": (pred_cov_rot/(i+1)), "pred_cov_vel": (pred_cov_vel/(i+1)), "pred_cov_pos": (pred_cov_pos/(i+1))}
64 |
65 |
66 | def test(network, loader, confs):
67 | network.eval()
68 | with torch.no_grad():
69 | losses, pos_losses, rot_losses, vel_losses = 0, 0, 0, 0
70 | pred_cov_rot, pred_cov_vel, pred_cov_pos = 0, 0, 0
71 | acc_covs, gyro_covs = [], []
72 |
73 | t_range = tqdm.tqdm(loader)
74 | for i, (data, init_state, label) in enumerate(t_range):
75 |
76 | data, init_state, label = move_to([data, init_state, label], confs.device)
77 | inte_state = network(data, init_state)
78 |
79 | loss_state = get_loss(inte_state, label, confs)
80 | # statistics
81 | losses += loss_state['loss'].item()
82 | pos_losses += loss_state["pos"].item()
83 | rot_losses += loss_state["rot"].item()
84 | vel_losses += loss_state['vel'].item()
85 |
86 | if confs.propcov:
87 | acc_covs.append(inte_state["acc_cov"].reshape(-1))
88 | gyro_covs.append(inte_state["gyro_cov"].reshape(-1))
89 | pred_cov_pos += loss_state['pred_cov_pos'].mean().item()
90 | pred_cov_rot += loss_state['pred_cov_rot'].mean().item()
91 | pred_cov_vel += loss_state['pred_cov_vel'].mean().item()
92 |
93 | t_range.set_description(f'testing losses: %.06f, position, %.06f rotation %.06f, vel %.06f'%(losses/(i+1), \
94 | pos_losses/(i+1), rot_losses/(i+1), vel_losses/(i+1)))
95 | t_range.refresh()
96 |
97 | if acc_covs:
98 | acc_covs = torch.cat(acc_covs)
99 | if gyro_covs:
100 | gyro_covs = torch.cat(gyro_covs)
101 |
102 | return {"loss": (losses/(i+1)), "pos_loss":(pos_losses/(i+1)), "rot_loss":(rot_losses/(i+1)), "vel_loss":(vel_losses/(i+1)),
103 | "pred_cov_rot": (pred_cov_rot/(i+1)), "pred_cov_vel": (pred_cov_vel/(i+1)), "pred_cov_pos": (pred_cov_pos/(i+1)), "acc_covs": acc_covs, "gyro_covs": gyro_covs}
104 |
105 |
106 | def write_wandb(header, objs, epoch_i):
107 | if isinstance(objs, dict):
108 | for k, v in objs.items():
109 | if isinstance(v, float):
110 | wandb.log({os.path.join(header, k): v}, epoch_i)
111 | else:
112 | wandb.log({header: objs}, step = epoch_i)
113 |
114 |
115 | def save_ckpt(network, optimizer, scheduler, epoch_i, test_loss, conf, save_best = False):
116 | if epoch_i%conf.train.save_freq==conf.train.save_freq-1:
117 | torch.save({
118 | 'epoch': epoch_i,
119 | 'model_state_dict': network.state_dict(),
120 | 'optimizer_state_dict': optimizer.state_dict(),
121 | 'scheduler_state_dict': scheduler.state_dict(),
122 | 'best_loss': test_loss,
123 | }, os.path.join(conf.general.exp_dir, "ckpt/%04d.ckpt"%epoch_i))
124 |
125 | if save_best:
126 | print("saving the best model", test_loss)
127 | torch.save({
128 | 'epoch': epoch_i,
129 | 'model_state_dict': network.state_dict(),
130 | 'optimizer_state_dict': optimizer.state_dict(),
131 | 'scheduler_state_dict': scheduler.state_dict(),
132 | 'best_loss': test_loss,
133 | }, os.path.join(conf.general.exp_dir, "ckpt/best_model.ckpt"))
134 |
135 | torch.save({
136 | 'epoch': epoch_i,
137 | 'model_state_dict': network.state_dict(),
138 | 'optimizer_state_dict': optimizer.state_dict(),
139 | 'scheduler_state_dict': scheduler.state_dict(),
140 | 'best_loss': test_loss,
141 | }, os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt"))
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument('--config', type=str, default='configs/exp/EuRoC/codenet.conf', help='config file path')
147 | parser.add_argument('--device', type=str, default="cuda:0", help="cuda or cpu, Default is cuda:0")
148 | parser.add_argument('--load_ckpt', default=False, action="store_true", help="If True, try to load the newest.ckpt in the \
149 | exp_dir specificed in our config file.")
150 | parser.add_argument('--log', default=True, action="store_false", help="if True, save the meta data with wandb")
151 | args = parser.parse_args(); print(args)
152 | conf = ConfigFactory.parse_file(args.config)
153 | # torch.cuda.set_device(args.device)
154 |
155 | conf.train.device = args.device
156 | exp_folder = os.path.split(conf.general.exp_dir)[-1]
157 | conf_name = os.path.split(args.config)[-1].split(".")[0]
158 | conf['general']['exp_dir'] = os.path.join(conf.general.exp_dir, conf_name)
159 |
160 | if 'collate' in conf.dataset.keys():
161 | collate_fn = collate_fcs[conf.dataset.collate]
162 | else:
163 | collate_fn = collate_fcs['base']
164 |
165 | train_dataset = SeqeuncesDataset(data_set_config=conf.dataset.train)
166 | test_dataset = SeqeuncesDataset(data_set_config=conf.dataset.test)
167 | eval_dataset = SeqeuncesDataset(data_set_config=conf.dataset.eval)
168 | train_loader = Data.DataLoader(dataset=train_dataset, batch_size=conf.train.batch_size, shuffle=True, collate_fn=collate_fn)
169 | test_loader = Data.DataLoader(dataset=test_dataset, batch_size=conf.train.batch_size, shuffle=False, collate_fn=collate_fn)
170 | eval_loader = Data.DataLoader(dataset=eval_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
171 |
172 | os.makedirs(os.path.join(conf.general.exp_dir, "ckpt"), exist_ok=True)
173 | with open(os.path.join(conf.general.exp_dir, "parameters.yaml"), "w") as f:
174 | f.write(conf_convert.to_yaml(conf))
175 |
176 | if not args.log:
177 | wandb.disabled = True
178 | print("wandb is disabled")
179 | else:
180 | wandb.init(project= "AirIMU_" + exp_folder,
181 | config= conf.train,
182 | group = conf.train.network,
183 | name = conf_name,)
184 |
185 | ## optimizer
186 | network = net_dict[conf.train.network](conf.train).to(device = args.device, dtype = train_dataset.get_dtype())
187 | optimizer = torch.optim.Adam(network.parameters(), lr = conf.train.lr, weight_decay=conf.train.weight_decay) # to use with ViTs
188 | scheduler = ReduceLROnPlateau(optimizer, 'min', factor = conf.train.factor, patience = conf.train.patience, min_lr = conf.train.min_lr)
189 | best_loss = np.inf
190 | epoch = 0
191 |
192 | ## load the chkp if there exist
193 | if args.load_ckpt:
194 | if os.path.isfile(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt")):
195 | checkpoint = torch.load(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt"), map_location = args.device)
196 | network.load_state_dict(checkpoint["model_state_dict"])
197 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
198 | scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
199 | epoch = checkpoint['epoch']
200 | best_loss = checkpoint['best_loss']
201 | print("loaded state dict %s best_loss %f"%(os.path.join(conf.general.exp_dir, "ckpt/newest.ckpt"), best_loss))
202 | else:
203 | print("Can't find the checkpoint")
204 |
205 | for epoch_i in range(epoch, conf.train.max_epoches):
206 | train_loss = train(network, train_loader, conf.train, epoch_i, optimizer)
207 | test_loss = test(network, test_loader, conf.train)
208 | print("train loss: %f test loss: %f"%(train_loss["loss"], test_loss["loss"]))
209 |
210 | # save the training meta information
211 | if args.log:
212 | write_wandb("train", train_loss, epoch_i)
213 | write_wandb("test", test_loss, epoch_i)
214 | write_wandb("lr", scheduler.optimizer.param_groups[0]['lr'], epoch_i)
215 |
216 | if epoch_i%conf.train.eval_freq == conf.train.eval_freq-1:
217 | eval_state = evaluate(network=network, loader = eval_loader, confs=conf.train)
218 | if args.log:
219 | write_wandb('eval/pos_loss', eval_state['loss']['pos'].mean(), epoch_i)
220 | write_wandb('eval/rot_loss', eval_state['loss']['rot'].mean(), epoch_i)
221 | write_wandb('eval/vel_loss', eval_state['loss']['vel'].mean(), epoch_i)
222 | write_wandb('eval/rot_dist', eval_state['loss']['rot_dist'].mean(), epoch_i)
223 | write_wandb('eval/vel_dist', eval_state['loss']['vel_dist'].mean(), epoch_i)
224 | write_wandb('eval/pos_dist', eval_state['loss']['pos_dist'].mean(), epoch_i)
225 |
226 | print("eval pos: %f eval rot: %f"%(eval_state['loss']['pos'].mean(), eval_state['loss']['rot'].mean()))
227 |
228 | scheduler.step(test_loss['loss'])
229 | if test_loss['loss'] < best_loss:
230 | best_loss = test_loss['loss'];save_best = True
231 | else:
232 | save_best = False
233 |
234 | save_ckpt(network, optimizer, scheduler, epoch_i, best_loss, conf, save_best=save_best,)
235 |
236 | wandb.finish()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .visualize import *
3 | from .integrate import *
--------------------------------------------------------------------------------
/utils/deferentiate_vel.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | import numpy as np
3 |
4 | def interp_xyz(time, opt_time, xyz):
5 |
6 | intep_x = np.interp(time, xp=opt_time, fp = xyz[:,0])
7 | intep_y = np.interp(time, xp=opt_time, fp = xyz[:,1])
8 | intep_z = np.interp(time, xp=opt_time, fp = xyz[:,2])
9 |
10 | inte_xyz = np.stack([intep_x, intep_y, intep_z]).transpose()
11 | return inte_xyz
12 |
13 | def gradientvelo(xyz, imu_time, time):
14 |
15 | inte_xyz = interp_xyz(imu_time, time, xyz)
16 | time_interval = imu_time[1:] - imu_time[:-1]
17 | time_interval = np.append(time_interval, time_interval.mean())
18 | velo_d = np.einsum('nd, n -> nd', np.gradient(inte_xyz, axis=0), 1/time_interval)
19 |
20 | return velo_d
21 |
22 |
23 | if __name__ == '__main__':
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--root', type=str, default='/data/datasets/yuhengq/tumvio')
26 | parser.add_argument('--seq', nargs='+', default=["dataset-room1_512_16", "dataset-room2_512_16", "dataset-room3_512_16",
27 | "dataset-room4_512_16", "dataset-room5_512_16", "dataset-room6_512_16"])
28 | parser.add_argument("--device", type=str, default='cuda:0', help="cuda or cpu")
29 | parser.add_argument('--load_ckpt', default=False, action="store_true")
30 |
31 | args = parser.parse_args(); print(args)
32 |
33 | for seq in args.seq:
34 | print(seq)
35 | gt_data = np.loadtxt(os.path.join(args.root, seq, "mav0/mocap0/data.csv"), dtype=float, delimiter=',')
36 | imu_data = np.loadtxt(os.path.join(args.root, seq, "mav0/imu0/data.csv"), dtype=float, delimiter=',')
37 |
38 | gt_time = gt_data[:,0]*1e-9
39 | xyz = gt_data[:,1:4]
40 |
41 | imu_time = imu_data[:, 0]*1e-9
42 | acc = imu_data[:, 4:]
43 |
44 | gt_velo = gradientvelo(xyz, imu_time, gt_time)
45 | to_save = np.concatenate([imu_time[:,None], gt_velo], 1)
46 | print("saving to ", os.path.join(args.root, seq, "mav0/mocap0/grad_velo.txt"))
47 | np.savetxt(os.path.join(args.root, seq, "mav0/mocap0/grad_velo.txt"), to_save)
48 |
49 |
--------------------------------------------------------------------------------
/utils/integrate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 | from utils import move_to
4 |
5 |
6 | def integrate(integrator, loader, init, device="cpu", gtinit=False, save_full_traj=False, use_gt_rot=True):
7 | """
8 | gtinit:
9 | If gtinit is True, use the ground truth initial state to integrate the small segments.
10 | This is used for the evaluation of the local pattern of the trajectory.
11 | If gtinit is False, use the predicted initial state to integrate the segments.
12 | Which is equivalent to integrate through the entire trajectory.
13 | save_full_traj:
14 | If save_full_traj is True, save the full trajectory.
15 | If save_full_traj is False, save the last frame of each segment.
16 | """
17 | # states to ouput
18 | integrator.eval()
19 | out_state = dict()
20 | poses, poses_gt = [init['pos'][None,:]], [init['pos'][None,:]]
21 | orientations,orientations_gt = [init['rot'][None,:]], [init['rot'][None,:]]
22 | vel, vel_gt = [init['vel'][None,:]], [init['vel'][None,:]]
23 | covs = [torch.zeros(9, 9)]
24 | for idx, data in tqdm.tqdm(enumerate(loader)):
25 | data = move_to(data, device)
26 | if gtinit:
27 | init_state = {
28 | "pos": data["init_pos"][:,:1,:],
29 | "vel": data["init_vel"][:,:1,:],
30 | "rot": data["init_rot"][:,:1,:],
31 | }
32 | else:
33 | init_state = None
34 |
35 | init_rot = data['init_rot'] if use_gt_rot else None
36 | state = integrator(
37 | init_state = init_state, dt=data['dt'],
38 | gyro=data['gyro'], acc=data['acc'],
39 | rot=init_rot
40 | )
41 |
42 | if save_full_traj:
43 | vel.append(state['vel'][..., :, :].cpu())
44 | vel_gt.append(data['gt_vel'][..., :, :].cpu())
45 | orientations.append(state['rot'][..., :, :].cpu())
46 | orientations_gt.append(data['gt_rot'][..., :, :].cpu())
47 | poses_gt.append(data['gt_pos'][..., :, :].cpu())
48 | poses.append(state['pos'][..., :, :].cpu())
49 | else:
50 | vel.append(state['vel'][..., -1:, :].cpu())
51 | vel_gt.append(data['gt_vel'][..., -1:, :].cpu())
52 | orientations.append(state['rot'][..., -1:, :].cpu())
53 | orientations_gt.append(data['gt_rot'][..., -1:, :].cpu())
54 | poses_gt.append(data['gt_pos'][..., -1:, :].cpu())
55 | poses.append(state['pos'][..., -1:, :].cpu())
56 |
57 |
58 | covs.append(state['cov'][..., -1, :, :].cpu())
59 | out_state['vel'] = torch.cat(vel, dim=-2)
60 | out_state['vel_gt'] = torch.cat(vel_gt, dim=-2)
61 |
62 | out_state['orientations'] = torch.cat(orientations, dim=-2)
63 | out_state['orientations_gt'] = torch.cat(orientations_gt, dim=-2)
64 |
65 | out_state['poses'] = torch.cat(poses, dim=-2)
66 | out_state['poses_gt'] = torch.cat(poses_gt, dim=-2)
67 |
68 | out_state['covs'] = torch.stack(covs, dim=0)
69 | out_state['pos_dist'] = (out_state['poses'][:, 1:, :] - out_state['poses_gt'][:, 1:, :]).norm(dim=-1)
70 | out_state['vel_dist'] = (out_state['vel'][:, 1:, :] - out_state['vel_gt'][:, 1:, :]).norm(dim=-1)
71 | out_state['rot_dist'] = ((out_state['orientations_gt'][:, 1:, :].Inv() @ out_state['orientations'][:, 1:, :]).Log()).norm(dim=-1)
72 | return out_state
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os, warnings
2 | import torch
3 | import io, pickle
4 | import numpy as np
5 | from pyproj import Proj, transform
6 | from inspect import currentframe, getframeinfo
7 |
8 | from scipy.spatial.transform import Rotation as R
9 | from scipy.spatial.transform import Slerp
10 |
11 | def save_state(out_states:dict, in_state:dict):
12 | for k, v in in_state.items():
13 | if v is None:
14 | continue
15 | elif isinstance(v, dict):
16 | save_state(out_states=out_states, in_state=v)
17 | elif k in out_states.keys():
18 | out_states[k].append(v)
19 | else:
20 | out_states[k] = [v]
21 |
22 | def trans_ecef2wgs(traj):
23 | # Define the coordinate reference systems
24 | ecef = Proj(proj='geocent', ellps='WGS84', datum='WGS84')
25 | wgs84 = Proj(proj='latlong', datum='WGS84')
26 |
27 | # Convert ECEF coordinates to geographic coordinates
28 | trajectory_geo = [transform(ecef, wgs84, x, y, z, radians=False) for x, y, z in traj]
29 |
30 | return trajectory_geo
31 |
32 | def Gaussian_noise(num_nodes, sigma_x=0.05 ,sigma_y=0.05, sigma_z=0.05):
33 | std = torch.stack([torch.ones(num_nodes)*sigma_x, torch.ones(num_nodes)*sigma_y, torch.ones(num_nodes)*sigma_z], dim=-1)
34 | return torch.normal(mean = 0, std = std)
35 |
36 | def move_to(obj, device):
37 | if torch.is_tensor(obj):return obj.to(device)
38 | elif obj is None:
39 | return None
40 | elif isinstance(obj, dict):
41 | res = {}
42 | for k, v in obj.items():
43 | res[k] = move_to(v, device)
44 | return res
45 | elif isinstance(obj, list):
46 | res = []
47 | for v in obj:
48 | res.append(move_to(v, device))
49 | return res
50 | elif isinstance(obj, np.ndarray):
51 | return torch.tensor(obj).to(device)
52 | else:
53 | raise TypeError("Invalid type for move_to", type(obj))
54 |
55 | def qinterp(qs, t, t_int):
56 | qs = R.from_quat(qs.numpy())
57 | slerp = Slerp(t, qs)
58 | interp_rot = slerp(t_int).as_quat()
59 | return torch.tensor(interp_rot)
60 |
61 | def lookAt(dir_vec, up = torch.tensor([0.,0.,1.], dtype=torch.float64), source = torch.tensor([0.,0.,0.], dtype=torch.float64)):
62 | '''
63 | dir_vec: the tensor shall be (1)
64 | return the rotation matrix of the
65 | '''
66 | if not isinstance(dir_vec, torch.Tensor):
67 | dir_vec = torch.tensor(dir_vec)
68 | def normalize(x):
69 | length = x.norm()
70 | if length< 0.005:
71 | length = 1
72 | warnings.warn('Normlization error that the norm is too small')
73 | return x/length
74 |
75 | zaxis = normalize(dir_vec - source)
76 | xaxis = normalize(torch.cross(zaxis, up))
77 | yaxis = torch.cross(xaxis, zaxis)
78 |
79 | m = torch.tensor([
80 | [xaxis[0], xaxis[1], xaxis[2]],
81 | [yaxis[0], yaxis[1], yaxis[2]],
82 | [zaxis[0], zaxis[1], zaxis[2]],
83 | ])
84 |
85 | return m
86 |
87 | def cat_state(in_state:dict):
88 | pop_list = []
89 | for k, v in in_state.items():
90 | if len(v[0].shape) > 2:
91 | in_state[k] = torch.cat(v, dim=-2)
92 | else:
93 | pop_list.append(k)
94 | for k in pop_list:
95 | in_state.pop(k)
96 |
97 | class CPU_Unpickler(pickle.Unpickler):
98 | def find_class(self, module, name):
99 | if module == 'torch.storage' and name == '_load_from_bytes':
100 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
101 | else:
102 | return super().find_class(module, name)
103 |
104 | def write_board(writer, objs, epoch_i, header = ''):
105 | # writer = SummaryWriter(log_dir=conf.general.exp_dir)
106 | if isinstance(objs, dict):
107 | for k, v in objs.items():
108 | if isinstance(v, float):
109 | writer.add_scalar(os.path.join(header, k), v, epoch_i)
110 | elif isinstance(objs, float):
111 | writer.add_scalar(header, v, epoch_i)
112 |
113 | def report_hasNan(x):
114 | cf = currentframe().f_back
115 | res = torch.any(torch.isnan(x)).cpu().item()
116 | if res: print(f"[hasnan!] {getframeinfo(cf).filename}:{cf.f_lineno}")
117 |
118 | def report_hasNeg(x):
119 | cf = currentframe().f_back
120 | res = torch.any(x < 0).cpu().item()
121 | if res: print(f"[hasneg!] {getframeinfo(cf).filename}:{cf.f_lineno}")
122 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | import pypose as pp
5 | from scipy.spatial.transform import Rotation as T
6 |
7 | import matplotlib.pyplot as plt
8 | import os
9 | import argparse
10 |
11 | def visualize_rotations(rotations, rotations_gt, save_folder, save_prefix = ""):
12 | ## Visualize the euler angle
13 | gt_r_euler = rotations_gt.euler().cpu().numpy()
14 | r_euler = rotations.data.euler().cpu().numpy()
15 |
16 | fig, axs = plt.subplots(3,)
17 |
18 | fig.suptitle("integrated orientation v.s. gt orientation")
19 | axs[0].plot(r_euler[:,0])
20 | axs[0].plot(gt_r_euler[:,0])
21 | axs[0].legend(["euler_x", "euler_x_gt"])
22 |
23 | axs[1].plot(r_euler[:,1])
24 | axs[1].plot(gt_r_euler[:,1])
25 | axs[1].legend(["euler_y", "euler_y_gt"])
26 |
27 | axs[2].plot(r_euler[:,2])
28 | axs[2].plot(gt_r_euler[:,2])
29 | axs[2].legend(["euler_z", "euler_z_gt"])
30 | plt.savefig(os.path.join(save_folder, save_prefix + "orientation.png"))
31 |
32 | r = rotations_gt[0]
33 | gt_r_euler = ((r.Inv()@rotations_gt).euler()).cpu().numpy()
34 | r_euler = ((r.Inv()@rotations).euler()).cpu().numpy()
35 |
36 | fig, axs = plt.subplots(3,)
37 |
38 | fig.suptitle("Incremental orientation v.s. gt orientation")
39 | axs[0].plot(r_euler[:,0])
40 | axs[0].plot(gt_r_euler[:,0])
41 | axs[0].legend(["euler_x", "euler_x_gt"])
42 |
43 | axs[1].plot(r_euler[:,1])
44 | axs[1].plot(gt_r_euler[:,1])
45 | axs[1].legend(["euler_y", "euler_y_gt"])
46 |
47 | axs[2].plot(r_euler[:,2])
48 | axs[2].plot(gt_r_euler[:,2])
49 | axs[2].legend(["euler_z", "euler_z_gt"])
50 | plt.savefig(os.path.join(save_folder, save_prefix + "incremental_orientation.png"))
51 |
52 |
53 | def visualize_velocities(velocities, gt_velocities, save_folder, save_prefix = ""):
54 | fig, axs = plt.subplots(3,)
55 | velocities = velocities.detach().numpy()
56 |
57 | fig.suptitle("integrated velocity v.s. gt velocity")
58 | axs[0].plot(velocities[:,0])
59 | axs[0].plot(gt_velocities[:,0])
60 | axs[0].legend(["velocity", "gt velocity"])
61 |
62 | axs[1].plot(velocities[:,1])
63 | axs[1].plot(gt_velocities[:,1])
64 | axs[1].legend(["velocity", "gt velocity"])
65 |
66 | axs[2].plot(velocities[:,2])
67 | axs[2].plot(gt_velocities[:,2])
68 | axs[2].legend(["velocity", "gt velocity"])
69 | plt.savefig(os.path.join(save_folder, save_prefix + "velocity.png"))
70 |
71 |
72 | def plot_2d_traj(trajectory, trajectory_gt, save_folder, vis_length = None, save_prefix = ""):
73 |
74 | if torch.is_tensor(trajectory):
75 | trajectory = trajectory.detach().cpu().numpy()
76 | trajectory_gt = trajectory_gt.detach().cpu().numpy()
77 |
78 | plt.clf()
79 | plt.figure(figsize=(3, 3),facecolor=(1, 1, 1))
80 |
81 | ax = plt.axes()
82 | ax.plot(trajectory[:,0][:vis_length], trajectory[:,1][:vis_length], 'b')
83 | ax.plot(trajectory_gt[:,0][:vis_length], trajectory_gt[:,1][:vis_length], 'r')
84 | plt.title("PyPose IMU Integrator")
85 | plt.legend(["PyPose", "Ground Truth"],loc='right')
86 |
87 | plt.savefig(os.path.join(save_folder, save_prefix + "poses.png"))
88 |
89 | def plot_poses(points1, points2, title='', axlim=None, savefig = None):
90 | if torch.is_tensor(points1):
91 | points1 = points1.detach().cpu().numpy()
92 | if torch.is_tensor(points2):
93 | points2 = points2.detach().cpu().numpy()
94 |
95 | plt.figure(figsize=(7, 7))
96 | ax = plt.axes(projection='3d')
97 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF")
98 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth")
99 |
100 | plt.title(title)
101 | ax.legend()
102 | if axlim is not None:
103 | ax.set_xlim(axlim[0])
104 | ax.set_ylim(axlim[1])
105 | ax.set_zlim(axlim[2])
106 | if savefig is not None:
107 | plt.savefig(savefig)
108 | print('Saving to', savefig)
109 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
110 |
111 |
112 | def vis_cov_error_diag(pos_loss_xyzs, pred_pos_cov, save_folder, save_prefix = ""):
113 | """
114 | pred_pos_covs
115 | """
116 |
117 | for i, axis in enumerate(["x", "y", "z"]):
118 | if torch.is_tensor(pos_loss_xyzs):
119 | pos_loss_xyzs = pos_loss_xyzs.detach().cpu().numpy()
120 | pred_pos_cov = pred_pos_cov.detach().cpu().numpy()
121 |
122 | plt.clf()
123 | plt.grid()
124 | plt.ylim(0.0, 1)
125 | plt.scatter(pos_loss_xyzs[:,i].cpu().numpy(), torch.sqrt(pred_pos_cov)[:,i].cpu().numpy(), marker="o", s = 2)
126 | plt.legend(["covariance","error"], loc='left')
127 |
128 | plt.savefig(os.path.join(save_folder, save_prefix + "cov-error_%s.png"%axis))
129 |
130 |
131 | def vis_rotation_error(ts, error, save_folder):
132 | title = "$SO(3)$ orientation error"
133 | ts = ts -ts[0]
134 |
135 | fig, axs = plt.subplots(3, 1, sharex=True, figsize=(20, 12))
136 | axs[0].set(ylabel='roll (deg)', title=title)
137 | axs[1].set(ylabel='pitch (deg)')
138 | axs[2].set(xlabel='$t$ (s)', ylabel='yaw (deg)')
139 |
140 | for i in range(3):
141 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU')
142 | axs[i].plot(ts, 180./np.pi * error[:, i].detach().cpu().numpy(), color='blue', label=r'net IMU')
143 | axs[i].set_ylim(-10, 10)
144 | axs[i].set_xlim(ts[0], ts[-1])
145 |
146 | for i in range(len(axs)):
147 | axs[i].grid()
148 | axs[i].legend()
149 | fig.tight_layout()
150 |
151 | plt.savefig(save_folder + '_orientation_error.png')
152 |
153 |
154 | def vis_corrections( error, save_folder):
155 | title = "$acc & gyro correction"
156 |
157 | fig, axs = plt.subplots(6, 1, sharex=True, figsize=(20, 24))
158 | axs[0].set(ylabel='x (m)', title=title)
159 | axs[1].set(ylabel='y (m)')
160 | axs[2].set(ylabel='z (m)')
161 | axs[3].set(ylabel='roll (deg)')
162 | axs[4].set(ylabel='pitch (deg)')
163 | axs[5].set(ylabel='yaw (deg)')
164 |
165 | for i in range(3):
166 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU')
167 | axs[i].plot(error[:,i], color='blue', label=r'net IMU')
168 | # axs[i].set_ylim(-0.2, 0.2)
169 |
170 | for i in range(3,6):
171 | # axs[i].plot(ts, raw_err[:, i], color='red', label=r'raw IMU')
172 | axs[i].plot(180./np.pi * error[:, i].detach().cpu().numpy(), color='blue', label=r'net IMU')
173 | # axs[i].set_ylim(-2, 2)
174 |
175 | for i in range(len(axs)):
176 | axs[i].grid()
177 | axs[i].legend()
178 | fig.tight_layout()
179 |
180 | plt.savefig(save_folder + 'corrections.png')
181 |
182 |
183 |
184 | def plot_and_save(points, title='', axlim=None, savefig = None):
185 | points = points.detach().cpu().numpy()
186 | plt.figure(figsize=(7, 7))
187 | ax = plt.axes(projection='3d')
188 | ax.plot3D(points[:,0], points[:,1], points[:,2], 'b')
189 | plt.title(title)
190 | if axlim is not None:
191 | ax.set_xlim(axlim[0])
192 | ax.set_ylim(axlim[1])
193 | ax.set_zlim(axlim[2])
194 | if savefig is not None:
195 | plt.savefig(savefig)
196 | print('Saving to', savefig)
197 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
198 |
199 | def plot2_and_save(points1, points2, title='', axlim=None, savefig = None):
200 | points1 = points1.detach().cpu().numpy()
201 | points2 = points2.detach().cpu().numpy()
202 | plt.figure(figsize=(7, 7))
203 | ax = plt.axes(projection='3d')
204 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b')
205 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r')
206 | plt.title(title)
207 | if axlim is not None:
208 | ax.set_xlim(axlim[0])
209 | ax.set_ylim(axlim[1])
210 | ax.set_zlim(axlim[2])
211 | if savefig is not None:
212 | plt.savefig(savefig)
213 | print('Saving to', savefig)
214 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
215 |
216 | # .plot(x, y, marker="o", markersize=20, markeredgecolor="red", markerfacecolor="green")
217 |
218 | def plot_nodes(points1, points2, point_nodes, title='', axlim=None, savefig = None):
219 | points1 = points1.detach().cpu().numpy()
220 | points2 = points2.detach().cpu().numpy()
221 | point_nodes = point_nodes.detach().cpu().numpy()
222 | plt.figure(figsize=(7, 7))
223 | ax = plt.axes(projection='3d')
224 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF")
225 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth")
226 | ax.scatter(point_nodes[:,0], point_nodes[:,1], point_nodes[:,2], marker="o",
227 | facecolor = "yellow", edgecolor="green", label = "GPS signal")
228 |
229 | plt.title(title)
230 | ax.legend()
231 | if axlim is not None:
232 | ax.set_xlim(axlim[0])
233 | ax.set_ylim(axlim[1])
234 | ax.set_zlim(axlim[2])
235 | if savefig is not None:
236 | plt.savefig(savefig)
237 | print('Saving to', savefig)
238 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
239 |
240 | def plot_trajs(points, labels, title='', axlim=None, savefig = None):
241 | for i, p in enumerate(points):
242 | if torch.is_tensor(p):
243 | points[i] = points[i].detach().cpu().numpy()
244 |
245 | plt.figure(figsize=(7, 7))
246 | ax = plt.axes(projection='3d')
247 | for i, p in enumerate(points):
248 | ax.plot3D(p[:,0], p[:,1], p[:,2], label = labels[i])
249 |
250 | plt.title(title)
251 | ax.legend()
252 | if axlim is not None:
253 | ax.set_xlim(axlim[0])
254 | ax.set_ylim(axlim[1])
255 | ax.set_zlim(axlim[2])
256 | if savefig is not None:
257 | plt.savefig(savefig)
258 | print('Saving to', savefig)
259 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
260 |
261 | def plot_nodes_2d(points1, points2, point_nodes, title='', axlim=None, savefig = None):
262 | points1 = points1.detach().cpu().numpy()
263 | points2 = points2.detach().cpu().numpy()
264 | point_nodes = point_nodes.detach().cpu().numpy()
265 | plt.figure(figsize=(7, 7))
266 | ax = plt.axes(projection='3d')
267 | ax.plot3D(points1[:,0], points1[:,1], points1[:,2], 'b', label = "KF")
268 | ax.plot3D(points2[:,0], points2[:,1], points2[:,2], 'r', label = "ground truth")
269 | ax.scatter(point_nodes[:,0], point_nodes[:,1], point_nodes[:,2], marker="o",
270 | facecolor = "yellow", edgecolor="green", label = "GPS signal")
271 |
272 | plt.title(title)
273 | ax.legend()
274 | if axlim is not None:
275 | ax.set_xlim(axlim[0])
276 | ax.set_ylim(axlim[1])
277 | ax.set_zlim(axlim[2])
278 | if savefig is not None:
279 | plt.savefig(savefig)
280 | print('Saving to', savefig)
281 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
282 |
283 |
284 |
285 | def plot_trajs(points, labels, title='', axlim=None, savefig = None):
286 | for i, p in enumerate(points):
287 | if torch.is_tensor(p):
288 | points[i] = points[i].detach().cpu().numpy()
289 |
290 | plt.figure(figsize=(7, 7))
291 | ax = plt.axes(projection='3d')
292 | for i, p in enumerate(points):
293 | ax.plot3D(p[:,0], p[:,1], p[:,2], label = labels[i])
294 |
295 | plt.title(title)
296 | ax.legend()
297 |
298 | ## Take the largest range
299 | x_len = ax.get_xlim()[1] - ax.get_xlim()[0]
300 | y_len = ax.get_ylim()[1] - ax.get_ylim()[0]
301 | z_len = ax.get_zlim()[1] - ax.get_zlim()[0]
302 | x_mean = np.mean(ax.get_xlim())
303 | y_mean = np.mean(ax.get_ylim())
304 | z_mean = np.mean(ax.get_zlim())
305 | _len = np.max([x_len, y_len, z_len])
306 |
307 | ax.set_xlim(x_mean - _len / 2, x_mean + _len / 2)
308 | ax.set_ylim(y_mean - _len / 2, y_mean + _len / 2)
309 | ax.set_zlim(z_mean - _len / 2, z_mean + _len / 2)
310 |
311 | if axlim is not None:
312 | ax.set_xlim(axlim[0])
313 | ax.set_ylim(axlim[1])
314 | ax.set_zlim(axlim[2])
315 | if savefig is not None:
316 | plt.savefig(savefig)
317 | print('Saving to', savefig)
318 | return ax.get_xlim(), ax.get_ylim(), ax.get_zlim()
--------------------------------------------------------------------------------
/utils/visualize_state.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import pypose as pp
5 | import matplotlib.pyplot as plt
6 | import matplotlib.patches as mpatches
7 |
8 |
9 | def visualize_state_error(save_prefix, relative_outstate, relative_infstate, \
10 | save_folder=None, mask=None, file_name="state_error_compare.png"):
11 | if mask is None:
12 | outstate_pos_err = relative_outstate['pos_dist'][0]
13 | outstate_vel_err = relative_outstate['vel_dist'][0]
14 | outstate_rot_err = relative_outstate['rot_dist'][0]
15 |
16 | infstate_pos_err = relative_infstate['pos_dist'][0]
17 | infstate_vel_err = relative_infstate['vel_dist'][0]
18 | infstate_rot_err = relative_infstate['rot_dist'][0]
19 | else:
20 | outstate_pos_err = relative_outstate['pos_dist'][0, mask]
21 | outstate_vel_err = relative_outstate['vel_dist'][0, mask]
22 | outstate_rot_err = relative_outstate['rot_dist'][0, mask]
23 |
24 | infstate_pos_err = relative_infstate['pos_dist'][0, mask]
25 | infstate_vel_err = relative_infstate['vel_dist'][0, mask]
26 | infstate_rot_err = relative_infstate['rot_dist'][0, mask]
27 |
28 | fig, axs = plt.subplots(3,)
29 | fig.suptitle("Integration error vs AirIMU Integration error")
30 |
31 | axs[0].plot(outstate_pos_err,color = 'b',linewidth=1)
32 | axs[0].plot(infstate_pos_err,color = 'red',linewidth=1)
33 | axs[0].legend(["integration_pos_error", "AirIMU_pos_error"])
34 | axs[0].grid(True)
35 |
36 | axs[1].plot(outstate_vel_err,color = 'b',linewidth=1)
37 | axs[1].plot(infstate_vel_err,color = 'red',linewidth=1)
38 | axs[1].legend(["integration_vel_error", "AirIMU_vel_error"])
39 | axs[1].grid(True)
40 |
41 | axs[2].plot(outstate_rot_err,color = 'b',linewidth=1)
42 | axs[2].plot(infstate_rot_err,color = 'red',linewidth=1)
43 | axs[2].legend(["integration_rot_error", "AirIMU_rot_error"])
44 | axs[2].grid(True)
45 |
46 | plt.tight_layout()
47 | if save_folder is not None:
48 | plt.savefig(os.path.join(save_folder, save_prefix + file_name), dpi = 300)
49 | plt.show()
50 |
51 |
52 | def visualize_rotations(save_prefix, gt_rot, out_rot, inf_rot = None,save_folder=None):
53 |
54 | gt_euler = 180./np.pi* pp.SO3(gt_rot).euler()
55 | outstate_euler = 180./np.pi* pp.SO3(out_rot).euler()
56 |
57 | legend_list = ["roll","pitch", "yaw"]
58 | fig, axs = plt.subplots(3,)
59 | fig.suptitle("integrated orientation")
60 | for i in range(3):
61 | axs[i].plot(outstate_euler[:,i],color = 'b',linewidth=0.9)
62 | axs[i].plot(gt_euler[:,i],color = 'mediumseagreen',linewidth=0.9)
63 | axs[i].legend(["Integrated_"+legend_list[i],"gt_"+legend_list[i]])
64 | axs[i].grid(True)
65 |
66 | if inf_rot is not None:
67 | infstate_euler = 180./np.pi* pp.SO3(inf_rot).euler()
68 | print(infstate_euler.shape)
69 | for i in range(3):
70 | axs[i].plot(infstate_euler[:,i],color = 'red',linewidth=0.9)
71 | axs[i].legend(["Integrated_"+legend_list[i],"gt_"+legend_list[i],"AirIMU_"+legend_list[i]])
72 | plt.tight_layout()
73 | if save_folder is not None:
74 | plt.savefig(os.path.join(save_folder, save_prefix+ "_orientation_compare.png"), dpi = 300)
75 | plt.show()
76 |
77 |
78 | def visualize_trajectory(save_prefix, save_folder, outstate, infstate):
79 | gt_x, gt_y, gt_z = torch.split(outstate["poses_gt"][0].cpu(), 1, dim=1)
80 | rawTraj_x, rawTraj_y, rawTraj_z = torch.split(outstate["poses"][0].cpu(), 1, dim=1)
81 | airTraj_x, airTraj_y, airTraj_z = torch.split(infstate["poses"][0].cpu(), 1, dim=1)
82 |
83 | fig, ax = plt.subplots()
84 | ax.plot(rawTraj_x, rawTraj_y, label="Raw")
85 | ax.plot(airTraj_x, airTraj_y, label="AirIMU")
86 | ax.plot(gt_x , gt_y , label="Ground Truth")
87 |
88 | ax.set_xlabel('X axis')
89 | ax.set_ylabel('Y axis')
90 | ax.legend()
91 | ax.set_aspect('equal', adjustable='box')
92 |
93 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_xy.png"), dpi = 300)
94 | plt.close()
95 |
96 | ###########################################################
97 |
98 | fig, ax = plt.subplots()
99 | ax.plot(rawTraj_x, rawTraj_z, label="Raw")
100 | ax.plot(airTraj_x, airTraj_z, label="AirIMU")
101 | ax.plot(gt_x , gt_z , label="Ground Truth")
102 |
103 | ax.set_xlabel('X axis')
104 | ax.set_ylabel('Z axis')
105 | ax.legend()
106 | ax.set_aspect('equal', adjustable='box')
107 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_xz.png"), dpi = 300)
108 | plt.close()
109 |
110 | ###########################################################
111 |
112 | fig, ax = plt.subplots()
113 | ax.plot(rawTraj_y, rawTraj_z, label="Raw")
114 | ax.plot(airTraj_y, airTraj_z, label="AirIMU")
115 | ax.plot(gt_y , gt_z , label="Ground Truth")
116 |
117 | ax.set_xlabel('Y axis')
118 | ax.set_ylabel('Z axis')
119 | ax.legend()
120 | ax.set_aspect('equal', adjustable='box')
121 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_yz.png"), dpi = 300)
122 | plt.close()
123 |
124 | ###########################################################
125 |
126 | fig = plt.figure()
127 | ax = fig.add_subplot(111, projection='3d')
128 |
129 | elevation_angle = 20 # Change the elevation angle (view from above/below)
130 | azimuthal_angle = 30 # Change the azimuthal angle (rotate around z-axis)
131 |
132 | ax.view_init(elevation_angle, azimuthal_angle) # Set the view
133 |
134 | # Plotting the ground truth and inferred poses
135 | ax.plot(rawTraj_x, rawTraj_y, rawTraj_z, label="Raw")
136 | ax.plot(airTraj_x, airTraj_y, airTraj_z, label="AirIMU")
137 | ax.plot(gt_x , gt_y , gt_z , label="Ground Truth")
138 |
139 | # Adding labels
140 | ax.set_xlabel('X axis')
141 | ax.set_ylabel('Y axis')
142 | ax.set_zlabel('Z axis')
143 | ax.legend()
144 |
145 | plt.savefig(os.path.join(save_folder, save_prefix+ "_trajectory_3d.png"), dpi = 300)
146 | plt.close()
147 |
148 |
149 | def box_plot_wrapper(ax, data, edge_color, fill_color, **kwargs):
150 | bp = ax.boxplot(data, **kwargs)
151 |
152 | for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']:
153 | plt.setp(bp[element], color=edge_color)
154 |
155 | for patch in bp['boxes']:
156 | patch.set(facecolor=fill_color)
157 |
158 | return bp
159 |
160 |
161 | def plot_boxes(folder, input_data, metrics, show_metrics):
162 | fig, ax = plt.subplots(dpi=300)
163 | raw_ticks = [_-0.12 for _ in range(1, len(metrics) + 1)]
164 | air_ticks = [_+0.12 for _ in range(1, len(metrics) + 1)]
165 | label_ticks = [_ for _ in range(1, len(metrics) + 1)]
166 |
167 | raw_data = [input_data[metric + "(raw)" ] for metric in metrics]
168 | air_data = [input_data[metric + "(AirIMU)"] for metric in metrics]
169 |
170 | # ax.boxplot(data, patch_artist=True, positions=ticks, widths=.2)
171 | box_plot_wrapper(ax, raw_data, edge_color="black", fill_color="royalblue", positions=raw_ticks, patch_artist=True, widths=.2)
172 | box_plot_wrapper(ax, air_data, edge_color="black", fill_color="gold", positions=air_ticks, patch_artist=True, widths=.2)
173 | ax.set_xticks(label_ticks)
174 | ax.set_xticklabels(show_metrics)
175 |
176 | # Create color patches for legend
177 | gold_patch = mpatches.Patch(color='gold', label='AirIMU')
178 | royalblue_patch = mpatches.Patch(color='royalblue', label='Raw')
179 | ax.legend(handles=[gold_patch, royalblue_patch])
180 |
181 | plt.savefig(os.path.join(folder, "Metrics.png"), dpi = 300)
182 | plt.close()
183 |
184 |
--------------------------------------------------------------------------------