├── LICENSE ├── README.md ├── dependencies.txt ├── scripts ├── run_full_pipeline.sh ├── run_preprocess_data.sh └── train_and_test_driver_sensor_model.sh └── src ├── driver_sensor_model ├── inference_cvae.py ├── inference_gmm.py ├── inference_kmeans.py ├── models_cvae.py ├── train_cvae.py ├── train_gmm.py ├── train_kmeans.py └── visualize_cvae.ipynb ├── full_pipeline ├── full_pipeline_metrics.py ├── main_save_full_pipeline.py └── main_visualize_full_pipeline.py ├── preprocess ├── generate_data.py ├── get_driver_sensor_data.py ├── preprocess_driver_sensor_data.py ├── statistics.txt └── train_val_test_split.py └── utils ├── combinations.py ├── data_generator.py ├── dataset_reader.py ├── dataset_types.py ├── grid_fuse.py ├── grid_utils.py ├── interaction_utils.py ├── map_vis_without_lanelet.py ├── tracks_save.py ├── tracks_vis.py └── utils_model.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiAgentVariationalOcclusionInference 2 | Multi-agent occlusion inference using observed driver behaviors. A driver sensor model is learned using a conditional variational autoencoder which maps an observed driver trajectory to the space ahead of the driver, represented as an occupancy grid map (OGM). Information from multiple drivers is fused into an ego vehicle's map using evidential theory. See our [video](https://www.youtube.com/watch?v=cTHl5nDBNBM) and [paper](https://arxiv.org/abs/2109.02173) for more details: 3 | 4 | M. Itkina, Y.-J. Mun, K. Driggs-Campbell, and M. J. Kochenderfer. "Multi-Agent Variational Occlusion Inference Using People as Sensors". In International Conference on Robotics and Automation (ICRA), 2022. 5 | 6 |

7 | 8 |

9 | 10 | **Approach Overview:** Our proposed occlusion inference approach. The learned driver sensor model maps behaviors of visible drivers (cyan) to an OGM of the environment ahead of them (gray). The inferred OGMs are then fused into the ego vehicle’s (green) map. Our goal is to infer the presence or absence of occluded agents (blue). Driver 1 is waiting to turn left, occluding oncoming traffic from the ego vehicle’s view. Driver 1 being stopped should indicate to the ego that there may be oncoming traffic; and thus it is not safe to proceed with its right turn. Driver 2 is driving at a constant speed. This observed behavior is not enough to discern whether the vehicle is traveling in traffic or on an open road. We aim to encode such intuition into our occlusion inference algorithm. 11 | 12 |

13 | 14 |

15 | 16 | **Results (Example Scenario 1):** The ego vehicle (green) is waiting to turn right. Observed driver 33 is waiting for a gap in cross-traffic to make a left turn, blocking the view of the ego vehicle. Our occlusion inference algorithms produces occupancy estimates as soon as 1 second of trajectory data is accumulated for an observed driver. While driver 33 is stopped (left), our algorithm estimates occupied space in the region ahead of driver 33, encompassing occluded driver 92. When driver 33 accelerates (right), this space is inferred to be free, indicating to the ego vehicle to proceed with its right turn. These results match our intuition for these observed behaviors. 17 | 18 |

19 | 20 |

