├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── multiple_futures_prediction
├── __init__.py
├── assets
│ ├── __init__.py
│ └── imgs
│ │ ├── .gitignore
│ │ ├── mfp_comp_graph.png
│ │ └── neurips_mfp_poster.pdf
├── checkpts
│ └── .gitignore
├── cmd
│ ├── __init__.py
│ └── train_ngsim_cmd.py
├── configs
│ └── mfp2_ngsim.gin
├── dataset_ngsim.py
├── model_ngsim.py
├── my_utils.py
├── ngsim_data
│ └── .gitignore
├── py.typed
└── train_ngsim.py
├── mypy.ini
├── pyproject.toml
└── requirements.txt
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software have utilized the following copyrighted material, the use of which is hereby acknowledged.
3 |
4 | _____________________
5 | PyTorch (https://pytorch.org)
6 | We use PyTorch as the training framework for our model.
7 |
8 | From PyTorch:
9 |
10 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
11 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
12 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
13 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
14 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
15 | Copyright (c) 2011-2013 NYU (Clement Farabet)
16 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
17 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
18 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
19 |
20 | _____________________
21 | Nachiket Deo and Mohan M. Trivedi (https://github.com/nachiket92/conv-social-pooling)
22 | Our initial prototype and certain loss functions are based on the code based provided above, which is distrubted under MIT license.
23 |
24 | MIT License
25 |
26 | Copyright (c) 2018 Nachiket Deo
27 |
28 | Permission is hereby granted, free of charge, to any person obtaining a copy
29 | of this software and associated documentation files (the "Software"), to deal
30 | in the Software without restriction, including without limitation the rights
31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
32 | copies of the Software, and to permit persons to whom the Software is
33 | furnished to do so, subject to the following conditions:
34 |
35 | The above copyright notice and this permission notice shall be included in all
36 | copies or substantial portions of the Software.
37 |
38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
44 | SOFTWARE.
45 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | ## Before you get started
6 |
7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
43 |
44 | This software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multiple Futures Prediction
2 | ## Paper
3 | This software accompanies the paper [**Multiple Futures Prediction**](https://arxiv.org/abs/1911.00997). ([Poster](multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf))
4 | [Yichuan Charlie Tang](https://www.cs.toronto.edu/~tang) and Ruslan Salakhutdinov
5 | Neural Information Processing Systems, 2019. (NeurIPS 2019)
6 |
7 |
8 | Please cite our paper if you find our work useful for your research:
9 | ```
10 | @article{tang2019mfp,
11 | title={Multiple Futures Prediction},
12 | author={Tang, Yichuan Charlie and Salakhutdinov, Ruslan},
13 | booktitle={Advances in neural information processing systems},
14 | year={2019}
15 | }
16 | ```
17 |
18 | ## Introduction
19 | Multiple Futures Prediction (MFP) is a framework for learning to forecast or predict future trajectories of agents, such as vehicles or pedestrians. A key feature of our framework is that it is able to learn multiple modes or multiple possible futures, by learning directly from trajectory data without annotations. Multi-agent interactions are also taken into account and the framework scales to an arbitrary number of agents in the scene by using a novel dynamic attention mechanism. It currently achieves state-of-the-art results on three vehicle forecasting datasets.
20 |
21 | This research code is for demonstration purposes only. Please see the paper for more details.
22 |
23 | ### Overall Architecture
24 |
25 |
26 |
27 | The Multiple Futures Prediction (MFP) architecture is shown above. For an arbitrary number of agents in the scene, we first use RNNs to encode their past trajectories into feature vectors. Dynamic attentional contextual encoding aggregates interactions and relational information. For each agent, a distribution over its latent modes are then predicted. You can think of the latent modes representing conservative/aggressive behaviors or directional (left vs. right turns) intentions. Given a distribution over the latent modes, decoding RNNs are then employed to decode or forecast future temporal trajectories. The MFP is a latent-variable graphical model and we use the EM algorithm to optimize the evidence lower-bound.
28 |
29 |
30 | ## Getting Started
31 |
32 | ### Prerequisites
33 | This code is tested with Python 3.6, and PyTorch 1.1.0. Conda or Virtualenv is recommended.
34 | Use pip (recent version, e.g. 20.1) to install dependencies, for example:
35 | ```
36 | python3.6 -m venv .venv # Create new venv
37 | source ./venv/bin/activate # Activate it
38 | pip install -U pip # Update to latest version of pip
39 | pip install -r requirements.txt # Install everything
40 | ```
41 |
42 | ### Datasets
43 |
44 | #### How to Obtain NGSIM Data:
45 |
46 | 1. Obtain NGSIM Dataset here (US-101 and I-80):
47 | (https://data.transportation.gov/Automobiles/Next-Generation-Simulation-NGSIM-Vehicle-Trajector/8ect-6jqj)
48 | ```
49 | Specifically you will need these files:
50 | US-101:
51 | '0750am-0805am/trajectories-0750am-0805am.txt'
52 | '0805am-0820am/trajectories-0805am-0820am.txt'
53 | '0820am-0835am/trajectories-0820am-0835am.txt'
54 |
55 | I-80:
56 | '0400pm-0415pm/trajectories-0400-0415.txt'
57 | '0500pm-0515pm/trajectories-0500-0515.txt'
58 | '0515pm-0530pm/trajectories-0515-0530.txt'
59 | ```
60 | 2. Preprocess dataset with code from Nachiket Deo and Mohan M. Trivedi:
[Convolutional Social Pooling for Vehicle Trajectory Prediction.] (CVPRW, 2018)
61 | (https://github.com/nachiket92/conv-social-pooling)
62 |
63 | 3. From the conv-social-pooling repo, run prepocess_data.m, this should obtain three files:
64 | TrainSet.mat, ValSet.mat, and TestSet.mat. Copy them to the ngsim_data folder.
65 |
66 | ### Usage
67 |
68 | #### Training
69 | ```bash
70 | train_ngsim --config multiple_futures_prediction/configs/mfp2_ngsim.gin
71 | ```
72 | or
73 | ```bash
74 | python -m multiple_futures_prediction.cmd.train_ngsim_cmd --config multiple_futures_prediction/configs/mfp2_ngsim.gin
75 | ```
76 | Hyperparameters (e.g. specifying how many modes of MFP) can be specified in the .gin config files.
77 |
78 | Expected training outputs:
79 | ```
80 | Epoch no: 0 update: 99 | Avg train loss: 57.3198 learning_rate:0.00100
81 | Epoch no: 0 update: 199 | Avg train loss: 4.7679 learning_rate:0.00100
82 | Epoch no: 0 update: 299 | Avg train loss: 4.3250 learning_rate:0.00100
83 | Epoch no: 0 update: 399 | Avg train loss: 4.0717 learning_rate:0.00100
84 | Epoch no: 0 update: 499 | Avg train loss: 3.9722 learning_rate:0.00100
85 | Epoch no: 0 update: 599 | Avg train loss: 3.8525 learning_rate:0.00100
86 | Epoch no: 0 update: 699 | Avg train loss: 3.5253 learning_rate:0.00100
87 | Epoch no: 0 update: 799 | Avg train loss: 3.6077 learning_rate:0.00100
88 | Epoch no: 0 update: 899 | Avg train loss: 3.4526 learning_rate:0.00100
89 | Epoch no: 0 update: 999 | Avg train loss: 3.5830 learning_rate:0.00100
90 | Starting eval
91 | eval val_dl nll
92 | tensor([-1.5164, -0.3173, 0.3902, 0.9374, 1.3751, 1.7362, 2.0362, 2.3008,
93 | 2.5510, 2.7974, 3.0370, 3.2702, 3.4920, 3.7007, 3.8979, 4.0836,
94 | 4.2569, 4.4173, 4.5682, 4.7082, 4.8378, 4.9581, 5.0716, 5.1855,
95 | 5.3239])
96 | ```
97 | Depending on the CPU/GPU available, it can take from one to two days to complete
98 | 300K training updates on the NGSIM dataset and match the results in Table 5 of the paper.
99 |
100 | ## License
101 | This code is released under the [LICENSE](LICENSE) terms.
102 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2020 Apple Inc. All rights reserved.
3 | #
4 |
5 | """Demo code for Multiple Futures Prediction Paper (NeurIPS 2019)."""
6 |
7 | __version__ = "0.1.0"
8 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/assets/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2020 Apple Inc. All rights reserved.
3 | #
4 |
5 | """This subpackage contains small resources necessary to run this package."""
6 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/assets/imgs/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/.gitignore
--------------------------------------------------------------------------------
/multiple_futures_prediction/assets/imgs/mfp_comp_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/mfp_comp_graph.png
--------------------------------------------------------------------------------
/multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf
--------------------------------------------------------------------------------
/multiple_futures_prediction/checkpts/.gitignore:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/checkpts/.gitignore
--------------------------------------------------------------------------------
/multiple_futures_prediction/cmd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/cmd/__init__.py
--------------------------------------------------------------------------------
/multiple_futures_prediction/cmd/train_ngsim_cmd.py:
--------------------------------------------------------------------------------
1 | from typing import List, Set, Dict, Tuple, Optional, Union, Any
2 | from multiple_futures_prediction.train_ngsim import train, Params
3 | import gin
4 | import argparse
5 |
6 | def parse_args() -> Any:
7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
8 | parser.add_argument('--config', type=str, default='')
9 | return parser.parse_args()
10 |
11 | def main() -> None:
12 | args = parse_args()
13 | gin.parse_config_file( args.config )
14 | params = Params()()
15 | train(params)
16 |
17 | # python -m multiple_futures_prediction.cmd.train_ngsim_cmd --config multiple_futures_prediction/configs/mfp2_ngsim.gin
18 | if __name__ == '__main__':
19 | main()
20 |
21 |
22 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/configs/mfp2_ngsim.gin:
--------------------------------------------------------------------------------
1 | Params.log = 0 # log results and save model checkpoint
2 | Params.modes = 2 # MFP modes
3 |
4 | Params.subsampling = 2 # subsampling factor
5 | Params.hist_len_orig_hz = 30 # history steps at original sampling rate (e.g. 10 hz)
6 | Params.fut_len_orig_hz = 50 # future steps at original sampling rate (e.g. 10 hz)
7 |
8 | Params.encoder_size = 64 # sizes of various neural layers
9 | Params.decoder_size = 128
10 | Params.dyn_embedding_size = 32
11 | Params.input_embedding_size = 32
12 | Params.dec_nbr_enc_size = 8
13 | Params.nbr_atten_embedding_size = 80
14 |
15 | Params.seed = 1234
16 | Params.use_cuda = True
17 |
18 | Params.remove_y_mean = True # remove the mean of the future trajectory
19 | Params.use_gru = True # GRU or LSTM
20 | Params.bi_direc = True # bidrectional RNN
21 | Params.self_norm = True # Normalize prediction targets
22 | Params.data_aug = False # Add noise to data?
23 | Params.use_context = False # use contexture bird's eye view map?
24 | Params.nll = True # Use negative Log-liklihood loss
25 | Params.use_forcing = 0 # 0: no forcing. 1: teacher-forcing. 2: classmates-forcing.
26 |
27 | Params.iter_per_err = 100
28 | Params.iter_per_eval = 1000
29 | Params.iters_per_save = 5000
30 | Params.pre_train_num_updates = 200000
31 | Params.updates_div_by_10 = 100000
32 | Params.nbr_search_depth = 10 # depth of searching for finding 'neighbors' (only for NGSIM dataset)
33 |
34 | Params.lr_init = 0.001 # initial learning rate
35 | Params.min_lr = 0.00005 # minimum learning rate
36 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/dataset_ngsim.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any
7 | from torch.utils.data import Dataset, DataLoader
8 | import scipy.io as scp
9 | import numpy as np
10 | import torch
11 | import pickle
12 | import os
13 | import cv2
14 |
15 | # Dataset for pytorch training
16 | class NgsimDataset(Dataset):
17 | def __init__(self, mat_file:str, t_h:int=30, t_f:int=50, d_s:int = 2,
18 | enc_size:int=64, use_gru:bool=False, self_norm:bool=False,
19 | data_aug:bool=False, use_context:bool=False, nbr_search_depth:int= 3,
20 | ds_seed:int=1234) -> None:
21 | self.D = scp.loadmat(mat_file)['traj']
22 | self.T = scp.loadmat(mat_file)['tracks']
23 | self.t_h = t_h # length of track history
24 | self.t_f = t_f # length of predicted trajectory
25 | self.d_s = d_s # down sampling rate of all sequences
26 | self.enc_size = enc_size # size of encoder LSTM
27 | self.grid_size = (13,3) # size of context grid
28 | self.enc_fac = 2 if use_gru else 1
29 | self.self_norm = self_norm
30 | self.data_aug = data_aug
31 | self.noise = np.array([[0.5, 2.0]])
32 | self.dt = 0.1*self.d_s
33 | self.ft_to_m = 0.3048
34 | self.use_context = use_context
35 | if self.use_context:
36 | self.maps = pickle.load(open('data/maps.pkl', 'rb'))
37 | self.nbr_search_depth = nbr_search_depth
38 |
39 | cache_file = 'multiple_futures_prediction/ngsim_data/NgsimIndex_%s.p'%os.path.basename(mat_file)
40 | #build index of [dataset (0 based), veh_id_0b, frame(time)] into a dictionary
41 | if not os.path.exists(cache_file):
42 | self.Index = {}
43 | print('building index...')
44 | for i, row in enumerate(self.D):
45 | key = (int(row[0]-1), int(row[1]-1), int(row[2]))
46 | self.Index[key] = i
47 | print('build index done')
48 | pickle.dump( self.Index, open(cache_file,'wb'))
49 | else:
50 | self.Index = pickle.load( open(cache_file,'rb'))
51 |
52 | self.ind_random = np.arange(len(self.D))
53 | self.seed = ds_seed
54 | np.random.seed(self.seed)
55 | np.random.shuffle(self.ind_random)
56 |
57 | def __len__(self) -> int:
58 | return len(self.D)
59 |
60 | def convert_pt(self, pt: np.ndarray, dsId0b: int) -> np.ndarray:
61 | """Convert a point from abs coords to pixel coords.
62 | Args:
63 | pt is a 2d x,y coordinate
64 | dsId0b - data set id 0 index based
65 | """
66 | ft_per_pixel = self.maps[dsId0b]['ft_per_pixel']
67 | return np.array([int((np.round(pt[0])-self.maps[dsId0b]['x0'])/ft_per_pixel),
68 | int((np.round(pt[1]) - self.maps[dsId0b]['y0'])/ft_per_pixel)])
69 |
70 | def compute_vel_theta(self, hist: torch.Tensor, frac: Optional[float]=0.5) -> Tuple[np.ndarray, np.ndarray]:
71 | """Estimate velocity and orientation from history trajectory."""
72 | if hist.shape[0] <= 1:
73 | return np.array([0.0]), np.array([0.0])
74 | else:
75 | total_wts = 0.0
76 | counter = 0.0
77 | vel = theta = 0.0
78 | for t in range(hist.shape[0]-1,0,-1):
79 | counter += 1.0
80 | wt = np.power(frac, counter)
81 | total_wts += wt
82 | diff = hist[t,:] - hist[t-1,:]
83 | vel += wt*np.linalg.norm(diff)*self.ft_to_m/self.dt
84 | theta += wt*np.arctan2(diff[1],diff[0])
85 | return np.array([vel/total_wts]), np.array([theta/total_wts])
86 |
87 | def find_neighbors(self, dsID_0b: int, vehId_0b: int, frame: int) -> Dict:
88 | """Find a list of neighbors w.r.t. self."""
89 | key = ( int(dsID_0b), int(vehId_0b), int(frame))
90 | if key not in self.Index:
91 | return {}
92 |
93 | idx = self.Index[key]
94 | grid1b = self.D[idx,8:]
95 | nonzero = np.nonzero(grid1b)[0]
96 | return {vehId_0b: list(zip( grid1b[nonzero].astype(np.int64)-1, nonzero)) }
97 |
98 | def __getitem__(self, idx_: int) -> Tuple[ List, List, Dict, Union[None, np.ndarray] ] :
99 | idx = self.ind_random[idx_]
100 | dsId_1b = self.D[idx, 0].astype(int)
101 | vehId_1b = self.D[idx, 1].astype(int)
102 | dsId_0b = dsId_1b-1
103 | vehId_0b = vehId_1b-1
104 | t = self.D[idx, 2]
105 | grid = self.D[idx,8:] #1-based
106 |
107 | ids = {} #0-based keys
108 | leafs = [vehId_0b]
109 | for _ in range( self.nbr_search_depth ):
110 | new_leafs = []
111 | for id0b in leafs:
112 | nbr_dict = self.find_neighbors(dsId_0b, id0b, t)
113 | if len(nbr_dict) > 0:
114 | ids.update( nbr_dict )
115 | if len(nbr_dict[id0b]) > 0:
116 | nbr_id0b = list( zip(*nbr_dict[id0b]))[0]
117 | new_leafs.extend ( nbr_id0b )
118 | leafs = np.unique( new_leafs )
119 | leafs = [x for x in leafs if x not in ids]
120 |
121 | sorted_keys = sorted(ids.keys()) # e.g. [1, 3, 4, 5, ... , 74]
122 | id_map = {key: value for (key, value) in zip(sorted_keys, np.arange(len(sorted_keys)).tolist()) }
123 | #obj id to index within a batch
124 | sz = len(ids)
125 | assert sz > 0
126 |
127 | hist = []
128 | fut = []
129 | neighbors: Dict[int,List] = {} # key is batch ind, followed by a list of (batch_ind, ego_id, grid/nbr ind)
130 |
131 | for ind, vehId0b in enumerate(sorted_keys):
132 | hist.append( self.getHistory0b(vehId0b,t,dsId_0b) ) # no normalization
133 | fut.append( self.getFuture(vehId0b+1,t,dsId_0b+1 ) ) #subtract off ref pos
134 | neighbors[ind] = []
135 | for v_id, nbr_ind in ids[ vehId0b ]:
136 | if v_id not in id_map:
137 | k2 = -1
138 | else:
139 | k2 = id_map[v_id]
140 | neighbors[ind].append( (k2, v_id, nbr_ind) )
141 |
142 | if self.use_context:
143 | x_range_ft = np.array([-15, 15])
144 | y_range_ft = np.array([-30, 300])
145 |
146 | pad = int(np.ceil(300/self.maps[dsId_1b-1]['ft_per_pixel'])) # max of all ranges
147 |
148 | if not 'im_color' in self.maps[dsId_1b-1]:
149 | im_big = np.pad(self.maps[dsId_1b-1]['im'], ((pad, pad), (pad, pad)), 'constant', constant_values= 0.0)
150 | self.maps[dsId_1b-1]['im_color'] = (im_big[np.newaxis,...].repeat(3, axis=0)*255.0).astype(np.uint8)
151 |
152 | im = self.maps[dsId_1b-1]['im_color']
153 | height, width = im.shape[1:]
154 |
155 | ref_pos = self.D[idx,3:5]
156 | im_x, im_y = self.convert_pt( ref_pos, dsId_1b-1 )
157 | im_x += pad
158 | im_y += pad
159 |
160 | x_range = (x_range_ft/self.maps[dsId_1b-1]['ft_per_pixel']).astype(int)
161 | y_range = (y_range_ft/self.maps[dsId_1b-1]['ft_per_pixel']).astype(int)
162 |
163 | x_range[0] = np.maximum( 0, x_range[0]+im_x )
164 | x_range[1] = np.minimum( width-1, x_range[1]+im_x )
165 | y_range[0] = np.maximum( 0, y_range[0]+im_y )
166 | y_range[1] = np.minimum( height-1, y_range[1]+im_y )
167 |
168 | im_crop = np.ascontiguousarray(im[:, y_range[0]:y_range[1], x_range[0]:x_range[1]].transpose((1,2,0)))
169 | im_crop[:,:,[0, 1]] = 0
170 |
171 | for _, other in neighbors.items():
172 | if len(other) == 0:
173 | continue
174 | for k in range( len(other)-1 ):
175 | x1, y1 = self.convert_pt( other[k]+ref_pos, dsId_1b-1 )
176 | x2, y2 = self.convert_pt( other[k+1]+ref_pos, dsId_1b-1 )
177 | x1+=pad; y1+=pad; x2+=pad; y2+=pad
178 | cv2.line(im_crop,(x1-x_range[0],y1-y_range[0]),(x2-x_range[0],y2-y_range[0]), (255, 0, 0), 2 )
179 |
180 | x, y = self.convert_pt( other[-1]+ref_pos, dsId_1b-1 )
181 | x+=pad; y+=pad
182 | cv2.circle(im_crop, (x-x_range[0], y-y_range[0]), 4, (255, 0, 0), -1)
183 |
184 | for k in range( len(hist)-1 ):
185 | x1, y1 = self.convert_pt( hist[k]+ref_pos, dsId_1b-1 )
186 | x2, y2 = self.convert_pt( hist[k+1]+ref_pos, dsId_1b-1 )
187 | x1+=pad; y1+=pad; x2+=pad; y2+=pad
188 | cv2.line(im_crop,(x1-x_range[0],y1-y_range[0]),(x2-x_range[0],y2-y_range[0]), (0, 255, 0), 3 )
189 |
190 | cv2.circle(im_crop, (im_x-x_range[0], im_y-y_range[0]), 5, (0, 255, 0), -1)
191 | assert im_crop.shape == (660,60,3)
192 | im_crop = im_crop.transpose((2,0,1))
193 | else:
194 | im_crop = None
195 | return hist, fut, neighbors, im_crop # neighbors is a list of all vehicles in the batch
196 |
197 | def getHistory(self, vehId: int, t: int, refVehId: int, dsId: int) -> np.ndarray:
198 | """Get trajectory history. VehId and refVehId are 1-based."""
199 | if vehId == 0:
200 | return np.empty([0,2])
201 | else:
202 | if self.T.shape[1]<=vehId-1:
203 | return np.empty([0,2])
204 | vehTrack = self.T[dsId-1][vehId-1].transpose()
205 | if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
206 | return np.empty([0,2])
207 | else:
208 | refTrack = self.T[dsId-1][refVehId-1].transpose()
209 | found = np.where(refTrack[:,0]==t)
210 | refPos = refTrack[found][0,1:3]
211 |
212 | stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
213 | enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
214 | hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
215 |
216 | if self.data_aug:
217 | hist += np.random.randn( hist.shape[0],hist.shape[1] )*self.noise
218 |
219 | if len(hist) < self.t_h//self.d_s + 1:
220 | return np.empty([0,2])
221 | return hist
222 |
223 | def getFuture(self, vehId:int, t:int, dsId: int) -> np.ndarray :
224 | """Get future trajectory. VehId and dsId are 1-based."""
225 | vehTrack = self.T[dsId-1][vehId-1].transpose()
226 | refPos = vehTrack[np.where(vehTrack[:, 0] == t)][0, 1:3]
227 | stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s
228 | enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1)
229 | fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
230 | return fut
231 |
232 | def getHistory0b(self, vehId:int, t:int, dsId:int ) -> np.ndarray :
233 | """Get track history trajectory. VehId and dsId are zero-based.
234 | No normalizations are performed.
235 | """
236 | if vehId < 0:
237 | return np.empty([0,2])
238 | else:
239 | if vehId >= self.T.shape[1]:
240 | return np.empty([0,2])
241 | vehTrack = self.T[dsId][vehId].transpose()
242 | if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0:
243 | return np.empty([0,2])
244 | else:
245 | stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h)
246 | enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1
247 | hist = vehTrack[stpt:enpt:self.d_s,1:3]
248 | if len(hist) < self.t_h//self.d_s + 1:
249 | return np.empty([0,2])
250 | return hist
251 |
252 | def getFuture0b(self, vehId:int, t:int, dsId:int, refPos:int) -> np.ndarray :
253 | """Get track future trajectory. VehId and dsId are zero-based."""
254 | vehTrack = self.T[dsId][vehId].transpose()
255 | stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s
256 | enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1)
257 | fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos
258 | return fut
259 |
260 | def collate_fn(self, samples: List[Any]) -> Tuple[Any,Any,Any,Any,Any,Union[Any,None],Any] :
261 | """Prepare a batch suitable for MFP training."""
262 | nbr_batch_size = 0
263 | num_samples = 0
264 | for _,_,nbrs,im_crop in samples:
265 | nbr_batch_size += sum([len(nbr) for nbr in nbrs.values() ])
266 | num_samples += len(nbrs)
267 |
268 | maxlen = self.t_h//self.d_s + 1
269 | if nbr_batch_size <= 0:
270 | nbrs_batch = torch.zeros(maxlen,1,2)
271 | else:
272 | nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2)
273 |
274 | pos = [0, 0]
275 | nbr_inds_batch = torch.zeros( num_samples, self.grid_size[1],self.grid_size[0], self.enc_size*self.enc_fac)
276 | nbr_inds_batch = nbr_inds_batch.byte()
277 |
278 | hist_batch = torch.zeros(maxlen, num_samples, 2) #e.g. (31, 41, 2)
279 | fut_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2)
280 | mask_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2)
281 | if self.use_context:
282 | context_batch = torch.zeros(num_samples, im_crop.shape[0], im_crop.shape[1], im_crop.shape[2] )
283 | else:
284 | context_batch: Union[None, torch.Tensor] = None # type: ignore
285 |
286 | nbrs_infos = []
287 | count = 0
288 | samples_so_far = 0
289 | for sampleId,(hist, fut, nbrs, context) in enumerate(samples):
290 | num = len(nbrs)
291 | for j in range(num):
292 | hist_batch[0:len(hist[j]), samples_so_far+j, :] = torch.from_numpy(hist[j])
293 | fut_batch[0:len(fut[j]), samples_so_far+j, :] = torch.from_numpy(fut[j])
294 | mask_batch[0:len(fut[j]),samples_so_far+j,:] = 1
295 | samples_so_far += num
296 |
297 | nbrs_infos.append(nbrs)
298 |
299 | if self.use_context:
300 | context_batch[sampleId,:,:,:] = torch.from_numpy(context)
301 |
302 | # nbrs is a dictionary of key to list of nbr (batch_index, veh_id, grid_ind)
303 | for batch_ind, list_of_nbr in nbrs.items():
304 | for batch_id, vehid, grid_ind in list_of_nbr:
305 | if batch_id >= 0:
306 | nbr_hist = hist[batch_id]
307 | nbrs_batch[0:len(nbr_hist),count,:] = torch.from_numpy( nbr_hist )
308 | pos[0] = grid_ind % self.grid_size[0]
309 | pos[1] = grid_ind // self.grid_size[0]
310 | nbr_inds_batch[batch_ind,pos[1],pos[0],:] = torch.ones(self.enc_size*self.enc_fac).byte()
311 | count+=1
312 |
313 | return (hist_batch, nbrs_batch, nbr_inds_batch, fut_batch, mask_batch, context_batch, nbrs_infos)
314 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/model_ngsim.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any
7 | import torch
8 | import torch.nn as nn
9 | from multiple_futures_prediction.my_utils import *
10 |
11 | # Multiple Futures Prediction Network
12 | class mfpNet(nn.Module):
13 | def __init__(self, args: Dict) -> None:
14 | super(mfpNet, self).__init__() #type: ignore
15 | self.use_cuda = args['use_cuda']
16 | self.encoder_size = args['encoder_size']
17 | self.decoder_size = args['decoder_size']
18 | self.out_length = args['fut_len_orig_hz']//args['subsampling']
19 |
20 | self.dyn_embedding_size = args['dyn_embedding_size']
21 | self.input_embedding_size = args['input_embedding_size']
22 |
23 | self.nbr_atten_embedding_size = args['nbr_atten_embedding_size']
24 | self.st_enc_hist_size = self.nbr_atten_embedding_size
25 | self.st_enc_pos_size = args['dec_nbr_enc_size']
26 | self.use_gru = args['use_gru']
27 | self.bi_direc = args['bi_direc']
28 | self.use_context = args['use_context']
29 | self.modes = args['modes']
30 | self.use_forcing = args['use_forcing'] # 1: Teacher forcing. 2:classmates forcing.
31 |
32 | self.hidden_fac = 2 if args['use_gru'] else 1
33 | self.bi_direc_fac = 2 if args['bi_direc'] else 1
34 | self.dec_fac = 2 if args['bi_direc'] else 1
35 |
36 | self.init_rbf_state_enc( in_dim=self.encoder_size*self.hidden_fac )
37 | self.posi_enc_dim = self.st_enc_pos_size
38 | self.posi_enc_ego_dim = 2
39 |
40 | # Input embedding layer
41 | self.ip_emb = torch.nn.Linear(2,self.input_embedding_size) #type: ignore
42 |
43 | # Encoding RNN.
44 | if not self.use_gru:
45 | self.enc_lstm = torch.nn.LSTM(self.input_embedding_size,self.encoder_size,1) # type: ignore
46 | else:
47 | self.num_layers=2
48 | self.enc_lstm = torch.nn.GRU(self.input_embedding_size,self.encoder_size, # type: ignore
49 | num_layers=self.num_layers, bidirectional=False)
50 |
51 | # Dynamics embeddings.
52 | self.dyn_emb = torch.nn.Linear(self.encoder_size*self.hidden_fac, self.dyn_embedding_size) #type: ignore
53 |
54 | context_feat_size = 64 if self.use_context else 0
55 | self.dec_lstm = []
56 | self.op = []
57 | for k in range(self.modes):
58 | if not self.use_gru:
59 | self.dec_lstm.append( torch.nn.LSTM(self.nbr_atten_embedding_size + self.dyn_embedding_size + #type: ignore
60 | context_feat_size+self.posi_enc_dim+self.posi_enc_ego_dim, self.decoder_size) )
61 | else:
62 | self.num_layers=2
63 | self.dec_lstm.append( torch.nn.GRU(self.nbr_atten_embedding_size + self.dyn_embedding_size + context_feat_size+self.posi_enc_dim+self.posi_enc_ego_dim, # type: ignore
64 | self.decoder_size, num_layers=self.num_layers, bidirectional=self.bi_direc ))
65 |
66 | self.op.append( torch.nn.Linear(self.decoder_size*self.dec_fac, 5) ) #type: ignore
67 |
68 | self.op[k] = self.op[k]
69 | self.dec_lstm[k] = self.dec_lstm[k]
70 |
71 | self.dec_lstm = torch.nn.ModuleList(self.dec_lstm) # type: ignore
72 | self.op = torch.nn.ModuleList(self.op ) # type: ignore
73 |
74 | self.op_modes = torch.nn.Linear(self.nbr_atten_embedding_size + self.dyn_embedding_size + context_feat_size, self.modes) #type: ignore
75 |
76 | # Nonlinear activations.
77 | self.leaky_relu = torch.nn.LeakyReLU(0.1) #type: ignore
78 | self.relu = torch.nn.ReLU() #type: ignore
79 | self.softmax = torch.nn.Softmax(dim=1) #type: ignore
80 |
81 | if self.use_context:
82 | self.context_conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=2) #type: ignore
83 | self.context_conv2 = torch.nn.Conv2d(16, 16, kernel_size=3, stride=2) #type: ignore
84 | self.context_maxpool = torch.nn.MaxPool2d(kernel_size=(4,2)) #type: ignore
85 | self.context_conv3 = torch.nn.Conv2d(16, 16, kernel_size=3, stride=2) #type: ignore
86 | self.context_fc = torch.nn.Linear(16*20*3, context_feat_size) #type: ignore
87 |
88 | def init_rbf_state_enc(self, in_dim: int ) -> None:
89 | """Initialize the dynamic attentional RBF encoder.
90 | Args:
91 | in_dim is the input dim of the observation.
92 | """
93 | self.sec_in_dim = in_dim
94 | self.extra_pos_dim = 2
95 |
96 | self.sec_in_pos_dim = 2
97 | self.sec_key_dim = 8
98 | self.sec_key_hidden_dim = 32
99 |
100 | self.sec_hidden_dim = 32
101 | self.scale = 1.0
102 | self.slot_key_scale = 1.0
103 | self.num_slots = 8
104 | self.slot_keys = []
105 |
106 | # Network for computing the 'key'
107 | self.sec_key_net = torch.nn.Sequential( #type: ignore
108 | torch.nn.Linear(self.sec_in_dim+self.extra_pos_dim, self.sec_key_hidden_dim),
109 | torch.nn.ReLU(),
110 | torch.nn.Linear(self.sec_key_hidden_dim, self.sec_key_dim)
111 | )
112 |
113 | for ss in range(self.num_slots):
114 | self.slot_keys.append( torch.nn.Parameter( self.slot_key_scale*torch.randn( self.sec_key_dim, 1, dtype=torch.float32) ) ) #type: ignore
115 | self.slot_keys = torch.nn.ParameterList( self.slot_keys ) # type: ignore
116 |
117 | # Network for encoding a scene-level contextual feature.
118 | self.sec_hist_net = torch.nn.Sequential( #type: ignore
119 | torch.nn.Linear(self.sec_in_dim*self.num_slots, self.sec_hidden_dim),
120 | torch.nn.ReLU(),
121 | torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
122 | torch.nn.ReLU(),
123 | torch.nn.Linear(self.sec_hidden_dim, self.st_enc_hist_size)
124 | )
125 |
126 | # Encoder position of other's into a feature network, input should be normalized to ref_pos.
127 | self.sec_pos_net = torch.nn.Sequential( #type: ignore
128 | torch.nn.Linear(self.sec_in_pos_dim*self.num_slots, self.sec_hidden_dim),
129 | torch.nn.ReLU(),
130 | torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim),
131 | torch.nn.ReLU(),
132 | torch.nn.Linear(self.sec_hidden_dim, self.st_enc_pos_size)
133 | )
134 |
135 | def rbf_state_enc_get_attens(self, nbrs_enc: torch.Tensor, ref_pos: torch.Tensor, nbrs_info_this: List ) -> List[torch.Tensor]:
136 | """Computing the attention over other agents.
137 | Args:
138 | nbrs_info_this is a list of list of (nbr_batch_ind, nbr_id, nbr_ctx_ind)
139 | Returns:
140 | attention weights over the neighbors.
141 | """
142 | assert len(nbrs_info_this) == ref_pos.shape[0]
143 | if self.extra_pos_dim > 0:
144 | pos_enc = torch.zeros(nbrs_enc.shape[0],2, device=nbrs_enc.device)
145 | counter = 0
146 | for n in range(len(nbrs_info_this)):
147 | for nbr in nbrs_info_this[n]:
148 | pos_enc[counter,:] = ref_pos[nbr[0],:] - ref_pos[n,:]
149 | counter += 1
150 | Key = self.sec_key_net( torch.cat( (nbrs_enc,pos_enc),dim=1) )
151 | # e.g. num_agents by self.sec_key_dim
152 | else:
153 | Key = self.sec_key_net( nbrs_enc ) # e.g. num_agents by self.sec_key_dim
154 |
155 | attens0 = []
156 | for slot in self.slot_keys:
157 | attens0.append( torch.exp( -self.scale*(Key-torch.t(slot)).norm(dim=1)) )
158 |
159 | Atten = torch.stack(attens0, dim=0) # e.g. num_keys x num_agents
160 | attens = []
161 | counter = 0
162 | for n in range(len(nbrs_info_this)):
163 | list_of_nbrs = nbrs_info_this[n]
164 | counter2 = counter+len(list_of_nbrs)
165 | attens.append( Atten[:, counter:counter2 ] )
166 | counter = counter2
167 | return attens
168 |
169 | def rbf_state_enc_hist_fwd(self, attens: List, nbrs_enc: torch.Tensor, nbrs_info_this: List) -> torch.Tensor:
170 | """Computes dynamic state encoding.
171 | Computes dynica state encoding with precomputed attention tensor and the
172 | RNN based encoding.
173 | Args:
174 | attens is a list of [ [slots x num_neighbors]]
175 | nbrs_enc is num_agents by input_dim
176 | Returns:
177 | feature vector
178 | """
179 | out = []
180 | counter = 0
181 | for n in range(len(nbrs_info_this)):
182 | list_of_nbrs = nbrs_info_this[n]
183 | if len(list_of_nbrs) > 0:
184 | counter2 = counter+len(list_of_nbrs)
185 | nbr_feat = nbrs_enc[counter:counter2,:]
186 | out.append( torch.mm( attens[n], nbr_feat ) )
187 | counter = counter2
188 | else:
189 | out.append( torch.zeros(self.num_slots, nbrs_enc.shape[1] ).to(nbrs_enc.device) )
190 | # if no neighbors found, use all zeros.
191 | st_enc = torch.stack(out, dim=0).view(len(out),-1) # num_agents by slots*enc dim
192 | return self.sec_hist_net(st_enc)
193 |
194 | def rbf_state_enc_pos_fwd(self, attens: List, ref_pos: torch.Tensor, fut_t: torch.Tensor, flatten_inds: torch.Tensor, chunks: List) -> torch.Tensor:
195 | """Computes the features from dynamic attention for interactive rollouts.
196 | Args:
197 | attens is a list of [ [slots x num_neighbors]]
198 | ref_pos should be (num_agents by 2)
199 | Returns:
200 | feature vector
201 | """
202 | fut = fut_t + ref_pos #convert to 'global' frame
203 | nbr_feat = torch.index_select( fut, 0, flatten_inds)
204 | splits = torch.split(nbr_feat, chunks, dim=0) #type: ignore
205 | out = []
206 | for n, nbr_feat in enumerate(splits):
207 | out.append( torch.mm( attens[n], nbr_feat - ref_pos[n,:] ) )
208 | pos_enc = torch.stack(out, dim=0).view(len(attens),-1) # num_agents by slots*enc dim
209 | return self.sec_pos_net(pos_enc)
210 |
211 |
212 | def forward_mfp(self, hist:torch.Tensor, nbrs:torch.Tensor, masks:torch.Tensor, context:Any,
213 | nbrs_info:List, fut:torch.Tensor, bStepByStep:bool,
214 | use_forcing:Optional[Union[None,int]]=None) -> Tuple[List[torch.Tensor], Any]:
215 | """Forward propagation function for the MFP
216 |
217 | Computes dynamic state encoding with precomputed attention tensor and the
218 | RNN based encoding.
219 | Args:
220 | hist: Trajectory history.
221 | nbrs: Neighbors.
222 | masks: Neighbors mask.
223 | context: contextual information in image form (if used).
224 | nbrs_info: information as to which other agents are neighbors.
225 | fut: Future Trajectory.
226 | bStepByStep: During rollout, interactive or independent.
227 | use_forcing: Teacher-forcing or classmate forcing.
228 |
229 | Returns:
230 | fut_pred: a list of predictions, one for each mode.
231 | modes_pred: prediction over latent modes.
232 | """
233 | use_forcing = self.use_forcing if use_forcing==None else use_forcing
234 |
235 | # Normalize to reference position.
236 | ref_pos = hist[-1,:,:]
237 | hist = hist - ref_pos.view(1,-1,2)
238 |
239 | # Encode history trajectories.
240 | if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
241 | _, hist_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(hist)))
242 | else:
243 | _,(hist_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist))) #hist torch.Size([16, 128, 2])
244 |
245 | if self.use_gru:
246 | hist_enc = hist_enc.permute(1,0,2).contiguous()
247 | hist_enc = self.leaky_relu(self.dyn_emb( hist_enc.view(hist_enc.shape[0], -1) ))
248 | else:
249 | hist_enc = self.leaky_relu(self.dyn_emb(hist_enc.view(hist_enc.shape[1],hist_enc.shape[2]))) #torch.Size([128, 32])
250 |
251 | num_nbrs = sum([len(nbs) for nb_id, nbs in nbrs_info[0].items() ])
252 | if num_nbrs > 0:
253 | nbrs_ref_pos = nbrs[-1,:,:]
254 | nbrs = nbrs - nbrs_ref_pos.view(1,-1,2) # normalize
255 |
256 | # Forward pass for all neighbors.
257 | if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU):
258 | _, nbrs_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
259 | else:
260 | _, (nbrs_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs)))
261 |
262 | if self.use_gru:
263 | nbrs_enc = nbrs_enc.permute(1,0,2).contiguous()
264 | nbrs_enc = nbrs_enc.view(nbrs_enc.shape[0], -1)
265 | else:
266 | nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2])
267 |
268 | attens = self.rbf_state_enc_get_attens(nbrs_enc, ref_pos, nbrs_info[0])
269 | nbr_atten_enc = self.rbf_state_enc_hist_fwd(attens, nbrs_enc, nbrs_info[0])
270 |
271 | else: # if have no neighbors
272 | attens = None # type: ignore
273 | nbr_atten_enc = torch.zeros( 1, self.nbr_atten_embedding_size, dtype=torch.float32, device=masks.device )
274 |
275 | if self.use_context: #context encoding
276 | context_enc = self.relu(self.context_conv( context ))
277 | context_enc = self.context_maxpool( self.context_conv2( context_enc ))
278 | context_enc = self.relu(self.context_conv3(context_enc))
279 | context_enc = self.context_fc( context_enc.view( context_enc.shape[0], -1) )
280 |
281 | enc = torch.cat((nbr_atten_enc, hist_enc, context_enc),1)
282 | else:
283 | enc = torch.cat((nbr_atten_enc, hist_enc),1)
284 | # e.g. nbr_atten_enc: [num_agents by 80], hist_enc: [num_agents by 32], enc would be [num_agents by 112]
285 |
286 | ######################################################################################################
287 | modes_pred = None if self.modes==1 else self.softmax(self.op_modes(enc))
288 | fut_pred = self.decode(enc, attens, nbrs_info[0], ref_pos, fut, bStepByStep, use_forcing)
289 | return fut_pred, modes_pred
290 |
291 | def decode(self, enc: torch.Tensor, attens:List, nbrs_info_this:List, ref_pos:torch.Tensor, fut:torch.Tensor, bStepByStep:bool, use_forcing:Any ) -> List[torch.Tensor]:
292 | """Decode the future trajectory using RNNs.
293 |
294 | Given computed feature vector, decode the future with multimodes, using
295 | dynamic attention and either interactive or non-interactive rollouts.
296 | Args:
297 | enc: encoded features, one per agent.
298 | attens: attentional weights, list of objs, each with dimenstion of [8 x 4] (e.g.)
299 | nbrs_info_this: information on who are the neighbors
300 | ref_pos: the current postion (reference position) of the agents.
301 | fut: future trajectory (only useful for teacher or classmate forcing)
302 | bStepByStep: interactive or non-interactive rollout
303 | use_forcing: 0: None. 1: Teacher-forcing. 2: classmate forcing.
304 |
305 | Returns:
306 | fut_pred: a list of predictions, one for each mode.
307 | modes_pred: prediction over latent modes.
308 | """
309 | if not bStepByStep: # Non-interactive rollouts
310 | enc = enc.repeat(self.out_length, 1, 1)
311 | pos_enc = torch.zeros( self.out_length, enc.shape[1], self.posi_enc_dim+self.posi_enc_ego_dim, device=enc.device )
312 | enc2 = torch.cat( (enc, pos_enc), dim=2)
313 | fut_preds = []
314 | for k in range(self.modes):
315 | h_dec, _ = self.dec_lstm[k](enc2)
316 | h_dec = h_dec.permute(1, 0, 2)
317 | fut_pred = self.op[k](h_dec)
318 | fut_pred = fut_pred.permute(1, 0, 2) #torch.Size([nSteps, num_agents, 5])
319 |
320 | fut_pred = Gaussian2d(fut_pred)
321 | fut_preds.append(fut_pred)
322 | return fut_preds
323 | else:
324 | batch_sz = enc.shape[0]
325 | inds = []
326 | chunks = []
327 | for n in range(len(nbrs_info_this)):
328 | chunks.append( len(nbrs_info_this[n]) )
329 | for nbr in nbrs_info_this[n]:
330 | inds.append(nbr[0])
331 | flat_index = torch.LongTensor(inds).to(ref_pos.device) # type: ignore
332 |
333 | fut_preds = []
334 | for k in range(self.modes):
335 | direc = 2 if self.bi_direc else 1
336 | hidden = torch.zeros(self.num_layers*direc, batch_sz, self.decoder_size).to(fut.device)
337 | preds: List[torch.Tensor] = []
338 | for t in range(self.out_length):
339 | if t == 0: # Intial timestep.
340 | if use_forcing == 0:
341 | pred_fut_t = torch.zeros_like(fut[t,:,:])
342 | ego_fut_t = pred_fut_t
343 | elif use_forcing == 1:
344 | pred_fut_t = fut[t,:,:]
345 | ego_fut_t = pred_fut_t
346 | else:
347 | pred_fut_t = fut[t,:,:]
348 | ego_fut_t = torch.zeros_like(pred_fut_t)
349 | else:
350 | if use_forcing == 0:
351 | pred_fut_t = preds[-1][:,:,:2].squeeze()
352 | ego_fut_t = pred_fut_t
353 | elif use_forcing == 1:
354 | pred_fut_t = fut[t,:,:]
355 | ego_fut_t = pred_fut_t
356 | else:
357 | pred_fut_t = fut[t,:,:]
358 | ego_fut_t = preds[-1][:,:,:2]
359 |
360 | if attens == None:
361 | pos_enc = torch.zeros(batch_sz, self.posi_enc_dim, device=enc.device )
362 | else:
363 | pos_enc = self.rbf_state_enc_pos_fwd(attens, ref_pos, pred_fut_t, flat_index, chunks )
364 |
365 | enc_large = torch.cat( ( enc.view(1,enc.shape[0],enc.shape[1]),
366 | pos_enc.view(1,batch_sz, self.posi_enc_dim),
367 | ego_fut_t.view(1, batch_sz, self.posi_enc_ego_dim ) ), dim=2 )
368 |
369 | out, hidden = self.dec_lstm[k]( enc_large, hidden)
370 | pred = Gaussian2d(self.op[k](out))
371 | preds.append( pred )
372 | fut_pred_k = torch.cat(preds,dim=0)
373 | fut_preds.append(fut_pred_k)
374 | return fut_preds
375 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/my_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from typing import List, Set, Dict, Tuple, Optional, Union
7 | from scipy import special
8 | import numpy as np
9 | import torch
10 | import json
11 | import os
12 |
13 |
14 | """Helper functions for loss or normalizations.
15 | """
16 |
17 | def Gaussian2d(x: torch.Tensor) -> torch.Tensor :
18 | """Computes the parameters of a bivariate 2D Gaussian."""
19 | x_mean = x[:,:,0]
20 | y_mean = x[:,:,1]
21 | sigma_x = torch.exp(x[:,:,2])
22 | sigma_y = torch.exp(x[:,:,3])
23 | rho = torch.tanh(x[:,:,4])
24 | return torch.stack([x_mean, y_mean, sigma_x, sigma_y, rho], dim=2)
25 |
26 | def nll_loss( pred: torch.Tensor, data: torch.Tensor, mask:torch.Tensor ) -> torch.Tensor :
27 | """NLL averages across steps, samples, and dimensions(x,y)."""
28 | x_mean = pred[:,:,0]
29 | y_mean = pred[:,:,1]
30 | x_sigma = pred[:,:,2]
31 | y_sigma = pred[:,:,3]
32 | rho = pred[:,:,4]
33 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore
34 | x = data[:,:, 0]; y = data[:,:, 1]
35 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2)
36 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr)
37 |
38 | results = results*mask[:,:,0]
39 | assert torch.sum(mask) > 0.0
40 | return torch.sum(results)/torch.sum(mask[:,:,0])
41 |
42 | def nll_loss_per_sample( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor :
43 | """NLL averages across steps and dimensions, but not samples (agents)."""
44 | x_mean = pred[:,:,0]
45 | y_mean = pred[:,:,1]
46 | x_sigma = pred[:,:,2]
47 | y_sigma = pred[:,:,3]
48 | rho = pred[:,:,4]
49 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore
50 | x = data[:,:, 0]; y = data[:,:, 1]
51 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2)
52 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr)
53 | results = results*mask[:,:,0] # nSteps by nBatch
54 | return torch.sum(results, dim=0)/torch.sum(mask[:,:,0], dim=0)
55 |
56 | def nll_loss_test( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
57 | """NLL for testing cases, returns a vector over future timesteps."""
58 | x_mean = pred[:,:,0]
59 | y_mean = pred[:,:,1]
60 | x_sigma = pred[:,:,2]
61 | y_sigma = pred[:,:,3]
62 | rho = pred[:,:,4]
63 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore
64 | x = data[:,:, 0]; y = data[:,:, 1]
65 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2)
66 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr)
67 | results = results*mask[:,:,0] # nSteps by nBatch
68 | assert torch.sum(mask) > 0.0
69 | counts = torch.sum(mask[:, :, 0], dim=1)
70 | return torch.sum(results, dim=1), counts
71 |
72 | def mse_loss( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor:
73 | """Mean squared error loss."""
74 | x_mean = pred[:,:,0]
75 | y_mean = pred[:,:,1]
76 | x = data[:,:, 0]
77 | y = data[:,:, 1]
78 | results = torch.pow(x-x_mean, 2) + torch.pow(y-y_mean, 2)
79 | results = results*mask[:,:,0]
80 | return torch.sum(results)/torch.sum(mask[:,:,0])
81 |
82 | def mse_loss_test( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]:
83 | """Mean squared error loss for test time."""
84 | x_mean = pred[:,:,0]
85 | y_mean = pred[:,:,1]
86 | x = data[:,:,0]
87 | y = data[:,:,1]
88 | results = torch.pow(x-x_mean, 2) + torch.pow(y-y_mean, 2)
89 | results = results*mask[:,:,0]
90 | counts = torch.sum(mask[:,:,0],dim=1)
91 | lossVal = torch.sum(results,dim=1)
92 | return lossVal, counts
93 |
94 | def logsumexp(inputs: torch.Tensor, dim: Optional[int] =None, keepdim: Optional[bool] =False) -> torch.Tensor:
95 | if dim is None:
96 | inputs = inputs.view(-1)
97 | dim = 0
98 | s, _ = torch.max(inputs, dim=dim, keepdim=True)
99 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log()
100 | if not keepdim:
101 | outputs = outputs.squeeze(dim)
102 | return outputs
103 |
104 | def nll_loss_test_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, y_mean: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor] :
105 | """NLL loss multimodes for test time."""
106 | modes = len(pred)
107 | nSteps, batch_sz, dim = pred[0].shape
108 | total = torch.zeros(mask.shape[0],mask.shape[1], modes).to(y_mean.device)
109 | count = 0
110 | for k in range(modes):
111 | wts = modes_pred[:,k]
112 | wts = wts.repeat(nSteps,1)
113 | y_pred = pred[k]
114 | if y_mean is not None:
115 | x_pred_mean = y_pred[:, :, 0]+y_mean[:,0].view(-1,1)
116 | y_pred_mean = y_pred[:, :, 1]+y_mean[:,1].view(-1,1)
117 | else:
118 | x_pred_mean = y_pred[:, :, 0]
119 | y_pred_mean = y_pred[:, :, 1]
120 | x_sigma = y_pred[:, :, 2]
121 | y_sigma = y_pred[:, :, 3]
122 | rho = y_pred[:, :, 4]
123 | ohr = torch.pow(1 - torch.pow(rho, 2), -0.5) # type: ignore
124 | x = data[:, :, 0]
125 | y = data[:, :, 1]
126 | out = -(torch.pow(ohr, 2) * (torch.pow(x_sigma, 2) * torch.pow(x - x_pred_mean, 2) + torch.pow(y_sigma, 2) * torch.pow(y - y_pred_mean,2)
127 | -2 * rho * torch.pow(x_sigma, 1) * torch.pow(y_sigma, 1) * (x - x_pred_mean) * (y - y_pred_mean)) - torch.log(x_sigma * y_sigma * ohr))
128 | total[:, :, count] = out + torch.log(wts)
129 | count += 1
130 | total = -logsumexp(total,dim = 2)
131 | total = total * mask[:,:,0]
132 | lossVal = torch.sum(total,dim=1)
133 | counts = torch.sum(mask[:,:,0],dim=1)
134 | return lossVal, counts
135 |
136 | def nll_loss_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, noise: Optional[float]=0.0 ) -> float:
137 | """NLL loss multimodes for training.
138 | Args:
139 | pred is a list (with N modes) of predictions
140 | data is ground truth
141 | noise is optional
142 | """
143 | modes = len(pred)
144 | nSteps, batch_sz, dim = pred[0].shape
145 | log_lik = np.zeros( (batch_sz, modes) )
146 | with torch.no_grad():
147 | for kk in range(modes):
148 | nll = nll_loss_per_sample(pred[kk], data, mask)
149 | log_lik[:,kk] = -nll.cpu().numpy()
150 |
151 | priors = modes_pred.detach().cpu().numpy()
152 |
153 | log_posterior_unnorm = log_lik + np.log(priors).reshape((-1, modes)) #[TotalObjs, net.modes]
154 | log_posterior_unnorm += np.random.randn( *log_posterior_unnorm.shape)*noise
155 | log_posterior = log_posterior_unnorm - special.logsumexp( log_posterior_unnorm, axis=1 ).reshape((batch_sz, 1))
156 | post_pr = np.exp(log_posterior) #[TotalObjs, net.modes]
157 |
158 | post_pr = torch.tensor(post_pr).float().to(data.device)
159 | loss = 0.0
160 | for kk in range(modes):
161 | nll_k = nll_loss_per_sample(pred[kk], data, mask)*post_pr[:,kk]
162 | loss += nll_k.sum()/float(batch_sz)
163 |
164 | kl_loss = torch.nn.KLDivLoss(reduction='batchmean') #type: ignore
165 | loss += kl_loss( torch.log(modes_pred), post_pr)
166 | return loss
167 |
168 | ################################################################################
169 | def load_json_file(json_filename: str) -> dict:
170 | with open(json_filename) as json_file:
171 | json_dictionary = json.load(json_file)
172 | return json_dictionary
173 |
174 | def write_json_file(json_filename: str, json_dict: dict, pretty: Optional[bool]=False) -> None:
175 | with open(os.path.expanduser(json_filename), 'w') as outfile:
176 | if pretty:
177 | json.dump(json_dict, outfile, sort_keys=True, indent = 2)
178 | else:
179 | json.dump(json_dict, outfile, sort_keys=True,)
180 |
181 | def pi(obj: Union[torch.Tensor, np.ndarray]) -> None:
182 | """ Prints out some info."""
183 | if isinstance(obj, torch.Tensor):
184 | print(str(obj.shape), end=' ')
185 | print(str(obj.device), end=' ')
186 | print( 'min:', float(obj.min() ), end=' ')
187 | print( 'max:', float(obj.max() ), end=' ')
188 | print( 'std:', float(obj.std() ), end=' ')
189 | print(str(obj.dtype) )
190 | elif isinstance(obj, np.ndarray):
191 | print(str(obj.shape), end=' ')
192 | print( 'min:', float(obj.min() ), end=' ')
193 | print( 'max:', float(obj.max() ), end=' ')
194 | print( 'std:', float(obj.std() ), end=' ')
195 | print(str(obj.dtype) )
196 |
197 | def compute_angles(x_mean: torch.Tensor, num_steps:int=3) -> torch.Tensor:
198 | """Compute the 2d angle of trajectories.
199 | Args:
200 | x_mean is [nSteps, nObjs, dim]
201 | """
202 | nSteps, nObjs, dim = x_mean.shape
203 | thetas = np.zeros( (nObjs, num_steps))
204 | for k in range(num_steps):
205 | for o in range(nObjs):
206 | diff = x_mean[k+1,o,:] - x_mean[k,o,:]
207 | thetas[o,k] = np.arctan2(diff[1], diff[0])
208 | return thetas.mean(axis=1)
209 |
210 | def rotate_to(data: np.ndarray, theta0: np.ndarray, x0: np.ndarray) -> np.ndarray:
211 | """Rotate data about location x0 with theta0 in radians.
212 | Args:
213 | data is [nSteps, dim] or [nSteps, nObjs, dim]
214 | """
215 | rot = np.array( [ [np.cos(theta0), np.sin(theta0) ],
216 | [ -np.sin(theta0), np.cos(theta0)] ] )
217 | if len(data.shape) == 2:
218 | return np.dot( data - x0, rot.T)
219 | else:
220 | nSteps, nObjs, dim = data.shape
221 | return np.dot( data.reshape((-1,dim))-x0, rot.T).reshape((nSteps, nObjs, dim))
222 |
223 | def rotate_to_inv(data: np.ndarray, theta0: np.ndarray, x0: np.ndarray) -> np.ndarray:
224 | """Inverse rotate data about location x0 with theta0 in radians.
225 | Args:
226 | data is [nSteps, dim] or [nSteps, nObjs, dim]
227 | """
228 | rot = np.array( [ [np.cos(theta0), -np.sin(theta0) ],
229 | [ np.sin(theta0), np.cos(theta0)] ] )
230 | if len(data.shape) == 2:
231 | return np.dot( data, rot.T) + x0
232 | else:
233 | nSteps, nObjs, dim = data.shape
234 | return (np.dot( data.reshape((-1,dim)), rot.T)+x0).reshape((nSteps, nObjs, dim))
235 |
236 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/ngsim_data/.gitignore:
--------------------------------------------------------------------------------
1 | *.mat
2 |
--------------------------------------------------------------------------------
/multiple_futures_prediction/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/py.typed
--------------------------------------------------------------------------------
/multiple_futures_prediction/train_ngsim.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any
7 | import numpy as np
8 | np.set_printoptions(suppress=1)
9 | import time
10 | import math
11 | import glob
12 | from attrdict import AttrDict
13 | import gin
14 | import torch
15 | from torch.utils.data import DataLoader
16 | from multiple_futures_prediction.dataset_ngsim import *
17 | from multiple_futures_prediction.model_ngsim import mfpNet
18 | from multiple_futures_prediction.my_utils import *
19 |
20 | def eval(metric: str, net: torch.nn.Module, params: AttrDict, data_loader: DataLoader, bStepByStep: bool,
21 | use_forcing: int, y_mean: np.ndarray, num_batches: int, dataset_name: str) -> torch.Tensor:
22 | """Evaluation function for validation and test data.
23 |
24 | Given a MFP network, data loader, evaulate either NLL or RMSE error.
25 | """
26 | print('eval ', dataset_name)
27 | num = params.fut_len_orig_hz//params.subsampling
28 | lossVals = torch.zeros(num)
29 | counts = torch.zeros(num)
30 |
31 | for i, data in enumerate(data_loader):
32 | if i >= num_batches:
33 | break
34 | hist, nbrs, mask, fut, mask, context, nbrs_info = data
35 | if params.use_cuda:
36 | hist = hist.cuda()
37 | nbrs = nbrs.cuda()
38 | mask = mask.cuda()
39 | fut = fut.cuda()
40 | mask = mask.cuda()
41 | if context is not None:
42 | context = context.cuda()
43 |
44 | if metric == 'nll':
45 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing)
46 | if params.modes == 1:
47 | if params.remove_y_mean:
48 | fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device)
49 | l, c = nll_loss_test(fut_preds[0], fut, mask)
50 | else:
51 | l, c = nll_loss_test_multimodes(fut_preds, fut, mask, modes_pred, y_mean.to(fut.device) )
52 | else: # RMSE error
53 | assert params.modes == 1
54 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing)
55 | if params.modes == 1:
56 | if params.remove_y_mean:
57 | fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device)
58 | l, c = mse_loss_test(fut_preds[0], fut, mask)
59 |
60 | lossVals += l.detach().cpu()
61 | counts += c.detach().cpu()
62 |
63 | if metric == 'nll':
64 | err = lossVals / counts
65 | print(lossVals / counts)
66 | else:
67 | err = torch.pow(lossVals / counts,0.5)*0.3048
68 | print( err ) # Calculate RMSE and convert from feet to meters
69 | return err
70 |
71 | def get_mean( train_data_loader: DataLoader, batches: Optional[int]=200 ) -> np.ndarray:
72 | """Compute the means over some samples from the training data."""
73 | yy = []
74 | counters = None
75 | for i, data in enumerate(train_data_loader):
76 | if i > batches: # type: ignore
77 | break
78 | hist, nbrs, _, fut, fut_mask, _, _ = data
79 | target = fut.cpu().numpy()
80 | valid = fut_mask.cpu().numpy().sum(axis=1)
81 |
82 | if counters is None:
83 | counters = np.zeros_like( valid )
84 | counters += valid
85 |
86 | isinvalid = (fut_mask == 0)
87 | target[isinvalid] = 0
88 | yy.append( target )
89 |
90 | Y = np.concatenate(yy, axis=1)
91 | y_mean= np.divide( np.sum(Y,axis=1), counters)
92 | return y_mean
93 |
94 | def setup_logger(root_dir: str, SCENARIO_NAME: str ) -> Tuple[Any, Any]:
95 | """Setup the data logger for logging."""
96 | import glob
97 | from subprocess import call
98 | import time
99 | import datetime
100 | import os
101 |
102 | timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y.%m.%d_%H.%M.%S')
103 | logging_dir = root_dir+"%s_%s/"%(SCENARIO_NAME, timestamp)
104 | if not os.path.isdir(logging_dir):
105 | os.makedirs(logging_dir)
106 | os.makedirs(logging_dir+'/checkpoints')
107 | print ("! " + logging_dir + " CREATED!")
108 |
109 | logger_file = open(logging_dir+'/log', 'w')
110 | return logger_file, logging_dir
111 |
112 | ################################################################################
113 | @gin.configurable
114 | class Params(object):
115 | def __init__(self, log:bool=False, # save checkpoints?
116 | modes:int=2, # how many latent modes
117 | use_cuda:bool=True,
118 | encoder_size:int=16, # encoder latent layer size
119 | decoder_size:int=16, # decoder latent layer size
120 | subsampling:int=2, # factor subsample in time
121 | hist_len_orig_hz:int=30, # number of original history samples
122 | fut_len_orig_hz:int=50, # number of original future samples
123 | dyn_embedding_size:int=32, # dynamic embedding size
124 | input_embedding_size:int=32, # input embedding size
125 | dec_nbr_enc_size:int=8, # decoder neighbors encode size
126 | nbr_atten_embedding_size:int=80, # neighborhood attention embedding size
127 | seed:int=1234,
128 | remove_y_mean:bool=False, # normalize by remove mean of the future trajectory
129 | use_gru:bool=True, # GRUs instead of LSTMs
130 | bi_direc:bool=False, # bidrectional
131 | self_norm:bool=False, # normalize with respect to the current time
132 | data_aug:bool=False, # data augment
133 | use_context:bool=False, # use contextual image as additional input
134 | nll:bool=True, # negative log-liklihood loss
135 | use_forcing:int=0, # teacher forcing
136 | iter_per_err:int=100, # iterations to display errors
137 | iter_per_eval:int=1000, # iterations to eval on validation set
138 | pre_train_num_updates:int=200000, # how many iterations for pretraining
139 | updates_div_by_10:int=100000, # at what iteration to divide the learning rate by 10.0
140 | nbr_search_depth:int=10, # how deep do we search for neighbors
141 | lr_init:float=0.001, # initial learning rate
142 | min_lr:float=0.00005, # minimal learning rate
143 | iters_per_save:int=1500 ) -> None :
144 | self.params = AttrDict(locals())
145 | def __call__(self) -> Any:
146 | return self.params
147 | ################################################################################
148 |
149 | def train( params: AttrDict ) -> Any :
150 | """Main training function."""
151 | torch.manual_seed( params.seed ) #type: ignore
152 | np.random.seed( params.seed )
153 |
154 | ############################
155 | batch_size = 1
156 | data_hz = 10
157 | ns_between_samples = (1.0/data_hz)*1e9
158 | d_s = params.subsampling
159 | t_h = params.hist_len_orig_hz
160 | t_f = params.fut_len_orig_hz
161 | NUM_WORKERS = 1
162 |
163 | DATA_PATH = 'multiple_futures_prediction/'
164 |
165 | # Loading the dataset.
166 | train_set = NgsimDataset(DATA_PATH + 'ngsim_data/TrainSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
167 | params.data_aug, params.use_context, params.nbr_search_depth)
168 | val_set = NgsimDataset(DATA_PATH + 'ngsim_data/ValSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
169 | params.data_aug, params.use_context, params.nbr_search_depth)
170 | test_set = NgsimDataset(DATA_PATH + 'ngsim_data/TestSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm,
171 | params.data_aug, params.use_context, params.nbr_search_depth)
172 | train_data_loader = DataLoader(train_set,batch_size=batch_size, shuffle=1, num_workers=NUM_WORKERS, collate_fn=train_set.collate_fn,drop_last=True) # type: ignore
173 | val_data_loader = DataLoader(val_set,batch_size=batch_size,shuffle=0, num_workers=NUM_WORKERS, collate_fn=val_set.collate_fn,drop_last=True) #type: ignore
174 | test_data_loader = DataLoader(test_set,batch_size=batch_size, shuffle=0, num_workers=NUM_WORKERS,collate_fn=test_set.collate_fn,drop_last=False) #type: ignore
175 |
176 | # Compute or load existing mean over future trajectories.
177 | if os.path.exists(DATA_PATH+'ngsim_data/y_mean.pkl'):
178 | y_mean = pickle.load( open(DATA_PATH+'ngsim_data/y_mean.pkl', 'rb'))
179 | else:
180 | y_mean = get_mean(train_data_loader)
181 | pickle.dump( y_mean, open(DATA_PATH+'ngsim_data/y_mean.pkl', 'wb') )
182 |
183 | # Initialize network
184 | net = mfpNet( params )
185 | if params.use_cuda:
186 | net = net.cuda() #type: ignore
187 |
188 | net.y_mean = y_mean
189 | y_mean = torch.tensor(net.y_mean)
190 |
191 | if params.log:
192 | logger_file, logging_dir = setup_logger(DATA_PATH+"./checkpts/", 'NGSIM' )
193 |
194 | train_loss: List = []
195 | val_loss: List = []
196 |
197 | MODE='Pre' # For efficiency, we first pre-train w/o interactive rollouts.
198 | num_updates = 0
199 | optimizer = None
200 |
201 | for epoch_num in range(20):
202 | if MODE == 'EndPre':
203 | MODE = 'Train'
204 | print('Training with interactive rollouts.')
205 | bStepByStep = True
206 | else:
207 | print('Pre-training without interactive rollouts.')
208 | bStepByStep = False
209 |
210 | # Average losses.
211 | avg_tr_loss = 0.
212 | avg_tr_time = 0.
213 | loss_counter = 0.0
214 |
215 | for i, data in enumerate(train_data_loader):
216 | if num_updates > params.pre_train_num_updates and MODE == 'Pre':
217 | MODE = 'EndPre'
218 | break
219 |
220 | lr_fac = np.power(0.1, num_updates // params.updates_div_by_10 )
221 | lr = max( params.min_lr, params.lr_init*lr_fac)
222 | if optimizer is None:
223 | optimizer = torch.optim.Adam(net.parameters(), lr=lr) #type: ignore
224 | elif lr != optimizer.defaults['lr']:
225 | optimizer = torch.optim.Adam(net.parameters(), lr=lr)
226 |
227 | st_time = time.time()
228 | hist, nbrs, mask, fut, mask, context, nbrs_info = data
229 |
230 | if params.remove_y_mean:
231 | fut = fut-y_mean.unsqueeze(1)
232 |
233 | if params.use_cuda:
234 | hist = hist.cuda()
235 | nbrs = nbrs.cuda()
236 | mask = mask.cuda()
237 | fut = fut.cuda()
238 | mask = mask.cuda()
239 | if context is not None:
240 | context = context.cuda()
241 |
242 | # Forward pass.
243 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep)
244 | if params.modes == 1:
245 | l = nll_loss(fut_preds[0], fut, mask)
246 | else:
247 | l = nll_loss_multimodes(fut_preds, fut, mask, modes_pred) # type: ignore
248 |
249 | # Backprop.
250 | optimizer.zero_grad()
251 | l.backward()
252 | torch.nn.utils.clip_grad_norm_(net.parameters(), 10) #type: ignore
253 | optimizer.step()
254 | num_updates += 1
255 |
256 | batch_time = time.time()-st_time
257 | avg_tr_loss += l.item()
258 | avg_tr_time += batch_time
259 |
260 | effective_batch_sz = float(hist.shape[1])
261 | if num_updates % params.iter_per_err == params.iter_per_err-1:
262 | print("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
263 | format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr)
264 | train_loss.append(avg_tr_loss/100)
265 |
266 | if params.log:
267 | msg_str_ = ("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:",
268 | format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr)
269 | msg_str = str([str(ss) for ss in msg_str_])
270 | logger_file.write(msg_str+'\n')
271 | logger_file.flush()
272 |
273 | avg_tr_loss = 0.
274 | if num_updates % params.iter_per_eval == params.iter_per_eval-1:
275 | print("Starting eval")
276 | val_nll_err = eval( 'nll', net, params, val_data_loader, bStepByStep,
277 | use_forcing=params.use_forcing, y_mean=y_mean,
278 | num_batches=500, dataset_name='val_dl nll')
279 |
280 | if params.log:
281 | logger_file.write('val nll: ' + str(val_nll_err)+'\n')
282 | logger_file.flush()
283 |
284 | # Save weights.
285 | if params.log and num_updates % params.iters_per_save == params.iters_per_save-1:
286 | msg_str = '\nSaving state, update iter:%d %s'%(num_updates, logging_dir)
287 | print(msg_str)
288 | logger_file.write( msg_str ); logger_file.flush()
289 | torch.save(net.state_dict(), logging_dir + '/checkpoints/ngsim_%06d'%num_updates + '.pth') #type: ignore
290 |
291 |
292 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright (C) 2020 Apple Inc. All rights reserved.
3 | #
4 |
5 | [mypy]
6 |
7 | ignore_missing_imports = True
8 | # You may need to copy, paste and uncomment the following snippet for libraries that do not
9 | # support mypy yet.
10 |
11 | #[mypy-