├── LICENSE
├── README.md
├── dependencies.txt
├── scripts
├── run_full_pipeline.sh
├── run_preprocess_data.sh
└── train_and_test_driver_sensor_model.sh
└── src
├── driver_sensor_model
├── inference_cvae.py
├── inference_gmm.py
├── inference_kmeans.py
├── models_cvae.py
├── train_cvae.py
├── train_gmm.py
├── train_kmeans.py
└── visualize_cvae.ipynb
├── full_pipeline
├── full_pipeline_metrics.py
├── main_save_full_pipeline.py
└── main_visualize_full_pipeline.py
├── preprocess
├── generate_data.py
├── get_driver_sensor_data.py
├── preprocess_driver_sensor_data.py
├── statistics.txt
└── train_val_test_split.py
└── utils
├── combinations.py
├── data_generator.py
├── dataset_reader.py
├── dataset_types.py
├── grid_fuse.py
├── grid_utils.py
├── interaction_utils.py
├── map_vis_without_lanelet.py
├── tracks_save.py
├── tracks_vis.py
└── utils_model.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MultiAgentVariationalOcclusionInference
2 | Multi-agent occlusion inference using observed driver behaviors. A driver sensor model is learned using a conditional variational autoencoder which maps an observed driver trajectory to the space ahead of the driver, represented as an occupancy grid map (OGM). Information from multiple drivers is fused into an ego vehicle's map using evidential theory. See our [video](https://www.youtube.com/watch?v=cTHl5nDBNBM) and [paper](https://arxiv.org/abs/2109.02173) for more details:
3 |
4 | M. Itkina, Y.-J. Mun, K. Driggs-Campbell, and M. J. Kochenderfer. "Multi-Agent Variational Occlusion Inference Using People as Sensors". In International Conference on Robotics and Automation (ICRA), 2022.
5 |
6 |
7 |
8 |
9 |
10 | **Approach Overview:** Our proposed occlusion inference approach. The learned driver sensor model maps behaviors of visible drivers (cyan) to an OGM of the environment ahead of them (gray). The inferred OGMs are then fused into the ego vehicle’s (green) map. Our goal is to infer the presence or absence of occluded agents (blue). Driver 1 is waiting to turn left, occluding oncoming traffic from the ego vehicle’s view. Driver 1 being stopped should indicate to the ego that there may be oncoming traffic; and thus it is not safe to proceed with its right turn. Driver 2 is driving at a constant speed. This observed behavior is not enough to discern whether the vehicle is traveling in traffic or on an open road. We aim to encode such intuition into our occlusion inference algorithm.
11 |
12 |
13 |
14 |
15 |
16 | **Results (Example Scenario 1):** The ego vehicle (green) is waiting to turn right. Observed driver 33 is waiting for a gap in cross-traffic to make a left turn, blocking the view of the ego vehicle. Our occlusion inference algorithms produces occupancy estimates as soon as 1 second of trajectory data is accumulated for an observed driver. While driver 33 is stopped (left), our algorithm estimates occupied space in the region ahead of driver 33, encompassing occluded driver 92. When driver 33 accelerates (right), this space is inferred to be free, indicating to the ego vehicle to proceed with its right turn. These results match our intuition for these observed behaviors.
17 |
18 |
19 |
20 |
21 |
22 | **Results (Example Scenario 2):** Observed driver 106 is waiting to turn left ahead of the ego vehicle (green). Observed drivers 102, 113, and 116 are also waiting for a gap in cross-traffic. While drivers 106, 102, 113, and 116 are all stopped (left), our algorithm estimates highly occupied space surrounding occluded driver 112 due to multiple agreeing measurements. When driver 106 starts their left turn (right), this space is estimated to be more free, indicating to the ego vehicle that it may be safe to complete its maneuver (e.g., a U-turn).
23 |
24 | ## Instructions
25 | The code reproduces the qualitative and quantitative experiments in the paper. The required dependencies are listed in dependencies.txt. Note that the INTERACTION dataset for the GL intersection has to be downloaded from: https://interaction-dataset.com/ and placed into the `data` directory. Then, a directory: `data/INTERACTION-Dataset-DR-v1_1` should exist.
26 |
27 | To run the experiments, please run the following files:
28 |
29 | - To process the data:
30 | `scripts/run_preprocess_data.sh`.
31 |
32 | - To train and test the driver sensor models (ours, k-means PaS, GMM PaS):
33 | `scripts/train_and_test_driver_sensor_model.sh`.
34 |
35 | - To run and evaluate the multi-agent occlusion inference pipeline:
36 | `scripts/run_full_pipeline.sh`.
37 |
38 | Running the above code will reproduce the numbers in the tables reported in the paper. To visualize the qualitative results, see `src/driver_sensor_model/visualize_cvae.ipynb` (Fig. 3) and `src/full_pipeline/main_visualize_full_pipeline.py` (Fig. 4).
39 |
40 | ## Data Split
41 | To process the data, ego vehicle IDs were sampled from the GL intersection in the [INTERACTION dataset](https://interaction-dataset.com/). For each of the 60 available scenes, a maximum of 100 ego vehicles were chosen. The train, validation, and test set were randomly split based on the ego vehicle IDs accounting for 85%, 5% and 10% of the total number of ego vehicles, respectively. Due to computational constraints, the training set for the driver sensor model was further reduced to have 70,001 contiguous driver sensor trajectories or 2,602,332 time steps of data. The validation set used to select the driver sensor model and the sensor fusion scheme contained 4,858 contiguous driver sensor trajectories or 180,244 time steps and 289 ego vehicles. The results presented in this paper were reported on the test set, which consists of 9,884 contiguous driver sensor trajectories or 365,201 time steps and 578 ego vehicles.
42 |
43 | ## CVAE Driver Sensor Model Architecture and Training
44 | We set the number of latent classes in the CVAE to K = 100 based on computational time and tractability for the considered baselines. We standardize the trajectory data to have zero mean and unit standard deviation. The prior encoder in the model consists of an LSTM with a hidden dimension of 5 to process the 1 s of trajectory input data. A linear layer then coverts the output into a K-dimensional vector that goes into a softmax function, producing the prior distribution. The posterior encoder extracts features from the ground truth input OGM using a [VQ-VAE](https://arxiv.org/abs/1711.00937) backbone with a hidden dimension of 4. These features are flattened and concatenated with the LSTM output from the trajectory data, and then passed into a linear layer and a softmax function, producing the posterior distribution. The decoder passes the latent encoding through two linear layers with ReLU activation functions and a transposed VQ-VAE backbone, outputting the inferred OGM.
45 |
46 | To avoid latent space collapse, we clamped the KL divergence term in the loss in at 0.2. Additionally, we anneal the beta hyperparameter in the loss according to a sigmoid schedule. In our hyperparameter search, we found a maximum beta of 1 with a crossover point at 10,000 iterations to work well. The beta hyperparameter increases from a value of 0 to 1 over 1,000 iterations. We set the alpha hyperparameter to 1.5. We trained the network with a batch size of 256 for 30 epochs using the [Adam optimizer](https://arxiv.org/abs/1412.6980) with a starting learning rate of 0.001.
47 |
--------------------------------------------------------------------------------
/dependencies.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | absl-py=0.9.0=pypi_0
6 | argon2-cffi=20.1.0=py36h27cfd23_1
7 | async_generator=1.10=py36h28b3542_0
8 | attrs=19.3.0=pypi_0
9 | backcall=0.2.0=pyhd3eb1b0_0
10 | bleach=3.1.5=pypi_0
11 | ca-certificates=2021.5.25=h06a4308_1
12 | cachetools=4.1.0=pypi_0
13 | certifi=2021.5.30=py36h06a4308_0
14 | cffi=1.14.5=py36h261ae71_0
15 | chardet=3.0.4=pypi_0
16 | cvxpy=1.1.1=pypi_0
17 | cycler=0.10.0=pypi_0
18 | decorator=4.4.2=pypi_0
19 | defusedxml=0.6.0=pypi_0
20 | diffcp=1.0.13=pypi_0
21 | dill=0.3.2=pypi_0
22 | ecos=2.0.7.post1=pypi_0
23 | entrypoints=0.3=pypi_0
24 | future=0.18.2=pypi_0
25 | google-auth=1.16.0=pypi_0
26 | google-auth-oauthlib=0.4.1=pypi_0
27 | grpcio=1.29.0=pypi_0
28 | h5py=2.10.0=pypi_0
29 | hickle=4.0.0=pypi_0
30 | idna=2.9=pypi_0
31 | imageio=2.8.0=pypi_0
32 | importlib-metadata=1.6.0=pypi_0
33 | importlib_metadata=3.10.0=hd3eb1b0_0
34 | ipykernel=5.3.2=pypi_0
35 | ipython=7.16.1=py36h5ca1d4c_0
36 | ipython_genutils=0.2.0=pyhd3eb1b0_1
37 | ipywidgets=7.6.3=pyhd3eb1b0_1
38 | jedi=0.17.1=pypi_0
39 | jinja2=2.11.2=pypi_0
40 | joblib=0.16.0=pypi_0
41 | json5=0.9.5=pypi_0
42 | jsonschema=3.2.0=py_2
43 | jupyter-client=6.1.6=pypi_0
44 | jupyter-core=4.6.3=pypi_0
45 | jupyter_client=6.1.12=pyhd3eb1b0_0
46 | jupyter_core=4.7.1=py36h06a4308_0
47 | jupyterlab=2.2.0=pypi_0
48 | jupyterlab-server=1.2.0=pypi_0
49 | jupyterlab_pygments=0.1.2=py_0
50 | jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
51 | kiwisolver=1.2.0=pypi_0
52 | kmeans-pytorch=0.3=pypi_0
53 | ld_impl_linux-64=2.33.1=h53a641e_7
54 | libedit=3.1.20181209=hc058e9b_0
55 | libffi=3.3=he6710b0_1
56 | libgcc-ng=9.1.0=hdf63c60_0
57 | libsodium=1.0.18=h7b6447c_0
58 | libstdcxx-ng=9.1.0=hdf63c60_0
59 | markdown=3.2.2=pypi_0
60 | markupsafe=1.1.1=pypi_0
61 | matplotlib=3.2.1=pypi_0
62 | mistune=0.8.4=py36h7b6447c_0
63 | nbclient=0.5.3=pyhd3eb1b0_0
64 | nbconvert=5.6.1=pypi_0
65 | nbformat=5.0.7=pypi_0
66 | ncurses=6.2=he6710b0_1
67 | nest-asyncio=1.5.1=pyhd3eb1b0_0
68 | networkx=2.5=pypi_0
69 | notebook=6.0.3=pypi_0
70 | numpy=1.18.4=pypi_0
71 | oauthlib=3.1.0=pypi_0
72 | opencv-python=4.4.0.46=pypi_0
73 | openssl=1.1.1k=h27cfd23_0
74 | osqp=0.6.1=pypi_0
75 | packaging=20.4=pypi_0
76 | pandas=1.1.0=pypi_0
77 | pandoc=2.12=h06a4308_0
78 | pandocfilters=1.4.2=pypi_0
79 | parso=0.7.0=pypi_0
80 | pathlib=1.0.1=pypi_0
81 | pexpect=4.8.0=pyhd3eb1b0_3
82 | pickleshare=0.7.5=pyhd3eb1b0_1003
83 | pillow=7.1.2=pypi_0
84 | pip=20.0.2=py36_3
85 | prometheus-client=0.8.0=pypi_0
86 | prometheus_client=0.10.1=pyhd3eb1b0_0
87 | prompt-toolkit=3.0.5=pypi_0
88 | protobuf=3.12.2=pypi_0
89 | ptyprocess=0.6.0=pypi_0
90 | pwlf=2.0.4=pypi_0
91 | pyasn1=0.4.8=pypi_0
92 | pyasn1-modules=0.2.8=pypi_0
93 | pybind11=2.5.0=pypi_0
94 | pycparser=2.20=py_2
95 | pydoe=0.3.8=pypi_0
96 | pygments=2.6.1=pypi_0
97 | pyparsing=2.4.7=pyhd3eb1b0_0
98 | pyproj=2.6.1.post1=pypi_0
99 | pyrsistent=0.16.0=pypi_0
100 | python=3.6.10=h7579374_2
101 | python-dateutil=2.8.1=pyhd3eb1b0_0
102 | pytorch-lightning=0.7.6=pypi_0
103 | pytz=2020.1=pypi_0
104 | pywavelets=1.1.1=pypi_0
105 | pyyaml=5.3.1=pypi_0
106 | pyzmq=19.0.1=pypi_0
107 | readline=8.0=h7b6447c_0
108 | requests=2.23.0=pypi_0
109 | requests-oauthlib=1.3.0=pypi_0
110 | rsa=4.0=pypi_0
111 | scikit-image=0.17.2=pypi_0
112 | scikit-learn=0.23.2=pypi_0
113 | scipy=1.4.1=pypi_0
114 | scs=2.1.2=pypi_0
115 | seaborn=0.10.1=pypi_0
116 | send2trash=1.5.0=pyhd3eb1b0_1
117 | setuptools=46.4.0=py36_0
118 | six=1.15.0=py36h06a4308_0
119 | sklearn=0.0=pypi_0
120 | sqlite=3.31.1=h62c20be_1
121 | tensorboard=2.4.1=pypi_0
122 | tensorboard-plugin-wit=1.8.0=pypi_0
123 | terminado=0.8.3=pypi_0
124 | testpath=0.4.4=pyhd3eb1b0_0
125 | threadpoolctl=2.1.0=pypi_0
126 | tifffile=2020.9.3=pypi_0
127 | tikzplotlib=0.9.8=pypi_0
128 | tk=8.6.8=hbc83047_0
129 | torch=1.2.0=pypi_0
130 | torchvision=0.4.0=pypi_0
131 | tornado=6.0.4=pypi_0
132 | tqdm=4.46.0=pypi_0
133 | traitlets=4.3.3=py36_0
134 | typing_extensions=3.7.4.3=pyha847dfd_0
135 | urllib3=1.25.9=pypi_0
136 | wcwidth=0.2.5=py_0
137 | webencodings=0.5.1=pypi_0
138 | werkzeug=1.0.1=pypi_0
139 | wheel=0.34.2=py36_0
140 | widgetsnbextension=3.5.1=py36_0
141 | xz=5.2.5=h7b6447c_0
142 | zeromq=4.3.4=h2531618_0
143 | zipp=3.1.0=pypi_0
144 | zlib=1.2.11=h7b6447c_3
145 |
--------------------------------------------------------------------------------
/scripts/run_full_pipeline.sh:
--------------------------------------------------------------------------------
1 | cd ../src/full_pipeline
2 | python main_save_full_pipeline.py --model=vae --mode=evidential
3 | python main_save_full_pipeline.py --model=gmm --mode=evidential
4 | python main_save_full_pipeline.py --model=kmeans --mode=evidential
5 | python main_save_full_pipeline.py --model=vae --mode=average
6 |
7 | # Change the model and fusion mode in the file to obtain metrics.
8 | python full_pipeline_metrics.py
9 |
10 | # Visualize the scenarios in the paper and in the appendix.
11 | python main_visualize_full_pipeline.py --model=vae --mode=evidential
12 |
13 |
--------------------------------------------------------------------------------
/scripts/run_preprocess_data.sh:
--------------------------------------------------------------------------------
1 | # Record the mean and standard deviation statistics at the end into statistics.txt and into src/utils/data_generator.py in three functions.
2 | cd ../src/preprocess
3 | python generate_data.py
4 | python train_val_test_split.py
5 | python get_driver_sensor_data.py
6 | python preprocess_driver_sensor_data.py
7 |
--------------------------------------------------------------------------------
/scripts/train_and_test_driver_sensor_model.sh:
--------------------------------------------------------------------------------
1 | cd ../src/driver_sensor_model
2 | # Train and test our CVAE driver sensor model.
3 | python train_cvae.py --norm --mut_info='const' --epochs=30 --learning_rate=0.001 --beta=1.0 --alpha=1.5 --latent_size=100 --batch_size=256 --crossover=10000
4 | python inference_cvae.py --norm --mut_info='const' --epochs=30 --learning_rate=0.001 --beta=1.0 --alpha=1.5 --latent_size=100 --batch_size=256 --crossover=10000
5 |
6 | # Train and test the k-means PaS baseline driver sensor model.
7 | python train_kmeans.py
8 | python inference_kmeans.py
9 |
10 | # Train and test the GMM PaS baseline driver sensor model.
11 | python train_gmm.py
12 | python inference_gmm.py
13 |
14 | # To view the qualitative driver sensor model results, please see the Jupyter notebook: src/driver_sensor_model/visualize_cvae.ipynb.
15 |
--------------------------------------------------------------------------------
/src/driver_sensor_model/inference_gmm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.manual_seed(123)
3 | torch.backends.cudnn.deterministic = True
4 | torch.backends.cudnn.benchmark = False
5 | import torch.nn as nn
6 | import torch.utils.data as data
7 | import torchvision
8 | from torchvision import datasets, transforms
9 |
10 | import numpy as np
11 | np.random.seed(0)
12 | from matplotlib import pyplot as plt
13 | import hickle as hkl
14 | import pickle as pkl
15 | import pdb
16 | import os
17 | import multiprocessing as mp
18 | from mpl_toolkits.axes_grid1 import ImageGrid
19 |
20 | os.chdir("../..")
21 |
22 | from src.utils.data_generator import *
23 | from sklearn.mixture import GaussianMixture as GMM
24 | from src.utils.interaction_utils import *
25 | import time
26 |
27 | from tqdm import tqdm
28 |
29 | # Load data.
30 | nt = 10
31 | num_states = 7
32 | grid_shape = (20, 30)
33 |
34 | use_cuda = torch.cuda.is_available()
35 | device = torch.device("cuda:0" if use_cuda else "cpu")
36 |
37 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset'
38 |
39 | # Test on test data.
40 | data_file_states = os.path.join(dir, 'states_test.hkl')
41 | data_file_grids = os.path.join(dir, 'label_grids_test.hkl')
42 | data_file_sources = os.path.join(dir, 'sources_test.hkl')
43 |
44 | data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
45 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False)
46 |
47 | test_loader = torch.utils.data.DataLoader(data_test,
48 | batch_size=len(data_test), shuffle=False,
49 | num_workers=mp.cpu_count()-1, pin_memory=True)
50 |
51 | for batch_x_test, batch_y_test, sources_test in test_loader:
52 | batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device)
53 |
54 | batch_x_test = batch_x_test.cpu().data.numpy()
55 | batch_y_test_orig = batch_y_test_orig.cpu().data.numpy()
56 |
57 | # GMM
58 | plt.ioff()
59 |
60 | models = 'gmm'
61 | dir_models = os.path.join('/models/', models)
62 | if not os.path.isdir(dir_models):
63 | os.mkdir(dir_models)
64 |
65 | Ks = [100] # [4, 10, 25, 50, 100, 150]
66 | N = 3
67 | ncols = {4:2, 10:2, 25:5, 50:5, 100:10, 150:10}
68 | nrows = {4:2, 10:5, 25:5, 50:10, 100:10, 150:15}
69 | acc_nums = []
70 | mse_nums = []
71 | im_nums = []
72 | plot_train_flag = True
73 | plot_metric_clusters_flag = False
74 | plot_scatter_flag = False
75 | plot_scatter_clusters_flag = False
76 |
77 | for K in Ks:
78 |
79 | model, p_m_a_np = pkl.load(open(os.path.join(dir_models, "GMM_K_" + str(K) + "_reg_covar_0_001.pkl"), "rb" ) )
80 |
81 | cluster_centers_np = model.means_
82 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,num_states))
83 | cluster_centers_np = unnormalize(cluster_centers_np)
84 |
85 | if plot_train_flag:
86 | plot_train(K, p_m_a_np, cluster_centers_np, dir_models, grid_shape)
87 |
88 | cluster_ids_y_test = model.predict(batch_x_test)
89 | cluster_ids_y_test_prob = model.predict_proba(batch_x_test)
90 |
91 | # Most likely class.
92 | grids_pred_orig = p_m_a_np[cluster_ids_y_test]
93 |
94 | grids_plot = np.reshape(p_m_a_np, (-1,grid_shape[0],grid_shape[1]))
95 | fig, axeslist = plt.subplots(ncols=ncols[K], nrows=nrows[K])
96 | for i in range(grids_plot.shape[0]):
97 | axeslist.ravel()[i].imshow(grids_plot[i], cmap=plt.gray())
98 | axeslist.ravel()[i].set_xticks([])
99 | axeslist.ravel()[i].set_xticklabels([])
100 | axeslist.ravel()[i].set_yticks([])
101 | axeslist.ravel()[i].set_yticklabels([])
102 | axeslist.ravel()[i].set_aspect('equal')
103 | plt.subplots_adjust(left = 0.5, hspace = 0, wspace = 0)
104 | plt.margins(0,0)
105 | fig.savefig(os.path.join(dir_models, 'test_' + str(K) + '.png'), pad_inches=0)
106 | plt.close(fig)
107 |
108 | grids_pred = deepcopy(grids_pred_orig)
109 | grids_pred[grids_pred_orig >= 0.6] = 1.0
110 | grids_pred[np.logical_and(grids_pred_orig < 0.6, grids_pred_orig > 0.4)] = 0.5
111 | grids_pred[grids_pred_orig <= 0.4] = 0.0
112 |
113 | acc_occ_free = np.mean(grids_pred == batch_y_test_orig)
114 | mse_occ_free = np.mean((grids_pred_orig - batch_y_test_orig)**2)
115 |
116 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = \
117 | MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1]))) # im_ocl_grids,
118 | im = np.mean(im_grids)
119 | im_occ = np.mean(im_occ_grids)
120 | im_free = np.mean(im_free_grids)
121 | im_occ_free_grids = im_occ_grids + im_free_grids
122 | im_occ_free = np.mean(im_occ_grids + im_free_grids)
123 |
124 | acc_nums.append(acc)
125 | mse_nums.append(mse)
126 | im_nums.append(im)
127 |
128 | acc_occ = np.mean(grids_pred[batch_y_test_orig == 1] == batch_y_test_orig[batch_y_test_orig == 1])
129 | acc_free = np.mean(grids_pred[batch_y_test_orig == 0] == batch_y_test_orig[batch_y_test_orig == 0])
130 |
131 | mse_occ = np.mean((grids_pred_orig[batch_y_test_orig == 1] - batch_y_test_orig[batch_y_test_orig == 1])**2)
132 | mse_free = np.mean((grids_pred_orig[batch_y_test_orig == 0] - batch_y_test_orig[batch_y_test_orig == 0])**2)
133 |
134 | print("Metrics: ")
135 |
136 | print("Occupancy and Free Metrics: ")
137 | print("K: ", K, "Accuracy: ", acc_occ_free, "MSE: ", mse_occ_free, "IM: ", im_occ_free, "IM max: ", np.amax(im_occ_free_grids), "IM min: ", np.amin(im_occ_free_grids))
138 |
139 | print("Occupancy Metrics: ")
140 | print("Accuracy: ", acc_occ, "MSE: ", mse_occ, "IM: ", im_occ, "IM max: ", np.amax(im_occ_grids), "IM min: ", np.amin(im_occ_grids))
141 |
142 | print("Free Metrics: ")
143 | print("Accuracy: ", acc_free, "MSE: ", mse_free, "IM: ", im_free, "IM max: ", np.amax(im_free_grids), "IM min: ", np.amin(im_free_grids))
144 |
145 | num_occ = np.sum(batch_y_test_orig == 1, axis=-1)
146 | num_free = np.sum(batch_y_test_orig == 0, axis=-1)
147 | num_occ_free = np.sum(np.logical_or(batch_y_test_orig == 1, batch_y_test_orig == 0), axis=-1)
148 |
149 | im_nums_N = np.empty((batch_y_test_orig.shape[0], N))
150 | acc_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N))
151 | mse_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N))
152 | im_occ_free_nums_N = np.empty((batch_y_test_orig.shape[0], N))
153 | acc_occ_nums_N = np.empty((np.sum(num_occ > 0), N))
154 | mse_occ_nums_N = np.empty((np.sum(num_occ > 0), N))
155 | im_occ_nums_N = np.empty((batch_y_test_orig.shape[0], N))
156 | acc_free_nums_N = np.empty((np.sum(num_free > 0), N))
157 | mse_free_nums_N = np.empty((np.sum(num_free > 0), N))
158 | im_free_nums_N = np.empty((batch_y_test_orig.shape[0], N))
159 |
160 | for n in range(1,N+1):
161 | grids_pred_orig = p_m_a_np[cluster_ids_y_test_prob.argsort(1)[:,-n]]
162 | grids_pred = (grids_pred_orig >= 0.6).astype(float)
163 | grids_pred[grids_pred_orig <= 0.4] = 0.0
164 | grids_pred[np.logical_and(grids_pred_orig < 0.6, grids_pred_orig > 0.4)] = 0.5
165 |
166 | grids_pred_free = (grids_pred * (batch_y_test_orig == 0))
167 | grids_pred_free[batch_y_test_orig != 0] = 2.0
168 |
169 | acc_occ_free_grids = np.mean(grids_pred == batch_y_test_orig, axis=-1)
170 | acc_occ_grids = np.sum((grids_pred * (batch_y_test_orig == 1)) == 1., axis=-1)[num_occ > 0] / num_occ[num_occ > 0] * 1.0
171 | acc_free_grids = np.sum(grids_pred_free == 0., axis=-1)[num_free > 0] / num_free[num_free > 0] * 1.0
172 |
173 | mse_occ_free_grids = np.mean((grids_pred_orig - batch_y_test_orig)**2, axis=-1)
174 | mse_occ_grids = np.sum(((grids_pred_orig * (batch_y_test_orig == 1)) - batch_y_test_orig * (batch_y_test_orig == 1))**2, axis=-1)[num_occ > 0] / num_occ[num_occ > 0] * 1.0
175 | mse_free_grids = np.sum(((grids_pred_orig * (batch_y_test_orig == 0)) - batch_y_test_orig * (batch_y_test_orig == 0))**2, axis=-1)[num_free > 0] / num_free[num_free > 0] * 1.0
176 |
177 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1])))
178 |
179 | acc_occ_free_nums_N[:,n-1] = acc_occ_free_grids
180 | mse_occ_free_nums_N[:,n-1] = mse_occ_free_grids
181 | im_occ_free_nums_N[:,n-1] = im_occ_grids + im_free_grids
182 |
183 | acc_occ_nums_N[:,n-1] = acc_occ_grids
184 | mse_occ_nums_N[:,n-1] = mse_occ_grids
185 | im_occ_nums_N[:,n-1] = im_occ_grids
186 |
187 | acc_free_nums_N[:,n-1] = acc_free_grids
188 | mse_free_nums_N[:,n-1] = mse_free_grids
189 | im_free_nums_N[:,n-1] = im_free_grids
190 |
191 | acc_occ_free_best = np.mean(np.amax(acc_occ_free_nums_N, axis=1))
192 | mse_occ_free_best = np.mean(np.amin(mse_occ_free_nums_N, axis=1))
193 | im_occ_free_best = np.mean(np.amin(im_occ_free_nums_N, axis=1))
194 |
195 | acc_occ_best = np.mean(np.amax(acc_occ_nums_N, axis=1))
196 | mse_occ_best = np.mean(np.amin(mse_occ_nums_N, axis=1))
197 | im_occ_best = np.mean(np.amin(im_occ_nums_N, axis=1))
198 |
199 | acc_free_best = np.mean(np.amax(acc_free_nums_N, axis=1))
200 | mse_free_best = np.mean(np.amin(mse_free_nums_N, axis=1))
201 | im_free_best = np.mean(np.amin(im_free_nums_N, axis=1))
202 |
203 | print("Top 3 Metrics: ")
204 |
205 | print("Occupancy and Free Metrics: ")
206 | print("Accuracy: ", acc_occ_free_best, "MSE: ", mse_occ_free_best, "IM: ", im_occ_free_best)
207 |
208 | print("Occupancy Metrics: ")
209 | print("Accuracy: ", acc_occ_best, "MSE: ", mse_occ_best, "IM: ", im_occ_best)
210 |
211 | print("Free Metrics: ")
212 | print("Accuracy: ", acc_free_best, "MSE: ", mse_free_best, "IM: ", im_free_best)
213 |
214 | hkl.dump(np.array([acc_best, acc_occ_best, acc_free_best, acc_occ_free_best,\
215 | mse_best, mse_occ_best, mse_free_best, mse_occ_free_best,\
216 | im_best, im_occ_best, im_free_best, im_occ_free_best]),\
217 | os.path.join(dir_models, 'top_3_metrics.hkl'), mode='w',)
218 |
219 | batch_x_np_test = batch_x_test
220 | batch_x_np_test = np.reshape(batch_x_np_test, (-1,nt,num_states))
221 | batch_x_np_test = unnormalize(batch_x_np_test)
222 |
223 | if plot_metric_clusters_flag:
224 | plot_metric_clusters(K, batch_x_np_test, cluster_ids_y_test, im_grids2, 'IM2', dir_models, sources_test)
225 |
226 | if plot_scatter_flag:
227 | plot_scatter(K, batch_x_np_test, im_grids, 'IM', dir_models)
228 |
229 | if plot_scatter_clusters_flag:
230 | plot_scatter_clusters(K, batch_x_np_test, cluster_ids_y_test, dir_models)
231 | plt.show()
232 |
233 | # Standard error
234 | acc_occ_std_error = np.std(grids_pred[batch_y_test_orig == 1] == 1)/np.sqrt(grids_pred[batch_y_test_orig == 1].size)
235 | acc_free_std_error = np.std(grids_pred[batch_y_test_orig == 0] == 0)/np.sqrt(grids_pred[batch_y_test_orig == 0].size)
236 | acc_occ_free_std_error = np.std(grids_pred[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] == batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size)
237 |
238 | mse_occ_std_error = np.std((grids_pred_orig[batch_y_test_orig == 1] - 1)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size)
239 | mse_free_std_error = np.std((grids_pred_orig[batch_y_test_orig == 0] - 0)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size)
240 | mse_occ_free_std_error = np.std((grids_pred_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] - batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])**2)/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size)
241 |
242 | im_occ_std_error = np.std(im_occ_grids)/np.sqrt(im_occ_grids.size)
243 | im_free_std_error = np.std(im_free_grids)/np.sqrt(im_free_grids.size)
244 | im_occ_free_std_error = np.std(im_occ_grids + im_free_grids)/np.sqrt(im_occ_grids.size)
245 |
246 | print('Maximum standard error:')
247 | print(np.amax([acc_occ_std_error, acc_free_std_error, acc_occ_free_std_error, mse_occ_std_error, mse_free_std_error, mse_occ_free_std_error]))
248 | print(np.amax([im_occ_std_error, im_free_std_error, im_occ_free_std_error]))
249 |
250 | # GMM latent space visualization.
251 | all_latent_classes = np.reshape(p_m_a_np, (100,20,30))
252 |
253 | fig = plt.figure(figsize=(10, 10))
254 | grid = ImageGrid(fig, 111,
255 | nrows_ncols=(10, 10),
256 | axes_pad=0.1,
257 | )
258 |
259 | for ax, im in zip(grid, all_latent_classes):
260 | ax.matshow(im, cmap='gray_r', vmin=0, vmax=1)
261 | ax.set_xticks([], [])
262 | ax.set_yticks([], [])
263 |
264 | plt.savefig('models/gmm/all_latent_classes_gmm.png')
265 | plt.show()
--------------------------------------------------------------------------------
/src/driver_sensor_model/inference_kmeans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.manual_seed(123)
3 | torch.backends.cudnn.deterministic = True
4 | torch.backends.cudnn.benchmark = False
5 | import torch.nn as nn
6 | import torch.utils.data as data
7 | import torchvision
8 | from torchvision import datasets, transforms
9 |
10 | import numpy as np
11 | np.random.seed(0)
12 | from matplotlib import pyplot as plt
13 | import hickle as hkl
14 | import pickle as pkl
15 | import pdb
16 | import os
17 | import multiprocessing as mp
18 | from mpl_toolkits.axes_grid1 import ImageGrid
19 |
20 | os.chdir("../..")
21 |
22 | from src.utils.data_generator import *
23 | from sklearn.cluster import KMeans
24 | from src.utils.interaction_utils import *
25 | import time
26 |
27 | from tqdm import tqdm
28 |
29 | # Load data.
30 | nt = 10
31 | num_states = 7
32 | grid_shape = (20, 30)
33 |
34 | use_cuda = torch.cuda.is_available()
35 | device = torch.device("cuda:0" if use_cuda else "cpu")
36 |
37 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset'
38 |
39 | # Test on test data.
40 | data_file_states = os.path.join(dir, 'states_test.hkl')
41 | data_file_grids = os.path.join(dir, 'label_grids_test.hkl')
42 | data_file_sources = os.path.join(dir, 'sources_test.hkl')
43 |
44 | data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
45 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False)
46 |
47 | test_loader = torch.utils.data.DataLoader(data_test,
48 | batch_size=len(data_test), shuffle=False,
49 | num_workers=mp.cpu_count()-1, pin_memory=True)
50 |
51 | for batch_x_test, batch_y_test, sources_test in test_loader:
52 | batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device)
53 |
54 | batch_x_test = batch_x_test.cpu().data.numpy()
55 | batch_y_test_orig = batch_y_test_orig.cpu().data.numpy()
56 |
57 | plt.ioff()
58 |
59 | models = 'kmeans'
60 | dir_models = os.path.join('/models/', models)
61 | if not os.path.isdir(dir_models):
62 | os.mkdir(dir_models)
63 |
64 | Ks = [100] # [4, 10, 25, 50, 100, 150]
65 | ncols = {4:2, 10:2, 25:5, 50:5, 100:10, 150:10}
66 | nrows = {4:2, 10:5, 25:5, 50:10, 100:10, 150:15}
67 |
68 | plot_train_flag = True
69 | plot_metric_clusters_flag = False
70 | plot_scatter_flag = False
71 | plot_scatter_clusters_flag = False
72 |
73 | for K in Ks:
74 |
75 | [model, p_m_a_np] = pkl.load(open(os.path.join(dir_models,"clusters_kmeans_K_" + str(K) + "_sklearn.pkl"), "rb" ) )
76 |
77 | cluster_centers_np = model.cluster_centers_
78 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,num_states))
79 | cluster_centers_np = unnormalize(cluster_centers_np)
80 |
81 | if plot_train_flag:
82 | plot_train(K, p_m_a_np, cluster_centers_np, dir_models, grid_shape)
83 |
84 | start = time.time()
85 | cluster_ids_y_test = model.predict(batch_x_test)
86 | grids_pred_orig = p_m_a_np[cluster_ids_y_test]
87 |
88 | grids_plot = np.reshape(p_m_a_np, (-1,grid_shape[0],grid_shape[1]))
89 | fig, axeslist = plt.subplots(ncols=ncols[K], nrows=nrows[K])
90 | for i in range(grids_plot.shape[0]):
91 | axeslist.ravel()[i].imshow(grids_plot[i], cmap=plt.gray())
92 | axeslist.ravel()[i].set_xticks([])
93 | axeslist.ravel()[i].set_xticklabels([])
94 | axeslist.ravel()[i].set_yticks([])
95 | axeslist.ravel()[i].set_yticklabels([])
96 | axeslist.ravel()[i].set_aspect('equal')
97 | plt.subplots_adjust(left = 0.5, hspace = 0, wspace = 0)
98 | plt.margins(0,0)
99 | fig.savefig(os.path.join(dir_models, 'test_' + str(K) + '.png'), pad_inches=0)
100 | plt.close(fig)
101 |
102 | grids_pred = deepcopy(grids_pred_orig)
103 | grids_pred[grids_pred_orig >= 0.6] = 1.0
104 | grids_pred[np.logical_and(grids_pred_orig > 0.4, grids_pred_orig < 0.6)] = 0.5
105 | grids_pred[grids_pred_orig <= 0.4] = 0.0
106 |
107 | acc_occ_free = np.mean(grids_pred == batch_y_test_orig)
108 | mse_occ_free = np.mean((grids_pred_orig - batch_y_test_orig)**2)
109 |
110 | im_grids, im_occ_grids, im_free_grids, im_ocl_grids = \
111 | MapSimilarityMetric(np.reshape(grids_pred, (-1,grid_shape[0],grid_shape[1])), np.reshape(batch_y_test_orig, (-1,grid_shape[0],grid_shape[1]))) # im_ocl_grids,
112 | im = np.mean(im_grids)
113 | im_occ = np.mean(im_occ_grids)
114 | im_free = np.mean(im_free_grids)
115 | im_occ_free_grids = im_occ_grids + im_free_grids
116 | im_occ_free = np.mean(im_occ_grids + im_free_grids)
117 |
118 | acc_occ = np.mean(grids_pred[batch_y_test_orig == 1] == 1)
119 | acc_free = np.mean(grids_pred[batch_y_test_orig == 0] == 0)
120 |
121 | mse_occ = np.mean((grids_pred_orig[batch_y_test_orig == 1] - 1)**2)
122 | mse_free = np.mean((grids_pred_orig[batch_y_test_orig == 0] - 0)**2)
123 |
124 | print("Metrics: ")
125 |
126 | print("Occupancy and Free Metrics: ")
127 | print("K: ", K, "Accuracy: ", acc_occ_free, "MSE: ", mse_occ_free, "IM: ", im_occ_free, "IM max: ", np.amax(im_occ_free_grids), "IM min: ", np.amin(im_occ_free_grids))
128 |
129 | print("Occupancy Metrics: ")
130 | print("Accuracy: ", acc_occ, "MSE: ", mse_occ, "IM: ", im_occ, "IM max: ", np.amax(im_occ_grids), "IM min: ", np.amin(im_occ_grids))
131 |
132 | print("Free Metrics: ")
133 | print("Accuracy: ", acc_free, "MSE: ", mse_free, "IM: ", im_free, "IM max: ", np.amax(im_free_grids), "IM min: ", np.amin(im_free_grids))
134 |
135 | batch_x_np_test = batch_x_test
136 | batch_x_np_test = np.reshape(batch_x_np_test, (-1,nt,num_states))
137 | batch_x_np_test = unnormalize(batch_x_np_test)
138 |
139 | if plot_metric_clusters_flag:
140 | plot_metric_clusters(K, batch_x_np_test, cluster_ids_y_test, im_grids2, 'IM2', dir_models, sources_test)
141 |
142 | if plot_scatter_flag:
143 | plot_scatter(K, batch_x_np_test, im_grids, 'IM', dir_models)
144 |
145 | if plot_scatter_clusters_flag:
146 | plot_scatter_clusters(K, batch_x_np_test, cluster_ids_y_test, dir_models)
147 | plt.show()
148 |
149 | # Standard error
150 | acc_occ_std_error = np.std(grids_pred[batch_y_test_orig == 1] == 1)/np.sqrt(grids_pred[batch_y_test_orig == 1].size)
151 | acc_free_std_error = np.std(grids_pred[batch_y_test_orig == 0] == 0)/np.sqrt(grids_pred[batch_y_test_orig == 0].size)
152 | acc_occ_free_std_error = np.std(grids_pred[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] == batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size)
153 |
154 | mse_occ_std_error = np.std((grids_pred_orig[batch_y_test_orig == 1] - 1)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size)
155 | mse_free_std_error = np.std((grids_pred_orig[batch_y_test_orig == 0] - 0)**2)/np.sqrt(grids_pred_orig[batch_y_test_orig == 0].size)
156 | mse_occ_free_std_error = np.std((grids_pred_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)] - batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)])**2)/np.sqrt(batch_y_test_orig[np.logical_or(batch_y_test_orig == 0, batch_y_test_orig == 1)].size)
157 |
158 | im_occ_std_error = np.std(im_occ_grids)/np.sqrt(im_occ_grids.size)
159 | im_free_std_error = np.std(im_free_grids)/np.sqrt(im_free_grids.size)
160 | im_occ_free_std_error = np.std(im_occ_grids + im_free_grids)/np.sqrt(im_occ_grids.size)
161 |
162 | print('Maximum standard error:')
163 | print(np.amax([acc_occ_std_error, acc_free_std_error, acc_occ_free_std_error, mse_occ_std_error, mse_free_std_error, mse_occ_free_std_error]))
164 | print(np.amax([im_occ_std_error, im_free_std_error, im_occ_free_std_error]))
165 |
166 | # Visualize all latent classes.
167 | all_latent_classes = np.reshape(p_m_a_np, (100,20,30))
168 |
169 | fig = plt.figure(figsize=(10, 10))
170 | grid = ImageGrid(fig, 111,
171 | nrows_ncols=(10, 10),
172 | axes_pad=0.1,
173 | )
174 |
175 | for ax, im in zip(grid, all_latent_classes):
176 | ax.matshow(im, cmap='gray_r', vmin=0, vmax=1)
177 | ax.set_xticks([], [])
178 | ax.set_yticks([], [])
179 |
180 | plt.savefig('models/kmeans/all_latent_classes_kmeans.png')
181 | plt.show()
--------------------------------------------------------------------------------
/src/driver_sensor_model/models_cvae.py:
--------------------------------------------------------------------------------
1 | # Architecture for the CVAE driver sensor model. Code is adapted from: https://github.com/sisl/EvidentialSparsification and
2 | # https://github.com/ritheshkumar95/pytorch-vqvae/blob/master/modules.py.
3 |
4 | seed = 123
5 | import numpy as np
6 | np.random.seed(seed)
7 | import torch
8 | import math
9 |
10 | import torch.nn as nn
11 | from src.utils.utils_model import to_var, sample_p
12 | import pdb
13 |
14 | class ResBlock(nn.Module):
15 | def __init__(self, dim):
16 | super().__init__()
17 | self.block = nn.Sequential(
18 | nn.ReLU(False),
19 | nn.Conv2d(dim, dim, 3, 1, 1),
20 | nn.BatchNorm2d(dim),
21 | nn.ReLU(False),
22 | nn.Conv2d(dim, dim, 1),
23 | nn.BatchNorm2d(dim)
24 | )
25 |
26 | def forward(self, x):
27 | return x + self.block(x)
28 |
29 | class VAE(nn.Module):
30 |
31 | def __init__(self, encoder_layer_sizes_p, n_lstms, latent_size, dim):
32 |
33 | super().__init__()
34 |
35 | torch.manual_seed(seed)
36 | torch.cuda.manual_seed(seed)
37 |
38 | assert type(encoder_layer_sizes_p) == list
39 | assert type(latent_size) == int
40 |
41 | self.latent_size = latent_size
42 | self.label_size = encoder_layer_sizes_p[-1]
43 |
44 | self.encoder = Encoder(encoder_layer_sizes_p, n_lstms, latent_size, dim)
45 | self.decoder = Decoder(latent_size, self.label_size, dim)
46 |
47 | def forward(self, x, c=None):
48 |
49 | batch_size = x.size(0)
50 |
51 | # Encode the input.
52 | alpha_q, alpha_p, output_all_c = self.encoder(x, c)
53 |
54 | # Obtain all possible latent classes.
55 | z = torch.eye(self.latent_size).cuda()
56 |
57 | # Decode all latent classes.
58 | recon_x = self.decoder(z)
59 |
60 | return recon_x, alpha_q, alpha_p, self.encoder.linear_latent_q, self.encoder.linear_latent_p, output_all_c, z
61 |
62 | def inference(self, n=1, c=None, mode='sample', k=None):
63 |
64 | batch_size = n
65 |
66 | alpha_q, alpha_p, output_all_c = self.encoder(x=torch.empty((0,0)), c=c, train=False)
67 |
68 | if mode == 'sample':
69 | # Decode the mode sampled from the prior distribution.
70 | z = sample_p(alpha_p, batch_size=batch_size).view(-1,self.latent_size)
71 | elif mode == 'all':
72 | # Decode all the modes.
73 | z = torch.eye(self.latent_size).cuda()
74 | elif mode == 'most_likely':
75 | # Decode the most likely mode.
76 | z = torch.nn.functional.one_hot(torch.argmax(alpha_p, dim=1), num_classes=self.latent_size).float()
77 | elif mode == 'multimodal':
78 | # Decode a particular mode.
79 | z = torch.nn.functional.one_hot(torch.argsort(alpha_p, dim=1)[:,-k], num_classes=self.latent_size).float()
80 |
81 | recon_x = self.decoder(z)
82 |
83 | return recon_x, alpha_p, self.encoder.linear_latent_p, output_all_c, z
84 |
85 | class Encoder(nn.Module):
86 |
87 | def __init__(self, layer_sizes_p, n_lstms, latent_size, dim):
88 |
89 | super().__init__()
90 |
91 | input_dim = 1
92 | self.VQVAEBlock = nn.Sequential(
93 | nn.Conv2d(input_dim, dim, 4, 2, 1),
94 | nn.BatchNorm2d(dim),
95 | nn.ReLU(False),
96 | nn.Conv2d(dim, dim, 4, 2, 1),
97 | ResBlock(dim),
98 | ResBlock(dim)
99 | )
100 |
101 | self.lstm = nn.LSTM(layer_sizes_p[0],
102 | layer_sizes_p[-1],
103 | num_layers=n_lstms,
104 | batch_first=True)
105 |
106 | self.linear_latent_q = nn.Linear(5*7*dim + layer_sizes_p[-1]*10, latent_size) # 10 is the time dimension.
107 | self.softmax_q = nn.Softmax(dim=-1)
108 |
109 | self.linear_latent_p = nn.Linear(layer_sizes_p[-1]*10, latent_size) # 10 is the time dimension.
110 | self.softmax_p = nn.Softmax(dim=-1)
111 |
112 | def forward(self, x=None, c=None, train=True):
113 |
114 | output_all, (full_c, _) = self.lstm(c)
115 | output_all_c = torch.reshape(output_all, (c.shape[0], -1))
116 | alpha_p_lin = self.linear_latent_p(output_all_c)
117 | alpha_p = self.softmax_p(alpha_p_lin)
118 |
119 | if train:
120 |
121 | full_x = self.VQVAEBlock(x)
122 | full_x = full_x.view(full_x.shape[0],-1)
123 | output_all_c = torch.cat((full_x, output_all_c), dim=-1)
124 | alpha_q_lin = self.linear_latent_q(output_all_c)
125 | alpha_q = self.softmax_q(alpha_q_lin)
126 |
127 | else:
128 | alpha_q_lin = None
129 | alpha_q = None
130 |
131 | return alpha_q, alpha_p, output_all_c
132 |
133 |
134 | class Decoder(nn.Module):
135 |
136 | def __init__(self, latent_size, label_size, dim):
137 |
138 | super().__init__()
139 | self.latent_size = latent_size
140 | self.dim = dim
141 | input_dim = 1
142 |
143 | self.decode_linear = nn.Sequential(
144 | nn.Linear(self.latent_size, self.latent_size*dim),
145 | nn.ReLU(False),
146 | nn.Linear(self.latent_size*dim, self.latent_size*dim),
147 | nn.ReLU(False),
148 | )
149 |
150 | self.VQVAEBlock = nn.Sequential(
151 | ResBlock(dim),
152 | ResBlock(dim),
153 | nn.ReLU(False),
154 | nn.ConvTranspose2d(dim, dim, (4,5), (2,3), 1),
155 | nn.BatchNorm2d(dim),
156 | nn.ReLU(False),
157 | nn.ConvTranspose2d(dim, input_dim, 1, 1, 0),
158 | nn.Sigmoid()
159 | )
160 |
161 | def forward(self, z):
162 | latent_size_sqrt = int(math.sqrt(self.latent_size))
163 | z_c = self.decode_linear(z)
164 | z_c = z_c.view(-1, self.dim, latent_size_sqrt, latent_size_sqrt)
165 | x = self.VQVAEBlock(z_c)
166 | return x
--------------------------------------------------------------------------------
/src/driver_sensor_model/train_gmm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | random = 123
3 | torch.manual_seed(random)
4 | torch.backends.cudnn.deterministic = True
5 | torch.backends.cudnn.benchmark = False
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | import torchvision
9 | from torchvision import datasets, transforms
10 |
11 | import numpy as np
12 | np.random.seed(random)
13 | from matplotlib import pyplot as plt
14 | import hickle as hkl
15 | import pickle as pkl
16 | import pdb
17 | import os
18 | import multiprocessing as mp
19 |
20 | os.chdir("../..")
21 |
22 | from src.utils.data_generator import *
23 | from sklearn.mixture import GaussianMixture as GMM
24 | from tqdm import tqdm
25 | import time
26 |
27 | # Load data.
28 | nt = 10
29 |
30 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset/'
31 | models = 'gmm'
32 | dir_models = os.path.join(dir, models)
33 | if not os.path.isdir(dir_models):
34 | os.mkdir(dir_models)
35 | data_file_states = os.path.join(dir, 'states_shuffled_train.hkl')
36 | data_file_grids = os.path.join(dir, 'label_grids_shuffled_train.hkl')
37 | data_file_sources = os.path.join(dir, 'sources_shuffled_train.hkl')
38 |
39 | data_train = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
40 | batch_size=None, shuffle=True, sequence_start_mode='all', norm=False)
41 | print("number of unique sources: ", len(np.unique(data_train.sources)))
42 |
43 | train_loader = torch.utils.data.DataLoader(data_train,
44 | batch_size=len(data_train), shuffle=True,
45 | num_workers=mp.cpu_count()-1, pin_memory=True)
46 |
47 | use_cuda = torch.cuda.is_available()
48 | device = torch.device("cuda:0" if use_cuda else "cpu")
49 |
50 | for batch_x, batch_y, sources in train_loader:
51 | batch_x, batch_y = batch_x.to("cpu").data.numpy(), batch_y.to("cpu").data.numpy()
52 | print(batch_x.shape, batch_y.shape)
53 |
54 | # Test on validation data.
55 | data_file_states = os.path.join(dir, 'states_val.hkl')
56 | data_file_grids = os.path.join(dir, 'label_grids_val.hkl')
57 | data_file_sources = os.path.join(dir, 'sources_val.hkl')
58 |
59 | data_val = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
60 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False)
61 |
62 | val_loader = torch.utils.data.DataLoader(data_val,
63 | batch_size=len(data_val), shuffle=False,
64 | num_workers=mp.cpu_count()-1, pin_memory=True)
65 |
66 | for batch_x_val, batch_y_val, sources_val in val_loader:
67 | batch_x_val, batch_y_val = batch_x_val.to("cpu").data.numpy(), batch_y_val.to("cpu").data.numpy()
68 |
69 | print('unique values', np.unique(batch_y_val))
70 |
71 | Ks = [100] # [4, 10, 25, 50, 100, 150]
72 | acc_nums = []
73 | acc_nums_free_occ = []
74 | acc_nums_occ = []
75 | acc_nums_free = []
76 | acc_nums_ocl = []
77 | plot_train = False
78 | grid_shape = (20,30)
79 |
80 | for K in tqdm(Ks):
81 |
82 | model = GMM(n_components=K, covariance_type='diag', reg_covar=1e-3)
83 | model.fit(batch_x)
84 | cluster_ids_x = model.predict(batch_x)
85 | cluster_centers = model.means_
86 |
87 | p_a_m_1 = np.zeros((K, batch_y.shape[1]))
88 | p_a_m_0 = np.zeros((K, batch_y.shape[1]))
89 | for i in range(batch_y.shape[1]):
90 | p_m_1 = np.sum(batch_y[:,i] == 1).astype(float)/batch_y.shape[0]
91 | p_m_0 = (np.sum(batch_y[:,i] == 0).astype(float)/batch_y.shape[0])
92 | for k in range(K):
93 | p_a_m_1[k,i] = np.sum(np.logical_and(cluster_ids_x == k, batch_y[:,i] == 1)).astype(float)/batch_y.shape[0]
94 | p_a_m_0[k,i] = np.sum(np.logical_and(cluster_ids_x == k, batch_y[:,i] == 0)).astype(float)/batch_y.shape[0]
95 |
96 | if p_m_1 != 0:
97 | p_a_m_1[:,i] = p_a_m_1[:,i]/p_m_1
98 | else:
99 | p_a_m_1[:,i] = p_a_m_1[:,i]/1.0
100 |
101 | if p_m_0 != 0:
102 | p_a_m_0[:,i] = p_a_m_0[:,i]/p_m_0
103 | else:
104 | p_a_m_0[:,i] = p_a_m_0[:,i]/1.0
105 |
106 | p_m_a_np = p_a_m_1/(p_a_m_1+p_a_m_0)
107 |
108 | pkl.dump([model, p_m_a_np], open( os.path.join(dir_models, "GMM_K_" + str(K) + "_reg_covar_0_001.pkl"), "wb" ) )
109 |
110 | cluster_centers_np = np.reshape(cluster_centers, (K,nt,7))
111 | cluster_centers_np = unnormalize(cluster_centers_np)
112 |
113 | if plot_train:
114 | fig, (ax1, ax2, ax3) = plt.subplots(3)
115 | fig.suptitle('State Clusters')
116 |
117 | for k in range(K):
118 | fig_occ, ax_occ = plt.subplots(1)
119 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0)
120 | ax_occ.imshow(image, cmap='gray')
121 | picture_file = os.path.join(dir_models, 'cluster_' + str(k) + '.png')
122 | plt.savefig(picture_file)
123 | fig_occ.clf()
124 |
125 | ax1.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,1], label=str(k))
126 | ax2.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,2], label=str(k))
127 | ax3.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,3], label=str(k))
128 |
129 | ax1.set_ylabel("Pos (m)")
130 | ax1.set_ylim(-5,120)
131 | ax1.set_xlim(-5,120)
132 | ax2.set_ylabel("Vel (m/s)")
133 | ax2.set_ylim(0,8)
134 | ax2.set_xlim(-5,120)
135 | ax3.set_xlabel("Position (m)")
136 | ax3.set_ylabel("Acc (m/s^2)")
137 | ax3.set_ylim(-3,3)
138 | ax3.set_xlim(-5,120)
139 |
140 | handles, labels = ax1.get_legend_handles_labels()
141 | fig.legend(handles, labels, loc='center right')
142 |
143 | picture_file = os.path.join(dir_models, 'state_clusters_gmm.png')
144 | fig.savefig(picture_file)
145 |
146 | del(p_m_a)
147 | del(cluster_centers_np)
148 | del(cluster_centers)
149 | del(cluster_ids)
150 |
151 | # Test on validation data.
152 | cluster_ids_y_val = model.predict(batch_x_val)
153 |
154 | grids_pred = p_m_a_np[cluster_ids_y_val]
155 |
156 | grids_pred[grids_pred >= 0.6] = 1.0
157 | grids_pred[grids_pred <= 0.4] = 0.0
158 | grids_pred[np.logical_and(grids_pred > 0.4, grids_pred < 0.6)] = 0.5
159 | grids_gt = batch_y_val
160 |
161 | acc = np.mean(grids_pred == grids_gt)
162 | acc_nums.append(acc)
163 |
164 | mask_occ = (batch_y_val == 1)
165 | acc_occ = np.mean(grids_pred[mask_occ] == grids_gt[mask_occ])
166 | acc_nums_occ.append(acc_occ)
167 |
168 | mask_free = (batch_y_val == 0)
169 | acc_free = np.mean(grids_pred[mask_free] == grids_gt[mask_free])
170 | acc_nums_free.append(acc_free)
171 |
172 | print("K: ", K, " Accuracy: ", acc, acc_occ, acc_free)
173 |
174 | if plot_train:
175 | plt.scatter(Ks, acc_nums, label='acc')
176 | plt.scatter(Ks, acc_nums_occ, label='acc occupied')
177 | plt.scatter(Ks, acc_nums_free, label='acc free')
178 | plt.ylim(0,1)
179 | plt.legend()
180 | plt.title('Accuracy vs Number of Clusters for Driver Sensor Model')
181 | plt.xlabel('Number of Clusters')
182 | plt.ylabel('Accuracy')
183 | plt.savefig(os.path.join(dir_models, 'acc_clusters_gmm.png'))
--------------------------------------------------------------------------------
/src/driver_sensor_model/train_kmeans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | random = 123
3 | torch.manual_seed(random)
4 | torch.backends.cudnn.deterministic = True
5 | torch.backends.cudnn.benchmark = False
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | import torchvision
9 | from torchvision import datasets, transforms
10 |
11 | import numpy as np
12 | np.random.seed(random)
13 | from matplotlib import pyplot as plt
14 | import hickle as hkl
15 | import pickle as pkl
16 | import pdb
17 | import os
18 | import multiprocessing as mp
19 |
20 | os.chdir("../..")
21 |
22 | from src.utils.data_generator import *
23 | from sklearn.cluster import KMeans
24 | import time
25 |
26 | from tqdm import tqdm
27 |
28 | # Load data.
29 | nt = 10
30 |
31 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data_new_goal/driver_sensor_dataset/'
32 | models = 'kmeans'
33 | dir_models = os.path.join('/models/', models)
34 | if not os.path.isdir(dir_models):
35 | os.mkdir(dir_models)
36 | data_file_states = os.path.join(dir, 'states_shuffled_train.hkl')
37 | data_file_grids = os.path.join(dir, 'label_grids_shuffled_train.hkl')
38 | data_file_sources = os.path.join(dir, 'sources_shuffled_train.hkl')
39 |
40 | data_train = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
41 | batch_size=None, shuffle=True, sequence_start_mode='all', norm=False)
42 | print("number of unique sources: ", len(np.unique(data_train.sources)))
43 |
44 | train_loader = torch.utils.data.DataLoader(data_train,
45 | batch_size=len(data_train), shuffle=True,
46 | num_workers=mp.cpu_count()-1, pin_memory=True)
47 |
48 | use_cuda = torch.cuda.is_available()
49 | device = torch.device("cuda:0" if use_cuda else "cpu")
50 |
51 | for batch_x, batch_y, sources in train_loader:
52 | batch_x, batch_y = batch_x.to("cpu").data.numpy(), batch_y.to("cpu").data.numpy()
53 | print(batch_x.shape, batch_y.shape)
54 |
55 | data_file_states = os.path.join(dir, 'states_val.hkl')
56 | data_file_grids = os.path.join(dir, 'label_grids_val.hkl')
57 | data_file_sources = os.path.join(dir, 'sources_val.hkl')
58 |
59 | data_val = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,
60 | batch_size=None, shuffle=False, sequence_start_mode='unique', norm=False)
61 |
62 | val_loader = torch.utils.data.DataLoader(data_val,
63 | batch_size=len(data_val), shuffle=False,
64 | num_workers=mp.cpu_count()-1, pin_memory=True)
65 |
66 | for batch_x_val, batch_y_val, sources_val in val_loader:
67 | batch_x_val, batch_y_val = batch_x_val.to("cpu").data.numpy(), batch_y_val.to("cpu").data.numpy()
68 |
69 | Ks = [100] # [4, 10, 25, 50, 100, 150]
70 | acc_nums = []
71 | acc_nums_free_occ = []
72 | acc_nums_occ = []
73 | acc_nums_free = []
74 | acc_nums_ocl = []
75 | plot_train = False
76 | grid_shape = (20,30)
77 |
78 | for K in tqdm(Ks):
79 |
80 | model = KMeans(n_clusters=K)
81 | model.fit(batch_x)
82 |
83 | cluster_ids_x = model.predict(batch_x)
84 | cluster_centers = model.cluster_centers_
85 |
86 | p_a_m_occ = np.zeros((K, batch_y.shape[1]))
87 | p_a_m_free = np.zeros((K, batch_y.shape[1]))
88 |
89 | for i in range(batch_y.shape[1]):
90 | p_m_occ = np.sum(batch_y[:,i] == 1).astype(float)/batch_y.shape[0]
91 | p_m_free = np.sum(batch_y[:,i] == 0).astype(float)/batch_y.shape[0]
92 |
93 | for k in range(K):
94 | p_a_m_occ[k,i] = np.sum((cluster_ids_x == k) & (batch_y[:,i] == 1)).astype(float)/batch_y.shape[0]
95 | p_a_m_free[k,i] = np.sum((cluster_ids_x == k) & (batch_y[:,i] == 0)).astype(float)/batch_y.shape[0]
96 |
97 | if p_m_occ != 0:
98 | p_a_m_occ[:,i] = p_a_m_occ[:,i]/p_m_occ
99 | else:
100 | p_a_m_occ[:,i] = p_a_m_occ[:,i]/1.0
101 |
102 | if p_m_free != 0:
103 | p_a_m_free[:,i] = p_a_m_free[:,i]/p_m_free
104 | else:
105 | p_a_m_free[:,i] = p_a_m_free[:,i]/1.0
106 |
107 | p_m_a = p_a_m_occ/(p_a_m_occ + p_a_m_free)
108 | p_m_a_np = p_m_a
109 |
110 | pkl.dump([model, p_m_a_np], open(os.path.join(dir_models, "clusters_kmeans_K_" + str(K) + ".pkl"), "wb" ) )
111 |
112 | cluster_centers_np = cluster_centers
113 | cluster_centers_np = np.reshape(cluster_centers_np, (K,nt,-1))
114 | cluster_centers_np = unnormalize(cluster_centers_np)
115 |
116 | if plot_train:
117 | fig, (ax1, ax2, ax3) = plt.subplots(3)
118 | fig.suptitle('State Clusters')
119 |
120 | for k in range(K):
121 | fig_occ, ax_occ = plt.subplots(1)
122 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0)
123 | ax_occ.imshow(image, cmap='gray')
124 | picture_file = os.path.join(dir_models, 'cluster_' + str(k) + '.png')
125 | plt.savefig(picture_file)
126 | fig_occ.clf()
127 |
128 | ax1.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,1], label=str(k))
129 | ax2.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,2], label=str(k))
130 | ax3.scatter(cluster_centers_np[k,:,1], cluster_centers_np[k,:,3], label=str(k))
131 |
132 | ax1.set_ylabel("Pos (m)")
133 | ax1.set_ylim(-5,120)
134 | ax1.set_xlim(-5,120)
135 | ax2.set_ylabel("Vel (m/s)")
136 | ax2.set_ylim(0,8)
137 | ax2.set_xlim(-5,120)
138 | ax3.set_xlabel("Position (m)")
139 | ax3.set_ylabel("Acc (m/s^2)")
140 | ax3.set_ylim(-3,3)
141 | ax3.set_xlim(-5,120)
142 |
143 | handles, labels = ax1.get_legend_handles_labels()
144 | fig.legend(handles, labels, loc='center right')
145 |
146 | picture_file = os.path.join(dir_models, 'state_clusters.png')
147 | fig.savefig(picture_file)
148 |
149 | # Test on validation data.
150 | cluster_ids_y_val = model.predict(batch_x_val)
151 |
152 | grids_pred = p_m_a_np[cluster_ids_y_val]
153 | grids_pred[grids_pred >= 0.6] = 1.0
154 | grids_pred[grids_pred <= 0.4] = 0.0
155 | grids_pred[np.logical_and(grids_pred > 0.4, grids_pred < 0.6)] = 0.5
156 | grids_gt = batch_y_val
157 |
158 | acc = np.mean(grids_pred == grids_gt)
159 | acc_nums.append(acc)
160 |
161 | mask_occ = (batch_y_val == 1)
162 | acc_occ = np.mean(grids_pred[mask_occ] == grids_gt[mask_occ])
163 | acc_nums_occ.append(acc_occ)
164 |
165 | mask_free = (batch_y_val == 0)
166 | acc_free = np.mean(grids_pred[mask_free] == grids_gt[mask_free])
167 | acc_nums_free.append(acc_free)
168 |
169 | print("K: ", K, " Accuracy: ", acc, acc_occ, acc_free)
170 |
171 | if plot_train:
172 | plt.scatter(Ks, acc_nums, label='acc')
173 | plt.scatter(Ks, acc_nums_occ, label='acc occupied')
174 | plt.scatter(Ks, acc_nums_free, label='acc free')
175 | plt.legend()
176 | plt.ylim(0,1)
177 | plt.title('Accuracy vs Number of Clusters for Driver Sensor Model')
178 | plt.xlabel('Number of Clusters')
179 | plt.ylabel('Accuracy')
180 | plt.savefig(os.path.join(dir_models, 'acc_clusters.png'))
--------------------------------------------------------------------------------
/src/driver_sensor_model/visualize_cvae.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "seed = 123\n",
10 | "\n",
11 | "import numpy as np\n",
12 | "np.random.seed(seed)\n",
13 | "from matplotlib import pyplot as plt\n",
14 | "from mpl_toolkits.axes_grid1 import ImageGrid\n",
15 | "import torch\n",
16 | "torch.manual_seed(seed)\n",
17 | "torch.cuda.manual_seed(seed)\n",
18 | "torch.backends.cudnn.deterministic = True\n",
19 | "torch.backends.cudnn.benchmark = False\n",
20 | "import torch.nn as nn\n",
21 | "torch.autograd.set_detect_anomaly(True)\n",
22 | "import torchvision\n",
23 | "from torchvision import datasets, transforms\n",
24 | "\n",
25 | "import hickle as hkl\n",
26 | "import pickle as pkl\n",
27 | "import pdb\n",
28 | "import os\n",
29 | "\n",
30 | "from sklearn.cluster import KMeans\n",
31 | "from sklearn.mixture import GaussianMixture as GMM\n",
32 | "import time\n",
33 | "\n",
34 | "from torch.utils.data import DataLoader\n",
35 | "\n",
36 | "from tqdm import tqdm\n",
37 | "\n",
38 | "from copy import deepcopy\n",
39 | "import io\n",
40 | "import PIL.Image\n",
41 | "import multiprocessing as mp\n",
42 | "import tikzplotlib\n",
43 | "\n",
44 | "os.chdir(\"../..\")\n",
45 | "\n",
46 | "from src.utils.utils_model import to_var\n",
47 | "from src.driver_sensor_model.models_cvae import VAE\n",
48 | "from src.utils.data_generator import *\n",
49 | "from src.utils.interaction_utils import *"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "# Load data.\n",
59 | "nt = 10\n",
60 | "num_states = 7\n",
61 | "grid_shape = (20, 30)\n",
62 | " \n",
63 | "use_cuda = torch.cuda.is_available()\n",
64 | "device = torch.device(\"cuda:0\" if use_cuda else \"cpu\")\n",
65 | "\n",
66 | "dir = '/data/INTERACTION-Dataset-DR-v1_1/Processed_data_new_goal/driver_sensor_dataset/'\n",
67 | "\n",
68 | "# Test data.\n",
69 | "data_file_states = os.path.join(dir, 'states_test.hkl')\n",
70 | "data_file_grids = os.path.join(dir, 'label_grids_test.hkl')\n",
71 | "data_file_sources = os.path.join(dir, 'sources_test.hkl')\n",
72 | "\n",
73 | "data_test = SequenceGenerator(data_file_state=data_file_states, data_file_grid=data_file_grids, source_file=data_file_sources, nt=nt,\n",
74 | " batch_size=None, shuffle=False, sequence_start_mode='unique', norm=True)\n",
75 | "\n",
76 | "test_loader = torch.utils.data.DataLoader(data_test,\n",
77 | " batch_size=len(data_test), shuffle=False,\n",
78 | " num_workers=mp.cpu_count()-1, pin_memory=True)"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "for batch_x_test, batch_y_test, sources_test in test_loader:\n",
88 | " batch_x_test, batch_y_test_orig = batch_x_test.to(device), batch_y_test.to(device)\n",
89 | " batch_y_test_orig = batch_y_test_orig.view(batch_y_test_orig.shape[0],1,20,30)\n",
90 | " y_full = unnormalize(batch_x_test.cpu().data.numpy(), nt)\n",
91 | " pos_x = y_full[:,:,0]\n",
92 | " pos_y = y_full[:,:,1]\n",
93 | " orientation = y_full[:,:,2]\n",
94 | " cos_theta = np.cos(orientation)\n",
95 | " sin_theta = np.sin(orientation)\n",
96 | " vel_x = y_full[:,:,3]\n",
97 | " vel_y = y_full[:,:,4]\n",
98 | " speed = np.sqrt(vel_x**2 + vel_y**2)\n",
99 | " acc_x = y_full[:,:,5]\n",
100 | " acc_y = y_full[:,:,6]\n",
101 | "\n",
102 | " # Project the acceleration on the orientation vector to get longitudinal acceleration.\n",
103 | " dot_prod = acc_x * cos_theta + acc_y * sin_theta\n",
104 | " sign = np.sign(dot_prod)\n",
105 | " acc_proj_x = dot_prod * cos_theta\n",
106 | " acc_proj_y = dot_prod * sin_theta\n",
107 | "\n",
108 | " acc_proj_sign = sign * np.sqrt(acc_proj_x**2 + acc_proj_y**2)\n",
109 | "\n",
110 | " batch_size = batch_y_test_orig.shape[0]"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "# VAE\n",
120 | "folder_vae = '/models/cvae'\n",
121 | "name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256'\n",
122 | "\n",
123 | "vae = VAE(\n",
124 | " encoder_layer_sizes_p=[7, 5],\n",
125 | " n_lstms=1,\n",
126 | " latent_size=100,\n",
127 | " decoder_layer_sizes=[256, 600], # used to be 10\n",
128 | " dim=4\n",
129 | " )\n",
130 | "\n",
131 | "vae = vae.cuda()\n",
132 | "\n",
133 | "save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt'\n",
134 | "\n",
135 | "with open(save_filename, 'rb') as f:\n",
136 | " state_dict = torch.load(f)\n",
137 | " vae.load_state_dict(state_dict)\n",
138 | "\n",
139 | "vae.eval()\n",
140 | "\n",
141 | "with torch.no_grad():\n",
142 | " recon_y_inf_most_likely, alpha_p, alpha_p_lin, full_c, z = vae.inference(n=1, c=batch_x_test, mode='most_likely')\n",
143 | " recon_x_inf, _, _, _, _ = vae.inference(n=100, c=batch_x_test, mode='all')\n",
144 | "print(recon_y_inf_most_likely.shape, batch_y_test_orig.shape)\n",
145 | "print(torch.max(recon_y_inf_most_likely), torch.min(recon_y_inf_most_likely))\n",
146 | "\n",
147 | "grid_shape = (20,30)\n",
148 | "recon_y_inf_np = np.reshape(recon_y_inf_most_likely.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
149 | "y_np = np.reshape(batch_y_test_orig.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
150 | "print(recon_y_inf_np.shape, y_np.shape)\n",
151 | "\n",
152 | "recon_y_inf_np_pred = (recon_y_inf_np >= 0.6).astype(float)\n",
153 | "recon_y_inf_np_pred[recon_y_inf_np <= 0.4] = 0.0\n",
154 | "recon_y_inf_np_pred[np.logical_and(recon_y_inf_np < 0.6, recon_y_inf_np > 0.4)] = 0.5\n",
155 | "\n",
156 | "acc = np.mean(recon_y_inf_np_pred == y_np)\n",
157 | "mse = np.mean((recon_y_inf_np - y_np)**2)\n",
158 | "print('Acc: ', acc, 'MSE: ', mse)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": null,
164 | "metadata": {},
165 | "outputs": [],
166 | "source": [
167 | "# Visualize all latent classes.\n",
168 | "recon_y_inf_all, _, _, _, _ = vae.inference(n=1, c=batch_x_test, mode='all')\n",
169 | "recon_y_inf_all_np = recon_y_inf_all.cpu().data.numpy()\n",
170 | "\n",
171 | "fig = plt.figure(figsize=(10, 10))\n",
172 | "grid = ImageGrid(fig, 111,\n",
173 | " nrows_ncols=(10, 10),\n",
174 | " axes_pad=0.1,\n",
175 | " )\n",
176 | "\n",
177 | "for ax, im in zip(grid, recon_y_inf_all_np[:,0]):\n",
178 | " ax.matshow(im, cmap='gray_r', vmin=0, vmax=1)\n",
179 | " ax.set_xticks([], [])\n",
180 | " ax.set_yticks([], [])\n",
181 | "plt.savefig('all_latent_classes_vae.png') \n",
182 | "plt.show()"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# Visualize the three most likely modes VAE: slowing down.\n",
192 | "sample = np.where(np.sum(acc_proj_sign < -1.5, axis=-1) > 0)[0][16]\n",
193 | "print('sample: ', sample) \n",
194 | "print('acceleration: ', acc_proj_sign[sample])\n",
195 | "print('speed: ', speed[sample])\n",
196 | "print('probabilitiies: ', torch.sort(alpha_p[sample])[0])\n",
197 | "\n",
198 | "recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)\n",
199 | "recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
200 | "\n",
201 | "plt.gca().set_aspect('equal', adjustable='box') # 'datalim'\n",
202 | "plt.scatter(pos_x[sample], pos_y[sample])\n",
203 | "plt.savefig(str(sample) + '_traj_dec.png')\n",
204 | "tikzplotlib.save(str(sample) + '_traj_dec.tex')\n",
205 | "plt.figure()\n",
206 | "plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)\n",
207 | "plt.xticks([])\n",
208 | "plt.yticks([])\n",
209 | "plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)\n",
210 | "plt.show()\n",
211 | "\n",
212 | "recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)\n",
213 | "recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
214 | "\n",
215 | "plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)\n",
216 | "plt.xticks([])\n",
217 | "plt.yticks([])\n",
218 | "plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_dec.png', pad_inches=0.0)\n",
219 | "plt.show()\n",
220 | "\n",
221 | "plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)\n",
222 | "plt.xticks([])\n",
223 | "plt.yticks([])\n",
224 | "plt.savefig('models/' + str(sample) + '_gt_dec.png', pad_inches=0.0)\n",
225 | "plt.show()"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": null,
231 | "metadata": {},
232 | "outputs": [],
233 | "source": [
234 | "# Visualize the most likely modes VAE: constant speed.\n",
235 | "print(np.sum(np.sum(np.logical_and(np.abs(acc_proj_sign) < 10, speed > 5.0), axis=-1) > 9), speed.shape)\n",
236 | "sample = np.where(np.logical_and(np.sum(np.logical_and(np.abs(acc_proj_sign) < 0.25, speed > 5.0), axis=-1) > 9, alpha_p.cpu().detach().numpy()[:,27] > 0.3))[0][7]\n",
237 | "print('sample: ', sample) \n",
238 | "print('acceleration: ', acc_proj_sign[sample])\n",
239 | "print('speed: ', speed[sample])\n",
240 | "print('probabilitiies: ', torch.sort(alpha_p[sample])[0])\n",
241 | "\n",
242 | "recon_y_inf_1, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=1)\n",
243 | "recon_y_inf_np_1 = np.reshape(recon_y_inf_1.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
244 | "\n",
245 | "plt.gca().set_aspect('equal', adjustable='box') # 'datalim'\n",
246 | "plt.scatter(pos_x[sample], pos_y[sample])\n",
247 | "plt.savefig(str(sample) + '_traj_const.png')\n",
248 | "tikzplotlib.save(str(sample) + '_traj_const.tex')\n",
249 | "plt.figure()\n",
250 | "plt.matshow(recon_y_inf_np_1[0], cmap='gray_r', vmin=0, vmax=1)\n",
251 | "plt.xticks([])\n",
252 | "plt.yticks([])\n",
253 | "plt.savefig('models/' + str(sample) + '_mode_1_' + str(torch.sort(alpha_p[sample])[0][-1].cpu().data.numpy()) + '_const.png', pad_inches=0.0)\n",
254 | "plt.show()\n",
255 | "\n",
256 | "recon_y_inf_2, _, _, _, _ = vae.inference(n=1, c=torch.unsqueeze(batch_x_test[sample], dim=0), mode='multimodal', k=2)\n",
257 | "recon_y_inf_np_2 = np.reshape(recon_y_inf_2.cpu().data.numpy(), (-1, grid_shape[0], grid_shape[1]))\n",
258 | "\n",
259 | "plt.matshow(recon_y_inf_np_2[0], cmap='gray_r', vmin=0, vmax=1)\n",
260 | "plt.xticks([])\n",
261 | "plt.yticks([])\n",
262 | "plt.savefig('models/' + str(sample) + '_mode_2_' + str(torch.sort(alpha_p[sample])[0][-2].cpu().data.numpy()) + '_const.png', pad_inches=0.0)\n",
263 | "plt.show()\n",
264 | "\n",
265 | "plt.matshow(batch_y_test_orig[sample,0].cpu().data.numpy(), cmap='gray_r', vmin=0, vmax=1)\n",
266 | "plt.xticks([])\n",
267 | "plt.yticks([])\n",
268 | "plt.savefig('models/' + str(sample) + '_gt_const.png', pad_inches=0.0)\n",
269 | "plt.show()"
270 | ]
271 | }
272 | ],
273 | "metadata": {
274 | "kernelspec": {
275 | "display_name": "Python 3",
276 | "language": "python",
277 | "name": "python3"
278 | },
279 | "language_info": {
280 | "codemirror_mode": {
281 | "name": "ipython",
282 | "version": 3
283 | },
284 | "file_extension": ".py",
285 | "mimetype": "text/x-python",
286 | "name": "python",
287 | "nbconvert_exporter": "python",
288 | "pygments_lexer": "ipython3",
289 | "version": "3.6.10"
290 | }
291 | },
292 | "nbformat": 4,
293 | "nbformat_minor": 4
294 | }
295 |
--------------------------------------------------------------------------------
/src/full_pipeline/main_save_full_pipeline.py:
--------------------------------------------------------------------------------
1 | # Code to save full occlusion inference pipeline. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | try:
4 | import lanelet2
5 | use_lanelet2_lib = True
6 | except:
7 | import warnings
8 | string = "Could not import lanelet2. It must be built and sourced, " + \
9 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details."
10 | warnings.warn(string)
11 | print("Using visualization without lanelet2.")
12 | use_lanelet2_lib = False
13 | from utils import map_vis_without_lanelet
14 |
15 | import argparse
16 | import os
17 | import time
18 | import matplotlib.pyplot as plt
19 | from matplotlib.widgets import Button
20 | import hickle as hkl
21 | import pickle as pkl
22 | import pdb
23 | import numpy as np
24 | import csv
25 |
26 | seed = 123
27 |
28 | import numpy as np
29 | np.random.seed(seed)
30 | from matplotlib import pyplot as plt
31 | import torch
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.backends.cudnn.deterministic = True
35 | torch.backends.cudnn.benchmark = False
36 | import torch.nn as nn
37 | torch.autograd.set_detect_anomaly(True)
38 |
39 | import io
40 | import PIL.Image
41 | from tqdm import tqdm
42 | import time
43 |
44 | import argparse
45 | import pandas as pd
46 | import seaborn as sns
47 | import matplotlib.pyplot as plt
48 | from torchvision import datasets, transforms
49 | from torch.utils.data import DataLoader
50 | from torch.utils.tensorboard import SummaryWriter
51 | from torch.optim.lr_scheduler import ExponentialLR
52 | from collections import OrderedDict, defaultdict
53 |
54 | os.chdir("../..")
55 |
56 | from src.utils import dataset_reader
57 | from src.utils import dataset_types
58 | from src.utils import map_vis_lanelet2
59 | from src.utils import tracks_save
60 | # from src.utils import dict_utils
61 | from src.driver_sensor_model.models_cvae import VAE
62 | from src.utils.interaction_utils import *
63 |
64 | import torch._utils
65 | try:
66 | torch._utils._rebuild_tensor_v2
67 | except AttributeError:
68 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
69 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
70 | tensor.requires_grad = requires_grad
71 | tensor._backward_hooks = backward_hooks
72 | return tensor
73 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
74 |
75 |
76 | def update_plot():
77 | global timestamp, track_dictionary, pedestrian_dictionary,\
78 | data, car_ids, sensor_grids, id_grids, label_grids, driver_sensor_data, driver_sensor_state_data, driver_sensor_state, driver_sensor_state_dict,\
79 | min_max_xy, models, results, mode, model
80 |
81 | # Update text and tracks based on current timestamp.
82 | if (timestamp < timestamp_min):
83 | pdb.set_trace()
84 | assert(timestamp <= timestamp_max), "timestamp=%i" % timestamp
85 | assert(timestamp >= timestamp_min), "timestamp=%i" % timestamp
86 | assert(timestamp % dataset_types.DELTA_TIMESTAMP_MS == 0), "timestamp=%i" % timestamp
87 | tracks_save_DR_goal_multimodal_average.update_objects_plot(timestamp, track_dict=track_dictionary, pedest_dict=pedestrian_dictionary,
88 | data=data, car_ids=car_ids, sensor_grids=sensor_grids, id_grids=id_grids, label_grids=label_grids,
89 | driver_sensor_data = driver_sensor_data, driver_sensor_state_data=driver_sensor_state_data, driver_sensor_state = driver_sensor_state,
90 | endpoint=min_max_xy, models=models, results=results, mode=mode, model=model)
91 |
92 | if __name__ == "__main__":
93 |
94 | # Provide data to be visualized.
95 | parser = argparse.ArgumentParser()
96 | parser.add_argument("scenario_name", type=str, help="Name of the scenario (to identify map and folder for track "
97 | "files)", nargs="?")
98 | parser.add_argument("track_file_number", type=int, help="Number of the track file (int)", default=0, nargs="?")
99 | parser.add_argument("load_mode", type=str, help="Dataset to load (vehicle, pedestrian, or both)", default="both",
100 | nargs="?")
101 | parser.add_argument("--mode", type=str, help="Sensor fusion mode: evidential or average", nargs="?")
102 | parser.add_argument("--model", type=str, help="Name of the model: vae, gmm, kmeans", nargs="?")
103 | args = parser.parse_args()
104 |
105 | if args.load_mode != "vehicle" and args.load_mode != "pedestrian" and args.load_mode != "both":
106 | raise IOError("Invalid load command. Use 'vehicle', 'pedestrian', or 'both'")
107 |
108 | # Load test folder.
109 | error_string = ""
110 | tracks_dir = "/data/INTERACTION-Dataset-DR-v1_1/recorded_trackfiles"
111 | maps_dir = "/data/INTERACTION-Dataset-DR-v1_1/maps"
112 | scenario_name = 'DR_USA_Intersection_GL'
113 | home = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/"
114 | test_set_dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split'
115 | test_set_file = 'ego_test_set.csv'
116 |
117 | ego_files = []
118 | with open(os.path.join(test_set_dir, test_set_file)) as csv_file:
119 | csv_reader = csv.reader(csv_file, delimiter=',')
120 | for row in csv_reader:
121 | ego_files.append(row[0].split(home)[-1])
122 |
123 | number = 0
124 | ego_file = ego_files[number]
125 |
126 | # Load the models.
127 | models = dict()
128 | folder_vae = '/model/cvae/'
129 | name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256'
130 |
131 | models['vae'] = VAE(
132 | encoder_layer_sizes_p=[7, 5],
133 | n_lstms=1,
134 | latent_size=100,
135 | dim=4
136 | )
137 |
138 | models['vae'] = models['vae'].cuda()
139 |
140 | save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt'
141 |
142 | with open(save_filename, 'rb') as f:
143 | state_dict = torch.load(f)
144 | models['vae'].load_state_dict(state_dict)
145 | models['vae'].eval()
146 |
147 | folder_kmeans = '/model/kmeans/'
148 | models['kmeans'] = pkl.load(open(os.path.join(folder_kmeans,"clusters_kmeans_K_100.pkl"), "rb" ) )
149 |
150 | folder_gmm = '/model/gmm/'
151 | models['gmm'] = pkl.load(open(os.path.join(folder_gmm, "GMM_K_100_reg_covar_0_001.pkl"), "rb" ) )
152 |
153 | if args.model == 'vae':
154 | folder_model = folder_vae
155 | elif args.model == 'gmm':
156 | folder_model = folder_gmm
157 | elif args.model == 'kmeans':
158 | folder_model = folder_kmeans
159 |
160 | # Get the driver sensor data.
161 | for ego_file in tqdm(ego_files):
162 | print(ego_file)
163 | ego_run = int(ego_file.split('_run_')[-1].split('_ego_')[0])
164 | if scenario_name[-2:] == 'GL':
165 | middle = scenario_name
166 | track_file_number = ego_file.split("_run_")[0][-3:]
167 | if os.path.exists(os.path.join(home, ego_file)):
168 | data = pkl.load(open(os.path.join(home, ego_file), 'rb'))
169 | else:
170 | continue
171 |
172 | driver_sensor_data = dict()
173 | driver_sensor_state_data = dict()
174 | driver_sensor_state = dict()
175 | for file in os.listdir(os.path.join(home)):
176 | if scenario_name[-2:] == 'GL':
177 | file_split = file.split("_run_")
178 | if track_file_number != file_split[0][-3:]:
179 | continue
180 | [run, vtype, _, vid] = file_split[-1][:-4].split('_')
181 | run = int(run)
182 | vid = int(vid)
183 |
184 | if (run == ego_run) and (vtype == 'ref'):
185 | driver_sensor_data[vid] = pkl.load(open(os.path.join(home, file), 'rb'))
186 | for item in driver_sensor_data[vid][1:]:
187 | timestamp = item[0]
188 | x = item[2]
189 | y = item[3]
190 | orientation = item[4]
191 | vx = item[5]
192 | vy = item[6]
193 | ax = item[7]
194 | ay = item[8]
195 | if vid in driver_sensor_state_data.keys():
196 | driver_sensor_state_data[vid].append(np.array([timestamp, x, y, orientation, vx, vy, ax, ay]))
197 | else:
198 | driver_sensor_state_data[vid] = [np.array([timestamp, x, y, orientation, vx, vy, ax, ay])]
199 |
200 | lanelet_map_ending = ".osm"
201 | lanelet_map_file = maps_dir + "/" + scenario_name + lanelet_map_ending
202 | scenario_dir = tracks_dir + "/" + scenario_name
203 | track_file_prefix = "vehicle_tracks_"
204 | track_file_ending = ".csv"
205 | track_file_name = scenario_dir + "/" + track_file_prefix + str(track_file_number).zfill(3) + track_file_ending
206 | pedestrian_file_prefix = "pedestrian_tracks_"
207 | pedestrian_file_ending = ".csv"
208 | pedestrian_file_name = scenario_dir + "/" + pedestrian_file_prefix + str(track_file_number).zfill(3) + pedestrian_file_ending
209 | if not os.path.isdir(tracks_dir):
210 | error_string += "Did not find track file directory \"" + tracks_dir + "\"\n"
211 | if not os.path.isdir(maps_dir):
212 | error_string += "Did not find map file directory \"" + tracks_dir + "\"\n"
213 | if not os.path.isdir(scenario_dir):
214 | error_string += "Did not find scenario directory \"" + scenario_dir + "\"\n"
215 | if not os.path.isfile(lanelet_map_file):
216 | error_string += "Did not find lanelet map file \"" + lanelet_map_file + "\"\n"
217 | if not os.path.isfile(track_file_name):
218 | error_string += "Did not find track file \"" + track_file_name + "\"\n"
219 | if not os.path.isfile(pedestrian_file_name):
220 | flag_ped = 0
221 | else:
222 | flag_ped = 1
223 | if error_string != "":
224 | error_string += "Type --help for help."
225 | raise IOError(error_string)
226 |
227 | # Load and draw the lanelet2 map, either with or without the lanelet2 library.
228 | lat_origin = 0. # Origin is necessary to correctly project the lat lon values in the osm file to the local.
229 | lon_origin = 0.
230 | print("Loading map...")
231 | fig, axes = plt.subplots(1, 1)
232 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin)
233 |
234 | # Expand map size.
235 | min_max_xy[0:2] -= 25.0
236 | min_max_xy[2:] += 25.0
237 |
238 | # Load the tracks.
239 | print("Loading tracks...")
240 | track_dictionary = None
241 | pedestrian_dictionary = None
242 | if args.load_mode == 'both':
243 | track_dictionary = dataset_reader.read_tracks(track_file_name)
244 | if flag_ped:
245 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name)
246 |
247 | elif args.load_mode == 'vehicle':
248 | track_dictionary = dataset_reader.read_tracks(track_file_name)
249 | elif args.load_mode == 'pedestrian':
250 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name)
251 |
252 | timestamp_min = 1e9
253 | timestamp_max = 0
254 |
255 | # Set the time to the run and get the sensor vehicles.
256 | ego_id = data[1][1]
257 | car_ids = {}
258 | sensor_grids = {}
259 | label_grids = {}
260 | id_grids = {}
261 | for i in range(1,len(data)):
262 | timestamp_min = min(timestamp_min, data[i][0])
263 | timestamp_max = max(timestamp_max, data[i][0])
264 |
265 | if data[i][0] in car_ids.keys():
266 | car_ids[data[i][0]].append(data[i][1])
267 | else:
268 | car_ids[data[i][0]] = [data[i][1]]
269 |
270 | if data[i][1] == ego_id:
271 | sensor_grids[data[i][0]] = data[i][-2]
272 | id_grids[data[i][0]] = data[i][-3]
273 | label_grids[data[i][0]] = data[i][-4]
274 |
275 | args.start_timestamp = timestamp_min
276 |
277 | # Results.
278 | results = dict()
279 | results['ego_sensor'] = []
280 | results['ego_label'] = []
281 | results['vae'] = []
282 | results['timestamp'] = []
283 | results['source'] = []
284 | results['run_time'] = []
285 | results['all_latent_classes'] = []
286 | results['ref_local_xy'] = []
287 | results['ego_local_xy'] = []
288 | results['alpha_p'] = []
289 | results['ego_sensor_dst'] = []
290 | results['endpoint'] = []
291 | results['res'] = []
292 | mode = args.mode
293 | model = args.model
294 |
295 | print("Saving...")
296 | timestamp = args.start_timestamp
297 |
298 | while timestamp < timestamp_max:
299 | update_plot()
300 | timestamp += dataset_types.DELTA_TIMESTAMP_MS
301 |
302 | # Clear all variables for the next scenario
303 | del(timestamp_min)
304 | del(timestamp_max)
305 | del(timestamp)
306 |
307 | for i in range(len(results['timestamp'])):
308 | results['source'].append(ego_file[:-4])
309 |
310 | # Save the data for each run.
311 | folder_model_new = os.path.join(folder_model, 'full_pipeline_' + model + '_' + mode)
312 | if not os.path.isdir(folder_model_new):
313 | os.mkdir(folder_model_new)
314 | pkl.dump(results, open(os.path.join(folder_model_new, ego_file[:-4] + '_ego_results.pkl'), "wb"))
315 |
316 |
--------------------------------------------------------------------------------
/src/full_pipeline/main_visualize_full_pipeline.py:
--------------------------------------------------------------------------------
1 | # Code to visualize full occlusion inference pipeline. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | try:
4 | import lanelet2
5 | use_lanelet2_lib = True
6 | except:
7 | import warnings
8 | string = "Could not import lanelet2. It must be built and sourced, " + \
9 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details."
10 | warnings.warn(string)
11 | print("Using visualization without lanelet2.")
12 | use_lanelet2_lib = False
13 | from utils import map_vis_without_lanelet
14 |
15 | import argparse
16 | import os
17 | import time
18 | import matplotlib.pyplot as plt
19 | from matplotlib.widgets import Button
20 | import hickle as hkl
21 | import pickle as pkl
22 | import pdb
23 | import numpy as np
24 | import csv
25 |
26 | seed = 123
27 |
28 | import numpy as np
29 | np.random.seed(seed)
30 | from matplotlib import pyplot as plt
31 | import torch
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.backends.cudnn.deterministic = True
35 | torch.backends.cudnn.benchmark = False
36 | import torch.nn as nn
37 | torch.autograd.set_detect_anomaly(True)
38 |
39 | import io
40 | import PIL.Image
41 | from tqdm import tqdm
42 | import time
43 |
44 | import argparse
45 | import pandas as pd
46 | import seaborn as sns
47 | import matplotlib.pyplot as plt
48 | from collections import OrderedDict, defaultdict
49 |
50 | from src.utils import dataset_reader
51 | from src.utils import dataset_types
52 | from src.utils import map_vis_lanelet2
53 | from src.utils import tracks_vis
54 | # from src.utils import dict_utils
55 | from src.driver_sensor_model.models_cvae import VAE
56 | from src.utils.interaction_utils import *
57 |
58 | import torch._utils
59 | try:
60 | torch._utils._rebuild_tensor_v2
61 | except AttributeError:
62 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
63 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
64 | tensor.requires_grad = requires_grad
65 | tensor._backward_hooks = backward_hooks
66 | return tensor
67 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
68 |
69 |
70 | def update_plot():
71 | global fig, timestamp, title_text, track_dictionary, patches_dict, text_dict, axes, pedestrian_dictionary,\
72 | data, car_ids, sensor_grids, id_grids, label_grids, grids_dict, driver_sensor_data, driver_sensor_state_data, driver_sensor_state, driver_sensor_state_dict,\
73 | min_max_xy, models, results, mode, model
74 |
75 | # Update text and tracks based on current timestamp.
76 | if (timestamp < timestamp_min):
77 | pdb.set_trace()
78 | assert(timestamp <= timestamp_max), "timestamp=%i" % timestamp
79 | assert(timestamp >= timestamp_min), "timestamp=%i" % timestamp
80 | assert(timestamp % dataset_types.DELTA_TIMESTAMP_MS == 0), "timestamp=%i" % timestamp
81 | title_text.set_text("\nts = {}".format(timestamp))
82 | tracks_vis_DR_goal_multimodal_average.update_objects_plot(timestamp, patches_dict, text_dict, axes, track_dict=track_dictionary, pedest_dict=pedestrian_dictionary,
83 | data=data, car_ids=car_ids, sensor_grids=sensor_grids, id_grids=id_grids, label_grids=label_grids, grids_dict=grids_dict,
84 | driver_sensor_data = driver_sensor_data, driver_sensor_state_data=driver_sensor_state_data, driver_sensor_state = driver_sensor_state, driver_sensor_state_dict=driver_sensor_state_dict,
85 | endpoint=min_max_xy, models=models, mode=mode, model=model)
86 |
87 | fig.canvas.draw()
88 |
89 |
90 | def start_playback():
91 | global timestamp, timestamp_min, timestamp_max, playback_stopped
92 | playback_stopped = False
93 | plt.ion()
94 | while timestamp < timestamp_max and not playback_stopped:
95 | timestamp += dataset_types.DELTA_TIMESTAMP_MS
96 | start_time = time.time()
97 | update_plot()
98 | end_time = time.time()
99 | diff_time = end_time - start_time
100 | # plt.pause(max(0.001, dataset_types.DELTA_TIMESTAMP_MS / 1000. - diff_time))
101 | plt.ioff()
102 |
103 |
104 | class FrameControlButton(object):
105 | def __init__(self, position, label):
106 | self.ax = plt.axes(position)
107 | self.label = label
108 | self.button = Button(self.ax, label)
109 | self.button.on_clicked(self.on_click)
110 |
111 | def on_click(self, event):
112 | global timestamp, timestamp_min, timestamp_max, playback_stopped
113 |
114 | if self.label == "play":
115 | if not playback_stopped:
116 | return
117 | else:
118 | start_playback()
119 | return
120 | playback_stopped = True
121 | if self.label == "<<":
122 | timestamp -= 10*dataset_types.DELTA_TIMESTAMP_MS
123 | elif self.label == "<":
124 | timestamp -= dataset_types.DELTA_TIMESTAMP_MS
125 | elif self.label == ">":
126 | timestamp += dataset_types.DELTA_TIMESTAMP_MS
127 | elif self.label == ">>":
128 | timestamp += 10*dataset_types.DELTA_TIMESTAMP_MS
129 | timestamp = min(timestamp, timestamp_max)
130 | timestamp = max(timestamp, timestamp_min)
131 | update_plot()
132 |
133 | if __name__ == "__main__":
134 |
135 | # Provide data to be visualized.
136 | parser = argparse.ArgumentParser()
137 | parser.add_argument("scenario_name", type=str, help="Name of the scenario (to identify map and folder for track "
138 | "files)", nargs="?")
139 | parser.add_argument("track_file_number", type=int, help="Number of the track file (int)", default=0, nargs="?")
140 | parser.add_argument("load_mode", type=str, help="Dataset to load (vehicle, pedestrian, or both)", default="both",
141 | nargs="?")
142 | parser.add_argument("--start_timestamp", type=int, nargs="?")
143 | parser.add_argument("--mode", type=str, default='evidential', help="Sensor fusion mode: evidential or average", nargs="?")
144 | parser.add_argument("--model", type=str, default='vae', help="Name of the model: vae, gmm, kmeans", nargs="?")
145 | args = parser.parse_args()
146 |
147 | if args.load_mode != "vehicle" and args.load_mode != "pedestrian" and args.load_mode != "both":
148 | raise IOError("Invalid load command. Use 'vehicle', 'pedestrian', or 'both'")
149 |
150 | # Load test folder.
151 | error_string = ""
152 | tracks_dir = "/data/INTERACTION-Dataset-DR-v1_1/recorded_trackfiles"
153 | maps_dir = "/data/INTERACTION-Dataset-DR-v1_1/maps"
154 | scenario_name = 'DR_USA_Intersection_GL'
155 | home = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/"
156 | test_set_dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split'
157 | test_set_file = 'ego_test_set.csv'
158 |
159 | ego_files = []
160 | with open(os.path.join(test_set_dir, test_set_file)) as csv_file:
161 | csv_reader = csv.reader(csv_file, delimiter=',')
162 | for row in csv_reader:
163 | ego_files.append(row[0].split(home)[-1])
164 |
165 | number = 0
166 | ego_file = ego_files[number]
167 |
168 | # Load the models.
169 | models = dict()
170 | folder_vae = '/model/cvae/'
171 | name_vae = 'lstm_1_Adam_z_100_lr_0.001_rand_123_norm_True_kl_start_0_finish_1.0_center_10000.0_mutual_info_const_alpha_1.5_epochs_30_batch_256'
172 |
173 | models['vae'] = VAE(
174 | encoder_layer_sizes_p=[7, 5],
175 | n_lstms=1,
176 | latent_size=100,
177 | dim=4
178 | )
179 |
180 | models['vae'] = models['vae'].cuda()
181 |
182 | save_filename = os.path.join(folder_vae, name_vae) + 'epoch_30_vae.pt'
183 |
184 | with open(save_filename, 'rb') as f:
185 | state_dict = torch.load(f)
186 | models['vae'].load_state_dict(state_dict)
187 | models['vae'].eval()
188 |
189 | folder_kmeans = '/model/kmeans/'
190 | models['kmeans'] = pkl.load(open(os.path.join(folder_kmeans,"clusters_kmeans_K_100.pkl"), "rb" ) )
191 |
192 | folder_gmm = '/model/gmm/'
193 | models['gmm'] = pkl.load(open(os.path.join(folder_gmm, "GMM_K_100_reg_covar_0_001.pkl"), "rb" ) )
194 |
195 | if args.model == 'vae':
196 | folder_model = folder_vae
197 | elif args.model == 'gmm':
198 | folder_model = folder_gmm
199 | elif args.model == 'kmeans':
200 | folder_model = folder_kmeans
201 |
202 | for ego_file in tqdm(ego_files):
203 | print(ego_file)
204 |
205 | # Visualize the paper and appendix scenarios.
206 | if (ego_file != 'DR_USA_Intersection_GL_037_run_54_ego_vehicle_83.pkl') and (ego_file != 'DR_USA_Intersection_GL_021_run_60_ego_vehicle_108.pkl'):
207 | continue
208 |
209 | ego_run = int(ego_file.split('_run_')[-1].split('_ego_')[0])
210 | if scenario_name[-2:] == 'VA':
211 | track_file_number = int(ego_file.split(scenario_name)[-1][1:4])
212 | middle = scenario_name + "/" + scenario_name + "_00" + str(track_file_number)
213 | if os.path.exists(os.path.join(home, middle, ego_file)):
214 | data = hkl.load(os.path.join(home, middle, ego_file))
215 | else:
216 | continue
217 | elif scenario_name[-2:] == 'GL':
218 | middle = scenario_name
219 | track_file_number = ego_file.split("_run_")[0][-3:]
220 | if os.path.exists(os.path.join(home, ego_file)):
221 | data = pkl.load(open(os.path.join(home, ego_file), 'rb'))
222 | else:
223 | continue
224 |
225 | # Get the driver sensor data.
226 | driver_sensor_data = dict()
227 | driver_sensor_state_data = dict()
228 | driver_sensor_state = dict()
229 | for file in os.listdir(os.path.join(home)):
230 | if scenario_name[-2:] == 'VA':
231 | file_split = file.split(scenario_name + "_00" + str(track_file_number) + "_run_")
232 | [run, vtype, _, vid] = file_split[-1][:-4].split('_')
233 | elif scenario_name[-2:] == 'GL':
234 | file_split = file.split("_run_")
235 | if track_file_number != file_split[0][-3:]:
236 | continue
237 | [run, vtype, _, vid] = file_split[-1][:-4].split('_')
238 | run = int(run)
239 | vid = int(vid)
240 |
241 | if (run == ego_run) and (vtype == 'ref'):
242 | driver_sensor_data[vid] = pkl.load(open(os.path.join(home, file), 'rb'))
243 | for item in driver_sensor_data[vid][1:]:
244 | timestamp = item[0]
245 | x = item[2]
246 | y = item[3]
247 | orientation = item[4]
248 | vx = item[5]
249 | vy = item[6]
250 | ax = item[7]
251 | ay = item[8]
252 | if vid in driver_sensor_state_data.keys():
253 | driver_sensor_state_data[vid].append(np.array([timestamp, x, y, orientation, vx, vy, ax, ay]))
254 | else:
255 | driver_sensor_state_data[vid] = [np.array([timestamp, x, y, orientation, vx, vy, ax, ay])]
256 |
257 | lanelet_map_ending = ".osm"
258 | lanelet_map_file = maps_dir + "/" + scenario_name + lanelet_map_ending
259 | scenario_dir = tracks_dir + "/" + scenario_name
260 | track_file_prefix = "vehicle_tracks_"
261 | track_file_ending = ".csv"
262 | track_file_name = scenario_dir + "/" + track_file_prefix + str(track_file_number).zfill(3) + track_file_ending
263 | pedestrian_file_prefix = "pedestrian_tracks_"
264 | pedestrian_file_ending = ".csv"
265 | pedestrian_file_name = scenario_dir + "/" + pedestrian_file_prefix + str(track_file_number).zfill(3) + pedestrian_file_ending
266 | if not os.path.isdir(tracks_dir):
267 | error_string += "Did not find track file directory \"" + tracks_dir + "\"\n"
268 | if not os.path.isdir(maps_dir):
269 | error_string += "Did not find map file directory \"" + tracks_dir + "\"\n"
270 | if not os.path.isdir(scenario_dir):
271 | error_string += "Did not find scenario directory \"" + scenario_dir + "\"\n"
272 | if not os.path.isfile(lanelet_map_file):
273 | error_string += "Did not find lanelet map file \"" + lanelet_map_file + "\"\n"
274 | if not os.path.isfile(track_file_name):
275 | error_string += "Did not find track file \"" + track_file_name + "\"\n"
276 | if not os.path.isfile(pedestrian_file_name):
277 | flag_ped = 0
278 | else:
279 | flag_ped = 1
280 | if error_string != "":
281 | error_string += "Type --help for help."
282 | raise IOError(error_string)
283 |
284 | # Create a figure.
285 | fig, axes = plt.subplots(1, 1)
286 | fig.canvas.set_window_title("Interaction Dataset Visualization")
287 |
288 | # Load and draw the lanelet2 map, either with or without the lanelet2 library.
289 | lat_origin = 0. # Origin is necessary to correctly project the lat lon values in the osm file to the local.
290 | lon_origin = 0.
291 | print("Loading map...")
292 | if use_lanelet2_lib:
293 | projector = lanelet2.projection.UtmProjector(lanelet2.io.Origin(lat_origin, lon_origin))
294 | laneletmap = lanelet2.io.load(lanelet_map_file, projector)
295 | map_vis_lanelet2.draw_lanelet_map(laneletmap, axes)
296 | else:
297 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin)
298 |
299 | # Expand map size.
300 | min_max_xy[0:2] -= 25.0
301 | min_max_xy[2:] += 25.0
302 |
303 | # Load the tracks.
304 | print("Loading tracks...")
305 | track_dictionary = None
306 | pedestrian_dictionary = None
307 | if args.load_mode == 'both':
308 | track_dictionary = dataset_reader.read_tracks(track_file_name)
309 | if flag_ped:
310 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name)
311 |
312 | elif args.load_mode == 'vehicle':
313 | track_dictionary = dataset_reader.read_tracks(track_file_name)
314 | elif args.load_mode == 'pedestrian':
315 | pedestrian_dictionary = dataset_reader.read_pedestrian(pedestrian_file_name)
316 |
317 | if (pedestrian_dictionary == None) or len(pedestrian_dictionary.items()) == 0:
318 | continue
319 |
320 | timestamp_min = 1e9
321 | timestamp_max = 0
322 |
323 | # Set the time to the run and get the sensor vehicles
324 | ego_id = data[1][1]
325 | car_ids = {}
326 | sensor_grids = {}
327 | label_grids = {}
328 | id_grids = {}
329 | for i in range(1,len(data)):
330 | timestamp_min = min(timestamp_min, data[i][0])
331 | timestamp_max = max(timestamp_max, data[i][0])
332 |
333 | if data[i][0] in car_ids.keys():
334 | car_ids[data[i][0]].append(data[i][1])
335 | else:
336 | car_ids[data[i][0]] = [data[i][1]]
337 |
338 | if data[i][1] == ego_id:
339 | sensor_grids[data[i][0]] = data[i][-2]
340 | id_grids[data[i][0]] = data[i][-3]
341 | label_grids[data[i][0]] = data[i][-4]
342 |
343 | if ego_file == 'DR_USA_Intersection_GL_037_run_54_ego_vehicle_83.pkl':
344 | args.start_timestamp = 117500
345 | elif ego_file == 'DR_USA_Intersection_GL_021_run_60_ego_vehicle_108.pkl':
346 | args.start_timestamp = 168700
347 | else:
348 | args.start_timestamp = timestamp_min
349 |
350 | mode = args.mode
351 | model = args.model
352 |
353 | button_pp = FrameControlButton([0.2, 0.05, 0.05, 0.05], '<<')
354 | button_p = FrameControlButton([0.27, 0.05, 0.05, 0.05], '<')
355 | button_f = FrameControlButton([0.4, 0.05, 0.05, 0.05], '>')
356 | button_ff = FrameControlButton([0.47, 0.05, 0.05, 0.05], '>>')
357 |
358 | button_play = FrameControlButton([0.6, 0.05, 0.1, 0.05], 'play')
359 | button_pause = FrameControlButton([0.71, 0.05, 0.1, 0.05], 'pause')
360 |
361 | # Storage for track visualization.
362 | patches_dict = dict()
363 | text_dict = dict()
364 | grids_dict = dict()
365 | driver_sensor_state_dict = dict()
366 |
367 | # Visualize tracks.
368 | print("Plotting...")
369 | timestamp = args.start_timestamp
370 | title_text = fig.suptitle("")
371 | playback_stopped = True
372 | update_plot()
373 | plt.show()
374 |
375 | # Clear all variables for the next scenario.
376 | plt.clf()
377 | plt.close()
378 | del(timestamp_min)
379 | del(timestamp_max)
380 | del(timestamp)
381 |
382 |
--------------------------------------------------------------------------------
/src/preprocess/generate_data.py:
--------------------------------------------------------------------------------
1 | # INTERACTION dataset processing code. Some code snippets are adapted from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | import argparse
4 | try:
5 | import lanelet2
6 | use_lanelet2_lib = True
7 | except:
8 | import warnings
9 | string = "Could not import lanelet2. It must be built and sourced, " + \
10 | "see https://github.com/fzi-forschungszentrum-informatik/Lanelet2 for details."
11 | warnings.warn(string)
12 | print("Using visualization without lanelet2.")
13 | use_lanelet2_lib = False
14 | from utils import map_vis_without_lanelet
15 |
16 | from utils import dataset_reader
17 | from utils import dataset_types
18 | # from utils import dict_utils
19 | from utils import map_vis_lanelet2
20 | from utils.grid_utils import SceneObjects, global_grid, AllObjects, generateLabelGrid, generateSensorGrid
21 | from utils.dataset_types import Track, MotionState
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | import csv
25 | from collections import defaultdict
26 | import os
27 | import pickle as pkl
28 | from datetime import datetime
29 | import glob
30 | import random
31 | from tqdm import tqdm
32 | import time
33 | np.random.seed(123)
34 |
35 |
36 | def getstate(timestamp, track_dict, id):
37 | for key, value in track_dict.items():
38 | if key==id:
39 | return value.motion_states[timestamp]
40 |
41 | def getmap(maps_dir, scenario_data):
42 | # load and draw the lanelet2 map, either with or without the lanelet2 library
43 | fig, axes = plt.subplots(1)
44 | lat_origin = 0. # origin is necessary to correctly project the lat lon values in the osm file to the local
45 | lon_origin = 0. # coordinates in which the tracks are provided; we decided to use (0|0) for every scenario
46 | lanelet_map_ending = ".osm"
47 | lanelet_map_file = maps_dir + "/" + scenario_data + lanelet_map_ending
48 | # print("Loading map...")
49 | if use_lanelet2_lib:
50 | projector = lanelet2.projection.UtmProjector(lanelet2.io.Origin(lat_origin, lon_origin))
51 | laneletmap = lanelet2.io.load(lanelet_map_file, projector)
52 | map_vis_lanelet2.draw_lanelet_map(laneletmap, axes)
53 | else:
54 | min_max_xy = map_vis_without_lanelet.draw_map_without_lanelet(lanelet_map_file, axes, lat_origin, lon_origin)
55 | plt.close(fig)
56 | return min_max_xy
57 |
58 | # Get data for the observed drivers.
59 | def vis_state(vis_datas, ref_datas, object_id, label_grid_ego, prev_id_grid_ego, sensor_grid_ego, prev_sensor_grid_ego, track_dict, stamp, start_timestamp, track_pedes_dict=None, pedes_id=None): # For sensor vis & vis vis. Not for ego vis
60 | global vis_ax, vis_ay, vis_ids, ego_id, num_vis, res, gridglobal_x, gridglobal_y
61 |
62 | # Create a mask for where there is occupied space in the sensor grid that is not the ego vehicle.
63 | mask = np.where(np.logical_and(sensor_grid_ego==1., label_grid_ego[3]!=ego_id), True, False)
64 |
65 | # Get the visible car ids around the ego vehicle.
66 | vis_temp = np.array(np.unique(label_grid_ego[3,mask]),dtype=int)
67 |
68 | # Remove the pedestrians from the labels.
69 | vis_temp = vis_temp[vis_temp >= 0]
70 |
71 | # Create a mask for the previous timestamp sensor grid where there is occupied space that is not the ego vehicle.
72 | mask = np.where(np.logical_and(prev_sensor_grid_ego==1., prev_id_grid_ego!=ego_id), True, False)
73 | prev_vis_temp = np.array(np.unique(prev_id_grid_ego[mask]),dtype=int)
74 | prev_vis_temp = prev_vis_temp[prev_vis_temp >= 0]
75 |
76 | # Compute the number of new visible ids at this timestamp.
77 | num_new = 0
78 | # Loop through the visible ids at the current timestamp and add them to vis_ids.
79 | for vis_t in vis_temp:
80 | if vis_t not in vis_ids:
81 | # Add new visible vehicle to this vis_ids array.
82 | vis_ids.append(vis_t)
83 | num_new += 1
84 |
85 | # All visible cars till the current timestamp.
86 | num_vis = len(vis_ids)
87 | vis_ms = []
88 | prev_vis_ms = []
89 | widths = []
90 | lengths = []
91 |
92 | for id in vis_ids:
93 | # Check if the ID is present in the current and previous timestamp. This is important for acceleration computations.
94 | if id in vis_temp and id in prev_vis_temp:
95 | vis_ms.append(getstate(stamp, track_dict, id))
96 | widths.append(track_dict[id].width)
97 | lengths.append(track_dict[id].length)
98 | prev_vis_ms.append(getstate(stamp-100, track_dict, id))
99 | else:
100 | # If the ID was not present in a previous timestamp, set to None.
101 | vis_ms.append(None)
102 | prev_vis_ms.append(None)
103 | widths.append(None)
104 | lengths.append(None)
105 |
106 | # If there are new visible cars, extend the lists for accelerations and data.
107 | if num_new != 0:
108 | vis_ax = vis_ax+ [0]*num_new
109 | vis_ay = vis_ay+ [0]*num_new
110 | vis_datas = vis_datas + [[]]*num_new
111 | ref_datas = ref_datas + [[]]*num_new
112 |
113 | # Compute the accelerations for all the visible vehicles.
114 | for k in range(num_vis):
115 | # Only compute the acceleration, if the there are two timestamps of data for the vehicle.
116 | if vis_ms[k] is not None:
117 | vis_ax[k] = (vis_ms[k].vx - prev_vis_ms[k].vx) / 0.1
118 | vis_ay[k] = (vis_ms[k].vy - prev_vis_ms[k].vy) / 0.1
119 |
120 | # Loop through all the visible vehicles.
121 | for k in range(num_vis):
122 | # Check if the vehicle is visible in at least two timestamps.
123 | if vis_ids[k] in vis_temp and vis_ids[k] in prev_vis_temp :
124 |
125 | # Generate label and sensor grids for the observed driver.
126 | label_grid_ref, center_ref, _, _, pre_local_x_ref, pre_local_y_ref = generateLabelGrid(stamp, track_dict, vis_ids[k], object_id, ego_flag=False, res=res, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id)
127 | width = widths[k]
128 | length = lengths[k]
129 | sensor_grid_ref, _, _ = generateSensorGrid(label_grid_ref, pre_local_x_ref, pre_local_y_ref, vis_ms[k], width, length, res=res, ego_flag=False)
130 |
131 | # Mark the ego car as free.
132 | sensor_grid_ref[label_grid_ref[0] == 2] = 0
133 | label_grid_ref[0] = np.where(label_grid_ref[0]==2., 0. ,label_grid_ref[0])
134 |
135 | # Store data for visible vehicles for the ego vehicle and the observed driver.
136 | vis_step_data = [stamp, vis_ids[k], vis_ms[k].x, vis_ms[k].y, vis_ms[k].psi_rad, vis_ms[k].vx, vis_ms[k].vy, vis_ax[k], vis_ay[k], label_grid_ref[0], label_grid_ref[3], sensor_grid_ref, np.nan]
137 | ref_step_data = [stamp, vis_ids[k], vis_ms[k].x, vis_ms[k].y, vis_ms[k].psi_rad, vis_ms[k].vx, vis_ms[k].vy, vis_ax[k], vis_ay[k], label_grid_ref[0], sensor_grid_ref]
138 | ref_datas[k] = ref_datas[k]+[ref_step_data]
139 | vis_datas[k] = vis_datas[k]+[vis_step_data]
140 |
141 | # Note ref_datas may include unseen vehicles (or vehicles seen for just one timestamp).
142 | return vis_datas, ref_datas
143 |
144 | def Dataprocessing():
145 | global vis_ids, vis_ax, vis_ay, ref_ax, ref_ay, ego_id, res, gridglobal_x, gridglobal_y
146 |
147 | main_folder = '/data/INTERACTION-Dataset-DR-v1_1/'
148 | scenarios = ['DR_USA_Intersection_GL']
149 |
150 | # Total number of track files.
151 | num_files = [60]
152 | for scene, nth_scene in zip(scenarios, num_files):
153 | for i in tqdm(range(nth_scene)):
154 |
155 | i_str = ['%03d' % i][0]
156 | filename = os.path.join(main_folder, 'recorded_trackfiles/'+scene+'/'+'vehicle_tracks_'+ i_str +'.csv')
157 | track_dict = dataset_reader.read_tracks(filename)
158 | filename_pedes = os.path.join(main_folder, 'recorded_trackfiles/'+scene+'/'+'pedestrian_tracks_'+ i_str +'.csv')
159 |
160 | if os.path.exists(filename_pedes):
161 | track_pedes_dict = dataset_reader.read_pedestrian(filename_pedes)
162 | else:
163 | track_pedes_dict = None
164 |
165 | run_count = 0
166 |
167 | maps_dir = os.path.join(main_folder, "maps")
168 | min_max_xy = getmap(maps_dir, scene)
169 | xminGPS = min_max_xy[0]
170 | xmaxGPS = min_max_xy[2]
171 | yminGPS = min_max_xy[1]
172 | ymaxGPS = min_max_xy[3]
173 |
174 | res = 1.
175 | gridglobal_x,gridglobal_y = global_grid(np.array([xminGPS,yminGPS]),np.array([xmaxGPS,ymaxGPS]),res)
176 |
177 | vehobjects, pedesobjects = AllObjects(track_dict, track_pedes_dict)
178 |
179 | processed_file = glob.glob(os.path.join(main_folder, '/Processed_data_new_goal/pkl/DR_USA_Intersection_GL_'+i_str+'*_ego_*'))
180 | processed_id = [ int(file.split('_')[-1][:-4]) for file in processed_file]
181 |
182 | num = 0
183 | sampled_key = [id for id in vehobjects if id not in processed_id][num:]
184 | run_count = num
185 | sampled_key = np.random.choice(sampled_key,np.minimum(100, len(sampled_key)),replace=False)
186 |
187 | for key, value in track_dict.items():
188 | assert isinstance(value, Track)
189 | if key in sampled_key:
190 |
191 | start_time = datetime.now()
192 | ego_data = [['timestamp', 'car_id', 'x', 'y', 'orientation', 'vx','vy', 'ax', 'ay', 'label_grid', 'id_grid', 'sensor_grid', 'occluded_id']]
193 | ref_data = [['timestamp', 'car_id', 'x', 'y', 'orientation', 'vx','vy', 'ax', 'ay', 'label_grid', 'sensor_grid']]
194 |
195 | ego_id = int(key)
196 | vis_ids = []
197 |
198 | start_timestamp = value.time_stamp_ms_first
199 | last_timestamp = value.time_stamp_ms_last
200 |
201 | ref_datas = []
202 | vis_datas = []
203 |
204 | vis_ax = []
205 | vis_ay = []
206 |
207 | # Get accelerations.
208 | for stamp in range(start_timestamp, last_timestamp, 100):
209 | object_id, pedes_id = SceneObjects(track_dict, stamp, track_pedes_dict)
210 | ego_ms = getstate(stamp, track_dict, ego_id)
211 |
212 | if stamp == start_timestamp :
213 | prev_ego_vx = ego_ms.vx
214 | prev_ego_vy = ego_ms.vy
215 |
216 | else:
217 | ego_ax = (ego_ms.vx - prev_ego_vx) / 0.1
218 | ego_ay = (ego_ms.vy - prev_ego_vy) / 0.1
219 |
220 | # Get label grid.
221 | label_grid_ego, center_ego, local_x_ego, local_y_ego, pre_local_x_ego, pre_local_y_ego = generateLabelGrid(stamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id)
222 |
223 | # Get sensor grid.
224 | width = value.width
225 | length = value.length
226 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x_ego, pre_local_y_ego, ego_ms, width, length, res=res, ego_flag=True)
227 |
228 | # Convert the ego grid cells to occupied.
229 | label_grid_ego[0] = np.where(label_grid_ego[0]==2., 1. ,label_grid_ego[0])
230 |
231 | # Ignore the first timestamp because we do not have acceleration data.
232 | if stamp != start_timestamp :
233 |
234 | # Save ego data.
235 | ego_step_data = [stamp, ego_id, ego_ms.x, ego_ms.y, ego_ms.psi_rad, ego_ms.vx, ego_ms.vy, ego_ax, ego_ay, label_grid_ego[0], label_grid_ego[3], sensor_grid_ego, occluded_id]
236 | ego_data.append(ego_step_data)
237 |
238 | vis_datas, ref_datas = vis_state(vis_datas, ref_datas, object_id, label_grid_ego, prev_id_grid_ego, sensor_grid_ego, prev_sensor_grid_ego, track_dict, stamp, start_timestamp, track_pedes_dict=track_pedes_dict, pedes_id=pedes_id)
239 |
240 | # Get the previous time stamp information.
241 | prev_id_grid_ego = label_grid_ego[3]
242 | prev_sensor_grid_ego = sensor_grid_ego
243 | prev_ego_vx = ego_ms.vx
244 | prev_ego_vy = ego_ms.vy
245 |
246 | if vis_ids is not None:
247 | ego_data = ego_data + sum(vis_datas,[])
248 |
249 | # Save the ego information.
250 | pkl_path = os.path.join(main_folder, 'processed_data/pkl/')
251 | if not os.path.exists(pkl_path):
252 | os.makedirs(pkl_path)
253 | ego_filename = scene+'_'+i_str+'_run_'+str(run_count)+'_ego_vehicle_'+str(key)
254 | pkl.dump(ego_data, open(str(hkl_path) + ego_filename+'.pkl', 'wb'))
255 |
256 | # Save only the visible reference drivers.
257 | for k in range(num_vis):
258 | ref_filename = scene+'_'+i_str+'_run_'+str(run_count)+'_ref_vehicle_'+str(vis_ids[k])
259 | ref_d = ref_data + ref_datas[k]
260 | pkl.dump(ref_d, open(pkl_path +ref_filename+'.pkl', 'wb'))
261 |
262 | run_count += 1
263 | end_time = datetime.now()
264 | print(ego_filename, ', execution time:', end_time - start_time)
265 |
266 | if __name__ == "__main__":
267 | Dataprocessing()
268 |
--------------------------------------------------------------------------------
/src/preprocess/get_driver_sensor_data.py:
--------------------------------------------------------------------------------
1 | # Code to form the state information for the driver sensor data and split into contiguous segments.
2 |
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 | import hickle as hkl
6 | import pickle as pkl
7 | import pdb
8 | import os
9 | import fnmatch
10 | import csv
11 | from tqdm import tqdm
12 |
13 | dir_train_set = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/'
14 | dir = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/'
15 | for split in ['train', 'val', 'test']:
16 |
17 | train_set = os.path.join(dir_train_set, 'ref_' + split + '_set.csv')
18 | train_set_files = []
19 |
20 | with open(train_set,'rt') as f:
21 | data = csv.reader(f)
22 | for row in data:
23 | train_set_files.append(row[0])
24 |
25 | dir_pickle = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_' + split + '_data/'
26 |
27 | if not os.path.isdir(dir_pickle):
28 | os.mkdir(dir_pickle)
29 |
30 | grid_shape = (70,50)
31 |
32 | # Set-up count for number of files processed.
33 | count = 0
34 |
35 | for file in tqdm(os.listdir(dir)):
36 | file_path = os.path.join(dir, file)
37 |
38 | if file_path in train_set_files:
39 | count += 1
40 | f = open(file_path, 'rb')
41 | X = pkl.load(f)
42 | f.close()
43 |
44 | # Check if the data is empty. Skip if it is.
45 | if len(X) == 1:
46 | continue
47 |
48 | ts = []
49 | posx = []
50 | posy = []
51 | orientation = []
52 | velx = []
53 | vely = []
54 | accx = []
55 | accy = []
56 | grid_driver = []
57 | split_ids = []
58 |
59 | for i in range(1,len(X)):
60 | if ((i != 1) and (X[i][0] != (ts[-1] + 100))):
61 | split_ids.append(i-1)
62 | ts.append(X[i][0])
63 | posx.append(X[i][2])
64 | posy.append(X[i][3])
65 | orientation.append(X[i][4])
66 | velx.append(X[i][5])
67 | vely.append(X[i][6])
68 | accx.append(X[i][7])
69 | accy.append(X[i][8])
70 |
71 | # Reduce grid shape to (20,30).
72 | grid_driver.append(X[i][9][25:-25,:30])
73 |
74 | start = 0
75 | if len(split_ids) > 0:
76 | for i in range(len(split_ids)):
77 | if len(ts[start:split_ids[i]]) == 0:
78 | print("the loop is a problem")
79 | pdb.set_trace()
80 | pkl.dump([np.array(grid_driver[start:split_ids[i]]),\
81 | np.array(ts[start:split_ids[i]]),\
82 | np.array(posx[start:split_ids[i]]),\
83 | np.array(posy[start:split_ids[i]]),\
84 | np.array(orientation[start:split_ids[i]]),\
85 | np.array(velx[start:split_ids[i]]),\
86 | np.array(vely[start:split_ids[i]]),\
87 | np.array(accx[start:split_ids[i]]),\
88 | np.array(accy[start:split_ids[i]])],\
89 | open(os.path.join(dir_pickle, file[0:-4] + '_' + chr(i + 65) + '.pkl'), 'wb'))
90 | start = split_ids[i]
91 |
92 | if len(ts[start:]) == 0:
93 | print("the last one is a problem")
94 | pdb.set_trace()
95 | ts = np.array(ts[start:])
96 | posx = np.array(posx[start:])
97 | posy = np.array(posy[start:])
98 | orientation = np.array(orientation[start:])
99 | velx = np.array(velx[start:])
100 | vely = np.array(vely[start:])
101 | accx = np.array(accx[start:])
102 | accy = np.array(accy[start:])
103 | grid_driver = np.array(grid_driver[start:])
104 |
105 | if len(split_ids) > 0:
106 | pkl.dump([grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy], open(os.path.join(dir_pickle, file[0:-4] + '_' + chr(i + 1 + 65) + '.pkl'), 'wb'))
107 | else:
108 | pkl.dump([grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy], open(os.path.join(dir_pickle, file[0:-4] + '.pkl'), 'wb'))
--------------------------------------------------------------------------------
/src/preprocess/preprocess_driver_sensor_data.py:
--------------------------------------------------------------------------------
1 | # Code to downsample the training set and to form numpy arrays for the driver sensor dataset with the associated unique sources.
2 |
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 | from scipy import signal
6 |
7 | import os
8 | import pdb
9 | import pickle as pkl
10 | import hickle as hkl
11 | from tqdm import tqdm
12 | import random
13 |
14 | random.seed(123)
15 | np.random.seed(seed=123)
16 |
17 | dir_pickle_train = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_train_data/'
18 | dir_pickle_val = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_val_data/'
19 | dir_pickle_test = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_label_grids_test_data/'
20 | dirs_pickle = [dir_pickle_train, dir_pickle_val, dir_pickle_test]
21 | dir_dataset = '/data/INTERACTION-Dataset-DR-v1_1/processed_data/driver_sensor_dataset/'
22 |
23 | if not os.path.isdir(dir_dataset):
24 | os.mkdir(dir_dataset)
25 |
26 | count_train = 0
27 | count_val = 0
28 | count_test = 0
29 |
30 | grids_train_list = []
31 | states_train_list = []
32 | sources_train_list = []
33 |
34 | grids_val_list = []
35 | states_val_list = []
36 | sources_val_list = []
37 |
38 | grids_test_list = []
39 | states_test_list = []
40 | sources_test_list = []
41 |
42 | for i in range(0, len(dirs_pickle)):
43 | dir_pickle = dirs_pickle[i]
44 |
45 | if i == 0:
46 | # Random shuffle the directories for the training set.
47 | dirs_list = sorted(os.listdir(dir_pickle))
48 | random.shuffle(dirs_list)
49 | else:
50 | dirs_list = sorted(os.listdir(dir_pickle))
51 |
52 | for file in tqdm(dirs_list):
53 | path = os.path.join(dir_pickle, file)
54 | [grid_driver, ts, posx, posy, orientation, velx, vely, accx, accy, goalx, goaly, tgoal] = pkl.load(open(path, 'rb'))
55 | state = np.vstack((posx,posy,orientation,velx,vely,accx,accy,goalx,goaly,tgoal)).T
56 |
57 | if grid_driver.shape == (0,):
58 | continue
59 |
60 | if i == 0:
61 | # Downsample the training set and ensure that at least 10 seconds of data exists for every observed trajectory.
62 | if count_train <= 70000 and state.shape[0] >= 10:
63 | count_train += 1
64 |
65 | states_train_list.append(state.astype('float32'))
66 | grids_train_list.append(grid_driver.astype('float32'))
67 | sources_train_list.append(np.repeat(np.array(file[:-4]), state.shape[0]))
68 |
69 | if count_train % 1000 == 0:
70 | print(count_train)
71 |
72 | elif i == 1:
73 | # Ensure that at least 10 seconds of data exists for every observed trajectory.
74 | if state.shape[0] >= 10:
75 |
76 | states_val_list.append(state.astype('float32'))
77 | grids_val_list.append(grid_driver.astype('float32'))
78 | sources_val_list.append(np.repeat(np.array(file[:-4]), state.shape[0]))
79 |
80 | count_val += 1
81 |
82 | elif i == 2:
83 | # Ensure that at least 10 seconds of data exists for every observed trajectory.
84 | if state.shape[0] >= 10:
85 | states_test_list.append(state.astype('float32'))
86 | grids_test_list.append(grid_driver.astype('float32'))
87 | sources_test_list.append(np.repeat(np.array(file[:-4]), state.shape[0]))
88 |
89 | count_test += 1
90 |
91 | if i == 0:
92 | states_train = np.concatenate(states_train_list, axis=0)
93 | grids_train = np.concatenate(grids_train_list, axis=0)
94 | sources_train = np.concatenate(sources_train_list, axis=0)
95 |
96 | hkl.dump(states_train, os.path.join(dir_dataset, 'states_shuffled_train.hkl'), mode='w')
97 | hkl.dump(grids_train, os.path.join(dir_dataset, 'label_grids_shuffled_train.hkl'), mode='w')
98 | hkl.dump(sources_train, os.path.join(dir_dataset, 'sources_shuffled_train.hkl'), mode='w')
99 |
100 | # Get mean and std.
101 | mean = np.mean(states_train, axis=0)
102 | std = np.std(states_train, axis=0)
103 | print('mean: ', mean, 'std: ', std)
104 |
105 | del(states_train)
106 | del(grids_train)
107 | del(sources_train)
108 | del(states_train_list)
109 | del(grids_train_list)
110 | del(sources_train_list)
111 |
112 | states_val = np.concatenate(states_val_list, axis=0)
113 | grids_val = np.concatenate(grids_val_list, axis=0)
114 | sources_val = np.concatenate(sources_val_list, axis=0)
115 |
116 | states_test = np.concatenate(states_test_list, axis=0)
117 | grids_test = np.concatenate(grids_test_list, axis=0)
118 | sources_test = np.concatenate(sources_test_list, axis=0)
119 |
120 | hkl.dump(states_val, os.path.join(dir_dataset, 'states_val.hkl'), mode='w')
121 | hkl.dump(grids_val, os.path.join(dir_dataset, 'label_grids_val.hkl'), mode='w')
122 | hkl.dump(sources_val, os.path.join(dir_dataset, 'sources_val.hkl'), mode='w')
123 | del(states_val)
124 | del(grids_val)
125 | del(sources_val)
126 |
127 | hkl.dump(states_test, os.path.join(dir_dataset, 'states_test.hkl'), mode='w')
128 | hkl.dump(grids_test, os.path.join(dir_dataset, 'label_grids_test.hkl'), mode='w')
129 | hkl.dump(sources_test, os.path.join(dir_dataset, 'sources_test.hkl'), mode='w')
130 | del(states_test)
131 | del(grids_test)
132 | del(sources_test)
--------------------------------------------------------------------------------
/src/preprocess/statistics.txt:
--------------------------------------------------------------------------------
1 | mean: [ 9.9949402e+02 9.9329816e+02 1.5631536e-01 -1.2793580e-01
2 | -6.1600453e-01 -1.1768220e-01 -1.2684283e-01]
3 | std: [1.9491096e+01 1.2559539e+01 2.0200148e+00 5.6204329e+00 2.5343339e+00
4 | 9.0749109e-01 7.8766114e-01]
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/src/preprocess/train_val_test_split.py:
--------------------------------------------------------------------------------
1 | # Code to split the dataset into train/validation/test.
2 |
3 | import argparse
4 | import os
5 | import time
6 | from collections import defaultdict
7 |
8 | import math
9 | import numpy as np
10 | import csv
11 |
12 | import hickle as hkl
13 | import glob
14 | from sklearn.model_selection import train_test_split
15 | import pdb
16 | import random
17 | from generate_data import getmap
18 | from tqdm import tqdm
19 | np.random.seed(123)
20 |
21 | if __name__ == "__main__":
22 | ego_file = glob.glob('/data/INTERACTION-Dataset-DR-v1_1/processed_data/pkl/*ego*')
23 | y = np.arange(len(ego_file))
24 | ego_temp, ego_test_files, _, _ = train_test_split(ego_file, y, test_size=0.1, random_state=42) # test set : 10% of total data
25 | y = np.arange(len(ego_temp))
26 | ego_train_files, ego_val_files, _, _ = train_test_split(ego_temp, y, test_size=5./90., random_state=42) # validation set : 5% of total data
27 |
28 | ego_train_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_train_set.csv"
29 | ego_val_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_val_set.csv"
30 | ego_test_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ego_test_set.csv"
31 | train_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_train_set.csv"
32 | val_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_val_set.csv"
33 | test_set_path = "/data/INTERACTION-Dataset-DR-v1_1/processed_data/train_test_split/ref_test_set.csv"
34 |
35 | with open(ego_train_set_path, 'w') as f:
36 | wr = csv.writer(f, delimiter=',')
37 | for row in ego_train_files:
38 | wr.writerow([row])
39 |
40 | with open(ego_val_set_path, 'w') as f:
41 | wr = csv.writer(f, delimiter=',')
42 | for row in ego_val_files:
43 | wr.writerow([row])
44 |
45 | with open(ego_test_set_path, 'w') as f:
46 | wr = csv.writer(f, delimiter=',')
47 | for row in ego_test_files:
48 | wr.writerow([row])
49 |
50 | # Training
51 | ref_train_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_train_files]
52 | ref_train_file = sum(ref_train_file, [])
53 |
54 | with open(train_set_path, 'w') as f:
55 | wr = csv.writer(f, delimiter=',')
56 | for row in ref_train_file:
57 | wr.writerow([row])
58 |
59 | # Validation
60 | ref_val_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_val_files]
61 | ref_val_file = sum(ref_val_file, [])
62 |
63 | with open(val_set_path, 'w') as f:
64 | wr = csv.writer(f)
65 | for row in ref_val_file:
66 | wr.writerow([row])
67 |
68 | # Test
69 | ref_test_file = [glob.glob("_".join(filename.split('_')[:-3])+'_ref_*') for filename in ego_test_files]
70 | ref_test_file = sum(ref_test_file, [])
71 |
72 | with open(test_set_path, 'w') as f:
73 | wr = csv.writer(f)
74 | for row in ref_test_file:
75 | wr.writerow([row])
76 |
77 | print('ego_train_files', len(ego_train_files),len(set(ego_train_files)))
78 | print('ego_val_files', len(ego_val_files),len(set(ego_val_files)))
79 | print('ego_test_files', len(ego_test_files),len(set(ego_test_files)))
80 | print('ref_train_file', len(ref_train_file),len(set(ref_train_file)))
81 | print('ref_val_file', len(ref_val_file),len(set(ref_val_file)))
82 | print('ref_test_file', len(ref_test_file),len(set(ref_test_file)))
83 |
84 | print('train/val/test split done')
85 |
--------------------------------------------------------------------------------
/src/utils/combinations.py:
--------------------------------------------------------------------------------
1 | # Code for obtaining the top num_modes most likely combinations for the fused ego vehicle grid using breadth-first search.
2 | # The algorithm we use is adapted from here: https://cs.stackexchange.com/questions/46910/efficiently-finding-k-smallest-elements-of-cartesian-product.
3 |
4 | import numpy as np
5 | from copy import deepcopy
6 |
7 | def compute_best_cost(alpha, indices):
8 | cost = 1.0
9 | for i in range(len(alpha)):
10 | cost *= alpha[i][indices[i]]
11 | return cost
12 |
13 | def covert_indices(alpha_indices, alpha_sorted, indices):
14 | best_indices = np.zeros(indices.shape)
15 | for i in range(len(alpha_sorted)):
16 | best_indices[i] = alpha_indices[i][indices[i]]
17 | return best_indices
18 |
19 | # Takes as input sorted alpha.
20 | def BFS(alpha, top=3):
21 | N = len(alpha)
22 | num_modes = len(alpha[0])
23 |
24 | alpha_sorted = []
25 | alpha_indices = []
26 | for i in range(len(alpha)):
27 | alpha_indices.append(np.argsort(alpha[i])[::-1])
28 | alpha_sorted.append(alpha[i][alpha_indices[-1]])
29 |
30 | Q = [np.zeros((N,)).astype('int')]
31 | best_index = np.array([])
32 | best_cost = np.array([])
33 | cost = np.array([compute_best_cost(alpha_sorted, Q[-1])])
34 |
35 | # Do BFS to obtain the top num_modes most likely products.
36 | while len(best_index) < top:
37 | idx = np.argmax(cost)
38 | if len(best_cost) == 0:
39 | best_cost = np.array([cost[idx]])
40 | best_index = np.array([covert_indices(alpha_indices, alpha_sorted, Q[idx])])
41 | else:
42 | best_cost = np.vstack((best_cost, cost[idx]))
43 | best_index = np.vstack((best_index, covert_indices(alpha_indices, alpha_sorted, Q[idx])))
44 | last = Q.pop(idx)
45 | cost = np.delete(cost, idx)
46 | if len(Q) == 0:
47 | for i in range(N):
48 | array = np.zeros((N,)).astype('int')
49 | array[i] = 1
50 | Q.append(array)
51 | if len(cost) == 0:
52 | cost = np.array([compute_best_cost(alpha_sorted, Q[-1])])
53 | else:
54 | cost = np.hstack((cost, compute_best_cost(alpha_sorted, Q[-1])))
55 | else:
56 | for i in range(N):
57 | array = deepcopy(last)
58 | if array[i] < num_modes-1:
59 | array[i] += 1
60 | else:
61 | continue
62 | Q.append(array)
63 | cost = np.hstack((cost, compute_best_cost(alpha_sorted, Q[-1])))
64 | if last[i] > 0:
65 | break
66 | return best_index, best_cost
--------------------------------------------------------------------------------
/src/utils/data_generator.py:
--------------------------------------------------------------------------------
1 | # Code for preprocessing and data generation. Code is adapted from: https://github.com/sisl/EvidentialSparsification.
2 | import torch
3 | torch.manual_seed(123)
4 | torch.backends.cudnn.deterministic = True
5 | torch.backends.cudnn.benchmark = False
6 | import torch.nn as nn
7 | import torch.utils.data as data
8 | import torchvision
9 |
10 | import numpy as np
11 | np.random.seed(0)
12 | import hickle as hkl
13 | import pdb
14 |
15 | def unnormalize(X, nt=1, norm=True):
16 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01])
17 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01])
18 |
19 | if norm:
20 | X = X * std + mean
21 | return X.astype(np.float32)
22 |
23 |
24 | def preprocess(X, norm=True):
25 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01])
26 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01])
27 | if norm:
28 | X = (X - mean)/std
29 | return X.astype(np.float32)
30 |
31 | # Data generator that creates sequences for input.
32 | class SequenceGenerator(data.Dataset):
33 | def __init__(self, data_file_state, data_file_grid, source_file, nt,
34 | batch_size=8, shuffle=False, sequence_start_mode='all', norm=True):
35 | self.state = hkl.load(data_file_state)
36 | self.grid = hkl.load(data_file_grid)
37 | self.grid = np.reshape(self.grid, (self.grid.shape[0], -1))
38 | self.grid = self.grid.astype(np.float32)
39 |
40 | print(self.grid.shape)
41 |
42 | # Source for each grid so when creating sequences can ensure that consecutive frames are from same data run.
43 | self.sources = hkl.load(source_file)
44 | self.nt = nt
45 | self.norm = norm
46 |
47 | if batch_size == None:
48 | self.batch_size = self.state.shape[0]
49 | else:
50 | self.batch_size = batch_size
51 | assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}'
52 | self.sequence_start_mode = sequence_start_mode
53 |
54 | # Allow for any possible sequence, starting from any frame.
55 | if self.sequence_start_mode == 'all':
56 | self.possible_starts = np.array([i for i in range(self.state.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]])
57 | # Create sequences where each unique frame is in at most one sequence.
58 | elif self.sequence_start_mode == 'unique':
59 | curr_location = 0
60 | possible_starts = []
61 | while curr_location < self.state.shape[0] - self.nt + 1:
62 | if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]:
63 | possible_starts.append(curr_location)
64 | curr_location += self.nt
65 | else:
66 | curr_location += 1
67 | self.possible_starts = possible_starts
68 |
69 | if shuffle:
70 | self.possible_starts = np.random.permutation(self.possible_starts)
71 | self.N_sequences = len(self.possible_starts)
72 |
73 | def __getitem__(self, idx):
74 | if torch.is_tensor(idx):
75 | idx = idx.tolist()
76 | idx = self.possible_starts[idx]
77 | batch_x = self.preprocess(self.state[idx:idx+self.nt])
78 | batch_y = self.grid[idx+self.nt-1]
79 | sources = self.sources[idx+self.nt-1]
80 | return batch_x, batch_y, sources
81 |
82 | def __len__(self):
83 | return self.N_sequences
84 |
85 | def preprocess(self, X):
86 | mean = np.array([9.9923511e+02, 9.9347858e+02, 1.3369133e-01, -3.7689242e-01, -6.0116798e-01, -1.3141568e-01, -1.2848811e-01])
87 | std = np.array([1.9491451e+01, 1.2455911e+01, 2.0100091e+00, 5.9114690e+00, 2.5912049e+00, 9.5518535e-01, 8.1346464e-01])
88 |
89 | if self.norm:
90 | X = (X - mean)/std
91 | return X.astype(np.float32)
92 |
--------------------------------------------------------------------------------
/src/utils/dataset_reader.py:
--------------------------------------------------------------------------------
1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | import csv
4 |
5 | from utils.dataset_types import MotionState, Track
6 |
7 |
8 | class Key:
9 | track_id = "track_id"
10 | frame_id = "frame_id"
11 | time_stamp_ms = "timestamp_ms"
12 | agent_type = "agent_type"
13 | x = "x"
14 | y = "y"
15 | vx = "vx"
16 | vy = "vy"
17 | psi_rad = "psi_rad"
18 | length = "length"
19 | width = "width"
20 |
21 |
22 | class KeyEnum:
23 | track_id = 0
24 | frame_id = 1
25 | time_stamp_ms = 2
26 | agent_type = 3
27 | x = 4
28 | y = 5
29 | vx = 6
30 | vy = 7
31 | psi_rad = 8
32 | length = 9
33 | width = 10
34 |
35 |
36 | def read_tracks(filename):
37 |
38 | with open(filename) as csv_file:
39 | csv_reader = csv.reader(csv_file, delimiter=',')
40 |
41 | track_dict = dict()
42 | track_id = None
43 |
44 | for i, row in enumerate(list(csv_reader)):
45 |
46 | if i == 0:
47 | # check first line with key names
48 | assert(row[KeyEnum.track_id] == Key.track_id)
49 | assert(row[KeyEnum.frame_id] == Key.frame_id)
50 | assert(row[KeyEnum.time_stamp_ms] == Key.time_stamp_ms)
51 | assert(row[KeyEnum.agent_type] == Key.agent_type)
52 | assert(row[KeyEnum.x] == Key.x)
53 | assert(row[KeyEnum.y] == Key.y)
54 | assert(row[KeyEnum.vx] == Key.vx)
55 | assert(row[KeyEnum.vy] == Key.vy)
56 | assert(row[KeyEnum.psi_rad] == Key.psi_rad)
57 | assert(row[KeyEnum.length] == Key.length)
58 | assert(row[KeyEnum.width] == Key.width)
59 | continue
60 |
61 | if int(row[KeyEnum.track_id]) != track_id:
62 | # new track
63 | track_id = int(row[KeyEnum.track_id])
64 | assert(track_id not in track_dict.keys()), \
65 | "Line %i: Track id %i already in dict, track file not sorted properly" % (i+1, track_id)
66 | track = Track(track_id)
67 | track.agent_type = row[KeyEnum.agent_type]
68 | track.length = float(row[KeyEnum.length])
69 | track.width = float(row[KeyEnum.width])
70 | track.time_stamp_ms_first = int(row[KeyEnum.time_stamp_ms])
71 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms])
72 | track_dict[track_id] = track
73 |
74 | track = track_dict[track_id]
75 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms])
76 | ms = MotionState(int(row[KeyEnum.time_stamp_ms]))
77 | ms.x = float(row[KeyEnum.x])
78 | ms.y = float(row[KeyEnum.y])
79 | ms.vx = float(row[KeyEnum.vx])
80 | ms.vy = float(row[KeyEnum.vy])
81 | ms.psi_rad = float(row[KeyEnum.psi_rad])
82 | track.motion_states[ms.time_stamp_ms] = ms
83 |
84 | return track_dict
85 |
86 |
87 | def read_pedestrian(filename):
88 |
89 | with open(filename) as csv_file:
90 | csv_reader = csv.reader(csv_file, delimiter=',')
91 |
92 | track_dict = dict()
93 | track_id = None
94 |
95 | for i, row in enumerate(list(csv_reader)):
96 |
97 | if i == 0:
98 | # check first line with key names
99 | assert (row[KeyEnum.track_id] == Key.track_id)
100 | assert (row[KeyEnum.frame_id] == Key.frame_id)
101 | assert (row[KeyEnum.time_stamp_ms] == Key.time_stamp_ms)
102 | assert (row[KeyEnum.agent_type] == Key.agent_type)
103 | assert (row[KeyEnum.x] == Key.x)
104 | assert (row[KeyEnum.y] == Key.y)
105 | assert (row[KeyEnum.vx] == Key.vx)
106 | assert (row[KeyEnum.vy] == Key.vy)
107 | continue
108 |
109 | if row[KeyEnum.track_id] != track_id:
110 | # new track
111 | track_id = row[KeyEnum.track_id]
112 | assert (track_id not in track_dict.keys()), \
113 | "Line %i: Track id %s already in dict, track file not sorted properly" % (i + 1, track_id)
114 | track = Track(track_id)
115 | track.agent_type = row[KeyEnum.agent_type]
116 | track.time_stamp_ms_first = int(row[KeyEnum.time_stamp_ms])
117 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms])
118 | track_dict[track_id] = track
119 |
120 | track = track_dict[track_id]
121 | track.time_stamp_ms_last = int(row[KeyEnum.time_stamp_ms])
122 | ms = MotionState(int(row[KeyEnum.time_stamp_ms]))
123 | ms.x = float(row[KeyEnum.x])
124 | ms.y = float(row[KeyEnum.y])
125 | ms.vx = float(row[KeyEnum.vx])
126 | ms.vy = float(row[KeyEnum.vy])
127 | track.motion_states[ms.time_stamp_ms] = ms
128 |
129 | return track_dict
130 |
--------------------------------------------------------------------------------
/src/utils/dataset_types.py:
--------------------------------------------------------------------------------
1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | DELTA_TIMESTAMP_MS = 100 # similar throughout the whole dataset
4 |
5 | class MotionState:
6 | def __init__(self, time_stamp_ms):
7 | assert isinstance(time_stamp_ms, int)
8 | self.time_stamp_ms = time_stamp_ms
9 | self.x = None
10 | self.y = None
11 | self.vx = None
12 | self.vy = None
13 | self.psi_rad = None
14 |
15 | def __str__(self):
16 | return "MotionState: " + str(self.__dict__)
17 |
18 |
19 | class Track:
20 | def __init__(self, id):
21 | # assert isinstance(id, int)
22 | self.track_id = id
23 | self.agent_type = None
24 | self.length = None
25 | self.width = None
26 | self.time_stamp_ms_first = None
27 | self.time_stamp_ms_last = None
28 | self.motion_states = dict()
29 |
30 | def __str__(self):
31 | string = "Track: track_id=" + str(self.track_id) + ", agent_type=" + str(self.agent_type) + \
32 | ", length=" + str(self.length) + ", width=" + str(self.width) + \
33 | ", time_stamp_ms_first=" + str(self.time_stamp_ms_first) + \
34 | ", time_stamp_ms_last=" + str(self.time_stamp_ms_last) + \
35 | "\n motion_states:"
36 | for key, value in sorted(self.motion_states.items()):
37 | string += "\n " + str(key) + ": " + str(value)
38 | return string
39 |
--------------------------------------------------------------------------------
/src/utils/grid_fuse.py:
--------------------------------------------------------------------------------
1 | # Code to transform the driver sensor OGMs to the ego vehicle's OGM frame of reference.
2 |
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import math
6 | import copy
7 | from utils.grid_utils import global_grid
8 | import time
9 | from scipy.spatial import cKDTree
10 | import pdb
11 |
12 | def mask_in_EgoGrid(global_grid_x, global_grid_y, ref_xy, ego_xy, pred_egoGrid, pred_maps, res, mask_unk=None, tolerance=1):
13 |
14 | # Consider only the unknown cells in pred_egoGrid (ego sensor grid before trasfering values).
15 | indices = np.where(mask_unk)
16 | ego_x = ego_xy[0][indices]
17 | ego_y = ego_xy[1][indices]
18 | ego_xy = [ego_x, ego_y]
19 | flat_indicies = indices[0]*pred_egoGrid.shape[1]+indices[1]
20 |
21 | # ref indx --> global indx
22 | ref_x_ind = np.floor(global_grid_x.shape[1]*(ref_xy[0]-x_min+res/2.)/(x_max-x_min+res)).astype(int) # column index
23 | ref_y_ind = np.floor(global_grid_y.shape[0]*(ref_xy[1]-y_min+res/2.)/(y_max-y_min+res)).astype(int) # row index
24 |
25 | ref_global_ind = np.vstack((ref_y_ind.flatten(), ref_x_ind.flatten())).T
26 |
27 | # ego indx --> global indx
28 | ego_x_ind = np.floor(global_grid_x.shape[1]*(ego_xy[0]-x_min+res/2.)/(x_max-x_min+res)).astype(int) # column index
29 | ego_y_ind = np.floor(global_grid_y.shape[0]*(ego_xy[1]-y_min+res/2.)/(y_max-y_min+res)).astype(int) # row index
30 | ego_global_ind = np.vstack((ego_y_ind.flatten(), ego_x_ind.flatten())).T
31 |
32 | # Look for the matching global_grid indices between the ref_grid and ego_grid.
33 | kdtree = cKDTree(ref_global_ind)
34 | dists, inds = kdtree.query(ego_global_ind)
35 |
36 | pred_egoGrid_flat = pred_egoGrid.flatten()
37 | pred_maps_flat = pred_maps.flatten()
38 |
39 | # Back to the local grid indices. Tolerance should be an integer because kd tree is comparing indices.
40 | ego_ind = flat_indicies[np.where(dists<=tolerance)]
41 | ref_ind = inds[np.where(dists<=tolerance)]
42 |
43 | # Assign the values for the corresponding cells.
44 | pred_egoGrid_flat[ego_ind] = pred_maps_flat[ref_ind]
45 | pred_egoGrid = pred_egoGrid_flat.reshape(pred_egoGrid.shape)
46 | return pred_egoGrid
47 |
48 | def Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, ego_sensor_grid, endpoint, res=0.1, mask_unk=None):
49 | global x_min, x_max, y_min, y_max
50 | #####################################################################################################################################
51 | ## Goal : Transfer pred_maps (in driver sensor's grid) cell information to the unknown cells of ego car's sensor_grid
52 | ## Method : Use global grid as an intermediate (ref indx --> global indx --> ego indx).
53 | ## ref_local_xy (N, 2, w, h) & pred_maps (N, w, h)
54 | ## ego_xy (2, w', h') & & ego_sensor_grid (w', h')
55 | ## return pred_maps_egoGrid(N, w', h')
56 | ## * N : number of agents
57 | #####################################################################################################################################
58 |
59 | x_min = endpoint[0]
60 | x_max = endpoint[2]
61 | y_min = endpoint[1]
62 | y_max = endpoint[3]
63 |
64 | global_res = 1.0
65 | global_grid_x, global_grid_y = global_grid(np.array([x_min,y_min]),np.array([x_max,y_max]),global_res)
66 |
67 | if np.any(ref_local_xy[0] == None):
68 | pred_maps_egoGrid.append(None)
69 |
70 | else:
71 | pred_egoGrid = copy.copy(ego_sensor_grid)
72 | pred_egoGrid = np.ones(ego_sensor_grid.shape)*2
73 |
74 | pred_egoGrid = mask_in_EgoGrid(global_grid_x, global_grid_y, ref_local_xy, ego_local_xy, pred_egoGrid, pred_maps, res, mask_unk)
75 |
76 | return pred_egoGrid
--------------------------------------------------------------------------------
/src/utils/grid_utils.py:
--------------------------------------------------------------------------------
1 | # Code includes utilities for evidential fusion and OGMs.
2 |
3 | import numpy as np
4 | from utils.dataset_types import Track, MotionState
5 | import pdb
6 | import time
7 | from matplotlib import pyplot as plt
8 |
9 | # Outputs grid_dst: [2,w,h].
10 | # The first channel denotes occupied and free belief masses.
11 | def get_belief_mass(grid, ego_flag=True, m=None):
12 | grid_dst = np.zeros((2,grid.shape[0], grid.shape[1]))
13 |
14 | if ego_flag:
15 | mask_occ = grid == 1
16 | mask_free = grid == 0
17 | mask_unk = grid == 2
18 |
19 | else:
20 | mask_not_unk = grid != 2
21 | mask_unk = None
22 |
23 | if m is not None:
24 | mass = m
25 | elif ego_flag:
26 | mass = 1.0
27 | else:
28 | mass = 0.75
29 |
30 | if ego_flag:
31 | grid_dst[0,mask_occ] = mass
32 | grid_dst[1,mask_free] = mass
33 |
34 | else:
35 | grid_dst[0, mask_not_unk] = grid[mask_not_unk] * mass
36 | grid_dst[1, mask_not_unk] = (1.0-grid[mask_not_unk]) * mass
37 |
38 | return grid_dst, mask_unk
39 |
40 | def dst_fusion(sensor_grid_dst, ego_grid_dst, mask_unk):
41 |
42 | fused_grid_dst = np.zeros(ego_grid_dst.shape)
43 |
44 | # predicted unknown mass
45 | ego_unknown = 1. - ego_grid_dst[0] - ego_grid_dst[1]
46 |
47 | # measurement masses: meas_m_free, meas_m_occ
48 | sensor_unknown = 1. - sensor_grid_dst[0] - sensor_grid_dst[1]
49 |
50 | # Implement DST rule of combination.
51 | K = np.multiply(ego_grid_dst[1], sensor_grid_dst[0]) + np.multiply(ego_grid_dst[0], sensor_grid_dst[1])
52 |
53 | fused_grid_dst[0] = np.divide((np.multiply(ego_grid_dst[0], sensor_unknown) + np.multiply(ego_unknown, sensor_grid_dst[0]) + np.multiply(ego_grid_dst[0], sensor_grid_dst[0])), (1. - K))
54 | fused_grid_dst[1] = np.divide((np.multiply(ego_grid_dst[1], sensor_unknown) + np.multiply(ego_unknown, sensor_grid_dst[1]) + np.multiply(ego_grid_dst[1], sensor_grid_dst[1])), (1. - K))
55 |
56 | pignistic_grid = pignistic(fused_grid_dst)
57 |
58 | return fused_grid_dst, pignistic_grid
59 |
60 | def pignistic(grid_dst):
61 | grid = 0.5*grid_dst[0] + 0.5*(1.-grid_dst[1])
62 | return grid
63 |
64 | def rotate_around_center(pts, center, yaw):
65 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center
66 |
67 | # Create a grid of x and y values that have an associated x,y centre coordinate in the global frame.
68 | def global_grid(origin,endpoint,res):
69 |
70 | xmin = min(origin[0],endpoint[0])
71 | xmax = max(origin[0],endpoint[0])
72 | ymin = min(origin[1],endpoint[1])
73 | ymax = max(origin[1],endpoint[1])
74 |
75 | x_coords = np.arange(xmin,xmax,res)
76 | y_coords = np.arange(ymin,ymax,res)
77 |
78 | gridx,gridy = np.meshgrid(x_coords,y_coords)
79 | return gridx.T,gridy.T
80 |
81 | def local_grid(ms, width, length, res, ego_flag=True, grid_shape=None):
82 |
83 | center = np.array([ms.x, ms.y])
84 |
85 | if ego_flag:
86 | minx = center[0] - 10.
87 | miny = center[1] - 35.
88 | maxx = center[0] + 50.
89 | maxy = center[1] + 35.
90 | else:
91 | if grid_shape is not None:
92 | minx = center[0] + length/2.
93 | miny = center[1] - grid_shape[0]/2.
94 | maxx = center[0] + length/2. + grid_shape[1]
95 | maxy = center[1] + grid_shape[0]/2.
96 | else:
97 | minx = center[0] + length/2.
98 | miny = center[1] - 35.
99 | maxx = center[0] + length/2. + 50.
100 | maxy = center[1] + 35.
101 |
102 | x_coords = np.arange(minx,maxx,res)
103 | y_coords = np.arange(miny,maxy,res)
104 |
105 | mesh_x, mesh_y = np.meshgrid(x_coords,y_coords)
106 | pre_local_x = mesh_x
107 | pre_local_y = np.flipud(mesh_y)
108 |
109 | xy_local = rotate_around_center(np.vstack((pre_local_x.flatten(), pre_local_y.flatten())).T, center, ms.psi_rad)
110 |
111 | x_local = xy_local[:,0].reshape(mesh_x.shape)
112 | y_local = xy_local[:,1].reshape(mesh_y.shape)
113 |
114 | if grid_shape is not None:
115 | x_local = grid_reshape(x_local, grid_shape)
116 | y_local = grid_reshape(y_local, grid_shape)
117 |
118 | elif ego_flag:
119 | grid_shape = (int(70/res),int(60/res))
120 |
121 | x_local = grid_reshape(x_local, grid_shape)
122 | y_local = grid_reshape(y_local, grid_shape)
123 |
124 | return x_local, y_local, pre_local_x, pre_local_y
125 |
126 | # Element in nd array closest to the scalar value v.
127 | def find_nearest(n,v,v0,vn,res):
128 | idx = int(np.floor( n*(v-v0+res/2.)/(vn-v0+res) ))
129 | return idx
130 |
131 | def grid_reshape(data, grid_shape):
132 | if len(data)==1:
133 | d = data[0]
134 | if d.shape[0]!= grid_shape[0] or d.shape[1]!= grid_shape[1]:
135 | data = [d[:grid_shape[0],:grid_shape[1]]]
136 |
137 | else:
138 | if len(data.shape) == 3:
139 | if data.shape[1] > grid_shape[0]:
140 | data = data[:,:grid_shape[0]]
141 | if data.shape[2] > grid_shape[1]:
142 | data = data[:,:,:grid_shape[1]]
143 |
144 | elif len(data.shape) == 2:
145 | if data.shape[0]!= grid_shape[0] or data.shape[1]!= grid_shape[1]:
146 | data = data[:grid_shape[0],:grid_shape[1]]
147 | return data
148 |
149 | ############## LabelGrid ###########################
150 | # 0 = empty
151 | # 1 = ego vehicle
152 | # 2 = vehicle
153 | # 3 = pedestrian
154 | # Pedestrians have negative labels (e.g., P1 -> -1).
155 | #####################################################
156 | def generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag, res=1., grid_shape=None, track_pedes_dict=None, pedes_id=None):
157 |
158 | # Initialize the labels.
159 | labels = []
160 | boxes_vehicles = []
161 | boxes_persons = []
162 | dynamics_vehicles = []
163 | dynamics_persons = []
164 |
165 | # The global reference direction of all vehicles and people is in the direction of the ego vehicle.
166 | for key, value in track_dict.items():
167 | assert isinstance(value, Track)
168 | if key == ego_id:
169 | ms = value.motion_states[timestamp]
170 | assert isinstance(ms, MotionState)
171 | x_ego = ms.x
172 | y_ego = ms.y
173 | theta_ego = ms.psi_rad
174 | vx_ego = ms.vx
175 | vy_ego = ms.vy
176 | w_ego = value.width
177 | l_ego = value.length
178 |
179 | center = np.array([x_ego, y_ego])
180 | ms = getstate(timestamp, track_dict, ego_id)
181 |
182 | x_local, y_local, pre_local_x, pre_local_y = local_grid(ms, w_ego, l_ego, res, ego_flag=ego_flag, grid_shape=grid_shape)
183 |
184 | label_grid = np.zeros((4,x_local.shape[0],x_local.shape[1]))
185 | label_grid[3] = np.nan
186 |
187 | for key, value in track_dict.items():
188 | assert isinstance(value, Track)
189 | if key in object_id:
190 | ms = value.motion_states[timestamp]
191 | assert isinstance(ms, MotionState)
192 |
193 | d = np.sqrt((ms.x-x_ego)**2 + (ms.y-y_ego)**2)
194 | if d < (np.sqrt(2.)*22.+2.*l_ego):
195 | vx = ms.vx
196 | vy = ms.vy
197 |
198 | w = value.width
199 | l = value.length
200 | ID = key
201 |
202 | if key == ego_id:
203 | w += res/4.
204 |
205 | coords = polygon_xy_from_motionstate(ms, w, l)
206 | mask = point_in_rectangle(x_local, y_local, coords)
207 | if ID == ego_id:
208 | # Mark as occupied with ego vehicle.
209 | label_grid[0,mask] = 2.
210 |
211 | else:
212 | # Mark as occupied with vehicle.
213 | label_grid[0,mask] = 1. #
214 |
215 | label_grid[3,mask] = value.track_id
216 |
217 | label_grid[1,mask] = vx
218 | label_grid[2,mask] = vy
219 |
220 | if track_pedes_dict != None:
221 | for key, value in track_pedes_dict.items():
222 | assert isinstance(value, Track)
223 | if key in pedes_id:
224 | ms = value.motion_states[timestamp]
225 | assert isinstance(ms, MotionState)
226 |
227 | d = np.sqrt((ms.x-x_ego)**2 + (ms.y-y_ego)**2)
228 | if d < (np.sqrt(2.)*22.+2.*l_ego):
229 | vx = ms.vx
230 | vy = ms.vy
231 |
232 | w = 1.5
233 | l = 1.5
234 | ID = key
235 |
236 | coords = polygon_xy_from_motionstate_pedest(ms, w, l)
237 | mask = point_in_rectangle(x_local, y_local, coords)
238 | # Mark as occupied with vehicle.
239 | label_grid[0,mask] = 1. #
240 |
241 | label_grid[3,mask] = -1.0 * float(value.track_id[1:])
242 |
243 | label_grid[1,mask] = vx
244 | label_grid[2,mask] = vy
245 |
246 | if grid_shape == None:
247 | if ego_flag:
248 | grid_shape = (int(70/res),int(60/res))
249 | else:
250 | grid_shape = (int(70/res),int(50/res))
251 |
252 | label_grid = grid_reshape(label_grid, grid_shape)
253 | x_local = grid_reshape(x_local, grid_shape)
254 | y_local = grid_reshape(y_local, grid_shape)
255 | pre_local_x = grid_reshape(pre_local_x, grid_shape)
256 | pre_local_y = grid_reshape(pre_local_y, grid_shape)
257 |
258 | return label_grid, center, x_local, y_local, pre_local_x, pre_local_y
259 |
260 | ############## SensorGrid ##################
261 | # 0: empty
262 | # 1 : occupied
263 | # 2 : unknown
264 | #############################################
265 | def generateSensorGrid(labels_grid, pre_local_x, pre_local_y, ms, width, length, res=1., ego_flag=True, grid_shape=None):
266 | center_ego = (ms.x, ms.y)
267 | theta_ego = ms.psi_rad
268 | occluded_id = []
269 | visible_id = []
270 | # All the angles of the LiDAR simulation.
271 | angle_res = 0.01 #0.002
272 | if ego_flag:
273 | angles = np.arange(0., 2.*np.pi+angle_res, angle_res)
274 | else:
275 | angles = np.arange(0., 2.*np.pi+angle_res, angle_res)
276 |
277 | # Get the maximum and minimum x and y values in the local grids.
278 | x_min = np.amin(pre_local_x)
279 | x_max = np.amax(pre_local_x)
280 | y_min = np.amin(pre_local_y)
281 | y_max = np.amax(pre_local_y)
282 |
283 | x_shape = pre_local_x.shape[0]
284 | y_shape = pre_local_y.shape[1]
285 |
286 | # Get the cells not occupied by the ego vehicle.
287 | mask = np.where(labels_grid[0]!=2, True,False)
288 |
289 | # No need to do ray tracing if no object on the grid.
290 | if np.all(labels_grid[0,mask]==0.):
291 | sensor_grid = np.zeros((x_shape, y_shape))
292 |
293 | else:
294 | # Generate a line from the center indices to the edge of the local grid: sqrt(2)*128./3. meters away (LiDAR distance).
295 | r = (np.sqrt(x_shape**2 + y_shape**2) + 10) * 1.0 * res
296 | x = (r*np.cos(angles)+center_ego[0]) # length of angles
297 | y = (r*np.sin(angles)+center_ego[1]) # length of angles
298 |
299 | sensor_grid = np.zeros((x_shape, y_shape))
300 |
301 | for i in range(x.shape[0]):
302 |
303 | if x[i] < center_ego[0]:
304 | x_range = np.arange(center_ego[0],np.maximum(x[i], x_min-res),-res*angle_res)
305 | else:
306 | x_range = np.arange(center_ego[0],np.minimum(x[i]+res, x_max+res),res*angle_res)
307 |
308 | # Find the corresponding ys.
309 | y_range = linefunction(center_ego[0],center_ego[1],x[i],y[i],x_range)
310 |
311 | y_temp = np.floor(y_shape*(x_range-x_min-res/2.)/(x_max-x_min+res)).astype(int)
312 | x_temp = np.floor(x_shape*(y_range-y_min-res/2.)/(y_max-y_min+res)).astype(int)
313 |
314 | # Take only the indices inside the local grid.
315 | indices = np.where(np.logical_and(np.logical_and(np.logical_and((x_temp < x_shape), (x_temp >= 0)), (y_temp < y_shape)), (y_temp >= 0)))
316 | x_temp = x_temp[indices]
317 | y_temp = y_temp[indices]
318 |
319 | # Found first occupied cell.
320 | labels_reduced = labels_grid[0,x_temp,y_temp]
321 |
322 | if len(labels_reduced)!=0 :
323 | unique_labels = np.unique(labels_reduced)
324 |
325 | # Check if there are any occupied cells.
326 | if np.any(unique_labels==1.):
327 | ind = np.where(labels_reduced == 1)
328 | sensor_grid[x_temp[ind[0][0]:],y_temp[ind[0][0]:]] = 2
329 | sensor_grid[x_temp[ind[0][0]], y_temp[ind[0][0]]] = 1
330 | else:
331 | sensor_grid[x_temp, y_temp] = 0
332 |
333 | # No ego id included.
334 | unique_id = np.unique(labels_grid[3,:,:])
335 | unique_id= np.delete(unique_id, np.where(np.isnan(unique_id)))
336 |
337 | # Set the unknown area as free.
338 | for id in unique_id:
339 | mask = (labels_grid[3,:,:]==id)
340 | if np.any(sensor_grid[mask] == 1):
341 | # Set as occupied/visible.
342 | sensor_grid[mask] = 1
343 | visible_id.append(id)
344 | # Ignore ego vehicle.
345 | elif np.any(labels_grid[0,mask]== 2): # except ego car
346 | pass
347 | else:
348 | # Set as occluded/invisible.
349 | sensor_grid[mask] = 2
350 | occluded_id.append(id)
351 |
352 | # Set the ego vehicle as occupied.
353 | mask = (labels_grid[0,:,:] == 2)
354 | sensor_grid[mask] = 1
355 |
356 | if grid_shape == None:
357 | if ego_flag:
358 | grid_shape = (int(70/res),int(60/res))
359 | else:
360 | grid_shape = (int(70/res),int(50/res))
361 |
362 | sensor_grid = grid_reshape(sensor_grid, grid_shape)
363 |
364 | return sensor_grid, occluded_id, visible_id
365 |
366 | def linefunction(velx,vely,indx,indy,x_range):
367 | m = (indy-vely)/(indx-velx)
368 | b = vely-m*velx
369 | return m*x_range + b
370 |
371 | def point_in_rectangle(x, y, rectangle):
372 | A = rectangle[0]
373 | B = rectangle[1]
374 | C = rectangle[2]
375 |
376 | M = np.array([x,y]).transpose((1,2,0))
377 |
378 | AB = B-A
379 | AM = M-A
380 | BC = C-B
381 | BM = M-B
382 |
383 | dotABAM = np.dot(AM,AB)
384 | dotABAB = np.dot(AB,AB)
385 | dotBCBM = np.dot(BM,BC)
386 | dotBCBC = np.dot(BC,BC)
387 |
388 | return np.logical_and(np.logical_and(np.logical_and((0. <= dotABAM), (dotABAM <= dotABAB)), (0. <= dotBCBM)), (dotBCBM <= dotBCBC)) # nxn
389 |
390 | def rotate_around_center(pts, center, yaw):
391 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center
392 |
393 |
394 | def polygon_xy_from_motionstate(ms, width, length):
395 | assert isinstance(ms, MotionState)
396 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
397 | lowright = (ms.x + length / 2., ms.y - width / 2.)
398 | upright = (ms.x + length / 2., ms.y + width / 2.)
399 | upleft = (ms.x - length / 2., ms.y + width / 2.)
400 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad)
401 |
402 | def polygon_xy_from_motionstate_pedest(ms, width, length):
403 | assert isinstance(ms, MotionState)
404 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
405 | lowright = (ms.x + length / 2., ms.y - width / 2.)
406 | upright = (ms.x + length / 2., ms.y + width / 2.)
407 | upleft = (ms.x - length / 2., ms.y + width / 2.)
408 | return np.array([lowleft, lowright, upright, upleft])
409 |
410 | def getstate(timestamp, track_dict, id):
411 | for key, value in track_dict.items():
412 | if key==id:
413 | return value.motion_states[timestamp]
414 |
415 | # All objects at time step t in interval [time_stamp_ms_start,time_stamp_ms_last].
416 | def SceneObjects(track_dict, time_step, track_pedes_dict=None):
417 | object_id = []
418 | pedes_id = []
419 | if track_dict != None:
420 | for key, value in track_dict.items():
421 | assert isinstance(value, Track)
422 |
423 | if value.time_stamp_ms_first <= time_step <= value.time_stamp_ms_last:
424 | object_id.append(value.track_id)
425 |
426 | if track_pedes_dict != None:
427 | for key, value in track_pedes_dict.items():
428 | assert isinstance(value, Track)
429 |
430 | if value.time_stamp_ms_first <= time_step <= value.time_stamp_ms_last:
431 | pedes_id.append(value.track_id)
432 |
433 | return object_id, pedes_id
434 |
435 | # All objects in the scene.
436 | def AllObjects(track_dict, track_pedes_dict=None):
437 | object_id = []
438 | for key, value in track_dict.items():
439 | assert isinstance(value, Track)
440 |
441 | object_id.append(value.track_id)
442 |
443 | pedes_id = []
444 | if track_pedes_dict != None:
445 | for key, value in track_pedes_dict.items():
446 | assert isinstance(value, Track)
447 |
448 | pedes_id.append(value.track_id)
449 |
450 | return object_id, pedes_id
--------------------------------------------------------------------------------
/src/utils/interaction_utils.py:
--------------------------------------------------------------------------------
1 | # Code for computing the IS metric and visualizing metrics. IS metric code is modified from: https://github.com/BenQLange/Occupancy-Grid-Metrics.
2 |
3 | import numpy as np
4 | from matplotlib import pyplot as plt
5 | import matplotlib.patches as mpatches
6 | import os
7 | import imageio
8 | from copy import deepcopy
9 | import pdb
10 | from tqdm import tqdm
11 |
12 | # Metric visualization.
13 | def plot_metric_clusters(K, states, cluster_ids, metric_grids, metric, dir, sources_val):
14 | patches = []
15 | color = plt.cm.jet(np.linspace(0,1,K))
16 |
17 | for c in range(K):
18 | patches.append(mpatches.Patch(color=color[c], label=c))
19 |
20 | unique_sources = np.unique(np.array(sources_val))
21 |
22 | for source in unique_sources:
23 | if source[-1] != scenario:
24 | continue
25 |
26 | mask = (np.array(sources_val) == source)
27 |
28 | batch_x_np_val_scenario = states[mask]
29 | cluster_ids_y_val_scenario = cluster_ids[mask]
30 | metric_grids_scenario = metric_grids[mask]
31 |
32 | if dict_scenario[scenario] == 0:
33 | fig, (ax1, ax2, ax3, ax4) = plt.subplots(4)
34 | symbol = '.'
35 | else:
36 | symbol = 'x'
37 |
38 | for k in range(batch_x_np_val_scenario.shape[0]):
39 | ax1.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,1], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol)
40 | ax2.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,2], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol)
41 | ax3.scatter(batch_x_np_val_scenario[k,:,1], batch_x_np_val_scenario[k,:,3], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol)
42 | ax4.scatter(batch_x_np_val_scenario[k,0,1], metric_grids_scenario[k], color=color[cluster_ids_y_val_scenario[k]], label=str(cluster_ids_y_val_scenario[k]), linewidths=1.0, marker=symbol)
43 |
44 | if dict_scenario[scenario] == 1:
45 | ax1.set_ylabel("Pos (m)")
46 | ax1.set_ylim(-5,120)
47 | ax1.set_xlim(-5,120)
48 | ax2.set_ylabel("Vel (m/s)")
49 | ax2.set_ylim(0,8)
50 | ax2.set_xlim(-5,120)
51 | ax3.set_ylabel("Acc (m/s^2)")
52 | ax3.set_ylim(-3,3)
53 | ax3.set_xlim(-5,120)
54 | ax4.set_ylim(0,1)
55 | if metric == 'IM':
56 | ax4.set_ylim(0,50)
57 | if metric == 'IM2':
58 | ax4.set_ylim(0, 50)
59 | ax4.set_xlim(-5,120)
60 | ax4.set_xlabel("Position (m)")
61 | ax4.set_ylabel(metric)
62 | fig.tight_layout(pad=.05)
63 |
64 | picture_file = os.path.join(dir, 'val_examples_K_' + str(K) + '_scenario_' + scenario + '_' + metric + '.png')
65 | fig.savefig(picture_file)
66 | fig.clf()
67 |
68 | dict_scenario[scenario] += 1
69 |
70 | def plot_train(K, p_m_a_np, cluster_centers_np, dir, grid_shape):
71 | fig, (ax1, ax2, ax3) = plt.subplots(3)
72 | fig.suptitle('State Clusters')
73 |
74 | for k in range(K):
75 | fig_occ, ax_occ = plt.subplots(1)
76 | image = np.flip(np.transpose(1.0-np.reshape(p_m_a_np[k], grid_shape), (1,0)), axis=0)
77 | ax_occ.imshow(image, cmap='gray')
78 | picture_file = os.path.join(dir, str(K) + '_cluster_' + str(k) + '.png')
79 | plt.savefig(picture_file)
80 | fig_occ.clf()
81 |
82 | def plot_scatter(K, states, metric_vals, metric, dir):
83 | x = states[:,0,0]
84 | y = states[:,0,1]
85 | plt.figure()
86 | plt.scatter(x, y, marker='.', s=150, linewidths=1, c=metric_vals, cmap=plt.cm.coolwarm)
87 | cb = plt.colorbar()
88 | plt.clim(0,80)
89 | plt.xlabel('x (m)')
90 | plt.ylabel('y (m)')
91 | plt.title('IM over First Driver State')
92 | picture_file = os.path.join(dir, str(K) + '_' + metric + '_first_driver_state.png')
93 | plt.savefig(picture_file)
94 | plt.clf()
95 | del(cb)
96 |
97 | def plot_scatter_clusters(K, states, labels, dir):
98 | x = states[:,0,0]
99 | y = states[:,0,1]
100 | plt.figure()
101 | plt.scatter(x, y, marker='.', s=150, linewidths=1, c=labels, cmap=plt.cm.coolwarm)
102 | cb = plt.colorbar()
103 | plt.clim(0,K)
104 | plt.xlabel('x (m)')
105 | plt.ylabel('y (m)')
106 | plt.title('Clusters over First Driver State')
107 | picture_file = os.path.join(dir, str(K) + '_clusters_first_driver_state.png')
108 | plt.savefig(picture_file)
109 | plt.clf()
110 | del(cb)
111 |
112 | def plot_multimodal_metrics(Ks, acc_nums, mse_nums, im_nums, acc_nums_best, mse_nums_best, im_nums_best, dir):
113 | plt.figure()
114 | plt.plot(Ks, acc_nums, label='Acc')
115 | plt.plot(Ks, acc_nums_best, label='Best 3 Acc')
116 | plt.title('Multimodality Performance Accuracy')
117 | plt.xlabel('K')
118 | plt.ylabel('Acc')
119 | plt.legend()
120 | picture_file = os.path.join(dir, 'multimodality_acc_best_3.png')
121 | plt.savefig(picture_file)
122 |
123 | plt.figure()
124 | plt.plot(Ks, mse_nums, label='MSE')
125 | plt.plot(Ks, mse_nums_best, label='Best 3 MSE')
126 | plt.title('Multimodality Performance MSE')
127 | plt.xlabel('K')
128 | plt.ylabel('MSE')
129 | plt.legend()
130 | picture_file = os.path.join(dir, 'multimodality_mse_best_3.png')
131 | plt.savefig(picture_file)
132 |
133 | plt.figure()
134 | plt.plot(Ks, im_nums, label='IM')
135 | plt.plot(Ks, im_nums_best, label='Best 3 IM')
136 | plt.title('Multimodality Performance IM')
137 | plt.xlabel('K')
138 | plt.ylabel('IM')
139 | plt.legend()
140 | picture_file = os.path.join(dir, 'multimodality_im_best_3.png')
141 | plt.savefig(picture_file)
142 |
143 | # IS metric.
144 | def MapSimilarityMetric(grids_pred, grids_actual):
145 |
146 | num_samples,_,_ = grids_pred.shape
147 | score, score_occupied, score_free, score_occluded = np.zeros((num_samples,)), np.zeros((num_samples,)), np.zeros((num_samples,)), np.zeros((num_samples,)) # score_occluded, np.zeros((num_samples,))
148 | for sample in range(num_samples): # tqdm
149 | occupied, free, occluded = computeSimilarityMetric(grids_actual[sample,:,:], grids_pred[sample,:,:])
150 | score[sample] += occupied
151 | score[sample] += occluded
152 | score[sample] += free
153 |
154 | score_occupied[sample] += occupied
155 | score_occluded[sample] += occluded
156 | score_free[sample] += free
157 |
158 | return score, score_occupied, score_free, score_occluded
159 |
160 | def toDiscrete(m):
161 | """
162 | Args:
163 | - m (m,n) : np.array with the occupancy grid
164 | Returns:
165 | - discrete_m : thresholded m
166 | """
167 |
168 | y_size, x_size = m.shape
169 | m_occupied = np.zeros(m.shape)
170 | m_free = np.zeros(m.shape)
171 | m_occluded = np.zeros(m.shape)
172 |
173 | m_occupied[m == 1.0] = 1.0
174 | m_occluded[m == 0.5] = 1.0
175 | m_free[m == 0.0] = 1.0
176 |
177 | return m_occupied, m_free, m_occluded
178 |
179 | def todMap(m):
180 |
181 | """
182 | Extra if statements are for edge cases.
183 | """
184 |
185 | y_size, x_size = m.shape
186 | dMap = np.ones(m.shape) * np.Inf
187 | dMap[m == 1] = 0.0
188 |
189 | for y in range(0,y_size):
190 | if y == 0:
191 | for x in range(1,x_size):
192 | h = dMap[y,x-1]+1
193 | dMap[y,x] = min(dMap[y,x], h)
194 |
195 | else:
196 | for x in range(0,x_size):
197 | if x == 0:
198 | h = dMap[y-1,x]+1
199 | dMap[y,x] = min(dMap[y,x], h)
200 | else:
201 | h = min(dMap[y,x-1]+1, dMap[y-1,x]+1)
202 | dMap[y,x] = min(dMap[y,x], h)
203 |
204 | for y in range(y_size-1,-1,-1):
205 |
206 | if y == y_size-1:
207 | for x in range(x_size-2,-1,-1):
208 | h = dMap[y,x+1]+1
209 | dMap[y,x] = min(dMap[y,x], h)
210 |
211 | else:
212 | for x in range(x_size-1,-1,-1):
213 | if x == x_size-1:
214 | h = dMap[y+1,x]+1
215 | dMap[y,x] = min(dMap[y,x], h)
216 | else:
217 | h = min(dMap[y+1,x]+1, dMap[y,x+1]+1)
218 | dMap[y,x] = min(dMap[y,x], h)
219 |
220 | return dMap
221 |
222 | def computeDistance(m1,m2):
223 |
224 | y_size, x_size = m1.shape
225 | dMap = todMap(m2)
226 |
227 | d = np.sum(dMap[m1 == 1])
228 | num_cells = np.sum(m1 == 1)
229 |
230 | # If either of the grids does not have a particular class,
231 | # set to x_size + y_size (proxy for infinity - worst case Manhattan distance).
232 | # If both of the grids do not have a class, set to zero.
233 | if ((num_cells != 0) and (np.sum(dMap == np.Inf) == 0)):
234 | output = d/num_cells
235 | elif ((num_cells == 0) and (np.sum(dMap == np.Inf) != 0)):
236 | output = 0.0
237 | elif ((num_cells == 0) or (np.sum(dMap == np.Inf) != 0)):
238 | output = x_size + y_size
239 |
240 | if output == np.Inf:
241 | pdb.set_trace()
242 |
243 | return output
244 |
245 | def computeSimilarityMetric(m1, m2):
246 |
247 | m1_occupied, m1_free, m1_occluded = toDiscrete(m1)
248 | m2_occupied, m2_free, m2_occluded = toDiscrete(m2)
249 |
250 | occupied = computeDistance(m1_occupied,m2_occupied) + computeDistance(m2_occupied,m1_occupied)
251 | occluded = computeDistance(m2_occluded, m1_occluded) + computeDistance(m1_occluded, m2_occluded)
252 | free = computeDistance(m1_free,m2_free) + computeDistance(m2_free,m1_free)
253 |
254 | return occupied, free, occluded
--------------------------------------------------------------------------------
/src/utils/map_vis_without_lanelet.py:
--------------------------------------------------------------------------------
1 | # Code is from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | import matplotlib
4 | import matplotlib.axes
5 | import matplotlib.pyplot as plt
6 |
7 | import xml.etree.ElementTree as xml
8 | import pyproj
9 | import math
10 | import numpy as np
11 |
12 | # from utils import dict_utils
13 |
14 |
15 | class Point:
16 | def __init__(self):
17 | self.x = None
18 | self.y = None
19 |
20 |
21 | class LL2XYProjector:
22 | def __init__(self, lat_origin, lon_origin):
23 | self.lat_origin = lat_origin
24 | self.lon_origin = lon_origin
25 | self.zone = math.floor((lon_origin+180.)/6)+1 # works for most tiles, and for all in the dataset
26 | self.p = pyproj.Proj(proj='utm', ellps='WGS84', zone=self.zone, datum='WGS84')
27 | [self.x_origin, self.y_origin] = self.p(lon_origin, lat_origin)
28 |
29 | def latlon2xy(self, lat, lon):
30 | [x, y] = self.p(lon, lat)
31 | return [x-self.x_origin, y-self.y_origin]
32 |
33 |
34 | def get_type(element):
35 | for tag in element.findall("tag"):
36 | if tag.get("k") == "type":
37 | return tag.get("v")
38 | return None
39 |
40 |
41 | def get_subtype(element):
42 | for tag in element.findall("tag"):
43 | if tag.get("k") == "subtype":
44 | return tag.get("v")
45 | return None
46 |
47 |
48 | def get_x_y_lists(element, point_dict):
49 | x_list = list()
50 | y_list = list()
51 | for nd in element.findall("nd"):
52 | pt_id = int(nd.get("ref"))
53 | point = point_dict[pt_id]
54 | x_list.append(point.x)
55 | y_list.append(point.y)
56 | return x_list, y_list
57 |
58 |
59 | def set_visible_area(point_dict, axes):
60 | min_x = 10e9
61 | min_y = 10e9
62 | max_x = -10e9
63 | max_y = -10e9
64 |
65 | for id, point in dict_utils.get_item_iterator(point_dict):
66 | min_x = min(point.x, min_x)
67 | min_y = min(point.y, min_y)
68 | max_x = max(point.x, max_x)
69 | max_y = max(point.y, max_y)
70 |
71 | axes.set_aspect('equal', adjustable='box')
72 | axes.set_xlim([min_x - 10, max_x + 10])
73 | axes.set_ylim([min_y - 10, max_y + 10])
74 |
75 | return np.array([min_x, min_y, max_x, max_y])
76 |
77 |
78 | def draw_map_without_lanelet(filename, axes, lat_origin, lon_origin):
79 |
80 | assert isinstance(axes, matplotlib.axes.Axes)
81 |
82 | axes.set_aspect('equal', adjustable='box')
83 | axes.patch.set_facecolor('lightgrey')
84 |
85 | projector = LL2XYProjector(lat_origin, lon_origin)
86 |
87 | e = xml.parse(filename).getroot()
88 |
89 | point_dict = dict()
90 | for node in e.findall("node"):
91 | point = Point()
92 | point.x, point.y = projector.latlon2xy(float(node.get('lat')), float(node.get('lon')))
93 | point_dict[int(node.get('id'))] = point
94 |
95 | endpoint = set_visible_area(point_dict, axes)
96 |
97 | unknown_linestring_types = list()
98 |
99 | for way in e.findall('way'):
100 | way_type = get_type(way)
101 | if way_type is None:
102 | raise RuntimeError("Linestring type must be specified")
103 | elif way_type == "curbstone":
104 | type_dict = dict(color="black", linewidth=1, zorder=10)
105 | elif way_type == "line_thin":
106 | way_subtype = get_subtype(way)
107 | if way_subtype == "dashed":
108 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[10, 10])
109 | else:
110 | type_dict = dict(color="white", linewidth=1, zorder=10)
111 | elif way_type == "line_thick":
112 | way_subtype = get_subtype(way)
113 | if way_subtype == "dashed":
114 | type_dict = dict(color="white", linewidth=2, zorder=10, dashes=[10, 10])
115 | else:
116 | type_dict = dict(color="white", linewidth=2, zorder=10)
117 | elif way_type == "pedestrian_marking":
118 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[5, 10])
119 | elif way_type == "bike_marking":
120 | type_dict = dict(color="white", linewidth=1, zorder=10, dashes=[5, 10])
121 | elif way_type == "stop_line":
122 | type_dict = dict(color="white", linewidth=3, zorder=10)
123 | elif way_type == "virtual":
124 | type_dict = dict(color="blue", linewidth=1, zorder=10, dashes=[2, 5])
125 | elif way_type == "road_border":
126 | type_dict = dict(color="black", linewidth=1, zorder=10)
127 | elif way_type == "guard_rail":
128 | type_dict = dict(color="black", linewidth=1, zorder=10)
129 | elif way_type == "traffic_sign":
130 | continue
131 | else:
132 | if way_type not in unknown_linestring_types:
133 | unknown_linestring_types.append(way_type)
134 | continue
135 |
136 | x_list, y_list = get_x_y_lists(way, point_dict)
137 | plt.plot(x_list, y_list, **type_dict)
138 |
139 | if len(unknown_linestring_types) != 0:
140 | print("Found the following unknown types, did not plot them: " + str(unknown_linestring_types))
141 |
142 | return endpoint
143 |
--------------------------------------------------------------------------------
/src/utils/tracks_save.py:
--------------------------------------------------------------------------------
1 | # Save the full pipeline occlusion inference output. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | import matplotlib
4 | import matplotlib.patches
5 | import matplotlib.transforms
6 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox, AnchoredOffsetbox
7 | from scipy import ndimage
8 | import skimage.transform
9 | from PIL import Image
10 | import pdb
11 | from matplotlib import pyplot as plt
12 | from matplotlib.colors import ListedColormap
13 | from copy import deepcopy
14 |
15 | import os
16 | import time
17 |
18 | seed = 123
19 |
20 | import numpy as np
21 | np.random.seed(seed)
22 | from matplotlib import pyplot as plt
23 | import torch
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.backends.cudnn.deterministic = True
27 | torch.backends.cudnn.benchmark = False
28 | import torch.nn as nn
29 | torch.autograd.set_detect_anomaly(True)
30 |
31 | import io
32 | from tqdm import tqdm
33 | import time
34 |
35 | import argparse
36 | import pandas as pd
37 | import seaborn as sns
38 | import matplotlib.pyplot as plt
39 | from collections import OrderedDict, defaultdict
40 | import torch._utils
41 |
42 | from src.utils.dataset_types import Track, MotionState
43 | from src.utils.grid_utils import *
44 | from src.utils.grid_fuse import *
45 | from src.utils.utils_model import to_var
46 | from src.driver_sensor_model.models_cvae import VAE
47 | from src.utils.interaction_utils import *
48 |
49 | try:
50 | torch._utils._rebuild_tensor_v2
51 | except AttributeError:
52 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
53 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
54 | tensor.requires_grad = requires_grad
55 | tensor._backward_hooks = backward_hooks
56 | return tensor
57 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
58 |
59 | def rotate_around_center(pts, center, yaw):
60 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center
61 |
62 |
63 | def polygon_xy_from_motionstate(ms, width, length):
64 | assert isinstance(ms, MotionState)
65 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
66 | lowright = (ms.x + length / 2., ms.y - width / 2.)
67 | upright = (ms.x + length / 2., ms.y + width / 2.)
68 | upleft = (ms.x - length / 2., ms.y + width / 2.)
69 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad)
70 |
71 |
72 | def polygon_xy_from_motionstate_pedest(ms, width, length):
73 | assert isinstance(ms, MotionState)
74 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
75 | lowright = (ms.x + length / 2., ms.y - width / 2.)
76 | upright = (ms.x + length / 2., ms.y + width / 2.)
77 | upleft = (ms.x - length / 2., ms.y + width / 2.)
78 | return np.array([lowleft, lowright, upright, upleft])
79 |
80 | def update_objects_plot(timestamp, track_dict=None, pedest_dict=None,
81 | data=None, car_ids=None, sensor_grids=None, id_grids=None, label_grids=None,
82 | driver_sensor_data=None, driver_sensor_state_data=None, driver_sensor_state=None, endpoint=None, models=None, results=None, mode='evidential', model='vae'):
83 |
84 | ego_id = data[1][1]
85 |
86 | update = False
87 |
88 | if track_dict is not None:
89 |
90 | # Plot and save the ego-vehicle first.
91 | assert isinstance(track_dict[ego_id], Track)
92 | value = track_dict[ego_id]
93 | if (value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last):
94 | ms_ego = value.motion_states[timestamp]
95 | assert isinstance(ms_ego, MotionState)
96 |
97 | width = value.width
98 | length = value.length
99 |
100 | # Obtain the ego vehicle's grid.
101 | res = 1.0
102 | object_id, pedes_id = SceneObjects(track_dict, timestamp, track_pedes_dict=pedest_dict)
103 | local_x_ego, local_y_ego, _, _ = local_grid(ms_ego, width, length, res=res, ego_flag=True)
104 | label_grid_ego, center, x_local, y_local, pre_local_x, pre_local_y = generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=pedest_dict, pedes_id=pedes_id)
105 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x, pre_local_y, ms_ego, width, length, ego_flag=True, res=res)
106 | visible_id += [ego_id]
107 |
108 | full_sensor_grid_dst, mask_unk = get_belief_mass(sensor_grid_ego, ego_flag=True)
109 | full_sensor_grid = pignistic(full_sensor_grid_dst)
110 |
111 | run_time = 0.
112 |
113 | # Initialize the variables to keep for later computation.
114 | if model != 'kmeans':
115 | all_latent_classes = []
116 | ref_local_xy_list = []
117 | ego_local_xy_list = []
118 | alpha_p_list = []
119 | sensor_grid_ego_dst = [full_sensor_grid_dst]
120 |
121 | if mode == 'average':
122 | average_mask = np.zeros(full_sensor_grid.shape)
123 | driver_sensor_grid = np.zeros(full_sensor_grid.shape)
124 |
125 | # Consider the rest of the agents.
126 | for key, value in track_dict.items():
127 | assert isinstance(value, Track)
128 | if ((value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last)):
129 | ms = value.motion_states[timestamp]
130 | assert isinstance(ms, MotionState)
131 |
132 | width = value.width
133 | length = value.length
134 |
135 | # Consider all the visible drivers.
136 | if ((key in visible_id) and (key != ego_id)):
137 | res = 1.0
138 | if key in driver_sensor_state_data.keys():
139 | for state in driver_sensor_state_data[key]:
140 | if state[0] == timestamp:
141 | if key in driver_sensor_state.keys():
142 |
143 | # Make sure that the states are contiguous.
144 | if state[0] - driver_sensor_state[key][-1,0] == 100:
145 | driver_sensor_state[key] = np.concatenate((driver_sensor_state[key], np.reshape(state, (1,-1))))
146 | else:
147 | driver_sensor_state[key] = np.reshape(state, (1,-1))
148 | else:
149 | driver_sensor_state[key] = np.reshape(state, (1,-1))
150 |
151 | # Perform occlusion inference if at least 1 second of observed driver state has been observed.
152 | if ((driver_sensor_state[key].shape[0] == 10)):
153 |
154 | # Flag that the map is being updated.
155 | update = True
156 |
157 | start = time.time()
158 |
159 | x_local, y_local, _, _ = local_grid(ms, width, length, res=res, ego_flag=False, grid_shape=(20,30))
160 |
161 | # Merge grid with driver_sensor_grid.
162 | ref_local_xy = np.stack((x_local, y_local), axis=0)
163 | ego_local_xy = np.stack((local_x_ego, local_y_ego), axis=0)
164 |
165 | if model == 'kmeans':
166 | input_state = driver_sensor_state[key][:,1:]
167 | input_state = np.expand_dims(input_state.flatten(), 0)
168 | [kmeans, p_m_a_np] = models['kmeans']
169 | cluster_ids_y_val = kmeans.predict(input_state.astype('float32'))
170 | pred_maps = p_m_a_np[cluster_ids_y_val]
171 | pred_maps = np.reshape(pred_maps[0], (20,30))
172 |
173 | elif model == 'gmm':
174 | input_state = driver_sensor_state[key][:,1:]
175 | input_state = np.expand_dims(input_state.flatten(), 0)
176 | [gmm, p_m_a_np] = models['gmm']
177 | cluster_ids_y_val = gmm.predict(input_state.astype('float32'))
178 | alpha_p = gmm.predict_proba(input_state.astype('float32'))
179 | pred_maps = p_m_a_np[cluster_ids_y_val]
180 | pred_maps = np.reshape(pred_maps[0], (20,30))
181 |
182 | if len(all_latent_classes) == 0:
183 | all_latent_classes = [np.reshape(p_m_a_np, (100,20,30))]
184 |
185 | elif model == 'vae':
186 | input_state = preprocess(driver_sensor_state[key][:,1:])
187 | input_state = torch.unsqueeze(to_var(torch.from_numpy(input_state)), 0).float().cuda()
188 |
189 | models['vae'].eval()
190 | with torch.no_grad():
191 | pred_maps, alpha_p, _, _, z = models['vae'].inference(n=1, c=input_state, mode='most_likely')
192 | if len(all_latent_classes) == 0:
193 | recon_x_inf, _, _, _, _ = models['vae'].inference(n=100, c=input_state, mode='all')
194 | all_latent_classes = [recon_x_inf.cpu().numpy()]
195 |
196 | pred_maps = pred_maps[0][0].cpu().numpy()
197 | alpha_p = alpha_p.cpu().numpy()
198 |
199 | # Transfer the driver sensor model prediction to the ego vehicle's frame of reference.
200 | predEgoMaps = Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, full_sensor_grid, endpoint=endpoint, res=res, mask_unk=mask_unk)
201 |
202 | if model != 'kmeans':
203 | if not np.all(predEgoMaps == 2):
204 | alpha_p_list.append(alpha_p[0])
205 | ref_local_xy_list.append(ref_local_xy)
206 | ego_local_xy_list.append(ego_local_xy)
207 |
208 | # Fuse the driver sensor model prediction into the ego vehicle's grid.
209 | if mode == 'evidential':
210 | driver_sensor_grid_dst, _ = get_belief_mass(predEgoMaps, ego_flag=False, m=0.95)
211 | full_sensor_grid_dst, full_sensor_grid = dst_fusion(driver_sensor_grid_dst, full_sensor_grid_dst, mask_unk)
212 | elif mode == 'average':
213 | average_mask[predEgoMaps != 2] += 1.0
214 | driver_sensor_grid[predEgoMaps != 2] += predEgoMaps[predEgoMaps != 2]
215 |
216 | run_time += (time.time() - start)
217 |
218 | # Remove the oldest state.
219 | driver_sensor_state[key] = driver_sensor_state[key][1:]
220 |
221 | break
222 |
223 | # Save data if the map is updated.
224 | if update:
225 |
226 | if mode == 'average':
227 | driver_sensor_grid[average_mask != 0] /= average_mask[average_mask != 0]
228 | full_sensor_grid[average_mask != 0] = driver_sensor_grid[average_mask != 0]
229 |
230 | sensor_grid_ego[sensor_grid_ego == 2] = 0.5
231 | results['ego_sensor'].append(sensor_grid_ego)
232 | results['ego_label'].append(label_grid_ego[0])
233 | results['vae'].append(full_sensor_grid)
234 | results['timestamp'].append(timestamp)
235 | results['run_time'].append(run_time)
236 | if model != 'kmeans':
237 | results['all_latent_classes'].append(all_latent_classes[0])
238 | results['ref_local_xy'].append(ref_local_xy_list)
239 | results['ego_local_xy'].append(ego_local_xy_list)
240 | results['alpha_p'].append(alpha_p_list)
241 | results['ego_sensor_dst'].append(sensor_grid_ego_dst[0])
242 | results['endpoint'].append(endpoint)
243 | results['res'].append(res)
--------------------------------------------------------------------------------
/src/utils/tracks_vis.py:
--------------------------------------------------------------------------------
1 | # Plot the full pipeline occlusion inference. Code is adapted from: https://github.com/interaction-dataset/interaction-dataset.
2 |
3 | import matplotlib
4 | import matplotlib.patches
5 | import matplotlib.transforms
6 | from matplotlib.offsetbox import OffsetImage, AnnotationBbox, AnchoredOffsetbox
7 | from scipy import ndimage
8 | import skimage.transform
9 | from PIL import Image
10 | import pdb
11 | from matplotlib import pyplot as plt
12 | from matplotlib.colors import ListedColormap
13 | from copy import deepcopy
14 |
15 | import os
16 | import time
17 |
18 | seed = 123
19 |
20 | import numpy as np
21 | np.random.seed(seed)
22 | from matplotlib import pyplot as plt
23 | import torch
24 | torch.manual_seed(seed)
25 | torch.cuda.manual_seed(seed)
26 | torch.backends.cudnn.deterministic = True
27 | torch.backends.cudnn.benchmark = False
28 | import torch.nn as nn
29 | torch.autograd.set_detect_anomaly(True)
30 |
31 | import io
32 | from tqdm import tqdm
33 | import time
34 |
35 | import argparse
36 | import pandas as pd
37 | import seaborn as sns
38 | import matplotlib.pyplot as plt
39 | from collections import OrderedDict, defaultdict
40 | import torch._utils
41 |
42 | from src.utils.dataset_types import Track, MotionState
43 | from src.utils.grid_utils import *
44 | from src.utils.grid_fuse import *
45 | from src.utils.utils_model import to_var
46 | from src.driver_sensor_model.models_cvae import VAE
47 | from src.utils.interaction_utils import *
48 |
49 | try:
50 | torch._utils._rebuild_tensor_v2
51 | except AttributeError:
52 | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
53 | tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
54 | tensor.requires_grad = requires_grad
55 | tensor._backward_hooks = backward_hooks
56 | return tensor
57 | torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2
58 |
59 | def rotate_around_center(pts, center, yaw):
60 | return np.dot(pts - center, np.array([[np.cos(yaw), np.sin(yaw)], [-np.sin(yaw), np.cos(yaw)]])) + center
61 |
62 | def polygon_xy_from_motionstate(ms, width, length):
63 | assert isinstance(ms, MotionState)
64 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
65 | lowright = (ms.x + length / 2., ms.y - width / 2.)
66 | upright = (ms.x + length / 2., ms.y + width / 2.)
67 | upleft = (ms.x - length / 2., ms.y + width / 2.)
68 | return rotate_around_center(np.array([lowleft, lowright, upright, upleft]), np.array([ms.x, ms.y]), yaw=ms.psi_rad)
69 |
70 |
71 | def polygon_xy_from_motionstate_pedest(ms, width, length):
72 | assert isinstance(ms, MotionState)
73 | lowleft = (ms.x - length / 2., ms.y - width / 2.)
74 | lowright = (ms.x + length / 2., ms.y - width / 2.)
75 | upright = (ms.x + length / 2., ms.y + width / 2.)
76 | upleft = (ms.x - length / 2., ms.y + width / 2.)
77 | return np.array([lowleft, lowright, upright, upleft])
78 |
79 | def update_objects_plot(timestamp, patches_dict, text_dict, axes, track_dict=None, pedest_dict=None,
80 | data=None, car_ids=None, sensor_grids=None, id_grids=None, label_grids=None, grids_dict=None,
81 | driver_sensor_data=None, driver_sensor_state_data=None, driver_sensor_state=None, driver_sensor_state_dict=None, endpoint=None, models=None, mode='evidential', model='vae'):
82 |
83 | ego_id = data[1][1]
84 |
85 | if track_dict is not None:
86 |
87 | # Plot and save the ego-vehicle first.
88 | assert isinstance(track_dict[ego_id], Track)
89 | value = track_dict[ego_id]
90 | if (value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last):
91 |
92 | ms_ego = value.motion_states[timestamp]
93 | assert isinstance(ms_ego, MotionState)
94 |
95 | # Check if the ego vehicle already exists.
96 | if ego_id not in patches_dict:
97 | width = value.width
98 | length = value.length
99 |
100 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms_ego, width, length), closed=True,
101 | zorder=40, color='green')
102 |
103 | axes.add_patch(rect)
104 | patches_dict[ego_id] = rect
105 | text_dict[ego_id] = axes.text(ms_ego.x, ms_ego.y + 3, str(ego_id), horizontalalignment='center', zorder=50, fontsize='xx-large')
106 |
107 | else:
108 | width = value.width
109 | length = value.length
110 | patches_dict[ego_id].set_xy(polygon_xy_from_motionstate(ms_ego, width, length))
111 |
112 | text_dict[ego_id].set_position((ms_ego.x, ms_ego.y + 3))
113 |
114 | # Set colors of the ego vehicle.
115 | patches_dict[ego_id].set_color('green')
116 |
117 | if ego_id in grids_dict.keys():
118 | grids_dict[ego_id].remove()
119 |
120 | # Plot the ego vehicle's vanilla grid.
121 | res = 1.0
122 | colormap = ListedColormap(['white','black','gray'])
123 | object_id, pedes_id = SceneObjects(track_dict, timestamp, track_pedes_dict=pedest_dict)
124 | local_x_ego, local_y_ego, _, _ = local_grid(ms_ego, width, length, res=res, ego_flag=True)
125 | label_grid_ego, center, x_local, y_local, pre_local_x, pre_local_y = generateLabelGrid(timestamp, track_dict, ego_id, object_id, ego_flag=True, res=res, track_pedes_dict=pedest_dict, pedes_id=pedes_id)
126 | start = time.time()
127 | sensor_grid_ego, occluded_id, visible_id = generateSensorGrid(label_grid_ego, pre_local_x, pre_local_y, ms_ego, width, length, ego_flag=True, res=res)
128 | visible_id += [ego_id]
129 | # sensor_grid_ego = 2.0*np.ones((sensor_grid_ego.shape))
130 |
131 | full_sensor_grid_dst, mask_unk = get_belief_mass(sensor_grid_ego, ego_flag=True)
132 | full_sensor_grid = pignistic(full_sensor_grid_dst)
133 | full_sensor_grid_dst[0,label_grid_ego[0] == 2] = 0.0
134 | full_sensor_grid_dst[1,label_grid_ego[0] == 2] = 1.0
135 | full_sensor_grid[label_grid_ego[0] == 2] = 0.0
136 | sensor_grid_ego[label_grid_ego[0] == 2] = 0.0
137 |
138 | box = axes.pcolormesh(local_x_ego, local_y_ego, full_sensor_grid, cmap='gray_r', zorder=30, alpha=0.7, vmin=0, vmax=1)
139 | grids_dict[ego_id] = box
140 |
141 | # Initialize the variables to keep for later computation.
142 | if model != 'kmeans':
143 | all_latent_classes = []
144 | ref_local_xy_list = []
145 | ego_local_xy_list = []
146 | alpha_p_list = []
147 | sensor_grid_ego_dst = [full_sensor_grid_dst]
148 |
149 | if mode == 'average':
150 | average_mask = np.zeros(full_sensor_grid.shape)
151 | driver_sensor_grid = np.zeros(full_sensor_grid.shape)
152 |
153 | # Plot the rest of the agents.
154 | for key, value in track_dict.items():
155 | assert isinstance(value, Track)
156 | if ((value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last)):
157 |
158 | ms = value.motion_states[timestamp]
159 | assert isinstance(ms, MotionState)
160 |
161 | if key not in patches_dict:
162 | width = value.width
163 | length = value.length
164 |
165 | if ((key in visible_id) and (key != ego_id)):
166 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms, width, length), closed=True,
167 | zorder=40, color='cyan')
168 | colormap = ListedColormap(['white','black','gray'])
169 | res = 1.0
170 | if key in driver_sensor_state_data.keys():
171 | for state in driver_sensor_state_data[key]:
172 | if state[0] == timestamp:
173 | driver_sensor_state[key] = np.reshape(state, (1,-1))
174 | break
175 |
176 | elif ((key not in visible_id) and (key != ego_id)) :
177 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate(ms, width, length), closed=True,
178 | zorder=20)
179 | patches_dict[key] = rect
180 | axes.add_patch(rect)
181 | text_dict[key] = axes.text(ms.x, ms.y + 3, str(key), horizontalalignment='center', zorder=50, fontsize='xx-large')
182 |
183 | else:
184 | width = value.width
185 | length = value.length
186 | patches_dict[key].set_xy(polygon_xy_from_motionstate(ms, width, length))
187 | if (key in visible_id):
188 | patches_dict[key].set_zorder(40)
189 | elif (key not in visible_id):
190 | patches_dict[key].set_zorder(20)
191 | text_dict[key].set_position((ms.x, ms.y + 3))
192 |
193 | # Consider all the visible drivers.
194 | if ((key in visible_id) and (key != ego_id)):
195 | patches_dict[key].set_color('cyan')
196 |
197 | colormap = ListedColormap(['white','black','gray'])
198 | res = 1.0
199 | if key in driver_sensor_state_data.keys():
200 | for state in driver_sensor_state_data[key]:
201 | if state[0] == timestamp:
202 | if key in driver_sensor_state.keys():
203 |
204 | # Make sure that the states are contiguous.
205 | if state[0] - driver_sensor_state[key][-1,0] == 100:
206 | driver_sensor_state[key] = np.concatenate((driver_sensor_state[key], np.reshape(state, (1,-1))))
207 | else:
208 | driver_sensor_state[key] = np.reshape(state, (1,-1))
209 | else:
210 | driver_sensor_state[key] = np.reshape(state, (1,-1))
211 |
212 | # Perform occlusion inference if at least 1 second of observed driver state has been observed.
213 | if ((driver_sensor_state[key].shape[0] == 10)): # and key == 116):
214 |
215 | # Clear the existing plot.
216 | if key in driver_sensor_state_dict.keys():
217 | driver_sensor_state_dict[key].remove()
218 |
219 | box_sensor_state = axes.scatter(driver_sensor_state[key][:,1], driver_sensor_state[key][:,2], s=20, color='orange', zorder=50, alpha=1.0)
220 | driver_sensor_state_dict[key] = box_sensor_state
221 |
222 | x_local, y_local, _, _ = local_grid(ms, width, length, res=res, ego_flag=False, grid_shape=(20,30))
223 |
224 | # Clear the ego grid plot.
225 | if ego_id in grids_dict.keys():
226 | grids_dict[ego_id].remove()
227 |
228 | # Merge grid with driver_sensor_grid.
229 | ref_local_xy = np.stack((x_local, y_local), axis=0)
230 | ego_local_xy = np.stack((local_x_ego, local_y_ego), axis=0)
231 |
232 | if model == 'kmeans':
233 | input_state = driver_sensor_state[key][:,1:]
234 | input_state = np.expand_dims(input_state.flatten(), 0)
235 | [kmeans, p_m_a_np] = models['kmeans']
236 | cluster_ids_y_val = kmeans.predict(input_state.astype('float32'))
237 | pred_maps = p_m_a_np[cluster_ids_y_val]
238 | pred_maps = np.reshape(pred_maps[0], (20,30))
239 |
240 | elif model == 'gmm':
241 | input_state = driver_sensor_state[key][:,1:]
242 | input_state = np.expand_dims(input_state.flatten(), 0)
243 | [gmm, p_m_a_np] = models['gmm']
244 | cluster_ids_y_val = gmm.predict(input_state.astype('float32'))
245 | alpha_p = gmm.predict_proba(input_state.astype('float32'))
246 | pred_maps = p_m_a_np[cluster_ids_y_val]
247 | pred_maps = np.reshape(pred_maps[0], (20,30))
248 |
249 | if len(all_latent_classes) == 0:
250 | all_latent_classes = [np.reshape(p_m_a_np, (100,20,30))]
251 |
252 | elif model == 'vae':
253 | input_state = preprocess(driver_sensor_state[key][:,1:])
254 | input_state = torch.unsqueeze(to_var(torch.from_numpy(input_state)), 0).float().cuda()
255 |
256 | models['vae'].eval()
257 | with torch.no_grad():
258 | pred_maps, alpha_p, _, _, z = models['vae'].inference(n=1, c=input_state, mode='most_likely')
259 | if len(all_latent_classes) == 0:
260 | recon_x_inf, _, _, _, _ = models['vae'].inference(n=100, c=input_state, mode='all')
261 | all_latent_classes = [recon_x_inf.cpu().numpy()]
262 |
263 | pred_maps = pred_maps[0][0].cpu().numpy()
264 | alpha_p = alpha_p.cpu().numpy()
265 |
266 | # print(key,np.sort(alpha_p[0]))
267 |
268 | # Transfer the driver sensor model prediction to the ego vehicle's frame of reference.
269 | predEgoMaps = Transfer_to_EgoGrid(ref_local_xy, pred_maps, ego_local_xy, full_sensor_grid, endpoint=endpoint, res=res, mask_unk=mask_unk)
270 |
271 | if model != 'kmeans':
272 | if not np.all(predEgoMaps == 2):
273 | alpha_p_list.append(alpha_p[0])
274 | ref_local_xy_list.append(ref_local_xy)
275 | ego_local_xy_list.append(ego_local_xy)
276 |
277 | # Fuse the driver sensor model prediction into the ego vehicle's grid.
278 | if mode == 'evidential':
279 | driver_sensor_grid_dst, _ = get_belief_mass(predEgoMaps, ego_flag=False, m=0.95)
280 | full_sensor_grid_dst, full_sensor_grid = dst_fusion(driver_sensor_grid_dst, full_sensor_grid_dst, mask_unk)
281 |
282 | elif mode == 'average':
283 | average_mask[predEgoMaps != 2] += 1.0
284 | driver_sensor_grid[predEgoMaps != 2] += predEgoMaps[predEgoMaps != 2]
285 |
286 | grids_dict[ego_id] = axes.pcolormesh(local_x_ego, local_y_ego, full_sensor_grid, cmap='gray_r', zorder=30, alpha=0.7, vmin=0, vmax=1)
287 |
288 | # Remove the oldest state.
289 | driver_sensor_state[key] = driver_sensor_state[key][1:]
290 |
291 | break
292 | elif ((key not in visible_id) and (key != ego_id)):
293 | patches_dict[key].set_color('#1f77b4')
294 | if key in driver_sensor_state_dict.keys():
295 | driver_sensor_state_dict[key].remove()
296 | del driver_sensor_state_dict[key]
297 |
298 | else:
299 | if key in patches_dict:
300 | patches_dict[key].remove()
301 | patches_dict.pop(key)
302 | text_dict[key].remove()
303 | text_dict.pop(key)
304 | if key in driver_sensor_state_dict.keys():
305 | driver_sensor_state_dict[key].remove()
306 | del driver_sensor_state_dict[key]
307 |
308 | # Plot the pedestrians.
309 | if pedest_dict is not None:
310 |
311 | for key, value in pedest_dict.items():
312 | assert isinstance(value, Track)
313 | if value.time_stamp_ms_first <= timestamp <= value.time_stamp_ms_last:
314 | ms = value.motion_states[timestamp]
315 | assert isinstance(ms, MotionState)
316 |
317 | if key not in patches_dict:
318 | width = 1.5
319 | length = 1.5
320 |
321 | rect = matplotlib.patches.Polygon(polygon_xy_from_motionstate_pedest(ms, width, length), closed=True,
322 | zorder=20, color='red')
323 | patches_dict[key] = rect
324 | axes.add_patch(rect)
325 | text_dict[key] = axes.text(ms.x, ms.y + 3, str(key), horizontalalignment='center', zorder=50, fontsize='xx-large')
326 | else:
327 | width = 1.5
328 | length = 1.5
329 | patches_dict[key].set_xy(polygon_xy_from_motionstate_pedest(ms, width, length))
330 | text_dict[key].set_position((ms.x, ms.y + 3))
331 | else:
332 | if key in patches_dict:
333 | patches_dict[key].remove()
334 | patches_dict.pop(key)
335 | text_dict[key].remove()
336 | text_dict.pop(key)
337 |
338 |
--------------------------------------------------------------------------------
/src/utils/utils_model.py:
--------------------------------------------------------------------------------
1 | # Helper code for the CVAE driver sensor model. Code is adapted from: https://github.com/sisl/EvidentialSparsification.
2 |
3 | seed = 123
4 | import numpy as np
5 | np.random.seed(seed)
6 | import torch
7 | torch.manual_seed(seed)
8 | torch.cuda.manual_seed(seed)
9 |
10 | from torch.autograd import Variable
11 |
12 | def to_var(x, volatile=False):
13 | if torch.cuda.is_available():
14 | x = x.cuda()
15 | return Variable(x, volatile=volatile)
16 |
17 | def idx2onehot(idx, n):
18 |
19 | assert idx.size(1) == 1
20 | assert torch.max(idx).data < n
21 |
22 | onehot = torch.zeros(idx.size(0), n).cuda()
23 | onehot.scatter_(1, idx.data, 1)
24 | onehot = to_var(onehot)
25 |
26 | return onehot
27 |
28 | def sample_p(alpha, batch_size=1):
29 | zdist = torch.distributions.one_hot_categorical.OneHotCategorical(probs = alpha)
30 | return zdist.sample(torch.Size([batch_size]))
31 |
--------------------------------------------------------------------------------