21 | 22 | **Results (Example Scenario 2):** Observed driver 106 is waiting to turn left ahead of the ego vehicle (green). Observed drivers 102, 113, and 116 are also waiting for a gap in cross-traffic. While drivers 106, 102, 113, and 116 are all stopped (left), our algorithm estimates highly occupied space surrounding occluded driver 112 due to multiple agreeing measurements. When driver 106 starts their left turn (right), this space is estimated to be more free, indicating to the ego vehicle that it may be safe to complete its maneuver (e.g., a U-turn). 23 | 24 | ## Instructions 25 | The code reproduces the qualitative and quantitative experiments in the paper. The required dependencies are listed in dependencies.txt. Note that the INTERACTION dataset for the GL intersection has to be downloaded from: https://interaction-dataset.com/ and placed into the `data` directory. Then, a directory: `data/INTERACTION-Dataset-DR-v1_1` should exist. 26 | 27 | To run the experiments, please run the following files: 28 | 29 | - To process the data: 30 | `scripts/run_preprocess_data.sh`. 31 | 32 | - To train and test the driver sensor models (ours, k-means PaS, GMM PaS): 33 | `scripts/train_and_test_driver_sensor_model.sh`. 34 | 35 | - To run and evaluate the multi-agent occlusion inference pipeline: 36 | `scripts/run_full_pipeline.sh`. 37 | 38 | Running the above code will reproduce the numbers in the tables reported in the paper. To visualize the qualitative results, see `src/driver_sensor_model/visualize_cvae.ipynb` (Fig. 3) and `src/full_pipeline/main_visualize_full_pipeline.py` (Fig. 4). 39 | 40 | ## Data Split 41 | To process the data, ego vehicle IDs were sampled from the GL intersection in the [INTERACTION dataset](https://interaction-dataset.com/). For each of the 60 available scenes, a maximum of 100 ego vehicles were chosen. The train, validation, and test set were randomly split based on the ego vehicle IDs accounting for 85%, 5% and 10% of the total number of ego vehicles, respectively. Due to computational constraints, the training set for the driver sensor model was further reduced to have 70,001 contiguous driver sensor trajectories or 2,602,332 time steps of data. The validation set used to select the driver sensor model and the sensor fusion scheme contained 4,858 contiguous driver sensor trajectories or 180,244 time steps and 289 ego vehicles. The results presented in this paper were reported on the test set, which consists of 9,884 contiguous driver sensor trajectories or 365,201 time steps and 578 ego vehicles. 42 | 43 | ## CVAE Driver Sensor Model Architecture and Training 44 | We set the number of latent classes in the CVAE to K = 100 based on computational time and tractability for the considered baselines. We standardize the trajectory data to have zero mean and unit standard deviation. The prior encoder in the model consists of an LSTM with a hidden dimension of 5 to process the 1 s of trajectory input data. A linear layer then coverts the output into a K-dimensional vector that goes into a softmax function, producing the prior distribution. The posterior encoder extracts features from the ground truth input OGM using a [VQ-VAE](https://arxiv.org/abs/1711.00937) backbone with a hidden dimension of 4. These features are flattened and concatenated with the LSTM output from the trajectory data, and then passed into a linear layer and a softmax function, producing the posterior distribution. The decoder passes the latent encoding through two linear layers with ReLU activation functions and a transposed VQ-VAE backbone, outputting the inferred OGM. 45 | 46 | To avoid latent space collapse, we clamped the KL divergence term in the loss in at 0.2. Additionally, we anneal the beta hyperparameter in the loss according to a sigmoid schedule. In our hyperparameter search, we found a maximum beta of 1 with a crossover point at 10,000 iterations to work well. The beta hyperparameter increases from a value of 0 to 1 over 1,000 iterations. We set the alpha hyperparameter to 1.5. We trained the network with a batch size of 256 for 30 epochs using the [Adam optimizer](https://arxiv.org/abs/1412.6980) with a starting learning rate of 0.001. 47 | -------------------------------------------------------------------------------- /dependencies.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.9.0=pypi_0 6 | argon2-cffi=20.1.0=py36h27cfd23_1 7 | async_generator=1.10=py36h28b3542_0 8 | attrs=19.3.0=pypi_0 9 | backcall=0.2.0=pyhd3eb1b0_0 10 | bleach=3.1.5=pypi_0 11 | ca-certificates=2021.5.25=h06a4308_1 12 | cachetools=4.1.0=pypi_0 13 | certifi=2021.5.30=py36h06a4308_0 14 | cffi=1.14.5=py36h261ae71_0 15 | chardet=3.0.4=pypi_0 16 | cvxpy=1.1.1=pypi_0 17 | cycler=0.10.0=pypi_0 18 | decorator=4.4.2=pypi_0 19 | defusedxml=0.6.0=pypi_0 20 | diffcp=1.0.13=pypi_0 21 | dill=0.3.2=pypi_0 22 | ecos=2.0.7.post1=pypi_0 23 | entrypoints=0.3=pypi_0 24 | future=0.18.2=pypi_0 25 | google-auth=1.16.0=pypi_0 26 | google-auth-oauthlib=0.4.1=pypi_0 27 | grpcio=1.29.0=pypi_0 28 | h5py=2.10.0=pypi_0 29 | hickle=4.0.0=pypi_0 30 | idna=2.9=pypi_0 31 | imageio=2.8.0=pypi_0 32 | importlib-metadata=1.6.0=pypi_0 33 | importlib_metadata=3.10.0=hd3eb1b0_0 34 | ipykernel=5.3.2=pypi_0 35 | ipython=7.16.1=py36h5ca1d4c_0 36 | ipython_genutils=0.2.0=pyhd3eb1b0_1 37 | ipywidgets=7.6.3=pyhd3eb1b0_1 38 | jedi=0.17.1=pypi_0 39 | jinja2=2.11.2=pypi_0 40 | joblib=0.16.0=pypi_0 41 | json5=0.9.5=pypi_0 42 | jsonschema=3.2.0=py_2 43 | jupyter-client=6.1.6=pypi_0 44 | jupyter-core=4.6.3=pypi_0 45 | jupyter_client=6.1.12=pyhd3eb1b0_0 46 | jupyter_core=4.7.1=py36h06a4308_0 47 | jupyterlab=2.2.0=pypi_0 48 | jupyterlab-server=1.2.0=pypi_0 49 | jupyterlab_pygments=0.1.2=py_0 50 | jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 51 | kiwisolver=1.2.0=pypi_0 52 | kmeans-pytorch=0.3=pypi_0 53 | ld_impl_linux-64=2.33.1=h53a641e_7 54 | libedit=3.1.20181209=hc058e9b_0 55 | libffi=3.3=he6710b0_1 56 | libgcc-ng=9.1.0=hdf63c60_0 57 | libsodium=1.0.18=h7b6447c_0 58 | libstdcxx-ng=9.1.0=hdf63c60_0 59 | markdown=3.2.2=pypi_0 60 | markupsafe=1.1.1=pypi_0 61 | matplotlib=3.2.1=pypi_0 62 | mistune=0.8.4=py36h7b6447c_0 63 | nbclient=0.5.3=pyhd3eb1b0_0 64 | nbconvert=5.6.1=pypi_0 65 | nbformat=5.0.7=pypi_0 66 | ncurses=6.2=he6710b0_1 67 | nest-asyncio=1.5.1=pyhd3eb1b0_0 68 | networkx=2.5=pypi_0 69 | notebook=6.0.3=pypi_0 70 | numpy=1.18.4=pypi_0 71 | oauthlib=3.1.0=pypi_0 72 | opencv-python=4.4.0.46=pypi_0 73 | openssl=1.1.1k=h27cfd23_0 74 | osqp=0.6.1=pypi_0 75 | packaging=20.4=pypi_0 76 | pandas=1.1.0=pypi_0 77 | pandoc=2.12=h06a4308_0 78 | pandocfilters=1.4.2=pypi_0 79 | parso=0.7.0=pypi_0 80 | pathlib=1.0.1=pypi_0 81 | pexpect=4.8.0=pyhd3eb1b0_3 82 | pickleshare=0.7.5=pyhd3eb1b0_1003 83 | pillow=7.1.2=pypi_0 84 | pip=20.0.2=py36_3 85 | prometheus-client=0.8.0=pypi_0 86 | prometheus_client=0.10.1=pyhd3eb1b0_0 87 | prompt-toolkit=3.0.5=pypi_0 88 | protobuf=3.12.2=pypi_0 89 | ptyprocess=0.6.0=pypi_0 90 | pwlf=2.0.4=pypi_0 91 | pyasn1=0.4.8=pypi_0 92 | pyasn1-modules=0.2.8=pypi_0 93 | pybind11=2.5.0=pypi_0 94 | pycparser=2.20=py_2 95 | pydoe=0.3.8=pypi_0 96 | pygments=2.6.1=pypi_0 97 | pyparsing=2.4.7=pyhd3eb1b0_0 98 | pyproj=2.6.1.post1=pypi_0 99 | pyrsistent=0.16.0=pypi_0 100 | python=3.6.10=h7579374_2 101 | python-dateutil=2.8.1=pyhd3eb1b0_0 102 | pytorch-lightning=0.7.6=pypi_0 103 | pytz=2020.1=pypi_0 104 | pywavelets=1.1.1=pypi_0 105 | pyyaml=5.3.1=pypi_0 106 | pyzmq=19.0.1=pypi_0 107 | readline=8.0=h7b6447c_0 108 | requests=2.23.0=pypi_0 109 | requests-oauthlib=1.3.0=pypi_0 110 | rsa=4.0=pypi_0 111 | scikit-image=0.17.2=pypi_0 112 | scikit-learn=0.23.2=pypi_0 113 | scipy=1.4.1=pypi_0 114 | scs=2.1.2=pypi_0 115 | seaborn=0.10.1=pypi_0 116 | send2trash=1.5.0=pyhd3eb1b0_1 117 | setuptools=46.4.0=py36_0 118 | six=1.15.0=py36h06a4308_0 119 | sklearn=0.0=pypi_0 120 | sqlite=3.31.1=h62c20be_1 121 | tensorboard=2.4.1=pypi_0 122 | tensorboard-plugin-wit=1.8.0=pypi_0 123 | terminado=0.8.3=pypi_0 124 | testpath=0.4.4=pyhd3eb1b0_0 125 | threadpoolctl=2.1.0=pypi_0 126 | tifffile=2020.9.3=pypi_0 127 | tikzplotlib=0.9.8=pypi_0 128 | tk=8.6.8=hbc83047_0 129 | torch=1.2.0=pypi_0 130 | torchvision=0.4.0=pypi_0 131 | tornado=6.0.4=pypi_0 132 | tqdm=4.46.0=pypi_0 133 | traitlets=4.3.3=py36_0 134 | typing_extensions=3.7.4.3=pyha847dfd_0 135 | urllib3=1.25.9=pypi_0 136 | wcwidth=0.2.5=py_0 137 | webencodings=0.5.1=pypi_0 138 | werkzeug=1.0.1=pypi_0 139 | wheel=0.34.2=py36_0 140 | widgetsnbextension=3.5.1=py36_0 141 | xz=5.2.5=h7b6447c_0 142 | zeromq=4.3.4=h2531618_0 143 | zipp=3.1.0=pypi_0 144 | zlib=1.2.11=h7b6447c_3 145 | -------------------------------------------------------------------------------- /scripts/run_full_pipeline.sh: -------------------------------------------------------------------------------- 1 | cd ../src/full_pipeline 2 | python main_save_full_pipeline.py --model=vae --mode=evidential 3 | python main_save_full_pipeline.py --model=gmm --mode=evidential 4 | python main_save_full_pipeline.py --model=kmeans --mode=evidential 5 | python main_save_full_pipeline.py --model=vae --mode=average 6 | 7 | # Change the model and fusion mode in the file to obtain metrics. 8 | python full_pipeline_metrics.py 9 | 10 | # Visualize the scenarios in the paper and in the appendix. 11 | python main_visualize_full_pipeline.py --model=vae --mode=evidential 12 | 13 | -------------------------------------------------------------------------------- /scripts/run_preprocess_data.sh: -------------------------------------------------------------------------------- 1 | # Record the mean and standard deviation statistics at the end into statistics.txt and into src/utils/data_generator.py in three functions. 2 | cd ../src/preprocess 3 | python generate_data.py 4 | python train_val_test_split.py 5 | python get_driver_sensor_data.py 6 | python preprocess_driver_sensor_data.py 7 | -------------------------------------------------------------------------------- /scripts/train_and_test_driver_sensor_model.sh: -------------------------------------------------------------------------------- 1 | cd ../src/driver_sensor_model 2 | # Train and test our CVAE driver sensor model. 3 | python train_cvae.py --norm --mut_info='const' --epochs=30 --learning_rate=0.001 --beta=1.0 --alpha=1.5 --latent_size=100 --batch_size=256 --crossover=10000 4 | python inference_cvae.py --norm --mut_info='const' --epochs=30 --learning_rate=0.001 --beta=1.0 --alpha=1.5 --latent_size=100 --batch_size=256 --crossover=10000 5 | 6 | # Train and test the k-means PaS baseline driver sensor model. 7 | python train_kmeans.py 8 | python inference_kmeans.py 9 | 10 | # Train and test the GMM PaS baseline driver sensor model. 11 | python train_gmm.py 12 | python inference_gmm.py 13 | 14 | # To view the qualitative driver sensor model results, please see the Jupyter notebook: src/driver_sensor_model/visualize_cvae.ipynb. 15 | -------------------------------------------------------------------------------- /src/driver_sensor_model/inference_gmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(123) 3 | torch.backends.cudnn.deterministic = True 4 | torch.backends.cudnn.benchmark = False 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import torchvision 8 | from torchvision import datasets, transforms 9 | 10 | import numpy as np 11 | np.random.seed(0) 12 | from matplotlib import pyplot as plt 13 | import hickle as hkl 14 | import pickle as pkl 15 | import pdb 16 | import os 17 | import multiprocessing as mp 18 | from mpl_toolkits.axes_grid1 import ImageGrid 19 | 20 | os.chdir("../..") 21 | 22 | from src.utils.data_generator import * 23 | from sklearn.mixture import GaussianMixture as GMM 24 | from src.utils.interaction_utils import * 25 | import time 26 | 27 | from tqdm import tqdm 28 | 29 | # Load data. 30 | nt = 10 31 | num_states = 7 32 | grid_shape = (20, 30) 33 | 34 | use_cuda = torch.cuda.is_available() 35 | device = torch.device("cuda:0" if use_cuda else "cpu") 36 | 37 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset' 38 | 39 | # Test on test data. 40 | data_file_states = os.path.join(dir, 'states_test.hkl') 41 | data_file_grids = os.path.join(dir, 'label_grids_test.hkl') 42 | data_file_sources = os.path.join(dir, 'sources_test.hkl') 43 | 44 | data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 45 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False) 46 | 47 | test_loader = torch.utils.data.DataLoader(data_test, 48 | batch_size=len(data_test), shuffle=False, 49 | num_workers=mp.cpu_count()-1, pin_memory=True) 50 | 51 | for batch_x_test, batch_y_test, sources_test in test_loader: 52 | batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device) 53 | 54 | batch_x_test = batch_x_test.cpu().data.numpy() 55 | batch_y_test_orig = batch_y_test_orig.cpu().data.numpy() 56 | 57 | # GMM 58 | plt.ioff() 59 | 60 | models = 'gmm' 61 | dir_models = os.path.join('/models/', models) 62 | if not os.path.isdir(dir_models): 63 | os.mkdir(dir_models) 64 | 65 | Ks = [100] # [4, 10, 25, 50, 100, 150] 66 | N = 3 67 | ncols = {4:2, 10:2, 25:5, 50:5, 100:10, 150:10} 68 | nrows = {4:2, 10:5, 25:5, 50:10, 100:10, 150:15} 69 | acc_nums = [] 70 | mse_nums = [] 71 | im_nums = [] 72 | plot_train_flag = True 73 | plot_metric_clusters_flag = False 74 | plot_scatter_flag = False 75 | plot_scatter_clusters_flag = False 76 | 77 | for K in Ks: 78 | 79 | model, p_m_a_np = pkl.load(open(os.path.join(dir_models, "GMM_K_" + str(K) + "_reg_covar_0_001.pkl"), "rb" ) ) 80 | 81 | cluster_centers_np = model.means_ 82 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,num_states)) 83 | cluster_centers_np = unnormalize(cluster_centers_np) 84 | 85 | if plot_train_flag: 86 | plot_train(K, p_m_a_np, cluster_centers_np, dir_models, grid_shape) 87 | 88 | cluster_ids_y_test = model.predict(batch_x_test) 89 | cluster_ids_y_test_prob = model.predict_proba(batch_x_test) 90 | 91 | # Most likely class. 92 | grids_pred_orig = p_m_a_np[cluster_ids_y_test] 93 | 94 | grids_plot = np.reshape(p_m_a_np, (-1,grid_shape[0],grid_shape[1])) 95 | fig, axeslist = plt.subplots(ncols=ncols[K], nrows=nrows[K]) 96 | for i in range(grids_plot.shape[0]): 97 | axeslist.ravel()[i].imshow(grids_plot[i], cmap=plt.gray()) 98 | axeslist.ravel()[i].set_xticks([]) 99 | axeslist.ravel()[i].set_xticklabels([]) 100 | axeslist.ravel()[i].set_yticks([]) 101 | axeslist.ravel()[i].set_yticklabels([]) 102 | axeslist.ravel()[i].set_aspect('equal') 103 | plt.subplots_adjust(left = 0.5, hspace = 0, wspace = 0) 104 | plt.margins(0,0) 105 | fig.savefig(os.path.join(dir_models, 'test_' + str(K) + '.png'), pad_inches=0) 106 | plt.close(fig) 107 | 108 | grids_pred = deepcopy(grids_pred_orig) 109 | grids_pred[grids_pred_orig >= 0.6] = 1.0 110 | grids_pred[np.logical_and(grids_pred_orig < 0.6, grids_pred_orig > 0.4)] = 0.5 111 | grids_pred[grids_pred_orig <= 0.4] = 0.0 112 | 113 | acc_occ_free = np.mean(grids_pred == batch_y_test_orig) 114 | mse_occ_free = np.mean((grids_pred_orig - batch_y_test_orig)**2) 115 | 116 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = \ 117 | MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1]))) # im_ocl_grids, 118 | im = np.mean(im_grids) 119 | im_occ = np.mean(im_occ_grids) 120 | im_free = np.mean(im_free_grids) 121 | im_occ_free_grids = im_occ_grids + im_free_grids 122 | im_occ_free = np.mean(im_occ_grids + im_free_grids) 123 | 124 | acc_nums.append(acc) 125 | mse_nums.append(mse) 126 | im_nums.append(im) 127 | 128 | acc_occ = np.mean(grids_pred[batch_y_test_orig == 1] == batch_y_test_orig[batch_y_test_orig == 1]) 129 | acc_free = np.mean(grids_pred[batch_y_test_orig == 0] == batch_y_test_orig[batch_y_test_orig == 0]) 130 | 131 | mse_occ = np.mean((grids_pred_orig[batch_y_test_orig == 1] - batch_y_test_orig[batch_y_test_orig == 1])**2) 132 | mse_free = np.mean((grids_pred_orig[batch_y_test_orig == 0] - batch_y_test_orig[batch_y_test_orig == 0])**2) 133 | 134 | print("Metrics: ") 135 | 136 | print("Occupancy and Free Metrics: ") 137 | print("K: ", K, "Accuracy: ", acc_occ_free, "MSE: ", mse_occ_free, "IM: ", im_occ_free, "IM max: ", np.amax(im_occ_free_grids), "IM min: ", np.amin(im_occ_free_grids)) 138 | 139 | print("Occupancy Metrics: ") 140 | print("Accuracy: ", acc_occ, "MSE: ", mse_occ, "IM: ", im_occ, "IM max: ", np.amax(im_occ_grids), "IM min: ", np.amin(im_occ_grids)) 141 | 142 | print("Free Metrics: ") 143 | print("Accuracy: ", acc_free, "MSE: ", mse_free, "IM: ", im_free, "IM max: ", np.amax(im_free_grids), "IM min: ", np.amin(im_free_grids)) 144 | 145 | num_occ = np.sum(batch_y_test_orig == 1, axis=-1) 146 | num_free = np.sum(batch_y_test_orig == 0, axis=-1) 147 | num_occ_free = np.sum(np.logical_or(batch_y_test_orig == 1, batch_y_test_orig == 0), axis=-1) 148 | 149 | im_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 150 | acc_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 151 | mse_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 152 | im_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 153 | acc_occ_nums_N = np.empty((np.sum(num_occ > 0), N)) 154 | mse_occ_nums_N = np.empty((np.sum(num_occ > 0), N)) 155 | im_occ_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 156 | acc_free_nums_N = np.empty((np.sum(num_free > 0), N)) 157 | mse_free_nums_N = np.empty((np.sum(num_free > 0), N)) 158 | im_free_nums_N = np.empty((batch_y_test_orig.shape[0], N)) 159 | 160 | for n in range(1,N+1): 161 | grids_pred_orig = p_m_a_np[cluster_ids_y_test_prob.argsort(1)[:,-n]] 162 | grids_pred = (grids_pred_orig >= 0.6).astype(float) 163 | grids_pred[grids_pred_orig <= 0.4] = 0.0 164 | grids_pred[np.logical_and(grids_pred_orig < 0.6, grids_pred_orig > 0.4)] = 0.5 165 | 166 | grids_pred_free = (grids_pred * (batch_y_test_orig == 0)) 167 | grids_pred_free[batch_y_test_orig != 0] = 2.0 168 | 169 | acc_occ_free_grids = np.mean(grids_pred == batch_y_test_orig, axis=-1) 170 | acc_occ_grids = np.sum((grids_pred * (batch_y_test_orig == 1)) == 1., axis=-1)[num_occ > 0] / num_occ[num_occ > 0] * 1.0 171 | acc_free_grids = np.sum(grids_pred_free == 0., axis=-1)[num_free > 0] / num_free[num_free > 0] * 1.0 172 | 173 | mse_occ_free_grids = np.mean((grids_pred_orig - batch_y_test_orig)**2, axis=-1) 174 | mse_occ_grids = np.sum(((grids_pred_orig * (batch_y_test_orig == 1)) - batch_y_test_orig * (batch_y_test_orig == 1))**2, axis=-1)[num_occ > 0] / num_occ[num_occ > 0] * 1.0 175 | mse_free_grids = np.sum(((grids_pred_orig * (batch_y_test_orig == 0)) - batch_y_test_orig * (batch_y_test_orig == 0))**2, axis=-1)[num_free > 0] / num_free[num_free > 0] * 1.0 176 | 177 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1]))) 178 | 179 | acc_occ_free_nums_N[:,n-1] = acc_occ_free_grids 180 | mse_occ_free_nums_N[:,n-1] = mse_occ_free_grids 181 | im_occ_free_nums_N[:,n-1] = im_occ_grids + im_free_grids 182 | 183 | acc_occ_nums_N[:,n-1] = acc_occ_grids 184 | mse_occ_nums_N[:,n-1] = mse_occ_grids 185 | im_occ_nums_N[:,n-1] = im_occ_grids 186 | 187 | acc_free_nums_N[:,n-1] = acc_free_grids 188 | mse_free_nums_N[:,n-1] = mse_free_grids 189 | im_free_nums_N[:,n-1] = im_free_grids 190 | 191 | acc_occ_free_best = np.mean(np.amax(acc_occ_free_nums_N, axis=1)) 192 | mse_occ_free_best = np.mean(np.amin(mse_occ_free_nums_N, axis=1)) 193 | im_occ_free_best = np.mean(np.amin(im_occ_free_nums_N, axis=1)) 194 | 195 | acc_occ_best = np.mean(np.amax(acc_occ_nums_N, axis=1)) 196 | mse_occ_best = np.mean(np.amin(mse_occ_nums_N, axis=1)) 197 | im_occ_best = np.mean(np.amin(im_occ_nums_N, axis=1)) 198 | 199 | acc_free_best = np.mean(np.amax(acc_free_nums_N, axis=1)) 200 | mse_free_best = np.mean(np.amin(mse_free_nums_N, axis=1)) 201 | im_free_best = np.mean(np.amin(im_free_nums_N, axis=1)) 202 | 203 | print("Top 3 Metrics: ") 204 | 205 | print("Occupancy and Free Metrics: ") 206 | print("Accuracy: ", acc_occ_free_best, "MSE: ", mse_occ_free_best, "IM: ", im_occ_free_best) 207 | 208 | print("Occupancy Metrics: ") 209 | print("Accuracy: ", acc_occ_best, "MSE: ", mse_occ_best, "IM: ", im_occ_best) 210 | 211 | print("Free Metrics: ") 212 | print("Accuracy: ", acc_free_best, "MSE: ", mse_free_best, "IM: ", im_free_best) 213 | 214 | hkl.dump(np.array([acc_best, acc_occ_best, acc_free_best, acc_occ_free_best,\ 215 | mse_best, mse_occ_best, mse_free_best, mse_occ_free_best,\ 216 | im_best, im_occ_best, im_free_best, im_occ_free_best]),\ 217 | os.path.join(dir_models, 'top_3_metrics.hkl'), mode='w',) 218 | 219 | batch_x_np_test = batch_x_test 220 | batch_x_np_test = np.reshape(batch_x_np_test, (-1,nt,num_states)) 221 | batch_x_np_test = unnormalize(batch_x_np_test) 222 | 223 | if plot_metric_clusters_flag: 224 | plot_metric_clusters(K, batch_x_np_test, cluster_ids_y_test, im_grids2, 'IM2', dir_models, sources_test) 225 | 226 | if plot_scatter_flag: 227 | plot_scatter(K, batch_x_np_test, im_grids, 'IM', dir_models) 228 | 229 | if plot_scatter_clusters_flag: 230 | plot_scatter_clusters(K, batch_x_np_test, cluster_ids_y_test, dir_models) 231 | plt.show() 232 | 233 | # Standard error 234 | acc_occ_std_error = np.std(grids_pred[batch_y_test_orig == 1] == 1)/np.sqrt(grids_pred[batch_y_test_orig == 1].size) 235 | acc_free_std_error = np.std(grids_pred[batch_y_test_orig == 0] == 0)/np.sqrt(grids_pred[batch_y_test_orig == 0].size) 236 | acc_occ_free_std_error = np.std(grids_pred[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] == batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size) 237 | 238 | mse_occ_std_error = np.std((grids_pred_orig[batch_y_test_orig == 1] - 1)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size) 239 | mse_free_std_error = np.std((grids_pred_orig[batch_y_test_orig == 0] - 0)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size) 240 | mse_occ_free_std_error = np.std((grids_pred_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] - batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])**2)/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size) 241 | 242 | im_occ_std_error = np.std(im_occ_grids)/np.sqrt(im_occ_grids.size) 243 | im_free_std_error = np.std(im_free_grids)/np.sqrt(im_free_grids.size) 244 | im_occ_free_std_error = np.std(im_occ_grids + im_free_grids)/np.sqrt(im_occ_grids.size) 245 | 246 | print('Maximum standard error:') 247 | print(np.amax([acc_occ_std_error, acc_free_std_error, acc_occ_free_std_error, mse_occ_std_error, mse_free_std_error, mse_occ_free_std_error])) 248 | print(np.amax([im_occ_std_error, im_free_std_error, im_occ_free_std_error])) 249 | 250 | # GMM latent space visualization. 251 | all_latent_classes = np.reshape(p_m_a_np, (100,20,30)) 252 | 253 | fig = plt.figure(figsize=(10, 10)) 254 | grid = ImageGrid(fig, 111, 255 | nrows_ncols=(10, 10), 256 | axes_pad=0.1, 257 | ) 258 | 259 | for ax, im in zip(grid, all_latent_classes): 260 | ax.matshow(im, cmap='gray_r', vmin=0, vmax=1) 261 | ax.set_xticks([], []) 262 | ax.set_yticks([], []) 263 | 264 | plt.savefig('models/gmm/all_latent_classes_gmm.png') 265 | plt.show() -------------------------------------------------------------------------------- /src/driver_sensor_model/inference_kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(123) 3 | torch.backends.cudnn.deterministic = True 4 | torch.backends.cudnn.benchmark = False 5 | import torch.nn as nn 6 | import torch.utils.data as data 7 | import torchvision 8 | from torchvision import datasets, transforms 9 | 10 | import numpy as np 11 | np.random.seed(0) 12 | from matplotlib import pyplot as plt 13 | import hickle as hkl 14 | import pickle as pkl 15 | import pdb 16 | import os 17 | import multiprocessing as mp 18 | from mpl_toolkits.axes_grid1 import ImageGrid 19 | 20 | os.chdir("../..") 21 | 22 | from src.utils.data_generator import * 23 | from sklearn.cluster import KMeans 24 | from src.utils.interaction_utils import * 25 | import time 26 | 27 | from tqdm import tqdm 28 | 29 | # Load data. 30 | nt = 10 31 | num_states = 7 32 | grid_shape = (20, 30) 33 | 34 | use_cuda = torch.cuda.is_available() 35 | device = torch.device("cuda:0" if use_cuda else "cpu") 36 | 37 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset' 38 | 39 | # Test on test data. 40 | data_file_states = os.path.join(dir, 'states_test.hkl') 41 | data_file_grids = os.path.join(dir, 'label_grids_test.hkl') 42 | data_file_sources = os.path.join(dir, 'sources_test.hkl') 43 | 44 | data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 45 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False) 46 | 47 | test_loader = torch.utils.data.DataLoader(data_test, 48 | batch_size=len(data_test), shuffle=False, 49 | num_workers=mp.cpu_count()-1, pin_memory=True) 50 | 51 | for batch_x_test, batch_y_test, sources_test in test_loader: 52 | batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device) 53 | 54 | batch_x_test = batch_x_test.cpu().data.numpy() 55 | batch_y_test_orig = batch_y_test_orig.cpu().data.numpy() 56 | 57 | plt.ioff() 58 | 59 | models = 'kmeans' 60 | dir_models = os.path.join('/models/', models) 61 | if not os.path.isdir(dir_models): 62 | os.mkdir(dir_models) 63 | 64 | Ks = [100] # [4, 10, 25, 50, 100, 150] 65 | ncols = {4:2, 10:2, 25:5, 50:5, 100:10, 150:10} 66 | nrows = {4:2, 10:5, 25:5, 50:10, 100:10, 150:15} 67 | 68 | plot_train_flag = True 69 | plot_metric_clusters_flag = False 70 | plot_scatter_flag = False 71 | plot_scatter_clusters_flag = False 72 | 73 | for K in Ks: 74 | 75 | [model, p_m_a_np] = pkl.load(open(os.path.join(dir_models,"clusters_kmeans_K_" + str(K) + "_sklearn.pkl"), "rb" ) ) 76 | 77 | cluster_centers_np = model.cluster_centers_ 78 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,num_states)) 79 | cluster_centers_np = unnormalize(cluster_centers_np) 80 | 81 | if plot_train_flag: 82 | plot_train(K, p_m_a_np, cluster_centers_np, dir_models, grid_shape) 83 | 84 | start = time.time() 85 | cluster_ids_y_test = model.predict(batch_x_test) 86 | grids_pred_orig = p_m_a_np[cluster_ids_y_test] 87 | 88 | grids_plot = np.reshape(p_m_a_np, (-1,grid_shape[0],grid_shape[1])) 89 | fig, axeslist = plt.subplots(ncols=ncols[K], nrows=nrows[K]) 90 | for i in range(grids_plot.shape[0]): 91 | axeslist.ravel()[i].imshow(grids_plot[i], cmap=plt.gray()) 92 | axeslist.ravel()[i].set_xticks([]) 93 | axeslist.ravel()[i].set_xticklabels([]) 94 | axeslist.ravel()[i].set_yticks([]) 95 | axeslist.ravel()[i].set_yticklabels([]) 96 | axeslist.ravel()[i].set_aspect('equal') 97 | plt.subplots_adjust(left = 0.5, hspace = 0, wspace = 0) 98 | plt.margins(0,0) 99 | fig.savefig(os.path.join(dir_models, 'test_' + str(K) + '.png'), pad_inches=0) 100 | plt.close(fig) 101 | 102 | grids_pred = deepcopy(grids_pred_orig) 103 | grids_pred[grids_pred_orig >= 0.6] = 1.0 104 | grids_pred[np.logical_and(grids_pred_orig > 0.4, grids_pred_orig < 0.6)] = 0.5 105 | grids_pred[grids_pred_orig <= 0.4] = 0.0 106 | 107 | acc_occ_free = np.mean(grids_pred == batch_y_test_orig) 108 | mse_occ_free = np.mean((grids_pred_orig - batch_y_test_orig)**2) 109 | 110 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = \ 111 | MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1]))) # im_ocl_grids, 112 | im = np.mean(im_grids) 113 | im_occ = np.mean(im_occ_grids) 114 | im_free = np.mean(im_free_grids) 115 | im_occ_free_grids = im_occ_grids + im_free_grids 116 | im_occ_free = np.mean(im_occ_grids + im_free_grids) 117 | 118 | acc_occ = np.mean(grids_pred[batch_y_test_orig == 1] == 1) 119 | acc_free = np.mean(grids_pred[batch_y_test_orig == 0] == 0) 120 | 121 | mse_occ = np.mean((grids_pred_orig[batch_y_test_orig == 1] - 1)**2) 122 | mse_free = np.mean((grids_pred_orig[batch_y_test_orig == 0] - 0)**2) 123 | 124 | print("Metrics: ") 125 | 126 | print("Occupancy and Free Metrics: ") 127 | print("K: ", K, "Accuracy: ", acc_occ_free, "MSE: ", mse_occ_free, "IM: ", im_occ_free, "IM max: ", np.amax(im_occ_free_grids), "IM min: ", np.amin(im_occ_free_grids)) 128 | 129 | print("Occupancy Metrics: ") 130 | print("Accuracy: ", acc_occ, "MSE: ", mse_occ, "IM: ", im_occ, "IM max: ", np.amax(im_occ_grids), "IM min: ", np.amin(im_occ_grids)) 131 | 132 | print("Free Metrics: ") 133 | print("Accuracy: ", acc_free, "MSE: ", mse_free, "IM: ", im_free, "IM max: ", np.amax(im_free_grids), "IM min: ", np.amin(im_free_grids)) 134 | 135 | batch_x_np_test = batch_x_test 136 | batch_x_np_test = np.reshape(batch_x_np_test, (-1,nt,num_states)) 137 | batch_x_np_test = unnormalize(batch_x_np_test) 138 | 139 | if plot_metric_clusters_flag: 140 | plot_metric_clusters(K, batch_x_np_test, cluster_ids_y_test, im_grids2, 'IM2', dir_models, sources_test) 141 | 142 | if plot_scatter_flag: 143 | plot_scatter(K, batch_x_np_test, im_grids, 'IM', dir_models) 144 | 145 | if plot_scatter_clusters_flag: 146 | plot_scatter_clusters(K, batch_x_np_test, cluster_ids_y_test, dir_models) 147 | plt.show() 148 | 149 | # Standard error 150 | acc_occ_std_error = np.std(grids_pred[batch_y_test_orig == 1] == 1)/np.sqrt(grids_pred[batch_y_test_orig == 1].size) 151 | acc_free_std_error = np.std(grids_pred[batch_y_test_orig == 0] == 0)/np.sqrt(grids_pred[batch_y_test_orig == 0].size) 152 | acc_occ_free_std_error = np.std(grids_pred[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] == batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size) 153 | 154 | mse_occ_std_error = np.std((grids_pred_orig[batch_y_test_orig == 1] - 1)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size) 155 | mse_free_std_error = np.std((grids_pred_orig[batch_y_test_orig == 0] - 0)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size) 156 | mse_occ_free_std_error = np.std((grids_pred_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] - batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])**2)/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size) 157 | 158 | im_occ_std_error = np.std(im_occ_grids)/np.sqrt(im_occ_grids.size) 159 | im_free_std_error = np.std(im_free_grids)/np.sqrt(im_free_grids.size) 160 | im_occ_free_std_error = np.std(im_occ_grids + im_free_grids)/np.sqrt(im_occ_grids.size) 161 | 162 | print('Maximum standard error:') 163 | print(np.amax([acc_occ_std_error, acc_free_std_error, acc_occ_free_std_error, mse_occ_std_error, mse_free_std_error, mse_occ_free_std_error])) 164 | print(np.amax([im_occ_std_error, im_free_std_error, im_occ_free_std_error])) 165 | 166 | # Visualize all latent classes. 167 | all_latent_classes = np.reshape(p_m_a_np, (100,20,30)) 168 | 169 | fig = plt.figure(figsize=(10, 10)) 170 | grid = ImageGrid(fig, 111, 171 | nrows_ncols=(10, 10), 172 | axes_pad=0.1, 173 | ) 174 | 175 | for ax, im in zip(grid, all_latent_classes): 176 | ax.matshow(im, cmap='gray_r', vmin=0, vmax=1) 177 | ax.set_xticks([], []) 178 | ax.set_yticks([], []) 179 | 180 | plt.savefig('models/kmeans/all_latent_classes_kmeans.png') 181 | plt.show() -------------------------------------------------------------------------------- /src/driver_sensor_model/models_cvae.py: -------------------------------------------------------------------------------- 1 | # Architecture for the CVAE driver sensor model. Code is adapted from: https://github.com/sisl/EvidentialSparsification and 2 | # https://github.com/ritheshkumar95/pytorch-vqvae/blob/master/modules.py. 3 | 4 | seed = 123 5 | import numpy as np 6 | np.random.seed(seed) 7 | import torch 8 | import math 9 | 10 | import torch.nn as nn 11 | from src.utils.utils_model import to_var, sample_p 12 | import pdb 13 | 14 | class ResBlock(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.block = nn.Sequential( 18 | nn.ReLU(False), 19 | nn.Conv2d(dim, dim, 3, 1, 1), 20 | nn.BatchNorm2d(dim), 21 | nn.ReLU(False), 22 | nn.Conv2d(dim, dim, 1), 23 | nn.BatchNorm2d(dim) 24 | ) 25 | 26 | def forward(self, x): 27 | return x + self.block(x) 28 | 29 | class VAE(nn.Module): 30 | 31 | def __init__(self, encoder_layer_sizes_p, n_lstms, latent_size, dim): 32 | 33 | super().__init__() 34 | 35 | torch.manual_seed(seed) 36 | torch.cuda.manual_seed(seed) 37 | 38 | assert type(encoder_layer_sizes_p) == list 39 | assert type(latent_size) == int 40 | 41 | self.latent_size = latent_size 42 | self.label_size = encoder_layer_sizes_p[-1] 43 | 44 | self.encoder = Encoder(encoder_layer_sizes_p, n_lstms, latent_size, dim) 45 | self.decoder = Decoder(latent_size, self.label_size, dim) 46 | 47 | def forward(self, x, c=None): 48 | 49 | batch_size = x.size(0) 50 | 51 | # Encode the input. 52 | alpha_q, alpha_p, output_all_c = self.encoder(x, c) 53 | 54 | # Obtain all possible latent classes. 55 | z = torch.eye(self.latent_size).cuda() 56 | 57 | # Decode all latent classes. 58 | recon_x = self.decoder(z) 59 | 60 | return recon_x, alpha_q, alpha_p, self.encoder.linear_latent_q, self.encoder.linear_latent_p, output_all_c, z 61 | 62 | def inference(self, n=1, c=None, mode='sample', k=None): 63 | 64 | batch_size = n 65 | 66 | alpha_q, alpha_p, output_all_c = self.encoder(x=torch.empty((0,0)), c=c, train=False) 67 | 68 | if mode == 'sample': 69 | # Decode the mode sampled from the prior distribution. 70 | z = sample_p(alpha_p, batch_size=batch_size).view(-1,self.latent_size) 71 | elif mode == 'all': 72 | # Decode all the modes. 73 | z = torch.eye(self.latent_size).cuda() 74 | elif mode == 'most_likely': 75 | # Decode the most likely mode. 76 | z = torch.nn.functional.one_hot(torch.argmax(alpha_p, dim=1), num_classes=self.latent_size).float() 77 | elif mode == 'multimodal': 78 | # Decode a particular mode. 79 | z = torch.nn.functional.one_hot(torch.argsort(alpha_p, dim=1)[:,-k], num_classes=self.latent_size).float() 80 | 81 | recon_x = self.decoder(z) 82 | 83 | return recon_x, alpha_p, self.encoder.linear_latent_p, output_all_c, z 84 | 85 | class Encoder(nn.Module): 86 | 87 | def __init__(self, layer_sizes_p, n_lstms, latent_size, dim): 88 | 89 | super().__init__() 90 | 91 | input_dim = 1 92 | self.VQVAEBlock = nn.Sequential( 93 | nn.Conv2d(input_dim, dim, 4, 2, 1), 94 | nn.BatchNorm2d(dim), 95 | nn.ReLU(False), 96 | nn.Conv2d(dim, dim, 4, 2, 1), 97 | ResBlock(dim), 98 | ResBlock(dim) 99 | ) 100 | 101 | self.lstm = nn.LSTM(layer_sizes_p[0], 102 | layer_sizes_p[-1], 103 | num_layers=n_lstms, 104 | batch_first=True) 105 | 106 | self.linear_latent_q = nn.Linear(5*7*dim + layer_sizes_p[-1]*10, latent_size) # 10 is the time dimension. 107 | self.softmax_q = nn.Softmax(dim=-1) 108 | 109 | self.linear_latent_p = nn.Linear(layer_sizes_p[-1]*10, latent_size) # 10 is the time dimension. 110 | self.softmax_p = nn.Softmax(dim=-1) 111 | 112 | def forward(self, x=None, c=None, train=True): 113 | 114 | output_all, (full_c, _) = self.lstm(c) 115 | output_all_c = torch.reshape(output_all, (c.shape[0], -1)) 116 | alpha_p_lin = self.linear_latent_p(output_all_c) 117 | alpha_p = self.softmax_p(alpha_p_lin) 118 | 119 | if train: 120 | 121 | full_x = self.VQVAEBlock(x) 122 | full_x = full_x.view(full_x.shape[0],-1) 123 | output_all_c = torch.cat((full_x, output_all_c), dim=-1) 124 | alpha_q_lin = self.linear_latent_q(output_all_c) 125 | alpha_q = self.softmax_q(alpha_q_lin) 126 | 127 | else: 128 | alpha_q_lin = None 129 | alpha_q = None 130 | 131 | return alpha_q, alpha_p, output_all_c 132 | 133 | 134 | class Decoder(nn.Module): 135 | 136 | def __init__(self, latent_size, label_size, dim): 137 | 138 | super().__init__() 139 | self.latent_size = latent_size 140 | self.dim = dim 141 | input_dim = 1 142 | 143 | self.decode_linear = nn.Sequential( 144 | nn.Linear(self.latent_size, self.latent_size*dim), 145 | nn.ReLU(False), 146 | nn.Linear(self.latent_size*dim, self.latent_size*dim), 147 | nn.ReLU(False), 148 | ) 149 | 150 | self.VQVAEBlock = nn.Sequential( 151 | ResBlock(dim), 152 | ResBlock(dim), 153 | nn.ReLU(False), 154 | nn.ConvTranspose2d(dim, dim, (4,5), (2,3), 1), 155 | nn.BatchNorm2d(dim), 156 | nn.ReLU(False), 157 | nn.ConvTranspose2d(dim, input_dim, 1, 1, 0), 158 | nn.Sigmoid() 159 | ) 160 | 161 | def forward(self, z): 162 | latent_size_sqrt = int(math.sqrt(self.latent_size)) 163 | z_c = self.decode_linear(z) 164 | z_c = z_c.view(-1, self.dim, latent_size_sqrt, latent_size_sqrt) 165 | x = self.VQVAEBlock(z_c) 166 | return x -------------------------------------------------------------------------------- /src/driver_sensor_model/train_gmm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | random = 123 3 | torch.manual_seed(random) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import torchvision 9 | from torchvision import datasets, transforms 10 | 11 | import numpy as np 12 | np.random.seed(random) 13 | from matplotlib import pyplot as plt 14 | import hickle as hkl 15 | import pickle as pkl 16 | import pdb 17 | import os 18 | import multiprocessing as mp 19 | 20 | os.chdir("../..") 21 | 22 | from src.utils.data_generator import * 23 | from sklearn.mixture import GaussianMixture as GMM 24 | from tqdm import tqdm 25 | import time 26 | 27 | # Load data. 28 | nt = 10 29 | 30 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset/' 31 | models = 'gmm' 32 | dir_models = os.path.join(dir, models) 33 | if not os.path.isdir(dir_models): 34 | os.mkdir(dir_models) 35 | data_file_states = os.path.join(dir, 'states_shuffled_train.hkl') 36 | data_file_grids = os.path.join(dir, 'label_grids_shuffled_train.hkl') 37 | data_file_sources = os.path.join(dir, 'sources_shuffled_train.hkl') 38 | 39 | data_train = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 40 | batch_size=None, shuffle=True, sequence_start_mode='all', norm=False) 41 | print("number of unique sources: ", len(np.unique(data_train.sources))) 42 | 43 | train_loader = torch.utils.data.DataLoader(data_train, 44 | batch_size=len(data_train), shuffle=True, 45 | num_workers=mp.cpu_count()-1, pin_memory=True) 46 | 47 | use_cuda = torch.cuda.is_available() 48 | device = torch.device("cuda:0" if use_cuda else "cpu") 49 | 50 | for batch_x, batch_y, sources in train_loader: 51 | batch_x, batch_y = batch_x.to("cpu").data.numpy(), batch_y.to("cpu").data.numpy() 52 | print(batch_x.shape, batch_y.shape) 53 | 54 | # Test on validation data. 55 | data_file_states = os.path.join(dir, 'states_val.hkl') 56 | data_file_grids = os.path.join(dir, 'label_grids_val.hkl') 57 | data_file_sources = os.path.join(dir, 'sources_val.hkl') 58 | 59 | data_val = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 60 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False) 61 | 62 | val_loader = torch.utils.data.DataLoader(data_val, 63 | batch_size=len(data_val), shuffle=False, 64 | num_workers=mp.cpu_count()-1, pin_memory=True) 65 | 66 | for batch_x_val, batch_y_val, sources_val in val_loader: 67 | batch_x_val, batch_y_val = batch_x_val.to("cpu").data.numpy(), batch_y_val.to("cpu").data.numpy() 68 | 69 | print('unique values', np.unique(batch_y_val)) 70 | 71 | Ks = [100] # [4, 10, 25, 50, 100, 150] 72 | acc_nums = [] 73 | acc_nums_free_occ = [] 74 | acc_nums_occ = [] 75 | acc_nums_free = [] 76 | acc_nums_ocl = [] 77 | plot_train = False 78 | grid_shape = (20,30) 79 | 80 | for K in tqdm(Ks): 81 | 82 | model = GMM(n_components=K, covariance_type='diag', reg_covar=1e-3) 83 | model.fit(batch_x) 84 | cluster_ids_x = model.predict(batch_x) 85 | cluster_centers = model.means_ 86 | 87 | p_a_m_1 = np.zeros((K, batch_y.shape[1])) 88 | p_a_m_0 = np.zeros((K, batch_y.shape[1])) 89 | for i in range(batch_y.shape[1]): 90 | p_m_1 = np.sum(batch_y[:,i] == 1).astype(float)/batch_y.shape[0] 91 | p_m_0 = (np.sum(batch_y[:,i] == 0).astype(float)/batch_y.shape[0]) 92 | for k in range(K): 93 | p_a_m_1[k,i] = np.sum(np.logical_and(cluster_ids_x == k, batch_y[:,i] == 1)).astype(float)/batch_y.shape[0] 94 | p_a_m_0[k,i] = np.sum(np.logical_and(cluster_ids_x == k, batch_y[:,i] == 0)).astype(float)/batch_y.shape[0] 95 | 96 | if p_m_1 != 0: 97 | p_a_m_1[:,i] = p_a_m_1[:,i]/p_m_1 98 | else: 99 | p_a_m_1[:,i] = p_a_m_1[:,i]/1.0 100 | 101 | if p_m_0 != 0: 102 | p_a_m_0[:,i] = p_a_m_0[:,i]/p_m_0 103 | else: 104 | p_a_m_0[:,i] = p_a_m_0[:,i]/1.0 105 | 106 | p_m_a_np = p_a_m_1/(p_a_m_1+p_a_m_0) 107 | 108 | pkl.dump([model, p_m_a_np], open( os.path.join(dir_models, "GMM_K_" + str(K) + "_reg_covar_0_001.pkl"), "wb" ) ) 109 | 110 | cluster_centers_np = np.reshape(cluster_centers, (K,nt,7)) 111 | cluster_centers_np = unnormalize(cluster_centers_np) 112 | 113 | if plot_train: 114 | fig, (ax1, ax2, ax3) = plt.subplots(3) 115 | fig.suptitle('State Clusters') 116 | 117 | for k in range(K): 118 | fig_occ, ax_occ = plt.subplots(1) 119 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0) 120 | ax_occ.imshow(image, cmap='gray') 121 | picture_file = os.path.join(dir_models, 'cluster_' + str(k) + '.png') 122 | plt.savefig(picture_file) 123 | fig_occ.clf() 124 | 125 | ax1.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,1], label=str(k)) 126 | ax2.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,2], label=str(k)) 127 | ax3.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,3], label=str(k)) 128 | 129 | ax1.set_ylabel("Pos (m)") 130 | ax1.set_ylim(-5,120) 131 | ax1.set_xlim(-5,120) 132 | ax2.set_ylabel("Vel (m/s)") 133 | ax2.set_ylim(0,8) 134 | ax2.set_xlim(-5,120) 135 | ax3.set_xlabel("Position (m)") 136 | ax3.set_ylabel("Acc (m/s^2)") 137 | ax3.set_ylim(-3,3) 138 | ax3.set_xlim(-5,120) 139 | 140 | handles, labels = ax1.get_legend_handles_labels() 141 | fig.legend(handles, labels, loc='center right') 142 | 143 | picture_file = os.path.join(dir_models, 'state_clusters_gmm.png') 144 | fig.savefig(picture_file) 145 | 146 | del(p_m_a) 147 | del(cluster_centers_np) 148 | del(cluster_centers) 149 | del(cluster_ids) 150 | 151 | # Test on validation data. 152 | cluster_ids_y_val = model.predict(batch_x_val) 153 | 154 | grids_pred = p_m_a_np[cluster_ids_y_val] 155 | 156 | grids_pred[grids_pred >= 0.6] = 1.0 157 | grids_pred[grids_pred <= 0.4] = 0.0 158 | grids_pred[np.logical_and(grids_pred > 0.4, grids_pred < 0.6)] = 0.5 159 | grids_gt = batch_y_val 160 | 161 | acc = np.mean(grids_pred == grids_gt) 162 | acc_nums.append(acc) 163 | 164 | mask_occ = (batch_y_val == 1) 165 | acc_occ = np.mean(grids_pred[mask_occ] == grids_gt[mask_occ]) 166 | acc_nums_occ.append(acc_occ) 167 | 168 | mask_free = (batch_y_val == 0) 169 | acc_free = np.mean(grids_pred[mask_free] == grids_gt[mask_free]) 170 | acc_nums_free.append(acc_free) 171 | 172 | print("K: ", K, " Accuracy: ", acc, acc_occ, acc_free) 173 | 174 | if plot_train: 175 | plt.scatter(Ks, acc_nums, label='acc') 176 | plt.scatter(Ks, acc_nums_occ, label='acc occupied') 177 | plt.scatter(Ks, acc_nums_free, label='acc free') 178 | plt.ylim(0,1) 179 | plt.legend() 180 | plt.title('Accuracy vs Number of Clusters for Driver Sensor Model') 181 | plt.xlabel('Number of Clusters') 182 | plt.ylabel('Accuracy') 183 | plt.savefig(os.path.join(dir_models, 'acc_clusters_gmm.png')) -------------------------------------------------------------------------------- /src/driver_sensor_model/train_kmeans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | random = 123 3 | torch.manual_seed(random) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import torchvision 9 | from torchvision import datasets, transforms 10 | 11 | import numpy as np 12 | np.random.seed(random) 13 | from matplotlib import pyplot as plt 14 | import hickle as hkl 15 | import pickle as pkl 16 | import pdb 17 | import os 18 | import multiprocessing as mp 19 | 20 | os.chdir("../..") 21 | 22 | from src.utils.data_generator import * 23 | from sklearn.cluster import KMeans 24 | import time 25 | 26 | from tqdm import tqdm 27 | 28 | # Load data. 29 | nt = 10 30 | 31 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data_new_goal/driver_sensor_dataset/' 32 | models = 'kmeans' 33 | dir_models = os.path.join('/models/', models) 34 | if not os.path.isdir(dir_models): 35 | os.mkdir(dir_models) 36 | data_file_states = os.path.join(dir, 'states_shuffled_train.hkl') 37 | data_file_grids = os.path.join(dir, 'label_grids_shuffled_train.hkl') 38 | data_file_sources = os.path.join(dir, 'sources_shuffled_train.hkl') 39 | 40 | data_train = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 41 | batch_size=None, shuffle=True, sequence_start_mode='all', norm=False) 42 | print("number of unique sources: ", len(np.unique(data_train.sources))) 43 | 44 | train_loader = torch.utils.data.DataLoader(data_train, 45 | batch_size=len(data_train), shuffle=True, 46 | num_workers=mp.cpu_count()-1, pin_memory=True) 47 | 48 | use_cuda = torch.cuda.is_available() 49 | device = torch.device("cuda:0" if use_cuda else "cpu") 50 | 51 | for batch_x, batch_y, sources in train_loader: 52 | batch_x, batch_y = batch_x.to("cpu").data.numpy(), batch_y.to("cpu").data.numpy() 53 | print(batch_x.shape, batch_y.shape) 54 | 55 | data_file_states = os.path.join(dir, 'states_val.hkl') 56 | data_file_grids = os.path.join(dir, 'label_grids_val.hkl') 57 | data_file_sources = os.path.join(dir, 'sources_val.hkl') 58 | 59 | data_val = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt, 60 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False) 61 | 62 | val_loader = torch.utils.data.DataLoader(data_val, 63 | batch_size=len(data_val), shuffle=False, 64 | num_workers=mp.cpu_count()-1, pin_memory=True) 65 | 66 | for batch_x_val, batch_y_val, sources_val in val_loader: 67 | batch_x_val, batch_y_val = batch_x_val.to("cpu").data.numpy(), batch_y_val.to("cpu").data.numpy() 68 | 69 | Ks = [100] # [4, 10, 25, 50, 100, 150] 70 | acc_nums = [] 71 | acc_nums_free_occ = [] 72 | acc_nums_occ = [] 73 | acc_nums_free = [] 74 | acc_nums_ocl = [] 75 | plot_train = False 76 | grid_shape = (20,30) 77 | 78 | for K in tqdm(Ks): 79 | 80 | model = KMeans(n_clusters=K) 81 | model.fit(batch_x) 82 | 83 | cluster_ids_x = model.predict(batch_x) 84 | cluster_centers = model.cluster_centers_ 85 | 86 | p_a_m_occ = np.zeros((K, batch_y.shape[1])) 87 | p_a_m_free = np.zeros((K, batch_y.shape[1])) 88 | 89 | for i in range(batch_y.shape[1]): 90 | p_m_occ = np.sum(batch_y[:,i] == 1).astype(float)/batch_y.shape[0] 91 | p_m_free = np.sum(batch_y[:,i] == 0).astype(float)/batch_y.shape[0] 92 | 93 | for k in range(K): 94 | p_a_m_occ[k,i] = np.sum((cluster_ids_x == k) & (batch_y[:,i] == 1)).astype(float)/batch_y.shape[0] 95 | p_a_m_free[k,i] = np.sum((cluster_ids_x == k) & (batch_y[:,i] == 0)).astype(float)/batch_y.shape[0] 96 | 97 | if p_m_occ != 0: 98 | p_a_m_occ[:,i] = p_a_m_occ[:,i]/p_m_occ 99 | else: 100 | p_a_m_occ[:,i] = p_a_m_occ[:,i]/1.0 101 | 102 | if p_m_free != 0: 103 | p_a_m_free[:,i] = p_a_m_free[:,i]/p_m_free 104 | else: 105 | p_a_m_free[:,i] = p_a_m_free[:,i]/1.0 106 | 107 | p_m_a = p_a_m_occ/(p_a_m_occ + p_a_m_free) 108 | p_m_a_np = p_m_a 109 | 110 | pkl.dump([model, p_m_a_np], open(os.path.join(dir_models, "clusters_kmeans_K_" + str(K) + ".pkl"), "wb" ) ) 111 | 112 | cluster_centers_np = cluster_centers 113 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,-1)) 114 | cluster_centers_np = unnormalize(cluster_centers_np) 115 | 116 | if plot_train: 117 | fig, (ax1, ax2, ax3) = plt.subplots(3) 118 | fig.suptitle('State Clusters') 119 | 120 | for k in range(K): 121 | fig_occ, ax_occ = plt.subplots(1) 122 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0) 123 | ax_occ.imshow(image, cmap='gray') 124 | picture_file = os.path.join(dir_models, 'cluster_' + str(k) + '.png') 125 | plt.savefig(picture_file) 126 | fig_occ.clf() 127 | 128 | ax1.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,1], label=str(k)) 129 | ax2.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,2], label=str(k)) 130 | ax3.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,3], label=str(k)) 131 | 132 | ax1.set_ylabel("Pos (m)") 133 | ax1.set_ylim(-5,120) 134 | ax1.set_xlim(-5,120) 135 | ax2.set_ylabel("Vel (m/s)") 136 | ax2.set_ylim(0,8) 137 | ax2.set_xlim(-5,120) 138 | ax3.set_xlabel("Position (m)") 139 | ax3.set_ylabel("Acc (m/s^2)") 140 | ax3.set_ylim(-3,3) 141 | ax3.set_xlim(-5,120) 142 | 143 | handles, labels = ax1.get_legend_handles_labels() 144 | fig.legend(handles, labels, loc='center right') 145 | 146 | picture_file = os.path.join(dir_models, 'state_clusters.png') 147 | fig.savefig(picture_file) 148 | 149 | # Test on validation data. 150 | cluster_ids_y_val = model.predict(batch_x_val) 151 | 152 | grids_pred = p_m_a_np[cluster_ids_y_val] 153 | grids_pred[grids_pred >= 0.6] = 1.0 154 | grids_pred[grids_pred <= 0.4] = 0.0 155 | grids_pred[np.logical_and(grids_pred > 0.4, grids_pred < 0.6)] = 0.5 156 | grids_gt = batch_y_val 157 | 158 | acc = np.mean(grids_pred == grids_gt) 159 | acc_nums.append(acc) 160 | 161 | mask_occ = (batch_y_val == 1) 162 | acc_occ = np.mean(grids_pred[mask_occ] == grids_gt[mask_occ]) 163 | acc_nums_occ.append(acc_occ) 164 | 165 | mask_free = (batch_y_val == 0) 166 | acc_free = np.mean(grids_pred[mask_free] == grids_gt[mask_free]) 167 | acc_nums_free.append(acc_free) 168 | 169 | print("K: ", K, " Accuracy: ", acc, acc_occ, acc_free) 170 | 171 | if plot_train: 172 | plt.scatter(Ks, acc_nums, label='acc') 173 | plt.scatter(Ks, acc_nums_occ, label='acc occupied') 174 | plt.scatter(Ks, acc_nums_free, label='acc free') 175 | plt.legend() 176 | plt.ylim(0,1) 177 | plt.title('Accuracy vs Number of Clusters for Driver Sensor Model') 178 | plt.xlabel('Number of Clusters') 179 | plt.ylabel('Accuracy') 180 | plt.savefig(os.path.join(dir_models, 'acc_clusters.png')) -------------------------------------------------------------------------------- /src/driver_sensor_model/visualize_cvae.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "seed = 123\n", 10 | "\n", 11 | "import numpy as np\n", 12 | "np.random.seed(seed)\n", 13 | "from matplotlib import pyplot as plt\n", 14 | "from mpl_toolkits.axes_grid1 import ImageGrid\n", 15 | "import torch\n", 16 | "torch.manual_seed(seed)\n", 17 | "torch.cuda.manual_seed(seed)\n", 18 | "torch.backends.cudnn.deterministic = True\n", 19 | "torch.backends.cudnn.benchmark = False\n", 20 | "import torch.nn as nn\n", 21 | "torch.autograd.set_detect_anomaly(True)\n", 22 | "import torchvision\n", 23 | "from torchvision import datasets, transforms\n", 24 | "\n", 25 | "import hickle as hkl\n", 26 | "import pickle as pkl\n", 27 | "import pdb\n", 28 | "import os\n", 29 | "\n", 30 | "from sklearn.cluster import KMeans\n", 31 | "from sklearn.mixture import GaussianMixture as GMM\n", 32 | "import time\n", 33 | "\n", 34 | "from torch.utils.data import DataLoader\n", 35 | "\n", 36 | "from tqdm import tqdm\n", 37 | "\n", 38 | "from copy import deepcopy\n", 39 | "import io\n", 40 | "import PIL.Image\n", 41 | "import multiprocessing as mp\n", 42 | "import tikzplotlib\n", 43 | "\n", 44 | "os.chdir(\"../..\")\n", 45 | "\n", 46 | "from src.utils.utils_model import to_var\n", 47 | "from src.driver_sensor_model.models_cvae import VAE\n", 48 | "from src.utils.data_generator import *\n", 49 | "from src.utils.interaction_utils import *" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "# Load data.\n", 59 | "nt = 10\n", 60 | "num_states = 7\n", 61 | "grid_shape = (20, 30)\n", 62 | " \n", 63 | "use_cuda = torch.cuda.is_available()\n", 64 | "device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n", 65 | "\n", 66 | "dir = '/data/INTERACTION-Dataset-DR-v1_1/Processed_data_new_goal/driver_sensor_dataset/'\n", 67 | "\n", 68 | "# Test data.\n", 69 | "data_file_states = os.path.join(dir, 'states_test.hkl')\n", 70 | "data_file_grids = os.path.join(dir, 'label_grids_test.hkl')\n", 71 | "data_file_sources = os.path.join(dir, 'sources_test.hkl')\n", 72 | "\n", 73 | "data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,\n", 74 | " batch_size=None, shuffle=False, sequence_start_mode='unique', norm=True)\n", 75 | "\n", 76 | "test_loader = torch.utils.data.DataLoader(data_test,\n", 77 | " batch_size=len(data_test), shuffle=False,\n", 78 | " num_workers=mp.cpu_count()-1, pin_memory=True)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "for batch_x_test, batch_y_test, sources_test in test_loader:\n", 88 | " batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device)\n", 89 | " batch_y_test_orig = batch_y_test_orig.view(batch_y_test_orig.shape[0],1,20,30)\n", 90 | " y_full = unnormalize(batch_x_test.cpu().data.numpy(), nt)\n", 91 | " pos_x = y_full[:,:,0]\n", 92 | " pos_y = y_full[:,:,1]\n", 93 | " orientation = y_full[:,:,2]\n", 94 | " cos_theta = np.cos(orientation)\n", 95 | " sin_theta = np.sin(orientation)\n", 96 | " vel_x = y_full[:,:,3]\n", 97 | " vel_y = y_full[:,:,4]\n", 98 | " speed = np.sqrt(vel_x**2 + vel_y**2)\n", 99 | " acc_x = y_full[:,:,5]\n", 100 | " acc_y = y_full[:,:,6]\n", 101 | "\n", 102 | " # Project the acceleration on the orientation vector to get longitudinal acceleration.\n", 103 | " dot_prod = acc_x * cos_theta + acc_y * sin_theta\n", 104 | " sign = np.sign(dot_prod)\n", 105 | " acc_proj_x = dot_prod * cos_theta\n", 106 | " acc_proj_y = dot_prod * sin_theta\n", 107 | "\n", 108 | " acc_proj_sign = sign * np.sqrt(acc_proj_x**2 + acc_proj_y**2)\n", 109 | "\n", 110 | " batch_size = batch_y_test_orig.shape[0]" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "# VAE\n", 120 | "folder_vae = '/models/cvae'\n", 121 | "name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256'\n", 122 | "\n", 123 | "vae = VAE(\n", 124 | " encoder_layer_sizes_p=[7, 5],\n", 125 | " n_lstms=1,\n", 126 | " latent_size=100,\n", 127 | " decoder_layer_sizes=[256, 600], # used to be 10\n", 128 | " dim=4\n", 129 | " )\n", 130 | "\n", 131 | "vae = vae.cuda()\n", 132 | "\n", 133 | "save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt'\n", 134 | "\n", 135 | "with open(save_filename, 'rb') as f:\n", 136 | " state_dict = torch.load(f)\n", 137 | " vae.load_state_dict(state_dict)\n", 138 | "\n", 139 | "vae.eval()\n", 140 | "\n", 141 | "with torch.no_grad():\n", 142 | " recon_y_inf_most_likely, alpha_p, alpha_p_lin, full_c, z = vae.inference(n=1, c=batch_x_test, mode='most_likely')\n", 143 | " recon_x_inf, _, _, _, _ = vae.inference(n=100, c=batch_x_test, mode='all')\n", 144 | "print(recon_y_inf_most_likely.shape, batch_y_test_orig.shape)\n", 145 | "print(torch.max(recon_y_inf_most_likely), torch.min(recon_y_inf_most_likely))\n", 146 | "\n", 147 | "grid_shape = (20,30)\n", 148 | "recon_y_inf_np = np.reshape(recon_y_inf_most_likely.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 149 | "y_np = np.reshape(batch_y_test_orig.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 150 | "print(recon_y_inf_np.shape, y_np.shape)\n", 151 | "\n", 152 | "recon_y_inf_np_pred = (recon_y_inf_np >= 0.6).astype(float)\n", 153 | "recon_y_inf_np_pred[recon_y_inf_np <= 0.4] = 0.0\n", 154 | "recon_y_inf_np_pred[np.logical_and(recon_y_inf_np < 0.6, recon_y_inf_np > 0.4)] = 0.5\n", 155 | "\n", 156 | "acc = np.mean(recon_y_inf_np_pred == y_np)\n", 157 | "mse = np.mean((recon_y_inf_np - y_np)**2)\n", 158 | "print('Acc: ', acc, 'MSE: ', mse)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "# Visualize all latent classes.\n", 168 | "recon_y_inf_all, _, _, _, _ = vae.inference(n=1, c=batch_x_test, mode='all')\n", 169 | "recon_y_inf_all_np = recon_y_inf_all.cpu().data.numpy()\n", 170 | "\n", 171 | "fig = plt.figure(figsize=(10, 10))\n", 172 | "grid = ImageGrid(fig, 111,\n", 173 | " nrows_ncols=(10, 10),\n", 174 | " axes_pad=0.1,\n", 175 | " )\n", 176 | "\n", 177 | "for ax, im in zip(grid, recon_y_inf_all_np[:,0]):\n", 178 | " ax.matshow(im, cmap='gray_r', vmin=0, vmax=1)\n", 179 | " ax.set_xticks([], [])\n", 180 | " ax.set_yticks([], [])\n", 181 | "plt.savefig('all_latent_classes_vae.png') \n", 182 | "plt.show()" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# Visualize the three most likely modes VAE: slowing down.\n", 192 | "sample = np.where(np.sum(acc_proj_sign < -1.5, axis=-1) > 0)[0][16]\n", 193 | "print('sample: ', sample) \n", 194 | "print('acceleration: ', acc_proj_sign[sample])\n", 195 | "print('speed: ', speed[sample])\n", 196 | "print('probabilitiies: ', torch.sort(alpha_p[sample])[0])\n", 197 | "\n", 198 | "recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)\n", 199 | "recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 200 | "\n", 201 | "plt.gca().set_aspect('equal', adjustable='box') # 'datalim'\n", 202 | "plt.scatter(pos_x[sample], pos_y[sample])\n", 203 | "plt.savefig(str(sample) + '_traj_dec.png')\n", 204 | "tikzplotlib.save(str(sample) + '_traj_dec.tex')\n", 205 | "plt.figure()\n", 206 | "plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)\n", 207 | "plt.xticks([])\n", 208 | "plt.yticks([])\n", 209 | "plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)\n", 210 | "plt.show()\n", 211 | "\n", 212 | "recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)\n", 213 | "recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 214 | "\n", 215 | "plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)\n", 216 | "plt.xticks([])\n", 217 | "plt.yticks([])\n", 218 | "plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)\n", 219 | "plt.show()\n", 220 | "\n", 221 | "plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)\n", 222 | "plt.xticks([])\n", 223 | "plt.yticks([])\n", 224 | "plt.savefig('models/' + str(sample) + '_gt_dec.png', pad_inches=0.0)\n", 225 | "plt.show()" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": null, 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "# Visualize the most likely modes VAE: constant speed.\n", 235 | "print(np.sum(np.sum(np.logical_and(np.abs(acc_proj_sign) < 10, speed > 5.0), axis=-1) > 9), speed.shape)\n", 236 | "sample = np.where(np.logical_and(np.sum(np.logical_and(np.abs(acc_proj_sign) < 0.25, speed > 5.0), axis=-1) > 9, alpha_p.cpu().detach().numpy()[:,27] > 0.3))[0][7]\n", 237 | "print('sample: ', sample) \n", 238 | "print('acceleration: ', acc_proj_sign[sample])\n", 239 | "print('speed: ', speed[sample])\n", 240 | "print('probabilitiies: ', torch.sort(alpha_p[sample])[0])\n", 241 | "\n", 242 | "recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)\n", 243 | "recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 244 | "\n", 245 | "plt.gca().set_aspect('equal', adjustable='box') # 'datalim'\n", 246 | "plt.scatter(pos_x[sample], pos_y[sample])\n", 247 | "plt.savefig(str(sample) + '_traj_const.png')\n", 248 | "tikzplotlib.save(str(sample) + '_traj_const.tex')\n", 249 | "plt.figure()\n", 250 | "plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)\n", 251 | "plt.xticks([])\n", 252 | "plt.yticks([])\n", 253 | "plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_const.png', pad_inches=0.0)\n", 254 | "plt.show()\n", 255 | "\n", 256 | "recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)\n", 257 | "recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n", 258 | "\n", 259 | "plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)\n", 260 | "plt.xticks([])\n", 261 | "plt.yticks([])\n", 262 | "plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_const.png', pad_inches=0.0)\n", 263 | "plt.show()\n", 264 | "\n", 265 | "plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)\n", 266 | "plt.xticks([])\n", 267 | "plt.yticks([])\n", 268 | "plt.savefig('models/' + str(sample) + '_gt_const.png', pad_inches=0.0)\n", 269 | "plt.show()" 270 | ] 271 | } 272 | ], 273 | "metadata": { 274 | "kernelspec": { 275 | "display_name": "Python 3", 276 | "language": "python", 277 | "name": "python3" 278 | }, 279 | "language_info": { 280 | "codemirror_mode": { 281 | "name": "ipython", 282 | "version": 3 283 | }, 284 | "file_extension": ".py", 285 | "mimetype": "text/x-python", 286 | "name": "python", 287 | "nbconvert_exporter": "python", 288 | "pygments_lexer": "ipython3", 289 | "version": "3.6.10" 290 | } 291 | }, 292 | "nbformat": 4, 293 | "nbformat_minor": 4 294 | } 295 | -------------------------------------------------------------------------------- /src/full_pipeline/main_save_full_pipeline.py: -------------------------------------------------------------------------------- 1 | # Code to save full occlusion inference pipeline. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | try: 4 | import lanelet2 5 | use_lanelet2_lib = True 6 | except: 7 | import warnings 8 | string = "Could not import lanelet2. It must be built and sourced, " + \ 9 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details." 10 | warnings.warn(string) 11 | print("Using visualization without lanelet2.") 12 | use_lanelet2_lib = False 13 | from utils import map_vis_without_lanelet 14 | 15 | import argparse 16 | import os 17 | import time 18 | import matplotlib.pyplot as plt 19 | from matplotlib.widgets import Button 20 | import hickle as hkl 21 | import pickle as pkl 22 | import pdb 23 | import numpy as np 24 | import csv 25 | 26 | seed = 123 27 | 28 | import numpy as np 29 | np.random.seed(seed) 30 | from matplotlib import pyplot as plt 31 | import torch 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | import torch.nn as nn 37 | torch.autograd.set_detect_anomaly(True) 38 | 39 | import io 40 | import PIL.Image 41 | from tqdm import tqdm 42 | import time 43 | 44 | import argparse 45 | import pandas as pd 46 | import seaborn as sns 47 | import matplotlib.pyplot as plt 48 | from torchvision import datasets, transforms 49 | from torch.utils.data import DataLoader 50 | from torch.utils.tensorboard import SummaryWriter 51 | from torch.optim.lr_scheduler import ExponentialLR 52 | from collections import OrderedDict, defaultdict 53 | 54 | os.chdir("../..") 55 | 56 | from src.utils import dataset_reader 57 | from src.utils import dataset_types 58 | from src.utils import map_vis_lanelet2 59 | from src.utils import tracks_save 60 | # from src.utils import dict_utils 61 | from src.driver_sensor_model.models_cvae import VAE 62 | from src.utils.interaction_utils import * 63 | 64 | import torch._utils 65 | try: 66 | torch._utils._rebuild_tensor_v2 67 | except AttributeError: 68 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 69 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 70 | tensor.requires_grad = requires_grad 71 | tensor._backward_hooks = backward_hooks 72 | return tensor 73 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 74 | 75 | 76 | def update_plot(): 77 | global timestamp, track_dictionary, pedestrian_dictionary,\ 78 | data, car_ids, sensor_grids, id_grids, label_grids, driver_sensor_data, driver_sensor_state_data, driver_sensor_state, driver_sensor_state_dict,\ 79 | min_max_xy, models, results, mode, model 80 | 81 | # Update text and tracks based on current timestamp. 82 | if (timestamp < timestamp_min): 83 | pdb.set_trace() 84 | assert(timestamp <= timestamp_max), "timestamp=%i" % timestamp 85 | assert(timestamp >= timestamp_min), "timestamp=%i" % timestamp 86 | assert(timestamp % dataset_types.DELTA_TIMESTAMP_MS == 0), "timestamp=%i" % timestamp 87 | tracks_save_DR_goal_multimodal_average.update_objects_plot(timestamp, track_dict=track_dictionary, pedest_dict=pedestrian_dictionary, 88 | data=data, car_ids=car_ids, sensor_grids=sensor_grids, id_grids=id_grids, label_grids=label_grids, 89 | driver_sensor_data = driver_sensor_data, driver_sensor_state_data=driver_sensor_state_data, driver_sensor_state = driver_sensor_state, 90 | endpoint=min_max_xy, models=models, results=results, mode=mode, model=model) 91 | 92 | if __name__ == "__main__": 93 | 94 | # Provide data to be visualized. 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument("scenario_name", type=str, help="Name of the scenario (to identify map and folder for track " 97 | "files)", nargs="?") 98 | parser.add_argument("track_file_number", type=int, help="Number of the track file (int)", default=0, nargs="?") 99 | parser.add_argument("load_mode", type=str, help="Dataset to load (vehicle, pedestrian, or both)", default="both", 100 | nargs="?") 101 | parser.add_argument("--mode", type=str, help="Sensor fusion mode: evidential or average", nargs="?") 102 | parser.add_argument("--model", type=str, help="Name of the model: vae, gmm, kmeans", nargs="?") 103 | args = parser.parse_args() 104 | 105 | if args.load_mode != "vehicle" and args.load_mode != "pedestrian" and args.load_mode != "both": 106 | raise IOError("Invalid load command. Use 'vehicle', 'pedestrian', or 'both'") 107 | 108 | # Load test folder. 109 | error_string = "" 110 | tracks_dir = "/data/INTERACTION-Dataset-DR-v1_1/recorded_trackfiles" 111 | maps_dir = "/data/INTERACTION-Dataset-DR-v1_1/maps" 112 | scenario_name = 'DR_USA_Intersection_GL' 113 | home = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/" 114 | test_set_dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split' 115 | test_set_file = 'ego_test_set.csv' 116 | 117 | ego_files = [] 118 | with open(os.path.join(test_set_dir, test_set_file)) as csv_file: 119 | csv_reader = csv.reader(csv_file, delimiter=',') 120 | for row in csv_reader: 121 | ego_files.append(row[0].split(home)[-1]) 122 | 123 | number = 0 124 | ego_file = ego_files[number] 125 | 126 | # Load the models. 127 | models = dict() 128 | folder_vae = '/model/cvae/' 129 | name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256' 130 | 131 | models['vae'] = VAE( 132 | encoder_layer_sizes_p=[7, 5], 133 | n_lstms=1, 134 | latent_size=100, 135 | dim=4 136 | ) 137 | 138 | models['vae'] = models['vae'].cuda() 139 | 140 | save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt' 141 | 142 | with open(save_filename, 'rb') as f: 143 | state_dict = torch.load(f) 144 | models['vae'].load_state_dict(state_dict) 145 | models['vae'].eval() 146 | 147 | folder_kmeans = '/model/kmeans/' 148 | models['kmeans'] = pkl.load(open(os.path.join(folder_kmeans,"clusters_kmeans_K_100.pkl"), "rb" ) ) 149 | 150 | folder_gmm = '/model/gmm/' 151 | models['gmm'] = pkl.load(open(os.path.join(folder_gmm, "GMM_K_100_reg_covar_0_001.pkl"), "rb" ) ) 152 | 153 | if args.model == 'vae': 154 | folder_model = folder_vae 155 | elif args.model == 'gmm': 156 | folder_model = folder_gmm 157 | elif args.model == 'kmeans': 158 | folder_model = folder_kmeans 159 | 160 | # Get the driver sensor data. 161 | for ego_file in tqdm(ego_files): 162 | print(ego_file) 163 | ego_run = int(ego_file.split('_run_')[-1].split('_ego_')[0]) 164 | if scenario_name[-2:] == 'GL': 165 | middle = scenario_name 166 | track_file_number = ego_file.split("_run_")[0][-3:] 167 | if os.path.exists(os.path.join(home, ego_file)): 168 | data = pkl.load(open(os.path.join(home, ego_file), 'rb')) 169 | else: 170 | continue 171 | 172 | driver_sensor_data = dict() 173 | driver_sensor_state_data = dict() 174 | driver_sensor_state = dict() 175 | for file in os.listdir(os.path.join(home)): 176 | if scenario_name[-2:] == 'GL': 177 | file_split = file.split("_run_") 178 | if track_file_number != file_split[0][-3:]: 179 | continue 180 | [run, vtype, _, vid] = file_split[-1][:-4].split('_') 181 | run = int(run) 182 | vid = int(vid) 183 | 184 | if (run == ego_run) and (vtype == 'ref'): 185 | driver_sensor_data[vid] = pkl.load(open(os.path.join(home, file), 'rb')) 186 | for item in driver_sensor_data[vid][1:]: 187 | timestamp = item[0] 188 | x = item[2] 189 | y = item[3] 190 | orientation = item[4] 191 | vx = item[5] 192 | vy = item[6] 193 | ax = item[7] 194 | ay = item[8] 195 | if vid in driver_sensor_state_data.keys(): 196 | driver_sensor_state_data[vid].append(np.array([timestamp, x, y, orientation, vx, vy, ax, ay])) 197 | else: 198 | driver_sensor_state_data[vid] = [np.array([timestamp, x, y, orientation, vx, vy, ax, ay])] 199 | 200 | lanelet_map_ending = ".osm" 201 | lanelet_map_file = maps_dir + "/" + scenario_name + lanelet_map_ending 202 | scenario_dir = tracks_dir + "/" + scenario_name 203 | track_file_prefix = "vehicle_tracks_" 204 | track_file_ending = ".csv" 205 | track_file_name = scenario_dir + "/" + track_file_prefix + str(track_file_number).zfill(3) + track_file_ending 206 | pedestrian_file_prefix = "pedestrian_tracks_" 207 | pedestrian_file_ending = ".csv" 208 | pedestrian_file_name = scenario_dir + "/" + pedestrian_file_prefix + str(track_file_number).zfill(3) + pedestrian_file_ending 209 | if not os.path.isdir(tracks_dir): 210 | error_string += "Did not find track file directory \"" + tracks_dir + "\"\n" 211 | if not os.path.isdir(maps_dir): 212 | error_string += "Did not find map file directory \"" + tracks_dir + "\"\n" 213 | if not os.path.isdir(scenario_dir): 214 | error_string += "Did not find scenario directory \"" + scenario_dir + "\"\n" 215 | if not os.path.isfile(lanelet_map_file): 216 | error_string += "Did not find lanelet map file \"" + lanelet_map_file + "\"\n" 217 | if not os.path.isfile(track_file_name): 218 | error_string += "Did not find track file \"" + track_file_name + "\"\n" 219 | if not os.path.isfile(pedestrian_file_name): 220 | flag_ped = 0 221 | else: 222 | flag_ped = 1 223 | if error_string != "": 224 | error_string += "Type --help for help." 225 | raise IOError(error_string) 226 | 227 | # Load and draw the lanelet2 map, either with or without the lanelet2 library. 228 | lat_origin = 0. # Origin is necessary to correctly project the lat lon values in the osm file to the local. 229 | lon_origin = 0. 230 | print("Loading map...") 231 | fig, axes = plt.subplots(1, 1) 232 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin) 233 | 234 | # Expand map size. 235 | min_max_xy[0:2] -= 25.0 236 | min_max_xy[2:] += 25.0 237 | 238 | # Load the tracks. 239 | print("Loading tracks...") 240 | track_dictionary = None 241 | pedestrian_dictionary = None 242 | if args.load_mode == 'both': 243 | track_dictionary = dataset_reader.read_tracks(track_file_name) 244 | if flag_ped: 245 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name) 246 | 247 | elif args.load_mode == 'vehicle': 248 | track_dictionary = dataset_reader.read_tracks(track_file_name) 249 | elif args.load_mode == 'pedestrian': 250 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name) 251 | 252 | timestamp_min = 1e9 253 | timestamp_max = 0 254 | 255 | # Set the time to the run and get the sensor vehicles. 256 | ego_id = data[1][1] 257 | car_ids = {} 258 | sensor_grids = {} 259 | label_grids = {} 260 | id_grids = {} 261 | for i in range(1,len(data)): 262 | timestamp_min = min(timestamp_min, data[i][0]) 263 | timestamp_max = max(timestamp_max, data[i][0]) 264 | 265 | if data[i][0] in car_ids.keys(): 266 | car_ids[data[i][0]].append(data[i][1]) 267 | else: 268 | car_ids[data[i][0]] = [data[i][1]] 269 | 270 | if data[i][1] == ego_id: 271 | sensor_grids[data[i][0]] = data[i][-2] 272 | id_grids[data[i][0]] = data[i][-3] 273 | label_grids[data[i][0]] = data[i][-4] 274 | 275 | args.start_timestamp = timestamp_min 276 | 277 | # Results. 278 | results = dict() 279 | results['ego_sensor'] = [] 280 | results['ego_label'] = [] 281 | results['vae'] = [] 282 | results['timestamp'] = [] 283 | results['source'] = [] 284 | results['run_time'] = [] 285 | results['all_latent_classes'] = [] 286 | results['ref_local_xy'] = [] 287 | results['ego_local_xy'] = [] 288 | results['alpha_p'] = [] 289 | results['ego_sensor_dst'] = [] 290 | results['endpoint'] = [] 291 | results['res'] = [] 292 | mode = args.mode 293 | model = args.model 294 | 295 | print("Saving...") 296 | timestamp = args.start_timestamp 297 | 298 | while timestamp < timestamp_max: 299 | update_plot() 300 | timestamp += dataset_types.DELTA_TIMESTAMP_MS 301 | 302 | # Clear all variables for the next scenario 303 | del(timestamp_min) 304 | del(timestamp_max) 305 | del(timestamp) 306 | 307 | for i in range(len(results['timestamp'])): 308 | results['source'].append(ego_file[:-4]) 309 | 310 | # Save the data for each run. 311 | folder_model_new = os.path.join(folder_model, 'full_pipeline_' + model + '_' + mode) 312 | if not os.path.isdir(folder_model_new): 313 | os.mkdir(folder_model_new) 314 | pkl.dump(results, open(os.path.join(folder_model_new, ego_file[:-4] + '_ego_results.pkl'), "wb")) 315 | 316 | -------------------------------------------------------------------------------- /src/full_pipeline/main_visualize_full_pipeline.py: -------------------------------------------------------------------------------- 1 | # Code to visualize full occlusion inference pipeline. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | try: 4 | import lanelet2 5 | use_lanelet2_lib = True 6 | except: 7 | import warnings 8 | string = "Could not import lanelet2. It must be built and sourced, " + \ 9 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details." 10 | warnings.warn(string) 11 | print("Using visualization without lanelet2.") 12 | use_lanelet2_lib = False 13 | from utils import map_vis_without_lanelet 14 | 15 | import argparse 16 | import os 17 | import time 18 | import matplotlib.pyplot as plt 19 | from matplotlib.widgets import Button 20 | import hickle as hkl 21 | import pickle as pkl 22 | import pdb 23 | import numpy as np 24 | import csv 25 | 26 | seed = 123 27 | 28 | import numpy as np 29 | np.random.seed(seed) 30 | from matplotlib import pyplot as plt 31 | import torch 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.backends.cudnn.deterministic = True 35 | torch.backends.cudnn.benchmark = False 36 | import torch.nn as nn 37 | torch.autograd.set_detect_anomaly(True) 38 | 39 | import io 40 | import PIL.Image 41 | from tqdm import tqdm 42 | import time 43 | 44 | import argparse 45 | import pandas as pd 46 | import seaborn as sns 47 | import matplotlib.pyplot as plt 48 | from collections import OrderedDict, defaultdict 49 | 50 | from src.utils import dataset_reader 51 | from src.utils import dataset_types 52 | from src.utils import map_vis_lanelet2 53 | from src.utils import tracks_vis 54 | # from src.utils import dict_utils 55 | from src.driver_sensor_model.models_cvae import VAE 56 | from src.utils.interaction_utils import * 57 | 58 | import torch._utils 59 | try: 60 | torch._utils._rebuild_tensor_v2 61 | except AttributeError: 62 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 63 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 64 | tensor.requires_grad = requires_grad 65 | tensor._backward_hooks = backward_hooks 66 | return tensor 67 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 68 | 69 | 70 | def update_plot(): 71 | global fig, timestamp, title_text, track_dictionary, patches_dict, text_dict, axes, pedestrian_dictionary,\ 72 | data, car_ids, sensor_grids, id_grids, label_grids, grids_dict, driver_sensor_data, driver_sensor_state_data, driver_sensor_state, driver_sensor_state_dict,\ 73 | min_max_xy, models, results, mode, model 74 | 75 | # Update text and tracks based on current timestamp. 76 | if (timestamp < timestamp_min): 77 | pdb.set_trace() 78 | assert(timestamp <= timestamp_max), "timestamp=%i" % timestamp 79 | assert(timestamp >= timestamp_min), "timestamp=%i" % timestamp 80 | assert(timestamp % dataset_types.DELTA_TIMESTAMP_MS == 0), "timestamp=%i" % timestamp 81 | title_text.set_text("\nts = {}".format(timestamp)) 82 | tracks_vis_DR_goal_multimodal_average.update_objects_plot(timestamp, patches_dict, text_dict, axes, track_dict=track_dictionary, pedest_dict=pedestrian_dictionary, 83 | data=data, car_ids=car_ids, sensor_grids=sensor_grids, id_grids=id_grids, label_grids=label_grids, grids_dict=grids_dict, 84 | driver_sensor_data = driver_sensor_data, driver_sensor_state_data=driver_sensor_state_data, driver_sensor_state = driver_sensor_state, driver_sensor_state_dict=driver_sensor_state_dict, 85 | endpoint=min_max_xy, models=models, mode=mode, model=model) 86 | 87 | fig.canvas.draw() 88 | 89 | 90 | def start_playback(): 91 | global timestamp, timestamp_min, timestamp_max, playback_stopped 92 | playback_stopped = False 93 | plt.ion() 94 | while timestamp < timestamp_max and not playback_stopped: 95 | timestamp += dataset_types.DELTA_TIMESTAMP_MS 96 | start_time = time.time() 97 | update_plot() 98 | end_time = time.time() 99 | diff_time = end_time - start_time 100 | # plt.pause(max(0.001, dataset_types.DELTA_TIMESTAMP_MS / 1000. - diff_time)) 101 | plt.ioff() 102 | 103 | 104 | class FrameControlButton(object): 105 | def __init__(self, position, label): 106 | self.ax = plt.axes(position) 107 | self.label = label 108 | self.button = Button(self.ax, label) 109 | self.button.on_clicked(self.on_click) 110 | 111 | def on_click(self, event): 112 | global timestamp, timestamp_min, timestamp_max, playback_stopped 113 | 114 | if self.label == "play": 115 | if not playback_stopped: 116 | return 117 | else: 118 | start_playback() 119 | return 120 | playback_stopped = True 121 | if self.label == "<<": 122 | timestamp -= 10*dataset_types.DELTA_TIMESTAMP_MS 123 | elif self.label == "<": 124 | timestamp -= dataset_types.DELTA_TIMESTAMP_MS 125 | elif self.label == ">": 126 | timestamp += dataset_types.DELTA_TIMESTAMP_MS 127 | elif self.label == ">>": 128 | timestamp += 10*dataset_types.DELTA_TIMESTAMP_MS 129 | timestamp = min(timestamp, timestamp_max) 130 | timestamp = max(timestamp, timestamp_min) 131 | update_plot() 132 | 133 | if __name__ == "__main__": 134 | 135 | # Provide data to be visualized. 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument("scenario_name", type=str, help="Name of the scenario (to identify map and folder for track " 138 | "files)", nargs="?") 139 | parser.add_argument("track_file_number", type=int, help="Number of the track file (int)", default=0, nargs="?") 140 | parser.add_argument("load_mode", type=str, help="Dataset to load (vehicle, pedestrian, or both)", default="both", 141 | nargs="?") 142 | parser.add_argument("--start_timestamp", type=int, nargs="?") 143 | parser.add_argument("--mode", type=str, default='evidential', help="Sensor fusion mode: evidential or average", nargs="?") 144 | parser.add_argument("--model", type=str, default='vae', help="Name of the model: vae, gmm, kmeans", nargs="?") 145 | args = parser.parse_args() 146 | 147 | if args.load_mode != "vehicle" and args.load_mode != "pedestrian" and args.load_mode != "both": 148 | raise IOError("Invalid load command. Use 'vehicle', 'pedestrian', or 'both'") 149 | 150 | # Load test folder. 151 | error_string = "" 152 | tracks_dir = "/data/INTERACTION-Dataset-DR-v1_1/recorded_trackfiles" 153 | maps_dir = "/data/INTERACTION-Dataset-DR-v1_1/maps" 154 | scenario_name = 'DR_USA_Intersection_GL' 155 | home = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/" 156 | test_set_dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split' 157 | test_set_file = 'ego_test_set.csv' 158 | 159 | ego_files = [] 160 | with open(os.path.join(test_set_dir, test_set_file)) as csv_file: 161 | csv_reader = csv.reader(csv_file, delimiter=',') 162 | for row in csv_reader: 163 | ego_files.append(row[0].split(home)[-1]) 164 | 165 | number = 0 166 | ego_file = ego_files[number] 167 | 168 | # Load the models. 169 | models = dict() 170 | folder_vae = '/model/cvae/' 171 | name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256' 172 | 173 | models['vae'] = VAE( 174 | encoder_layer_sizes_p=[7, 5], 175 | n_lstms=1, 176 | latent_size=100, 177 | dim=4 178 | ) 179 | 180 | models['vae'] = models['vae'].cuda() 181 | 182 | save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt' 183 | 184 | with open(save_filename, 'rb') as f: 185 | state_dict = torch.load(f) 186 | models['vae'].load_state_dict(state_dict) 187 | models['vae'].eval() 188 | 189 | folder_kmeans = '/model/kmeans/' 190 | models['kmeans'] = pkl.load(open(os.path.join(folder_kmeans,"clusters_kmeans_K_100.pkl"), "rb" ) ) 191 | 192 | folder_gmm = '/model/gmm/' 193 | models['gmm'] = pkl.load(open(os.path.join(folder_gmm, "GMM_K_100_reg_covar_0_001.pkl"), "rb" ) ) 194 | 195 | if args.model == 'vae': 196 | folder_model = folder_vae 197 | elif args.model == 'gmm': 198 | folder_model = folder_gmm 199 | elif args.model == 'kmeans': 200 | folder_model = folder_kmeans 201 | 202 | for ego_file in tqdm(ego_files): 203 | print(ego_file) 204 | 205 | # Visualize the paper and appendix scenarios. 206 | if (ego_file != 'DR_USA_Intersection_GL_037_run_54_ego_vehicle_83.pkl') and (ego_file != 'DR_USA_Intersection_GL_021_run_60_ego_vehicle_108.pkl'): 207 | continue 208 | 209 | ego_run = int(ego_file.split('_run_')[-1].split('_ego_')[0]) 210 | if scenario_name[-2:] == 'VA': 211 | track_file_number = int(ego_file.split(scenario_name)[-1][1:4]) 212 | middle = scenario_name + "/" + scenario_name + "_00" + str(track_file_number) 213 | if os.path.exists(os.path.join(home, middle, ego_file)): 214 | data = hkl.load(os.path.join(home, middle, ego_file)) 215 | else: 216 | continue 217 | elif scenario_name[-2:] == 'GL': 218 | middle = scenario_name 219 | track_file_number = ego_file.split("_run_")[0][-3:] 220 | if os.path.exists(os.path.join(home, ego_file)): 221 | data = pkl.load(open(os.path.join(home, ego_file), 'rb')) 222 | else: 223 | continue 224 | 225 | # Get the driver sensor data. 226 | driver_sensor_data = dict() 227 | driver_sensor_state_data = dict() 228 | driver_sensor_state = dict() 229 | for file in os.listdir(os.path.join(home)): 230 | if scenario_name[-2:] == 'VA': 231 | file_split = file.split(scenario_name + "_00" + str(track_file_number) + "_run_") 232 | [run, vtype, _, vid] = file_split[-1][:-4].split('_') 233 | elif scenario_name[-2:] == 'GL': 234 | file_split = file.split("_run_") 235 | if track_file_number != file_split[0][-3:]: 236 | continue 237 | [run, vtype, _, vid] = file_split[-1][:-4].split('_') 238 | run = int(run) 239 | vid = int(vid) 240 | 241 | if (run == ego_run) and (vtype == 'ref'): 242 | driver_sensor_data[vid] = pkl.load(open(os.path.join(home, file), 'rb')) 243 | for item in driver_sensor_data[vid][1:]: 244 | timestamp = item[0] 245 | x = item[2] 246 | y = item[3] 247 | orientation = item[4] 248 | vx = item[5] 249 | vy = item[6] 250 | ax = item[7] 251 | ay = item[8] 252 | if vid in driver_sensor_state_data.keys(): 253 | driver_sensor_state_data[vid].append(np.array([timestamp, x, y, orientation, vx, vy, ax, ay])) 254 | else: 255 | driver_sensor_state_data[vid] = [np.array([timestamp, x, y, orientation, vx, vy, ax, ay])] 256 | 257 | lanelet_map_ending = ".osm" 258 | lanelet_map_file = maps_dir + "/" + scenario_name + lanelet_map_ending 259 | scenario_dir = tracks_dir + "/" + scenario_name 260 | track_file_prefix = "vehicle_tracks_" 261 | track_file_ending = ".csv" 262 | track_file_name = scenario_dir + "/" + track_file_prefix + str(track_file_number).zfill(3) + track_file_ending 263 | pedestrian_file_prefix = "pedestrian_tracks_" 264 | pedestrian_file_ending = ".csv" 265 | pedestrian_file_name = scenario_dir + "/" + pedestrian_file_prefix + str(track_file_number).zfill(3) + pedestrian_file_ending 266 | if not os.path.isdir(tracks_dir): 267 | error_string += "Did not find track file directory \"" + tracks_dir + "\"\n" 268 | if not os.path.isdir(maps_dir): 269 | error_string += "Did not find map file directory \"" + tracks_dir + "\"\n" 270 | if not os.path.isdir(scenario_dir): 271 | error_string += "Did not find scenario directory \"" + scenario_dir + "\"\n" 272 | if not os.path.isfile(lanelet_map_file): 273 | error_string += "Did not find lanelet map file \"" + lanelet_map_file + "\"\n" 274 | if not os.path.isfile(track_file_name): 275 | error_string += "Did not find track file \"" + track_file_name + "\"\n" 276 | if not os.path.isfile(pedestrian_file_name): 277 | flag_ped = 0 278 | else: 279 | flag_ped = 1 280 | if error_string != "": 281 | error_string += "Type --help for help." 282 | raise IOError(error_string) 283 | 284 | # Create a figure. 285 | fig, axes = plt.subplots(1, 1) 286 | fig.canvas.set_window_title("Interaction Dataset Visualization") 287 | 288 | # Load and draw the lanelet2 map, either with or without the lanelet2 library. 289 | lat_origin = 0. # Origin is necessary to correctly project the lat lon values in the osm file to the local. 290 | lon_origin = 0. 291 | print("Loading map...") 292 | if use_lanelet2_lib: 293 | projector = lanelet2.projection.UtmProjector(lanelet2.io.Origin(lat_origin, lon_origin)) 294 | laneletmap = lanelet2.io.load(lanelet_map_file, projector) 295 | map_vis_lanelet2.draw_lanelet_map(laneletmap, axes) 296 | else: 297 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin) 298 | 299 | # Expand map size. 300 | min_max_xy[0:2] -= 25.0 301 | min_max_xy[2:] += 25.0 302 | 303 | # Load the tracks. 304 | print("Loading tracks...") 305 | track_dictionary = None 306 | pedestrian_dictionary = None 307 | if args.load_mode == 'both': 308 | track_dictionary = dataset_reader.read_tracks(track_file_name) 309 | if flag_ped: 310 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name) 311 | 312 | elif args.load_mode == 'vehicle': 313 | track_dictionary = dataset_reader.read_tracks(track_file_name) 314 | elif args.load_mode == 'pedestrian': 315 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name) 316 | 317 | if (pedestrian_dictionary == None) or len(pedestrian_dictionary.items()) == 0: 318 | continue 319 | 320 | timestamp_min = 1e9 321 | timestamp_max = 0 322 | 323 | # Set the time to the run and get the sensor vehicles 324 | ego_id = data[1][1] 325 | car_ids = {} 326 | sensor_grids = {} 327 | label_grids = {} 328 | id_grids = {} 329 | for i in range(1,len(data)): 330 | timestamp_min = min(timestamp_min, data[i][0]) 331 | timestamp_max = max(timestamp_max, data[i][0]) 332 | 333 | if data[i][0] in car_ids.keys(): 334 | car_ids[data[i][0]].append(data[i][1]) 335 | else: 336 | car_ids[data[i][0]] = [data[i][1]] 337 | 338 | if data[i][1] == ego_id: 339 | sensor_grids[data[i][0]] = data[i][-2] 340 | id_grids[data[i][0]] = data[i][-3] 341 | label_grids[data[i][0]] = data[i][-4] 342 | 343 | if ego_file == 'DR_USA_Intersection_GL_037_run_54_ego_vehicle_83.pkl': 344 | args.start_timestamp = 117500 345 | elif ego_file == 'DR_USA_Intersection_GL_021_run_60_ego_vehicle_108.pkl': 346 | args.start_timestamp = 168700 347 | else: 348 | args.start_timestamp = timestamp_min 349 | 350 | mode = args.mode 351 | model = args.model 352 | 353 | button_pp = FrameControlButton([0.2, 0.05, 0.05, 0.05], '<<') 354 | button_p = FrameControlButton([0.27, 0.05, 0.05, 0.05], '<') 355 | button_f = FrameControlButton([0.4, 0.05, 0.05, 0.05], '>') 356 | button_ff = FrameControlButton([0.47, 0.05, 0.05, 0.05], '>>') 357 | 358 | button_play = FrameControlButton([0.6, 0.05, 0.1, 0.05], 'play') 359 | button_pause = FrameControlButton([0.71, 0.05, 0.1, 0.05], 'pause') 360 | 361 | # Storage for track visualization. 362 | patches_dict = dict() 363 | text_dict = dict() 364 | grids_dict = dict() 365 | driver_sensor_state_dict = dict() 366 | 367 | # Visualize tracks. 368 | print("Plotting...") 369 | timestamp = args.start_timestamp 370 | title_text = fig.suptitle("") 371 | playback_stopped = True 372 | update_plot() 373 | plt.show() 374 | 375 | # Clear all variables for the next scenario. 376 | plt.clf() 377 | plt.close() 378 | del(timestamp_min) 379 | del(timestamp_max) 380 | del(timestamp) 381 | 382 | -------------------------------------------------------------------------------- /src/preprocess/generate_data.py: -------------------------------------------------------------------------------- 1 | # INTERACTION dataset processing code. Some code snippets are adapted from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | import argparse 4 | try: 5 | import lanelet2 6 | use_lanelet2_lib = True 7 | except: 8 | import warnings 9 | string = "Could not import lanelet2. It must be built and sourced, " + \ 10 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details." 11 | warnings.warn(string) 12 | print("Using visualization without lanelet2.") 13 | use_lanelet2_lib = False 14 | from utils import map_vis_without_lanelet 15 | 16 | from utils import dataset_reader 17 | from utils import dataset_types 18 | # from utils import dict_utils 19 | from utils import map_vis_lanelet2 20 | from utils.grid_utils import SceneObjects, global_grid, AllObjects, generateLabelGrid, generateSensorGrid 21 | from utils.dataset_types import Track, MotionState 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | import csv 25 | from collections import defaultdict 26 | import os 27 | import pickle as pkl 28 | from datetime import datetime 29 | import glob 30 | import random 31 | from tqdm import tqdm 32 | import time 33 | np.random.seed(123) 34 | 35 | 36 | def getstate(timestamp, track_dict, id): 37 | for key, value in track_dict.items(): 38 | if key==id: 39 | return value.motion_states[timestamp] 40 | 41 | def getmap(maps_dir, scenario_data): 42 | # load and draw the lanelet2 map, either with or without the lanelet2 library 43 | fig, axes = plt.subplots(1) 44 | lat_origin = 0. # origin is necessary to correctly project the lat lon values in the osm file to the local 45 | lon_origin = 0. # coordinates in which the tracks are provided; we decided to use (0|0) for every scenario 46 | lanelet_map_ending = ".osm" 47 | lanelet_map_file = maps_dir + "/" + scenario_data + lanelet_map_ending 48 | # print("Loading map...") 49 | if use_lanelet2_lib: 50 | projector = lanelet2.projection.UtmProjector(lanelet2.io.Origin(lat_origin, lon_origin)) 51 | laneletmap = lanelet2.io.load(lanelet_map_file, projector) 52 | map_vis_lanelet2.draw_lanelet_map(laneletmap, axes) 53 | else: 54 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin) 55 | plt.close(fig) 56 | return min_max_xy 57 | 58 | # Get data for the observed drivers. 59 | def vis_state(vis_datas, ref_datas, object_id, label_grid_ego, prev_id_grid_ego, sensor_grid_ego, prev_sensor_grid_ego, track_dict, stamp, start_timestamp, track_pedes_dict=None, pedes_id=None): # For sensor vis & vis vis. Not for ego vis 60 | global vis_ax, vis_ay, vis_ids, ego_id, num_vis, res, gridglobal_x, gridglobal_y 61 | 62 | # Create a mask for where there is occupied space in the sensor grid that is not the ego vehicle. 63 | mask = np.where(np.logical_and(sensor_grid_ego==1., label_grid_ego[3]!=ego_id), True, False) 64 | 65 | # Get the visible car ids around the ego vehicle. 66 | vis_temp = np.array(np.unique(label_grid_ego[3,mask]),dtype=int) 67 | 68 | # Remove the pedestrians from the labels. 69 | vis_temp = vis_temp[vis_temp >= 0] 70 | 71 | # Create a mask for the previous timestamp sensor grid where there is occupied space that is not the ego vehicle. 72 | mask = np.where(np.logical_and(prev_sensor_grid_ego==1., prev_id_grid_ego!=ego_id), True, False) 73 | prev_vis_temp = np.array(np.unique(prev_id_grid_ego[mask]),dtype=int) 74 | prev_vis_temp = prev_vis_temp[prev_vis_temp >= 0] 75 | 76 | # Compute the number of new visible ids at this timestamp. 77 | num_new = 0 78 | # Loop through the visible ids at the current timestamp and add them to vis_ids. 79 | for vis_t in vis_temp: 80 | if vis_t not in vis_ids: 81 | # Add new visible vehicle to this vis_ids array. 82 | vis_ids.append(vis_t) 83 | num_new += 1 84 | 85 | # All visible cars till the current timestamp. 86 | num_vis = len(vis_ids) 87 | vis_ms = [] 88 | prev_vis_ms = [] 89 | widths = [] 90 | lengths = [] 91 | 92 | for id in vis_ids: 93 | # Check if the ID is present in the current and previous timestamp. This is important for acceleration computations. 94 | if id in vis_temp and id in prev_vis_temp: 95 | vis_ms.append(getstate(stamp, track_dict, id)) 96 | widths.append(track_dict[id].width) 97 | lengths.append(track_dict[id].length) 98 | prev_vis_ms.append(getstate(stamp-100, track_dict, id)) 99 | else: 100 | # If the ID was not present in a previous timestamp, set to None. 101 | vis_ms.append(None) 102 | prev_vis_ms.append(None) 103 | widths.append(None) 104 | lengths.append(None) 105 | 106 | # If there are new visible cars, extend the lists for accelerations and data. 107 | if num_new != 0: 108 | vis_ax = vis_ax+ [0]*num_new 109 | vis_ay = vis_ay+ [0]*num_new 110 | vis_datas = vis_datas + [[]]*num_new 111 | ref_datas = ref_datas + [[]]*num_new 112 | 113 | # Compute the accelerations for all the visible vehicles. 114 | for k in range(num_vis): 115 | # Only compute the acceleration, if the there are two timestamps of data for the vehicle. 116 | if vis_ms[k] is not None: 117 | vis_ax[k] = (vis_ms[k].vx - prev_vis_ms[k].vx) / 0.1 118 | vis_ay[k] = (vis_ms[k].vy - prev_vis_ms[k].vy) / 0.1 119 | 120 | # Loop through all the visible vehicles. 121 | for k in range(num_vis): 122 | # Check if the vehicle is visible in at least two timestamps. 123 | if vis_ids[k] in vis_temp and vis_ids[k] in prev_vis_temp : 124 | 125 | # Generate label and sensor grids for the observed driver. 126 | label_grid_ref, center_ref, _, _, pre_local_x_ref, pre_local_y_ref = generateLabelGrid(stamp, track_dict, vis_ids[k], object_id, ego_flag=False, res=res, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id) 127 | width = widths[k] 128 | length = lengths[k] 129 | sensor_grid_ref, _, _ = generateSensorGrid(label_grid_ref, pre_local_x_ref, pre_local_y_ref, vis_ms[k], width, length, res=res, ego_flag=False) 130 | 131 | # Mark the ego car as free. 132 | sensor_grid_ref[label_grid_ref[0] == 2] = 0 133 | label_grid_ref[0] = np.where(label_grid_ref[0]==2., 0. ,label_grid_ref[0]) 134 | 135 | # Store data for visible vehicles for the ego vehicle and the observed driver. 136 | vis_step_data = [stamp, vis_ids[k], vis_ms[k].x, vis_ms[k].y, vis_ms[k].psi_rad, vis_ms[k].vx, vis_ms[k].vy, vis_ax[k], vis_ay[k], label_grid_ref[0], label_grid_ref[3], sensor_grid_ref, np.nan] 137 | ref_step_data = [stamp, vis_ids[k], vis_ms[k].x, vis_ms[k].y, vis_ms[k].psi_rad, vis_ms[k].vx, vis_ms[k].vy, vis_ax[k], vis_ay[k], label_grid_ref[0], sensor_grid_ref] 138 | ref_datas[k] = ref_datas[k]+[ref_step_data] 139 | vis_datas[k] = vis_datas[k]+[vis_step_data] 140 | 141 | # Note ref_datas may include unseen vehicles (or vehicles seen for just one timestamp). 142 | return vis_datas, ref_datas 143 | 144 | def Dataprocessing(): 145 | global vis_ids, vis_ax, vis_ay, ref_ax, ref_ay, ego_id, res, gridglobal_x, gridglobal_y 146 | 147 | main_folder = '/data/INTERACTION-Dataset-DR-v1_1/' 148 | scenarios = ['DR_USA_Intersection_GL'] 149 | 150 | # Total number of track files. 151 | num_files = [60] 152 | for scene, nth_scene in zip(scenarios, num_files): 153 | for i in tqdm(range(nth_scene)): 154 | 155 | i_str = ['%03d' % i][0] 156 | filename = os.path.join(main_folder, 'recorded_trackfiles/'+scene+'/'+'vehicle_tracks_'+ i_str +'.csv') 157 | track_dict = dataset_reader.read_tracks(filename) 158 | filename_pedes = os.path.join(main_folder, 'recorded_trackfiles/'+scene+'/'+'pedestrian_tracks_'+ i_str +'.csv') 159 | 160 | if os.path.exists(filename_pedes): 161 | track_pedes_dict = dataset_reader.read_pedestrian(filename_pedes) 162 | else: 163 | track_pedes_dict = None 164 | 165 | run_count = 0 166 | 167 | maps_dir = os.path.join(main_folder, "maps") 168 | min_max_xy = getmap(maps_dir, scene) 169 | xminGPS = min_max_xy[0] 170 | xmaxGPS = min_max_xy[2] 171 | yminGPS = min_max_xy[1] 172 | ymaxGPS = min_max_xy[3] 173 | 174 | res = 1. 175 | gridglobal_x,gridglobal_y = global_grid(np.array([xminGPS,yminGPS]),np.array([xmaxGPS,ymaxGPS]),res) 176 | 177 | vehobjects, pedesobjects = AllObjects(track_dict, track_pedes_dict) 178 | 179 | processed_file = glob.glob(os.path.join(main_folder, '/Processed_data_new_goal/pkl/DR_USA_Intersection_GL_'+i_str+'*_ego_*')) 180 | processed_id = [ int(file.split('_')[-1][:-4]) for file in processed_file] 181 | 182 | num = 0 183 | sampled_key = [id for id in vehobjects if id not in processed_id][num:] 184 | run_count = num 185 | sampled_key = np.random.choice(sampled_key,np.minimum(100, len(sampled_key)),replace=False) 186 | 187 | for key, value in track_dict.items(): 188 | assert isinstance(value, Track) 189 | if key in sampled_key: 190 | 191 | start_time = datetime.now() 192 | ego_data = [['timestamp', 'car_id', 'x', 'y', 'orientation', 'vx','vy', 'ax', 'ay', 'label_grid', 'id_grid', 'sensor_grid', 'occluded_id']] 193 | ref_data = [['timestamp', 'car_id', 'x', 'y', 'orientation', 'vx','vy', 'ax', 'ay', 'label_grid', 'sensor_grid']] 194 | 195 | ego_id = int(key) 196 | vis_ids = [] 197 | 198 | start_timestamp = value.time_stamp_ms_first 199 | last_timestamp = value.time_stamp_ms_last 200 | 201 | ref_datas = [] 202 | vis_datas = [] 203 | 204 | vis_ax = [] 205 | vis_ay = [] 206 | 207 | # Get accelerations. 208 | for stamp in range(start_timestamp, last_timestamp, 100): 209 | object_id, pedes_id = SceneObjects(track_dict, stamp, track_pedes_dict) 210 | ego_ms = getstate(stamp, track_dict, ego_id) 211 | 212 | if stamp == start_timestamp : 213 | prev_ego_vx = ego_ms.vx 214 | prev_ego_vy = ego_ms.vy 215 | 216 | else: 217 | ego_ax = (ego_ms.vx - prev_ego_vx) / 0.1 218 | ego_ay = (ego_ms.vy - prev_ego_vy) / 0.1 219 | 220 | # Get label grid. 221 | label_grid_ego, center_ego, local_x_ego, local_y_ego, pre_local_x_ego, pre_local_y_ego = generateLabelGrid(stamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id) 222 | 223 | # Get sensor grid. 224 | width = value.width 225 | length = value.length 226 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x_ego, pre_local_y_ego, ego_ms, width, length, res=res, ego_flag=True) 227 | 228 | # Convert the ego grid cells to occupied. 229 | label_grid_ego[0] = np.where(label_grid_ego[0]==2., 1. ,label_grid_ego[0]) 230 | 231 | # Ignore the first timestamp because we do not have acceleration data. 232 | if stamp != start_timestamp : 233 | 234 | # Save ego data. 235 | ego_step_data = [stamp, ego_id, ego_ms.x, ego_ms.y, ego_ms.psi_rad, ego_ms.vx, ego_ms.vy, ego_ax, ego_ay, label_grid_ego[0], label_grid_ego[3], sensor_grid_ego, occluded_id] 236 | ego_data.append(ego_step_data) 237 | 238 | vis_datas, ref_datas = vis_state(vis_datas, ref_datas, object_id, label_grid_ego, prev_id_grid_ego, sensor_grid_ego, prev_sensor_grid_ego, track_dict, stamp, start_timestamp, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id) 239 | 240 | # Get the previous time stamp information. 241 | prev_id_grid_ego = label_grid_ego[3] 242 | prev_sensor_grid_ego = sensor_grid_ego 243 | prev_ego_vx = ego_ms.vx 244 | prev_ego_vy = ego_ms.vy 245 | 246 | if vis_ids is not None: 247 | ego_data = ego_data + sum(vis_datas,[]) 248 | 249 | # Save the ego information. 250 | pkl_path = os.path.join(main_folder, 'processed_data/pkl/') 251 | if not os.path.exists(pkl_path): 252 | os.makedirs(pkl_path) 253 | ego_filename = scene+'_'+i_str+'_run_'+str(run_count)+'_ego_vehicle_'+str(key) 254 | pkl.dump(ego_data, open(str(hkl_path) + ego_filename+'.pkl', 'wb')) 255 | 256 | # Save only the visible reference drivers. 257 | for k in range(num_vis): 258 | ref_filename = scene+'_'+i_str+'_run_'+str(run_count)+'_ref_vehicle_'+str(vis_ids[k]) 259 | ref_d = ref_data + ref_datas[k] 260 | pkl.dump(ref_d, open(pkl_path +ref_filename+'.pkl', 'wb')) 261 | 262 | run_count += 1 263 | end_time = datetime.now() 264 | print(ego_filename, ', execution time:', end_time - start_time) 265 | 266 | if __name__ == "__main__": 267 | Dataprocessing() 268 | -------------------------------------------------------------------------------- /src/preprocess/get_driver_sensor_data.py: -------------------------------------------------------------------------------- 1 | # Code to form the state information for the driver sensor data and split into contiguous segments. 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import hickle as hkl 6 | import pickle as pkl 7 | import pdb 8 | import os 9 | import fnmatch 10 | import csv 11 | from tqdm import tqdm 12 | 13 | dir_train_set = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/' 14 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/' 15 | for split in ['train', 'val', 'test']: 16 | 17 | train_set = os.path.join(dir_train_set, 'ref_' + split + '_set.csv') 18 | train_set_files = [] 19 | 20 | with open(train_set,'rt') as f: 21 | data = csv.reader(f) 22 | for row in data: 23 | train_set_files.append(row[0]) 24 | 25 | dir_pickle = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_' + split + '_data/' 26 | 27 | if not os.path.isdir(dir_pickle): 28 | os.mkdir(dir_pickle) 29 | 30 | grid_shape = (70,50) 31 | 32 | # Set-up count for number of files processed. 33 | count = 0 34 | 35 | for file in tqdm(os.listdir(dir)): 36 | file_path = os.path.join(dir, file) 37 | 38 | if file_path in train_set_files: 39 | count += 1 40 | f = open(file_path, 'rb') 41 | X = pkl.load(f) 42 | f.close() 43 | 44 | # Check if the data is empty. Skip if it is. 45 | if len(X) == 1: 46 | continue 47 | 48 | ts = [] 49 | posx = [] 50 | posy = [] 51 | orientation = [] 52 | velx = [] 53 | vely = [] 54 | accx = [] 55 | accy = [] 56 | grid_driver = [] 57 | split_ids = [] 58 | 59 | for i in range(1,len(X)): 60 | if ((i != 1) and (X[i][0] != (ts[-1] + 100))): 61 | split_ids.append(i-1) 62 | ts.append(X[i][0]) 63 | posx.append(X[i][2]) 64 | posy.append(X[i][3]) 65 | orientation.append(X[i][4]) 66 | velx.append(X[i][5]) 67 | vely.append(X[i][6]) 68 | accx.append(X[i][7]) 69 | accy.append(X[i][8]) 70 | 71 | # Reduce grid shape to (20,30). 72 | grid_driver.append(X[i][9][25:-25,:30]) 73 | 74 | start = 0 75 | if len(split_ids) > 0: 76 | for i in range(len(split_ids)): 77 | if len(ts[start:split_ids[i]]) == 0: 78 | print("the loop is a problem") 79 | pdb.set_trace() 80 | pkl.dump([np.array(grid_driver[start:split_ids[i]]),\ 81 | np.array(ts[start:split_ids[i]]),\ 82 | np.array(posx[start:split_ids[i]]),\ 83 | np.array(posy[start:split_ids[i]]),\ 84 | np.array(orientation[start:split_ids[i]]),\ 85 | np.array(velx[start:split_ids[i]]),\ 86 | np.array(vely[start:split_ids[i]]),\ 87 | np.array(accx[start:split_ids[i]]),\ 88 | np.array(accy[start:split_ids[i]])],\ 89 | open(os.path.join(dir_pickle, file[0:-4] + '_' + chr(i + 65) + '.pkl'), 'wb')) 90 | start = split_ids[i] 91 | 92 | if len(ts[start:]) == 0: 93 | print("the last one is a problem") 94 | pdb.set_trace() 95 | ts = np.array(ts[start:]) 96 | posx = np.array(posx[start:]) 97 | posy = np.array(posy[start:]) 98 | orientation = np.array(orientation[start:]) 99 | velx = np.array(velx[start:]) 100 | vely = np.array(vely[start:]) 101 | accx = np.array(accx[start:]) 102 | accy = np.array(accy[start:]) 103 | grid_driver = np.array(grid_driver[start:]) 104 | 105 | if len(split_ids) > 0: 106 | pkl.dump([grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy], open(os.path.join(dir_pickle, file[0:-4] + '_' + chr(i + 1 + 65) + '.pkl'), 'wb')) 107 | else: 108 | pkl.dump([grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy], open(os.path.join(dir_pickle, file[0:-4] + '.pkl'), 'wb')) -------------------------------------------------------------------------------- /src/preprocess/preprocess_driver_sensor_data.py: -------------------------------------------------------------------------------- 1 | # Code to downsample the training set and to form numpy arrays for the driver sensor dataset with the associated unique sources. 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from scipy import signal 6 | 7 | import os 8 | import pdb 9 | import pickle as pkl 10 | import hickle as hkl 11 | from tqdm import tqdm 12 | import random 13 | 14 | random.seed(123) 15 | np.random.seed(seed=123) 16 | 17 | dir_pickle_train = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_train_data/' 18 | dir_pickle_val = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_val_data/' 19 | dir_pickle_test = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_test_data/' 20 | dirs_pickle = [dir_pickle_train, dir_pickle_val, dir_pickle_test] 21 | dir_dataset = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset/' 22 | 23 | if not os.path.isdir(dir_dataset): 24 | os.mkdir(dir_dataset) 25 | 26 | count_train = 0 27 | count_val = 0 28 | count_test = 0 29 | 30 | grids_train_list = [] 31 | states_train_list = [] 32 | sources_train_list = [] 33 | 34 | grids_val_list = [] 35 | states_val_list = [] 36 | sources_val_list = [] 37 | 38 | grids_test_list = [] 39 | states_test_list = [] 40 | sources_test_list = [] 41 | 42 | for i in range(0, len(dirs_pickle)): 43 | dir_pickle = dirs_pickle[i] 44 | 45 | if i == 0: 46 | # Random shuffle the directories for the training set. 47 | dirs_list = sorted(os.listdir(dir_pickle)) 48 | random.shuffle(dirs_list) 49 | else: 50 | dirs_list = sorted(os.listdir(dir_pickle)) 51 | 52 | for file in tqdm(dirs_list): 53 | path = os.path.join(dir_pickle, file) 54 | [grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy, goalx, goaly, tgoal] = pkl.load(open(path, 'rb')) 55 | state = np.vstack((posx,posy,orientation,velx,vely,accx,accy,goalx,goaly,tgoal)).T 56 | 57 | if grid_driver.shape == (0,): 58 | continue 59 | 60 | if i == 0: 61 | # Downsample the training set and ensure that at least 10 seconds of data exists for every observed trajectory. 62 | if count_train <= 70000 and state.shape[0] >= 10: 63 | count_train += 1 64 | 65 | states_train_list.append(state.astype('float32')) 66 | grids_train_list.append(grid_driver.astype('float32')) 67 | sources_train_list.append(np.repeat(np.array(file[:-4]), state.shape[0])) 68 | 69 | if count_train % 1000 == 0: 70 | print(count_train) 71 | 72 | elif i == 1: 73 | # Ensure that at least 10 seconds of data exists for every observed trajectory. 74 | if state.shape[0] >= 10: 75 | 76 | states_val_list.append(state.astype('float32')) 77 | grids_val_list.append(grid_driver.astype('float32')) 78 | sources_val_list.append(np.repeat(np.array(file[:-4]), state.shape[0])) 79 | 80 | count_val += 1 81 | 82 | elif i == 2: 83 | # Ensure that at least 10 seconds of data exists for every observed trajectory. 84 | if state.shape[0] >= 10: 85 | states_test_list.append(state.astype('float32')) 86 | grids_test_list.append(grid_driver.astype('float32')) 87 | sources_test_list.append(np.repeat(np.array(file[:-4]), state.shape[0])) 88 | 89 | count_test += 1 90 | 91 | if i == 0: 92 | states_train = np.concatenate(states_train_list, axis=0) 93 | grids_train = np.concatenate(grids_train_list, axis=0) 94 | sources_train = np.concatenate(sources_train_list, axis=0) 95 | 96 | hkl.dump(states_train, os.path.join(dir_dataset, 'states_shuffled_train.hkl'), mode='w') 97 | hkl.dump(grids_train, os.path.join(dir_dataset, 'label_grids_shuffled_train.hkl'), mode='w') 98 | hkl.dump(sources_train, os.path.join(dir_dataset, 'sources_shuffled_train.hkl'), mode='w') 99 | 100 | # Get mean and std. 101 | mean = np.mean(states_train, axis=0) 102 | std = np.std(states_train, axis=0) 103 | print('mean: ', mean, 'std: ', std) 104 | 105 | del(states_train) 106 | del(grids_train) 107 | del(sources_train) 108 | del(states_train_list) 109 | del(grids_train_list) 110 | del(sources_train_list) 111 | 112 | states_val = np.concatenate(states_val_list, axis=0) 113 | grids_val = np.concatenate(grids_val_list, axis=0) 114 | sources_val = np.concatenate(sources_val_list, axis=0) 115 | 116 | states_test = np.concatenate(states_test_list, axis=0) 117 | grids_test = np.concatenate(grids_test_list, axis=0) 118 | sources_test = np.concatenate(sources_test_list, axis=0) 119 | 120 | hkl.dump(states_val, os.path.join(dir_dataset, 'states_val.hkl'), mode='w') 121 | hkl.dump(grids_val, os.path.join(dir_dataset, 'label_grids_val.hkl'), mode='w') 122 | hkl.dump(sources_val, os.path.join(dir_dataset, 'sources_val.hkl'), mode='w') 123 | del(states_val) 124 | del(grids_val) 125 | del(sources_val) 126 | 127 | hkl.dump(states_test, os.path.join(dir_dataset, 'states_test.hkl'), mode='w') 128 | hkl.dump(grids_test, os.path.join(dir_dataset, 'label_grids_test.hkl'), mode='w') 129 | hkl.dump(sources_test, os.path.join(dir_dataset, 'sources_test.hkl'), mode='w') 130 | del(states_test) 131 | del(grids_test) 132 | del(sources_test) -------------------------------------------------------------------------------- /src/preprocess/statistics.txt: -------------------------------------------------------------------------------- 1 | mean: [ 9.9949402e+02 9.9329816e+02 1.5631536e-01 -1.2793580e-01 2 | -6.1600453e-01 -1.1768220e-01 -1.2684283e-01] 3 | std: [1.9491096e+01 1.2559539e+01 2.0200148e+00 5.6204329e+00 2.5343339e+00 4 | 9.0749109e-01 7.8766114e-01] 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /src/preprocess/train_val_test_split.py: -------------------------------------------------------------------------------- 1 | # Code to split the dataset into train/validation/test. 2 | 3 | import argparse 4 | import os 5 | import time 6 | from collections import defaultdict 7 | 8 | import math 9 | import numpy as np 10 | import csv 11 | 12 | import hickle as hkl 13 | import glob 14 | from sklearn.model_selection import train_test_split 15 | import pdb 16 | import random 17 | from generate_data import getmap 18 | from tqdm import tqdm 19 | np.random.seed(123) 20 | 21 | if __name__ == "__main__": 22 | ego_file = glob.glob('/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/*ego*') 23 | y = np.arange(len(ego_file)) 24 | ego_temp, ego_test_files, _, _ = train_test_split(ego_file, y, test_size=0.1, random_state=42) # test set : 10% of total data 25 | y = np.arange(len(ego_temp)) 26 | ego_train_files, ego_val_files, _, _ = train_test_split(ego_temp, y, test_size=5./90., random_state=42) # validation set : 5% of total data 27 | 28 | ego_train_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_train_set.csv" 29 | ego_val_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_val_set.csv" 30 | ego_test_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_test_set.csv" 31 | train_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_train_set.csv" 32 | val_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_val_set.csv" 33 | test_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_test_set.csv" 34 | 35 | with open(ego_train_set_path, 'w') as f: 36 | wr = csv.writer(f, delimiter=',') 37 | for row in ego_train_files: 38 | wr.writerow([row]) 39 | 40 | with open(ego_val_set_path, 'w') as f: 41 | wr = csv.writer(f, delimiter=',') 42 | for row in ego_val_files: 43 | wr.writerow([row]) 44 | 45 | with open(ego_test_set_path, 'w') as f: 46 | wr = csv.writer(f, delimiter=',') 47 | for row in ego_test_files: 48 | wr.writerow([row]) 49 | 50 | # Training 51 | ref_train_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_train_files] 52 | ref_train_file = sum(ref_train_file, []) 53 | 54 | with open(train_set_path, 'w') as f: 55 | wr = csv.writer(f, delimiter=',') 56 | for row in ref_train_file: 57 | wr.writerow([row]) 58 | 59 | # Validation 60 | ref_val_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_val_files] 61 | ref_val_file = sum(ref_val_file, []) 62 | 63 | with open(val_set_path, 'w') as f: 64 | wr = csv.writer(f) 65 | for row in ref_val_file: 66 | wr.writerow([row]) 67 | 68 | # Test 69 | ref_test_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_test_files] 70 | ref_test_file = sum(ref_test_file, []) 71 | 72 | with open(test_set_path, 'w') as f: 73 | wr = csv.writer(f) 74 | for row in ref_test_file: 75 | wr.writerow([row]) 76 | 77 | print('ego_train_files', len(ego_train_files),len(set(ego_train_files))) 78 | print('ego_val_files', len(ego_val_files),len(set(ego_val_files))) 79 | print('ego_test_files', len(ego_test_files),len(set(ego_test_files))) 80 | print('ref_train_file', len(ref_train_file),len(set(ref_train_file))) 81 | print('ref_val_file', len(ref_val_file),len(set(ref_val_file))) 82 | print('ref_test_file', len(ref_test_file),len(set(ref_test_file))) 83 | 84 | print('train/val/test split done') 85 | -------------------------------------------------------------------------------- /src/utils/combinations.py: -------------------------------------------------------------------------------- 1 | # Code for obtaining the top num_modes most likely combinations for the fused ego vehicle grid using breadth-first search. 2 | # The algorithm we use is adapted from here: https://cs.stackexchange.com/questions/46910/efficiently-finding-k-smallest-elements-of-cartesian-product. 3 | 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | def compute_best_cost(alpha, indices): 8 | cost = 1.0 9 | for i in range(len(alpha)): 10 | cost *= alpha[i][indices[i]] 11 | return cost 12 | 13 | def covert_indices(alpha_indices, alpha_sorted, indices): 14 | best_indices = np.zeros(indices.shape) 15 | for i in range(len(alpha_sorted)): 16 | best_indices[i] = alpha_indices[i][indices[i]] 17 | return best_indices 18 | 19 | # Takes as input sorted alpha. 20 | def BFS(alpha, top=3): 21 | N = len(alpha) 22 | num_modes = len(alpha[0]) 23 | 24 | alpha_sorted = [] 25 | alpha_indices = [] 26 | for i in range(len(alpha)): 27 | alpha_indices.append(np.argsort(alpha[i])[::-1]) 28 | alpha_sorted.append(alpha[i][alpha_indices[-1]]) 29 | 30 | Q = [np.zeros((N,)).astype('int')] 31 | best_index = np.array([]) 32 | best_cost = np.array([]) 33 | cost = np.array([compute_best_cost(alpha_sorted, Q[-1])]) 34 | 35 | # Do BFS to obtain the top num_modes most likely products. 36 | while len(best_index) < top: 37 | idx = np.argmax(cost) 38 | if len(best_cost) == 0: 39 | best_cost = np.array([cost[idx]]) 40 | best_index = np.array([covert_indices(alpha_indices, alpha_sorted, Q[idx])]) 41 | else: 42 | best_cost = np.vstack((best_cost, cost[idx])) 43 | best_index = np.vstack((best_index, covert_indices(alpha_indices, alpha_sorted, Q[idx]))) 44 | last = Q.pop(idx) 45 | cost = np.delete(cost, idx) 46 | if len(Q) == 0: 47 | for i in range(N): 48 | array = np.zeros((N,)).astype('int') 49 | array[i] = 1 50 | Q.append(array) 51 | if len(cost) == 0: 52 | cost = np.array([compute_best_cost(alpha_sorted, Q[-1])]) 53 | else: 54 | cost = np.hstack((cost, compute_best_cost(alpha_sorted, Q[-1]))) 55 | else: 56 | for i in range(N): 57 | array = deepcopy(last) 58 | if array[i] < num_modes-1: 59 | array[i] += 1 60 | else: 61 | continue 62 | Q.append(array) 63 | cost = np.hstack((cost, compute_best_cost(alpha_sorted, Q[-1]))) 64 | if last[i] > 0: 65 | break 66 | return best_index, best_cost -------------------------------------------------------------------------------- /src/utils/data_generator.py: -------------------------------------------------------------------------------- 1 | # Code for preprocessing and data generation. Code is adapted from: https://github.com/sisl/EvidentialSparsification. 2 | import torch 3 | torch.manual_seed(123) 4 | torch.backends.cudnn.deterministic = True 5 | torch.backends.cudnn.benchmark = False 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | import torchvision 9 | 10 | import numpy as np 11 | np.random.seed(0) 12 | import hickle as hkl 13 | import pdb 14 | 15 | def unnormalize(X, nt=1, norm=True): 16 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01]) 17 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01]) 18 | 19 | if norm: 20 | X = X * std + mean 21 | return X.astype(np.float32) 22 | 23 | 24 | def preprocess(X, norm=True): 25 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01]) 26 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01]) 27 | if norm: 28 | X = (X - mean)/std 29 | return X.astype(np.float32) 30 | 31 | # Data generator that creates sequences for input. 32 | class SequenceGenerator(data.Dataset): 33 | def __init__(self, data_file_state, data_file_grid, source_file, nt, 34 | batch_size=8, shuffle=False, sequence_start_mode='all', norm=True): 35 | self.state = hkl.load(data_file_state) 36 | self.grid = hkl.load(data_file_grid) 37 | self.grid = np.reshape(self.grid, (self.grid.shape[0], -1)) 38 | self.grid = self.grid.astype(np.float32) 39 | 40 | print(self.grid.shape) 41 | 42 | # Source for each grid so when creating sequences can ensure that consecutive frames are from same data run. 43 | self.sources = hkl.load(source_file) 44 | self.nt = nt 45 | self.norm = norm 46 | 47 | if batch_size == None: 48 | self.batch_size = self.state.shape[0] 49 | else: 50 | self.batch_size = batch_size 51 | assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}' 52 | self.sequence_start_mode = sequence_start_mode 53 | 54 | # Allow for any possible sequence, starting from any frame. 55 | if self.sequence_start_mode == 'all': 56 | self.possible_starts = np.array([i for i in range(self.state.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]]) 57 | # Create sequences where each unique frame is in at most one sequence. 58 | elif self.sequence_start_mode == 'unique': 59 | curr_location = 0 60 | possible_starts = [] 61 | while curr_location < self.state.shape[0] - self.nt + 1: 62 | if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]: 63 | possible_starts.append(curr_location) 64 | curr_location += self.nt 65 | else: 66 | curr_location += 1 67 | self.possible_starts = possible_starts 68 | 69 | if shuffle: 70 | self.possible_starts = np.random.permutation(self.possible_starts) 71 | self.N_sequences = len(self.possible_starts) 72 | 73 | def __getitem__(self, idx): 74 | if torch.is_tensor(idx): 75 | idx = idx.tolist() 76 | idx = self.possible_starts[idx] 77 | batch_x = self.preprocess(self.state[idx:idx+self.nt]) 78 | batch_y = self.grid[idx+self.nt-1] 79 | sources = self.sources[idx+self.nt-1] 80 | return batch_x, batch_y, sources 81 | 82 | def __len__(self): 83 | return self.N_sequences 84 | 85 | def preprocess(self, X): 86 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01]) 87 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01]) 88 | 89 | if self.norm: 90 | X = (X - mean)/std 91 | return X.astype(np.float32) 92 | -------------------------------------------------------------------------------- /src/utils/dataset_reader.py: -------------------------------------------------------------------------------- 1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | import csv 4 | 5 | from utils.dataset_types import MotionState, Track 6 | 7 | 8 | class Key: 9 | track_id = "track_id" 10 | frame_id = "frame_id" 11 | time_stamp_ms = "timestamp_ms" 12 | agent_type = "agent_type" 13 | x = "x" 14 | y = "y" 15 | vx = "vx" 16 | vy = "vy" 17 | psi_rad = "psi_rad" 18 | length = "length" 19 | width = "width" 20 | 21 | 22 | class KeyEnum: 23 | track_id = 0 24 | frame_id = 1 25 | time_stamp_ms = 2 26 | agent_type = 3 27 | x = 4 28 | y = 5 29 | vx = 6 30 | vy = 7 31 | psi_rad = 8 32 | length = 9 33 | width = 10 34 | 35 | 36 | def read_tracks(filename): 37 | 38 | with open(filename) as csv_file: 39 | csv_reader = csv.reader(csv_file, delimiter=',') 40 | 41 | track_dict = dict() 42 | track_id = None 43 | 44 | for i, row in enumerate(list(csv_reader)): 45 | 46 | if i == 0: 47 | # check first line with key names 48 | assert(row[KeyEnum.track_id] == Key.track_id) 49 | assert(row[KeyEnum.frame_id] == Key.frame_id) 50 | assert(row[KeyEnum.time_stamp_ms] == Key.time_stamp_ms) 51 | assert(row[KeyEnum.agent_type] == Key.agent_type) 52 | assert(row[KeyEnum.x] == Key.x) 53 | assert(row[KeyEnum.y] == Key.y) 54 | assert(row[KeyEnum.vx] == Key.vx) 55 | assert(row[KeyEnum.vy] == Key.vy) 56 | assert(row[KeyEnum.psi_rad] == Key.psi_rad) 57 | assert(row[KeyEnum.length] == Key.length) 58 | assert(row[KeyEnum.width] == Key.width) 59 | continue 60 | 61 | if int(row[KeyEnum.track_id]) != track_id: 62 | # new track 63 | track_id = int(row[KeyEnum.track_id]) 64 | assert(track_id not in track_dict.keys()), \ 65 | "Line %i: Track id %i already in dict, track file not sorted properly" % (i+1, track_id) 66 | track = Track(track_id) 67 | track.agent_type = row[KeyEnum.agent_type] 68 | track.length = float(row[KeyEnum.length]) 69 | track.width = float(row[KeyEnum.width]) 70 | track.time_stamp_ms_first = int(row[KeyEnum.time_stamp_ms]) 71 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms]) 72 | track_dict[track_id] = track 73 | 74 | track = track_dict[track_id] 75 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms]) 76 | ms = MotionState(int(row[KeyEnum.time_stamp_ms])) 77 | ms.x = float(row[KeyEnum.x]) 78 | ms.y = float(row[KeyEnum.y]) 79 | ms.vx = float(row[KeyEnum.vx]) 80 | ms.vy = float(row[KeyEnum.vy]) 81 | ms.psi_rad = float(row[KeyEnum.psi_rad]) 82 | track.motion_states[ms.time_stamp_ms] = ms 83 | 84 | return track_dict 85 | 86 | 87 | def read_pedestrian(filename): 88 | 89 | with open(filename) as csv_file: 90 | csv_reader = csv.reader(csv_file, delimiter=',') 91 | 92 | track_dict = dict() 93 | track_id = None 94 | 95 | for i, row in enumerate(list(csv_reader)): 96 | 97 | if i == 0: 98 | # check first line with key names 99 | assert (row[KeyEnum.track_id] == Key.track_id) 100 | assert (row[KeyEnum.frame_id] == Key.frame_id) 101 | assert (row[KeyEnum.time_stamp_ms] == Key.time_stamp_ms) 102 | assert (row[KeyEnum.agent_type] == Key.agent_type) 103 | assert (row[KeyEnum.x] == Key.x) 104 | assert (row[KeyEnum.y] == Key.y) 105 | assert (row[KeyEnum.vx] == Key.vx) 106 | assert (row[KeyEnum.vy] == Key.vy) 107 | continue 108 | 109 | if row[KeyEnum.track_id] != track_id: 110 | # new track 111 | track_id = row[KeyEnum.track_id] 112 | assert (track_id not in track_dict.keys()), \ 113 | "Line %i: Track id %s already in dict, track file not sorted properly" % (i + 1, track_id) 114 | track = Track(track_id) 115 | track.agent_type = row[KeyEnum.agent_type] 116 | track.time_stamp_ms_first = int(row[KeyEnum.time_stamp_ms]) 117 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms]) 118 | track_dict[track_id] = track 119 | 120 | track = track_dict[track_id] 121 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms]) 122 | ms = MotionState(int(row[KeyEnum.time_stamp_ms])) 123 | ms.x = float(row[KeyEnum.x]) 124 | ms.y = float(row[KeyEnum.y]) 125 | ms.vx = float(row[KeyEnum.vx]) 126 | ms.vy = float(row[KeyEnum.vy]) 127 | track.motion_states[ms.time_stamp_ms] = ms 128 | 129 | return track_dict 130 | -------------------------------------------------------------------------------- /src/utils/dataset_types.py: -------------------------------------------------------------------------------- 1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | DELTA_TIMESTAMP_MS = 100 # similar throughout the whole dataset 4 | 5 | class MotionState: 6 | def __init__(self, time_stamp_ms): 7 | assert isinstance(time_stamp_ms, int) 8 | self.time_stamp_ms = time_stamp_ms 9 | self.x = None 10 | self.y = None 11 | self.vx = None 12 | self.vy = None 13 | self.psi_rad = None 14 | 15 | def __str__(self): 16 | return "MotionState: " + str(self.__dict__) 17 | 18 | 19 | class Track: 20 | def __init__(self, id): 21 | # assert isinstance(id, int) 22 | self.track_id = id 23 | self.agent_type = None 24 | self.length = None 25 | self.width = None 26 | self.time_stamp_ms_first = None 27 | self.time_stamp_ms_last = None 28 | self.motion_states = dict() 29 | 30 | def __str__(self): 31 | string = "Track: track_id=" + str(self.track_id) + ", agent_type=" + str(self.agent_type) + \ 32 | ", length=" + str(self.length) + ", width=" + str(self.width) + \ 33 | ", time_stamp_ms_first=" + str(self.time_stamp_ms_first) + \ 34 | ", time_stamp_ms_last=" + str(self.time_stamp_ms_last) + \ 35 | "\n motion_states:" 36 | for key, value in sorted(self.motion_states.items()): 37 | string += "\n " + str(key) + ": " + str(value) 38 | return string 39 | -------------------------------------------------------------------------------- /src/utils/grid_fuse.py: -------------------------------------------------------------------------------- 1 | # Code to transform the driver sensor OGMs to the ego vehicle's OGM frame of reference. 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import math 6 | import copy 7 | from utils.grid_utils import global_grid 8 | import time 9 | from scipy.spatial import cKDTree 10 | import pdb 11 | 12 | def mask_in_EgoGrid(global_grid_x, global_grid_y, ref_xy, ego_xy, pred_egoGrid, pred_maps, res, mask_unk=None, tolerance=1): 13 | 14 | # Consider only the unknown cells in pred_egoGrid (ego sensor grid before trasfering values). 15 | indices = np.where(mask_unk) 16 | ego_x = ego_xy[0][indices] 17 | ego_y = ego_xy[1][indices] 18 | ego_xy = [ego_x, ego_y] 19 | flat_indicies = indices[0]*pred_egoGrid.shape[1]+indices[1] 20 | 21 | # ref indx --> global indx 22 | ref_x_ind = np.floor(global_grid_x.shape[1]*(ref_xy[0]-x_min+res/2.)/(x_max-x_min+res)).astype(int) # column index 23 | ref_y_ind = np.floor(global_grid_y.shape[0]*(ref_xy[1]-y_min+res/2.)/(y_max-y_min+res)).astype(int) # row index 24 | 25 | ref_global_ind = np.vstack((ref_y_ind.flatten(), ref_x_ind.flatten())).T 26 | 27 | # ego indx --> global indx 28 | ego_x_ind = np.floor(global_grid_x.shape[1]*(ego_xy[0]-x_min+res/2.)/(x_max-x_min+res)).astype(int) # column index 29 | ego_y_ind = np.floor(global_grid_y.shape[0]*(ego_xy[1]-y_min+res/2.)/(y_max-y_min+res)).astype(int) # row index 30 | ego_global_ind = np.vstack((ego_y_ind.flatten(), ego_x_ind.flatten())).T 31 | 32 | # Look for the matching global_grid indices between the ref_grid and ego_grid. 33 | kdtree = cKDTree(ref_global_ind) 34 | dists, inds = kdtree.query(ego_global_ind) 35 | 36 | pred_egoGrid_flat = pred_egoGrid.flatten() 37 | pred_maps_flat = pred_maps.flatten() 38 | 39 | # Back to the local grid indices. Tolerance should be an integer because kd tree is comparing indices. 40 | ego_ind = flat_indicies[np.where(dists<=tolerance)] 41 | ref_ind = inds[np.where(dists<=tolerance)] 42 | 43 | # Assign the values for the corresponding cells. 44 | pred_egoGrid_flat[ego_ind] = pred_maps_flat[ref_ind] 45 | pred_egoGrid = pred_egoGrid_flat.reshape(pred_egoGrid.shape) 46 | return pred_egoGrid 47 | 48 | def Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, ego_sensor_grid, endpoint, res=0.1, mask_unk=None): 49 | global x_min, x_max, y_min, y_max 50 | ##################################################################################################################################### 51 | ## Goal : Transfer pred_maps (in driver sensor's grid) cell information to the unknown cells of ego car's sensor_grid 52 | ## Method : Use global grid as an intermediate (ref indx --> global indx --> ego indx). 53 | ## ref_local_xy (N, 2, w, h) & pred_maps (N, w, h) 54 | ## ego_xy (2, w', h') & & ego_sensor_grid (w', h') 55 | ## return pred_maps_egoGrid(N, w', h') 56 | ## * N : number of agents 57 | ##################################################################################################################################### 58 | 59 | x_min = endpoint[0] 60 | x_max = endpoint[2] 61 | y_min = endpoint[1] 62 | y_max = endpoint[3] 63 | 64 | global_res = 1.0 65 | global_grid_x, global_grid_y = global_grid(np.array([x_min,y_min]),np.array([x_max,y_max]),global_res) 66 | 67 | if np.any(ref_local_xy[0] == None): 68 | pred_maps_egoGrid.append(None) 69 | 70 | else: 71 | pred_egoGrid = copy.copy(ego_sensor_grid) 72 | pred_egoGrid = np.ones(ego_sensor_grid.shape)*2 73 | 74 | pred_egoGrid = mask_in_EgoGrid(global_grid_x, global_grid_y, ref_local_xy, ego_local_xy, pred_egoGrid, pred_maps, res, mask_unk) 75 | 76 | return pred_egoGrid -------------------------------------------------------------------------------- /src/utils/grid_utils.py: -------------------------------------------------------------------------------- 1 | # Code includes utilities for evidential fusion and OGMs. 2 | 3 | import numpy as np 4 | from utils.dataset_types import Track, MotionState 5 | import pdb 6 | import time 7 | from matplotlib import pyplot as plt 8 | 9 | # Outputs grid_dst: [2,w,h]. 10 | # The first channel denotes occupied and free belief masses. 11 | def get_belief_mass(grid, ego_flag=True, m=None): 12 | grid_dst = np.zeros((2,grid.shape[0], grid.shape[1])) 13 | 14 | if ego_flag: 15 | mask_occ = grid == 1 16 | mask_free = grid == 0 17 | mask_unk = grid == 2 18 | 19 | else: 20 | mask_not_unk = grid != 2 21 | mask_unk = None 22 | 23 | if m is not None: 24 | mass = m 25 | elif ego_flag: 26 | mass = 1.0 27 | else: 28 | mass = 0.75 29 | 30 | if ego_flag: 31 | grid_dst[0,mask_occ] = mass 32 | grid_dst[1,mask_free] = mass 33 | 34 | else: 35 | grid_dst[0, mask_not_unk] = grid[mask_not_unk] * mass 36 | grid_dst[1, mask_not_unk] = (1.0-grid[mask_not_unk]) * mass 37 | 38 | return grid_dst, mask_unk 39 | 40 | def dst_fusion(sensor_grid_dst, ego_grid_dst, mask_unk): 41 | 42 | fused_grid_dst = np.zeros(ego_grid_dst.shape) 43 | 44 | # predicted unknown mass 45 | ego_unknown = 1. - ego_grid_dst[0] - ego_grid_dst[1] 46 | 47 | # measurement masses: meas_m_free, meas_m_occ 48 | sensor_unknown = 1. - sensor_grid_dst[0] - sensor_grid_dst[1] 49 | 50 | # Implement DST rule of combination. 51 | K = np.multiply(ego_grid_dst[1], sensor_grid_dst[0]) + np.multiply(ego_grid_dst[0], sensor_grid_dst[1]) 52 | 53 | fused_grid_dst[0] = np.divide((np.multiply(ego_grid_dst[0], sensor_unknown) + np.multiply(ego_unknown, sensor_grid_dst[0]) + np.multiply(ego_grid_dst[0], sensor_grid_dst[0])), (1. - K)) 54 | fused_grid_dst[1] = np.divide((np.multiply(ego_grid_dst[1], sensor_unknown) + np.multiply(ego_unknown, sensor_grid_dst[1]) + np.multiply(ego_grid_dst[1], sensor_grid_dst[1])), (1. - K)) 55 | 56 | pignistic_grid = pignistic(fused_grid_dst) 57 | 58 | return fused_grid_dst, pignistic_grid 59 | 60 | def pignistic(grid_dst): 61 | grid = 0.5*grid_dst[0] + 0.5*(1.-grid_dst[1]) 62 | return grid 63 | 64 | def rotate_around_center(pts, center, yaw): 65 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center 66 | 67 | # Create a grid of x and y values that have an associated x,y centre coordinate in the global frame. 68 | def global_grid(origin,endpoint,res): 69 | 70 | xmin = min(origin[0],endpoint[0]) 71 | xmax = max(origin[0],endpoint[0]) 72 | ymin = min(origin[1],endpoint[1]) 73 | ymax = max(origin[1],endpoint[1]) 74 | 75 | x_coords = np.arange(xmin,xmax,res) 76 | y_coords = np.arange(ymin,ymax,res) 77 | 78 | gridx,gridy = np.meshgrid(x_coords,y_coords) 79 | return gridx.T,gridy.T 80 | 81 | def local_grid(ms, width, length, res, ego_flag=True, grid_shape=None): 82 | 83 | center = np.array([ms.x, ms.y]) 84 | 85 | if ego_flag: 86 | minx = center[0] - 10. 87 | miny = center[1] - 35. 88 | maxx = center[0] + 50. 89 | maxy = center[1] + 35. 90 | else: 91 | if grid_shape is not None: 92 | minx = center[0] + length/2. 93 | miny = center[1] - grid_shape[0]/2. 94 | maxx = center[0] + length/2. + grid_shape[1] 95 | maxy = center[1] + grid_shape[0]/2. 96 | else: 97 | minx = center[0] + length/2. 98 | miny = center[1] - 35. 99 | maxx = center[0] + length/2. + 50. 100 | maxy = center[1] + 35. 101 | 102 | x_coords = np.arange(minx,maxx,res) 103 | y_coords = np.arange(miny,maxy,res) 104 | 105 | mesh_x, mesh_y = np.meshgrid(x_coords,y_coords) 106 | pre_local_x = mesh_x 107 | pre_local_y = np.flipud(mesh_y) 108 | 109 | xy_local = rotate_around_center(np.vstack((pre_local_x.flatten(), pre_local_y.flatten())).T, center, ms.psi_rad) 110 | 111 | x_local = xy_local[:,0].reshape(mesh_x.shape) 112 | y_local = xy_local[:,1].reshape(mesh_y.shape) 113 | 114 | if grid_shape is not None: 115 | x_local = grid_reshape(x_local, grid_shape) 116 | y_local = grid_reshape(y_local, grid_shape) 117 | 118 | elif ego_flag: 119 | grid_shape = (int(70/res),int(60/res)) 120 | 121 | x_local = grid_reshape(x_local, grid_shape) 122 | y_local = grid_reshape(y_local, grid_shape) 123 | 124 | return x_local, y_local, pre_local_x, pre_local_y 125 | 126 | # Element in nd array closest to the scalar value v. 127 | def find_nearest(n,v,v0,vn,res): 128 | idx = int(np.floor( n*(v-v0+res/2.)/(vn-v0+res) )) 129 | return idx 130 | 131 | def grid_reshape(data, grid_shape): 132 | if len(data)==1: 133 | d = data[0] 134 | if d.shape[0]!= grid_shape[0] or d.shape[1]!= grid_shape[1]: 135 | data = [d[:grid_shape[0],:grid_shape[1]]] 136 | 137 | else: 138 | if len(data.shape) == 3: 139 | if data.shape[1] > grid_shape[0]: 140 | data = data[:,:grid_shape[0]] 141 | if data.shape[2] > grid_shape[1]: 142 | data = data[:,:,:grid_shape[1]] 143 | 144 | elif len(data.shape) == 2: 145 | if data.shape[0]!= grid_shape[0] or data.shape[1]!= grid_shape[1]: 146 | data = data[:grid_shape[0],:grid_shape[1]] 147 | return data 148 | 149 | ############## LabelGrid ########################### 150 | # 0 = empty 151 | # 1 = ego vehicle 152 | # 2 = vehicle 153 | # 3 = pedestrian 154 | # Pedestrians have negative labels (e.g., P1 -> -1). 155 | ##################################################### 156 | def generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag, res=1., grid_shape=None, track_pedes_dict=None, pedes_id=None): 157 | 158 | # Initialize the labels. 159 | labels = [] 160 | boxes_vehicles = [] 161 | boxes_persons = [] 162 | dynamics_vehicles = [] 163 | dynamics_persons = [] 164 | 165 | # The global reference direction of all vehicles and people is in the direction of the ego vehicle. 166 | for key, value in track_dict.items(): 167 | assert isinstance(value, Track) 168 | if key == ego_id: 169 | ms = value.motion_states[timestamp] 170 | assert isinstance(ms, MotionState) 171 | x_ego = ms.x 172 | y_ego = ms.y 173 | theta_ego = ms.psi_rad 174 | vx_ego = ms.vx 175 | vy_ego = ms.vy 176 | w_ego = value.width 177 | l_ego = value.length 178 | 179 | center = np.array([x_ego, y_ego]) 180 | ms = getstate(timestamp, track_dict, ego_id) 181 | 182 | x_local, y_local, pre_local_x, pre_local_y = local_grid(ms, w_ego, l_ego, res, ego_flag=ego_flag, grid_shape=grid_shape) 183 | 184 | label_grid = np.zeros((4,x_local.shape[0],x_local.shape[1])) 185 | label_grid[3] = np.nan 186 | 187 | for key, value in track_dict.items(): 188 | assert isinstance(value, Track) 189 | if key in object_id: 190 | ms = value.motion_states[timestamp] 191 | assert isinstance(ms, MotionState) 192 | 193 | d = np.sqrt((ms.x-x_ego)**2 + (ms.y-y_ego)**2) 194 | if d < (np.sqrt(2.)*22.+2.*l_ego): 195 | vx = ms.vx 196 | vy = ms.vy 197 | 198 | w = value.width 199 | l = value.length 200 | ID = key 201 | 202 | if key == ego_id: 203 | w += res/4. 204 | 205 | coords = polygon_xy_from_motionstate(ms, w, l) 206 | mask = point_in_rectangle(x_local, y_local, coords) 207 | if ID == ego_id: 208 | # Mark as occupied with ego vehicle. 209 | label_grid[0,mask] = 2. 210 | 211 | else: 212 | # Mark as occupied with vehicle. 213 | label_grid[0,mask] = 1. # 214 | 215 | label_grid[3,mask] = value.track_id 216 | 217 | label_grid[1,mask] = vx 218 | label_grid[2,mask] = vy 219 | 220 | if track_pedes_dict != None: 221 | for key, value in track_pedes_dict.items(): 222 | assert isinstance(value, Track) 223 | if key in pedes_id: 224 | ms = value.motion_states[timestamp] 225 | assert isinstance(ms, MotionState) 226 | 227 | d = np.sqrt((ms.x-x_ego)**2 + (ms.y-y_ego)**2) 228 | if d < (np.sqrt(2.)*22.+2.*l_ego): 229 | vx = ms.vx 230 | vy = ms.vy 231 | 232 | w = 1.5 233 | l = 1.5 234 | ID = key 235 | 236 | coords = polygon_xy_from_motionstate_pedest(ms, w, l) 237 | mask = point_in_rectangle(x_local, y_local, coords) 238 | # Mark as occupied with vehicle. 239 | label_grid[0,mask] = 1. # 240 | 241 | label_grid[3,mask] = -1.0 * float(value.track_id[1:]) 242 | 243 | label_grid[1,mask] = vx 244 | label_grid[2,mask] = vy 245 | 246 | if grid_shape == None: 247 | if ego_flag: 248 | grid_shape = (int(70/res),int(60/res)) 249 | else: 250 | grid_shape = (int(70/res),int(50/res)) 251 | 252 | label_grid = grid_reshape(label_grid, grid_shape) 253 | x_local = grid_reshape(x_local, grid_shape) 254 | y_local = grid_reshape(y_local, grid_shape) 255 | pre_local_x = grid_reshape(pre_local_x, grid_shape) 256 | pre_local_y = grid_reshape(pre_local_y, grid_shape) 257 | 258 | return label_grid, center, x_local, y_local, pre_local_x, pre_local_y 259 | 260 | ############## SensorGrid ################## 261 | # 0: empty 262 | # 1 : occupied 263 | # 2 : unknown 264 | ############################################# 265 | def generateSensorGrid(labels_grid, pre_local_x, pre_local_y, ms, width, length, res=1., ego_flag=True, grid_shape=None): 266 | center_ego = (ms.x, ms.y) 267 | theta_ego = ms.psi_rad 268 | occluded_id = [] 269 | visible_id = [] 270 | # All the angles of the LiDAR simulation. 271 | angle_res = 0.01 #0.002 272 | if ego_flag: 273 | angles = np.arange(0., 2.*np.pi+angle_res, angle_res) 274 | else: 275 | angles = np.arange(0., 2.*np.pi+angle_res, angle_res) 276 | 277 | # Get the maximum and minimum x and y values in the local grids. 278 | x_min = np.amin(pre_local_x) 279 | x_max = np.amax(pre_local_x) 280 | y_min = np.amin(pre_local_y) 281 | y_max = np.amax(pre_local_y) 282 | 283 | x_shape = pre_local_x.shape[0] 284 | y_shape = pre_local_y.shape[1] 285 | 286 | # Get the cells not occupied by the ego vehicle. 287 | mask = np.where(labels_grid[0]!=2, True,False) 288 | 289 | # No need to do ray tracing if no object on the grid. 290 | if np.all(labels_grid[0,mask]==0.): 291 | sensor_grid = np.zeros((x_shape, y_shape)) 292 | 293 | else: 294 | # Generate a line from the center indices to the edge of the local grid: sqrt(2)*128./3. meters away (LiDAR distance). 295 | r = (np.sqrt(x_shape**2 + y_shape**2) + 10) * 1.0 * res 296 | x = (r*np.cos(angles)+center_ego[0]) # length of angles 297 | y = (r*np.sin(angles)+center_ego[1]) # length of angles 298 | 299 | sensor_grid = np.zeros((x_shape, y_shape)) 300 | 301 | for i in range(x.shape[0]): 302 | 303 | if x[i] < center_ego[0]: 304 | x_range = np.arange(center_ego[0],np.maximum(x[i], x_min-res),-res*angle_res) 305 | else: 306 | x_range = np.arange(center_ego[0],np.minimum(x[i]+res, x_max+res),res*angle_res) 307 | 308 | # Find the corresponding ys. 309 | y_range = linefunction(center_ego[0],center_ego[1],x[i],y[i],x_range) 310 | 311 | y_temp = np.floor(y_shape*(x_range-x_min-res/2.)/(x_max-x_min+res)).astype(int) 312 | x_temp = np.floor(x_shape*(y_range-y_min-res/2.)/(y_max-y_min+res)).astype(int) 313 | 314 | # Take only the indices inside the local grid. 315 | indices = np.where(np.logical_and(np.logical_and(np.logical_and((x_temp < x_shape), (x_temp >= 0)), (y_temp < y_shape)), (y_temp >= 0))) 316 | x_temp = x_temp[indices] 317 | y_temp = y_temp[indices] 318 | 319 | # Found first occupied cell. 320 | labels_reduced = labels_grid[0,x_temp,y_temp] 321 | 322 | if len(labels_reduced)!=0 : 323 | unique_labels = np.unique(labels_reduced) 324 | 325 | # Check if there are any occupied cells. 326 | if np.any(unique_labels==1.): 327 | ind = np.where(labels_reduced == 1) 328 | sensor_grid[x_temp[ind[0][0]:],y_temp[ind[0][0]:]] = 2 329 | sensor_grid[x_temp[ind[0][0]], y_temp[ind[0][0]]] = 1 330 | else: 331 | sensor_grid[x_temp, y_temp] = 0 332 | 333 | # No ego id included. 334 | unique_id = np.unique(labels_grid[3,:,:]) 335 | unique_id= np.delete(unique_id, np.where(np.isnan(unique_id))) 336 | 337 | # Set the unknown area as free. 338 | for id in unique_id: 339 | mask = (labels_grid[3,:,:]==id) 340 | if np.any(sensor_grid[mask] == 1): 341 | # Set as occupied/visible. 342 | sensor_grid[mask] = 1 343 | visible_id.append(id) 344 | # Ignore ego vehicle. 345 | elif np.any(labels_grid[0,mask]== 2): # except ego car 346 | pass 347 | else: 348 | # Set as occluded/invisible. 349 | sensor_grid[mask] = 2 350 | occluded_id.append(id) 351 | 352 | # Set the ego vehicle as occupied. 353 | mask = (labels_grid[0,:,:] == 2) 354 | sensor_grid[mask] = 1 355 | 356 | if grid_shape == None: 357 | if ego_flag: 358 | grid_shape = (int(70/res),int(60/res)) 359 | else: 360 | grid_shape = (int(70/res),int(50/res)) 361 | 362 | sensor_grid = grid_reshape(sensor_grid, grid_shape) 363 | 364 | return sensor_grid, occluded_id, visible_id 365 | 366 | def linefunction(velx,vely,indx,indy,x_range): 367 | m = (indy-vely)/(indx-velx) 368 | b = vely-m*velx 369 | return m*x_range + b 370 | 371 | def point_in_rectangle(x, y, rectangle): 372 | A = rectangle[0] 373 | B = rectangle[1] 374 | C = rectangle[2] 375 | 376 | M = np.array([x,y]).transpose((1,2,0)) 377 | 378 | AB = B-A 379 | AM = M-A 380 | BC = C-B 381 | BM = M-B 382 | 383 | dotABAM = np.dot(AM,AB) 384 | dotABAB = np.dot(AB,AB) 385 | dotBCBM = np.dot(BM,BC) 386 | dotBCBC = np.dot(BC,BC) 387 | 388 | return np.logical_and(np.logical_and(np.logical_and((0. <= dotABAM), (dotABAM <= dotABAB)), (0. <= dotBCBM)), (dotBCBM <= dotBCBC)) # nxn 389 | 390 | def rotate_around_center(pts, center, yaw): 391 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center 392 | 393 | 394 | def polygon_xy_from_motionstate(ms, width, length): 395 | assert isinstance(ms, MotionState) 396 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 397 | lowright = (ms.x + length / 2., ms.y - width / 2.) 398 | upright = (ms.x + length / 2., ms.y + width / 2.) 399 | upleft = (ms.x - length / 2., ms.y + width / 2.) 400 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad) 401 | 402 | def polygon_xy_from_motionstate_pedest(ms, width, length): 403 | assert isinstance(ms, MotionState) 404 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 405 | lowright = (ms.x + length / 2., ms.y - width / 2.) 406 | upright = (ms.x + length / 2., ms.y + width / 2.) 407 | upleft = (ms.x - length / 2., ms.y + width / 2.) 408 | return np.array([lowleft, lowright, upright, upleft]) 409 | 410 | def getstate(timestamp, track_dict, id): 411 | for key, value in track_dict.items(): 412 | if key==id: 413 | return value.motion_states[timestamp] 414 | 415 | # All objects at time step t in interval [time_stamp_ms_start,time_stamp_ms_last]. 416 | def SceneObjects(track_dict, time_step, track_pedes_dict=None): 417 | object_id = [] 418 | pedes_id = [] 419 | if track_dict != None: 420 | for key, value in track_dict.items(): 421 | assert isinstance(value, Track) 422 | 423 | if value.time_stamp_ms_first <= time_step <= value.time_stamp_ms_last: 424 | object_id.append(value.track_id) 425 | 426 | if track_pedes_dict != None: 427 | for key, value in track_pedes_dict.items(): 428 | assert isinstance(value, Track) 429 | 430 | if value.time_stamp_ms_first <= time_step <= value.time_stamp_ms_last: 431 | pedes_id.append(value.track_id) 432 | 433 | return object_id, pedes_id 434 | 435 | # All objects in the scene. 436 | def AllObjects(track_dict, track_pedes_dict=None): 437 | object_id = [] 438 | for key, value in track_dict.items(): 439 | assert isinstance(value, Track) 440 | 441 | object_id.append(value.track_id) 442 | 443 | pedes_id = [] 444 | if track_pedes_dict != None: 445 | for key, value in track_pedes_dict.items(): 446 | assert isinstance(value, Track) 447 | 448 | pedes_id.append(value.track_id) 449 | 450 | return object_id, pedes_id -------------------------------------------------------------------------------- /src/utils/interaction_utils.py: -------------------------------------------------------------------------------- 1 | # Code for computing the IS metric and visualizing metrics. IS metric code is modified from: https://github.com/BenQLange/Occupancy-Grid-Metrics. 2 | 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | import matplotlib.patches as mpatches 6 | import os 7 | import imageio 8 | from copy import deepcopy 9 | import pdb 10 | from tqdm import tqdm 11 | 12 | # Metric visualization. 13 | def plot_metric_clusters(K, states, cluster_ids, metric_grids, metric, dir, sources_val): 14 | patches = [] 15 | color = plt.cm.jet(np.linspace(0,1,K)) 16 | 17 | for c in range(K): 18 | patches.append(mpatches.Patch(color=color[c], label=c)) 19 | 20 | unique_sources = np.unique(np.array(sources_val)) 21 | 22 | for source in unique_sources: 23 | if source[-1] != scenario: 24 | continue 25 | 26 | mask = (np.array(sources_val) == source) 27 | 28 | batch_x_np_val_scenario = states[mask] 29 | cluster_ids_y_val_scenario = cluster_ids[mask] 30 | metric_grids_scenario = metric_grids[mask] 31 | 32 | if dict_scenario[scenario] == 0: 33 | fig, (ax1, ax2, ax3, ax4) = plt.subplots(4) 34 | symbol = '.' 35 | else: 36 | symbol = 'x' 37 | 38 | for k in range(batch_x_np_val_scenario.shape[0]): 39 | ax1.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,1], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol) 40 | ax2.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,2], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol) 41 | ax3.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,3], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol) 42 | ax4.scatter(batch_x_np_val_scenario[k,0,1], metric_grids_scenario[k], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol) 43 | 44 | if dict_scenario[scenario] == 1: 45 | ax1.set_ylabel("Pos (m)") 46 | ax1.set_ylim(-5,120) 47 | ax1.set_xlim(-5,120) 48 | ax2.set_ylabel("Vel (m/s)") 49 | ax2.set_ylim(0,8) 50 | ax2.set_xlim(-5,120) 51 | ax3.set_ylabel("Acc (m/s^2)") 52 | ax3.set_ylim(-3,3) 53 | ax3.set_xlim(-5,120) 54 | ax4.set_ylim(0,1) 55 | if metric == 'IM': 56 | ax4.set_ylim(0,50) 57 | if metric == 'IM2': 58 | ax4.set_ylim(0, 50) 59 | ax4.set_xlim(-5,120) 60 | ax4.set_xlabel("Position (m)") 61 | ax4.set_ylabel(metric) 62 | fig.tight_layout(pad=.05) 63 | 64 | picture_file = os.path.join(dir, 'val_examples_K_' + str(K) + '_scenario_' + scenario + '_' + metric + '.png') 65 | fig.savefig(picture_file) 66 | fig.clf() 67 | 68 | dict_scenario[scenario] += 1 69 | 70 | def plot_train(K, p_m_a_np, cluster_centers_np, dir, grid_shape): 71 | fig, (ax1, ax2, ax3) = plt.subplots(3) 72 | fig.suptitle('State Clusters') 73 | 74 | for k in range(K): 75 | fig_occ, ax_occ = plt.subplots(1) 76 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0) 77 | ax_occ.imshow(image, cmap='gray') 78 | picture_file = os.path.join(dir, str(K) + '_cluster_' + str(k) + '.png') 79 | plt.savefig(picture_file) 80 | fig_occ.clf() 81 | 82 | def plot_scatter(K, states, metric_vals, metric, dir): 83 | x = states[:,0,0] 84 | y = states[:,0,1] 85 | plt.figure() 86 | plt.scatter(x, y, marker='.', s=150, linewidths=1, c=metric_vals, cmap=plt.cm.coolwarm) 87 | cb = plt.colorbar() 88 | plt.clim(0,80) 89 | plt.xlabel('x (m)') 90 | plt.ylabel('y (m)') 91 | plt.title('IM over First Driver State') 92 | picture_file = os.path.join(dir, str(K) + '_' + metric + '_first_driver_state.png') 93 | plt.savefig(picture_file) 94 | plt.clf() 95 | del(cb) 96 | 97 | def plot_scatter_clusters(K, states, labels, dir): 98 | x = states[:,0,0] 99 | y = states[:,0,1] 100 | plt.figure() 101 | plt.scatter(x, y, marker='.', s=150, linewidths=1, c=labels, cmap=plt.cm.coolwarm) 102 | cb = plt.colorbar() 103 | plt.clim(0,K) 104 | plt.xlabel('x (m)') 105 | plt.ylabel('y (m)') 106 | plt.title('Clusters over First Driver State') 107 | picture_file = os.path.join(dir, str(K) + '_clusters_first_driver_state.png') 108 | plt.savefig(picture_file) 109 | plt.clf() 110 | del(cb) 111 | 112 | def plot_multimodal_metrics(Ks, acc_nums, mse_nums, im_nums, acc_nums_best, mse_nums_best, im_nums_best, dir): 113 | plt.figure() 114 | plt.plot(Ks, acc_nums, label='Acc') 115 | plt.plot(Ks, acc_nums_best, label='Best 3 Acc') 116 | plt.title('Multimodality Performance Accuracy') 117 | plt.xlabel('K') 118 | plt.ylabel('Acc') 119 | plt.legend() 120 | picture_file = os.path.join(dir, 'multimodality_acc_best_3.png') 121 | plt.savefig(picture_file) 122 | 123 | plt.figure() 124 | plt.plot(Ks, mse_nums, label='MSE') 125 | plt.plot(Ks, mse_nums_best, label='Best 3 MSE') 126 | plt.title('Multimodality Performance MSE') 127 | plt.xlabel('K') 128 | plt.ylabel('MSE') 129 | plt.legend() 130 | picture_file = os.path.join(dir, 'multimodality_mse_best_3.png') 131 | plt.savefig(picture_file) 132 | 133 | plt.figure() 134 | plt.plot(Ks, im_nums, label='IM') 135 | plt.plot(Ks, im_nums_best, label='Best 3 IM') 136 | plt.title('Multimodality Performance IM') 137 | plt.xlabel('K') 138 | plt.ylabel('IM') 139 | plt.legend() 140 | picture_file = os.path.join(dir, 'multimodality_im_best_3.png') 141 | plt.savefig(picture_file) 142 | 143 | # IS metric. 144 | def MapSimilarityMetric(grids_pred, grids_actual): 145 | 146 | num_samples,_,_ = grids_pred.shape 147 | score, score_occupied, score_free, score_occluded = np.zeros((num_samples,)), np.zeros((num_samples,)), np.zeros((num_samples,)), np.zeros((num_samples,)) # score_occluded, np.zeros((num_samples,)) 148 | for sample in range(num_samples): # tqdm 149 | occupied, free, occluded = computeSimilarityMetric(grids_actual[sample,:,:], grids_pred[sample,:,:]) 150 | score[sample] += occupied 151 | score[sample] += occluded 152 | score[sample] += free 153 | 154 | score_occupied[sample] += occupied 155 | score_occluded[sample] += occluded 156 | score_free[sample] += free 157 | 158 | return score, score_occupied, score_free, score_occluded 159 | 160 | def toDiscrete(m): 161 | """ 162 | Args: 163 | - m (m,n) : np.array with the occupancy grid 164 | Returns: 165 | - discrete_m : thresholded m 166 | """ 167 | 168 | y_size, x_size = m.shape 169 | m_occupied = np.zeros(m.shape) 170 | m_free = np.zeros(m.shape) 171 | m_occluded = np.zeros(m.shape) 172 | 173 | m_occupied[m == 1.0] = 1.0 174 | m_occluded[m == 0.5] = 1.0 175 | m_free[m == 0.0] = 1.0 176 | 177 | return m_occupied, m_free, m_occluded 178 | 179 | def todMap(m): 180 | 181 | """ 182 | Extra if statements are for edge cases. 183 | """ 184 | 185 | y_size, x_size = m.shape 186 | dMap = np.ones(m.shape) * np.Inf 187 | dMap[m == 1] = 0.0 188 | 189 | for y in range(0,y_size): 190 | if y == 0: 191 | for x in range(1,x_size): 192 | h = dMap[y,x-1]+1 193 | dMap[y,x] = min(dMap[y,x], h) 194 | 195 | else: 196 | for x in range(0,x_size): 197 | if x == 0: 198 | h = dMap[y-1,x]+1 199 | dMap[y,x] = min(dMap[y,x], h) 200 | else: 201 | h = min(dMap[y,x-1]+1, dMap[y-1,x]+1) 202 | dMap[y,x] = min(dMap[y,x], h) 203 | 204 | for y in range(y_size-1,-1,-1): 205 | 206 | if y == y_size-1: 207 | for x in range(x_size-2,-1,-1): 208 | h = dMap[y,x+1]+1 209 | dMap[y,x] = min(dMap[y,x], h) 210 | 211 | else: 212 | for x in range(x_size-1,-1,-1): 213 | if x == x_size-1: 214 | h = dMap[y+1,x]+1 215 | dMap[y,x] = min(dMap[y,x], h) 216 | else: 217 | h = min(dMap[y+1,x]+1, dMap[y,x+1]+1) 218 | dMap[y,x] = min(dMap[y,x], h) 219 | 220 | return dMap 221 | 222 | def computeDistance(m1,m2): 223 | 224 | y_size, x_size = m1.shape 225 | dMap = todMap(m2) 226 | 227 | d = np.sum(dMap[m1 == 1]) 228 | num_cells = np.sum(m1 == 1) 229 | 230 | # If either of the grids does not have a particular class, 231 | # set to x_size + y_size (proxy for infinity - worst case Manhattan distance). 232 | # If both of the grids do not have a class, set to zero. 233 | if ((num_cells != 0) and (np.sum(dMap == np.Inf) == 0)): 234 | output = d/num_cells 235 | elif ((num_cells == 0) and (np.sum(dMap == np.Inf) != 0)): 236 | output = 0.0 237 | elif ((num_cells == 0) or (np.sum(dMap == np.Inf) != 0)): 238 | output = x_size + y_size 239 | 240 | if output == np.Inf: 241 | pdb.set_trace() 242 | 243 | return output 244 | 245 | def computeSimilarityMetric(m1, m2): 246 | 247 | m1_occupied, m1_free, m1_occluded = toDiscrete(m1) 248 | m2_occupied, m2_free, m2_occluded = toDiscrete(m2) 249 | 250 | occupied = computeDistance(m1_occupied,m2_occupied) + computeDistance(m2_occupied,m1_occupied) 251 | occluded = computeDistance(m2_occluded, m1_occluded) + computeDistance(m1_occluded, m2_occluded) 252 | free = computeDistance(m1_free,m2_free) + computeDistance(m2_free,m1_free) 253 | 254 | return occupied, free, occluded -------------------------------------------------------------------------------- /src/utils/map_vis_without_lanelet.py: -------------------------------------------------------------------------------- 1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | import matplotlib 4 | import matplotlib.axes 5 | import matplotlib.pyplot as plt 6 | 7 | import xml.etree.ElementTree as xml 8 | import pyproj 9 | import math 10 | import numpy as np 11 | 12 | # from utils import dict_utils 13 | 14 | 15 | class Point: 16 | def __init__(self): 17 | self.x = None 18 | self.y = None 19 | 20 | 21 | class LL2XYProjector: 22 | def __init__(self, lat_origin, lon_origin): 23 | self.lat_origin = lat_origin 24 | self.lon_origin = lon_origin 25 | self.zone = math.floor((lon_origin+180.)/6)+1 # works for most tiles, and for all in the dataset 26 | self.p = pyproj.Proj(proj='utm', ellps='WGS84', zone=self.zone, datum='WGS84') 27 | [self.x_origin, self.y_origin] = self.p(lon_origin, lat_origin) 28 | 29 | def latlon2xy(self, lat, lon): 30 | [x, y] = self.p(lon, lat) 31 | return [x-self.x_origin, y-self.y_origin] 32 | 33 | 34 | def get_type(element): 35 | for tag in element.findall("tag"): 36 | if tag.get("k") == "type": 37 | return tag.get("v") 38 | return None 39 | 40 | 41 | def get_subtype(element): 42 | for tag in element.findall("tag"): 43 | if tag.get("k") == "subtype": 44 | return tag.get("v") 45 | return None 46 | 47 | 48 | def get_x_y_lists(element, point_dict): 49 | x_list = list() 50 | y_list = list() 51 | for nd in element.findall("nd"): 52 | pt_id = int(nd.get("ref")) 53 | point = point_dict[pt_id] 54 | x_list.append(point.x) 55 | y_list.append(point.y) 56 | return x_list, y_list 57 | 58 | 59 | def set_visible_area(point_dict, axes): 60 | min_x = 10e9 61 | min_y = 10e9 62 | max_x = -10e9 63 | max_y = -10e9 64 | 65 | for id, point in dict_utils.get_item_iterator(point_dict): 66 | min_x = min(point.x, min_x) 67 | min_y = min(point.y, min_y) 68 | max_x = max(point.x, max_x) 69 | max_y = max(point.y, max_y) 70 | 71 | axes.set_aspect('equal', adjustable='box') 72 | axes.set_xlim([min_x - 10, max_x + 10]) 73 | axes.set_ylim([min_y - 10, max_y + 10]) 74 | 75 | return np.array([min_x, min_y, max_x, max_y]) 76 | 77 | 78 | def draw_map_without_lanelet(filename, axes, lat_origin, lon_origin): 79 | 80 | assert isinstance(axes, matplotlib.axes.Axes) 81 | 82 | axes.set_aspect('equal', adjustable='box') 83 | axes.patch.set_facecolor('lightgrey') 84 | 85 | projector = LL2XYProjector(lat_origin, lon_origin) 86 | 87 | e = xml.parse(filename).getroot() 88 | 89 | point_dict = dict() 90 | for node in e.findall("node"): 91 | point = Point() 92 | point.x, point.y = projector.latlon2xy(float(node.get('lat')), float(node.get('lon'))) 93 | point_dict[int(node.get('id'))] = point 94 | 95 | endpoint = set_visible_area(point_dict, axes) 96 | 97 | unknown_linestring_types = list() 98 | 99 | for way in e.findall('way'): 100 | way_type = get_type(way) 101 | if way_type is None: 102 | raise RuntimeError("Linestring type must be specified") 103 | elif way_type == "curbstone": 104 | type_dict = dict(color="black", linewidth=1, zorder=10) 105 | elif way_type == "line_thin": 106 | way_subtype = get_subtype(way) 107 | if way_subtype == "dashed": 108 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[10, 10]) 109 | else: 110 | type_dict = dict(color="white", linewidth=1, zorder=10) 111 | elif way_type == "line_thick": 112 | way_subtype = get_subtype(way) 113 | if way_subtype == "dashed": 114 | type_dict = dict(color="white", linewidth=2, zorder=10, dashes=[10, 10]) 115 | else: 116 | type_dict = dict(color="white", linewidth=2, zorder=10) 117 | elif way_type == "pedestrian_marking": 118 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[5, 10]) 119 | elif way_type == "bike_marking": 120 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[5, 10]) 121 | elif way_type == "stop_line": 122 | type_dict = dict(color="white", linewidth=3, zorder=10) 123 | elif way_type == "virtual": 124 | type_dict = dict(color="blue", linewidth=1, zorder=10, dashes=[2, 5]) 125 | elif way_type == "road_border": 126 | type_dict = dict(color="black", linewidth=1, zorder=10) 127 | elif way_type == "guard_rail": 128 | type_dict = dict(color="black", linewidth=1, zorder=10) 129 | elif way_type == "traffic_sign": 130 | continue 131 | else: 132 | if way_type not in unknown_linestring_types: 133 | unknown_linestring_types.append(way_type) 134 | continue 135 | 136 | x_list, y_list = get_x_y_lists(way, point_dict) 137 | plt.plot(x_list, y_list, **type_dict) 138 | 139 | if len(unknown_linestring_types) != 0: 140 | print("Found the following unknown types, did not plot them: " + str(unknown_linestring_types)) 141 | 142 | return endpoint 143 | -------------------------------------------------------------------------------- /src/utils/tracks_save.py: -------------------------------------------------------------------------------- 1 | # Save the full pipeline occlusion inference output. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | import matplotlib 4 | import matplotlib.patches 5 | import matplotlib.transforms 6 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox, AnchoredOffsetbox 7 | from scipy import ndimage 8 | import skimage.transform 9 | from PIL import Image 10 | import pdb 11 | from matplotlib import pyplot as plt 12 | from matplotlib.colors import ListedColormap 13 | from copy import deepcopy 14 | 15 | import os 16 | import time 17 | 18 | seed = 123 19 | 20 | import numpy as np 21 | np.random.seed(seed) 22 | from matplotlib import pyplot as plt 23 | import torch 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | import torch.nn as nn 29 | torch.autograd.set_detect_anomaly(True) 30 | 31 | import io 32 | from tqdm import tqdm 33 | import time 34 | 35 | import argparse 36 | import pandas as pd 37 | import seaborn as sns 38 | import matplotlib.pyplot as plt 39 | from collections import OrderedDict, defaultdict 40 | import torch._utils 41 | 42 | from src.utils.dataset_types import Track, MotionState 43 | from src.utils.grid_utils import * 44 | from src.utils.grid_fuse import * 45 | from src.utils.utils_model import to_var 46 | from src.driver_sensor_model.models_cvae import VAE 47 | from src.utils.interaction_utils import * 48 | 49 | try: 50 | torch._utils._rebuild_tensor_v2 51 | except AttributeError: 52 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 53 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 54 | tensor.requires_grad = requires_grad 55 | tensor._backward_hooks = backward_hooks 56 | return tensor 57 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 58 | 59 | def rotate_around_center(pts, center, yaw): 60 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center 61 | 62 | 63 | def polygon_xy_from_motionstate(ms, width, length): 64 | assert isinstance(ms, MotionState) 65 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 66 | lowright = (ms.x + length / 2., ms.y - width / 2.) 67 | upright = (ms.x + length / 2., ms.y + width / 2.) 68 | upleft = (ms.x - length / 2., ms.y + width / 2.) 69 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad) 70 | 71 | 72 | def polygon_xy_from_motionstate_pedest(ms, width, length): 73 | assert isinstance(ms, MotionState) 74 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 75 | lowright = (ms.x + length / 2., ms.y - width / 2.) 76 | upright = (ms.x + length / 2., ms.y + width / 2.) 77 | upleft = (ms.x - length / 2., ms.y + width / 2.) 78 | return np.array([lowleft, lowright, upright, upleft]) 79 | 80 | def update_objects_plot(timestamp, track_dict=None, pedest_dict=None, 81 | data=None, car_ids=None, sensor_grids=None, id_grids=None, label_grids=None, 82 | driver_sensor_data=None, driver_sensor_state_data=None, driver_sensor_state=None, endpoint=None, models=None, results=None, mode='evidential', model='vae'): 83 | 84 | ego_id = data[1][1] 85 | 86 | update = False 87 | 88 | if track_dict is not None: 89 | 90 | # Plot and save the ego-vehicle first. 91 | assert isinstance(track_dict[ego_id], Track) 92 | value = track_dict[ego_id] 93 | if (value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last): 94 | ms_ego = value.motion_states[timestamp] 95 | assert isinstance(ms_ego, MotionState) 96 | 97 | width = value.width 98 | length = value.length 99 | 100 | # Obtain the ego vehicle's grid. 101 | res = 1.0 102 | object_id, pedes_id = SceneObjects(track_dict, timestamp, track_pedes_dict=pedest_dict) 103 | local_x_ego, local_y_ego, _, _ = local_grid(ms_ego, width, length, res=res, ego_flag=True) 104 | label_grid_ego, center, x_local, y_local, pre_local_x, pre_local_y = generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=pedest_dict, pedes_id=pedes_id) 105 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x, pre_local_y, ms_ego, width, length, ego_flag=True, res=res) 106 | visible_id += [ego_id] 107 | 108 | full_sensor_grid_dst, mask_unk = get_belief_mass(sensor_grid_ego, ego_flag=True) 109 | full_sensor_grid = pignistic(full_sensor_grid_dst) 110 | 111 | run_time = 0. 112 | 113 | # Initialize the variables to keep for later computation. 114 | if model != 'kmeans': 115 | all_latent_classes = [] 116 | ref_local_xy_list = [] 117 | ego_local_xy_list = [] 118 | alpha_p_list = [] 119 | sensor_grid_ego_dst = [full_sensor_grid_dst] 120 | 121 | if mode == 'average': 122 | average_mask = np.zeros(full_sensor_grid.shape) 123 | driver_sensor_grid = np.zeros(full_sensor_grid.shape) 124 | 125 | # Consider the rest of the agents. 126 | for key, value in track_dict.items(): 127 | assert isinstance(value, Track) 128 | if ((value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last)): 129 | ms = value.motion_states[timestamp] 130 | assert isinstance(ms, MotionState) 131 | 132 | width = value.width 133 | length = value.length 134 | 135 | # Consider all the visible drivers. 136 | if ((key in visible_id) and (key != ego_id)): 137 | res = 1.0 138 | if key in driver_sensor_state_data.keys(): 139 | for state in driver_sensor_state_data[key]: 140 | if state[0] == timestamp: 141 | if key in driver_sensor_state.keys(): 142 | 143 | # Make sure that the states are contiguous. 144 | if state[0] - driver_sensor_state[key][-1,0] == 100: 145 | driver_sensor_state[key] = np.concatenate((driver_sensor_state[key], np.reshape(state, (1,-1)))) 146 | else: 147 | driver_sensor_state[key] = np.reshape(state, (1,-1)) 148 | else: 149 | driver_sensor_state[key] = np.reshape(state, (1,-1)) 150 | 151 | # Perform occlusion inference if at least 1 second of observed driver state has been observed. 152 | if ((driver_sensor_state[key].shape[0] == 10)): 153 | 154 | # Flag that the map is being updated. 155 | update = True 156 | 157 | start = time.time() 158 | 159 | x_local, y_local, _, _ = local_grid(ms, width, length, res=res, ego_flag=False, grid_shape=(20,30)) 160 | 161 | # Merge grid with driver_sensor_grid. 162 | ref_local_xy = np.stack((x_local, y_local), axis=0) 163 | ego_local_xy = np.stack((local_x_ego, local_y_ego), axis=0) 164 | 165 | if model == 'kmeans': 166 | input_state = driver_sensor_state[key][:,1:] 167 | input_state = np.expand_dims(input_state.flatten(), 0) 168 | [kmeans, p_m_a_np] = models['kmeans'] 169 | cluster_ids_y_val = kmeans.predict(input_state.astype('float32')) 170 | pred_maps = p_m_a_np[cluster_ids_y_val] 171 | pred_maps = np.reshape(pred_maps[0], (20,30)) 172 | 173 | elif model == 'gmm': 174 | input_state = driver_sensor_state[key][:,1:] 175 | input_state = np.expand_dims(input_state.flatten(), 0) 176 | [gmm, p_m_a_np] = models['gmm'] 177 | cluster_ids_y_val = gmm.predict(input_state.astype('float32')) 178 | alpha_p = gmm.predict_proba(input_state.astype('float32')) 179 | pred_maps = p_m_a_np[cluster_ids_y_val] 180 | pred_maps = np.reshape(pred_maps[0], (20,30)) 181 | 182 | if len(all_latent_classes) == 0: 183 | all_latent_classes = [np.reshape(p_m_a_np, (100,20,30))] 184 | 185 | elif model == 'vae': 186 | input_state = preprocess(driver_sensor_state[key][:,1:]) 187 | input_state = torch.unsqueeze(to_var(torch.from_numpy(input_state)), 0).float().cuda() 188 | 189 | models['vae'].eval() 190 | with torch.no_grad(): 191 | pred_maps, alpha_p, _, _, z = models['vae'].inference(n=1, c=input_state, mode='most_likely') 192 | if len(all_latent_classes) == 0: 193 | recon_x_inf, _, _, _, _ = models['vae'].inference(n=100, c=input_state, mode='all') 194 | all_latent_classes = [recon_x_inf.cpu().numpy()] 195 | 196 | pred_maps = pred_maps[0][0].cpu().numpy() 197 | alpha_p = alpha_p.cpu().numpy() 198 | 199 | # Transfer the driver sensor model prediction to the ego vehicle's frame of reference. 200 | predEgoMaps = Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, full_sensor_grid, endpoint=endpoint, res=res, mask_unk=mask_unk) 201 | 202 | if model != 'kmeans': 203 | if not np.all(predEgoMaps == 2): 204 | alpha_p_list.append(alpha_p[0]) 205 | ref_local_xy_list.append(ref_local_xy) 206 | ego_local_xy_list.append(ego_local_xy) 207 | 208 | # Fuse the driver sensor model prediction into the ego vehicle's grid. 209 | if mode == 'evidential': 210 | driver_sensor_grid_dst, _ = get_belief_mass(predEgoMaps, ego_flag=False, m=0.95) 211 | full_sensor_grid_dst, full_sensor_grid = dst_fusion(driver_sensor_grid_dst, full_sensor_grid_dst, mask_unk) 212 | elif mode == 'average': 213 | average_mask[predEgoMaps != 2] += 1.0 214 | driver_sensor_grid[predEgoMaps != 2] += predEgoMaps[predEgoMaps != 2] 215 | 216 | run_time += (time.time() - start) 217 | 218 | # Remove the oldest state. 219 | driver_sensor_state[key] = driver_sensor_state[key][1:] 220 | 221 | break 222 | 223 | # Save data if the map is updated. 224 | if update: 225 | 226 | if mode == 'average': 227 | driver_sensor_grid[average_mask != 0] /= average_mask[average_mask != 0] 228 | full_sensor_grid[average_mask != 0] = driver_sensor_grid[average_mask != 0] 229 | 230 | sensor_grid_ego[sensor_grid_ego == 2] = 0.5 231 | results['ego_sensor'].append(sensor_grid_ego) 232 | results['ego_label'].append(label_grid_ego[0]) 233 | results['vae'].append(full_sensor_grid) 234 | results['timestamp'].append(timestamp) 235 | results['run_time'].append(run_time) 236 | if model != 'kmeans': 237 | results['all_latent_classes'].append(all_latent_classes[0]) 238 | results['ref_local_xy'].append(ref_local_xy_list) 239 | results['ego_local_xy'].append(ego_local_xy_list) 240 | results['alpha_p'].append(alpha_p_list) 241 | results['ego_sensor_dst'].append(sensor_grid_ego_dst[0]) 242 | results['endpoint'].append(endpoint) 243 | results['res'].append(res) -------------------------------------------------------------------------------- /src/utils/tracks_vis.py: -------------------------------------------------------------------------------- 1 | # Plot the full pipeline occlusion inference. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset. 2 | 3 | import matplotlib 4 | import matplotlib.patches 5 | import matplotlib.transforms 6 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox, AnchoredOffsetbox 7 | from scipy import ndimage 8 | import skimage.transform 9 | from PIL import Image 10 | import pdb 11 | from matplotlib import pyplot as plt 12 | from matplotlib.colors import ListedColormap 13 | from copy import deepcopy 14 | 15 | import os 16 | import time 17 | 18 | seed = 123 19 | 20 | import numpy as np 21 | np.random.seed(seed) 22 | from matplotlib import pyplot as plt 23 | import torch 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | import torch.nn as nn 29 | torch.autograd.set_detect_anomaly(True) 30 | 31 | import io 32 | from tqdm import tqdm 33 | import time 34 | 35 | import argparse 36 | import pandas as pd 37 | import seaborn as sns 38 | import matplotlib.pyplot as plt 39 | from collections import OrderedDict, defaultdict 40 | import torch._utils 41 | 42 | from src.utils.dataset_types import Track, MotionState 43 | from src.utils.grid_utils import * 44 | from src.utils.grid_fuse import * 45 | from src.utils.utils_model import to_var 46 | from src.driver_sensor_model.models_cvae import VAE 47 | from src.utils.interaction_utils import * 48 | 49 | try: 50 | torch._utils._rebuild_tensor_v2 51 | except AttributeError: 52 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks): 53 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) 54 | tensor.requires_grad = requires_grad 55 | tensor._backward_hooks = backward_hooks 56 | return tensor 57 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2 58 | 59 | def rotate_around_center(pts, center, yaw): 60 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center 61 | 62 | def polygon_xy_from_motionstate(ms, width, length): 63 | assert isinstance(ms, MotionState) 64 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 65 | lowright = (ms.x + length / 2., ms.y - width / 2.) 66 | upright = (ms.x + length / 2., ms.y + width / 2.) 67 | upleft = (ms.x - length / 2., ms.y + width / 2.) 68 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad) 69 | 70 | 71 | def polygon_xy_from_motionstate_pedest(ms, width, length): 72 | assert isinstance(ms, MotionState) 73 | lowleft = (ms.x - length / 2., ms.y - width / 2.) 74 | lowright = (ms.x + length / 2., ms.y - width / 2.) 75 | upright = (ms.x + length / 2., ms.y + width / 2.) 76 | upleft = (ms.x - length / 2., ms.y + width / 2.) 77 | return np.array([lowleft, lowright, upright, upleft]) 78 | 79 | def update_objects_plot(timestamp, patches_dict, text_dict, axes, track_dict=None, pedest_dict=None, 80 | data=None, car_ids=None, sensor_grids=None, id_grids=None, label_grids=None, grids_dict=None, 81 | driver_sensor_data=None, driver_sensor_state_data=None, driver_sensor_state=None, driver_sensor_state_dict=None, endpoint=None, models=None, mode='evidential', model='vae'): 82 | 83 | ego_id = data[1][1] 84 | 85 | if track_dict is not None: 86 | 87 | # Plot and save the ego-vehicle first. 88 | assert isinstance(track_dict[ego_id], Track) 89 | value = track_dict[ego_id] 90 | if (value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last): 91 | 92 | ms_ego = value.motion_states[timestamp] 93 | assert isinstance(ms_ego, MotionState) 94 | 95 | # Check if the ego vehicle already exists. 96 | if ego_id not in patches_dict: 97 | width = value.width 98 | length = value.length 99 | 100 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms_ego, width, length), closed=True, 101 | zorder=40, color='green') 102 | 103 | axes.add_patch(rect) 104 | patches_dict[ego_id] = rect 105 | text_dict[ego_id] = axes.text(ms_ego.x, ms_ego.y + 3, str(ego_id), horizontalalignment='center', zorder=50, fontsize='xx-large') 106 | 107 | else: 108 | width = value.width 109 | length = value.length 110 | patches_dict[ego_id].set_xy(polygon_xy_from_motionstate(ms_ego, width, length)) 111 | 112 | text_dict[ego_id].set_position((ms_ego.x, ms_ego.y + 3)) 113 | 114 | # Set colors of the ego vehicle. 115 | patches_dict[ego_id].set_color('green') 116 | 117 | if ego_id in grids_dict.keys(): 118 | grids_dict[ego_id].remove() 119 | 120 | # Plot the ego vehicle's vanilla grid. 121 | res = 1.0 122 | colormap = ListedColormap(['white','black','gray']) 123 | object_id, pedes_id = SceneObjects(track_dict, timestamp, track_pedes_dict=pedest_dict) 124 | local_x_ego, local_y_ego, _, _ = local_grid(ms_ego, width, length, res=res, ego_flag=True) 125 | label_grid_ego, center, x_local, y_local, pre_local_x, pre_local_y = generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=pedest_dict, pedes_id=pedes_id) 126 | start = time.time() 127 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x, pre_local_y, ms_ego, width, length, ego_flag=True, res=res) 128 | visible_id += [ego_id] 129 | # sensor_grid_ego = 2.0*np.ones((sensor_grid_ego.shape)) 130 | 131 | full_sensor_grid_dst, mask_unk = get_belief_mass(sensor_grid_ego, ego_flag=True) 132 | full_sensor_grid = pignistic(full_sensor_grid_dst) 133 | full_sensor_grid_dst[0,label_grid_ego[0] == 2] = 0.0 134 | full_sensor_grid_dst[1,label_grid_ego[0] == 2] = 1.0 135 | full_sensor_grid[label_grid_ego[0] == 2] = 0.0 136 | sensor_grid_ego[label_grid_ego[0] == 2] = 0.0 137 | 138 | box = axes.pcolormesh(local_x_ego, local_y_ego, full_sensor_grid, cmap='gray_r', zorder=30, alpha=0.7, vmin=0, vmax=1) 139 | grids_dict[ego_id] = box 140 | 141 | # Initialize the variables to keep for later computation. 142 | if model != 'kmeans': 143 | all_latent_classes = [] 144 | ref_local_xy_list = [] 145 | ego_local_xy_list = [] 146 | alpha_p_list = [] 147 | sensor_grid_ego_dst = [full_sensor_grid_dst] 148 | 149 | if mode == 'average': 150 | average_mask = np.zeros(full_sensor_grid.shape) 151 | driver_sensor_grid = np.zeros(full_sensor_grid.shape) 152 | 153 | # Plot the rest of the agents. 154 | for key, value in track_dict.items(): 155 | assert isinstance(value, Track) 156 | if ((value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last)): 157 | 158 | ms = value.motion_states[timestamp] 159 | assert isinstance(ms, MotionState) 160 | 161 | if key not in patches_dict: 162 | width = value.width 163 | length = value.length 164 | 165 | if ((key in visible_id) and (key != ego_id)): 166 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms, width, length), closed=True, 167 | zorder=40, color='cyan') 168 | colormap = ListedColormap(['white','black','gray']) 169 | res = 1.0 170 | if key in driver_sensor_state_data.keys(): 171 | for state in driver_sensor_state_data[key]: 172 | if state[0] == timestamp: 173 | driver_sensor_state[key] = np.reshape(state, (1,-1)) 174 | break 175 | 176 | elif ((key not in visible_id) and (key != ego_id)) : 177 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms, width, length), closed=True, 178 | zorder=20) 179 | patches_dict[key] = rect 180 | axes.add_patch(rect) 181 | text_dict[key] = axes.text(ms.x, ms.y + 3, str(key), horizontalalignment='center', zorder=50, fontsize='xx-large') 182 | 183 | else: 184 | width = value.width 185 | length = value.length 186 | patches_dict[key].set_xy(polygon_xy_from_motionstate(ms, width, length)) 187 | if (key in visible_id): 188 | patches_dict[key].set_zorder(40) 189 | elif (key not in visible_id): 190 | patches_dict[key].set_zorder(20) 191 | text_dict[key].set_position((ms.x, ms.y + 3)) 192 | 193 | # Consider all the visible drivers. 194 | if ((key in visible_id) and (key != ego_id)): 195 | patches_dict[key].set_color('cyan') 196 | 197 | colormap = ListedColormap(['white','black','gray']) 198 | res = 1.0 199 | if key in driver_sensor_state_data.keys(): 200 | for state in driver_sensor_state_data[key]: 201 | if state[0] == timestamp: 202 | if key in driver_sensor_state.keys(): 203 | 204 | # Make sure that the states are contiguous. 205 | if state[0] - driver_sensor_state[key][-1,0] == 100: 206 | driver_sensor_state[key] = np.concatenate((driver_sensor_state[key], np.reshape(state, (1,-1)))) 207 | else: 208 | driver_sensor_state[key] = np.reshape(state, (1,-1)) 209 | else: 210 | driver_sensor_state[key] = np.reshape(state, (1,-1)) 211 | 212 | # Perform occlusion inference if at least 1 second of observed driver state has been observed. 213 | if ((driver_sensor_state[key].shape[0] == 10)): # and key == 116): 214 | 215 | # Clear the existing plot. 216 | if key in driver_sensor_state_dict.keys(): 217 | driver_sensor_state_dict[key].remove() 218 | 219 | box_sensor_state = axes.scatter(driver_sensor_state[key][:,1], driver_sensor_state[key][:,2], s=20, color='orange', zorder=50, alpha=1.0) 220 | driver_sensor_state_dict[key] = box_sensor_state 221 | 222 | x_local, y_local, _, _ = local_grid(ms, width, length, res=res, ego_flag=False, grid_shape=(20,30)) 223 | 224 | # Clear the ego grid plot. 225 | if ego_id in grids_dict.keys(): 226 | grids_dict[ego_id].remove() 227 | 228 | # Merge grid with driver_sensor_grid. 229 | ref_local_xy = np.stack((x_local, y_local), axis=0) 230 | ego_local_xy = np.stack((local_x_ego, local_y_ego), axis=0) 231 | 232 | if model == 'kmeans': 233 | input_state = driver_sensor_state[key][:,1:] 234 | input_state = np.expand_dims(input_state.flatten(), 0) 235 | [kmeans, p_m_a_np] = models['kmeans'] 236 | cluster_ids_y_val = kmeans.predict(input_state.astype('float32')) 237 | pred_maps = p_m_a_np[cluster_ids_y_val] 238 | pred_maps = np.reshape(pred_maps[0], (20,30)) 239 | 240 | elif model == 'gmm': 241 | input_state = driver_sensor_state[key][:,1:] 242 | input_state = np.expand_dims(input_state.flatten(), 0) 243 | [gmm, p_m_a_np] = models['gmm'] 244 | cluster_ids_y_val = gmm.predict(input_state.astype('float32')) 245 | alpha_p = gmm.predict_proba(input_state.astype('float32')) 246 | pred_maps = p_m_a_np[cluster_ids_y_val] 247 | pred_maps = np.reshape(pred_maps[0], (20,30)) 248 | 249 | if len(all_latent_classes) == 0: 250 | all_latent_classes = [np.reshape(p_m_a_np, (100,20,30))] 251 | 252 | elif model == 'vae': 253 | input_state = preprocess(driver_sensor_state[key][:,1:]) 254 | input_state = torch.unsqueeze(to_var(torch.from_numpy(input_state)), 0).float().cuda() 255 | 256 | models['vae'].eval() 257 | with torch.no_grad(): 258 | pred_maps, alpha_p, _, _, z = models['vae'].inference(n=1, c=input_state, mode='most_likely') 259 | if len(all_latent_classes) == 0: 260 | recon_x_inf, _, _, _, _ = models['vae'].inference(n=100, c=input_state, mode='all') 261 | all_latent_classes = [recon_x_inf.cpu().numpy()] 262 | 263 | pred_maps = pred_maps[0][0].cpu().numpy() 264 | alpha_p = alpha_p.cpu().numpy() 265 | 266 | # print(key,np.sort(alpha_p[0])) 267 | 268 | # Transfer the driver sensor model prediction to the ego vehicle's frame of reference. 269 | predEgoMaps = Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, full_sensor_grid, endpoint=endpoint, res=res, mask_unk=mask_unk) 270 | 271 | if model != 'kmeans': 272 | if not np.all(predEgoMaps == 2): 273 | alpha_p_list.append(alpha_p[0]) 274 | ref_local_xy_list.append(ref_local_xy) 275 | ego_local_xy_list.append(ego_local_xy) 276 | 277 | # Fuse the driver sensor model prediction into the ego vehicle's grid. 278 | if mode == 'evidential': 279 | driver_sensor_grid_dst, _ = get_belief_mass(predEgoMaps, ego_flag=False, m=0.95) 280 | full_sensor_grid_dst, full_sensor_grid = dst_fusion(driver_sensor_grid_dst, full_sensor_grid_dst, mask_unk) 281 | 282 | elif mode == 'average': 283 | average_mask[predEgoMaps != 2] += 1.0 284 | driver_sensor_grid[predEgoMaps != 2] += predEgoMaps[predEgoMaps != 2] 285 | 286 | grids_dict[ego_id] = axes.pcolormesh(local_x_ego, local_y_ego, full_sensor_grid, cmap='gray_r', zorder=30, alpha=0.7, vmin=0, vmax=1) 287 | 288 | # Remove the oldest state. 289 | driver_sensor_state[key] = driver_sensor_state[key][1:] 290 | 291 | break 292 | elif ((key not in visible_id) and (key != ego_id)): 293 | patches_dict[key].set_color('#1f77b4') 294 | if key in driver_sensor_state_dict.keys(): 295 | driver_sensor_state_dict[key].remove() 296 | del driver_sensor_state_dict[key] 297 | 298 | else: 299 | if key in patches_dict: 300 | patches_dict[key].remove() 301 | patches_dict.pop(key) 302 | text_dict[key].remove() 303 | text_dict.pop(key) 304 | if key in driver_sensor_state_dict.keys(): 305 | driver_sensor_state_dict[key].remove() 306 | del driver_sensor_state_dict[key] 307 | 308 | # Plot the pedestrians. 309 | if pedest_dict is not None: 310 | 311 | for key, value in pedest_dict.items(): 312 | assert isinstance(value, Track) 313 | if value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last: 314 | ms = value.motion_states[timestamp] 315 | assert isinstance(ms, MotionState) 316 | 317 | if key not in patches_dict: 318 | width = 1.5 319 | length = 1.5 320 | 321 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate_pedest(ms, width, length), closed=True, 322 | zorder=20, color='red') 323 | patches_dict[key] = rect 324 | axes.add_patch(rect) 325 | text_dict[key] = axes.text(ms.x, ms.y + 3, str(key), horizontalalignment='center', zorder=50, fontsize='xx-large') 326 | else: 327 | width = 1.5 328 | length = 1.5 329 | patches_dict[key].set_xy(polygon_xy_from_motionstate_pedest(ms, width, length)) 330 | text_dict[key].set_position((ms.x, ms.y + 3)) 331 | else: 332 | if key in patches_dict: 333 | patches_dict[key].remove() 334 | patches_dict.pop(key) 335 | text_dict[key].remove() 336 | text_dict.pop(key) 337 | 338 | -------------------------------------------------------------------------------- /src/utils/utils_model.py: -------------------------------------------------------------------------------- 1 | # Helper code for the CVAE driver sensor model. Code is adapted from: https://github.com/sisl/EvidentialSparsification. 2 | 3 | seed = 123 4 | import numpy as np 5 | np.random.seed(seed) 6 | import torch 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed(seed) 9 | 10 | from torch.autograd import Variable 11 | 12 | def to_var(x, volatile=False): 13 | if torch.cuda.is_available(): 14 | x = x.cuda() 15 | return Variable(x, volatile=volatile) 16 | 17 | def idx2onehot(idx, n): 18 | 19 | assert idx.size(1) == 1 20 | assert torch.max(idx).data < n 21 | 22 | onehot = torch.zeros(idx.size(0), n).cuda() 23 | onehot.scatter_(1, idx.data, 1) 24 | onehot = to_var(onehot) 25 | 26 | return onehot 27 | 28 | def sample_p(alpha, batch_size=1): 29 | zdist = torch.distributions.one_hot_categorical.OneHotCategorical(probs = alpha) 30 | return zdist.sample(torch.Size([batch_size])) 31 | --------------------------------------------------------------------------------