├── .gitignore ├── Doc ├── Algorithm Description.html ├── Algorithm Description.md ├── Manual.html ├── Manual.md └── dataflow.svg ├── README.md ├── Ref ├── A Drift Eliminated Attitude & Position Estimation Algorithm In 3D.pdf ├── Adaptive Zero Velocity Update Based On Velocity Classification For Pedestrian.pdf ├── Drift reduction in IMU-only pedestrian navigation system in unstructured environment.pdf └── Using Inertial Sensors for Position and Orientation Estimation.pdf ├── butter.py ├── data_receiver.py ├── main.ipynb ├── main.py ├── mathlib.py └── plotlib.py /.gitignore: -------------------------------------------------------------------------------- 1 | /__pycache__/ 2 | /.ipynb_checkpoints/ 3 | .vscode/.ropeproject/config.py 4 | .vscode/.ropeproject/objectdb 5 | data.txt 6 | -------------------------------------------------------------------------------- /Doc/Algorithm Description.md: -------------------------------------------------------------------------------- 1 | # Algorithm 2 | ## Definitions 3 | Quaternion is defined as $q=[q_{scalar}\ q_{vector}^T]^T=[q_0\ q_1\ q_2\ q_3]^T$ 4 | 5 | The rotation matrix of a quaternion is 6 | $$ 7 | R(q)= 8 | \begin{bmatrix} 9 | q_0^2+q_1^2-q_2^2-q_3^2 & 2q_1q_2-2q_0q_3 & 2q_1q_3+2q_0q_2 \\ 10 | 2q_1q_2+2q_0q_3 & q_0^2-q_1^2+q_2^2-q_3^2 & 2q_2q_3-2q_0q_1 \\ 11 | 2q_1q_3-2q_0q_2 & 2q_2q_3+2q_0q_1 & q_0^2-q_1^2-q_2^2+q_3^2 12 | \end{bmatrix} 13 | $$ 14 | 15 | Define 16 | $$ 17 | \Omega(\omega)= 18 | \begin{bmatrix} 19 | 0 & -\omega_x & -\omega_y & -\omega_z \\ 20 | \omega_x & 0 & \omega_z & -\omega_y \\ 21 | \omega_y & -\omega_z & 0 & \omega_x \\ 22 | \omega_z & \omega_y & -\omega_x & 0 23 | \end{bmatrix} 24 | $$ 25 | 26 | $$ 27 | G(q)=\frac{1}{2} 28 | \begin{bmatrix} 29 | -q_1 & -q_2 & -q_3 \\ 30 | q_0 & -q_3 & q_2 \\ 31 | q_3 & q_0 & -q_1 \\ 32 | -q_2 & q_1 & q_0 33 | \end{bmatrix} 34 | $$ 35 | 36 | ## Initialization 37 | $$q=[1\ 0\ 0\ 0]^T$$ 38 | $$P=1\times10^{-8}*I_4$$ 39 | 40 | I assume the device stays still for a certain period of time during initialization, so we can get a initial gravity vector 41 | $$g_n$$ 42 | and magnetic field vector 43 | $$m_n$$ 44 | The notation $\cdot_n$ means navigation frame. 45 | 46 | This can also be done using QUEST algorithm and many other algorithms. 47 | 48 | 49 | ## Propagation 50 | The state vector is quaternion $q$ 51 | 52 | State transfer matrix 53 | $$F_t=I_4+\dfrac{1}{2}dt*\Omega(\omega_t)$$ 54 | 55 | Then we derive process noise 56 | $$Q=(GyroNoise*dt)^2*GG^T$$ 57 | 58 | Make estimation: 59 | $$q=F_tq$$ 60 | $$P=F_tPF_t^T+Q$$ 61 | 62 | and finally normalize $q$ to reduce error: 63 | $$q=\dfrac{q}{||q||_2}$$ 64 | 65 | ## Measurement Update 66 | 67 | Here we only use the unit vector of our measurements to avoid accumulation of noise errors. 68 | 69 | $$ea=\dfrac{a_t}{||a_t||_2}$$ 70 | $$em=\dfrac{m_t}{||m_t||_2}$$ 71 | 72 | and the estimation(predition) of is as follow: 73 | $$pa=Normalize(-R(q)g_n)$$ 74 | $$pm=Normalize(R(q)m_n)$$ 75 | 76 | so error: 77 | $$ 78 | \epsilon_t= 79 | \begin{bmatrix} 80 | ea\\em 81 | \end{bmatrix}_{6\times1} 82 | - 83 | \begin{bmatrix} 84 | pa\\pm 85 | \end{bmatrix}_{6\times1} 86 | $$ 87 | 88 | note that $a_t$ and $m_t$ come in as $3\times1$ vectors, so don't get confused over dimensions. 89 | 90 | And we use the measurement matrix to calculate kalman gain. The measurent matrix $H$ is defined as 91 | $$ 92 | H= 93 | \begin{bmatrix} 94 | -\dfrac{\partial}{\partial q}(R(q)g_n) \\ 95 | \dfrac{\partial}{\partial q}(R(q)m_n) 96 | \end{bmatrix}_{6\times4} 97 | $$ 98 | which is annoying to calculate but it's done in `mathlib.py`. 99 | 100 | As for the sensor noise $R$, you can either just use $R=C*I_6$ or some other fancy definitions. 101 | 102 | Then it's just usual kalman filter stuff: 103 | $$S=HPH^T+R$$ 104 | $$K=PH^TS^{-1}$$ 105 | $$q=q+K\epsilon_t$$ 106 | $$P=P-KHP$$ 107 | 108 | ## Post Correction 109 | 110 | Normalize $q$ again 111 | $$q=\dfrac{q}{||q||_2}$$ 112 | and make sure P is symmetrical 113 | $$P=\dfrac{1}{2}(P+P^T)$$ 114 | 115 | ## Double Integration 116 | 117 | body frame acceleration: 118 | $$a_b=a_t+R(q)g_n$$ 119 | and we can update position and velocity now 120 | $$position=position + velocity*dt + \dfrac{1}{2}a_bdt^2$$ 121 | $$velocity=velocity+a_bdt$$ -------------------------------------------------------------------------------- /Doc/Manual.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 说明 6 | 8 | 9 | 10 | 11 | 12 | 15 | 22 | 23 | 24 | 25 | 26 | 27 | 53 |

说明

54 |

依赖库

55 | 60 |

项目结构

61 | 74 |

main.py

75 |

接口

76 |

IMUTracker类

77 | 204 |

使用

205 |
    206 |
  1. 初始化IMUTracker
  2. 207 |
  3. 将初始化数据输入initialize方法
  4. 208 |
  5. initialize返回的list,同传感器数据一起传入attitudeTrack方法
  6. 209 |
  7. 使用removeAccErr方法修正attitudeTrack返回的加速度数据
  8. 210 |
  9. 使用zupt方法计算速度
  10. 211 |
  11. 使用positionTrack方法对速度和加速度数据进行积分,得到位移
  12. 212 |
213 |

示例

214 |
data = sensor.getData() 215 | tracker = IMUTracker(sampling=100) 216 | 217 | # init 218 | init_list = tracker.initialize(data[5:30]) 219 | # EKF 220 | a_nav, orix, oriy, oriz = tracker.attitudeTrack(data[30:], init_list) 221 | # filter a_nav 222 | a_nav_filtered = tracker.removeAccErr(a_nav, filter=False) 223 | # get velocity 224 | v = tracker.zupt(a_nav_filtered, threshold=0.2) 225 | # get position 226 | p = tracker.positionTrack(a_nav_filtered, v) 227 |
228 |

流程图

229 |
230 | 231 |
232 |

plotlib.py

233 |

接口

234 | 338 |

示例

339 |

直接输入数据即可画图:

340 |
plot3([acc1, acc2]) 341 | plot3D([[position, 'position'], [orientation, 'orientation']]) 342 |
343 |

plot3()默认使用3行布局,也可以使用3列布局:

