├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── config ├── gs_linear.yml ├── gs_nonlinear.yml ├── nn_linear.yml ├── nn_nonlinear.yml ├── pda_linear.yml └── pda_nonlinear.yml ├── demo_mbm_tracker.py ├── demo_pmbm_tracker.py ├── scenarios ├── __init__.py ├── base.py ├── linear.py └── nonlinear.py ├── scripts ├── debug_radarinfo.py ├── demo_mbm_filter.py ├── ekf_slam.py ├── entry_bernoulli_sot.py ├── entry_config.py ├── entry_mot.py ├── entry_mot_rfs.py ├── entry_sot.py ├── test_kalman_error_ellipse.py └── test_range_bearing_meas_cov.py ├── snapshots ├── eval.png ├── meas.png └── res.png ├── statecircle ├── __init__.py ├── base.py ├── configuration │ ├── __init__.py │ └── base.py ├── datasets │ ├── __init__.py │ ├── base.py │ └── tests │ │ ├── __init__.py │ │ └── test_data_generators.py ├── estimator │ ├── __init__.py │ └── base.py ├── hypothesiser │ ├── __init__.py │ └── base.py ├── lib │ ├── harem │ │ ├── README.md │ │ ├── debugger.py │ │ └── test │ │ │ └── test_prev_state_decorator.py │ ├── lane_regressor.py │ └── ransac.py ├── models │ ├── __init__.py │ ├── base.py │ ├── birth │ │ ├── __init__.py │ │ └── base.py │ ├── density │ │ ├── __init__.py │ │ ├── base.py │ │ ├── kalman.py │ │ ├── kalman_accumulated.py │ │ ├── tests │ │ │ ├── __init__.py │ │ │ └── test_state_models.py │ │ ├── unscented.py │ │ └── unscented_accumulated.py │ ├── measurement │ │ ├── __init__.py │ │ ├── base.py │ │ ├── clutter.py │ │ ├── linear.py │ │ ├── nonlinear.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ ├── test_linear.py │ │ │ └── test_nonlinear.py │ ├── sensor │ │ ├── __init__.py │ │ ├── base.py │ │ └── tests │ │ │ ├── __init__.py │ │ │ └── test_sensor_models.py │ └── transition │ │ ├── __init__.py │ │ ├── base.py │ │ ├── linear.py │ │ ├── nonlinear.py │ │ └── tests │ │ ├── __init__.py │ │ ├── test_linear_model.py │ │ └── test_nonlinear_model.py ├── platform │ ├── __init__.py │ └── base.py ├── reader │ ├── __init__.py │ ├── base.py │ └── tests │ │ ├── __init__.py │ │ └── test_readers.py ├── reductor │ ├── __init__.py │ ├── base.py │ ├── gate.py │ ├── hypothesis_reductor.py │ └── tests │ │ ├── __init__.py │ │ └── test_gates.py ├── trackers │ ├── base.py │ ├── mot │ │ ├── __init__.py │ │ ├── global_nearest_neighbour_tracker.py │ │ ├── joint_probabilistic_data_association_tracker.py │ │ ├── mbm_filter.py │ │ ├── mbm_tracker.py │ │ ├── md_tracker.py │ │ ├── meas_driven_pmbm_tracker.py │ │ ├── multi_hypothesis_tracker.py │ │ ├── phd_filter.py │ │ ├── pmbm_filter.py │ │ ├── pmbm_tracker.py │ │ └── prototype_tracker.py │ └── sot │ │ ├── __init__.py │ │ ├── bernoulli_trackers.py │ │ ├── ego_slam_tracker.py │ │ ├── ego_tracker.py │ │ ├── gaussian_sum_tracker.py │ │ ├── nearest_neighbour_tracker.py │ │ └── probabilistic_data_association_tracker.py ├── types │ ├── __init__.py │ ├── base.py │ ├── data.py │ ├── state.py │ └── tests │ │ ├── __init__.py │ │ └── test_states.py ├── utils │ ├── __init__.py │ ├── assignment.py │ ├── common.py │ ├── data_association.py │ └── visualizer.py └── wiki │ ├── MultiHypothesis.pdf │ ├── MultiHypothesis.pptx │ ├── flow.png │ └── framework.png ├── testcase ├── test_mhtracker.py └── test_tracker_methods.py └── tools ├── __init__.py ├── make_movie.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.o 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | #lib/ 20 | #lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # my ignore list 105 | *.lprof 106 | *untitled* 107 | .idea/* 108 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # 安装 2 | 3 | ## 依赖库 4 | 5 | - python基本库,推荐安装anaconda,python版本>=3.7 6 | 7 | 下载地址 8 | 9 | https://www.anaconda.com/products/distribution 10 | 11 | - 其他依赖库: 12 | 13 | - opencv-python: `pip install opencv-python` 14 | 15 | - cmake:编译murty库需要,可在官网下载对应平台的版本 16 | 17 | - eigin: 编译murty库需要,ubuntu系统下`sudo apt install libeigen3-dev` 18 | 19 | - murty: 20 | 21 | 注意该库编译目前仅在linux系统编译测试过 22 | 23 | ```bash 24 | git clone --recursive https://github.com/erikbohnsack/murty.git 25 | pip3 install ./murty 26 | ``` 27 | 28 | 该库编译时存在一些问题,需要修改下编译代码再运行上述安装指令: 29 | 30 | 1. `setup.py`中L48,删除`'-DPYTHON_EXECUTABLE=' + '/usr/local/opt/python/bin/python3.7'` 31 | 2. `CMakeList.txt`中L5需正确设置eigen包含路径:`SET( EIGEN3_INCLUDE_DIR "/your/path/to/eigen3" )` 32 | 33 | ## 例程运行 34 | 35 | `python demo_pmbm_tracker.py`正确输出如下结果即安装正常。 36 | 37 | ![res](snapshots/res.png) 38 | 39 | ![eval](snapshots/eval.png) 40 | 41 | ![meas](snapshots/meas.png) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StateCircle 2 | 状态循环:多传感器多目标融合库,实现了多种单/多目标滤波/跟踪融合算法 3 | ![framework](statecircle/wiki/framework.png) 4 | 5 | ## 术语 6 | - 状态(`State`):在`statecircle`中状态是一个泛化的描述,指代库中的不同分布表示 7 | 8 | - 搜索区域(surveillance area):人为指定的算法工作区域,为2维平面或3维空间 9 | 10 | - 滤波(filtering):依据历史时刻至当前时刻的所有观测对当前搜索区域内目标状态量进行更新,不输出轨迹 11 | 12 | - 跟踪(track):依据历史时刻至当前时刻的所有观测对当前搜索区域内目标状态量进行更新,输出轨迹,每段轨迹有唯一标识 13 | 14 | ## 主要模块 15 | 16 | ### 状态类(`Types`) 17 | 18 | #### 密度/强度(density/intensity)状态 19 | - 高斯状态(`GaussianState`):包含充分统计量`mean`与`cov`属性,描述一个高斯分布。 20 | - 高斯和状态(`GaussianSumState`):描述一个为高斯和形式的强度函数,包含一组权重以及对应的高斯状态属性 21 | - 高斯混合状态(`GaussianMixtureState`):描述一个为高斯混合形式的密度函数,包含一组权重(权重之和为1)以及对应的高斯状态属性 22 | - 无迹状态(`UnscentedState`):描述无迹变换下的高斯分布表示,在无迹共轭密度模型`UnscentedKalmanDensity`中被使用 23 | - 粒子状态(`ParticleState`):描述任意分布的粒子表示,在粒子共轭密度`ParticleDensity`中被使用 24 | 25 | #### 随机有限集(Random finite sets)状态 26 | - 伯努利状态(`BernoulliState`):包含存在概率`prob`以及一个密度/强度状态`state`,表示伯努利随机有限集 27 | - 伯努利轨迹(`BernoulliTrajectory`):继承自伯努利状态,并附加出生时刻(`t_birth`),死亡时刻(`t_death`)以及死亡权重(`w_death`)三个属性,用来表示单目标轨迹 28 | - (TODO)多伯努利状态(`MultiBernoulli`):表示多伯努利随机有限集 29 | - (TODO)多伯努利混合状态(`MultiBernoulliMixture`):表示多伯努利混合随机有限集,有两种等价表示 30 | - 全局假设表示 31 | - Multi-Bernoulli states 32 | - 局部假设表示 33 | - Hypothesis look-up table 34 | - Hypothesis forest 35 | - Poisson状态(`PoissonState`):包含一个强度函数属性,表示泊松随机有限集(泊松点过程) 36 | 37 | #### 其他状态 38 | - 标签状态(`LabeledState`):任意状态可继承标签状态来扩展唯一ID以形成轨迹 39 | 40 | ### 模型类(`Models`) 41 | 42 | #### 共轭密度模型(`ConjugatedDensityModel`) 43 | 共轭密度模型用保证每个滤波/跟踪器中的状态在整个状态循环过程中维持相同的假设密度函数(assumed density function),每个共轭密度模型提供三个接口方法: 44 | - predict:预测状态 45 | - update:更新状态 46 | - predicted_log_likelihood:预测观测似然 47 | 48 | 具体实现包括: 49 | - 卡尔曼密度模型(`KalmanDensityModel`):用于线性高斯模型假设,卡尔曼滤波组件 50 | - 扩展卡尔曼密度模型(`KalmanDensityModel`):用于非线性程度不高的模型假设,扩展卡尔曼滤波组件 51 | - 卡尔曼累计密度模型(`KalmanAccumulatedDensityModel`):对累计密度函数进行估计,平滑输出轨迹 52 | - 无迹卡尔曼密度模型(`UnscenetedKalmanDensityModel`):用于非线性模型假设,无迹卡尔曼滤波组件 53 | - 粒子密度模型(`ParticleDensityModel`):用于任意非线性模型假设,粒子滤波组件 54 | 55 | #### (*)出生模型(`BirthModel`) 56 | 描述物体出现,分为: 57 | - 单目标出生模型 58 | - 多目标出生模型 59 | 60 | #### (*)观测模型(`MeasurementModel`) 61 | 描述传感器观测过程 62 | - 线性观测模型: 63 | - 非线性观测模型:RangeBearing 64 | -(TODO)点物体检测,框物体检测,扩展物体观测,强度观测 65 | -杂波模型: 66 | - 泊松杂波模型 67 | - (TODO)状态依赖伯努利杂波模型 68 | 69 | #### (*)转移模型(`TransitionModel`) 70 | 描述状态转移过程 71 | - 线性转移模型:CV 72 | - 非线性转移模型:CTRV 73 | - (TODO) 74 | 75 | #### 跟踪器(`Tracker`) 76 | 跟踪器(`Trackers`)目录下实现了多种单/多目标滤波/跟踪算法,所有算法实现在滤波基类(`Filter`)中的`filtering`方法中被统一为相同的流程: 77 | 1. predict 78 | 2. birth 79 | 3. update 80 | 4. estimate 81 | 5. reduction 82 | 83 | 每种算法实现在滤波主循环中的各个步骤维持各自相同形式的状态,完成状态的循环。 84 | ![framework](statecircle/wiki/flow.png) 85 | 图中状态循环包含5个个状态,`predict`, `birth`, `update`, `estimate`与`reduction`,在循环开始时,由`reader`读取 86 | 观测数据送入`predict`模块,此时`predict`只是记录当前观测数据的时间戳`timestamp`后即返回,并未做实际的状态预测 87 | 运算,然后陆续经过后续模块后再回到`predict`时,此时`Tracker`类已经记录了初始帧的时间戳,在`predict`模块拿到当前 88 | 观测时即可计算出两次观测数据时间间隔,然后才可进行`predict`运算。 89 | 90 | 具体实现可分为单目标跟踪与多目标跟踪两大类 91 | 92 | ##### 单目标跟踪器 93 | - 最近邻跟踪器(`NearestNeighbourTracker`) 94 | - 概率数据关联跟踪器(`ProbabilisticDataAssociationTracker`) 95 | - 高斯核跟踪器(`GaussianSumTracker`) 96 | - 伯努利跟踪器(`BernoulliTracker`) 97 | 98 | ##### 多目标滤波/跟踪器 99 | - 全局最近邻跟踪器(`GlobalNearestNeighbourTracker`) 100 | - 联合概率数据关联跟踪器(`JointProbabilisticDataAssociationTracker`) 101 | - 多假设跟踪器(`MultiHypothesisTracker`) 102 | - 概率假设密度滤波器(`PHDFilter`) 103 | - (TODO)概率假设密度跟踪器(`PHDTracker`) 104 | - (TODO)集势概率假设密度滤波器(`CPHDFilter`) 105 | - 多伯努利混合滤波器(`MBMFilter`) 106 | - 多伯努利混合跟踪器(`MBMTracker`) 107 | - 泊松多伯努利混合滤波器(`PMBMFilter`) 108 | - 泊松多伯努利混合跟踪器(`PMBMTracker`) 109 | -------------------------------------------------------------------------------- /config/gs_linear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: gs linear 3 | type: statecircle.trackers.sot.gaussian_sum_tracker.GaussianSumTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: [0, 0, 10, 10] 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.linear.ConstantVelocityModel 16 | sigma: 5 17 | 18 | MEASUREMENT_MODEL: 19 | type: statecircle.models.measurement.linear.LinearMeasurementModel 20 | mapping: [1, 1, 0, 0] 21 | sigma: 10 22 | 23 | CLUTTER_MODEL: 24 | type: statecircle.models.measurement.clutter.PoissonClutterModel 25 | detection_rate: 0.9 26 | lambda_clutter: 20 # expectation number of clutter per frame 27 | scope: [[0, 1000], [0, 1000]] 28 | 29 | GATE: 30 | type: statecircle.reductor.gate.EllipsoidalGate 31 | percentile: 0.999 32 | 33 | ESTIMATOR: 34 | type: statecircle.estimator.base.EAPEstimator 35 | 36 | REDUCTOR: 37 | type: statecircle.reductor.hypothesis_reductor.HypothesisReductor 38 | weight_min: 0.001 39 | merging_threshold: 2 40 | capping_num: 100 41 | 42 | 43 | -------------------------------------------------------------------------------- /config/gs_nonlinear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: gs nonlinear 3 | type: statecircle.trackers.sot.gaussian_sum_tracker.GaussianSumTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: <[0, 0, 10, 0, np.pi/180]> 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.nonlinear.SimpleCTRVModel 16 | sigma_vel: 1 17 | sigma_omega: 18 | 19 | MEASUREMENT_MODEL: 20 | type: statecircle.models.measurement.nonlinear.RangeBearningMeasurementModel 21 | sigma_range: 5 22 | sigma_bearing: 23 | origin: [300, 400] 24 | 25 | CLUTTER_MODEL: 26 | type: statecircle.models.measurement.clutter.PoissonClutterModel 27 | detection_rate: 0.9 28 | lambda_clutter: 20 # expectation number of clutter per frame 29 | scope: <[[0, 1000], [-np.pi, np.pi]]> 30 | 31 | GATE: 32 | type: statecircle.reductor.gate.EllipsoidalGate 33 | percentile: 0.999 34 | 35 | ESTIMATOR: 36 | type: statecircle.estimator.base.EAPEstimator 37 | 38 | REDUCTOR: 39 | type: statecircle.reductor.hypothesis_reductor.HypothesisReductor 40 | weight_min: 0.001 41 | merging_threshold: 2 42 | capping_num: 100 -------------------------------------------------------------------------------- /config/nn_linear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: nn linear 3 | type: statecircle.trackers.sot.nearest_neighbour_tracker.NearestNeighbourTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: [0, 0, 10, 10] 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.linear.ConstantVelocityModel 16 | sigma: 5 17 | 18 | MEASUREMENT_MODEL: 19 | type: statecircle.models.measurement.linear.LinearMeasurementModel 20 | mapping: [1, 1, 0, 0] 21 | sigma: 10 22 | 23 | CLUTTER_MODEL: 24 | type: statecircle.models.measurement.clutter.PoissonClutterModel 25 | detection_rate: 0.9 26 | lambda_clutter: 20 # expectation number of clutter per frame 27 | scope: [[0, 1000], [0, 1000]] 28 | 29 | GATE: 30 | type: statecircle.reductor.gate.EllipsoidalGate 31 | percentile: 0.999 32 | 33 | ESTIMATOR: 34 | type: statecircle.estimator.base.EAPEstimator 35 | 36 | -------------------------------------------------------------------------------- /config/nn_nonlinear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: nn nonlinear 3 | type: statecircle.trackers.sot.nearest_neighbour_tracker.NearestNeighbourTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: <[0, 0, 10, 0, np.pi/180]> 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.nonlinear.SimpleCTRVModel 16 | sigma_vel: 1 17 | sigma_omega: 18 | 19 | MEASUREMENT_MODEL: 20 | type: statecircle.models.measurement.nonlinear.RangeBearningMeasurementModel 21 | sigma_range: 5 22 | sigma_bearing: 23 | origin: [300, 400] 24 | 25 | CLUTTER_MODEL: 26 | type: statecircle.models.measurement.clutter.PoissonClutterModel 27 | detection_rate: 0.9 28 | lambda_clutter: 20 # expectation number of clutter per frame 29 | scope: <[[0, 1000], [-np.pi, np.pi]]> 30 | 31 | GATE: 32 | type: statecircle.reductor.gate.EllipsoidalGate 33 | percentile: 0.999 34 | 35 | ESTIMATOR: 36 | type: statecircle.estimator.base.EAPEstimator 37 | 38 | -------------------------------------------------------------------------------- /config/pda_linear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: pda linear 3 | type: statecircle.trackers.sot.probabilistic_data_association_tracker.ProbabilisticDataAssociationTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: [0, 0, 10, 10] 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.linear.ConstantVelocityModel 16 | sigma: 5 17 | 18 | MEASUREMENT_MODEL: 19 | type: statecircle.models.measurement.linear.LinearMeasurementModel 20 | mapping: [1, 1, 0, 0] 21 | sigma: 10 22 | 23 | CLUTTER_MODEL: 24 | type: statecircle.models.measurement.clutter.PoissonClutterModel 25 | detection_rate: 0.9 26 | lambda_clutter: 20 # expectation number of clutter per frame 27 | scope: [[0, 1000], [0, 1000]] 28 | 29 | GATE: 30 | type: statecircle.reductor.gate.EllipsoidalGate 31 | percentile: 0.999 32 | 33 | ESTIMATOR: 34 | type: statecircle.estimator.base.EAPEstimator 35 | 36 | REDUCTOR: 37 | type: statecircle.reductor.hypothesis_reductor.HypothesisReductor 38 | weight_min: 0.001 39 | merging_threshold: 2 40 | capping_num: 100 41 | 42 | 43 | -------------------------------------------------------------------------------- /config/pda_nonlinear.yml: -------------------------------------------------------------------------------- 1 | TRACKER: 2 | name: pda nonlinear 3 | type: statecircle.trackers.sot.probabilistic_data_association_tracker.ProbabilisticDataAssociationTracker 4 | 5 | # basic components 6 | BIRTH_MODEL: 7 | type: statecircle.models.birth.base.SingleObjectBirthModel 8 | birth_cov: 9 | initial_state: <[0, 0, 10, 0, np.pi/180]> 10 | 11 | DENSITY_MODEL: 12 | type: statecircle.models.density.kalman.KalmanDensityModel 13 | 14 | TRANSITION_MODEL: 15 | type: statecircle.models.transition.nonlinear.SimpleCTRVModel 16 | sigma_vel: 1 17 | sigma_omega: 18 | 19 | MEASUREMENT_MODEL: 20 | type: statecircle.models.measurement.nonlinear.RangeBearningMeasurementModel 21 | sigma_range: 5 22 | sigma_bearing: 23 | origin: [300, 400] 24 | 25 | CLUTTER_MODEL: 26 | type: statecircle.models.measurement.clutter.PoissonClutterModel 27 | detection_rate: 0.9 28 | lambda_clutter: 20 # expectation number of clutter per frame 29 | scope: <[[0, 1000], [-np.pi, np.pi]]> 30 | 31 | GATE: 32 | type: statecircle.reductor.gate.EllipsoidalGate 33 | percentile: 0.999 34 | 35 | ESTIMATOR: 36 | type: statecircle.estimator.base.EAPEstimator 37 | 38 | REDUCTOR: 39 | type: statecircle.reductor.hypothesis_reductor.HypothesisReductor 40 | weight_min: 0.001 41 | merging_threshold: 2 42 | capping_num: 100 -------------------------------------------------------------------------------- /demo_pmbm_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Dec 25 17:32:44 2019 5 | 6 | @author: zhaoxm 7 | """ 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from scenarios.linear import LinearScenario 12 | from scenarios.nonlinear import NonlinearScenario 13 | from statecircle.estimator.base import EAPEstimator 14 | from statecircle.models.density.kalman_accumulated import KalmanAccumulatedDensityModel 15 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 16 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 17 | from statecircle.reader.base import MeasurementReader 18 | from statecircle.reductor.gate import EllipsoidalGate 19 | from statecircle.models.sensor.base import DummySensorModel 20 | from statecircle.models.measurement.clutter import PoissonClutterModel 21 | from statecircle.models.measurement.linear import LinearMeasurementModel 22 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 23 | from statecircle.models.transition.linear import ConstantVelocityModel 24 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 25 | from statecircle.trackers.mot.pmbm_tracker import PMBMTracker 26 | from statecircle.trackers.mot.meas_driven_pmbm_tracker import MeasDrivenPMBMTracker 27 | 28 | PMBMTracker = MeasDrivenPMBMTracker # [PMBMTracker | MeasDrivenPMBMTracker] 29 | seed = 666 30 | scenario = 'linear' 31 | # build scene 32 | if scenario == 'linear': 33 | scene = LinearScenario.caseC(birth_weight=0.0005, birth_cov_scale=400) 34 | birth_model = scene.birth_model 35 | 36 | # build transition/measurement/clutter/birth models 37 | transition_model = ConstantVelocityModel(sigma=5) 38 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 39 | sigma=10) 40 | clutter_model = PoissonClutterModel(detection_rate=0.9, 41 | lambda_clutter=20, 42 | scope=[[-1000, 1000], [-1000, 1000]]) 43 | else: 44 | scene = NonlinearScenario.caseC(birth_weight=0.001) 45 | birth_model = scene.birth_model 46 | 47 | # make transition/measurement/clutter/birth models 48 | transition_model = SimpleCTRVModel(sigma_vel=5, 49 | sigma_omega=np.pi / 180) 50 | measurement_model = RangeBearningMeasurementModel(sigma_range=5, 51 | sigma_bearing=np.pi / 180, 52 | origin=[300, 400]) 53 | clutter_model = PoissonClutterModel(detection_rate=0.9, 54 | lambda_clutter=20, 55 | scope=[[200, 1200], [-np.pi, np.pi]]) 56 | 57 | 58 | # build data generator 59 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=False) 60 | 61 | # build sensor model 62 | sensor_model = DummySensorModel(clutter_model, measurement_model, random_seed=seed) 63 | 64 | # build data reader 65 | data_reader = MeasurementReader(data_generator, sensor_model) 66 | 67 | # build density model 68 | #density_model = KalmanDensityModel() 69 | density_model = KalmanAccumulatedDensityModel(traceback_range=10) 70 | 71 | # gate method 72 | gate = EllipsoidalGate(percentile=0.999) 73 | 74 | # estimator 75 | estimator = EAPEstimator() 76 | 77 | # reductor 78 | reductor = HypothesisReductor(weight_min=0.01, merging_threshold=4, capping_num=100) 79 | 80 | # %% build trackers & filtering 81 | # some extra parameters 82 | # TODO: reformat the input parameters 83 | prior_birth = False 84 | surviving_rate = 0.99 85 | recycle_threshold = 0.1 86 | prob_min = 0.01 87 | prob_estimate = 0.5 88 | meas_models_dict = None 89 | clutter_models_dict = None 90 | 91 | pmbm_filter = PMBMTracker(prior_birth, 92 | surviving_rate, 93 | recycle_threshold, 94 | prob_min, 95 | prob_estimate, 96 | meas_models_dict, 97 | clutter_models_dict, 98 | birth_model, 99 | density_model, 100 | transition_model, 101 | measurement_model, 102 | clutter_model, 103 | gate, 104 | estimator, 105 | reductor) 106 | 107 | pmbm_estimates = pmbm_filter.filtering(data_reader) 108 | 109 | #%% ploting 110 | animation = False 111 | show_birth = True 112 | 113 | gt_datum = np.concatenate(data_generator.gt_series.datum, axis=-1) 114 | true_state = np.concatenate([ele.states for ele in gt_datum], -1) 115 | for k, (_, obj_meas_data, clutter_data) in enumerate(data_reader.truth_meas_generator()): 116 | if not animation: 117 | k = scene.time_range[-1] - 1 118 | PMBM_estimated_state = np.hstack(pmbm_estimates[k]).squeeze() 119 | 120 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 121 | plot_gt = ax.plot(true_state[0], true_state[1], 'yo', alpha=0.2, markersize=10) 122 | 123 | ax.grid() 124 | ax.set_xlabel('x (m)') 125 | ax.set_ylabel('y (m)') 126 | 127 | # # plot birth region 128 | # if prior_birth and show_birth: 129 | # for birth_state in birth_model: 130 | # plot_birth = plot_covariance_ellipse(birth_state.x[:2], birth_state.P[:2,:2], 'b', ax, 3) 131 | # else: 132 | # plot_birth = None 133 | 134 | # plot tracks 135 | for track in pmbm_estimates[k]: 136 | range_t = track['range'][-1] - track['range'][0] + 1 137 | # if range_t > -1: 138 | plt.plot(track['trajectory'][0], track['trajectory'][1], '-') 139 | 140 | # plot obejct measurements 141 | state_meas = measurement_model.reverse(obj_meas_data) 142 | plot_meas = ax.plot(state_meas[0], state_meas[1], 'r*', alpha=1) 143 | # plot clutter 144 | state_clutter = measurement_model.reverse(clutter_data) 145 | plot_clutter = ax.plot(state_clutter[0], state_clutter[1], 'k.', alpha=1) 146 | 147 | ax.legend((plot_gt[0], plot_meas[0], plot_clutter[0]), 148 | ['ground truth', 'detections', 'clutter'], 149 | loc='upper left') 150 | 151 | ax.set_xlim([-1000, 1000]) 152 | ax.set_ylim([-1000, 1000]) 153 | # plt.axis('equal') 154 | # plt.savefig('snapshot/results/track_{:04d}.png'.format(k)) 155 | 156 | plt.show() 157 | print('step: {}'.format(k)) 158 | plt.close('all') 159 | 160 | if not animation: 161 | break 162 | 163 | #%% plot cardinality 164 | plt.figure() 165 | plt.plot(data_generator.gt_series.num, 'yo') 166 | #PMBM_card_pred = [for ele in track for step, track in enumerate(PMBMEstimates)] 167 | PMBM_card_pred = [] 168 | for step, tracks in enumerate(pmbm_estimates): 169 | valid_track_num = 0 170 | for track in tracks: 171 | range_t = track['range'][-1] - track['range'][0] + 1 172 | if track['range'][-1] == step: 173 | valid_track_num += 1 174 | PMBM_card_pred.append(valid_track_num) 175 | plt.plot(PMBM_card_pred, 'b+') 176 | plt.legend(['GT', 'PMBM']) 177 | plt.grid() 178 | 179 | 180 | # %% plot measurements 181 | meas, obj_meas, clutter_meas = [], [], [] 182 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 183 | meas.append(meas_data.meas) 184 | obj_meas.append(obj_meas_) 185 | clutter_meas.append(clutter_meas_) 186 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 187 | 188 | plt.figure() 189 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 190 | 191 | # plot clutter 192 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 193 | plt.legend(['measurements', 'clutter']) 194 | plt.show() 195 | plt.close('all') 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /scenarios/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Dec 24 15:59:55 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | class BaseScenario: 10 | def __init__(self, time_range, birth_times, death_times, initial_states, birth_model): 11 | r""" Nonlinear scenerio class 12 | 13 | Attributes 14 | ---------- 15 | birth_times : list[#states] 16 | death_times : list[#states] 17 | time_range : list[#states] 18 | initial_states : 2darray[state_dim, #states] 19 | """ 20 | assert len(birth_times) == len(death_times) == initial_states.shape[1] 21 | assert len(time_range) == 2 and time_range[1] >= time_range[0] 22 | assert time_range[0] <= min(birth_times) 23 | assert time_range[1] >= max(death_times) 24 | 25 | self.birth_times = birth_times 26 | self.death_times = death_times 27 | self.time_range = time_range 28 | self.initial_states = initial_states 29 | self.state_dim, self.birth_num = initial_states.shape 30 | self.birth_model = birth_model -------------------------------------------------------------------------------- /scenarios/linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Dec 24 15:20:19 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | 11 | from .base import BaseScenario 12 | from statecircle.models.birth.base import SingleObjectBirthModel, MultiObjectBirthModel, PoissonBirthModel, BernoulliBirthModel 13 | from statecircle.types.state import GaussianState, GaussianSumState, BernoulliState, GaussianMixtureState, GaussianState 14 | 15 | 16 | class LinearScenario(BaseScenario): 17 | 18 | @classmethod 19 | def caseA(cls, state_func_handle=GaussianState): 20 | time_range = [0, 100] 21 | tbirth = [0] 22 | tdeath = [100] 23 | state_dim = 4 24 | initial_states = np.array([[0, 0, 10, 10]]).T 25 | 26 | # make birth model 27 | birth_cov = 10 * np.eye(state_dim) 28 | birth_model = SingleObjectBirthModel(initial_states, birth_cov, state_func_handle) 29 | 30 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 31 | 32 | @classmethod 33 | def caseB(cls, state_func_handle=GaussianState): 34 | # create ground truth model 35 | birth_num = 5 36 | time_range = [0, 100] 37 | tbirth = [0] * birth_num 38 | tdeath = [100] * birth_num 39 | initial_states = np.array([[0, 0, 0, -10], 40 | [400, -600, -10, 5], 41 | [-800, -200, 20, -5], 42 | [0, 0, 7.5, -5], 43 | [-200, 800, -3, -15]]).T 44 | # make birth model 45 | state_dim = 4 46 | birth_cov = np.eye(state_dim) 47 | birth_model = MultiObjectBirthModel(initial_states, birth_cov, state_func_handle) 48 | 49 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 50 | 51 | @classmethod 52 | def caseC(cls, birth_weight=0.03, birth_cov_scale=400, state_func_handle=GaussianState): 53 | time_range = [0, 100] 54 | tbirth = [1, 0, 1, 20, 20, 20, 40, 40, 60, 60, 80, 80] 55 | tdeath = [70, 100, 70, 100, 100, 100, 100, 100, 100, 100, 100, 100] 56 | initial_states = np.array([[0, 0, 0, -10], 57 | [400, -600, -10, 5], 58 | [-800, -200, 20, -5], 59 | [400, -600, -7, -4], 60 | [400, -600, -2.5, 10], 61 | [0, 0, 7.5, -5], 62 | [-800, -200, 12, 7], 63 | [-200, 800, 15, -10], 64 | [-800, -200, 3, 15], 65 | [-200, 800, -3, -15], 66 | [0, 0, -20, -15], 67 | [-200, 800, 15, -5]]).T 68 | 69 | # build birth model 70 | state_dim = 4 71 | birth_log_weights = np.log(birth_weight * np.ones(state_dim)) 72 | birth_num = 4 73 | birth_states = [None] * birth_num 74 | 75 | birth_cov = birth_cov_scale * np.eye(state_dim) 76 | birth_states[0] = state_func_handle(np.array([0, 0, 0, 0]), birth_cov) 77 | birth_states[1] = state_func_handle(np.array([400, -600, 0, 0]), birth_cov) 78 | birth_states[2] = state_func_handle(np.array([-800, -200, 0, 0]), birth_cov) 79 | birth_states[3] = state_func_handle(np.array([-200, 800, 0, 0]), birth_cov) 80 | intensity = GaussianSumState(birth_log_weights, birth_states) 81 | 82 | birth_model = PoissonBirthModel(intensity) 83 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 84 | 85 | @classmethod 86 | def caseD(cls, birth_prob=0.1, birth_cov_scale=400, state_func_handle=GaussianState): 87 | time_range = [0, 100] 88 | tbirth = [10] 89 | tdeath = [80] 90 | state_dim = 4 91 | initial_states = np.array([[0, 0, 10, 10]]).T 92 | 93 | # make birth model 94 | birth_cov = birth_cov_scale * np.eye(state_dim) 95 | birth_gaussian = state_func_handle(initial_states, birth_cov) 96 | bern = BernoulliState(prob=birth_prob, 97 | state=GaussianMixtureState(log_weights=np.array([0.]), 98 | gaussian_states=[birth_gaussian])) 99 | birth_model = BernoulliBirthModel(bern) 100 | 101 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 102 | 103 | @classmethod 104 | def caseE(cls, birth_weight=0.03, birth_cov_scale=400, state_func_handle=GaussianState): 105 | # build birth model 106 | time_range = [0, 100] 107 | tbirth = [10] 108 | tdeath = [80] 109 | initial_states = np.array([[0, 0, 0, 0, 0, 0]]).T 110 | 111 | state_dim = 6 112 | birth_num = 6 113 | birth_log_weights = np.log(birth_weight * np.ones(state_dim)) 114 | birth_states = [None] * birth_num 115 | birth_cov = birth_cov_scale * np.eye(state_dim) 116 | birth_states[0] = state_func_handle(np.ones(state_dim), birth_cov) 117 | birth_states[1] = state_func_handle(np.ones(state_dim), birth_cov) 118 | birth_states[2] = state_func_handle(np.ones(state_dim), birth_cov) 119 | birth_states[3] = state_func_handle(np.ones(state_dim), birth_cov) 120 | birth_states[4] = state_func_handle(np.ones(state_dim), birth_cov) 121 | birth_states[5] = state_func_handle(np.ones(state_dim), birth_cov) 122 | intensity = GaussianSumState(birth_log_weights, birth_states) 123 | 124 | birth_model = PoissonBirthModel(intensity) 125 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 126 | 127 | @classmethod 128 | def caseF(cls, birth_weight=0.03, birth_cov_scale=400, state_func_handle=GaussianState): 129 | time_range = [0, 100] 130 | tbirth = [1] 131 | tdeath = [100] 132 | initial_states = np.array([[0, 0, 10, 0]]).T 133 | 134 | # build birth model 135 | state_dim = 4 136 | birth_num = 1 137 | birth_log_weights = np.log(birth_weight * np.ones(birth_num)) 138 | birth_states = [None] * birth_num 139 | 140 | birth_cov = birth_cov_scale * np.eye(state_dim) 141 | birth_states[0] = state_func_handle(np.array([0, 0, 0, 0]), birth_cov) 142 | intensity = GaussianSumState(birth_log_weights, birth_states) 143 | 144 | birth_model = PoissonBirthModel(intensity) 145 | return LinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 146 | -------------------------------------------------------------------------------- /scenarios/nonlinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Dec 24 15:54:01 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | 11 | from .base import BaseScenario 12 | from statecircle.models.birth.base import SingleObjectBirthModel, MultiObjectBirthModel, PoissonBirthModel, \ 13 | BernoulliBirthModel 14 | from statecircle.types.state import GaussianState, GaussianSumState, BernoulliState, GaussianMixtureState 15 | 16 | 17 | class NonlinearScenario(BaseScenario): 18 | 19 | @classmethod 20 | def caseA(cls, state_func_handle=GaussianState): 21 | time_range = [0, 100] 22 | tbirth = [0] 23 | tdeath = [100] 24 | initial_states = np.array([[0, 0, 10, 0, np.pi / 180]]).T 25 | 26 | # make birth model 27 | birth_cov = np.diag(np.array([1, 1, 1, 1 * np.pi / 180, 1 * np.pi / 180]) ** 2) 28 | birth_model = SingleObjectBirthModel(initial_states, birth_cov, state_func_handle) 29 | 30 | return NonlinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 31 | 32 | @classmethod 33 | def caseB(cls, state_func_handle): 34 | # create ground truth model 35 | time_range = [0, 100] 36 | birth_num = 4 37 | tbirth = [0] * birth_num 38 | tdeath = [100] * birth_num 39 | initial_states = np.array([[0, 0, 5, 0, np.pi / 180], 40 | [20, 20, -20, 0, np.pi / 90], 41 | [-20, 10, -10, 0, np.pi / 360], 42 | [-10, -10, 8, 0, np.pi / 270]]).T 43 | 44 | # make birth model 45 | birth_cov = np.diag(np.array([1, 1, 1, 1 * np.pi / 180, 1 * np.pi / 180]) ** 2) 46 | birth_model = MultiObjectBirthModel(initial_states, birth_cov, state_func_handle) 47 | 48 | return NonlinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 49 | 50 | @classmethod 51 | def caseC(cls, birth_weight=0.3, birth_cov_scale=10, state_func_handle=GaussianState): 52 | time_range = [0, 100] 53 | tbirth = [0, 20, 40, 60] 54 | tdeath = [50, 70, 90, 100] 55 | initial_states = np.array([[0, 0, 5, 0, np.pi / 180], 56 | [20, 20, -10, 0, np.pi / 90], 57 | [-20, 10, -10, 0, np.pi / 360], 58 | [-10, -10, 8, 0, np.pi / 270]]).T 59 | 60 | # build birth model 61 | birth_num = 4 62 | birth_log_weights = np.log(birth_weight * np.ones(birth_num)) 63 | 64 | birth_cov = birth_cov_scale * np.diag(np.array([1, 1, 1, 1 * np.pi / 90, 1 * np.pi / 90]) ** 2) 65 | birth_states = [None] * birth_num 66 | birth_states[0] = state_func_handle(np.array([0, 0, 5, 0, np.pi / 180]), birth_cov) 67 | birth_states[1] = state_func_handle(np.array([20, 20, -10, 0, np.pi / 90]), birth_cov) 68 | birth_states[2] = state_func_handle(np.array([-20, 10, -10, 0, np.pi / 360]), birth_cov) 69 | birth_states[3] = state_func_handle(np.array([-10, -10, 8, 0, np.pi / 270]), birth_cov) 70 | intensity = GaussianSumState(birth_log_weights, birth_states) 71 | birth_model = PoissonBirthModel(intensity) 72 | 73 | return NonlinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 74 | 75 | @classmethod 76 | def caseD(cls, birth_prob=0.1, state_func_handle=GaussianState): 77 | time_range = [0, 100] 78 | tbirth = [10] 79 | tdeath = [80] 80 | initial_states = np.array([[0, 0, 10, 0, np.pi / 180]]).T 81 | 82 | # make birth model 83 | birth_cov = np.diag(np.array([1, 1, 1, 1 * np.pi / 180, 1 * np.pi / 180]) ** 2) 84 | birth_gaussian = state_func_handle(initial_states, birth_cov) 85 | birth_bern = BernoulliState(prob=birth_prob, 86 | state=GaussianMixtureState(log_weights=np.array([0.]), 87 | gaussian_states=[birth_gaussian])) 88 | 89 | birth_model = BernoulliBirthModel(birth_bern) 90 | 91 | return NonlinearScenario(time_range, tbirth, tdeath, initial_states, birth_model) 92 | -------------------------------------------------------------------------------- /scripts/debug_radarinfo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Sun May 24 16:06:40 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | #%% 11 | meas_data = data_reader[302] 12 | print(meas_data.sensor_type) 13 | for obj in meas_data.objects: 14 | vx = obj.VeloX 15 | vy = obj.VeloY 16 | x = obj.WldX 17 | y = obj.WldY 18 | thetap = np.arctan2(y,x) 19 | thetav = np.arctan2(vy,vx) 20 | print('x:{}, y:{}'.format(x, y)) 21 | print('vx:{}, vy:{}'.format(vx, vy)) 22 | print('p:{} vs v:{}'.format(thetap * 180/np.pi, thetav*180/np.pi)) 23 | 24 | #%% 25 | dd = { 26 | 'WldX': -1.28395, 27 | 'WldY': 87.4402, 28 | 'WldWidth': 0.0, 29 | 'WldHeight': 1.6, 30 | 'Dist': 87.4497, 31 | 'Angle': -89.1587, 32 | 'VeloX': 0.107444, 33 | 'VeloY': -14.7395, 34 | } 35 | 36 | np.arctan2(dd['WldY'], dd['WldX']) * 180 / np.pi 37 | np.arctan2(dd['VeloY'], dd['VeloX']) * 180 / np.pi 38 | -------------------------------------------------------------------------------- /scripts/demo_mbm_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Dec 24 16:28:42 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from scenarios.linear import LinearScenario 13 | from scenarios.nonlinear import NonlinearScenario 14 | from statecircle.estimator.base import EAPEstimator 15 | from statecircle.models.birth.base import MultiObjectBirthModel, PoissonBirthModel 16 | from statecircle.models.density.kalman import KalmanDensityModel 17 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 18 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 19 | from statecircle.reader.base import MeasurementReader 20 | from statecircle.reductor.gate import EllipsoidalGate 21 | from statecircle.models.sensor.base import DummySensorModel 22 | from statecircle.models.measurement.clutter import PoissonClutterModel 23 | from statecircle.models.measurement.linear import LinearMeasurementModel 24 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 25 | from statecircle.models.transition.linear import ConstantVelocityModel 26 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 27 | from statecircle.types.state import GaussianState, GaussianSumState 28 | from statecircle.trackers.mot.phd_filter import PHDFilter 29 | from statecircle.trackers.mot.mbm_filter import MBMFilter 30 | from statecircle.trackers.mot.pmbm_filter import PMBMFilter 31 | from tools.visualizer import Visualizer 32 | 33 | seed = None 34 | 35 | scenario = 'linear' 36 | # build scene 37 | if scenario == 'linear': 38 | scene = LinearScenario.caseC(birth_weight=0.03, birth_cov_scale=400) 39 | birth_model = scene.birth_model 40 | 41 | # build transition/measurement/clutter/birth models 42 | transition_model = ConstantVelocityModel(sigma=5) 43 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 44 | sigma=10) 45 | clutter_model = PoissonClutterModel(detection_rate=0.9, 46 | lambda_clutter=20, 47 | scope=[[-1000, 1000], [-1000, 1000]]) 48 | else: 49 | scene = NonlinearScenario.caseC(birth_weight=0.03) 50 | birth_model = scene.birth_model 51 | 52 | # make transition/measurement/clutter/birth models 53 | transition_model = SimpleCTRVModel(sigma_vel=1, 54 | sigma_omega=np.pi / 180) 55 | measurement_model = RangeBearningMeasurementModel(sigma_range=5, 56 | sigma_bearing=np.pi / 180, 57 | origin=[300, 400]) 58 | clutter_model = PoissonClutterModel(detection_rate=0.9, 59 | lambda_clutter=20, 60 | scope=[[200, 1200], [-np.pi, np.pi]]) 61 | 62 | # build data generator 63 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=False) 64 | 65 | # build sensor model 66 | sensor_model = DummySensorModel(clutter_model, measurement_model, random_seed=seed) 67 | 68 | # build data reader 69 | data_reader = MeasurementReader(data_generator, sensor_model) 70 | 71 | # build density model 72 | density_model = KalmanDensityModel() 73 | 74 | # gate method 75 | gate = EllipsoidalGate(percentile=0.999) 76 | 77 | # estimator 78 | estimator = EAPEstimator() 79 | 80 | # reductor 81 | reductor = HypothesisReductor(weight_min=1e-3, merging_threshold=4, capping_num=100) 82 | 83 | # %% build trackers & filtering 84 | # some extra parameters 85 | surviving_rate = 0.99 86 | recycle_threshold = 0.1 87 | prob_min = 1e-3 88 | prob_estimate = 0.5 89 | 90 | mbm_filter = MBMFilter(surviving_rate, 91 | prob_min, 92 | prob_estimate, 93 | birth_model, 94 | density_model, 95 | transition_model, 96 | measurement_model, 97 | clutter_model, 98 | gate, 99 | estimator, 100 | reductor) 101 | 102 | mbm_estimates = mbm_filter.filtering(data_reader) 103 | mbm_card_pred = [ele.shape[1] for ele in mbm_estimates] 104 | mbm_estimates = np.hstack(mbm_estimates) 105 | 106 | # %% visualise results 107 | visualizer = Visualizer(data_generator, 'MBM') 108 | visualizer.show_estimates(mbm_estimates) 109 | visualizer.show_cardinality(mbm_card_pred) 110 | visualizer.show_measurements(data_reader) 111 | -------------------------------------------------------------------------------- /scripts/entry_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | 11 | from scenarios.linear import LinearScenario 12 | from scenarios.nonlinear import NonlinearScenario 13 | from tools.visualizer import Visualizer 14 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 15 | from statecircle.models.sensor.base import DummySensorModel 16 | from statecircle.platform.base import TrackerPlatform 17 | from statecircle.reader.base import MeasurementReader 18 | 19 | 20 | # np.random.seed(9) 21 | 22 | # build tracker 23 | 24 | #scene = LinearScenario.caseA() 25 | #config_path = 'config/pda_linear.yml' 26 | 27 | scene = NonlinearScenario.caseA() 28 | config_path = 'config/gs_nonlinear.yml' 29 | 30 | platform = TrackerPlatform(config_path) 31 | 32 | # build data generator 33 | data_generator = SimulatedGroundTruthDataGenerator(scene, 34 | platform.transition_model, 35 | noisy=False) 36 | 37 | # build sensor model 38 | sensor_model = DummySensorModel(platform.clutter_model, platform.measurement_model) 39 | 40 | # build data reader 41 | data_reader = MeasurementReader(data_generator, sensor_model) 42 | 43 | # filtering 44 | estimates = platform.tracker.filtering(data_reader) 45 | 46 | card_pred = [ele.shape[1] for ele in estimates] 47 | estimates = np.hstack(estimates) 48 | 49 | # %% visualise results 50 | visualizer = Visualizer(data_generator, config_path) 51 | visualizer.show_estimates(estimates) 52 | visualizer.show_cardinality(card_pred) 53 | visualizer.show_measurements(data_reader) 54 | 55 | -------------------------------------------------------------------------------- /scripts/entry_mot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 18 17:50:11 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from scenarios.linear import LinearScenario 14 | from scenarios.nonlinear import NonlinearScenario 15 | from statecircle.estimator.base import EAPEstimator 16 | from statecircle.models.birth.base import MultiObjectBirthModel 17 | from statecircle.models.density.kalman import KalmanDensityModel 18 | from statecircle.models.density.unscented import UnscentedDensityModel 19 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 20 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 21 | from statecircle.reader.base import MeasurementReader 22 | from statecircle.reductor.gate import EllipsoidalGate 23 | from statecircle.models.sensor.base import DummySensorModel 24 | from statecircle.models.measurement.clutter import PoissonClutterModel 25 | from statecircle.models.measurement.linear import LinearMeasurementModel 26 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 27 | from statecircle.models.transition.linear import ConstantVelocityModel 28 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 29 | from statecircle.trackers.mot.global_nearest_neighbour_tracker import GlobalNearestNeighbourTracker 30 | from statecircle.trackers.mot.joint_probabilistic_data_association_tracker import JointProbabilisticDataAssociationTracker 31 | from statecircle.trackers.mot.multi_hypothesis_tracker import TrackOrientedMultiHypothesisTracker 32 | from statecircle.types.state import GaussianState, UnscentedState 33 | 34 | seed = None 35 | scenario = 'nonlinear' 36 | state_func_handle = UnscentedState # Supported `GaussianState` and `UnscentedState` 37 | 38 | if scenario == 'linear': 39 | # make transition/measurement/clutter/birth models 40 | transition_model = ConstantVelocityModel(sigma=5) 41 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 42 | sigma=10) 43 | clutter_model = PoissonClutterModel(detection_rate=0.9, 44 | lambda_clutter=20, 45 | scope=[[-1000, 1000], [-1000, 1000]]) 46 | 47 | scene = LinearScenario.caseB(state_func_handle=state_func_handle) 48 | birth_model = scene.birth_model 49 | 50 | elif scenario == 'nonlinear': 51 | # make transition/measurement/clutter/birth models 52 | transition_model = SimpleCTRVModel(sigma_vel=1, 53 | sigma_omega=np.pi / 180) 54 | measurement_model = RangeBearningMeasurementModel(sigma_range=5, 55 | sigma_bearing=np.pi / 180, 56 | origin=[300, 400]) 57 | clutter_model = PoissonClutterModel(detection_rate=0.9, 58 | lambda_clutter=20, 59 | scope=[[200, 1200], [-np.pi, np.pi]]) 60 | 61 | scene = NonlinearScenario.caseB(state_func_handle=state_func_handle) 62 | birth_model = scene.birth_model 63 | 64 | # make data generator 65 | noisy = False 66 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=noisy) 67 | 68 | # make sensor model 69 | sensor_model = DummySensorModel(clutter_model, measurement_model, random_seed=seed) 70 | 71 | # make data reader 72 | data_reader = MeasurementReader(data_generator, sensor_model) 73 | 74 | # make density model 75 | if state_func_handle is GaussianState: 76 | density_model = KalmanDensityModel() 77 | elif state_func_handle is UnscentedState: 78 | density_model = UnscentedDensityModel(transition_model.state_dim, alpha=1.0, beta=2.0) 79 | else: 80 | raise TypeError 81 | 82 | # gate method 83 | gate = EllipsoidalGate(percentile=0.999) 84 | 85 | # estimator 86 | estimator = EAPEstimator() 87 | 88 | # reductor 89 | reductor = HypothesisReductor(weight_min=1e-3, merging_threshold=2, capping_num=100) 90 | 91 | # %% build trackers & filtering 92 | # GNN tracker 93 | gnn_tracker = GlobalNearestNeighbourTracker(birth_model, 94 | density_model, 95 | transition_model, 96 | measurement_model, 97 | clutter_model, 98 | gate, 99 | estimator) 100 | gnn_estimates = gnn_tracker.filtering(data_reader) 101 | gnn_estimates = np.stack(gnn_estimates, -1) 102 | 103 | ## JPDA tracker 104 | jpda_tracker = JointProbabilisticDataAssociationTracker(birth_model, 105 | density_model, 106 | transition_model, 107 | measurement_model, 108 | clutter_model, 109 | gate, 110 | estimator, 111 | reductor) 112 | jpda_estimates = jpda_tracker.filtering(data_reader) 113 | jpda_estimates = np.stack(jpda_estimates, -1) 114 | 115 | # MH tracker 116 | mh_tracker = TrackOrientedMultiHypothesisTracker(birth_model, 117 | density_model, 118 | transition_model, 119 | measurement_model, 120 | clutter_model, 121 | gate, 122 | estimator, 123 | reductor) 124 | mh_estimates = mh_tracker.filtering(data_reader) 125 | mh_estimates = np.stack(mh_estimates, -1) 126 | # %% visualise results 127 | gt_series = data_generator.gt_series 128 | gt_data = np.hstack([ele.states for ele in data_generator]) 129 | fig, ax = plt.subplots(1, 1, figsize=(6,6)) 130 | plot_gt = ax.plot(gt_data[0], gt_data[1], 'yo', alpha=0.2, markersize=10) 131 | plot_gnn = ax.plot(gnn_estimates[0].T, gnn_estimates[1].T, 'r*-', alpha=0.5) 132 | plot_jpda = ax.plot(jpda_estimates[0].T, jpda_estimates[1].T, 'g.-', alpha=0.5) 133 | plot_mh = ax.plot(mh_estimates[0].T, mh_estimates[1].T, 'b+-', alpha=0.5) 134 | 135 | ax.grid() 136 | ax.set_xlabel('x (m)') 137 | ax.set_ylabel('y (m)') 138 | ax.legend((plot_gt[0], plot_gnn[0], plot_jpda[0], plot_mh[0]), ['Ground Truth', 'GNN', 'JPDA', "MHT"]) 139 | plt.axis('equal') 140 | 141 | # %% plot measurements 142 | meas, obj_meas, clutter_meas = [], [], [] 143 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 144 | meas.append(meas_data.meas) 145 | obj_meas.append(obj_meas_) 146 | clutter_meas.append(clutter_meas_) 147 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 148 | 149 | plt.figure() 150 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 151 | 152 | # plot clutter 153 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 154 | plt.legend(['measurements', 'clutter']) 155 | plt.show() 156 | plt.close('all') 157 | -------------------------------------------------------------------------------- /scripts/entry_mot_rfs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Fri Dec 20 14:14:33 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from scenarios.linear import LinearScenario 13 | from scenarios.nonlinear import NonlinearScenario 14 | from statecircle.estimator.base import EAPEstimator 15 | from statecircle.models.birth.base import MultiObjectBirthModel, PoissonBirthModel 16 | from statecircle.models.density.kalman import KalmanDensityModel 17 | from statecircle.models.density.unscented import UnscentedDensityModel 18 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 19 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 20 | from statecircle.reader.base import MeasurementReader 21 | from statecircle.reductor.gate import EllipsoidalGate 22 | from statecircle.models.sensor.base import DummySensorModel 23 | from statecircle.models.measurement.clutter import PoissonClutterModel 24 | from statecircle.models.measurement.linear import LinearMeasurementModel 25 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 26 | from statecircle.models.transition.linear import ConstantVelocityModel 27 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 28 | from statecircle.types.state import GaussianState, GaussianSumState, UnscentedState 29 | from statecircle.trackers.mot.phd_filter import PHDFilter 30 | from statecircle.trackers.mot.mbm_filter import MBMFilter 31 | from statecircle.trackers.mot.pmbm_filter import PMBMFilter 32 | 33 | seed = None 34 | scenario = 'linear' 35 | state_func_handle = UnscentedState # supported `GaussianState` and `UnscentedState` 36 | 37 | if scenario == 'linear': 38 | # make transition/measurement/clutter/birth models 39 | transition_model = ConstantVelocityModel(sigma=5) 40 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 41 | sigma=10) 42 | clutter_model = PoissonClutterModel(detection_rate=0.9, 43 | lambda_clutter=20, 44 | scope=[[-1000, 1000], [-1000, 1000]]) 45 | 46 | scene = LinearScenario.caseC(birth_weight=0.03, birth_cov_scale=400, state_func_handle=state_func_handle) 47 | birth_model = scene.birth_model 48 | 49 | elif scenario == 'nonlinear': 50 | # make transition/measurement/clutter/birth models 51 | transition_model = SimpleCTRVModel(sigma_vel=1, 52 | sigma_omega=np.pi / 180) 53 | measurement_model = RangeBearningMeasurementModel(sigma_range=5, 54 | sigma_bearing=np.pi / 180, 55 | origin=[300, 400]) 56 | clutter_model = PoissonClutterModel(detection_rate=0.9, 57 | lambda_clutter=20, 58 | scope=[[200, 1200], [-np.pi, np.pi]]) 59 | 60 | scene = NonlinearScenario.caseC(birth_weight=0.03, state_func_handle=state_func_handle) 61 | birth_model = scene.birth_model 62 | 63 | 64 | # make data generator 65 | noisy = False 66 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=noisy) 67 | 68 | # make sensor model 69 | sensor_model = DummySensorModel(clutter_model, measurement_model, random_seed=seed) 70 | 71 | # make data reader 72 | data_reader = MeasurementReader(data_generator, sensor_model) 73 | 74 | # make density model 75 | if state_func_handle is GaussianState: 76 | density_model = KalmanDensityModel() 77 | elif state_func_handle is UnscentedState: 78 | density_model = UnscentedDensityModel(transition_model.state_dim, alpha=1.0, beta=2.0) 79 | else: 80 | raise TypeError 81 | 82 | # gate method 83 | gate = EllipsoidalGate(percentile=0.999) 84 | 85 | # estimator 86 | estimator = EAPEstimator() 87 | 88 | # reductor 89 | reductor = HypothesisReductor(weight_min=1e-3, merging_threshold=4, capping_num=100) 90 | 91 | # %% build trackers & filtering 92 | # PHD filter 93 | surviving_rate = 0.99 94 | phd_filter = PHDFilter(surviving_rate, 95 | birth_model, 96 | density_model, 97 | transition_model, 98 | measurement_model, 99 | clutter_model, 100 | gate, 101 | estimator, 102 | reductor) 103 | phd_estimates_list = phd_filter.filtering(data_reader) 104 | phd_estimates = np.hstack(phd_estimates_list) 105 | phd_card_pred = [ele.shape[1] for ele in phd_estimates_list] 106 | 107 | surviving_rate = 0.99 108 | prob_min = 1e-3 109 | prob_estimate = 0.5 110 | mbm_filter = MBMFilter(surviving_rate, 111 | prob_min, 112 | prob_estimate, 113 | birth_model, 114 | density_model, 115 | transition_model, 116 | measurement_model, 117 | clutter_model, 118 | gate, 119 | estimator, 120 | reductor) 121 | mbm_estimates_list = mbm_filter.filtering(data_reader) 122 | mbm_estimates = np.hstack(mbm_estimates_list) 123 | mbm_card_pred = [ele.shape[1] for ele in mbm_estimates_list] 124 | 125 | #surviving_rate = 0.99 126 | recycle_threshold = 0.1 127 | prob_min = 1e-3 128 | prob_estimate = 0.5 129 | pmbm_filter = PMBMFilter(surviving_rate, 130 | recycle_threshold, 131 | prob_min, 132 | prob_estimate, 133 | birth_model, 134 | density_model, 135 | transition_model, 136 | measurement_model, 137 | clutter_model, 138 | gate, 139 | estimator, 140 | reductor) 141 | pmbm_estimates_list = pmbm_filter.filtering(data_reader) 142 | pmbm_estimates = np.hstack(pmbm_estimates_list) 143 | pmbm_card_pred = [ele.shape[1] for ele in pmbm_estimates_list] 144 | 145 | # %% visualise results 146 | gt_series = data_generator.gt_series 147 | gt_data = np.hstack([ele.states for ele in data_generator]) 148 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 149 | plot_gt = ax.plot(gt_data[0], gt_data[1], 'yo', alpha=0.2, markersize=10) 150 | plot_phd = ax.plot(phd_estimates[0].T, phd_estimates[1].T, 'g+', alpha=0.5) 151 | plot_mbm = ax.plot(mbm_estimates[0].T, mbm_estimates[1].T, 'b+', alpha=0.5) 152 | plot_pmbm = ax.plot(pmbm_estimates[0].T, pmbm_estimates[1].T, 'r+', alpha=0.5) 153 | 154 | ax.grid() 155 | ax.set_xlabel('x (m)') 156 | ax.set_ylabel('y (m)') 157 | ax.legend((plot_gt[0], plot_phd[0], plot_mbm[0], plot_pmbm[0]), ['Ground Truth', 'PHD', 'MBM', 'PMBM']) 158 | plt.axis('equal') 159 | 160 | # %% plot cardinality 161 | plt.figure() 162 | plt.plot(gt_series.num, 'yo') 163 | plt.plot(phd_card_pred, 'g+') 164 | plt.plot(mbm_card_pred, 'b+') 165 | plt.plot(pmbm_card_pred, 'r+') 166 | plt.legend(['GT', 'PHD', 'MBM', 'PMBM']) 167 | plt.grid() 168 | 169 | # %% plot measurements 170 | meas, obj_meas, clutter_meas = [], [], [] 171 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 172 | meas.append(meas_data.meas) 173 | obj_meas.append(obj_meas_) 174 | clutter_meas.append(clutter_meas_) 175 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 176 | 177 | plt.figure() 178 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 179 | 180 | # plot clutter 181 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 182 | plt.legend(['measurements', 'clutter']) 183 | plt.show() 184 | plt.close('all') 185 | -------------------------------------------------------------------------------- /scripts/entry_sot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 15:17:17 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | from scenarios.linear import LinearScenario 14 | from scenarios.nonlinear import NonlinearScenario 15 | from statecircle.estimator.base import EAPEstimator 16 | from statecircle.models.birth.base import SingleObjectBirthModel 17 | from statecircle.models.density.kalman import KalmanDensityModel 18 | from statecircle.models.density.unscented import UnscentedDensityModel 19 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 20 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 21 | from statecircle.reader.base import MeasurementReader 22 | from statecircle.reductor.gate import EllipsoidalGate 23 | from statecircle.models.sensor.base import DummySensorModel 24 | from statecircle.models.measurement.clutter import PoissonClutterModel 25 | from statecircle.models.measurement.linear import LinearMeasurementModel 26 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 27 | from statecircle.models.transition.linear import ConstantVelocityModel 28 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 29 | from statecircle.trackers.sot.nearest_neighbour_tracker import NearestNeighbourTracker 30 | from statecircle.trackers.sot.probabilistic_data_association_tracker import ProbabilisticDataAssociationTracker 31 | from statecircle.trackers.sot.gaussian_sum_tracker import GaussianSumTracker 32 | from statecircle.types.state import GaussianState, UnscentedState 33 | 34 | seed = None 35 | scenario = 'nonlinear' 36 | state_func_handle = UnscentedState # Supported `GaussianState` and `UnscentedState` 37 | 38 | if scenario == 'linear': 39 | # make transition/measurement/clutter/birth models 40 | transition_model = ConstantVelocityModel(sigma=5) 41 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 42 | sigma=10) 43 | clutter_model = PoissonClutterModel(detection_rate=0.9, 44 | lambda_clutter=20, 45 | scope=[[0, 1000], [0, 1000]]) 46 | 47 | scene = LinearScenario.caseA(state_func_handle=state_func_handle) 48 | birth_model = scene.birth_model 49 | 50 | elif scenario == 'nonlinear': 51 | # make transition/measurement/clutter/birth models 52 | transition_model = SimpleCTRVModel(sigma_vel=1, 53 | sigma_omega=np.pi / 180) 54 | measurement_model = RangeBearningMeasurementModel(sigma_range=5, 55 | sigma_bearing=np.pi / 180, 56 | origin=[300, 400]) 57 | clutter_model = PoissonClutterModel(detection_rate=0.9, 58 | lambda_clutter=20, 59 | scope=[[0, 1000], [-np.pi, np.pi]]) 60 | 61 | scene = NonlinearScenario.caseA(state_func_handle=state_func_handle) 62 | birth_model = scene.birth_model 63 | 64 | 65 | 66 | # make data generator 67 | birth_times = [0] 68 | death_times = [100] 69 | time_range = [0, 100] 70 | noisy = False 71 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=noisy) 72 | 73 | # make sensor model 74 | sensor_model = DummySensorModel(clutter_model, measurement_model, random_seed=seed) 75 | 76 | # make data reader 77 | data_reader = MeasurementReader(data_generator, sensor_model) 78 | 79 | # make density model 80 | if state_func_handle is GaussianState: 81 | density_model = KalmanDensityModel() 82 | elif state_func_handle is UnscentedState: 83 | density_model = UnscentedDensityModel(transition_model.state_dim, alpha=1.0, beta=2.0) 84 | else: 85 | raise TypeError 86 | 87 | # gate method 88 | gate = EllipsoidalGate(percentile=0.999) 89 | 90 | # estimator 91 | estimator = EAPEstimator() 92 | 93 | # reductor 94 | reductor = HypothesisReductor(weight_min=1e-3, merging_threshold=2, capping_num=100) 95 | 96 | # %% build trackers & filtering 97 | # NN tracker 98 | nn_tracker = NearestNeighbourTracker(birth_model, 99 | density_model, 100 | transition_model, 101 | measurement_model, 102 | clutter_model, 103 | gate, 104 | estimator) 105 | nn_estimates = nn_tracker.filtering(data_reader) 106 | nn_estimates = np.stack(nn_estimates, -1) 107 | 108 | # PDA tracker 109 | pda_tracker = ProbabilisticDataAssociationTracker(birth_model, 110 | density_model, 111 | transition_model, 112 | measurement_model, 113 | clutter_model, 114 | gate, 115 | estimator, 116 | reductor) 117 | pda_estimates = pda_tracker.filtering(data_reader) 118 | pda_estimates = np.stack(pda_estimates, -1) 119 | 120 | # GS tracker 121 | gs_tracker = GaussianSumTracker(birth_model, 122 | density_model, 123 | transition_model, 124 | measurement_model, 125 | clutter_model, 126 | gate, 127 | estimator, 128 | reductor) 129 | gs_estimates = gs_tracker.filtering(data_reader) 130 | gs_estimates = np.stack(gs_estimates, -1) 131 | 132 | # %% visualise results 133 | gt_series = data_generator.gt_series 134 | gt_data = np.hstack([data[0].states for data in gt_series.datum]) 135 | fig, ax = plt.subplots(1, 1, figsize=(6,6)) 136 | plot_gt = ax.plot(gt_data[0], gt_data[1], 'y-', linewidth=10, alpha=0.5) 137 | plot_nn = ax.plot(nn_estimates[0], nn_estimates[1], 'r*-', alpha=0.5) 138 | plot_pda = ax.plot(pda_estimates[0], pda_estimates[1], 'g.-', alpha=0.5) 139 | plot_gs = ax.plot(gs_estimates[0], gs_estimates[1], 'b+-', alpha=0.5) 140 | 141 | ax.grid() 142 | ax.set_xlabel('x (m)') 143 | ax.set_ylabel('y (m)') 144 | ax.legend((plot_gt[0], plot_nn[0], plot_pda[0], plot_gs[0]), ['Ground Truth', 'Nearest Neighbour', 'Probabilistic Data Associaiton', "Gaussian Sum"]) 145 | plt.axis('equal') 146 | 147 | # %% plot measurements 148 | meas, obj_meas, clutter_meas = [], [], [] 149 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 150 | meas.append(meas_data.meas) 151 | obj_meas.append(obj_meas_) 152 | clutter_meas.append(clutter_meas_) 153 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 154 | 155 | plt.figure() 156 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 157 | 158 | # plot clutter 159 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 160 | plt.legend(['measurements', 'clutter']) 161 | plt.show() 162 | plt.close('all') 163 | -------------------------------------------------------------------------------- /scripts/test_kalman_error_ellipse.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Fri May 15 17:03:24 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | #%% 10 | import numpy as np 11 | #from statecircle.lib.harem.debugger import plot_covariance_ellipse 12 | from scipy.stats import multivariate_normal, chi2 13 | import matplotlib.pyplot as plt 14 | from matplotlib.patches import Ellipse 15 | mvn = multivariate_normal 16 | 17 | def covariance_ellipse(P, thresh): 18 | U, s, v = np.linalg.svd(P) 19 | orientation = np.math.atan2(U[1,0], U[0,0]) 20 | width, height = 2 * np.sqrt(s * thresh) 21 | return orientation, width, height 22 | 23 | 24 | def plot_covariance_ellipse(mean, cov, percentile=0.95, color='b', ax=None, plot_center=True, plot_edge=False, scale=1): 25 | """ plot the covariance ellipse where mean is a (x,y) tuple for the mean 26 | of the covariance (center of ellipse) 27 | cov is a 2x2 covariance matrix 28 | """ 29 | thresh = chi2.ppf(percentile, 2) 30 | angle, width, height = covariance_ellipse(cov, thresh) 31 | angle = np.degrees(angle) 32 | width *= scale 33 | height *= scale 34 | ax = ax or plt.gca() 35 | 36 | e = Ellipse(mean, width, height, angle, fill=False, edgecolor=color, linewidth=2) 37 | if plot_center: 38 | center_obj = ax.scatter(*mean, marker='+', color=color) 39 | else: 40 | center_obj = None 41 | if plot_edge: 42 | e_aug = Ellipse(mean, width, height, angle, fill=False, edgecolor='k', linewidth=5) 43 | ax.add_patch(e_aug) 44 | ax.add_patch(e) 45 | return center_obj 46 | 47 | #%% initial state 48 | m = np.array([0, 0, 1, 0.5]) 49 | P = np.diag([100, 100, 100, 100]) 50 | #%% transition model 51 | dt = 1 52 | F = np.array([[1, 0, dt, 0], 53 | [0, 1, 0, dt], 54 | [0, 0, 1, 0], 55 | [0, 0, 0, 1]]) 56 | sigma = 1 57 | Q = sigma ** 2 * np.array([[dt ** 4 / 4, 0, dt ** 3 / 2, 0], 58 | [0, dt ** 4 / 4, 0, dt ** 3 / 2], 59 | [dt ** 3 / 2, 0, dt ** 2, 0], 60 | [0, dt ** 3 / 2, 0, dt ** 2]]) 61 | #%% measurement model 62 | meas_dim = 2 63 | H = np.array([[1, 0, 0, 0], 64 | [0, 1, 0, 0]]) 65 | est_sigma = 3 66 | est_R = est_sigma ** 2 * np.diag([1, 2]) 67 | 68 | real_sigma = 5 69 | real_R = real_sigma ** 2 * np.diag([1, 2]) 70 | 71 | #%% generate dataset 72 | T = 100 73 | x0 = np.array([0, 0, 1, 1]) 74 | x = x0 75 | meas = [] 76 | for t in range(T): 77 | x = F.dot(x) + mvn.rvs(cov=Q) 78 | z = H.dot(x) + mvn.rvs(cov=real_R) 79 | meas.append(z) 80 | meas = np.array(meas) 81 | 82 | #%% plot measurements 83 | plt.plot(meas[:,0], meas[:,1], '.') 84 | 85 | #%% NIS(NOrmalized Innovation Squared) 86 | def NIS(z_pred, mea, S): 87 | res = (z_pred - mea) 88 | # res.T * S^{-1} * res 89 | return res.dot(np.linalg.inv(S)).dot(res) 90 | 91 | #%% 92 | scope = 100 93 | percentile = 0.999 94 | nis_thresh = chi2.ppf(percentile, meas_dim) 95 | NIS_trace = [] 96 | for i in range(T): 97 | # predict 98 | m_pred = F.dot(m) 99 | P_pred = F.dot(P).dot(F.T) + Q 100 | 101 | # update 102 | z_pred = H.dot(m_pred) 103 | S = H.dot(P_pred).dot(H.T) + est_R 104 | K = P_pred.dot(H.T).dot(np.linalg.inv(S)) 105 | 106 | m_upd = m_pred + K.dot(meas[i] - z_pred) 107 | P_upd = P_pred - K.dot(S).dot(K.T) 108 | 109 | m = m_upd 110 | P = P_upd 111 | 112 | NIS_trace.append(NIS(z_pred, meas[i], S)) 113 | 114 | fig, ax = plt.subplots(1, 2, figsize=(10, 5)) 115 | # ax.plot(m[0], m[1], 'b.') 116 | ax[0].plot(meas[i][0], meas[i][1], 'rx') 117 | plot_covariance_ellipse(m[:2], P[:2,:2], percentile=0.999, color='b', 118 | ax=ax[0], plot_edge=False, scale=1) 119 | plot_covariance_ellipse(z_pred, S, percentile=0.999, color='r', 120 | ax=ax[0], plot_center=False, plot_edge=False, scale=1) 121 | ax[0].axis([m[0] - scope, m[0] + scope, m[1] - scope, m[1] + scope]) 122 | # ax[0].axis('equal') 123 | ax[1].hlines(nis_thresh, 0, T, 'r') 124 | ax[1].plot(NIS_trace) 125 | ax[0].grid(True) 126 | ax[1].grid(True) 127 | plt.show() 128 | print('') 129 | plt.close('all') 130 | -------------------------------------------------------------------------------- /scripts/test_range_bearing_meas_cov.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Jun 2 10:45:27 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | 10 | #%% 11 | import numpy as np 12 | from scipy.stats import chi2 13 | from matplotlib.patches import Ellipse 14 | import matplotlib.pyplot as plt 15 | def covariance_ellipse(P, thresh): 16 | U, s, v = np.linalg.svd(P) 17 | orientation = np.math.atan2(U[1,0], U[0,0]) 18 | width, height = 2 * np.sqrt(s * thresh) 19 | return orientation, width, height 20 | 21 | def plot_covariance_ellipse(mean, cov, percentile=0.95, color='b', ax=None, plot_center=False, plot_edge=False): 22 | """ plot the covariance ellipse where mean is a (x,y) tuple for the mean 23 | of the covariance (center of ellipse) 24 | cov is a 2x2 covariance matrix 25 | """ 26 | thresh = chi2.ppf(percentile, 2) 27 | angle, width, height = covariance_ellipse(cov, thresh) 28 | angle = np.degrees(angle) 29 | ax = ax or plt.gca() 30 | 31 | e = Ellipse(mean, width, height, angle, fill=False, edgecolor=color, linewidth=2, linestyle='-', alpha=0.7) 32 | if plot_center: 33 | center_obj = ax.scatter(*mean, marker='+', color=color) 34 | else: 35 | center_obj = None 36 | 37 | if plot_edge: 38 | e_aug = Ellipse(mean, width, height, angle, fill=False, edgecolor='k', linewidth=5) 39 | ax.add_patch(e_aug) 40 | ax.add_patch(e) 41 | return center_obj or e 42 | 43 | 44 | 45 | def draw_ellipse(var_x, var_y, r, phi): 46 | 47 | R = np.diag([var_x, var_y]) 48 | 49 | 50 | x = r * np.cos(phi) 51 | y = r * np.sin(phi) 52 | J = np.array([[np.cos(phi), -r * np.sin(phi)], 53 | [np.sin(phi), r * np.cos(phi)]]) 54 | 55 | cov = J.dot(R).dot(J.T) 56 | return np.array([x, y]), cov 57 | 58 | 59 | var_x = 1 60 | var_y = 0.001 61 | r = 10 62 | phi = np.pi / 3 63 | res = draw_ellipse(var_x, var_y, r, phi) 64 | fig, ax = plt.subplots(1, 1) 65 | plot_covariance_ellipse(*res, plot_center=True) 66 | plt.axis([-30, 30, 0, 50]) 67 | ax.set_aspect(1) 68 | 69 | -------------------------------------------------------------------------------- /snapshots/eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/snapshots/eval.png -------------------------------------------------------------------------------- /snapshots/meas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/snapshots/meas.png -------------------------------------------------------------------------------- /snapshots/res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/snapshots/res.png -------------------------------------------------------------------------------- /statecircle/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 4 20:34:02 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | from abc import ABCMeta, abstractmethod 11 | from collections import OrderedDict 12 | 13 | 14 | class BaseMeta(ABCMeta): 15 | r"""BaseMeta meta class 16 | """ 17 | 18 | # def __prepare__(mcls, name, bases, **kwargs): 19 | # return OrderedDict() 20 | 21 | def __new__(mcls, name, bases, namespace, **kwargs): 22 | return super().__new__(mcls, name, bases, namespace, **kwargs) 23 | 24 | 25 | class Base(metaclass=BaseMeta): 26 | r"""StateCircle base meta class 27 | """ 28 | -------------------------------------------------------------------------------- /statecircle/configuration/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/configuration/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from ..base import Base 9 | 10 | class Configuration(Base): 11 | r"""Configuration class""" 12 | 13 | -------------------------------------------------------------------------------- /statecircle/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/datasets/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from abc import abstractmethod 9 | import numpy as np 10 | 11 | from ..base import Base 12 | from statecircle.types.data import GroundTruthData, GroundTruthSeries 13 | 14 | 15 | class SENSOR_TYPE: 16 | CAMERA = 0 17 | MOBILEYE = 1 18 | RADAR = 2 19 | CONTI_RADAR = 3 20 | WHST_RADAR = 4 21 | UNKNOWN = -1 22 | FUSED = 666 23 | 24 | CAMERA_TYPES = [CAMERA, MOBILEYE, FUSED] 25 | RADAR_TYPES = [RADAR, CONTI_RADAR, WHST_RADAR, FUSED] 26 | 27 | 28 | class DataGenerator(Base): 29 | r"""Data generator base class""" 30 | 31 | @abstractmethod 32 | def generate(self): 33 | pass 34 | 35 | 36 | class SimulatedGroundTruthDataGenerator(Base): 37 | r"""Simulated ground truth data generator""" 38 | 39 | def __init__(self, scene, transition_model, noisy=False): 40 | r""" 41 | 42 | Parameters 43 | ---------- 44 | scene : Scennario class 45 | - initial_states 46 | - birth_times 47 | - death_times 48 | - time_range 49 | transition_model 50 | noisy 51 | """ 52 | initial_states = scene.initial_states[:, None] if scene.initial_states.ndim == 1 else scene.initial_states 53 | self.obj_num = initial_states.shape[1] 54 | self.initial_states = initial_states 55 | self.birth_times = scene.birth_times 56 | self.death_times = scene.death_times 57 | assert len(scene.time_range) == 2 and scene.time_range[1] >= scene.time_range[0] 58 | self.time_range = scene.time_range 59 | self.time_len = scene.time_range[1] - scene.time_range[0] 60 | self.transition_model = transition_model 61 | 62 | self.gt_series = self.initialize_trajectories(noisy) 63 | 64 | @property 65 | def ndim(self): 66 | return self.initial_states.shape[0] 67 | 68 | def __len__(self): 69 | return int(self.time_len) 70 | 71 | def initialize_trajectories(self, noisy=False): 72 | gt_series = GroundTruthSeries(time_len=self.time_len) 73 | 74 | for k in np.arange(self.time_range[0], self.time_range[1], dtype=np.int_): 75 | gt_series.timestamps[k] = k 76 | 77 | for i in range(self.obj_num): 78 | obj_state = self.initial_states[:, i][:, None] 79 | for k in np.arange(max(self.birth_times[i], self.time_range[0]), 80 | min((self.death_times[i], self.time_range[1])), 81 | dtype=np.int_): 82 | time_step = 1 83 | if noisy: 84 | obj_state = np.random.multivariate_normal(self.transition_model.forward(obj_state, time_step)[:, 0], 85 | self.transition_model.noise_covar(time_step=time_step))[:, None] 86 | else: 87 | obj_state = self.transition_model.forward(obj_state, time_step) 88 | 89 | gt_series.datum[k].append(GroundTruthData(timestamp=k, states=obj_state)) 90 | gt_series.num[k] += 1 91 | 92 | return gt_series 93 | 94 | def __getitem__(self, idx): 95 | if len(self.gt_series.datum[idx]) == 0: 96 | return GroundTruthData(self.gt_series.timestamps[idx], None) 97 | 98 | states = np.hstack([ele.states for ele in self.gt_series.datum[idx]]) 99 | timestamp = self.gt_series.timestamps[idx] 100 | 101 | return GroundTruthData(timestamp, states) 102 | -------------------------------------------------------------------------------- /statecircle/datasets/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/datasets/tests/test_data_generators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import pytest 9 | import numpy as np 10 | 11 | from ..base import SimulatedGroundTruthDataGenerator 12 | from statecircle.models.transition.linear import ConstantVelocityModel 13 | 14 | 15 | @pytest.fixture() 16 | def input_paras(): 17 | initial_states = np.random.rand(4, 5) 18 | birth_times = [10, 20, 30, 40, 50] 19 | death_times = [40, 50, 60, 70, 80] 20 | time_scope = [0, 100] 21 | transition_model = ConstantVelocityModel(sigma=5) 22 | return initial_states, birth_times, death_times, time_scope, transition_model 23 | 24 | def test_dummy_gt_data_generator(input_paras): 25 | initial_states, birth_times, death_times, time_scope, transition_model = input_paras 26 | data_gen = SimulatedGroundTruthDataGenerator(initial_states, birth_times, death_times, 27 | time_scope, transition_model, noisy=False) 28 | 29 | for data in data_gen: 30 | if data is not None: 31 | print(data.state.shape) 32 | 33 | for timestamp, datum in zip(data_gen.gt_series.timestamps, data_gen.gt_series.datum): 34 | for data in datum: 35 | assert timestamp == data.timestamp 36 | 37 | assert len(data_gen.gt_series.datum[5]) == 0 38 | assert len(data_gen.gt_series.datum[15]) == 1 39 | assert len(data_gen.gt_series.datum[25]) == 2 40 | assert len(data_gen.gt_series.datum[35]) == 3 41 | assert len(data_gen.gt_series.datum[45]) == 3 42 | assert len(data_gen.gt_series.datum[55]) == 3 43 | assert len(data_gen.gt_series.datum[65]) == 2 44 | assert len(data_gen.gt_series.datum[75]) == 1 45 | assert len(data_gen.gt_series.datum[85]) == 0 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /statecircle/estimator/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/estimator/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | 9 | from abc import abstractmethod 10 | 11 | from ..base import Base 12 | 13 | class Estimator(Base): 14 | r"""estimator base class""" 15 | 16 | @abstractmethod 17 | def __call__(self, state): 18 | r"""Extract estimate from density""" 19 | 20 | 21 | class EAPEstimator(Estimator): 22 | r"""Expected a posterior estimator""" 23 | 24 | def __call__(self, state): 25 | return state.mean 26 | 27 | 28 | class MAPEstimator(Estimator): 29 | r"""Maximum a posterior estimator""" 30 | 31 | def __call__(self, state): 32 | return state.max() 33 | -------------------------------------------------------------------------------- /statecircle/hypothesiser/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | 5 | @author: zhxm 6 | """ -------------------------------------------------------------------------------- /statecircle/lib/harem/README.md: -------------------------------------------------------------------------------- 1 | # **Harem** for all general purposes 2 | - store all sciprts and general functions here! -------------------------------------------------------------------------------- /statecircle/lib/harem/test/test_prev_state_decorator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Fri Jun 5 10:41:56 2020 6 | 7 | @author: zhxm 8 | """ 9 | 10 | from debugger import prev_state 11 | 12 | class Foo: 13 | def __init__(self, x): 14 | self.x = x 15 | 16 | @prev_state(10, KeyError) 17 | def test1(self, idx): 18 | return self.x[idx] 19 | 20 | @prev_state(2, KeyError) 21 | def test2(self, idx): 22 | return self.x[idx] 23 | 24 | @prev_state(2, TypeError) 25 | def test3(self, idx): 26 | return self.x[idx] 27 | 28 | 29 | foo = Foo({'a':1, 'c':3}) 30 | print(foo.test1('a1')) 31 | print(foo.test1('b')) 32 | print(foo.test1('a')) 33 | print(foo.test1('b')) 34 | print(foo.test1('c')) 35 | 36 | print(foo.test2('c')) 37 | print(foo.test2('b')) 38 | print(foo.test2('a')) 39 | 40 | foo.test3('d') 41 | -------------------------------------------------------------------------------- /statecircle/lib/lane_regressor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Mon May 25 22:38:11 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | 11 | from sklearn.linear_model import RANSACRegressor 12 | 13 | class LaneRegressor: 14 | def __init__(self, min_shift=0.0, min_samples=5, sample_range=[0, 20], sample_num=10, parallel=True): 15 | self.estimator_pos = RANSACRegressor(loss='squared_loss', stop_n_inliers=6) 16 | self.estimator_neg = RANSACRegressor(loss='squared_loss', stop_n_inliers=6) 17 | self.min_samples = min_samples 18 | self.min_shift = min_shift 19 | self.sample_range = sample_range 20 | self.sample_num = sample_num 21 | self.parallel = parallel 22 | self.pos_bond = None 23 | self.neg_bond = None 24 | 25 | def fit(self, coords): 26 | # coords: [2, #pts] 27 | # filter with lane poitns 28 | coords = coords[:, np.abs(coords[0]) >= self.min_shift] 29 | pos_idx = coords[0] >= 0 30 | self.pos_bond = coords[:, pos_idx][:, :, None] 31 | self.neg_bond = coords[:, ~pos_idx][:, :, None] 32 | 33 | if self.pos_bond.shape[1] > self.min_samples: 34 | self.reg_pos = self.estimator_pos.fit(self.pos_bond[1], self.pos_bond[0]) 35 | else: 36 | self.reg_pos = None 37 | 38 | if self.neg_bond.shape[1] > self.min_samples: 39 | self.reg_neg = self.estimator_neg.fit(self.neg_bond[1], self.neg_bond[0]) 40 | else: 41 | self.reg_neg = None 42 | 43 | if self.parallel: 44 | # parallize boundaries 45 | if self.reg_pos is not None and self.reg_neg is not None: 46 | self.reg_pos.estimator_.coef_ = self.reg_neg.estimator_.coef_ = \ 47 | (self.reg_pos.estimator_.coef_ + self.reg_pos.estimator_.coef_) / 2 48 | return self 49 | 50 | def sample(self): 51 | samples = [] 52 | if self.reg_pos is not None: 53 | y_pos = np.linspace(*self.sample_range, self.sample_num)[:, None] 54 | x_pos = self.reg_pos.predict(y_pos) 55 | samples.append(np.hstack((x_pos, y_pos)).T) 56 | 57 | if self.reg_neg is not None: 58 | y_neg = np.linspace(*self.sample_range, self.sample_num)[:, None] 59 | x_neg = self.reg_neg.predict(y_neg) 60 | samples.append(np.hstack((x_neg, y_neg)).T) 61 | 62 | return samples 63 | -------------------------------------------------------------------------------- /statecircle/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | 9 | from ..base import Base 10 | from abc import abstractmethod 11 | 12 | 13 | class Model(Base): 14 | r"""Virtual base models model""" 15 | 16 | @property 17 | def ndim(self): 18 | r"""return state dimension""" 19 | raise NotImplementedError 20 | 21 | def forward(self): 22 | r"""virtual forward function""" 23 | raise NotImplementedError 24 | 25 | def reverse(self ): 26 | r"""reverse function, not virtual method""" 27 | raise NotImplementedError 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /statecircle/models/birth/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/birth/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from abc import abstractmethod 9 | import numpy as np 10 | 11 | from ..base import Model 12 | from statecircle.types.state import GaussianState, PoissonState 13 | 14 | 15 | class BirthModel(Model): 16 | r"""Birth model""" 17 | @abstractmethod 18 | def birth(self): 19 | r"""virtual birth method""" 20 | 21 | 22 | class SingleObjectBirthModel(BirthModel): 23 | def __init__(self, initial_state, birth_cov, state_func_handle=GaussianState): 24 | r""" 25 | 26 | Parameters 27 | ---------- 28 | initial_state : state 29 | birth_cov : array 30 | """ 31 | initial_state = np.array(initial_state) 32 | initial_state = initial_state[:, None] if initial_state.ndim == 1 else initial_state 33 | self.mean = self.initial_state = initial_state 34 | self.cov = birth_cov 35 | self.state_func_handle = state_func_handle 36 | 37 | def birth(self): 38 | return self.state_func_handle(self.mean, self.cov) 39 | 40 | @property 41 | def ndim(self): 42 | return self.mean.shape[0] 43 | 44 | 45 | class MultiObjectBirthModel(BirthModel): 46 | def __init__(self, initial_states, birth_cov, state_func_handle=GaussianState): 47 | r""" 48 | 49 | Parameters 50 | ---------- 51 | initial_states : array(ndim_state, num_state) 52 | birth_cov: array 53 | """ 54 | self.means = initial_states 55 | self.covs = [birth_cov for _ in range(self.means.shape[1])] 56 | self.state_func_handle = state_func_handle 57 | 58 | def birth(self): 59 | return [self.state_func_handle(mean, cov) for mean, cov in zip(self.means.T, self.covs)] 60 | 61 | class BernoulliBirthModel(BirthModel): 62 | def __init__(self, bern): 63 | self.bern = bern 64 | 65 | def birth(self): 66 | return self.bern 67 | 68 | class MultiBernoulliBirthModel(BirthModel): 69 | def __init__(self, multi_bern): 70 | self.multi_bern = multi_bern 71 | 72 | def birth(self): 73 | return self.multi_bern 74 | 75 | 76 | class PoissonBirthModel(BirthModel): 77 | def __init__(self, intensity): 78 | self.intensity = intensity 79 | 80 | def birth(self): 81 | return PoissonState(self.intensity) 82 | -------------------------------------------------------------------------------- /statecircle/models/density/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/density/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from abc import abstractmethod 9 | 10 | from ..base import Model 11 | 12 | 13 | class ConjugateDensityModel(Model): 14 | r"""Conjugate Density model 15 | 16 | The density is multi-object conjugate under predict and update steps 17 | """ 18 | 19 | @abstractmethod 20 | def predict(self, state, time_step, transition_model): 21 | r""" 22 | 23 | Parameters 24 | ---------- 25 | state 26 | transition_model 27 | 28 | Returns 29 | ------- 30 | 31 | """ 32 | 33 | @abstractmethod 34 | def update(self, state, measurement_model): 35 | r""" 36 | 37 | Parameters 38 | ---------- 39 | state 40 | measurement_model 41 | 42 | Returns 43 | ------- 44 | 45 | """ 46 | 47 | @abstractmethod 48 | def predicted_log_likelihood(self, state, meas, measurement_model): 49 | r""" 50 | 51 | Parameters 52 | ---------- 53 | state 54 | meas 55 | measurement_model 56 | 57 | Returns 58 | ------- 59 | 60 | """ 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /statecircle/models/density/kalman.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | from scipy.stats import multivariate_normal 10 | 11 | from .base import ConjugateDensityModel 12 | from statecircle.types.state import LabeledState, GaussianLabeledState, GaussianState 13 | 14 | 15 | class KalmanDensityModel(ConjugateDensityModel): 16 | r"""[Extended] Kalman density model 17 | 18 | linear gaussian transition/measurement model 19 | """ 20 | 21 | def predict(self, state, time_step, transition_model): 22 | assert isinstance(state, GaussianState) 23 | mean_pred = transition_model.forward(state.mean, time_step=time_step) 24 | F = transition_model.transition_matrix(state.mean, time_step) 25 | cov_pred = F.dot(state.cov).dot(F.T) + transition_model.noise_covar(time_step=time_step) 26 | 27 | # TODO: add labeled state 28 | # if isinstance(state, LabeledState): 29 | # return GaussianLabeledState(state.label, mean_pred, cov_pred) 30 | 31 | return GaussianState(mean_pred, cov_pred) 32 | 33 | def update(self, state_pred, meas, measurement_model): 34 | meas = meas[:, None] if meas.ndim == 1 else meas 35 | assert meas.shape[1] == 1, 'expected one measurement for update step, but got {} measurements.' \ 36 | .format(meas.shape[1]) 37 | 38 | # measurement matrix 39 | H = measurement_model.measurement_matrix(state_pred.mean) 40 | 41 | # innovation covariance 42 | meas_pred = measurement_model.forward(state_pred.mean) 43 | S = H.dot(state_pred.cov).dot(H.T) + measurement_model.noise_covar(meas_pred) 44 | 45 | # make sure matrix S is positive definite 46 | S = (S + S.T) / 2 47 | 48 | K = state_pred.cov.dot(H.T).dot(np.linalg.inv(S)) 49 | 50 | # density update 51 | mean_upd = state_pred.mean + K.dot(meas - meas_pred) 52 | # covariance update 53 | # alternative formula: 54 | # cov_upd = (np.eye(measurement_model.ndim) - K.dot(H)).dot(state_pred.cov) 55 | cov_upd = state_pred.cov - K.dot(S).dot(K.T) 56 | 57 | # TODO: add labeled state 58 | # if isinstance(state_pred, LabeledState): 59 | # return GaussianLabeledState(mean_upd, cov_upd, state_pred.label) 60 | 61 | return GaussianState(mean_upd, cov_upd) 62 | 63 | def predicted_log_likelihood(self, state_pred, meas, measurement_model): 64 | r""" 65 | 66 | Parameters 67 | ---------- 68 | state_pred : `GaussianState` 69 | meas : array [`ndim_meas`, num_meas] 70 | measurement_model : 71 | 72 | Returns 73 | ------- 74 | log_likelihood : array [num_meas] 75 | """ 76 | if 0 in meas.shape: 77 | return [] 78 | 79 | # measurement matrix 80 | H = measurement_model.measurement_matrix(state_pred.mean) 81 | 82 | # innovation convariance 83 | meas_pred = measurement_model.forward(state_pred.mean) 84 | S = H.dot(state_pred.cov).dot(H.T) + measurement_model.noise_covar(meas_pred) 85 | 86 | # make sure matrix S is positive definite 87 | S = (S + S.T) / 2 88 | 89 | pred_loglik = multivariate_normal.logpdf(meas.T, meas_pred[:, 0], S, allow_singular=True) 90 | 91 | return pred_loglik 92 | -------------------------------------------------------------------------------- /statecircle/models/density/kalman_accumulated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 25 14:33:07 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | from copy import deepcopy 11 | 12 | from scipy.stats import multivariate_normal 13 | 14 | from statecircle.types.state import GaussianState, GaussianAccumulatedState 15 | from .base import ConjugateDensityModel 16 | 17 | 18 | # TODO: support unscented state 19 | class KalmanAccumulatedDensityModel(ConjugateDensityModel): 20 | r"""[Extended] Kalman accumulated state model 21 | 22 | linear gaussian transition/measurement model 23 | """ 24 | def __init__(self, traceback_range=100): 25 | r""" 26 | 27 | Parameters 28 | ---------- 29 | traceback_range : int 30 | accumulated density range 31 | """ 32 | assert traceback_range > 0, "`traceback_range` muse be positive." 33 | self.traceback_range = traceback_range 34 | 35 | def predict_state(self, state, time_step, transition_model): 36 | assert isinstance(state, GaussianState) 37 | 38 | if state.mean.ndim == 1: 39 | state.mean = state.mean[:, None] 40 | d = transition_model.state_dim 41 | 42 | mean_pred = transition_model.forward(state.mean[:, -1], time_step=time_step) 43 | F = transition_model.transition_matrix(state.mean[:, -1], time_step) 44 | cov_pred = F.dot(state.cov[-d:, -d:]).dot(F.T) + transition_model.noise_covar(time_step=time_step) 45 | 46 | return GaussianState(mean_pred, cov_pred) 47 | 48 | def predict(self, state, time_step, transition_model): 49 | state.mean = state.mean[:, None] if state.mean.ndim == 1 else state.mean 50 | mean_pred = np.hstack((state.mean, transition_model.forward(state.mean[:, -1:], time_step))) 51 | F = transition_model.transition_matrix(state.mean[:, -1], time_step) 52 | d = transition_model.state_dim 53 | if self.traceback_range > 1: 54 | step_d = (self.traceback_range - 1) * d 55 | new_cov1 = state.cov[-step_d:, -d:].dot(F.T) 56 | new_cov2 = F.dot(state.cov[-d:, -step_d:]) 57 | new_cov3 = F.dot(state.cov[-d:, -d:]).dot(F.T) + transition_model.noise_covar(time_step) 58 | cov_pred = np.vstack((np.hstack((state.cov[-step_d:, -step_d:], new_cov1)), 59 | np.hstack((new_cov2, new_cov3)) 60 | )) 61 | return GaussianAccumulatedState(mean_pred, cov_pred) 62 | else: 63 | cov_pred = F.dot(state.cov[-d:, -d:]).dot(F.T) + transition_model.noise_covar(time_step) 64 | return GaussianAccumulatedState(mean_pred, cov_pred) 65 | 66 | def update(self, state_pred, meas, measurement_model): 67 | meas = meas[:, None] if meas.ndim == 1 else meas 68 | assert meas.shape[1] == 1, 'expected one measurement for update step, but got {} measurements.' \ 69 | .format(meas.shape[1]) 70 | 71 | state_pred.mean = state_pred.mean[:, None] if state_pred.mean.ndim == 1 else state_pred.mean 72 | 73 | nx = state_pred.mean.shape[0] 74 | 75 | # measurement model Jacobian 76 | H = measurement_model.measurement_matrix(state_pred.mean[:, -1]) 77 | 78 | # innovation covariance 79 | meas_pred = measurement_model.forward(state_pred.mean[:, -1]) 80 | S = H.dot(state_pred.cov[-nx:, -nx:]).dot(H.T) + measurement_model.noise_covar(meas_pred) 81 | 82 | # make sure matrix S is positive definite 83 | S = (S + S.T) / 2 84 | 85 | step = self.traceback_range 86 | K = state_pred.cov[:, -nx:].dot(H.T).dot(np.linalg.inv(S)) 87 | 88 | # state update 89 | state_upd = deepcopy(state_pred) 90 | state_upd.mean[:, -step:] = state_pred.mean[:, -step:] + ( 91 | K.dot(meas - meas_pred)).reshape(-1, nx).T 92 | 93 | # covariance update 94 | # bug fix, use Kalman update equation `P = P - KSK` to avoid numerical problem 95 | # state_upd.cov = state_pred.cov- K.dot(H).dot(state_pred.cov[-nx:, :]) 96 | state_upd.cov = state_pred.cov - K.dot(S).dot(K.T) 97 | 98 | return state_upd 99 | 100 | 101 | def predicted_log_likelihood(self, state_pred, meas, measurement_model): 102 | r""" 103 | 104 | Parameters 105 | ---------- 106 | state_pred : `GaussianState` 107 | meas : array [`ndim_meas`, num_meas] 108 | measurement_model : 109 | 110 | Returns 111 | ------- 112 | log_likelihood : array [num_meas] 113 | """ 114 | state_pred.mean = state_pred.mean[:, None] if state_pred.mean.ndim == 1 else state_pred.mean 115 | 116 | if 0 in meas.shape: 117 | return [] 118 | 119 | nx = state_pred.mean.shape[0] 120 | 121 | # measurement model Jocobian 122 | H = measurement_model.measurement_matrix(state_pred.mean[:, -1]) 123 | 124 | # innovation convariance 125 | meas_pred = measurement_model.forward(state_pred.mean[:, -1]) 126 | S = H.dot(state_pred.cov[-nx:, -nx:]).dot(H.T) + measurement_model.noise_covar(meas_pred) 127 | 128 | # make sure matrix S is positive definite 129 | S = (S + S.T) / 2 130 | 131 | pred_loglik = multivariate_normal.logpdf(meas.T, meas_pred[:, 0], S, allow_singular=True) 132 | 133 | return pred_loglik 134 | 135 | -------------------------------------------------------------------------------- /statecircle/models/density/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/density/tests/test_state_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | import pytest 10 | 11 | from ..kalman import KalmanDensityModel 12 | from statecircle.types.state import GaussianState, GaussianLabeledState 13 | from statecircle.models.transition.linear import ConstantVelocityModel 14 | from statecircle.models.measurement.linear import LinearMeasurementModel 15 | 16 | 17 | @pytest.fixture() 18 | def input_gaussian_state(): 19 | ndim = 4 20 | mean = np.array([1, 2, 3, 4])[:, None] 21 | cov = np.eye(ndim) 22 | state = GaussianState(mean, cov) 23 | time_step = 0.1 24 | return state, time_step 25 | 26 | @pytest.fixture() 27 | def input_gaussian_labeled_state(): 28 | ndim = 4 29 | mean = np.array([1, 2, 3, 4])[:, None] 30 | cov = np.eye(ndim) 31 | label = 10 32 | state = GaussianLabeledState(label, mean, cov) 33 | time_step = 0.1 34 | return state, time_step 35 | 36 | 37 | def test_kalman_state_model_gaussian_state(input_gaussian_state): 38 | state, time_step = input_gaussian_state 39 | 40 | transition_model = ConstantVelocityModel(sigma=10) 41 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], sigma=5) 42 | state_model = KalmanDensityModel() 43 | single_meas = measurement_model.forward(state.mean, noisy=True) 44 | some_meas = measurement_model.forward(np.repeat(state.mean, 10, axis=1), noisy=True) 45 | 46 | state_pred = state_model.predict(state, time_step, transition_model) 47 | state_upd = state_model.update(state, single_meas, measurement_model) 48 | loglik = state_model.predicted_log_likelihood(state, some_meas, measurement_model) 49 | # print(state_pred) 50 | # print(state_upd) 51 | # print(loglik) 52 | 53 | 54 | def test_kalman_state_model_labeled_gaussian_state(input_gaussian_labeled_state): 55 | state, time_step = input_gaussian_labeled_state 56 | 57 | transition_model = ConstantVelocityModel(sigma=10) 58 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], sigma=5) 59 | state_model = KalmanDensityModel() 60 | single_meas = measurement_model.forward(state.mean, noisy=True) 61 | some_meas = measurement_model.forward(np.repeat(state.mean, 10, axis=1), noisy=True) 62 | 63 | state_pred = state_model.predict(state, time_step, transition_model) 64 | state_upd = state_model.update(state, single_meas, measurement_model) 65 | loglik = state_model.predicted_log_likelihood(state, some_meas, measurement_model) 66 | # print(state_pred) 67 | # print(state_upd) 68 | # print(loglik) 69 | -------------------------------------------------------------------------------- /statecircle/models/density/unscented.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Feb 27 09:49:13 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | 10 | import numpy as np 11 | from scipy.stats import multivariate_normal 12 | 13 | from statecircle.types.base import SigmaPoints 14 | from .base import ConjugateDensityModel 15 | from statecircle.types.state import UnscentedState, GaussianState 16 | 17 | # TODO: add `UnscentedAccumulatedDensityModel` 18 | class UnscentedDensityModel(ConjugateDensityModel): 19 | """ 20 | # UnscentedState defination: 21 | # state.mean [state_dim] mean 22 | # state.cov [state_dim ,state_dim] covariance 23 | # state.sigma.points [state_dim/meas_dim, 2*state_dim + 1] sigma points position 24 | # state.sigma.mean_weights [2*state_dim + 1] mean sigma points weights 25 | # state.sigma.cov_weights [2*state_dim + 1] covariance sigma points weights 26 | 27 | Ref: 28 | [1] Wan E A, Van Der Merwe R. The unscented Kalman filter for nonlinear estimation. // 29 | Proceedings of the IEEE 2000 Adaptive System for Signal Processing, Communications and 30 | Control Symposium.Ieee, 2000: 153-158 31 | [2] S.J. Julier and J.K. Uhlmann. A new extension of the Kalman gilter to nonlinear system. // 32 | In proc. of AeroSense: The 11th Int.Symp.on Aerospace/Defence Sensing, Simulation and Controls, 33 | 1997. 34 | """ 35 | 36 | def __init__(self, state_dim, alpha=1.0, beta=2.0, kappa=None): 37 | """ 38 | 39 | Parameters 40 | ---------- 41 | state_dim: int 42 | state dimension 43 | alpha: float 44 | Determine the spread of the sigma points 45 | beta: float 46 | Incorpate prior knowledge of the distribution 47 | kappa: float 48 | Secondary scaling parameter which is usually set to 3 - `state_dim` 49 | 50 | """ 51 | super().__init__() 52 | self.state_dim = state_dim 53 | 54 | # hyper-pamateters 55 | self.alpha = alpha # range in (0, 1] 56 | self.beta = beta # optimal choice for Gaussians is 2 57 | self.kappa = kappa or 3. - state_dim # 'state_dim + kappa range' in [0, +inf) 58 | self.lambda_ = alpha ** 2 * (state_dim + self.kappa) - state_dim 59 | 60 | # constant sigma weights 61 | self.mean_weights = np.ones([2 * state_dim + 1]) / 2 / (state_dim + self.lambda_) 62 | self.mean_weights[0] = self.lambda_ / (state_dim + self.lambda_) 63 | 64 | self.cov_weights = self.mean_weights.copy() 65 | self.cov_weights[0] += 1 - alpha ** 2 + beta 66 | 67 | def update_sigma_points(self, state): 68 | # inplace op 69 | # TODO: membrane unscented parameters 70 | L = np.linalg.cholesky(state.cov) 71 | points = np.tile(state.mean.astype('float'), (1, 2 * self.state_dim + 1)) 72 | points[:, 1:(self.state_dim + 1)] += np.sqrt(self.state_dim + self.lambda_) * L 73 | points[:, (self.state_dim + 1):] -= np.sqrt(self.state_dim + self.lambda_) * L 74 | 75 | sigma_pts = SigmaPoints(points, self.mean_weights, self.cov_weights) 76 | state.sigma = sigma_pts 77 | 78 | def compute_innovation_cov(self, measurement_model, sigma): 79 | sigma_pts_meas = measurement_model.forward(sigma.points) 80 | 81 | meas_hat = sigma_pts_meas.dot(sigma.mean_weights)[:, None] 82 | # innovation covariance 83 | res_meas = sigma_pts_meas - meas_hat 84 | S = res_meas.dot(np.diag(sigma.cov_weights).dot(res_meas.T)) + measurement_model.noise_covar(meas_hat) 85 | 86 | return S 87 | 88 | def predict(self, state, time_step, transition_model): 89 | assert isinstance(state, UnscentedState) 90 | # generate sigma points 91 | self.update_sigma_points(state) 92 | 93 | # TODO: consider the nonlinear noise factor, i.e. augmented sigma points 94 | # sigma points predict 95 | points_pred = transition_model.forward(state.sigma.points, time_step=time_step) 96 | 97 | mean_pred = points_pred.dot(state.sigma.mean_weights)[:, None] 98 | res = points_pred - mean_pred 99 | cov_pred = res.dot(np.diag(state.sigma.cov_weights).dot(res.T)) + transition_model.noise_covar(time_step) 100 | 101 | # TODO: add labeled state 102 | # if hasattr(state, 'label'): 103 | # state_pred.label = state.label 104 | 105 | return UnscentedState(mean_pred, cov_pred, SigmaPoints(points_pred, self.mean_weights, self.cov_weights)) 106 | 107 | def update(self, state_pred, meas, measurement_model): 108 | meas = meas[:, None] if meas.ndim == 1 else meas 109 | assert meas.shape[1] == 1, 'expected one measurement for update step, but got {} measurements.' \ 110 | .format(meas.shape[1]) 111 | 112 | sigma_pts_meas = measurement_model.forward(state_pred.sigma.points) 113 | meas_hat = sigma_pts_meas.dot(state_pred.sigma.mean_weights)[:, None] 114 | 115 | # innovation covariance 116 | res_meas = sigma_pts_meas - meas_hat 117 | # NOTE: pre-calculate the innovation covariance matrix in `predicted_log_likelihood` function 118 | # and save the property in `UnscentedState`, because this function always run before `update` 119 | # function 120 | # inno_cov = res_meas.dot(np.diag(state_pred.sigma.cov_weights).dot(res_meas.T)) + measurement_model.noise_covar() 121 | 122 | # cross-covariance 123 | res_state = state_pred.sigma.points - state_pred.sigma.points[:, 0:1] 124 | # almost equal to the below equation 125 | # res_state = state_pred.sigma.points - state_pred.sigma.points.dot(state_pred.sigma.mean_weights)[:, None] 126 | 127 | T = res_state.dot(np.diag(state_pred.sigma.cov_weights).dot(res_meas.T)) 128 | # kalman gain 129 | K = T.dot(np.linalg.inv(state_pred.inno_cov)) 130 | 131 | # state update 132 | mean_upd = state_pred.mean + K.dot(meas - meas_hat) 133 | # covariance update 134 | cov_upd = state_pred.cov - K.dot(state_pred.inno_cov).dot(K.T) 135 | 136 | # TODO: add labeled state 137 | # if hasattr(state_pred, 'label'): 138 | # state_upd.label = state_pred.label 139 | 140 | return UnscentedState(mean_upd, cov_upd, SigmaPoints) 141 | 142 | def predicted_log_likelihood(self, state_pred, meas, measurement_model): 143 | r""" 144 | 145 | Parameters 146 | ---------- 147 | state_pred : `GaussianState` -> `UnscentedState` 148 | meas : array [`ndim_meas`, num_meas] 149 | measurement_model : 150 | 151 | Returns 152 | ------- 153 | log_likelihood : array [num_meas] 154 | """ 155 | if 0 in meas.shape: 156 | return [] 157 | 158 | if isinstance(state_pred, GaussianState) or state_pred.sigma is None: 159 | # initialize sigma points 160 | self.update_sigma_points(state_pred) 161 | 162 | sigma_pts_meas = measurement_model.forward(state_pred.sigma.points) 163 | 164 | if isinstance(state_pred, GaussianState) or state_pred.inno_cov is None: 165 | # update innovation covariance 166 | meas_hat = sigma_pts_meas.dot(state_pred.sigma.mean_weights)[:, None] 167 | res_meas = sigma_pts_meas - meas_hat 168 | inno_cov = res_meas.dot(np.diag(state_pred.sigma.cov_weights).dot(res_meas.T)) + \ 169 | measurement_model.noise_covar(meas_hat) 170 | else: 171 | inno_cov = state_pred.inno_cov 172 | 173 | mean_meas = sigma_pts_meas.dot(state_pred.sigma.mean_weights) 174 | pred_loglik = multivariate_normal.logpdf(meas.T, mean_meas, inno_cov, allow_singular=True) 175 | 176 | state_pred.inno_cov = inno_cov 177 | 178 | return pred_loglik 179 | 180 | @property 181 | def ndim(self): 182 | return self.state_dim 183 | -------------------------------------------------------------------------------- /statecircle/models/density/unscented_accumulated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Mar 3 17:16:59 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | 10 | import numpy as np 11 | from copy import deepcopy 12 | from scipy.stats import multivariate_normal 13 | 14 | from statecircle.models.density.unscented import UnscentedDensityModel 15 | from statecircle.types.base import SigmaPoints 16 | from statecircle.types.state import GaussianState, GaussianAccumulatedState, UnscentedState 17 | from .base import ConjugateDensityModel 18 | 19 | from deprecated import deprecated 20 | 21 | @deprecated(reason="uncompleted.") 22 | class UnscentedAccumulatedDensityModel(UnscentedDensityModel): 23 | r"""[Extended] Kalman accumulated state model 24 | 25 | linear gaussian transition/measurement model 26 | """ 27 | 28 | def __init__(self, state_dim, alpha=1.0, beta=2.0, kappa=None, traceback_range=100): 29 | r""" 30 | 31 | Parameters 32 | ---------- 33 | state_dim: int 34 | state dimension 35 | alpha: float 36 | Determine the spread of the sigma points 37 | beta: float 38 | Incorpate prior knowledge of the distribution 39 | kappa: float 40 | Secondary scaling parameter which is usually set to 3 - `state_dim` 41 | traceback_range : int 42 | accumulated density range 43 | """ 44 | super().__init__(state_dim, alpha, beta, kappa) 45 | self.traceback_range = traceback_range 46 | 47 | def update_sigma_points(self, state): 48 | # inplace op 49 | # TODO: membrane unscented parameters 50 | L = np.linalg.cholesky(state.cov) 51 | points = np.tile(state.mean.astype('float'), (1, 2 * self.state_dim + 1)) 52 | points[:, 1:(self.state_dim + 1)] += np.sqrt(self.state_dim + self.lambda_) * L 53 | points[:, (self.state_dim + 1):] -= np.sqrt(self.state_dim + self.lambda_) * L 54 | 55 | sigma_pts = SigmaPoints(points, self.mean_weights, self.cov_weights) 56 | state.sigma = sigma_pts 57 | 58 | def predict_state(self, state, time_step, transition_model): 59 | if state.mean.ndim == 1: 60 | state.mean = state.mean[:, None] 61 | d = transition_model.state_dim 62 | 63 | assert isinstance(state, UnscentedState) 64 | # generate sigma points 65 | self.update_sigma_points(UnscentedState(state.mean[:, -1], state.cov[-d:, -d:])) 66 | 67 | points_pred = transition_model.forward(state.sigma.points, time_step=time_step) 68 | mean_pred = points_pred.dot(state.sigma.mean_weights)[:, None] 69 | res_pred = points_pred - mean_pred 70 | cov_pred = res_pred.dot(np.diag(state.sigma.cov_weights).dot(res_pred.T)) + \ 71 | transition_model.noise_covar(time_step) 72 | 73 | return GaussianState(mean_pred, cov_pred) 74 | 75 | def predict(self, state, time_step, transition_model): 76 | state.mean = state.mean[:, None] if state.mean.ndim == 1 else state.mean 77 | # sigma points predict 78 | points_pred = transition_model.forward(state.sigma.points, time_step=time_step) 79 | mean_pred = points_pred.dot(state.sigma.mean_weights)[:, None] 80 | mean_pred_accu = np.hstack((state.mean, mean_pred)) 81 | 82 | F = transition_model.transition_matrix(state.mean[:, -1], time_step) 83 | d = transition_model.state_dim 84 | step_d = self.traceback_range * d 85 | new_cov1 = state.cov[-step_d:, -d:].dot(F.T) 86 | new_cov2 = F.dot(state.cov[-d:, -step_d:]) 87 | new_cov3 = F.dot(state.cov[-d:, -d:]).dot(F.T) + transition_model.noise_covar(time_step) 88 | cov_pred = np.vstack((np.hstack((state.cov[-step_d:, -step_d:], new_cov1)), 89 | np.hstack((new_cov2, new_cov3)) 90 | )) 91 | return GaussianAccumulatedState(mean_pred_accu, cov_pred) 92 | 93 | def update(self, state_pred, meas, measurement_model): 94 | meas = meas[:, None] if meas.ndim == 1 else meas 95 | assert meas.shape[1] == 1, 'expected one measurement for update step, but got {} measurements.' \ 96 | .format(meas.shape[1]) 97 | 98 | state_pred.mean = state_pred.mean[:, None] if state_pred.mean.ndim == 1 else state_pred.mean 99 | 100 | nx = state_pred.mean.shape[0] 101 | 102 | # measurement model Jacobian 103 | H = measurement_model.measurement_matrix(state_pred.mean[:, -1]) 104 | 105 | # innovation covariance 106 | S = H.dot(state_pred.cov[-nx:, -nx:]).dot(H.T) + measurement_model.noise_covar() 107 | 108 | # make sure matrix S is positive definite 109 | S = (S + S.T) / 2 110 | 111 | step = self.traceback_range 112 | step_d = self.traceback_range * nx 113 | K = state_pred.cov[-step_d:, -nx:].dot(H.T).dot(np.linalg.inv(S)) 114 | 115 | # state update 116 | state_upd = deepcopy(state_pred) 117 | state_upd.mean[:, -step:] = state_pred.mean[:, -step:] + ( 118 | K.dot(meas - measurement_model.forward(state_pred.mean[:, -1]))).reshape(-1, nx).T 119 | 120 | # covariance update 121 | state_upd.cov = state_pred.cov[-step_d:, -step_d:] - K.dot(H).dot(state_pred.cov[-nx:, -step_d:]) 122 | 123 | return state_upd 124 | 125 | def predicted_log_likelihood(self, state_pred, meas, measurement_model): 126 | r""" 127 | 128 | Parameters 129 | ---------- 130 | state_pred : `GaussianState` 131 | meas : array [`ndim_meas`, num_meas] 132 | measurement_model : 133 | 134 | Returns 135 | ------- 136 | log_likelihood : array [num_meas] 137 | """ 138 | state_pred.mean = state_pred.mean[:, None] if state_pred.mean.ndim == 1 else state_pred.mean 139 | 140 | if 0 in meas.shape: 141 | return [] 142 | 143 | nx = state_pred.mean.shape[0] 144 | 145 | # measurement model Jocobian 146 | H = measurement_model.measurement_matrix(state_pred.mean[:, -1]) 147 | 148 | # innovation convariance 149 | S = H.dot(state_pred.cov[-nx:, -nx:]).dot(H.T) + measurement_model.noise_covar() 150 | 151 | # make sure matrix S is positive definite 152 | S = (S + S.T) / 2 153 | 154 | mean = measurement_model.forward(state_pred.mean[:, -1]) 155 | pred_loglik = multivariate_normal.logpdf(meas.T, mean[:, 0], S, allow_singular=True) 156 | 157 | return pred_loglik 158 | -------------------------------------------------------------------------------- /statecircle/models/measurement/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | 9 | from .base import MeasurementModel 10 | 11 | __all__ = ['MeasurementModel'] -------------------------------------------------------------------------------- /statecircle/models/measurement/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from ..base import Model 9 | from abc import abstractmethod 10 | 11 | class MeasurementModel(Model): 12 | r"""Base measurement models class""" 13 | 14 | @property 15 | @abstractmethod 16 | def ndim_meas(self): 17 | r"""measurement dimentsion""" 18 | 19 | 20 | -------------------------------------------------------------------------------- /statecircle/models/measurement/clutter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | from scipy.stats import poisson 10 | 11 | from statecircle.models.base import Model 12 | 13 | 14 | class ClutterModel(Model): 15 | r"""Clutter model base class""" 16 | 17 | 18 | class PoissonClutterModel(ClutterModel): 19 | r"""Poisson sensor model""" 20 | 21 | def __init__(self, detection_rate, lambda_clutter, scope): 22 | self.detection_rate = detection_rate 23 | self.lambda_clutter = lambda_clutter 24 | self.scope = np.array(scope) 25 | 26 | volumn = np.prod(self.scope[:, 1] - self.scope[:, 0]) 27 | self.intensity_clutter = lambda_clutter / volumn 28 | self.density = 1 / volumn 29 | self.cardinality_pmf = poisson(lambda_clutter).pmf 30 | -------------------------------------------------------------------------------- /statecircle/models/measurement/linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | 9 | import numpy as np 10 | from scipy.stats import multivariate_normal 11 | 12 | from .base import MeasurementModel 13 | 14 | 15 | class LinearMeasurementModel(MeasurementModel): 16 | r""" Linear measurement models 17 | 18 | Some dimensions of the state can be measured. 19 | Linear measuremodel can be described as follows: 20 | 21 | .. math:: 22 | 23 | y_t = H_t*x_t + v_t, \ \ \ \ v_t \sim \mathcal{N}(0, R) 24 | 25 | Attributes 26 | ---------- 27 | mapping : boolean mapping index 28 | Represent which dimension of the state can be measured 29 | """ 30 | 31 | def __init__(self, mapping, sigma=None, cov=None): 32 | if sigma is not None: 33 | assert sigma > 0 34 | self.sigma = sigma 35 | self.mapping = np.array(mapping, dtype=np.bool_) 36 | 37 | # measurement funciton, in linear case, observation function is a matrix 38 | self.meas_mat = np.zeros((self.ndim_meas, self.ndim), dtype=np.float_) 39 | row = 0 40 | for idx, ele in enumerate(self.mapping): 41 | if ele: 42 | self.meas_mat[row, idx] = 1 43 | row += 1 44 | 45 | # noise convariance 46 | if cov is not None: 47 | self.noise_cov = cov 48 | elif sigma is not None: 49 | self.noise_cov = self.sigma ** 2 * np.eye(self.ndim_meas, dtype=np.float_) 50 | else: 51 | raise ValueError 52 | 53 | def measurement_matrix(self, x): 54 | return self.meas_mat 55 | 56 | def noise_covar(self, *args, **kwargs): 57 | return self.noise_cov 58 | 59 | def rvs(self, mean, cov, num_samples): 60 | noise = multivariate_normal.rvs(mean, cov, num_samples) 61 | return np.atleast_2d(noise).T 62 | 63 | @property 64 | def ndim(self): 65 | return len(self.mapping) 66 | 67 | @property 68 | def ndim_meas(self): 69 | return self.mapping.sum() 70 | 71 | def forward(self, x, noisy=False): 72 | r"""calcute measurements with/without noise""" 73 | x = x[:, None] if x.ndim == 1 else x 74 | assert x.shape[0] == self.ndim 75 | num_states = x.shape[1] 76 | if noisy: 77 | noise = self.rvs(np.zeros(self.ndim_meas, dtype=np.float_), 78 | self.noise_cov, 79 | num_states) 80 | return self.meas_mat.dot(x) + noise 81 | else: 82 | return self.meas_mat.dot(x) 83 | 84 | def reverse(self, z): 85 | r"""calcute state from measurement""" 86 | assert z.shape[0] == self.ndim_meas 87 | return np.linalg.pinv(self.meas_mat).dot(z) 88 | -------------------------------------------------------------------------------- /statecircle/models/measurement/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/measurement/tests/test_linear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | from operator import __ne__ 10 | from ..linear import LinearMeasurementModel 11 | import pytest 12 | 13 | @pytest.fixture(scope='module') 14 | def state_measurement_model_testing(x, observered_index): 15 | meas_model = LinearMeasurementModel(mapping=observered_index, 16 | sigma=1) 17 | z = meas_model.forward(x, noisy=False) 18 | x_inv = meas_model.reverse(z) 19 | z_noise1 = meas_model.forward(x, noisy=True) 20 | z_noise2 = meas_model.forward(x, noisy=True) 21 | np.testing.assert_almost_equal(meas_model.jacobian(x), meas_model.meas_mat) 22 | x[~observered_index] = 0 23 | assert x.shape[0] == meas_model.ndim 24 | assert z.shape[0] == meas_model.ndim_meas 25 | assert x.shape[1] == z.shape[1] 26 | np.testing.assert_almost_equal(x, x_inv) 27 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 28 | 29 | 30 | def test_state_measurement_model_case1(): 31 | x = np.random.rand(4, 1) 32 | observered_index = np.array([1, 1, 0, 0], dtype=np.bool) 33 | state_measurement_model_testing(x, observered_index) 34 | 35 | 36 | def test_state_measurement_model_case2(): 37 | x = np.random.rand(4, 2) 38 | observered_index = np.array([1, 0, 1, 0], dtype=np.bool) 39 | state_measurement_model_testing(x, observered_index) 40 | 41 | 42 | def test_state_measurement_model_case3(): 43 | x = np.random.rand(5, 1) 44 | observered_index = np.array([1, 1, 0, 0, 0], dtype=np.bool) 45 | state_measurement_model_testing(x, observered_index) 46 | 47 | 48 | def test_state_measurement_model_case4(): 49 | x = np.random.rand(5, 3) 50 | observered_index = np.array([1, 0, 0, 1, 1], dtype=np.bool) 51 | state_measurement_model_testing(x, observered_index) 52 | -------------------------------------------------------------------------------- /statecircle/models/measurement/tests/test_nonlinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | 9 | import numpy as np 10 | from operator import __ne__ 11 | from statecircle.models.measurement.nonlinear import * 12 | 13 | 14 | def test_range_bearing_measurement_model_1state(): 15 | sigma_range = 1. 16 | sigma_bearing = 2. 17 | origin = np.array([10, 20]) 18 | meas_model = RangeBearningMeasurementModel(sigma_range, sigma_bearing, origin) 19 | 20 | x = np.random.rand(5, 1) 21 | z = meas_model.forward(x, noisy=False) 22 | z_noise1 = meas_model.forward(x, noisy=True) 23 | z_noise2 = meas_model.forward(x, noisy=True) 24 | 25 | x_delta = 0.01 * x 26 | z_delta = meas_model.forward(x + x_delta) - meas_model.forward(x) 27 | jacobian_mat = meas_model.jacobian(x) 28 | z_delta_approx = jacobian_mat.dot(x_delta) 29 | x_inv = meas_model.reverse(z) 30 | 31 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 32 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 33 | assert x.shape[1] == z.shape[1] 34 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 35 | np.testing.assert_almost_equal(z_delta, z_delta_approx, decimal=3) 36 | assert jacobian_mat.shape == (meas_model.meas_dim, meas_model.state_dim) 37 | x[2:] = 0 38 | np.testing.assert_almost_equal(x, x_inv) 39 | 40 | 41 | def test_range_bearing_measurement_model_2state(): 42 | sigma_range = 1. 43 | sigma_bearing = 2. 44 | origin = np.array([10, 20]) 45 | meas_model = RangeBearningMeasurementModel(sigma_range, sigma_bearing, origin) 46 | 47 | x = np.random.rand(5, 2) 48 | z = meas_model.forward(x, noisy=False) 49 | z_noise1 = meas_model.forward(x, noisy=True) 50 | z_noise2 = meas_model.forward(x, noisy=True) 51 | x_inv = meas_model.reverse(z) 52 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 53 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 54 | assert x.shape[1] == z.shape[1] 55 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 56 | x[2:] = 0 57 | np.testing.assert_almost_equal(x, x_inv) 58 | 59 | 60 | def test_range_bearing_2d_model_1state(): 61 | noise_cov = np.diag([1, 1]) 62 | meas_model = RangeBearing2DMeasurementModel(noise_cov) 63 | 64 | x = np.random.rand(4, 1) 65 | z = meas_model.forward(x, noisy=False) 66 | z_noise1 = meas_model.forward(x, noisy=True) 67 | z_noise2 = meas_model.forward(x, noisy=True) 68 | 69 | x_delta = 0.01 * x 70 | z_delta = meas_model.forward(x + x_delta) - meas_model.forward(x) 71 | jacobian_mat = meas_model.jacobian(x) 72 | z_delta_approx = jacobian_mat.dot(x_delta) 73 | x_inv = meas_model.reverse(z) 74 | 75 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 76 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 77 | assert x.shape[1] == z.shape[1] 78 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 79 | np.testing.assert_almost_equal(z_delta, z_delta_approx, decimal=3) 80 | assert jacobian_mat.shape == (meas_model.meas_dim, meas_model.state_dim) 81 | x[2:] = 0 82 | np.testing.assert_almost_equal(x, x_inv) 83 | 84 | 85 | def test_range_bearing_2d_model_3state(): 86 | noise_cov = np.diag([1, 1]) 87 | meas_model = RangeBearing2DMeasurementModel(noise_cov) 88 | 89 | x = np.random.rand(4, 3) 90 | z = meas_model.forward(x, noisy=False) 91 | x_inv = meas_model.reverse(z) 92 | 93 | z_noise1 = meas_model.forward(x, noisy=True) 94 | z_noise2 = meas_model.forward(x, noisy=True) 95 | 96 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 97 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 98 | assert x.shape[1] == z.shape[1] 99 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 100 | x[2:] = 0 101 | np.testing.assert_almost_equal(x, x_inv) 102 | 103 | def test_range_bearing_4d_model_1state(): 104 | noise_cov = np.diag([1, 1, 1, 1]) 105 | meas_model = RangeBearing4DMeasurementModel(noise_cov) 106 | 107 | x = np.random.rand(4, 1) 108 | z = meas_model.forward(x, noisy=False) 109 | z_noise1 = meas_model.forward(x, noisy=True) 110 | z_noise2 = meas_model.forward(x, noisy=True) 111 | 112 | x_delta = 0.01 * x 113 | z_delta = meas_model.forward(x + x_delta) - meas_model.forward(x) 114 | jacobian_mat = meas_model.jacobian(x) 115 | z_delta_approx = jacobian_mat.dot(x_delta) 116 | x_inv = meas_model.reverse(z) 117 | 118 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 119 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 120 | assert x.shape[1] == z.shape[1] 121 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 122 | np.testing.assert_almost_equal(z_delta, z_delta_approx, decimal=3) 123 | assert jacobian_mat.shape == (meas_model.meas_dim, meas_model.state_dim) 124 | np.testing.assert_almost_equal(x, x_inv) 125 | 126 | 127 | def test_range_bearing_4d_model_3state(): 128 | noise_cov = np.diag([1, 1, 1, 1]) 129 | meas_model = RangeBearing4DMeasurementModel(noise_cov) 130 | 131 | x = np.random.rand(4, 3) 132 | z = meas_model.forward(x, noisy=False) 133 | x_inv = meas_model.reverse(z) 134 | 135 | z_noise1 = meas_model.forward(x, noisy=True) 136 | z_noise2 = meas_model.forward(x, noisy=True) 137 | 138 | assert x.shape[0] == meas_model.state_dim == meas_model.ndim 139 | assert z.shape[0] == meas_model.meas_dim == meas_model.ndim_meas 140 | assert x.shape[1] == z.shape[1] 141 | np.testing.assert_array_compare(__ne__, z_noise1, z_noise2) 142 | np.testing.assert_almost_equal(x, x_inv) 143 | 144 | 145 | if __name__ == "__main__": 146 | test_range_bearing_measurement_model_1state() 147 | test_range_bearing_measurement_model_2state() 148 | 149 | test_range_bearing_2d_model_1state() 150 | test_range_bearing_2d_model_3state() 151 | 152 | test_range_bearing_4d_model_1state() 153 | test_range_bearing_4d_model_3state() 154 | -------------------------------------------------------------------------------- /statecircle/models/sensor/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/sensor/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from abc import abstractmethod 9 | import numpy as np 10 | 11 | from statecircle.types.data import MeasurementSeries, MeasurementData 12 | from ..base import Model 13 | 14 | 15 | class SensorModel(Model): 16 | r""""Sensor model base class""" 17 | 18 | @abstractmethod 19 | def detect(self): 20 | r"""virtual detect method""" 21 | 22 | 23 | class DummySensorModel(Model): 24 | r"""Generated simulated sensor data from ground truth data""" 25 | 26 | def __init__(self, clutter_model, measurement_model, random_seed=None, noisy=True): 27 | self.clutter_model = clutter_model 28 | self.measurement_model = measurement_model 29 | self.seed = random_seed or np.random.randint(2**32) 30 | self.noisy = noisy 31 | 32 | def detect(self, gt_series, shuffle=True): 33 | if self.seed is not None: 34 | np.random.seed(self.seed) 35 | 36 | # initialize memory 37 | data_len = len(gt_series.datum) 38 | # meas_data = [[] for _ in range(data_len)] 39 | meas_data = MeasurementSeries(data_len) 40 | obj_meas_data = [np.empty((self.measurement_model.ndim_meas, 0)) for _ in range(data_len)] 41 | clutter_data = [np.empty((self.measurement_model.ndim_meas, 0)) for _ in range(data_len)] 42 | 43 | # generate measurements 44 | for k in range(data_len): 45 | meas_data.timestamps[k] = gt_series.timestamps[k] 46 | if gt_series.num[k] > 0: 47 | idx = np.random.rand(gt_series.num[k]) < self.clutter_model.detection_rate 48 | # only generate object-orienated observations for detected objects 49 | if any(idx): 50 | obj_states = np.hstack([gt_data.states for i, gt_data in enumerate(gt_series.datum[k]) if idx[i]]) 51 | for i in range(obj_states.shape[1]): 52 | if self.noisy: 53 | meas = np.random.multivariate_normal( 54 | self.measurement_model.forward(obj_states[:, i])[:, 0], 55 | self.measurement_model.noise_covar())[:, None] 56 | else: 57 | meas = self.measurement_model.forward(obj_states[:, i])[:, 0] 58 | meas_data.datum[k].append(meas) 59 | 60 | # number of clutter measurements 61 | num_clutter = np.random.poisson(self.clutter_model.lambda_clutter) 62 | 63 | # generate clutter 64 | clutter = np.tile(self.clutter_model.scope[:, 0], [num_clutter, 1]).T + \ 65 | np.diag(self.clutter_model.scope.dot(np.array([-1, 1]))). \ 66 | dot(np.random.rand(self.measurement_model.ndim_meas, num_clutter)) 67 | 68 | if len(meas_data.datum[k]) == 0: 69 | # no detection measurement 70 | clutter_data[k] = meas_data.datum[k] = clutter 71 | else: 72 | obj_meas_data[k] = meas_data.datum[k] = np.hstack(meas_data.datum[k]) 73 | # total measurements are the union of object detections and clutter 74 | meas_data.datum[k] = np.hstack((meas_data.datum[k], clutter)) 75 | clutter_data[k] = clutter 76 | 77 | if shuffle: 78 | meas_data.datum[k] = np.random.permutation(meas_data.datum[k].T).T 79 | 80 | return meas_data, obj_meas_data, clutter_data 81 | 82 | def detect_iter(self, data_generator, shuffle=True): 83 | if self.seed is not None: 84 | np.random.seed(self.seed) 85 | 86 | for data in data_generator: 87 | obj_meas = np.empty([self.measurement_model.ndim_meas, 0]) 88 | if data.states is not None: 89 | idx = np.random.rand(data.states.shape[1]) < self.clutter_model.detection_rate 90 | # only generate object-orienated observations for detected objects 91 | if any(idx): 92 | obj_states = data.states[:, idx] 93 | for i in range(obj_states.shape[1]): 94 | if self.noisy: 95 | obj_meas = np.hstack((obj_meas, np.random.multivariate_normal( 96 | self.measurement_model.forward(obj_states[:, i])[:, 0], 97 | self.measurement_model.noise_covar())[:, None])) 98 | else: 99 | obj_meas = np.hstack((obj_meas, self.measurement_model.forward(obj_states[:, i]))) 100 | 101 | # number of clutter measurements 102 | num_clutter = np.random.poisson(self.clutter_model.lambda_clutter) 103 | 104 | # generate clutter 105 | clutter_meas = np.tile(self.clutter_model.scope[:, 0], [num_clutter, 1]).T + \ 106 | np.diag(self.clutter_model.scope.dot(np.array([-1, 1]))). \ 107 | dot(np.random.rand(self.measurement_model.ndim_meas, num_clutter)) 108 | 109 | meas_data = MeasurementData(timestamp=data.timestamp, meas=np.hstack((obj_meas, clutter_meas))) 110 | 111 | if shuffle: 112 | meas_data.meas = np.random.permutation(meas_data.meas.T).T 113 | 114 | yield meas_data, obj_meas, clutter_meas 115 | -------------------------------------------------------------------------------- /statecircle/models/sensor/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/sensor/tests/test_sensor_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import pytest 9 | import numpy as np 10 | 11 | from ..base import DummySensorModel 12 | from statecircle.models.measurement.clutter import PoissonClutterModel 13 | from statecircle.models.transition.linear import ConstantVelocityModel 14 | from statecircle.models.measurement.linear import LinearMeasurementModel 15 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 16 | 17 | 18 | @pytest.fixture() 19 | def input_paras(): 20 | P_D = 0.7 21 | lambda_clutter = 20 22 | scope = np.array([[0, 1000], [0, 1000]]) 23 | clutter_model = PoissonClutterModel(P_D, lambda_clutter, scope) 24 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], sigma=5) 25 | return clutter_model, measurement_model 26 | 27 | @pytest.fixture() 28 | def make_data_generator(): 29 | initial_states = np.random.rand(4, 5) 30 | birth_times = [10, 20, 30, 40, 50] 31 | death_times = [40, 50, 60, 70, 80] 32 | time_scope = [0, 100] 33 | transition_model = ConstantVelocityModel(sigma=5) 34 | data_gen = SimulatedGroundTruthDataGenerator(initial_states, birth_times, death_times, 35 | time_scope, transition_model, noisy=False) 36 | return data_gen 37 | 38 | 39 | def test_dummy_sensor_model(input_paras, make_data_generator): 40 | clutter_model, measurement_model = input_paras 41 | data_gen = make_data_generator 42 | sensor_model = DummySensorModel(clutter_model, measurement_model) 43 | meas_data, obj_meas_data, clutter_data = sensor_model.detect(data_gen.gt_series) 44 | for meas, obj_meas, c_meas in zip(meas_data.datum, obj_meas_data, clutter_data): 45 | assert meas.shape[1] == obj_meas.shape[1] + c_meas.shape[1] 46 | 47 | data_reader = sensor_model.detect_iter(data_gen) 48 | print(len(list(data_reader))) 49 | for meas in data_reader: 50 | meas_data, obj_meas, clutter_meas = meas 51 | assert meas_data.shape[1] == obj_meas.shape[1] + clutter_meas.shape[1] 52 | 53 | 54 | -------------------------------------------------------------------------------- /statecircle/models/transition/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/transition/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from ..base import Model 9 | 10 | 11 | class TransitionModel(Model): 12 | r"""Base transition models class""" 13 | def __init__(self, *args, **kwargs): 14 | pass 15 | 16 | def jacobian(self, x): 17 | r"""Virtual Jacobian matrix for transition forward function""" 18 | raise NotImplementedError 19 | 20 | def forward(self, x, noise=None): 21 | r"""Virtual transition forward method for transition models 22 | 23 | .. math:: 24 | 25 | x_{t+1} = f_t(x_t) + (u_t), \ \ \ \ u_t \sim \mathcal{N}(0, Q_t) 26 | """ 27 | raise NotImplementedError 28 | 29 | def reverse(self, z): 30 | r"""Virtial transition reverse method for transition models 31 | 32 | .. math:: 33 | 34 | x_t = f_t^{-1}(x_{t+1}) 35 | """ 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /statecircle/models/transition/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/models/transition/tests/test_linear_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | import pytest 10 | from operator import __ne__ 11 | import datetime 12 | 13 | from ..linear import ConstantVelocityModel 14 | 15 | 16 | def constant_velocity_model_testing(x, step): 17 | trans_model = ConstantVelocityModel(sigma=1) 18 | x_next = trans_model.forward(x, step, noisy=False) 19 | x_next_noise = trans_model.forward(x, step, noisy=True) 20 | np.testing.assert_almost_equal(trans_model.jacobian(x, step), 21 | trans_model.transition_matrix(step)) 22 | assert x.shape[0] == trans_model.state_dim 23 | assert x_next.shape[0] == trans_model.ndim 24 | np.testing.assert_array_compare(__ne__, x_next, x_next_noise) 25 | 26 | 27 | def test_constant_velocity_model_case1(): 28 | state = np.random.rand(4, 1) 29 | old_timestamp = datetime.datetime.now() 30 | timediff = 1 # 1sec 31 | new_timestamp = old_timestamp + datetime.timedelta(seconds=timediff) 32 | step = (new_timestamp - old_timestamp).total_seconds() 33 | constant_velocity_model_testing(state, step) 34 | 35 | 36 | def test_constant_velocity_model_case2(): 37 | state = np.random.rand(4, 2) 38 | old_timestamp = datetime.datetime.now() 39 | timediff = 10 # 1sec 40 | new_timestamp = old_timestamp + datetime.timedelta(seconds=timediff) 41 | step = (new_timestamp - old_timestamp).total_seconds() 42 | constant_velocity_model_testing(state, step) 43 | -------------------------------------------------------------------------------- /statecircle/models/transition/tests/test_nonlinear_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | import pytest 10 | from operator import __ne__ 11 | import datetime 12 | 13 | from ..nonlinear import SimpleCTRVModel 14 | 15 | 16 | def constant_turnrate_model_testing(x, step): 17 | trans_model = SimpleCTRVModel(sigma_vel=1, sigma_omega=2) 18 | x_next = trans_model.forward(x, step, noisy=False) 19 | x_next_noise = trans_model.forward(x, step, noisy=True) 20 | 21 | assert x.shape[0] == trans_model.state_dim 22 | assert x_next.shape[0] == trans_model.ndim 23 | # assert (x_next - x_next_noise).sum() > 1e-10 24 | 25 | if x.shape[1] == 1: 26 | x_delta = 0.01 * x 27 | x_next_delta = trans_model.forward(x + x_delta, step) - \ 28 | trans_model.forward(x, step) 29 | x_next_delta_approx = trans_model.jacobian(x, step).dot(x_delta) 30 | np.testing.assert_almost_equal(x_next_delta, x_next_delta_approx, decimal=3) 31 | 32 | 33 | def test_constant_turnrate_model_case1(): 34 | state = np.random.rand(5, 1) 35 | old_timestamp = datetime.datetime.now() 36 | timediff = 0.1 # 1sec 37 | new_timestamp = old_timestamp + datetime.timedelta(seconds=timediff) 38 | step = (new_timestamp - old_timestamp).total_seconds() 39 | constant_turnrate_model_testing(state, step) 40 | 41 | 42 | def test_constant_turnrate_model_case2(): 43 | state = np.random.rand(5, 2) 44 | old_timestamp = datetime.datetime.now() 45 | timediff = 10 # 1sec 46 | new_timestamp = old_timestamp + datetime.timedelta(seconds=timediff) 47 | step = (new_timestamp - old_timestamp).total_seconds() 48 | constant_turnrate_model_testing(state, step) 49 | -------------------------------------------------------------------------------- /statecircle/platform/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/platform/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | import yaml 10 | 11 | from ..base import Base 12 | from statecircle.utils.common import import_model_with_paras, import_model, AttrDict 13 | 14 | 15 | class Platform(Base): 16 | r"""Base platform class""" 17 | 18 | 19 | class TrackerPlatform(Platform): 20 | r"""Single Object Tracker Platform""" 21 | cfg = None 22 | 23 | def __init__(self, config_path): 24 | with open(config_path) as f: 25 | TrackerPlatform.cfg = AttrDict.from_dict(yaml.load(f)) 26 | 27 | print('building {}...\ntracker name: {}'.format(self.cfg.TRACKER.type, self.cfg.TRACKER.name)) 28 | # build modules 29 | self.build_tracker_modules() 30 | 31 | # sanity check 32 | self.sanity_check() 33 | 34 | # build tracker 35 | self.tracker = import_model(self.cfg.TRACKER.type)(**self.__dict__) 36 | 37 | def build_tracker_modules(self): 38 | for k, v in self.cfg.TRACKER.items(): 39 | if isinstance(v, AttrDict) and 'type' in v: 40 | setattr(self, k.lower(), import_model_with_paras(v)) 41 | 42 | def sanity_check(self): 43 | mandatory_modules = ['birth_model', 'density_model', 'transition_model', 44 | 'measurement_model', 'clutter_model', 'gate', 'estimator'] 45 | for module_name in mandatory_modules: 46 | assert module_name in self.__dict__.keys() 47 | 48 | assert self.birth_model.ndim == self.transition_model.ndim 49 | -------------------------------------------------------------------------------- /statecircle/reader/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/reader/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | from ..base import Base 9 | 10 | 11 | class Reader(Base): 12 | r""""Reader base class""" 13 | 14 | 15 | class MeasurementReader(Base): 16 | r"""Measurement reader""" 17 | def __init__(self, data_generator, sensor_model): 18 | self.data_generator = data_generator 19 | self.sensor_model = sensor_model 20 | 21 | def __len__(self): 22 | return len(self.data_generator) 23 | 24 | def __iter__(self): 25 | for meas, *_ in self.sensor_model.detect_iter(self.data_generator): 26 | yield meas 27 | 28 | def truth_meas_generator(self): 29 | for meas, obj_meas, clutter_meas in self.sensor_model.detect_iter(self.data_generator): 30 | yield meas, obj_meas, clutter_meas 31 | 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /statecircle/reader/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/reader/tests/test_readers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import pytest 9 | import numpy as np 10 | 11 | from ..base import MeasurementReader 12 | from statecircle.models.sensor.base import DummySensorModel 13 | from statecircle.models.measurement.clutter import PoissonClutterModel 14 | from statecircle.models.measurement.linear import LinearMeasurementModel 15 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 16 | from statecircle.models.transition.linear import ConstantVelocityModel 17 | 18 | 19 | @pytest.fixture() 20 | def data_generator(): 21 | initial_states = np.random.rand(4, 5) 22 | birth_times = [10, 20, 30, 40, 50] 23 | death_times = [40, 50, 60, 70, 80] 24 | time_scope = [0, 100] 25 | transition_model = ConstantVelocityModel(sigma=5) 26 | return SimulatedGroundTruthDataGenerator(initial_states, birth_times, death_times, 27 | time_scope, transition_model, noisy=False) 28 | 29 | @pytest.fixture() 30 | def sensor_model(): 31 | P_D = 0.7 32 | lambda_clutter = 20 33 | scope = np.array([[0, 1000], [0, 1000]]) 34 | clutter_model = PoissonClutterModel(P_D, lambda_clutter, scope) 35 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], sigma=5) 36 | return DummySensorModel(clutter_model, measurement_model) 37 | 38 | def test_measurement_model(data_generator, sensor_model): 39 | reader = MeasurementReader(data_generator, sensor_model) 40 | assert len(reader) == len(data_generator) 41 | for meas in reader: 42 | pass 43 | 44 | meas = next(iter(reader)) 45 | assert meas is not None 46 | 47 | for meas in reader: 48 | assert meas.shape[0] == sensor_model.measurement_model.ndim_meas 49 | 50 | for meas, obj_meas, clutter_meas in reader.truth_meas_generator: 51 | assert meas.shape[1] == obj_meas.shape[1] + clutter_meas.shape[1] -------------------------------------------------------------------------------- /statecircle/reductor/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/reductor/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 15:19:31 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | class Reductor(): 11 | r""" Reductor base class 12 | Contains gating methods and hypothesis reduce method such like pruning and 13 | merging 14 | """ 15 | def __init__(self, *args, **kwargs): 16 | pass 17 | 18 | -------------------------------------------------------------------------------- /statecircle/reductor/gate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 15:28:34 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | import numpy as np 11 | from scipy.stats import chi2 12 | 13 | from statecircle.models.measurement.linear import LinearMeasurementModel 14 | from statecircle.models.measurement.nonlinear import NonlinearMeasurementModel 15 | from .base import Reductor 16 | 17 | 18 | class Gate(Reductor): 19 | r""" Gating base class """ 20 | 21 | def gating(self): 22 | """ virtual gating operation function """ 23 | raise NotImplementedError 24 | 25 | 26 | class RectangularGate(Gate): 27 | r""" Rectangular gate """ 28 | def gating(self): 29 | raise NotImplementedError 30 | 31 | 32 | 33 | class EllipsoidalGate(Gate): 34 | r"""Ellipsoidal gating mechanism """ 35 | 36 | def __init__(self, percentile=None, thresh=None): 37 | r""" 38 | 39 | Parameters 40 | ---------- 41 | percentile 42 | thresh 43 | """ 44 | self.percentile = percentile 45 | self.thresh = thresh 46 | 47 | def gating(self, state, meas, measurement_model): 48 | r"""Ellipsoidal gating operation function 49 | 50 | Parameters 51 | ---------- 52 | state 53 | meas 54 | measurement_model 55 | 56 | Returns 57 | ------- 58 | meas_ingate : `Matrix` (meas_dim, num_meas_in_gate) 59 | measurements in current state gate scope 60 | meas_index : `BoolVector` (num_meas) 61 | bool vector represents the measurements index which inside the gate 62 | """ 63 | # used for `KalmanAccumulatedDensityModel` 64 | ndim = state.mean.shape[0] 65 | ndim_meas = measurement_model.ndim_meas 66 | 67 | if len(meas) == 0: 68 | return np.empty((ndim_meas, 0)), np.empty(0) 69 | 70 | if self.percentile is not None: 71 | self.thresh = chi2.ppf(self.percentile, ndim_meas) 72 | elif self.thresh is not None: 73 | self.percentile = chi2.cdf(self.thresh, ndim_meas) 74 | else: 75 | raise ValueError 76 | 77 | # measurement matrix 78 | H = measurement_model.measurement_matrix(state.mean[:, -1:]) 79 | 80 | # innovation covariance 81 | meas_pred = measurement_model.forward(state.mean[:, -1:]) 82 | S = H.dot(state.cov[-ndim:, -ndim:]).dot(H.T) + measurement_model.noise_covar(meas_pred) 83 | 84 | # make sure matrix S is positive definite 85 | S = (S + S.T) / 2 86 | 87 | num_meas = meas.shape[1] 88 | distances = np.zeros(num_meas) 89 | for i in range(num_meas): 90 | distances[i] = (meas[:, i] - meas_pred[:, 0]).dot(np.linalg.inv(S)).dot(meas[:, i] - meas_pred[:, 0]) 91 | 92 | meas_index = distances <= self.thresh 93 | meas_ingate = meas[:, meas_index] 94 | 95 | return meas_ingate, meas_index 96 | -------------------------------------------------------------------------------- /statecircle/reductor/hypothesis_reductor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 18 11:29:15 2019 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | 11 | from statecircle.types.state import GaussianState 12 | from statecircle.utils.common import list_logical_index, list_index, normalize_log_weights 13 | from .base import Reductor 14 | 15 | 16 | class HypothesisReductor(Reductor): 17 | r""" Hypothesis trees reductor """ 18 | 19 | def __init__(self, weight_min, merging_threshold, capping_num, 20 | prob_min=None, prob_recycle=None): 21 | r""" 22 | 23 | Parameters 24 | ---------- 25 | weight_min : used for prune method 26 | merging_threshold : used for merge method 27 | capping_num : used for cap method 28 | prob_min, prob_recycle: used for recycle method 29 | """ 30 | self.weight_min = weight_min 31 | self.log_weight_min = np.log(weight_min) 32 | self.merging_threshold = merging_threshold 33 | self.capping_num = capping_num 34 | self.prob_min = prob_min 35 | self.prob_recycle = prob_recycle 36 | 37 | def prune(self, log_weights, hypo_tree): 38 | r"""Prune hypothesis with weights lower than threshold 39 | 40 | Parameters 41 | ---------- 42 | hypo_tree 43 | threshold 44 | 45 | Returns 46 | ------- 47 | log_weights, hypo_tree 48 | """ 49 | if len(log_weights) == 0: 50 | return log_weights, hypo_tree 51 | 52 | idx = log_weights > self.log_weight_min 53 | log_weights = log_weights[idx] 54 | hypo_tree = list_logical_index(hypo_tree, idx) 55 | return log_weights, hypo_tree 56 | 57 | def cap(self, log_weights, hypo_tree): 58 | r"""Capping hypotheses let the number of hypotheses <= capping_num 59 | 60 | Parameters 61 | ---------- 62 | hypo_tree 63 | 64 | Returns 65 | ------- 66 | log_weights, hypo_tree 67 | """ 68 | if len(log_weights) == 0: 69 | return log_weights, hypo_tree 70 | 71 | if len(log_weights) <= self.capping_num: 72 | return log_weights, hypo_tree 73 | 74 | idx = np.argsort(-log_weights)[:self.capping_num] 75 | return log_weights[idx], list_index(hypo_tree, idx) 76 | 77 | def moment_matching(self, log_weights, states): 78 | r"""Moment matching for 1st/2nd moments 79 | 80 | Parameters 81 | ---------- 82 | log_weights 83 | states 84 | 85 | Returns 86 | ------- 87 | state 88 | """ 89 | if len(log_weights) == 0: 90 | return states 91 | 92 | if len(log_weights) == 1: 93 | return states[0] 94 | 95 | w = np.exp(log_weights) 96 | 97 | num_states = len(states) 98 | # TODO: support different states 99 | if num_states == 0: 100 | return None 101 | 102 | state = type(states[0])(0., 0.) 103 | for i in range(num_states): 104 | state.mean = state.mean + w[i] * states[i].mean 105 | 106 | for i in range(num_states): 107 | state.cov = state.cov + w[i] * (states[i].cov + np.outer((state.mean - states[i].mean), (state.mean - states[i].mean))) 108 | 109 | if hasattr(states[0], 'label'): 110 | max_idx = log_weights.argmax() 111 | state.label = states[max_idx].label 112 | 113 | return state 114 | 115 | def merge(self, log_weights, states): 116 | r"""Merge state components 117 | 118 | Parameters 119 | ---------- 120 | log_weights 121 | states 122 | 123 | Returns 124 | ------- 125 | log_weights, states 126 | 127 | """ 128 | if len(log_weights) <= 1: 129 | return log_weights, states 130 | 131 | # index set of components 132 | I = list(range(len(states))) 133 | el = 0 134 | log_w_hat = [] 135 | states_hat = [] 136 | while len(I) != 0: 137 | Ij = [] 138 | # find the component with the highest weight 139 | j = np.argmax(log_weights) 140 | 141 | for i in I: 142 | tmp = (states[i].mean - states[j].mean)[:,0] 143 | val = tmp.dot(np.linalg.inv(states[j].cov)).dot(tmp) 144 | 145 | # find other similar components in the sense of small Mahalnobis distance 146 | # TODO: use a percentile to calulate the merging threshold like the gating size 147 | if val < self.merging_threshold: 148 | Ij.append(i) 149 | 150 | # merge components by moment matching 151 | tmp, log_w_upd = normalize_log_weights(np.array([log_weights[idx] for idx in Ij])) 152 | states_upd = self.moment_matching(tmp, [states[idx] for idx in Ij]) 153 | 154 | log_w_hat.append(log_w_upd) 155 | states_hat.append(states_upd) 156 | 157 | # remove indices of merged components from index set 158 | I = set(I) - set(Ij) 159 | # set a negative to make sure this component won't be selected again 160 | for idx in Ij: 161 | log_weights[idx] = np.log(1e-100) 162 | el = el + 1 163 | 164 | return np.array(log_w_hat), states_hat 165 | -------------------------------------------------------------------------------- /statecircle/reductor/tests/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/reductor/tests/test_gates.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 16:00:54 2019 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | from scipy.stats import multivariate_normal 11 | from pytest import approx 12 | 13 | from ..gate import EllipsoidalGate, RectangularGate 14 | from statecircle.types.state import GaussianState 15 | from statecircle.models.measurement.linear import LinearMeasurementModel 16 | 17 | 18 | def test_retangular_gate(): 19 | pass 20 | 21 | 22 | def test_ellipsoidal_gate(): 23 | ndim = 4 24 | ndim_meas = 2 25 | num_samples = 10000 26 | percentile = 0.3 27 | gate = EllipsoidalGate(percentile=percentile) 28 | mean = np.random.rand(ndim, 1) 29 | cov = np.eye(4) 30 | state = GaussianState(mean, cov) 31 | meas_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], sigma=2) 32 | input_states = multivariate_normal.rvs(mean[:, 0], cov, num_samples).T 33 | meas = meas_model.forward(input_states, noisy=True) 34 | meas_ingate, meas_index = gate.gating(state, meas, meas_model) 35 | assert meas_ingate.shape[1] == meas_index.sum() 36 | assert meas_ingate.shape[1] / num_samples == approx(percentile, rel=1e-1) 37 | -------------------------------------------------------------------------------- /statecircle/trackers/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 15:02:46 2019 6 | 7 | @author: zhxm 8 | """ 9 | from abc import abstractmethod 10 | 11 | from ..base import Base 12 | 13 | 14 | class Filter(Base): 15 | """ Filter class 16 | Filter can estimate current state using the history measurements up to current 17 | time, but can not output the whole trajectory 18 | 19 | Attributes 20 | ---------- 21 | 22 | """ 23 | def __init__(self, birth_model, density_model, transition_model, 24 | measurement_model, clutter_model, gate, estimator, reductor=None): 25 | self.birth_model = birth_model 26 | self.density_model = density_model 27 | self.transition_model = transition_model 28 | self.measurement_model = measurement_model 29 | self.clutter_model = clutter_model 30 | self.gate = gate 31 | self.estimator = estimator 32 | self.reductor = reductor 33 | self.timestamp = None 34 | 35 | @abstractmethod 36 | def predict(self, *args, **kwargs): 37 | pass 38 | 39 | def birth(self, *args, **kwargs): 40 | pass 41 | 42 | @abstractmethod 43 | def update(self, *args, **kwargs): 44 | pass 45 | 46 | @abstractmethod 47 | def estimate(self, *args, **kwargs): 48 | pass 49 | 50 | def reduction(self, *args, **kwargs): 51 | pass 52 | 53 | @abstractmethod 54 | def filtering(self, *args, **kwargs): 55 | pass 56 | 57 | 58 | class Tracker(Filter): 59 | """ Tracker class 60 | Tracker can estimate current state and current state label at the same time 61 | using the history measurements up to currnet time 62 | 63 | Attributes 64 | ---------- 65 | """ 66 | 67 | class SingleObjectTracker(Tracker): 68 | """ Single Object Tracker """ 69 | 70 | 71 | class MultiObjectFilter(Filter): 72 | """ Multi-object filter """ 73 | 74 | 75 | class MultiObjectTracker(Tracker): 76 | """ Multi-Object trackers """ 77 | -------------------------------------------------------------------------------- /statecircle/trackers/mot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/trackers/mot/global_nearest_neighbour_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 18 16:01:24 2019 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | from scipy.optimize import linear_sum_assignment 11 | import matplotlib.pyplot as plt 12 | from tqdm import tqdm 13 | 14 | from statecircle.trackers.base import MultiObjectTracker 15 | from statecircle.utils.data_association import data_association 16 | 17 | 18 | class GlobalNearestNeighbourTracker(MultiObjectTracker): 19 | """ 20 | Global Nearest Neighbour tracker with known objects number 21 | 22 | The assumed density is Gaussian state 23 | """ 24 | def __init__(self, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | # Every local hypothesis tree contains only one local hypothesis for GNN tracker. 27 | self.hypo_forest = [] 28 | # Number of object 29 | self.num_trees = 0 30 | 31 | self.birth_initialized = False 32 | 33 | def predict(self, meas_data): 34 | # TODO: implenment time_step using datetime 35 | timestamp = meas_data.timestamp 36 | if self.timestamp is None: 37 | self.timestamp = timestamp 38 | return 39 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 40 | 41 | for i, hypo in enumerate(self.hypo_forest): 42 | # NOTE: shallow copy 43 | # predict each local hypothesis 44 | hypo[0] = self.density_model.predict(hypo[0], time_step, self.transition_model) 45 | 46 | def birth(self, meas_data): 47 | if not self.birth_initialized: 48 | # initialize birth state 49 | self.hypo_forest = [[state] for state in self.birth_model.birth()] 50 | self.num_trees = self.num_objs = len(self.hypo_forest) 51 | self.birth_initialized = True 52 | 53 | def update(self, meas_data): 54 | meas = meas_data.meas 55 | detection_rate = self.clutter_model.detection_rate 56 | intensity_clutter = self.clutter_model.intensity_clutter 57 | num_meas = meas.shape[1] 58 | num_objs = self.num_trees 59 | 60 | # create the cost matrix 61 | cost_mat = np.zeros([self.num_trees, num_meas + self.num_trees]) 62 | for i in range(self.num_trees): 63 | hypo = self.hypo_forest[i][0] 64 | # ellipsoidal gating for each predicted local hypotheses separately 65 | meas_ingate, meas_index = self.gate.gating(hypo, meas, self.measurement_model) 66 | # construct 2D cost matrix of size (num_objs x (num_meas in gate + num_objs)) 67 | pred_loglik = self.density_model.predicted_log_likelihood(hypo, meas_ingate, self.measurement_model) 68 | li0 = -np.inf * np.ones(num_objs) 69 | lij = -np.inf * np.ones(num_meas) 70 | li0[i] = np.log(1 - detection_rate) 71 | lij[meas_index] = np.log(detection_rate) - np.log(intensity_clutter) + pred_loglik 72 | cost_mat[i, :] = np.hstack((-lij, -li0)) 73 | 74 | # TODO: simplify the cost matrix 75 | # delete the missed detected cols from cost matrix 76 | # detected_cols = np.any(cost_mat != inf, 0) 77 | # associated_num_meas = cost_mat.shape[1] - num_objs 78 | # cost_mat = cost_mat[:, detected_cols] 79 | 80 | # find the best assignment matrix using a 2D assignment solver 81 | __use_hungarian_algorithm = False 82 | if __use_hungarian_algorithm: 83 | row_idx, col_idx = linear_sum_assignment(cost_mat) 84 | else: 85 | col_idx, cost = data_association(cost_mat, topN=1) 86 | theta_t = col_idx 87 | # theta == -1 means miss detection 88 | theta_t[theta_t > num_meas - 1] = -1 89 | 90 | # create new local hypothesis according to the best assignmet matrix obtained 91 | for i in range(num_objs): 92 | if theta_t[i] >= 0: 93 | self.hypo_forest[i][0] = self.density_model.update(self.hypo_forest[i][0], meas[:, theta_t[i]], self.measurement_model) 94 | 95 | def estimate(self): 96 | # extract object state estimates 97 | est_t = [self.estimator(hypo[0]) for hypo in self.hypo_forest] 98 | return np.hstack(est_t) 99 | 100 | def filtering(self, data_reader): 101 | estimates = [] 102 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 103 | # predict 104 | self.predict(meas_data) 105 | 106 | # birth 107 | self.birth(meas_data) 108 | 109 | # update 110 | self.update(meas_data) 111 | 112 | # estimate 113 | estimates.append(self.estimate()) 114 | 115 | # reduction 116 | self.reduction() 117 | 118 | return estimates 119 | 120 | -------------------------------------------------------------------------------- /statecircle/trackers/mot/joint_probabilistic_data_association_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 19 10:39:45 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | from scipy.misc import logsumexp 11 | from tqdm import tqdm 12 | 13 | from statecircle.trackers.base import MultiObjectTracker 14 | from statecircle.utils.data_association import data_association 15 | from statecircle.utils.common import normalize_log_weights 16 | 17 | class JointProbabilisticDataAssociationTracker(MultiObjectTracker): 18 | """ 19 | Joint Probability Data Association Filter with known objects number 20 | 21 | The assumed density is Gaussian state 22 | """ 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | # hypotheses forest, every term in hypotheses forest is a local hypotheses 26 | # tree, each tree has some local hypotheses (leaves) 27 | # each tree only contains one leaf in JPDA tracker 28 | self.hypo_forest = [] 29 | # Number of local hypotheses trees / object 30 | self.num_trees = self.num_objs = 0 31 | 32 | self.birth_initialized = False 33 | 34 | def predict(self, meas_data): 35 | # TODO: implenment time_step using datetime 36 | timestamp = meas_data.timestamp 37 | if self.timestamp is None: 38 | self.timestamp = timestamp 39 | return 40 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 41 | 42 | for i, hypo in enumerate(self.hypo_forest): 43 | # NOTE: shallow copy 44 | # predict each local hypothesis 45 | hypo[0] = self.density_model.predict(hypo[0], time_step, self.transition_model) 46 | 47 | def birth(self, meas_data): 48 | if not self.birth_initialized: 49 | # initialize birth state 50 | self.hypo_forest = [[state] for state in self.birth_model.birth()] 51 | self.num_trees = self.num_objs = len(self.hypo_forest) 52 | self.birth_initialized = True 53 | 54 | def update(self, meas_data): 55 | meas = meas_data.meas 56 | detection_rate = self.clutter_model.detection_rate 57 | intensity_clutter = self.clutter_model.intensity_clutter 58 | num_meas = meas.shape[1] 59 | 60 | cost_mat = np.zeros((self.num_objs, num_meas + self.num_objs)) 61 | for i in range(self.num_objs): 62 | hypo = self.hypo_forest[i][0] 63 | # 1.ellipsoidal gating for each predict local hypothesis seperately 64 | [meas_ingate, meas_index] = self.gate.gating(hypo, meas, self.measurement_model) 65 | 66 | # 2.construct 2D cost matrix of size (num_meas in gate + self.num_objs) 67 | pred_loglik = self.density_model.predicted_log_likelihood(hypo, meas_ingate, self.measurement_model) 68 | li0 = -np.inf * np.ones(self.num_objs) 69 | lij = -np.inf * np.ones(num_meas) 70 | li0[i] = np.log(1 - detection_rate) 71 | lij[meas_index] = np.log(detection_rate) - np.log(intensity_clutter) + pred_loglik 72 | cost_mat[i,:] = np.hstack((-lij, -li0)) 73 | 74 | # TODO: simplify the cost matrix 75 | # 3.find the M best assignment matrices using a M-best 2D assignment solver 76 | col, cost = data_association(cost_mat, self.reductor.capping_num) 77 | theta_t = col 78 | log_weights = -cost 79 | 80 | # 4.normalize the weights of different data association hypotheses 81 | log_weights, log_sum_weights = normalize_log_weights(log_weights) 82 | 83 | # 5.prune assignment matrices that correspond to data association hypotheses 84 | # with low weights and renormalize the weights 85 | log_weights, keep_idx = self.reductor.prune(log_weights, np.arange(theta_t.shape[1])) 86 | log_weights, log_sum_weights = normalize_log_weights(log_weights) 87 | theta_t = theta_t[:,keep_idx] 88 | # '-1' means missed detection hypothesis 89 | theta_t[theta_t > num_meas - 1] = -1 90 | for i in range(self.num_objs): 91 | hypo = self.hypo_forest[i][0] 92 | # 6.create new local hypotheses for each of the data asspcoatopm results 93 | log_marginal_weights = [] 94 | multi_hypo = [] 95 | for it, j in enumerate(np.unique(theta_t[i, :])): 96 | meas_idx_i = theta_t[i,:] == j 97 | if np.any(meas_idx_i): 98 | multi_hypo.append(hypo if j == -1 \ 99 | else self.density_model.update(hypo, meas[:,j], self.measurement_model)) 100 | log_marginal_weights.append(logsumexp(log_weights[meas_idx_i])) 101 | 102 | # 7.merge local hypotheses theta correspond to the same object by moment matching 103 | self.hypo_forest[i][0] = self.reductor.moment_matching(log_marginal_weights, multi_hypo) 104 | 105 | def estimate(self): 106 | # extract object state estimates 107 | est_t = [self.estimator(hypo[0]) for hypo in self.hypo_forest] 108 | return np.hstack(est_t) 109 | 110 | def filtering(self, data_reader): 111 | estimates = [] 112 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 113 | # predict 114 | self.predict(meas_data) 115 | 116 | # birth 117 | self.birth(meas_data) 118 | 119 | # update 120 | self.update(meas_data) 121 | 122 | # estimate 123 | estimates.append(self.estimate()) 124 | 125 | # reduction 126 | self.reduction() 127 | 128 | return estimates -------------------------------------------------------------------------------- /statecircle/trackers/mot/phd_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 19 18:17:12 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | from tqdm import tqdm 11 | from copy import deepcopy 12 | 13 | from statecircle.trackers.base import MultiObjectFilter 14 | from statecircle.types.state import GaussianSumState, PoissonState 15 | 16 | 17 | class PHDFilter(MultiObjectFilter): 18 | r"""Gaussian mixture intensity probability hypothesis density filter 19 | 20 | The Gaussian mixture is a intensity (NOT density) function in PHD, and the assumed 21 | density in PHD is Poisson random finite sets(PRFS) or Poisson point process(PPP) 22 | """ 23 | 24 | def __init__(self, surviving_rate, *args, **kwargs): 25 | super().__init__(*args, **kwargs) 26 | self.surviving_rate = surviving_rate 27 | self.poisson_state = PoissonState(GaussianSumState()) 28 | 29 | def predict(self, meas_data): 30 | # TODO: implenment time_step using datetime 31 | timestamp = meas_data.timestamp 32 | if self.timestamp is None: 33 | self.timestamp = timestamp 34 | return 35 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 36 | 37 | # predict performs PPP predicion step 38 | # predict each Gaussian component in the intensity for pre-existing objects 39 | log_weights = np.log(self.surviving_rate) + self.poisson_state.intensity.log_weights 40 | gaussian_states = [self.density_model.predict(state, time_step, self.transition_model) 41 | for state in self.poisson_state.intensity.gaussian_states] 42 | self.poisson_state.intensity = GaussianSumState(log_weights, gaussian_states) 43 | 44 | def birth(self, meas_data): 45 | # add (Gaussian mixture) Poisson birth intensity to (Gaussian mixture) 46 | # Poisson intensity for pre-existing objects 47 | birth_state = self.birth_model.birth().intensity 48 | self.poisson_state.intensity += deepcopy(birth_state) 49 | 50 | def update(self, meas_data): 51 | meas = meas_data.meas 52 | detection_rate = self.clutter_model.detection_rate 53 | intensity_clutter = self.clutter_model.intensity_clutter 54 | num_comps = self.poisson_state.intensity.num_comps 55 | num_meas = meas.shape[1] 56 | 57 | # 1) construct update components resulted from missed detection 58 | log_weights_missed = np.log(1 - detection_rate) + self.poisson_state.intensity.log_weights 59 | intensity_missed = GaussianSumState(log_weights_missed, self.poisson_state.intensity.gaussian_states) 60 | 61 | # 2) perform ellipsoidal gating for each Gaussian component in 62 | # the Poisson intensity 63 | gating_mat = np.zeros((num_comps, num_meas), dtype=np.bool) 64 | for i, state in enumerate(self.poisson_state.intensity.gaussian_states): 65 | _, meas_index = self.gate.gating(state, meas, self.measurement_model) 66 | gating_mat[i, :] = meas_index 67 | 68 | meas_in_all_gates = np.any(gating_mat != 0, axis=0) 69 | meas = meas[:, meas_in_all_gates] 70 | gating_mat = gating_mat[:, meas_in_all_gates] 71 | num_meas = meas.shape[1] 72 | 73 | # 3) construct update components resulted from object 74 | # detections that are inside the gates 75 | states_update = [] 76 | log_weights_update = [] 77 | for i in range(num_meas): 78 | log_weights_unnorm = [] 79 | for h, state in enumerate(self.poisson_state.intensity.gaussian_states): 80 | if gating_mat[h, i]: 81 | # 3) construct update components resulted from object 82 | # detections that are inside the gates 83 | pred_loglik = self.density_model.predicted_log_likelihood(state, 84 | meas[:, i], 85 | self.measurement_model) 86 | 87 | # TODO: optimization for the update parameters without measurement input 88 | # TODO: (mean, covariance for measurement, Kalman gain and updated covariance) 89 | states_update.append(self.density_model.update(state, meas[:, i], self.measurement_model)) 90 | log_weights_unnorm.append(np.log(detection_rate) + self.poisson_state.intensity.log_weights[h] + pred_loglik) 91 | 92 | log_weights_update.append(log_weights_unnorm - \ 93 | np.log(intensity_clutter + np.sum(np.exp(log_weights_unnorm)))) 94 | if len(log_weights_update) > 0: 95 | intensity_update = GaussianSumState(np.hstack(log_weights_update), states_update) 96 | else: # there is no meansurement in all gates 97 | intensity_update = GaussianSumState() 98 | 99 | self.poisson_state.intensity = intensity_missed + intensity_update 100 | 101 | def reduction(self): 102 | # component reduction approximates the PPP by representing its intensity 103 | # with fewer parameters 104 | reduction_function_handles = (self.reductor.prune, self.reductor.merge, self.reductor.cap) 105 | for func in reduction_function_handles: 106 | self.poisson_state.intensity = GaussianSumState(*func(self.poisson_state.intensity.log_weights, 107 | self.poisson_state.intensity.gaussian_states)) 108 | 109 | def estimate(self): 110 | # PHD estimator performs object state estimation in the GMPHD filter 111 | 112 | # 1) get a mean estimate of the cadinality of objects by 113 | # taking the summation of the weights of the Gaussian 114 | # components (rounded to the nearest integer), denoted as n 115 | card_mean = np.int(np.minimum(self.poisson_state.intensity.num_comps, 116 | np.round(self.poisson_state.mean) 117 | ) 118 | ) 119 | 120 | # 2) extract n object states from the means of the n Gaussian components 121 | # with the hightest weights 122 | keep_idx = np.argsort(-self.poisson_state.intensity.log_weights)[:card_mean] 123 | est = [self.poisson_state.intensity.gaussian_states[idx].mean for idx in keep_idx] 124 | 125 | if len(est) == 0: 126 | return np.empty((self.transition_model.ndim, 0)) 127 | 128 | return np.hstack(est) 129 | 130 | def estimate_duplicated(self, *args, **kwargs): 131 | raise NotImplementedError 132 | 133 | def filtering(self, data_reader): 134 | estimates = [] 135 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 136 | # PPP prediction 137 | self.predict(meas_data) 138 | 139 | # PPP birth 140 | self.birth(meas_data) 141 | 142 | # PPP update 143 | self.update(meas_data) 144 | 145 | # extract state estimates from PPP 146 | estimates.append(self.estimate()) 147 | 148 | # PPP approximation 149 | self.reduction() 150 | 151 | return estimates 152 | -------------------------------------------------------------------------------- /statecircle/trackers/mot/prototype_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Sat May 30 13:53:12 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from .md_tracker import MDTracker 13 | 14 | class ProtoTypeTracker(MDTracker): 15 | r""" EKF SLAM ego tracker + PMBM tracker """ 16 | def __init__(self, ego_tracker, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.ego_tracker = ego_tracker 19 | 20 | def ego_motion_update(self, meas_data): 21 | # covert km/h and degree/s to m/s and rad/s 22 | vel = meas_data.ego_info['Velocity'] / 3.6 23 | yaw_rate = meas_data.ego_info['YawRate'] * -np.pi / 180 24 | meas_data.ego_meas = np.array([vel, yaw_rate])[:, None] 25 | 26 | # ego predict 27 | self.ego_tracker.predict(meas_data) 28 | 29 | # ego birth 30 | self.ego_tracker.birth() 31 | 32 | # ego update 33 | self.ego_tracker.update(meas_data) 34 | 35 | # ego reduct 36 | self.ego_tracker.reduction() 37 | 38 | ego_pose = self.ego_tracker.state.mean[:3, 0] # [ego_x, ego_y, theta] 39 | landmarks = self.ego_tracker.state.mean[3:, 0].reshape(-1, 2).T # landmarks [x1, y1, ..., xn, yn] 40 | state_cov = self.ego_tracker.state.cov 41 | return ego_pose, landmarks, state_cov 42 | 43 | def filtering(self, data_reader, cut_latency_time=10): 44 | # tracks multiple objects using Poisson multi-Bernoulli mixture filter 45 | estimates = [] 46 | ego_trace = [] 47 | landmark_trace = [] 48 | cov_trace = [] 49 | for self.step, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 50 | # ego car EKF-SLAM 51 | ego_pose, landmarks, state_cov = self.ego_motion_update(meas_data) 52 | ego_trace.append(ego_pose) 53 | landmark_trace.append(landmarks) 54 | cov_trace.append(state_cov) 55 | 56 | # ego car motion compensation 57 | self.ego_motion_compensation(meas_data) 58 | 59 | # PMBM prediction 60 | self.predict(meas_data) 61 | 62 | # PMBM birth 63 | self.birth(meas_data) 64 | 65 | # PMBM update 66 | self.update(meas_data) 67 | 68 | # extract state estimates from PMBM 69 | estimates.append(self.estimate()) 70 | 71 | # Bern recycling & reduction and PPP reduction 72 | self.reduction(cut_latency_time) 73 | 74 | return estimates, np.array(ego_trace), landmark_trace, cov_trace -------------------------------------------------------------------------------- /statecircle/trackers/sot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ -------------------------------------------------------------------------------- /statecircle/trackers/sot/ego_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Sun May 24 18:31:19 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from ..base import SingleObjectTracker 13 | from statecircle.types.state import GaussianState 14 | 15 | 16 | class EgoTracker(SingleObjectTracker): 17 | r""" ego tracker 18 | 19 | The assumed density is Gaussian state 20 | """ 21 | def __init__(self, meas_models_dict, clutter_models_dict, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | # cyclic state in sot 24 | self.state = None 25 | self.birth_initialized = False 26 | self.meas_models_dict = meas_models_dict 27 | self.clutter_models_dict = clutter_models_dict 28 | 29 | def update(self, meas_data): 30 | # TODO: reformat the update input parameters, hard coding for now... 31 | meas = meas_data.ego_meas 32 | 33 | detection_rate = self.clutter_model.detection_rate 34 | intensity_clutter = self.clutter_model.intensity_clutter 35 | 36 | meas_ingate, meas_index = self.gate.gating(self.state, meas, self.measurement_model) 37 | pred_loglik = self.density_model.predicted_log_likelihood(self.state, meas_ingate, self.measurement_model) 38 | 39 | num_meas_ingate = meas_ingate.shape[1] 40 | log_weights_unnorm = np.empty(num_meas_ingate + 1) 41 | log_weights_unnorm[0] = np.log(1 - detection_rate) 42 | log_weights_unnorm[1:] = np.log(detection_rate) + pred_loglik - np.log(intensity_clutter) 43 | 44 | theta_max = np.argmax(log_weights_unnorm) 45 | if theta_max == 0: 46 | # missed, not update state 47 | pass 48 | else: 49 | # detected 50 | self.state = self.density_model.update(self.state, meas_ingate[:, theta_max - 1], self.measurement_model) 51 | 52 | def estimate(self): 53 | return self.estimator(self.state) 54 | 55 | def predict(self, meas_data): 56 | # TODO: implenment time_step using datetime 57 | timestamp = meas_data.timestamp 58 | if self.timestamp is None: 59 | self.timestamp = timestamp 60 | return 61 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 62 | 63 | self.state = self.density_model.predict(self.state, time_step, self.transition_model) 64 | 65 | def birth(self, meas_data): 66 | if not self.birth_initialized: 67 | # initialize birth state 68 | vel, yaw_rate = meas_data.ego_meas[0, 0], meas_data.ego_meas[1, 0] 69 | self.state = GaussianState(mean=np.array([[0, 0, vel, 0, yaw_rate]]).T, cov=np.zeros((5,5))) 70 | self.birth_initialized = True 71 | # TODO: add radar meas 72 | # use UNKNOWN measurement model first 73 | self.measurement_model = self.meas_models_dict[-1] 74 | self.clutter_model = self.clutter_models_dict[-1] 75 | 76 | def filtering(self, data_reader): 77 | estimates = [] 78 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 79 | # covert km/h and degree/s to m/s and rad/s 80 | vel = meas_data.ego_info['Velocity'] / 3.6 81 | yaw_rate = meas_data.ego_info['YawRate'] * -np.pi / 180 82 | meas_data.ego_meas = np.array([vel, yaw_rate])[:, None] 83 | 84 | # predict 85 | self.predict(meas_data) 86 | 87 | # birth 88 | self.birth(meas_data) 89 | 90 | # gating & update 91 | self.update(meas_data) 92 | 93 | # estimate 94 | estimates.append(self.estimate()) 95 | 96 | return estimates 97 | -------------------------------------------------------------------------------- /statecircle/trackers/sot/gaussian_sum_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 18 13:58:56 2019 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from ..base import SingleObjectTracker 13 | from statecircle.utils.common import normalize_log_weights 14 | 15 | 16 | class GaussianSumTracker(SingleObjectTracker): 17 | r"""Gaussian sum filter for single object 18 | 19 | In fact, the assumed density in `GaussianSumTracker` is a Gaussian mixture state, 20 | not a Gaussian sum intensity. 21 | """ 22 | 23 | def __init__(self, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | # TODO: use specified hypothesis type 26 | self.hypo_tree = [] 27 | self.log_weights = np.empty([0]) 28 | self.birth_initialized = False 29 | 30 | def predict(self, meas_data): 31 | # TODO: implenment time_step using datetime 32 | timestamp = meas_data.timestamp 33 | if self.timestamp is None: 34 | self.timestamp = timestamp 35 | return 36 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 37 | 38 | # for each hypothesis, perform prediction 39 | self.hypo_tree = [self.density_model.predict(hypo, time_step, self.transition_model) 40 | for hypo in self.hypo_tree] 41 | 42 | def birth(self, meas_data): 43 | if not self.birth_initialized: 44 | self.hypo_tree = [self.birth_model.birth()] 45 | self.log_weights = np.zeros(1) 46 | self.birth_initialized = True 47 | 48 | def update(self, meas_data): 49 | meas = meas_data.meas 50 | detection_rate = self.clutter_model.detection_rate 51 | intensity_clutter = self.clutter_model.intensity_clutter 52 | log_weights, hypo_tree = [], [] 53 | for i, hypo in enumerate(self.hypo_tree): 54 | meas_ingate, meas_index = self.gate.gating(hypo, meas, self.measurement_model) 55 | num_meas_ingate = meas_ingate.shape[1] 56 | 57 | # create missed detection hypothesis for each hypothesis 58 | hypo_tree.append(hypo) 59 | log_weights.append(self.log_weights[i] + np.log(1 - detection_rate)) 60 | 61 | if num_meas_ingate > 0: 62 | # create object detection hypothesis for each detection inside the gate 63 | pred_loglik = self.density_model.predicted_log_likelihood(hypo, meas_ingate, self.measurement_model) 64 | log_weights.append(self.log_weights[i] + np.log(detection_rate) + pred_loglik - np.log(intensity_clutter)) 65 | for k in range(num_meas_ingate): 66 | hypo_tree.append(self.density_model.update(hypo, meas_ingate[:, k], self.measurement_model)) 67 | log_weights = np.hstack(log_weights) 68 | 69 | # normalize hypothesis weights 70 | log_weights, log_suw_weihts = normalize_log_weights(log_weights) 71 | 72 | # return 73 | self.log_weights = log_weights 74 | self.hypo_tree = hypo_tree 75 | 76 | def estimate(self): 77 | # extract object state estimate using the most probably hypothesis estimation 78 | max_idx = np.argmax(self.log_weights) 79 | return self.estimator(self.hypo_tree[max_idx]) 80 | 81 | def reduction(self): 82 | # prune hypothesis with samll weights, and then re-normalize the weights 83 | log_weights, hypo_tree = self.reductor.prune(self.log_weights, self.hypo_tree) 84 | log_weights, log_sum_w = normalize_log_weights(log_weights) 85 | 86 | # hypothesis merging 87 | log_weights, hypo_tree = self.reductor.merge(log_weights, hypo_tree) 88 | 89 | # cap the number of the hypothesis, and then re-normalize the weights 90 | log_weights, hypo_tree = self.reductor.cap(log_weights, hypo_tree) 91 | log_weights, log_suw_weihts = normalize_log_weights(log_weights) 92 | 93 | # return 94 | self.log_weights = log_weights 95 | self.hypo_tree = hypo_tree 96 | 97 | def filtering(self, data_reader): 98 | estimates = [] 99 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 100 | #predict 101 | self.predict(meas_data) 102 | 103 | # birth 104 | self.birth(meas_data) 105 | 106 | # update 107 | self.update(meas_data) 108 | 109 | # estimate 110 | estimates.append(self.estimate()) 111 | 112 | # reduction 113 | self.reduction() 114 | 115 | return estimates 116 | -------------------------------------------------------------------------------- /statecircle/trackers/sot/nearest_neighbour_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 15:07:05 2019 6 | 7 | @author: zhxm 8 | """ 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | from ..base import SingleObjectTracker 13 | 14 | 15 | class NearestNeighbourTracker(SingleObjectTracker): 16 | r""" Nearest neighbour tracker 17 | 18 | The assumed density is Gaussian state 19 | """ 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | # cyclic state in sot 23 | self.state = None 24 | self.birth_initialized = False 25 | 26 | def update(self, meas_data): 27 | meas = meas_data.meas 28 | detection_rate = self.clutter_model.detection_rate 29 | intensity_clutter = self.clutter_model.intensity_clutter 30 | 31 | meas_ingate, meas_index = self.gate.gating(self.state, meas, self.measurement_model) 32 | pred_loglik = self.density_model.predicted_log_likelihood(self.state, meas_ingate, self.measurement_model) 33 | 34 | num_meas_ingate = meas_ingate.shape[1] 35 | log_weights_unnorm = np.empty(num_meas_ingate + 1) 36 | log_weights_unnorm[0] = np.log(1 - detection_rate) 37 | log_weights_unnorm[1:] = np.log(detection_rate) + pred_loglik - np.log(intensity_clutter) 38 | 39 | theta_max = np.argmax(log_weights_unnorm) 40 | if theta_max == 0: 41 | # missed, not update state 42 | pass 43 | else: 44 | # detected 45 | self.state = self.density_model.update(self.state, meas_ingate[:, theta_max - 1], self.measurement_model) 46 | 47 | def estimate(self): 48 | return self.estimator(self.state) 49 | 50 | def predict(self, meas_data): 51 | # TODO: implenment time_step using datetime 52 | timestamp = meas_data.timestamp 53 | if self.timestamp is None: 54 | self.timestamp = timestamp 55 | return 56 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 57 | 58 | self.state = self.density_model.predict(self.state, time_step, self.transition_model) 59 | 60 | def birth(self, meas_data): 61 | if not self.birth_initialized: 62 | # initialize birth state 63 | self.state = self.birth_model.birth() 64 | self.birth_initialized = True 65 | 66 | def filtering(self, data_reader): 67 | estimates = [] 68 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 69 | # predict 70 | self.predict(meas_data) 71 | 72 | # birth 73 | self.birth(meas_data) 74 | 75 | # gating & update 76 | self.update(meas_data) 77 | 78 | # estimate 79 | estimates.append(self.estimate()) 80 | 81 | return estimates 82 | -------------------------------------------------------------------------------- /statecircle/trackers/sot/probabilistic_data_association_tracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Wed Dec 18 11:01:58 2019 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from ..base import SingleObjectTracker 13 | from statecircle.utils.common import normalize_log_weights 14 | 15 | 16 | class ProbabilisticDataAssociationTracker(SingleObjectTracker): 17 | r""" Single object tracker using probabilistic data association 18 | 19 | The assumed density is Gaussian state 20 | """ 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | # cyclic state in sot 24 | self.state = None 25 | self.birth_initialized = False 26 | 27 | def predict(self, meas_data): 28 | # TODO: implenment time_step using datetime 29 | timestamp = meas_data.timestamp 30 | if self.timestamp is None: 31 | self.timestamp = timestamp 32 | return 33 | self.timestamp, time_step = timestamp, timestamp - self.timestamp 34 | 35 | self.state = self.density_model.predict(self.state, time_step, self.transition_model) 36 | 37 | def birth(self, meas_data): 38 | if not self.birth_initialized: 39 | self.state = self.birth_model.birth() 40 | self.birth_initialized = True 41 | 42 | def update(self, meas_data): 43 | r""" 44 | Parameters 45 | ---------- 46 | state : `State` 47 | meas : Array, [meas_dim, 1] 48 | """ 49 | meas = meas_data.meas 50 | detection_rate = self.clutter_model.detection_rate 51 | intensity_clutter = self.clutter_model.intensity_clutter 52 | 53 | meas_ingate, meas_index = self.gate.gating(self.state, meas, self.measurement_model) 54 | pred_loglik = self.density_model.predicted_log_likelihood(self.state, meas_ingate, self.measurement_model) 55 | 56 | num_meas_ingate = meas_ingate.shape[1] 57 | log_weights_unnorm = np.empty(num_meas_ingate + 1) 58 | 59 | # generate hypothesis tree 60 | hypo_tree = [] 61 | 62 | # missed detection hypothesis 63 | log_weights_unnorm[0] = np.log(1 - detection_rate) 64 | hypo_tree.append(self.state) 65 | 66 | # object detection hypothesis 67 | log_weights_unnorm[1:] = np.log(detection_rate) + pred_loglik - np.log(intensity_clutter) 68 | for i in range(1, num_meas_ingate + 1): 69 | hypo_tree.append(self.density_model.update(self.state, meas_ingate[:, i - 1], self.measurement_model)) 70 | 71 | # normalize hypothesis weights 72 | log_weights, log_sum_weights = normalize_log_weights(log_weights_unnorm) 73 | 74 | # prune hypothesis 75 | log_weights, hypo_tree = self.reductor.prune(log_weights, hypo_tree) 76 | 77 | # re-normalize 78 | log_weights, log_sum_w = normalize_log_weights(log_weights) 79 | 80 | # moment matching 81 | self.state = self.reductor.moment_matching(log_weights, hypo_tree) 82 | 83 | def estimate(self): 84 | return self.estimator(self.state) 85 | 86 | def filtering(self, data_reader): 87 | estimates = [] 88 | for t, meas_data in tqdm(enumerate(data_reader), total=len(data_reader)): 89 | # predict 90 | self.predict(meas_data) 91 | 92 | # birth 93 | self.birth(meas_data) 94 | 95 | # update 96 | self.update(meas_data) 97 | 98 | # estimate 99 | estimates.append(self.estimate()) 100 | 101 | # reduction 102 | self.reduction() 103 | 104 | return estimates 105 | -------------------------------------------------------------------------------- /statecircle/types/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /statecircle/types/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Thu Dec 5 11:49:17 2019 6 | 7 | @author: zhxm 8 | """ 9 | 10 | from ..base import Base 11 | 12 | 13 | class Type(Base): 14 | r"""Base Type""" 15 | 16 | def __repr__(self): 17 | attrs = ("{}:{!r}".format(name, value) for name, value in self.__dict__.items()) 18 | return "{}({})".format(type(self).__name__, ", ".join(attrs)) 19 | 20 | __str__ = __repr__ 21 | 22 | 23 | class Particles(Type): 24 | r""" Particle class 25 | Weighted particles used in particle filter 26 | 27 | Attributes 28 | ---------- 29 | pts : Matrix(2, num) 30 | paticle potins 31 | weights : Vector[num] 32 | particle weights 33 | """ 34 | 35 | def __init__(self, pts, weights): 36 | self.pts = pts 37 | self.weights = weights 38 | 39 | 40 | class SigmaPoints(Particles): 41 | r""" Sigma Points 42 | Sigma points used in unscented Kalman filter 43 | 44 | NOTE: number of sigma points = 2 * state_dim + 1 45 | 46 | Attributes 47 | ---------- 48 | pts : Matrix(2, num) 49 | Particle potins 50 | mean_weights : Vector(num) 51 | Particle first-order weights 52 | cov_weights : Vector(num) 53 | Particle second-order weights 54 | """ 55 | 56 | def __init__(self, points, mean_weighs, cov_weights): 57 | self.points = points 58 | self.mean_weights = mean_weighs 59 | self.cov_weights = cov_weights 60 | 61 | @property 62 | def mean(self): 63 | return self.points.dot(self.mean_weights) 64 | 65 | @property 66 | def cov(self): 67 | res = (self.points - self.mean) 68 | return res.dot(self.cov_weights).dot(res.T) 69 | -------------------------------------------------------------------------------- /statecircle/types/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | 6 | @author: zhxm 7 | """ 8 | import numpy as np 9 | 10 | from .base import Type 11 | 12 | 13 | class GroundTruthData(Type): 14 | r"""Ground truth data base type""" 15 | 16 | def __init__(self, timestamp, states): 17 | r""" 18 | 19 | Parameters 20 | ---------- 21 | timestamp : datatime 22 | states : State 23 | set of states 24 | """ 25 | self.timestamp = timestamp 26 | self.states = states 27 | 28 | 29 | class MeasurementData(Type): 30 | r"""Measurement data base type""" 31 | 32 | def __init__(self, timestamp, meas): 33 | r""" 34 | 35 | Parameters 36 | ---------- 37 | timestamp 38 | meas 39 | """ 40 | self.timestamp = timestamp 41 | self.meas = meas 42 | 43 | 44 | class DataSeries(Type): 45 | r"""Ground truth data series""" 46 | 47 | def __init__(self, time_len): 48 | self.timestamps = [None for _ in range(time_len)] 49 | self.datum = [[] for _ in range(time_len)] 50 | self.num = np.zeros(time_len, dtype=np.int_) 51 | 52 | def __getitem__(self, idx): 53 | return GroundTruthData(self.timestamps[idx], self.datum[idx]) 54 | 55 | 56 | class GroundTruthSeries(DataSeries): 57 | r"""Ground truth series""" 58 | 59 | 60 | class MeasurementSeries(DataSeries): 61 | r"""Measurement series""" 62 | -------------------------------------------------------------------------------- /statecircle/types/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/statecircle/types/tests/__init__.py -------------------------------------------------------------------------------- /statecircle/types/tests/test_states.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from statecircle.types.state import GaussianSumState 3 | 4 | def test_gaussian_sum_state(): 5 | state1 = GaussianSumState([1, 2, 3], [1, 2, 3]) 6 | state2 = GaussianSumState([4, 5, 6], [4, 5, 6]) 7 | state_sum = state1 + state2 8 | state_rsum = state2 + state1 9 | state_sum_ans = GaussianSumState([1, 2, 3, 4, 5, 6], 10 | [1, 2, 3, 4, 5, 6]) 11 | state_rsum_ans = GaussianSumState([4, 5, 6, 1, 2, 3], 12 | [4, 5, 6, 1, 2, 3]) 13 | np.testing.assert_allclose(state_sum.log_weights, state_sum_ans.log_weights) 14 | np.testing.assert_allclose(state_sum.gaussian_states, state_sum_ans.gaussian_states) 15 | np.testing.assert_allclose(state_rsum.log_weights, state_rsum_ans.log_weights) 16 | np.testing.assert_allclose(state_rsum.gaussian_states, state_rsum_ans.gaussian_states) 17 | 18 | state1 += state1 19 | np.testing.assert_allclose(state1.log_weights, [1, 2, 3, 1, 2, 3]) 20 | 21 | -------------------------------------------------------------------------------- /statecircle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Fri Jun 5 09:43:27 2020 6 | 7 | @author: zhxm 8 | """ 9 | 10 | -------------------------------------------------------------------------------- /statecircle/utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 16 15:56:16 2019 5 | 6 | @author: zhaoxm 7 | """ 8 | from importlib import import_module 9 | import numpy as np 10 | 11 | from scipy.stats import multivariate_normal as mvn 12 | from copy import deepcopy 13 | from collections import OrderedDict 14 | 15 | 16 | class IDontKnowHowToDoIt(Exception): 17 | pass 18 | 19 | 20 | 21 | def rotate2D(x, theta, origin=[0, 0]): 22 | r"""rotate 2d points 23 | 24 | Parameters 25 | ---------- 26 | x : array[2, #points] 27 | theta : float 28 | origin : point 29 | """ 30 | if len(x) == 0: 31 | return x 32 | theta *= np.pi / 180 33 | origin = np.atleast_2d(origin).T 34 | c = np.cos(theta) 35 | s = np.sin(theta) 36 | rot_mat = np.array([[c, -s], [s, c]]) 37 | return rot_mat.dot((x - origin)) + origin 38 | 39 | def init(self, **kwargs): 40 | for k, v in kwargs.items(): 41 | setattr(self, k, v) 42 | 43 | def repr_(self): 44 | return format(self.__dict__) 45 | 46 | def str_(self): 47 | return format(self.__dict__) 48 | 49 | class AttrDict(OrderedDict): 50 | def __getattr__(self, name): 51 | if name in self.__dict__: 52 | return self.__dict__[name] 53 | elif name in self: 54 | return self[name] 55 | else: 56 | raise AttributeError(name) 57 | 58 | def __setattr__(self, name, value): 59 | if name in self.__dict__: 60 | self.__dict__[name] = value 61 | else: 62 | self[name] = value 63 | 64 | @classmethod 65 | def from_dict(cls, dict_): 66 | for k, v in dict_.items(): 67 | if isinstance(v, dict): 68 | dict_[k] = cls.from_dict(v) 69 | return AttrDict(dict_) 70 | 71 | def deepcopy(self): 72 | return deepcopy(self) 73 | 74 | def remove(self, name): 75 | bak = self.deepcopy() 76 | bak.pop(name) 77 | return bak 78 | 79 | def list_index(input_list, index): 80 | return [input_list[i] for i in index] 81 | 82 | def list_logical_index(input_list, logical_index): 83 | return [input_list[i] for i in range(len(logical_index)) if logical_index[i]] 84 | 85 | def log_mvnpdf(x, mu, cov): 86 | return np.log(mvn(mu, cov).pdf(x)) 87 | 88 | def normalize_log_weights(log_w): 89 | if isinstance(log_w, list): 90 | log_w = np.array(log_w) 91 | if len(log_w) == 0: 92 | log_sum_w = np.empty([0]) 93 | log_w = np.empty([0]) 94 | elif len(log_w) == 1: 95 | log_sum_w = log_w[0] 96 | log_w = np.array([0.]) 97 | else: 98 | if np.max(log_w) == np.inf: 99 | # corner case 1 100 | inf_idx = log_w == np.inf 101 | log_w[inf_idx] = 0 - np.log(np.sum(inf_idx)) 102 | log_w[~inf_idx] = -np.inf 103 | log_sum_w = np.inf 104 | elif np.all(log_w == -np.inf): 105 | # corner case 2 106 | log_sum_w = float(len(log_w)) 107 | log_w = np.zeros_like(log_w) - np.log(len(log_w)) 108 | else: 109 | idx = np.argsort(-log_w) 110 | log_w_max = log_w[idx[0]] 111 | log_sum_w = log_w_max + np.log(1 + np.sum(np.exp(log_w[idx[1:]] - log_w_max))) 112 | log_w -= log_sum_w 113 | 114 | return log_w, log_sum_w 115 | 116 | def has_duplicated_ele(x): 117 | assert isinstance(x, np.ndarray) 118 | assert x.ndim == 1 119 | return not len(np.unique(x)) == len(x) 120 | 121 | def generate_trajectories(states, labels, keep_len=3): 122 | trajs = {} 123 | for state, label in zip(states, labels): 124 | for s, l in zip(state.T, label): 125 | if l in trajs.keys(): 126 | trajs[l].append(s) 127 | else: 128 | trajs[l] = [s] 129 | 130 | trajs_prune = {} 131 | for k, v in trajs.items(): 132 | if len(v) >= keep_len: 133 | trajs_prune[k] = np.stack(v, 1) 134 | return trajs_prune 135 | 136 | def weights_outer_sum(weight, res): 137 | # auxiliary function in unscented density 138 | # return np.sum(weight * np.stack([np.outer(res[:,i], res[:,i]) \ 139 | # for i in range(res.shape[1])], axis=-1), axis=-1) 140 | return res.dot(np.diag(weight).dot(res.T)) 141 | 142 | def is_pd(mat): 143 | # check the input matrix is posive definite 144 | return np.all(np.linalg.eigvals(mat) > 0) 145 | 146 | def import_model(path): 147 | p, m = path.rsplit('.', 1) 148 | mod = import_module(p) 149 | return getattr(mod, m) 150 | 151 | def import_model_with_paras(path): 152 | paras = path.remove('type') 153 | for k, v in paras.items(): 154 | if isinstance(v, str) and v.startswith('<') and v.endswith('>'): 155 | paras[k] = eval(v.strip('<>')) 156 | return import_model(path.type)(**paras) 157 | 158 | def tuple2int(x, y): 159 | return tuple([np.round(ele).astype(np.int) for ele in [x, y]]) 160 | 161 | def rem2pi(x): 162 | return x - 2 * np.pi * np.round(x / 2 / np.pi) 163 | -------------------------------------------------------------------------------- /statecircle/utils/data_association.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Oct 23 14:06:01 2019 5 | 6 | @author: zhxm 7 | """ 8 | 9 | from murty import Murty 10 | import numpy as np 11 | 12 | # TODO: fix the bug in murty cpp file which the cost matrix has the max negative number 13 | # TODO: This case can happen in the configuration which clutter_intensity == 0 14 | def data_association(cost_mat, topN): 15 | # TODO: temp workaround 16 | cost_mat[np.isinf(cost_mat)] = 10000 17 | # TODO: one way to walk around the annother annoy bug which two rows' value are same in cost matrix 18 | cost_mat += 1e-10 * np.random.randn(*cost_mat.shape) 19 | op = Murty(cost_mat) 20 | status = True 21 | cost = [] 22 | col_idx = [] 23 | num = 0 24 | while status and num < topN: 25 | status, cost_iter, col_iter = op.draw() 26 | if status: 27 | cost.append(cost_iter) 28 | col_idx.append(col_iter) 29 | num += 1 30 | assert len(cost) > 0, "invalid cost matrix." 31 | cost = np.array(cost) 32 | col_idx = np.stack(col_idx, -1) 33 | return col_idx, cost 34 | 35 | 36 | if __name__ == "__main__": 37 | cost_mat = np.eye(4) 38 | col_idx, cost = data_association(cost_mat, 20) 39 | print(col_idx) 40 | print(col_idx.shape) 41 | print(cost) 42 | 43 | 44 | -------------------------------------------------------------------------------- /statecircle/wiki/MultiHypothesis.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/statecircle/wiki/MultiHypothesis.pdf -------------------------------------------------------------------------------- /statecircle/wiki/MultiHypothesis.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/statecircle/wiki/MultiHypothesis.pptx -------------------------------------------------------------------------------- /statecircle/wiki/flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/statecircle/wiki/flow.png -------------------------------------------------------------------------------- /statecircle/wiki/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxiaomzxm/statecircle-python/956067e7c7ec0c1029c200256bc4b6fe5e40c551/statecircle/wiki/framework.png -------------------------------------------------------------------------------- /testcase/test_mhtracker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Sat Sep 12 19:51:11 2020 6 | 7 | @author: zhxm 8 | """ 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from scenarios.linear import LinearScenario 13 | from scenarios.nonlinear import NonlinearScenario 14 | from statecircle.estimator.base import EAPEstimator 15 | from statecircle.models.density.kalman_accumulated import KalmanAccumulatedDensityModel 16 | from statecircle.models.measurement.nonlinear import RangeBearningMeasurementModel 17 | from statecircle.models.transition.nonlinear import SimpleCTRVModel 18 | from statecircle.reader.base import MeasurementReader 19 | from statecircle.reductor.gate import EllipsoidalGate 20 | from statecircle.models.sensor.base import DummySensorModel 21 | from statecircle.models.measurement.clutter import PoissonClutterModel 22 | from statecircle.models.measurement.linear import LinearMeasurementModel 23 | from statecircle.datasets.base import SimulatedGroundTruthDataGenerator 24 | from statecircle.models.transition.linear import ConstantVelocityModel 25 | from statecircle.reductor.hypothesis_reductor import HypothesisReductor 26 | from statecircle.trackers.mot.pmbm_tracker import PMBMTracker 27 | from statecircle.trackers.mot.meas_driven_pmbm_tracker import MeasDrivenPMBMTracker 28 | 29 | PMBMTracker = MeasDrivenPMBMTracker # [PMBMTracker | MeasDrivenPMBMTracker] 30 | seed = 666 31 | # build scene 32 | scene = LinearScenario.caseF(birth_weight=0.0005, birth_cov_scale=400) 33 | birth_model = scene.birth_model 34 | 35 | # build transition/measurement/clutter/birth models 36 | transition_model = ConstantVelocityModel(sigma=5) 37 | measurement_model = LinearMeasurementModel(mapping=[1, 1, 0, 0], 38 | sigma=10) 39 | clutter_model = PoissonClutterModel(detection_rate=1.0, 40 | lambda_clutter=20, 41 | scope=[[-1000, 1000], [-1000, 1000]]) 42 | 43 | # build data generator 44 | data_generator = SimulatedGroundTruthDataGenerator(scene, transition_model, noisy=False) 45 | 46 | # build sensor model 47 | no_clutter_model = PoissonClutterModel(detection_rate=1.0, 48 | lambda_clutter=0, 49 | scope=[[-1000, 1000], [-1000, 1000]]) 50 | 51 | sensor_model = DummySensorModel(no_clutter_model, measurement_model, random_seed=seed, noisy=True) 52 | 53 | # build data reader 54 | data_reader = MeasurementReader(data_generator, sensor_model) 55 | 56 | # build density model 57 | #density_model = KalmanDensityModel() 58 | density_model = KalmanAccumulatedDensityModel(traceback_range=1) 59 | 60 | # gate method 61 | gate = EllipsoidalGate(percentile=0.999) 62 | 63 | # estimator 64 | estimator = EAPEstimator() 65 | 66 | # reductor 67 | reductor = HypothesisReductor(weight_min=0.01, merging_threshold=4, capping_num=1) 68 | 69 | # %% build trackers & filtering 70 | # some extra parameters 71 | # TODO: reformat the input parameters 72 | prior_birth = False 73 | surviving_rate = 0.99 74 | recycle_threshold = 0.1 75 | prob_min = 0.01 76 | prob_estimate = 0.5 77 | meas_models_dict = None 78 | clutter_models_dict = None 79 | 80 | pmbm_filter = PMBMTracker(prior_birth, 81 | surviving_rate, 82 | recycle_threshold, 83 | prob_min, 84 | prob_estimate, 85 | meas_models_dict, 86 | clutter_models_dict, 87 | birth_model, 88 | density_model, 89 | transition_model, 90 | measurement_model, 91 | clutter_model, 92 | gate, 93 | estimator, 94 | reductor) 95 | 96 | pmbm_estimates = pmbm_filter.filtering(data_reader) 97 | 98 | #%% ploting 99 | animation = False 100 | show_birth = True 101 | 102 | gt_datum = np.concatenate(data_generator.gt_series.datum, axis=-1) 103 | true_state = np.concatenate([ele.states for ele in gt_datum], -1) 104 | for k, (_, obj_meas_data, clutter_data) in enumerate(data_reader.truth_meas_generator()): 105 | if not animation: 106 | k = scene.time_range[-1] - 1 107 | PMBM_estimated_state = np.hstack(pmbm_estimates[k]).squeeze() 108 | 109 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 110 | plot_gt = ax.plot(true_state[0], true_state[1], 'yo', alpha=0.2, markersize=10) 111 | 112 | ax.grid() 113 | ax.set_xlabel('x (m)') 114 | ax.set_ylabel('y (m)') 115 | 116 | # # plot birth region 117 | # if prior_birth and show_birth: 118 | # for birth_state in birth_model: 119 | # plot_birth = plot_covariance_ellipse(birth_state.x[:2], birth_state.P[:2,:2], 'b', ax, 3) 120 | # else: 121 | # plot_birth = None 122 | 123 | # plot tracks 124 | for track in pmbm_estimates[k]: 125 | range_t = track['range'][-1] - track['range'][0] + 1 126 | # if range_t > -1: 127 | plt.plot(track['trajectory'][0], track['trajectory'][1], '-') 128 | 129 | # plot obejct measurements 130 | state_meas = measurement_model.reverse(obj_meas_data) 131 | plot_meas = ax.plot(state_meas[0], state_meas[1], 'r*', alpha=1) 132 | # plot clutter 133 | state_clutter = measurement_model.reverse(clutter_data) 134 | plot_clutter = ax.plot(state_clutter[0], state_clutter[1], 'k.', alpha=1) 135 | 136 | ax.legend((plot_gt[0], plot_meas[0], plot_clutter[0]), 137 | ['ground truth', 'detections', 'clutter'], 138 | loc='upper left') 139 | 140 | ax.set_xlim([-1000, 1000]) 141 | ax.set_ylim([-1000, 1000]) 142 | # plt.axis('equal') 143 | # plt.savefig('snapshot/results/track_{:04d}.png'.format(k)) 144 | 145 | plt.show() 146 | print('step: {}'.format(k)) 147 | plt.close('all') 148 | 149 | if not animation: 150 | break 151 | 152 | #%% plot cardinality 153 | plt.figure() 154 | plt.plot(data_generator.gt_series.num, 'yo') 155 | #PMBM_card_pred = [for ele in track for step, track in enumerate(PMBMEstimates)] 156 | PMBM_card_pred = [] 157 | for step, tracks in enumerate(pmbm_estimates): 158 | valid_track_num = 0 159 | for track in tracks: 160 | range_t = track['range'][-1] - track['range'][0] + 1 161 | if track['range'][-1] == step: 162 | valid_track_num += 1 163 | PMBM_card_pred.append(valid_track_num) 164 | plt.plot(PMBM_card_pred, 'b+') 165 | plt.legend(['GT', 'PMBM']) 166 | plt.grid() 167 | 168 | 169 | # %% plot measurements 170 | meas, obj_meas, clutter_meas = [], [], [] 171 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 172 | meas.append(meas_data.meas) 173 | obj_meas.append(obj_meas_) 174 | clutter_meas.append(clutter_meas_) 175 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 176 | 177 | plt.figure() 178 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 179 | 180 | # plot clutter 181 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 182 | plt.legend(['measurements', 'clutter']) 183 | plt.show() 184 | plt.close('all') 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 4 | 5 | @author: zhxm 6 | """ -------------------------------------------------------------------------------- /tools/make_movie.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Mar 10 16:00:50 2020 6 | 7 | @author: zhaoxm 8 | """ 9 | 10 | import os 11 | from glob import glob 12 | from tqdm import tqdm 13 | import numpy as np 14 | import cv2 15 | from pathlib import Path 16 | from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter 17 | 18 | 19 | def mkvideo(im_seq_path, video_file, image_size, fps=12, use_cv=True): 20 | if use_cv: 21 | writer = cv2.VideoWriter(video_file, cv2.VideoWriter_fourcc(*'XVID'), fps, image_size) 22 | else: 23 | writer = FFMPEG_VideoWriter(video_file, size=image_size, fps=fps) 24 | for im_path in tqdm(im_seq_path): 25 | im = cv2.imread(str(im_path)) 26 | if use_cv: 27 | writer.write(im) 28 | else: 29 | im = im[:, :, ::-1] 30 | writer.write_frame(im) 31 | if use_cv: 32 | cv2.destroyAllWindows() 33 | writer.release() 34 | else: 35 | writer.close() 36 | 37 | def sh_mkvideo(): 38 | os.system("ffmpeg -r 1 -i %07d.jpg -vcodec mpeg4 -y test.mp4") 39 | 40 | if __name__ == '__main__': 41 | path = r'/home/zhxm/datasets/fusion_dataset/trial/trial_2020_02_20_16-45-57#2020-06-10_09:55' 42 | save_path = path + '.avi' 43 | 44 | im_seq = sorted(Path(path).glob('*.jpg')) 45 | image_size = cv2.imread(str(im_seq[0])).shape[1::-1] 46 | mkvideo(im_seq, save_path, image_size) 47 | 48 | #%% 49 | #path = r'/home/zhaoxm/datasets/tmp/fusion_results/fused4557/maps_resize' 50 | #im_seq = sorted(Path(path).glob('*.jpg')) 51 | #save_path = Path(r'/home/zhaoxm/datasets/tmp/fusion_results/fused4557/maps_resize') 52 | #save_path.mkdir(parents=True, exist_ok=True) 53 | #for img_path in tqdm(im_seq): 54 | # im = cv2.imread(str(img_path)) 55 | # im_ = cv2.resize(im, (1000, 1000), interpolation=cv2.INTER_LINEAR) 56 | # save_name = save_path / img_path.name 57 | # cv2.imwrite(str(save_name), im_) 58 | 59 | 60 | -------------------------------------------------------------------------------- /tools/visualizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | 5 | Created on Tue Dec 24 16:13:12 2019 6 | 7 | @author: zhaoxm 8 | """ 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | class Visualizer: 13 | def __init__(self, data_generator, name='trial'): 14 | self.data_generator = data_generator 15 | self.name = name 16 | 17 | def show_estimates(self, estimates): 18 | gt_data = np.hstack([ele.states for ele in self.data_generator]) 19 | fig, ax = plt.subplots(1, 1, figsize=(6, 6)) 20 | plot_gt = ax.plot(gt_data[0], gt_data[1], 'yo', alpha=0.2, markersize=10) 21 | plot_est = ax.plot(estimates[0], estimates[1], 'g+', alpha=0.5) 22 | 23 | ax.grid() 24 | ax.set_xlabel('x (m)') 25 | ax.set_ylabel('y (m)') 26 | ax.legend((plot_gt[0], plot_est[0]), ['Ground Truth', self.name]) 27 | plt.axis('equal') 28 | 29 | def show_cardinality(self, card_pred): 30 | plt.figure() 31 | plt.plot(self.data_generator.gt_series.num, 'yo') 32 | plt.plot(card_pred, 'b+') 33 | plt.legend(['GT', self.name]) 34 | plt.grid() 35 | 36 | def show_measurements(self, data_reader): 37 | meas, obj_meas, clutter_meas = [], [], [] 38 | for meas_data, obj_meas_, clutter_meas_ in data_reader.truth_meas_generator(): 39 | meas.append(meas_data.meas) 40 | obj_meas.append(obj_meas_) 41 | clutter_meas.append(clutter_meas_) 42 | meas, obj_meas, clutter_meas = np.hstack(meas), np.hstack(obj_meas), np.hstack(clutter_meas) 43 | 44 | plt.figure() 45 | plt.plot(obj_meas[0], obj_meas[1], 'r.', alpha=0.5) 46 | 47 | # plot clutter 48 | plt.plot(clutter_meas[0], clutter_meas[1], 'k.', alpha=0.2) 49 | plt.legend(['measurements', 'clutter']) 50 | plt.show() 51 | plt.close('all') --------------------------------------------------------------------------------