344 |
import matplotlib.pyplot as plt 345 | fig, ax = plt.subplots(nrows=1, ncols=3) 346 | plot3([acc1, acc2], ax=ax) 347 |
348 | 349 | 350 | -------------------------------------------------------------------------------- /Doc/Manual.md: -------------------------------------------------------------------------------- 1 | - [说明](#说明) 2 | - [依赖库](#依赖库) 3 | - [项目结构](#项目结构) 4 | - [main.py](#mainpy) 5 | - [接口](#接口) 6 | - [IMUTracker类](#imutracker类) 7 | - [使用](#使用) 8 | - [示例](#示例) 9 | - [流程图](#流程图) 10 | - [plotlib.py](#plotlibpy) 11 | - [接口](#接口-1) 12 | - [示例](#示例-1) 13 | 14 | # 说明 15 | 16 | ## 依赖库 17 | - numpy 18 | - scipy 19 | - matplotlib 20 | 21 | ## 项目结构 22 | - `main.py`:主算法,带有一个从手机接收数据并可视化的示例。每次运行会自动保存数据到`data.txt`中,`receive_data()`的参数设为`mode='file'`即可从文件读取。 23 | - `plotlib.py`:可视化的简单封装。 24 | - `mathlib.py`:矩阵操作和滤波器的封装。 25 | - `butter.py`:实时巴特沃斯滤波器,详情见[这里](https://github.com/keikun555/Butter)。目前没有使用。 26 | - `main.ipynb`:开发使用。 27 | - `/Ref`: 28 | - *Using Inertial Sensors for Position and Orientation Estimation*是一个基本教程,包含了卡尔曼滤波的较为详细的描述。 29 | - 其它为修正算法相关文献。 30 | 31 | 32 | # main.py 33 | 34 | ## 接口 35 | ### IMUTracker类 36 | - `__init__(self, sampling, data_order={'w': 1, 'a': 2, 'm': 3})` 37 | - `sampling` 38 | - 采样率,单位:Hz。 39 | - `order` 40 | - 数据的顺序,角速度为`w`、加速度为`a`、磁力计为`m`,默认为:`{'w': 1, 'a': 2, 'm': 3}`。 41 | 42 | 43 | - `initialize(self, data, noise_coefficient={'w': 100, 'a': 100, 'm': 10})` 44 | 45 | 返回一个list,包含EKF算法需要的所有初始化值。 46 | 47 | - `data` 48 | - 传感器数据,$(n\times 9)$numpy数组。 49 | - 有时传感器刚开始测量时会产生无用的数据,可能需要预处理去掉明显不符合实际的数据点,如去掉刚开始测量时的前10个点。 50 | - `noise_coefficient` 51 | - 传感器噪声值包括了实际噪声和测量误差,噪声由初始化数据的方差算出,然后乘上一个系数作为算法中的传感器噪声。 52 | - 这个值越大,代表算法越不信任传感器,那么在滤波补偿时就会减小这个传感器的权值。 53 | 54 | 55 | - `attitudeTrack(self, data, init_list)` 56 | 57 | 使用Extended Kalman Filter(EKF) 计算姿态,算法描述在`/Doc/Algorithm Description.html`。返回地面坐标系下的加速度(去除了重力成分)和设备朝向。 58 | 59 | 朝向由3个$n\times 3$numpy数组表示,分别是$XYZ$轴的方向向量(单位向量),设备初始状态是: 60 | $$\hat{x}=[1,0,0]^T\ \hat{y}=[0,1,0]^T\ \hat{z}=[0,0,1]^T$$ 61 | 62 | 绕$Z$轴右手方向旋转$90\degree$后: 63 | $$\hat{x}=[0,1,0]^T\ \hat{y}=[-1,0,0]^T\ \hat{z}=[0,0,1]^T$$ 64 | 65 | 66 | 67 | - `data` 68 | - 传感器数据,$(n\times 9)$numpy数组。 69 | - `init_list` 70 | - 初始化值列表,可以直接使用`initialize`的返回值,也可自定义。 71 | - 顺序: 72 | - 地面坐标系重力向量 73 | - 重力大小 74 | - 磁场方向**单位**向量 75 | - 陀螺仪噪声 76 | - 陀螺仪偏差 77 | - 加速度计噪声 78 | - 地磁计噪声 79 | 80 | 81 | - `removeAccErr(self, a_nav, threshold=0.2, filter=False, wn=(0.01, 15))` 82 | 83 | 假设设备测量前后都静止,去除地面坐标系加速度的偏差,并通过一个带通滤波器(可选)。返回修正后的加速度数据。 84 | 85 | - `a_nav` 86 | - 地面坐标系加速度数据,$(n\times 3)$numpy数组。 87 | - `threshold` 88 | - 检测静止状态的加速度阈值,不建议太低。 89 | - `filter` 90 | - 滤波开关。 91 | - 低通滤波去除毛刺,高通滤波去除直流分量(固定的偏差),但有时滤波会起反作用,应视实际情况使用。 92 | - `wn` 93 | - 滤波器的截止频率。 94 | 95 | 96 | - `zupt(self, a_nav, threshold)` 97 | 98 | 使用Zero velocity UPdaTe算法来修正速度。返回修正后的速度数据。 99 | 100 | - `a_nav` 101 | - 地面坐标系加速度数据,$(n\times 3)$numpy数组。 102 | - `threshold` 103 | - 检测静止状态的加速度阈值,运动越激烈阈值应越高。 104 | 105 | 106 | - `positionTrack(self, a_nav, velocities)` 107 | 108 | 使用加速度和速度进行积分得到位移。返回位移数据。 109 | 110 | - `a_nav` 111 | - 地面坐标系加速度数据,$(n\times 3)$numpy数组。 112 | - `velocities` 113 | - 地面坐标系速度数据,$(n\times 3)$numpy数组。 114 | 115 | 116 | ## 使用 117 | 1. 初始化`IMUTracker`类 118 | 2. 将初始化数据输入`initialize`方法 119 | 3. 将`initialize`返回的list,同传感器数据一起传入`attitudeTrack`方法 120 | 4. 使用`removeAccErr`方法修正`attitudeTrack`返回的加速度数据 121 | 5. 使用`zupt`方法计算速度 122 | 6. 使用`positionTrack`方法对速度和加速度数据进行积分,得到位移 123 | 124 | ## 示例 125 | ```python 126 | data = sensor.getData() 127 | tracker = IMUTracker(sampling=100) 128 | 129 | # init 130 | init_list = tracker.initialize(data[5:30]) 131 | # EKF 132 | a_nav, orix, oriy, oriz = tracker.attitudeTrack(data[30:], init_list) 133 | # filter a_nav 134 | a_nav_filtered = tracker.removeAccErr(a_nav, filter=False) 135 | # get velocity 136 | v = tracker.zupt(a_nav_filtered, threshold=0.2) 137 | # get position 138 | p = tracker.positionTrack(a_nav_filtered, v) 139 | ``` 140 | 141 | ## 流程图 142 |
143 | 144 |
145 | 146 | 147 | # plotlib.py 148 | 149 | ## 接口 150 | - `plot3(data, ax=None, lims=None, labels=None, show=False, show_legend=False)` 151 | 152 | 接受多个$(n\times 3)$的数据,如$XYZ$方向加速度,将3个分量分别画在3张图中。默认不设置其他参数。返回使用的matplotlib axes对象。 153 | 154 | - `data` 155 | - 包含数据的列表:`[data1, data2, ...]`。每个元素都是$(n\times 3)$的numpy数组。 156 | - `ax` 157 | - matplotlib axes对象,默认单独创建。 158 | - `lims` 159 | - 坐标上下限:`[[[xl, xh], [yl, yh]], ...]`。 160 | - 嵌套了3层。 161 | - `labels` 162 | - 图例所用的标签:`[[x_label1, y_label1, z_label1], [x_label2, y_label2, z_label2], ...]`。 163 | - 与`data`中数据一一对应。 164 | - `show` 165 | - 是否调用`plt.show()`,用处不大可以忽略。 166 | - `show_legend` 167 | - 是否显示图例,如果不定义标签则显示空图例。 168 | - 有时图例会遮挡曲线,所以为了方便,单独定义了这个开关。 169 | 170 | - `plot3D(data, lim=None, ax=None)` 171 | 172 | 接受多个$(n\times 3)$的数据,画出3维图像。 173 | 174 | - `data` 175 | - 包含数据的列表:`[[data1, label1], [data2, label2], ...]`。 176 | - 这里不太一样的是强制要求定义标签。 177 | - `lim` 178 | - $XYZ$轴的上下限:`[[xl, xh], [yl, yh], [zl, zh]]`。 179 | - 单个轴设为`None`则使用默认值,如:`[[xl, xh], [yl, yh], None]`。 180 | - `ax` 181 | - 画图使用的matplotlib axes对象,用户一般不需要设置。 182 | 183 | - `plot3DAnimated(data, lim=[[-1, 1], [-1, 1], [-1, 1]], label=None, interval=10, show=True, repeat=False)` 184 | 185 | 生成一个3D动画。 186 | 187 | 视角调整比较麻烦,见[axes3d.view_init](https://matplotlib.org/mpl_toolkits/mplot3d/api.html#mpl_toolkits.mplot3d.axes3d.Axes3D.view_init)。 188 | 189 | 如果需要保存动画,则需要ffmpeg库,然后使用`ani.save()`,见[matplotlib文档示例](https://matplotlib.org/gallery/animation/basic_example_writer_sgskip.html)。 190 | 191 | - `data` 192 | - $(n\times 3)$numpy数组。 193 | - `lim` 194 | - 坐标轴范围。 195 | - 这个不会随数据自动设置,最好要手动指定。 196 | - `label` 197 | - 图例标签,字符串。 198 | - `interval` 199 | - 帧之间的时长,单位:ms 200 | - `show` 201 | - 控制是否调用`plt.show()`。 202 | - 如果不调用,则可以多次调用这个函数,在同一张图中画出多条曲线。最后需要手动调用`plt.show()` 203 | - `repeat` 204 | - 控制动画是否循环。 205 | 206 | 207 | ## 示例 208 | 209 | 直接输入数据即可画图: 210 | ```python 211 | plot3([acc1, acc2]) 212 | plot3D([[position, 'position'], [orientation, 'orientation']]) 213 | ``` 214 | 215 | `plot3()`默认使用3行布局,也可以使用3列布局: 216 | ```python 217 | import matplotlib.pyplot as plt 218 | fig, ax = plt.subplots(nrows=1, ncols=3) 219 | plot3([acc1, acc2], ax=ax) 220 | ``` 221 | -------------------------------------------------------------------------------- /Doc/dataflow.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
传感器
传感器
加速度数据
加速度数据
角速度数据
角速度数据
磁场强度数据
磁场强度数据
initialize()
initialize()
attitudeTrack()
attitudeTrack()
初始化数据
初始化数据
init_list
init_list
返回值
返回值
加速度数据
加速度数据
设备朝向数据
设备朝向数据
removeAccErr()
removeAccErr()
修正的加速度数据
修正的加速度数据
zupt()
zupt()
速度数据
速度数据
positionTrack()
positionTrack()
位移数据
位移数据
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IMU Position Tracking 2 | 3D position tracking based on data from 9 degree of freedom IMU (Accelerometer, Gyroscope and Magnetometer). This can track orientation pretty accurately and position but with significant accumulated errors from double integration of acceleration. 3 | 4 | ## Project Structure 5 | - `main.py`: where the main Extended Kalman Filter(EKF) and other algorithms sit. 6 | - `butter.py`: a digital realtime butterworth filter implementation from [this repo](https://github.com/keikun555/Butter) with minor fixes. But I don't use realtime filtering now. 7 | - `mathlib`: contains matrix definitions for the EKF and a filter helper function. 8 | - `plotlib.py`: some wrappers for visualization used in prototyping. 9 | - `main.ipynb`: almost the same as `main.py`, just used for prototyping. 10 | - `/Ref`: Some paper found on the internet that is helpful. 11 | - `/Doc`: an Algorithm description (you can view it in html as github doesn't support markdown latex extension) and an API documentation in Chinese. 12 | 13 | # Data Source 14 | I use an APP called [HyperIMU](https://play.google.com/store/apps/details?id=com.ianovir.hyper_imu) to pull (uncalibrated) data from my phone. Data is sent through TCP and received using `data_receiver.py`. 15 | -------------------------------------------------------------------------------- /Ref/A Drift Eliminated Attitude & Position Estimation Algorithm In 3D.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john2zy/IMU-Position-Tracking/48f9908d068c5d04fd9272e3a3d22acd142cc869/Ref/A Drift Eliminated Attitude & Position Estimation Algorithm In 3D.pdf -------------------------------------------------------------------------------- /Ref/Adaptive Zero Velocity Update Based On Velocity Classification For Pedestrian.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john2zy/IMU-Position-Tracking/48f9908d068c5d04fd9272e3a3d22acd142cc869/Ref/Adaptive Zero Velocity Update Based On Velocity Classification For Pedestrian.pdf -------------------------------------------------------------------------------- /Ref/Drift reduction in IMU-only pedestrian navigation system in unstructured environment.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john2zy/IMU-Position-Tracking/48f9908d068c5d04fd9272e3a3d22acd142cc869/Ref/Drift reduction in IMU-only pedestrian navigation system in unstructured environment.pdf -------------------------------------------------------------------------------- /Ref/Using Inertial Sensors for Position and Orientation Estimation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john2zy/IMU-Position-Tracking/48f9908d068c5d04fd9272e3a3d22acd142cc869/Ref/Using Inertial Sensors for Position and Orientation Estimation.pdf -------------------------------------------------------------------------------- /butter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Kei Imada 3 | 20170801 4 | A Butterworth signal filter worth using 5 | """ 6 | 7 | __all__ = ["Butter"] 8 | __version__ = "1.0" 9 | __author__ = "Kei Imada" 10 | 11 | 12 | import numpy as np 13 | from numba import jit 14 | 15 | 16 | @jit(nopython=True, cache=True) 17 | def _filterHelper(x, w, f, N): 18 | """ 19 | x a float 20 | w an array of arrays of floats 21 | f an array of arrays of floats 22 | N an int 23 | """ 24 | w[0][4] = x 25 | for m in range(N / 2): 26 | previous_x = w[m] 27 | previous_y = w[m + 1] 28 | 29 | ym = f[0][m] * ( 30 | previous_x[4] 31 | + f[1][m] * previous_x[3] 32 | + f[2][m] * previous_x[2] 33 | + f[3][m] * previous_x[1] 34 | + f[4][m] * previous_x[0] 35 | ) - ( 36 | f[5][m] * previous_y[3] 37 | + f[6][m] * previous_y[2] 38 | + f[7][m] * previous_y[1] 39 | + f[8][m] * previous_y[0] 40 | ) 41 | 42 | previous_y[4] = ym 43 | 44 | for i in range(len(previous_x) - 1): 45 | previous_x[i] = previous_x[i + 1] 46 | for i in range(len(previous_y) - 1): 47 | previous_y[i] = previous_y[i + 1] 48 | return ym 49 | 50 | 51 | class Butter(object): 52 | def __init__(self, btype="lowpass", cutoff=None, 53 | cutoff1=None, cutoff2=None, 54 | rolloff=48, sampling=None): 55 | """The constructor for the butter filter object 56 | @param btype string type of filter, default lowpass 57 | lowpass 58 | highpass 59 | bandpass 60 | notch 61 | bandstop 62 | filter required arguments 63 | @param rolloff float measured in dB/Oct, default 48Hz 64 | @param sampling float measured in Hz 65 | lowpass filter required arguments 66 | @param cutoff float measured in Hz 67 | highpass filter required arguments 68 | @param cutoff float measured in Hz 69 | bandpass filter required arguments 70 | @param cutoff1 float measured in Hz 71 | @param cutoff2 float measured in Hz 72 | cutoff1 < cutoff2 73 | notch filter required arguments 74 | @param cutoff float measured in Hz 75 | bandstop filter required arguments 76 | @param cutoff1 float measured in Hz 77 | @param cutoff2 float measured in Hz 78 | cutoff1 < cutoff2 79 | """ 80 | # input checking 81 | valid = map(lambda k: k[0], 82 | filter(lambda k: type(k[1]) in [int, float], 83 | zip(["cutoff", "cutoff1", "cutoff2", "rolloff", "sampling"], 84 | [cutoff, cutoff1, cutoff2, rolloff, sampling]) 85 | ) 86 | ) 87 | valid = list(valid) 88 | if None in [rolloff, sampling]: 89 | raise ValueError( 90 | "Butter:rolloff and sampling required for %s filter" % btype) 91 | if "rolloff" not in valid: 92 | raise TypeError("Butter:invalid rolloff argument") 93 | if "sampling" not in valid: 94 | raise TypeError("Butter:invalid sampling argument") 95 | if btype in ["lowpass", "highpass", "notch"]: 96 | if None in [cutoff]: 97 | raise ValueError( 98 | "Butter:cutoff required for %s filter" % btype) 99 | if "cutoff" not in valid: 100 | raise TypeError("Butter:invalid cutoff argument") 101 | elif btype in ["bandpass", "bandstop"]: 102 | if None in [cutoff1, cutoff2]: 103 | raise ValueError( 104 | "Butter:cutoff1 and cutoff2 required for %s filter" % btype) 105 | if "cutoff1" not in valid: 106 | raise TypeError("Butter:invalid cutoff1 argument") 107 | if "cutoff2" not in valid: 108 | raise TypeError("Butter:invalid cutoff2 argument") 109 | if cutoff1 > cutoff2: 110 | raise ValueError( 111 | "Butter:cutoff1 must be less than or equal to cutoff2") 112 | else: 113 | raise ValueError("Butter: invalid btype %s" % btype) 114 | self.btype = btype 115 | # initialize base filter variables 116 | A = float(rolloff) 117 | fs = float(sampling) 118 | Oc = cutoff 119 | f1 = cutoff1 120 | f2 = cutoff2 121 | B = 99.99 122 | wp = .3 * np.pi 123 | ws = 2 * wp 124 | d1 = B / 100.0 125 | d2 = 10**(np.log10(d1) - (A / 20.0)) 126 | self.N = int(np.ceil((np.log10(((1 / (d1**2)) - 1) / 127 | ((1 / (d2**2)) - 1))) / (2 * np.log10(wp / ws)))) 128 | if self.N % 2 == 1: 129 | self.N += 1 130 | self.wc = 10**(np.log10(wp) - (1.0 / (2 * self.N)) 131 | * np.log10((1 / (d1**2)) - 1)) 132 | self.fs = fs 133 | self.fc = Oc 134 | self.f1 = f1 135 | self.f2 = f2 136 | 137 | # to store the filtered data 138 | self.output = [] 139 | # to store passed in data 140 | self.data = [] 141 | # list of frequencies used in calculation of filters 142 | self.frequencylist = np.zeros((self.N // 2 + 1, 5)) 143 | 144 | # set variables for desired filter 145 | self.filter = { 146 | "lowpass": self.__lowpass_filter_variables, 147 | "highpass": self.__highpass_filter_variables, 148 | "bandpass": self.__bandpass_filter_variables, 149 | "notch": self.__notch_filter_variables, 150 | "bandstop": self.__bandstop_filter_variables 151 | }[btype]() 152 | 153 | def filtfilt(self): 154 | """Returns accumulated output values with forward-backwards filtering 155 | @return list of float/int accumulated output values, filtered through forward-backward filtering 156 | """ 157 | tempfrequencylist = [ 158 | [0 for i in range(5)] for j in range(self.N // 2 + 1)] 159 | data = self.output[:] 160 | data.reverse() 161 | for i in range(len(data)): 162 | data[i] = __filterHelper(data[i], tempfrequencylist) 163 | data.reverse() 164 | return data 165 | 166 | def send(self, data): 167 | """Send data to Butterworth filter 168 | @param data list of floats amplitude data to take in 169 | @return values from the filtered data, with forward filtering 170 | """ 171 | if type(data) != list: 172 | raise TypeError( 173 | "Butter.send: type of data must be a list of floats") 174 | self.data += data 175 | output = [] 176 | for amplitude in data: 177 | newamp = _filterHelper( 178 | amplitude, self.frequencylist, self.filter, self.N) 179 | output.append(newamp) 180 | self.output += output 181 | return output 182 | 183 | def __basic_filter_variables(self): 184 | """Returns basic filter variables 185 | @return dictionary key:string variable value: lambda k 186 | """ 187 | basic = np.zeros((9, (self.N // 2))) 188 | for k in range(self.N // 2): 189 | a = self.wc * \ 190 | np.sin((float(2.0 * (k + 1) - 1) / (2.0 * self.N)) * np.pi) 191 | B = 1 + a + ((self.wc**2) / 4.0) 192 | basic[0][k] = (self.wc**2) / (4.0 * B) 193 | basic[1][k] = 2 194 | basic[2][k] = 1 195 | basic[3][k] = 0 196 | basic[4][k] = 0 197 | basic[5][k] = 2 * ((self.wc**2 / (4.0)) - 1) / (B) 198 | basic[6][k] = ( 199 | 1 - a + (self.wc**2 / (4.0))) / (B) 200 | basic[7][k] = 0 201 | basic[8][k] = 0 202 | 203 | return basic 204 | 205 | def __lowpass_filter_variables(self): 206 | """Returns lowpass filter variables 207 | @return dictionary key:string variable value: lambda k 208 | """ 209 | basic = self.__basic_filter_variables() 210 | 211 | Op = 2 * (np.pi * self.fc / self.fs) 212 | vp = 2 * np.arctan(self.wc / 2.0) 213 | 214 | alpha = np.sin((vp - Op) / 2.0) / \ 215 | np.sin((vp + Op) / 2.0) 216 | 217 | lowpass = np.zeros((9, (self.N // 2))) 218 | for k in range(self.N // 2): 219 | C = 1 - basic[5][k] * \ 220 | alpha + basic[6][k] * (alpha**2) 221 | a = self.wc * \ 222 | np.sin((float(2.0 * (k + 1) - 1) / (2.0 * self.N)) * np.pi) 223 | B = 1 + a + ((self.wc**2) / 4.0) 224 | lowpass[0][k] = ((1 - alpha)**2) * basic[0][k] / C 225 | lowpass[1][k] = basic[1][k] 226 | lowpass[2][k] = basic[2][k] 227 | lowpass[3][k] = basic[3][k] 228 | lowpass[4][k] = basic[4][k] 229 | lowpass[5][k] = ( 230 | (1 + alpha**2) * basic[5][k] - 2 * alpha * (1 + basic[6][k])) / C 231 | lowpass[6][k] = ( 232 | alpha**2 - basic[5][k] * alpha + basic[6][k]) / C 233 | lowpass[7][k] = basic[7][k] 234 | lowpass[8][k] = basic[8][k] 235 | return lowpass 236 | 237 | def __highpass_filter_variables(self): 238 | """Returns highpass filter variables 239 | @return dictionary key:string variable value: lambda k 240 | """ 241 | basic = self.__basic_filter_variables() 242 | Op = 2 * (np.pi * float(self.fc) / self.fs) 243 | vp = 2 * np.arctan(self.wc / 2.0) 244 | 245 | alpha = -(np.cos((vp + Op) / (2.0))) / \ 246 | (np.cos((vp - Op) / (2.0))) 247 | 248 | highpass = np.zeros((9, (self.N // 2))) 249 | for k in range(self.N // 2): 250 | C = 1 - basic[5][k] * \ 251 | alpha + basic[6][k] * (alpha**2) 252 | highpass[0][k] = ((1 - alpha)**2) * basic[0][k] / C 253 | highpass[1][k] = -basic[1][k] 254 | highpass[2][k] = basic[2][k] 255 | highpass[3][k] = basic[3][k] 256 | highpass[4][k] = basic[4][k] 257 | highpass[5][k] = ( 258 | -(1.0 + alpha**2) * basic[5][k] + 2 * alpha * (1 + basic[6][k])) / C 259 | highpass[6][k] = ( 260 | float(alpha**2) - basic[5][k] * alpha + basic[6][k]) / C 261 | highpass[7][k] = basic[7][k] 262 | highpass[8][k] = basic[8][k] 263 | return highpass 264 | 265 | def __bandpass_filter_variables(self): 266 | """Returns bandpass filter variables 267 | @return dictionary key:string variable value: lambda k 268 | """ 269 | basic = self.__basic_filter_variables() 270 | Op1 = 2 * (np.pi * (self.f1) / self.fs) 271 | Op2 = 2 * (np.pi * (self.f2) / self.fs) 272 | alpha = np.cos((Op2 + Op1) / 2.0) / np.cos((Op2 - Op1) / 2.0) 273 | k = (self.wc / 2.0) / np.tan((Op2 - Op1) / 2.0) 274 | A = 2 * alpha * k / (k + 1) 275 | B = (k - 1) / (k + 1) 276 | 277 | bandpass = np.zeros((9, (self.N // 2))) 278 | for k in range(self.N // 2): 279 | C = 1 - basic[5][k] * B + basic[6][k] * (B**2) 280 | 281 | bandpass[0][k] = basic[0][k] * ((1 - B)**2) / C 282 | bandpass[1][k] = 0 283 | bandpass[2][k] = -basic[1][k] 284 | bandpass[3][k] = 0 285 | bandpass[4][k] = basic[2][k] 286 | bandpass[5][k] = (A / C) * (B * (basic[5][k] - 287 | 2 * basic[6][k]) + (basic[5][k] - 2)) 288 | bandpass[6][k] = (1 / C) * ((A**2) * (1 - basic[5][k] + basic[6][k]) + 289 | 2 * B * (1 + basic[6][k]) - basic[5][k] * (B**2) - basic[5][k]) 290 | bandpass[7][k] = (A / C) * (B * (basic[5][k] - 2) + 291 | (basic[5][k] - 2 * basic[6][k])) 292 | bandpass[8][k] = (1 / C) * ((B**2) - basic[5][k] * B + basic[6][k]) 293 | return bandpass 294 | 295 | def __notch_filter_variables(self): 296 | """Returns notch filter variables 297 | @return dictionary key:string variable value: lambda k 298 | """ 299 | basic = self.__basic_filter_variables() 300 | x = 1.0 301 | f1 = (1.0 - (x / 100)) * self.fc 302 | f2 = (1.0 + (x / 100)) * self.fc 303 | Op1 = 2 * (np.pi * f1 / self.fs) 304 | Op2 = 2 * (np.pi * f2 / self.fs) 305 | alpha = np.cos((Op2 + Op1) / 2.0) / np.cos((Op2 - Op1) / 2.0) 306 | k = (self.wc / 2.0) * np.tan((Op2 - Op1) / 2.0) 307 | A = 2 * alpha / (k + 1) 308 | B = (1 - k) / (1 + k) 309 | 310 | notch = np.zeros((9, (self.N // 2))) 311 | for k in range(self.N // 2): 312 | C = 1 + basic[5][k] * \ 313 | B + basic[6][k] * (B**2) 314 | notch[0][k] = basic[0][k] * ((1 + B)**2) / C 315 | notch[1][k] = -4.0 * A / (B + 1) 316 | notch[2][k] = 2.0 * ((2 * (A**2)) / ((B + 1)**2) + 1) 317 | notch[3][k] = -4.0 * A / (B + 1) 318 | notch[4][k] = 1 319 | notch[5][k] = -(A / C) * \ 320 | (B * (basic[5][k] + 2 * basic[6][k]) + 321 | (2 + basic[5][k])) 322 | notch[6][k] = (1 / C) * \ 323 | ((A**2) * (1 + basic[5][k] + basic[6][k]) + 324 | 2 * B * (1 + basic[6][k]) + 325 | basic[5][k] * (B**2) + 326 | basic[5][k]) 327 | notch[7][k] = -(A / C) * \ 328 | (B * (basic[5][k] + 2) + 329 | (basic[5][k] + 2 * basic[6][k])) 330 | notch[8][k] = (1 / C) * \ 331 | ((B**2) + 332 | basic[5][k] * B + 333 | basic[6][k]) 334 | return notch 335 | 336 | def __bandstop_filter_variables(self): 337 | """Returns bandstop filter variables 338 | @return dictionary key:string variable value: lambda k 339 | """ 340 | basic = self.__basic_filter_variables() 341 | Op1 = 2 * (np.pi * self.f1 / self.fs) 342 | Op2 = 2 * (np.pi * self.f2 / self.fs) 343 | alpha = np.cos((Op2 + Op1) / 2.0) / np.cos((Op2 - Op1) / 2.0) 344 | k = (self.wc / 2.0) * np.tan((Op2 - Op1) / 2.0) 345 | A = 2 * alpha / (k + 1) 346 | B = (1 - k) / (1 + k) 347 | 348 | bandstop = np.zeros((9, self.N // 2)) 349 | for k in range(self.N // 2): 350 | C = 1 + basic[5][k] * \ 351 | B + basic[6][k] * (B**2) 352 | bandstop[0][k] = basic[0][k] * ((1 + B)**2) / C 353 | bandstop[1][k] = -4.0 * A / (B + 1) 354 | bandstop[2][k] = 2.0 * ((2 * (A**2)) / ((B + 1)**2) + 1) 355 | bandstop[3][k] = -4.0 * A / (B + 1) 356 | bandstop[4][k] = 1 357 | bandstop[5][k] = -(A / C) * \ 358 | (B * (basic[5][k] + 2 * basic[6][k]) + 359 | (2 + basic[5][k])) 360 | bandstop[6][k] = (1 / C) * \ 361 | ((A**2) * (1 + basic[5][k] + basic[6][k]) + 362 | 2 * B * (1 + basic[6][k]) + 363 | basic[5][k] * (B**2) + 364 | basic[5][k]) 365 | bandstop[7][k] = -(A / C) * \ 366 | (B * (basic[5][k] + 2) + 367 | (basic[5][k] + 2 * basic[6][k])) 368 | bandstop[8][k] = (1 / C) * \ 369 | ((B**2) + 370 | basic[5][k] * B + 371 | basic[6][k]) 372 | return bandstop 373 | -------------------------------------------------------------------------------- /data_receiver.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | 4 | class Receiver: 5 | 6 | def __init__(self): 7 | super().__init__() 8 | self.HOST = '' # use '' to expose to all networks 9 | self.PORT = 20550 10 | 11 | def receive(self): 12 | """Open specified port and return file-like object""" 13 | 14 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 15 | sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 16 | sock.bind((self.HOST, self.PORT)) 17 | sock.listen(0) 18 | request, addr = sock.accept() 19 | return request.makefile('r') 20 | 21 | # r = Receiver() 22 | # for line in r.receive(): 23 | # print(line) -------------------------------------------------------------------------------- /main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import numpy.linalg" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "%matplotlib widget\n", 20 | "np.set_printoptions(precision=8, suppress=True)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import data_receiver\n", 30 | "from mathlib import *\n", 31 | "from plotlib import *" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "# sampling rate\n", 41 | "dt = 0.01 # s\n", 42 | "\n", 43 | "# the initialization interval\n", 44 | "ts = 1 # s" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": {}, 50 | "source": [ 51 | "# Pull Data From Phone\n", 52 | "data order: gyroscorpe, accelerometer, magnetometer" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "r = data_receiver.Receiver()\n", 62 | "\n", 63 | "data = []\n", 64 | "\n", 65 | "for line in r.receive():\n", 66 | " data.append(line.split(','))\n", 67 | "\n", 68 | "data = np.array(data, dtype = np.float)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "# Initialization" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# discard the first and last few readings\n", 85 | "# for some reason they fluctuate a lot\n", 86 | "w = data[10:-10, 0:3]\n", 87 | "a = data[10:-10, 3:6]\n", 88 | "m = data[10:-10, 6:9]\n", 89 | "\n", 90 | "if(np.shape(w)[0] < ts/dt):\n", 91 | " print(\"not enough data for intialization!\")\n", 92 | "\n", 93 | "# gravity\n", 94 | "gn = a[:int(ts/dt)].mean(axis = 0)\n", 95 | "gn = -gn[:, np.newaxis]\n", 96 | "g0 = np.linalg.norm(gn) # save the initial magnitude of gravity\n", 97 | "\n", 98 | "# magnetic field\n", 99 | "mn = m[:int(ts/dt)].mean(axis = 0)\n", 100 | "mn = Normalized(mn)[:, np.newaxis] # magnitude is not important\n", 101 | "\n", 102 | "avar = a[:int(ts/dt)].var(axis=0)\n", 103 | "wvar = w[:int(ts/dt)].var(axis=0)\n", 104 | "mvar = m[:int(ts/dt)].var(axis=0)\n", 105 | "print('acc var: ', avar, ', ', np.linalg.norm(avar))\n", 106 | "print('ang var: ', wvar, ', ', np.linalg.norm(wvar))\n", 107 | "print('mag var: ', mvar, ', ', np.linalg.norm(mvar))\n", 108 | "\n", 109 | "# cut the initialization data\n", 110 | "w = w[int(ts/dt) - 1:] - w[:int(ts/dt)].mean(axis=0)\n", 111 | "a = a[int(ts/dt):]\n", 112 | "m = m[int(ts/dt):]\n", 113 | "\n", 114 | "sample_number = np.shape(a)[0]" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "a_filtered, w_filtered = Filt_signal((a, w), dt=dt, wn=10, btype='lowpass')\n", 124 | "plot_signal([a, a_filtered], [w, w_filtered], [m])" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "# Kalman Filter" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "gyro_noise = 10 * np.linalg.norm(wvar)\n", 141 | "acc_noise = 10 * np.linalg.norm(avar)\n", 142 | "mag_noise = 10 * np.linalg.norm(mvar)\n", 143 | "\n", 144 | "P = 1e-10 * I(4)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": { 151 | "tags": [ 152 | "outputPrepend" 153 | ] 154 | }, 155 | "outputs": [], 156 | "source": [ 157 | "a_nav = []\n", 158 | "orientations = []\n", 159 | "\n", 160 | "q = np.array([[1., 0., 0., 0.]]).T\n", 161 | "orin = -gn / np.linalg.norm(gn)\n", 162 | "\n", 163 | "t = 0\n", 164 | "while t < sample_number:\n", 165 | " wt = w[t, np.newaxis].T\n", 166 | " at = a[t, np.newaxis].T\n", 167 | " mt = m[t, np.newaxis].T \n", 168 | " mt = Normalized(mt)\n", 169 | "\n", 170 | " # Propagation\n", 171 | " Ft = F(q, wt, dt)\n", 172 | " Gt = G(q)\n", 173 | " Q = (gyro_noise * dt)**2 * Gt @ Gt.T\n", 174 | " \n", 175 | " q = Ft @ q\n", 176 | " q = Normalized(q)\n", 177 | " P = Ft @ P @ Ft.T + Q \n", 178 | "\n", 179 | " # Measurement Update\n", 180 | " # Use only normalized measurements to reduce error!\n", 181 | " \n", 182 | " # acc and mag prediction\n", 183 | " pa = Normalized(-Rotate(q) @ gn)\n", 184 | " pm = Normalized(Rotate(q) @ mn)\n", 185 | "\n", 186 | " # Residual\n", 187 | " Eps = np.vstack((Normalized(at), mt)) - np.vstack((pa, pm))\n", 188 | " \n", 189 | " # internal error + external error\n", 190 | " Ra = [(acc_noise / np.linalg.norm(at))**2 + (1 - g0 / np.linalg.norm(at))**2] * 3\n", 191 | " Rm = [mag_noise**2] * 3\n", 192 | " R = np.diag(Ra + Rm)\n", 193 | " \n", 194 | " Ht = H(q, gn, mn)\n", 195 | "\n", 196 | " S = Ht @ P @ Ht.T + R\n", 197 | " K = P @ Ht.T @ np.linalg.inv(S)\n", 198 | " q = q + K @ Eps\n", 199 | " P = P - K @ Ht @ P\n", 200 | " \n", 201 | " # Post Correction\n", 202 | " q = Normalized(q)\n", 203 | " P = 0.5 * (P + P.T) # make sure P is symmertical\n", 204 | " \n", 205 | " conj = -I(4)\n", 206 | " conj[0, 0] = 1\n", 207 | " an = Rotate(conj @ q) @ at + gn\n", 208 | " ori = Rotate(conj @ q) @ orin\n", 209 | "\n", 210 | " a_nav.append(an.T[0])\n", 211 | " orientations.append(ori.T[0])\n", 212 | "\n", 213 | " t += 1\n", 214 | "\n", 215 | "a_nav = np.array(a_nav)\n", 216 | "orientations = np.array(orientations)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "# Accelerometer Bias/Error Correction" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "a_threshold = 0.2" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "metadata": { 239 | "tags": [ 240 | "outputPrepend" 241 | ] 242 | }, 243 | "outputs": [], 244 | "source": [ 245 | "t_start = 0\n", 246 | "for t in range(sample_number):\n", 247 | " at = a_nav[t]\n", 248 | " if np.linalg.norm(at) > a_threshold:\n", 249 | " t_start = t\n", 250 | " break\n", 251 | "\n", 252 | "t_end = 0\n", 253 | "for t in range(sample_number - 1, -1,-1):\n", 254 | " at = a_nav[t]\n", 255 | " if np.linalg.norm(at - a_nav[-1]) > a_threshold:\n", 256 | " t_end = t\n", 257 | " break" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "print('motion starts at: ', t_start)\n", 267 | "print('motion ends at: ', t_end)" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "an_drift = a_nav[t_end:].mean(axis=0)\n", 277 | "an_drift_rate = an_drift / (t_end - t_start)\n", 278 | "\n", 279 | "for i in range(t_end - t_start):\n", 280 | " a_nav[t_start + i] -= (i+1) * an_drift_rate\n", 281 | "\n", 282 | "for i in range(sample_number - t_end):\n", 283 | " a_nav[t_end + i] -= an_drift" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": null, 289 | "metadata": {}, 290 | "outputs": [], 291 | "source": [ 292 | "filtered_a_nav, = Filt_signal([a_nav], dt=dt, wn=(0.01, 15), btype='bandpass')\n", 293 | "plot_3([a_nav, filtered_a_nav])\n", 294 | "# plot_3([a_nav])" 295 | ] 296 | }, 297 | { 298 | "cell_type": "markdown", 299 | "metadata": {}, 300 | "source": [ 301 | "# Zero Velocity Update" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": { 308 | "tags": [ 309 | "outputPrepend" 310 | ] 311 | }, 312 | "outputs": [], 313 | "source": [ 314 | "velocities = []\n", 315 | "prevt = -1\n", 316 | "still_phase = False\n", 317 | "\n", 318 | "v = np.zeros((3, 1))\n", 319 | "t = 0\n", 320 | "while t < sample_number:\n", 321 | " at = filtered_a_nav[t, np.newaxis].T\n", 322 | "\n", 323 | " if np.linalg.norm(at) < a_threshold:\n", 324 | " if not still_phase:\n", 325 | " predict_v = v + at * dt\n", 326 | "\n", 327 | " v_drift_rate = predict_v / (t - prevt)\n", 328 | " for i in range(t - prevt - 1):\n", 329 | " velocities[prevt + 1 + i] -= (i + 1) * v_drift_rate.T[0]\n", 330 | "\n", 331 | " v = np.zeros((3, 1))\n", 332 | " prevt = t\n", 333 | " still_phase = True\n", 334 | " else:\n", 335 | " v = v + at * dt\n", 336 | " still_phase = False\n", 337 | " \n", 338 | " t += 1\n", 339 | " velocities.append(v.T[0])\n", 340 | "velocities = np.array(velocities)" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "metadata": {}, 347 | "outputs": [], 348 | "source": [ 349 | "plot_3([velocities])" 350 | ] 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "metadata": {}, 355 | "source": [ 356 | "# Integration To Get Position" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "positions = []\n", 366 | "p = np.array([[0, 0, 0]]).T\n", 367 | "\n", 368 | "t = 0\n", 369 | "while t < sample_number:\n", 370 | " at = filtered_a_nav[t, np.newaxis].T\n", 371 | " vt = velocities[t, np.newaxis].T\n", 372 | "\n", 373 | " p = p + vt * dt + 0.5 * at * dt**2\n", 374 | " positions.append(p.T[0])\n", 375 | "\n", 376 | " t += 1\n", 377 | "\n", 378 | "positions = np.array(positions)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": {}, 385 | "outputs": [], 386 | "source": [ 387 | "plot_3D([[positions, 'position']])" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": null, 393 | "metadata": {}, 394 | "outputs": [], 395 | "source": [ 396 | "plot_3([positions])" 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "# Close All Graphs" 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": null, 409 | "metadata": {}, 410 | "outputs": [], 411 | "source": [ 412 | "plt.close('all')" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": null, 418 | "metadata": {}, 419 | "outputs": [], 420 | "source": [] 421 | } 422 | ], 423 | "metadata": { 424 | "kernelspec": { 425 | "display_name": "Python 3", 426 | "language": "python", 427 | "name": "python3" 428 | }, 429 | "language_info": { 430 | "codemirror_mode": { 431 | "name": "ipython", 432 | "version": 3 433 | }, 434 | "file_extension": ".py", 435 | "mimetype": "text/x-python", 436 | "name": "python", 437 | "nbconvert_exporter": "python", 438 | "pygments_lexer": "ipython3", 439 | "version": "3.7.7-final" 440 | } 441 | }, 442 | "nbformat": 4, 443 | "nbformat_minor": 4 444 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import inv, norm 3 | 4 | import data_receiver 5 | from mathlib import * 6 | from plotlib import * 7 | 8 | 9 | class IMUTracker: 10 | 11 | def __init__(self, sampling, data_order={'w': 1, 'a': 2, 'm': 3}): 12 | ''' 13 | @param sampling: sampling rate of the IMU, in Hz 14 | @param tinit: initialization time where the device is expected to be stay still, in second 15 | @param data_order: specify the order of data in the data array 16 | ''' 17 | 18 | super().__init__() 19 | # ---- parameters ---- 20 | self.sampling = sampling 21 | self.dt = 1 / sampling # second 22 | self.data_order = data_order 23 | 24 | # ---- helpers ---- 25 | idx = {1: [0, 3], 2: [3, 6], 3: [6, 9]} 26 | self._widx = idx[data_order['w']] 27 | self._aidx = idx[data_order['a']] 28 | self._midx = idx[data_order['m']] 29 | 30 | def initialize(self, data, noise_coefficient={'w': 100, 'a': 100, 'm': 10}): 31 | ''' 32 | Algorithm initialization 33 | 34 | @param data: (,9) ndarray 35 | @param cut: cut the first few data to avoid potential corrupted data 36 | @param noise_coefficient: sensor noise is determined by variance magnitude times this coefficient 37 | 38 | Return: a list of initialization values used by EKF algorithm: 39 | (gn, g0, mn, gyro_noise, gyro_bias, acc_noise, mag_noise) 40 | ''' 41 | 42 | # discard the first few readings 43 | # for some reason they might fluctuate a lot 44 | w = data[:, self._widx[0]:self._widx[1]] 45 | a = data[:, self._aidx[0]:self._aidx[1]] 46 | m = data[:, self._midx[0]:self._midx[1]] 47 | 48 | # ---- gravity ---- 49 | gn = -a.mean(axis=0) 50 | gn = gn[:, np.newaxis] 51 | # save the initial magnitude of gravity 52 | g0 = np.linalg.norm(gn) 53 | 54 | # ---- magnetic field ---- 55 | mn = m.mean(axis=0) 56 | # magnitude is not important 57 | mn = normalized(mn)[:, np.newaxis] 58 | 59 | # ---- compute noise covariance ---- 60 | avar = a.var(axis=0) 61 | wvar = w.var(axis=0) 62 | mvar = m.var(axis=0) 63 | print('acc var: %s, norm: %s' % (avar, np.linalg.norm(avar))) 64 | print('ang var: %s, norm: %s' % (wvar, np.linalg.norm(wvar))) 65 | print('mag var: %s, norm: %s' % (mvar, np.linalg.norm(mvar))) 66 | 67 | # ---- define sensor noise ---- 68 | gyro_noise = noise_coefficient['w'] * np.linalg.norm(wvar) 69 | gyro_bias = w.mean(axis=0) 70 | acc_noise = noise_coefficient['a'] * np.linalg.norm(avar) 71 | mag_noise = noise_coefficient['m'] * np.linalg.norm(mvar) 72 | return (gn, g0, mn, gyro_noise, gyro_bias, acc_noise, mag_noise) 73 | 74 | def attitudeTrack(self, data, init_list): 75 | ''' 76 | Removes gravity from acceleration data and transform it into navitgaion frame. 77 | Also tracks device's orientation. 78 | 79 | @param data: (,9) ndarray 80 | @param list: initialization values for EKF algorithm: 81 | (gn, g0, mn, gyro_noise, gyro_bias, acc_noise, mag_noise) 82 | 83 | Return: (acc, orientation) 84 | ''' 85 | 86 | # ------------------------------- # 87 | # ---- Initialization ---- 88 | # ------------------------------- # 89 | gn, g0, mn, gyro_noise, gyro_bias, acc_noise, mag_noise = init_list 90 | w = data[:, self._widx[0]:self._widx[1]] - gyro_bias 91 | a = data[:, self._aidx[0]:self._aidx[1]] 92 | m = data[:, self._midx[0]:self._midx[1]] 93 | sample_number = np.shape(data)[0] 94 | 95 | # ---- data container ---- 96 | a_nav = [] 97 | orix = [] 98 | oriy = [] 99 | oriz = [] 100 | 101 | # ---- states and covariance matrix ---- 102 | P = 1e-10 * I(4) # state covariance matrix 103 | q = np.array([[1, 0, 0, 0]]).T # quaternion state 104 | init_ori = I(3) # initial orientation 105 | 106 | # ------------------------------- # 107 | # ---- Extended Kalman Filter ---- 108 | # ------------------------------- # 109 | 110 | # all vectors are column vectors 111 | 112 | t = 0 113 | while t < sample_number: 114 | 115 | # ------------------------------- # 116 | # ---- 0. Data Preparation ---- 117 | # ------------------------------- # 118 | 119 | wt = w[t, np.newaxis].T 120 | at = a[t, np.newaxis].T 121 | mt = normalized(m[t, np.newaxis].T) 122 | 123 | # ------------------------------- # 124 | # ---- 1. Propagation ---- 125 | # ------------------------------- # 126 | 127 | Ft = F(q, wt, self.dt) 128 | Gt = G(q) 129 | Q = (gyro_noise * self.dt)**2 * Gt @ Gt.T 130 | 131 | q = normalized(Ft @ q) 132 | P = Ft @ P @ Ft.T + Q 133 | 134 | # ------------------------------- # 135 | # ---- 2. Measurement Update ---- 136 | # ------------------------------- # 137 | 138 | # Use normalized measurements to reduce error! 139 | 140 | # ---- acc and mag prediction ---- 141 | pa = normalized(-rotate(q) @ gn) 142 | pm = normalized(rotate(q) @ mn) 143 | 144 | # ---- residual ---- 145 | Eps = np.vstack((normalized(at), mt)) - np.vstack((pa, pm)) 146 | 147 | # ---- sensor noise ---- 148 | # R = internal error + external error 149 | Ra = [(acc_noise / np.linalg.norm(at))**2 + (1 - g0 / np.linalg.norm(at))**2] * 3 150 | Rm = [mag_noise**2] * 3 151 | R = np.diag(Ra + Rm) 152 | 153 | # ---- kalman gain ---- 154 | Ht = H(q, gn, mn) 155 | S = Ht @ P @ Ht.T + R 156 | K = P @ Ht.T @ np.linalg.inv(S) 157 | 158 | # ---- actual update ---- 159 | q = q + K @ Eps 160 | P = P - K @ Ht @ P 161 | 162 | # ------------------------------- # 163 | # ---- 3. Post Correction ---- 164 | # ------------------------------- # 165 | 166 | q = normalized(q) 167 | P = 0.5 * (P + P.T) # make sure P is symmertical 168 | 169 | # ------------------------------- # 170 | # ---- 4. other things ---- 171 | # ------------------------------- # 172 | 173 | # ---- navigation frame acceleration ---- 174 | conj = -I(4) 175 | conj[0, 0] = 1 176 | an = rotate(conj @ q) @ at + gn 177 | 178 | # ---- navigation frame orientation ---- 179 | orin = rotate(conj @ q) @ init_ori 180 | 181 | # ---- saving data ---- 182 | a_nav.append(an.T[0]) 183 | orix.append(orin.T[0, :]) 184 | oriy.append(orin.T[1, :]) 185 | oriz.append(orin.T[2, :]) 186 | 187 | t += 1 188 | 189 | a_nav = np.array(a_nav) 190 | orix = np.array(orix) 191 | oriy = np.array(oriy) 192 | oriz = np.array(oriz) 193 | return (a_nav, orix, oriy, oriz) 194 | 195 | def removeAccErr(self, a_nav, threshold=0.2, filter=False, wn=(0.01, 15)): 196 | ''' 197 | Removes drift in acc data assuming that 198 | the device stays still during initialization and ending period. 199 | The initial and final acc are inferred to be exactly 0. 200 | The final acc data output is passed through a bandpass filter to further reduce noise and drift. 201 | 202 | @param a_nav: acc data, raw output from the kalman filter 203 | @param threshold: acc threshold to detect the starting and ending point of motion 204 | @param wn: bandpass filter cutoff frequencies 205 | 206 | Return: corrected and filtered acc data 207 | ''' 208 | 209 | sample_number = np.shape(a_nav)[0] 210 | t_start = 0 211 | for t in range(sample_number): 212 | at = a_nav[t] 213 | if np.linalg.norm(at) > threshold: 214 | t_start = t 215 | break 216 | 217 | t_end = 0 218 | for t in range(sample_number - 1, -1, -1): 219 | at = a_nav[t] 220 | if np.linalg.norm(at - a_nav[-1]) > threshold: 221 | t_end = t 222 | break 223 | 224 | an_drift = a_nav[t_end:].mean(axis=0) 225 | an_drift_rate = an_drift / (t_end - t_start) 226 | 227 | for i in range(t_end - t_start): 228 | a_nav[t_start + i] -= (i + 1) * an_drift_rate 229 | 230 | for i in range(sample_number - t_end): 231 | a_nav[t_end + i] -= an_drift 232 | 233 | if filter: 234 | filtered_a_nav = filtSignal([a_nav], dt=self.dt, wn=wn, btype='bandpass')[0] 235 | return filtered_a_nav 236 | else: 237 | return a_nav 238 | 239 | def zupt(self, a_nav, threshold): 240 | ''' 241 | Applies Zero Velocity Update(ZUPT) algorithm to acc data. 242 | 243 | @param a_nav: acc data 244 | @param threshold: stationary detection threshold, the more intense the movement is the higher this should be 245 | 246 | Return: velocity data 247 | ''' 248 | 249 | sample_number = np.shape(a_nav)[0] 250 | velocities = [] 251 | prevt = -1 252 | still_phase = False 253 | 254 | v = np.zeros((3, 1)) 255 | t = 0 256 | while t < sample_number: 257 | at = a_nav[t, np.newaxis].T 258 | 259 | if np.linalg.norm(at) < threshold: 260 | if not still_phase: 261 | predict_v = v + at * self.dt 262 | 263 | v_drift_rate = predict_v / (t - prevt) 264 | for i in range(t - prevt - 1): 265 | velocities[prevt + 1 + i] -= (i + 1) * v_drift_rate.T[0] 266 | 267 | v = np.zeros((3, 1)) 268 | prevt = t 269 | still_phase = True 270 | else: 271 | v = v + at * self.dt 272 | still_phase = False 273 | 274 | velocities.append(v.T[0]) 275 | t += 1 276 | 277 | velocities = np.array(velocities) 278 | return velocities 279 | 280 | def positionTrack(self, a_nav, velocities): 281 | ''' 282 | Simple integration of acc data and velocity data. 283 | 284 | @param a_nav: acc data 285 | @param velocities: velocity data 286 | 287 | Return: 3D coordinates in navigation frame 288 | ''' 289 | 290 | sample_number = np.shape(a_nav)[0] 291 | positions = [] 292 | p = np.array([[0, 0, 0]]).T 293 | 294 | t = 0 295 | while t < sample_number: 296 | at = a_nav[t, np.newaxis].T 297 | vt = velocities[t, np.newaxis].T 298 | 299 | p = p + vt * self.dt + 0.5 * at * self.dt**2 300 | positions.append(p.T[0]) 301 | t += 1 302 | 303 | positions = np.array(positions) 304 | return positions 305 | 306 | 307 | def receive_data(mode='tcp'): 308 | data = [] 309 | 310 | if mode == 'tcp': 311 | r = data_receiver.Receiver() 312 | file = open('data.txt', 'w') 313 | print('listening...') 314 | for line in r.receive(): 315 | file.write(line) 316 | data.append(line.split(',')) 317 | data = np.array(data, dtype=np.float) 318 | return data 319 | 320 | if mode == 'file': 321 | file = open('data.txt', 'r') 322 | for line in file.readlines(): 323 | data.append(line.split(',')) 324 | data = np.array(data, dtype=np.float) 325 | return data 326 | 327 | else: 328 | raise Exception('Invalid mode argument: ', mode) 329 | 330 | 331 | def plot_trajectory(): 332 | tracker = IMUTracker(sampling=100) 333 | data = receive_data('file') # toggle data source between 'tcp' and 'file' here 334 | 335 | print('initializing...') 336 | init_list = tracker.initialize(data[5:30]) 337 | 338 | print('--------') 339 | print('processing...') 340 | 341 | # EKF step 342 | a_nav, orix, oriy, oriz = tracker.attitudeTrack(data[30:], init_list) 343 | 344 | # Acceleration correction step 345 | a_nav_filtered = tracker.removeAccErr(a_nav, filter=False) 346 | # plot3([a_nav, a_nav_filtered]) 347 | 348 | # ZUPT step 349 | v = tracker.zupt(a_nav_filtered, threshold=0.2) 350 | # plot3([v]) 351 | 352 | # Integration Step 353 | p = tracker.positionTrack(a_nav_filtered, v) 354 | plot3D([[p, 'position']]) 355 | 356 | # make 3D animation 357 | # xl = np.min(p[:, 0]) - 0.05 358 | # xh = np.max(p[:, 0]) + 0.05 359 | # yl = np.min(p[:, 1]) - 0.05 360 | # yh = np.max(p[:, 1]) + 0.05 361 | # zl = np.min(p[:, 2]) - 0.05 362 | # zh = np.max(p[:, 2]) + 0.05 363 | # plot3DAnimated(p, lim=[[xl, xh], [yl, yh], [zl, zh]], label='position', interval=5) 364 | 365 | 366 | if __name__ == '__main__': 367 | plot_trajectory() 368 | -------------------------------------------------------------------------------- /mathlib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.linalg import norm 3 | import scipy.signal 4 | 5 | 6 | def I(n): 7 | ''' 8 | unit matrix 9 | just making its name prettier than np.eye 10 | ''' 11 | return np.eye(n) 12 | 13 | 14 | def normalized(x): 15 | try: 16 | return x / np.linalg.norm(x) 17 | except: 18 | return x 19 | 20 | 21 | def skew(x): 22 | ''' 23 | takes in a 3d column vector 24 | returns its Skew-symmetric matrix 25 | ''' 26 | 27 | x = x.T[0] 28 | return np.array([[0, -x[2], x[1]], [x[2], 0, -x[0]], [-x[1], x[0], 0]]) 29 | 30 | 31 | def rotate(q): 32 | ''' 33 | rotation transformation matrix 34 | nav frame to body frame as q is expected to be q^nb 35 | R(q) @ x to rotate x 36 | ''' 37 | 38 | qv = q[1:4, :] 39 | qc = q[0] 40 | return (qc**2 - qv.T @ qv) * I(3) - 2 * qc * skew(qv) + 2 * qv @ qv.T 41 | 42 | 43 | def F(q, wt, dt): 44 | '''state transfer matrix''' 45 | 46 | w = wt.T[0] 47 | Omega = np.array([[0, -w[0], -w[1], -w[2]], [w[0], 0, w[2], -w[1]], 48 | [w[1], -w[2], 0, w[0]], [w[2], w[1], -w[0], 0]]) 49 | 50 | return I(4) + 0.5 * dt * Omega 51 | 52 | 53 | def G(q): 54 | '''idk what its called ''' 55 | 56 | q = q.T[0] 57 | return 0.5 * np.array([[-q[1], -q[2], -q[3]], [q[0], -q[3], q[2]], 58 | [q[3], q[0], -q[1]], [-q[2], q[1], q[0]]]) 59 | 60 | 61 | def Hhelper(q, vector): 62 | # just for convenience 63 | x = vector.T[0][0] 64 | y = vector.T[0][1] 65 | z = vector.T[0][2] 66 | q0 = q.T[0][0] 67 | q1 = q.T[0][1] 68 | q2 = q.T[0][2] 69 | q3 = q.T[0][3] 70 | 71 | h = np.array([ 72 | [q0*x - q3*y + q2*z, q1*x + q2*y + q3*z, -q2*x + q1*y + q0*z, -q3*x - q0*y + q1*z], 73 | [q3*x + q0*y - q1*z, q2*x - q1*y - q0*z, q1*x + q2*y + q3*z, q0*x - q3*y + q2*z], 74 | [-q2*x + q1*y +q0*z, q3*x + q0*y - q1*z, -q0*x + q3*y - q2*z, q1*x + q2*y + q3*z] 75 | ]) 76 | return 2 * h 77 | 78 | 79 | def H(q, gn, mn): 80 | ''' 81 | Measurement matrix 82 | ''' 83 | 84 | H1 = Hhelper(q, gn) 85 | H2 = Hhelper(q, mn) 86 | return np.vstack((-H1, H2)) 87 | 88 | 89 | def filtSignal(data, dt=0.01, wn=10, btype='lowpass', order=1): 90 | ''' 91 | filter all data at once 92 | uses butterworth filter of scipy 93 | @param data: [...] 94 | @param dt: sampling time 95 | @param wn: critical frequency 96 | ''' 97 | 98 | res = [] 99 | n, s = scipy.signal.butter(order, wn, fs=1 / dt, btype=btype) 100 | for d in data: 101 | d = scipy.signal.filtfilt(n, s, d, axis=0) 102 | res.append(d) 103 | return res -------------------------------------------------------------------------------- /plotlib.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from matplotlib.animation import FuncAnimation 5 | 6 | # plt.style.use('fivethirtyeight') 7 | 8 | 9 | def plotSignal(al: list, wl: list, ml: list): 10 | f, ax = plt.subplots(ncols=3, nrows=3) 11 | plot3(al, ax=ax[:, 0]) 12 | plot3(wl, ax=ax[:, 1]) 13 | plot3(ml, ax=ax[:, 2]) 14 | ax[0, 0].set_ylabel('x') 15 | ax[1, 0].set_ylabel('y') 16 | ax[2, 0].set_ylabel('z') 17 | ax[2, 0].set_xlabel('a') 18 | ax[2, 1].set_xlabel(r'$\omega$') 19 | ax[2, 2].set_xlabel('m') 20 | 21 | 22 | def plotgAndAcc(g, ab): 23 | ''' 24 | plot tracked gravity and body frame acceleration 25 | ''' 26 | 27 | fig, ax = plt.subplots(nrows=1, ncols=3) 28 | plot3([g, ab], 29 | ax=ax, 30 | lims=[[None, [-12, 12]]] * 3, 31 | labels=[['$g_x$', '$g_y$', '$g_z$'], ['$a^b_x$', '$a^b_y$', '$a^b_z$']], 32 | show_legend=True) 33 | 34 | 35 | def plot3(data, ax=None, lims=None, labels=None, show=False, show_legend=False): 36 | ''' 37 | @param data: [ndarray, ...] 38 | @param lims: [[[xl, xh], [yl, yh]], ...] 39 | @param labels: [[label_string, ...], ...] 40 | ''' 41 | 42 | show_flag = False 43 | if ax is None: 44 | show_flag = True 45 | f, ax = plt.subplots(ncols=1, nrows=3) 46 | 47 | for axel in range(3): 48 | has_label = False 49 | for n in range(len(data)): 50 | d = data[n] 51 | label = labels[n] if labels is not None else None 52 | 53 | if label is not None: 54 | ax[axel].plot(d[:, axel], label=label[axel]) 55 | has_label = True 56 | else: 57 | ax[axel].plot(d[:, axel]) 58 | 59 | lim = lims[axel] if lims is not None else None 60 | if lim is not None: 61 | if lim[0] is not None: 62 | ax[axel].set_xlim(lim[0][0], lim[0][1]) 63 | if lim[1] is not None: 64 | ax[axel].set_ylim(lim[1][0], lim[1][1]) 65 | 66 | if (has_label is not None) and show_legend: 67 | ax[axel].legend() 68 | ax[axel].grid(True) 69 | 70 | plt.tight_layout(pad=0.4, w_pad=0.5, h_pad=1.0) 71 | 72 | if show or show_flag: 73 | plt.show() 74 | return ax 75 | 76 | 77 | def plot3D(data, lim=None, ax=None): 78 | ''' 79 | @param data: [[data, label_string], ...] 80 | @param lim: [[xl, xh], [yl, yh], [zl, zh]] 81 | ''' 82 | 83 | if ax is None: 84 | fig = plt.figure() 85 | ax = fig.add_subplot(111, projection='3d') 86 | 87 | for item in data: 88 | label = item[1] 89 | d = item[0] 90 | ax.plot(d[:, 0], d[:, 1], d[:, 2], 'o', label=label) 91 | 92 | if lim is not None: 93 | if lim[0] is not None: 94 | ax.set_xlim(lim[0][0], lim[0][1]) 95 | if lim[1] is not None: 96 | ax.set_ylim(lim[1][0], lim[1][1]) 97 | if lim[2] is not None: 98 | ax.set_zlim(lim[2][0], lim[2][1]) 99 | 100 | ax.legend() 101 | ax.set_xlabel('X axis') 102 | ax.set_ylabel('Y axis') 103 | ax.set_zlabel('Z axis') 104 | ax.plot([0], [0], [0], 'ro') 105 | plt.show() 106 | 107 | 108 | # ----------------------------- # 109 | # ---- Animated Plots ---- 110 | # ----------------------------- # 111 | 112 | 113 | def plot3DAnimated(data, lim=[[-1, 1], [-1, 1], [-1, 1]], label=None, interval=10, show=True, repeat=False): 114 | ''' 115 | @param data: (n, 3) ndarray 116 | @param lim: [[xl, xh], [yl, yh], [zl, zh]] 117 | @param show: if it's set to false, you can call this function multiple times to draw multiple lines 118 | ''' 119 | 120 | fig = plt.figure() 121 | ax = fig.add_subplot(111, projection='3d') 122 | 123 | if label is not None: 124 | ln, = ax.plot([], [], [], 'o', label=label) 125 | else: 126 | ln, = ax.plot([], [], [], 'o') 127 | 128 | def init(): 129 | ax.plot([0], [0], [0], 'ro') 130 | ax.set_xlim(lim[0][0], lim[0][1]) 131 | ax.set_ylim(lim[1][0], lim[1][1]) 132 | ax.set_zlim(lim[2][0], lim[2][1]) 133 | if label is not None: 134 | ax.legend() 135 | return ln, 136 | 137 | def update(frame): 138 | ln.set_xdata(data[:frame, 0]) 139 | ln.set_ydata(data[:frame, 1]) 140 | ln.set_3d_properties(data[:frame, 2]) 141 | return ln, 142 | 143 | ani = FuncAnimation(fig, 144 | update, 145 | frames=range(1, 146 | np.shape(data)[0] + 1), 147 | init_func=init, 148 | blit=True, 149 | interval=interval, 150 | repeat=repeat) 151 | if show: 152 | plt.show() --------------------------------------------------------------------------------