├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── multiple_futures_prediction ├── __init__.py ├── assets │ ├── __init__.py │ └── imgs │ │ ├── .gitignore │ │ ├── mfp_comp_graph.png │ │ └── neurips_mfp_poster.pdf ├── checkpts │ └── .gitignore ├── cmd │ ├── __init__.py │ └── train_ngsim_cmd.py ├── configs │ └── mfp2_ngsim.gin ├── dataset_ngsim.py ├── model_ngsim.py ├── my_utils.py ├── ngsim_data │ └── .gitignore ├── py.typed └── train_ngsim.py ├── mypy.ini ├── pyproject.toml └── requirements.txt /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software have utilized the following copyrighted material, the use of which is hereby acknowledged. 3 | 4 | _____________________ 5 | PyTorch (https://pytorch.org) 6 | We use PyTorch as the training framework for our model. 7 | 8 | From PyTorch: 9 | 10 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 11 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 12 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 13 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 14 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 15 | Copyright (c) 2011-2013 NYU (Clement Farabet) 16 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 17 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 18 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 19 | 20 | _____________________ 21 | Nachiket Deo and Mohan M. Trivedi (https://github.com/nachiket92/conv-social-pooling) 22 | Our initial prototype and certain loss functions are based on the code based provided above, which is distrubted under MIT license. 23 | 24 | MIT License 25 | 26 | Copyright (c) 2018 Nachiket Deo 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. 45 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | ## Before you get started 6 | 7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY: 43 | 44 | This software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multiple Futures Prediction 2 | ## Paper 3 | This software accompanies the paper [**Multiple Futures Prediction**](https://arxiv.org/abs/1911.00997). ([Poster](multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf))
4 | [Yichuan Charlie Tang](https://www.cs.toronto.edu/~tang) and Ruslan Salakhutdinov
5 | Neural Information Processing Systems, 2019. (NeurIPS 2019) 6 | 7 | 8 | Please cite our paper if you find our work useful for your research: 9 | ``` 10 | @article{tang2019mfp, 11 | title={Multiple Futures Prediction}, 12 | author={Tang, Yichuan Charlie and Salakhutdinov, Ruslan}, 13 | booktitle={Advances in neural information processing systems}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | ## Introduction 19 | Multiple Futures Prediction (MFP) is a framework for learning to forecast or predict future trajectories of agents, such as vehicles or pedestrians. A key feature of our framework is that it is able to learn multiple modes or multiple possible futures, by learning directly from trajectory data without annotations. Multi-agent interactions are also taken into account and the framework scales to an arbitrary number of agents in the scene by using a novel dynamic attention mechanism. It currently achieves state-of-the-art results on three vehicle forecasting datasets. 20 | 21 | This research code is for demonstration purposes only. Please see the paper for more details. 22 | 23 | ### Overall Architecture 24 |

25 | 26 | 27 | The Multiple Futures Prediction (MFP) architecture is shown above. For an arbitrary number of agents in the scene, we first use RNNs to encode their past trajectories into feature vectors. Dynamic attentional contextual encoding aggregates interactions and relational information. For each agent, a distribution over its latent modes are then predicted. You can think of the latent modes representing conservative/aggressive behaviors or directional (left vs. right turns) intentions. Given a distribution over the latent modes, decoding RNNs are then employed to decode or forecast future temporal trajectories. The MFP is a latent-variable graphical model and we use the EM algorithm to optimize the evidence lower-bound. 28 | 29 | 30 | ## Getting Started 31 | 32 | ### Prerequisites 33 | This code is tested with Python 3.6, and PyTorch 1.1.0. Conda or Virtualenv is recommended.
34 | Use pip (recent version, e.g. 20.1) to install dependencies, for example: 35 | ``` 36 | python3.6 -m venv .venv # Create new venv 37 | source ./venv/bin/activate # Activate it 38 | pip install -U pip # Update to latest version of pip 39 | pip install -r requirements.txt # Install everything 40 | ``` 41 | 42 | ### Datasets 43 | 44 | #### How to Obtain NGSIM Data: 45 | 46 | 1. Obtain NGSIM Dataset here (US-101 and I-80):
47 | (https://data.transportation.gov/Automobiles/Next-Generation-Simulation-NGSIM-Vehicle-Trajector/8ect-6jqj) 48 | ``` 49 | Specifically you will need these files: 50 | US-101: 51 | '0750am-0805am/trajectories-0750am-0805am.txt' 52 | '0805am-0820am/trajectories-0805am-0820am.txt' 53 | '0820am-0835am/trajectories-0820am-0835am.txt' 54 | 55 | I-80: 56 | '0400pm-0415pm/trajectories-0400-0415.txt' 57 | '0500pm-0515pm/trajectories-0500-0515.txt' 58 | '0515pm-0530pm/trajectories-0515-0530.txt' 59 | ``` 60 | 2. Preprocess dataset with code from Nachiket Deo and Mohan M. Trivedi:
[Convolutional Social Pooling for Vehicle Trajectory Prediction.] (CVPRW, 2018)
61 | (https://github.com/nachiket92/conv-social-pooling)
62 | 63 | 3. From the conv-social-pooling repo, run prepocess_data.m, this should obtain three files:
64 | TrainSet.mat, ValSet.mat, and TestSet.mat. Copy them to the ngsim_data folder. 65 | 66 | ### Usage 67 | 68 | #### Training 69 | ```bash 70 | train_ngsim --config multiple_futures_prediction/configs/mfp2_ngsim.gin 71 | ``` 72 | or 73 | ```bash 74 | python -m multiple_futures_prediction.cmd.train_ngsim_cmd --config multiple_futures_prediction/configs/mfp2_ngsim.gin 75 | ``` 76 | Hyperparameters (e.g. specifying how many modes of MFP) can be specified in the .gin config files. 77 | 78 | Expected training outputs: 79 | ``` 80 | Epoch no: 0 update: 99 | Avg train loss: 57.3198 learning_rate:0.00100 81 | Epoch no: 0 update: 199 | Avg train loss: 4.7679 learning_rate:0.00100 82 | Epoch no: 0 update: 299 | Avg train loss: 4.3250 learning_rate:0.00100 83 | Epoch no: 0 update: 399 | Avg train loss: 4.0717 learning_rate:0.00100 84 | Epoch no: 0 update: 499 | Avg train loss: 3.9722 learning_rate:0.00100 85 | Epoch no: 0 update: 599 | Avg train loss: 3.8525 learning_rate:0.00100 86 | Epoch no: 0 update: 699 | Avg train loss: 3.5253 learning_rate:0.00100 87 | Epoch no: 0 update: 799 | Avg train loss: 3.6077 learning_rate:0.00100 88 | Epoch no: 0 update: 899 | Avg train loss: 3.4526 learning_rate:0.00100 89 | Epoch no: 0 update: 999 | Avg train loss: 3.5830 learning_rate:0.00100 90 | Starting eval 91 | eval val_dl nll 92 | tensor([-1.5164, -0.3173, 0.3902, 0.9374, 1.3751, 1.7362, 2.0362, 2.3008, 93 | 2.5510, 2.7974, 3.0370, 3.2702, 3.4920, 3.7007, 3.8979, 4.0836, 94 | 4.2569, 4.4173, 4.5682, 4.7082, 4.8378, 4.9581, 5.0716, 5.1855, 95 | 5.3239]) 96 | ``` 97 | Depending on the CPU/GPU available, it can take from one to two days to complete 98 | 300K training updates on the NGSIM dataset and match the results in Table 5 of the paper. 99 | 100 | ## License 101 | This code is released under the [LICENSE](LICENSE) terms. 102 | -------------------------------------------------------------------------------- /multiple_futures_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2020 Apple Inc. All rights reserved. 3 | # 4 | 5 | """Demo code for Multiple Futures Prediction Paper (NeurIPS 2019).""" 6 | 7 | __version__ = "0.1.0" 8 | -------------------------------------------------------------------------------- /multiple_futures_prediction/assets/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2020 Apple Inc. All rights reserved. 3 | # 4 | 5 | """This subpackage contains small resources necessary to run this package.""" 6 | -------------------------------------------------------------------------------- /multiple_futures_prediction/assets/imgs/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/.gitignore -------------------------------------------------------------------------------- /multiple_futures_prediction/assets/imgs/mfp_comp_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/mfp_comp_graph.png -------------------------------------------------------------------------------- /multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/assets/imgs/neurips_mfp_poster.pdf -------------------------------------------------------------------------------- /multiple_futures_prediction/checkpts/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/checkpts/.gitignore -------------------------------------------------------------------------------- /multiple_futures_prediction/cmd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/cmd/__init__.py -------------------------------------------------------------------------------- /multiple_futures_prediction/cmd/train_ngsim_cmd.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Dict, Tuple, Optional, Union, Any 2 | from multiple_futures_prediction.train_ngsim import train, Params 3 | import gin 4 | import argparse 5 | 6 | def parse_args() -> Any: 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument('--config', type=str, default='') 9 | return parser.parse_args() 10 | 11 | def main() -> None: 12 | args = parse_args() 13 | gin.parse_config_file( args.config ) 14 | params = Params()() 15 | train(params) 16 | 17 | # python -m multiple_futures_prediction.cmd.train_ngsim_cmd --config multiple_futures_prediction/configs/mfp2_ngsim.gin 18 | if __name__ == '__main__': 19 | main() 20 | 21 | 22 | -------------------------------------------------------------------------------- /multiple_futures_prediction/configs/mfp2_ngsim.gin: -------------------------------------------------------------------------------- 1 | Params.log = 0 # log results and save model checkpoint 2 | Params.modes = 2 # MFP modes 3 | 4 | Params.subsampling = 2 # subsampling factor 5 | Params.hist_len_orig_hz = 30 # history steps at original sampling rate (e.g. 10 hz) 6 | Params.fut_len_orig_hz = 50 # future steps at original sampling rate (e.g. 10 hz) 7 | 8 | Params.encoder_size = 64 # sizes of various neural layers 9 | Params.decoder_size = 128 10 | Params.dyn_embedding_size = 32 11 | Params.input_embedding_size = 32 12 | Params.dec_nbr_enc_size = 8 13 | Params.nbr_atten_embedding_size = 80 14 | 15 | Params.seed = 1234 16 | Params.use_cuda = True 17 | 18 | Params.remove_y_mean = True # remove the mean of the future trajectory 19 | Params.use_gru = True # GRU or LSTM 20 | Params.bi_direc = True # bidrectional RNN 21 | Params.self_norm = True # Normalize prediction targets 22 | Params.data_aug = False # Add noise to data? 23 | Params.use_context = False # use contexture bird's eye view map? 24 | Params.nll = True # Use negative Log-liklihood loss 25 | Params.use_forcing = 0 # 0: no forcing. 1: teacher-forcing. 2: classmates-forcing. 26 | 27 | Params.iter_per_err = 100 28 | Params.iter_per_eval = 1000 29 | Params.iters_per_save = 5000 30 | Params.pre_train_num_updates = 200000 31 | Params.updates_div_by_10 = 100000 32 | Params.nbr_search_depth = 10 # depth of searching for finding 'neighbors' (only for NGSIM dataset) 33 | 34 | Params.lr_init = 0.001 # initial learning rate 35 | Params.min_lr = 0.00005 # minimum learning rate 36 | -------------------------------------------------------------------------------- /multiple_futures_prediction/dataset_ngsim.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any 7 | from torch.utils.data import Dataset, DataLoader 8 | import scipy.io as scp 9 | import numpy as np 10 | import torch 11 | import pickle 12 | import os 13 | import cv2 14 | 15 | # Dataset for pytorch training 16 | class NgsimDataset(Dataset): 17 | def __init__(self, mat_file:str, t_h:int=30, t_f:int=50, d_s:int = 2, 18 | enc_size:int=64, use_gru:bool=False, self_norm:bool=False, 19 | data_aug:bool=False, use_context:bool=False, nbr_search_depth:int= 3, 20 | ds_seed:int=1234) -> None: 21 | self.D = scp.loadmat(mat_file)['traj'] 22 | self.T = scp.loadmat(mat_file)['tracks'] 23 | self.t_h = t_h # length of track history 24 | self.t_f = t_f # length of predicted trajectory 25 | self.d_s = d_s # down sampling rate of all sequences 26 | self.enc_size = enc_size # size of encoder LSTM 27 | self.grid_size = (13,3) # size of context grid 28 | self.enc_fac = 2 if use_gru else 1 29 | self.self_norm = self_norm 30 | self.data_aug = data_aug 31 | self.noise = np.array([[0.5, 2.0]]) 32 | self.dt = 0.1*self.d_s 33 | self.ft_to_m = 0.3048 34 | self.use_context = use_context 35 | if self.use_context: 36 | self.maps = pickle.load(open('data/maps.pkl', 'rb')) 37 | self.nbr_search_depth = nbr_search_depth 38 | 39 | cache_file = 'multiple_futures_prediction/ngsim_data/NgsimIndex_%s.p'%os.path.basename(mat_file) 40 | #build index of [dataset (0 based), veh_id_0b, frame(time)] into a dictionary 41 | if not os.path.exists(cache_file): 42 | self.Index = {} 43 | print('building index...') 44 | for i, row in enumerate(self.D): 45 | key = (int(row[0]-1), int(row[1]-1), int(row[2])) 46 | self.Index[key] = i 47 | print('build index done') 48 | pickle.dump( self.Index, open(cache_file,'wb')) 49 | else: 50 | self.Index = pickle.load( open(cache_file,'rb')) 51 | 52 | self.ind_random = np.arange(len(self.D)) 53 | self.seed = ds_seed 54 | np.random.seed(self.seed) 55 | np.random.shuffle(self.ind_random) 56 | 57 | def __len__(self) -> int: 58 | return len(self.D) 59 | 60 | def convert_pt(self, pt: np.ndarray, dsId0b: int) -> np.ndarray: 61 | """Convert a point from abs coords to pixel coords. 62 | Args: 63 | pt is a 2d x,y coordinate 64 | dsId0b - data set id 0 index based 65 | """ 66 | ft_per_pixel = self.maps[dsId0b]['ft_per_pixel'] 67 | return np.array([int((np.round(pt[0])-self.maps[dsId0b]['x0'])/ft_per_pixel), 68 | int((np.round(pt[1]) - self.maps[dsId0b]['y0'])/ft_per_pixel)]) 69 | 70 | def compute_vel_theta(self, hist: torch.Tensor, frac: Optional[float]=0.5) -> Tuple[np.ndarray, np.ndarray]: 71 | """Estimate velocity and orientation from history trajectory.""" 72 | if hist.shape[0] <= 1: 73 | return np.array([0.0]), np.array([0.0]) 74 | else: 75 | total_wts = 0.0 76 | counter = 0.0 77 | vel = theta = 0.0 78 | for t in range(hist.shape[0]-1,0,-1): 79 | counter += 1.0 80 | wt = np.power(frac, counter) 81 | total_wts += wt 82 | diff = hist[t,:] - hist[t-1,:] 83 | vel += wt*np.linalg.norm(diff)*self.ft_to_m/self.dt 84 | theta += wt*np.arctan2(diff[1],diff[0]) 85 | return np.array([vel/total_wts]), np.array([theta/total_wts]) 86 | 87 | def find_neighbors(self, dsID_0b: int, vehId_0b: int, frame: int) -> Dict: 88 | """Find a list of neighbors w.r.t. self.""" 89 | key = ( int(dsID_0b), int(vehId_0b), int(frame)) 90 | if key not in self.Index: 91 | return {} 92 | 93 | idx = self.Index[key] 94 | grid1b = self.D[idx,8:] 95 | nonzero = np.nonzero(grid1b)[0] 96 | return {vehId_0b: list(zip( grid1b[nonzero].astype(np.int64)-1, nonzero)) } 97 | 98 | def __getitem__(self, idx_: int) -> Tuple[ List, List, Dict, Union[None, np.ndarray] ] : 99 | idx = self.ind_random[idx_] 100 | dsId_1b = self.D[idx, 0].astype(int) 101 | vehId_1b = self.D[idx, 1].astype(int) 102 | dsId_0b = dsId_1b-1 103 | vehId_0b = vehId_1b-1 104 | t = self.D[idx, 2] 105 | grid = self.D[idx,8:] #1-based 106 | 107 | ids = {} #0-based keys 108 | leafs = [vehId_0b] 109 | for _ in range( self.nbr_search_depth ): 110 | new_leafs = [] 111 | for id0b in leafs: 112 | nbr_dict = self.find_neighbors(dsId_0b, id0b, t) 113 | if len(nbr_dict) > 0: 114 | ids.update( nbr_dict ) 115 | if len(nbr_dict[id0b]) > 0: 116 | nbr_id0b = list( zip(*nbr_dict[id0b]))[0] 117 | new_leafs.extend ( nbr_id0b ) 118 | leafs = np.unique( new_leafs ) 119 | leafs = [x for x in leafs if x not in ids] 120 | 121 | sorted_keys = sorted(ids.keys()) # e.g. [1, 3, 4, 5, ... , 74] 122 | id_map = {key: value for (key, value) in zip(sorted_keys, np.arange(len(sorted_keys)).tolist()) } 123 | #obj id to index within a batch 124 | sz = len(ids) 125 | assert sz > 0 126 | 127 | hist = [] 128 | fut = [] 129 | neighbors: Dict[int,List] = {} # key is batch ind, followed by a list of (batch_ind, ego_id, grid/nbr ind) 130 | 131 | for ind, vehId0b in enumerate(sorted_keys): 132 | hist.append( self.getHistory0b(vehId0b,t,dsId_0b) ) # no normalization 133 | fut.append( self.getFuture(vehId0b+1,t,dsId_0b+1 ) ) #subtract off ref pos 134 | neighbors[ind] = [] 135 | for v_id, nbr_ind in ids[ vehId0b ]: 136 | if v_id not in id_map: 137 | k2 = -1 138 | else: 139 | k2 = id_map[v_id] 140 | neighbors[ind].append( (k2, v_id, nbr_ind) ) 141 | 142 | if self.use_context: 143 | x_range_ft = np.array([-15, 15]) 144 | y_range_ft = np.array([-30, 300]) 145 | 146 | pad = int(np.ceil(300/self.maps[dsId_1b-1]['ft_per_pixel'])) # max of all ranges 147 | 148 | if not 'im_color' in self.maps[dsId_1b-1]: 149 | im_big = np.pad(self.maps[dsId_1b-1]['im'], ((pad, pad), (pad, pad)), 'constant', constant_values= 0.0) 150 | self.maps[dsId_1b-1]['im_color'] = (im_big[np.newaxis,...].repeat(3, axis=0)*255.0).astype(np.uint8) 151 | 152 | im = self.maps[dsId_1b-1]['im_color'] 153 | height, width = im.shape[1:] 154 | 155 | ref_pos = self.D[idx,3:5] 156 | im_x, im_y = self.convert_pt( ref_pos, dsId_1b-1 ) 157 | im_x += pad 158 | im_y += pad 159 | 160 | x_range = (x_range_ft/self.maps[dsId_1b-1]['ft_per_pixel']).astype(int) 161 | y_range = (y_range_ft/self.maps[dsId_1b-1]['ft_per_pixel']).astype(int) 162 | 163 | x_range[0] = np.maximum( 0, x_range[0]+im_x ) 164 | x_range[1] = np.minimum( width-1, x_range[1]+im_x ) 165 | y_range[0] = np.maximum( 0, y_range[0]+im_y ) 166 | y_range[1] = np.minimum( height-1, y_range[1]+im_y ) 167 | 168 | im_crop = np.ascontiguousarray(im[:, y_range[0]:y_range[1], x_range[0]:x_range[1]].transpose((1,2,0))) 169 | im_crop[:,:,[0, 1]] = 0 170 | 171 | for _, other in neighbors.items(): 172 | if len(other) == 0: 173 | continue 174 | for k in range( len(other)-1 ): 175 | x1, y1 = self.convert_pt( other[k]+ref_pos, dsId_1b-1 ) 176 | x2, y2 = self.convert_pt( other[k+1]+ref_pos, dsId_1b-1 ) 177 | x1+=pad; y1+=pad; x2+=pad; y2+=pad 178 | cv2.line(im_crop,(x1-x_range[0],y1-y_range[0]),(x2-x_range[0],y2-y_range[0]), (255, 0, 0), 2 ) 179 | 180 | x, y = self.convert_pt( other[-1]+ref_pos, dsId_1b-1 ) 181 | x+=pad; y+=pad 182 | cv2.circle(im_crop, (x-x_range[0], y-y_range[0]), 4, (255, 0, 0), -1) 183 | 184 | for k in range( len(hist)-1 ): 185 | x1, y1 = self.convert_pt( hist[k]+ref_pos, dsId_1b-1 ) 186 | x2, y2 = self.convert_pt( hist[k+1]+ref_pos, dsId_1b-1 ) 187 | x1+=pad; y1+=pad; x2+=pad; y2+=pad 188 | cv2.line(im_crop,(x1-x_range[0],y1-y_range[0]),(x2-x_range[0],y2-y_range[0]), (0, 255, 0), 3 ) 189 | 190 | cv2.circle(im_crop, (im_x-x_range[0], im_y-y_range[0]), 5, (0, 255, 0), -1) 191 | assert im_crop.shape == (660,60,3) 192 | im_crop = im_crop.transpose((2,0,1)) 193 | else: 194 | im_crop = None 195 | return hist, fut, neighbors, im_crop # neighbors is a list of all vehicles in the batch 196 | 197 | def getHistory(self, vehId: int, t: int, refVehId: int, dsId: int) -> np.ndarray: 198 | """Get trajectory history. VehId and refVehId are 1-based.""" 199 | if vehId == 0: 200 | return np.empty([0,2]) 201 | else: 202 | if self.T.shape[1]<=vehId-1: 203 | return np.empty([0,2]) 204 | vehTrack = self.T[dsId-1][vehId-1].transpose() 205 | if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0: 206 | return np.empty([0,2]) 207 | else: 208 | refTrack = self.T[dsId-1][refVehId-1].transpose() 209 | found = np.where(refTrack[:,0]==t) 210 | refPos = refTrack[found][0,1:3] 211 | 212 | stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h) 213 | enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1 214 | hist = vehTrack[stpt:enpt:self.d_s,1:3]-refPos 215 | 216 | if self.data_aug: 217 | hist += np.random.randn( hist.shape[0],hist.shape[1] )*self.noise 218 | 219 | if len(hist) < self.t_h//self.d_s + 1: 220 | return np.empty([0,2]) 221 | return hist 222 | 223 | def getFuture(self, vehId:int, t:int, dsId: int) -> np.ndarray : 224 | """Get future trajectory. VehId and dsId are 1-based.""" 225 | vehTrack = self.T[dsId-1][vehId-1].transpose() 226 | refPos = vehTrack[np.where(vehTrack[:, 0] == t)][0, 1:3] 227 | stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s 228 | enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1) 229 | fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos 230 | return fut 231 | 232 | def getHistory0b(self, vehId:int, t:int, dsId:int ) -> np.ndarray : 233 | """Get track history trajectory. VehId and dsId are zero-based. 234 | No normalizations are performed. 235 | """ 236 | if vehId < 0: 237 | return np.empty([0,2]) 238 | else: 239 | if vehId >= self.T.shape[1]: 240 | return np.empty([0,2]) 241 | vehTrack = self.T[dsId][vehId].transpose() 242 | if vehTrack.size==0 or np.argwhere(vehTrack[:, 0] == t).size==0: 243 | return np.empty([0,2]) 244 | else: 245 | stpt = np.maximum(0, np.argwhere(vehTrack[:, 0] == t).item() - self.t_h) 246 | enpt = np.argwhere(vehTrack[:, 0] == t).item() + 1 247 | hist = vehTrack[stpt:enpt:self.d_s,1:3] 248 | if len(hist) < self.t_h//self.d_s + 1: 249 | return np.empty([0,2]) 250 | return hist 251 | 252 | def getFuture0b(self, vehId:int, t:int, dsId:int, refPos:int) -> np.ndarray : 253 | """Get track future trajectory. VehId and dsId are zero-based.""" 254 | vehTrack = self.T[dsId][vehId].transpose() 255 | stpt = np.argwhere(vehTrack[:, 0] == t).item() + self.d_s 256 | enpt = np.minimum(len(vehTrack), np.argwhere(vehTrack[:, 0] == t).item() + self.t_f + 1) 257 | fut = vehTrack[stpt:enpt:self.d_s,1:3]-refPos 258 | return fut 259 | 260 | def collate_fn(self, samples: List[Any]) -> Tuple[Any,Any,Any,Any,Any,Union[Any,None],Any] : 261 | """Prepare a batch suitable for MFP training.""" 262 | nbr_batch_size = 0 263 | num_samples = 0 264 | for _,_,nbrs,im_crop in samples: 265 | nbr_batch_size += sum([len(nbr) for nbr in nbrs.values() ]) 266 | num_samples += len(nbrs) 267 | 268 | maxlen = self.t_h//self.d_s + 1 269 | if nbr_batch_size <= 0: 270 | nbrs_batch = torch.zeros(maxlen,1,2) 271 | else: 272 | nbrs_batch = torch.zeros(maxlen,nbr_batch_size,2) 273 | 274 | pos = [0, 0] 275 | nbr_inds_batch = torch.zeros( num_samples, self.grid_size[1],self.grid_size[0], self.enc_size*self.enc_fac) 276 | nbr_inds_batch = nbr_inds_batch.byte() 277 | 278 | hist_batch = torch.zeros(maxlen, num_samples, 2) #e.g. (31, 41, 2) 279 | fut_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2) 280 | mask_batch = torch.zeros(self.t_f//self.d_s, num_samples, 2) 281 | if self.use_context: 282 | context_batch = torch.zeros(num_samples, im_crop.shape[0], im_crop.shape[1], im_crop.shape[2] ) 283 | else: 284 | context_batch: Union[None, torch.Tensor] = None # type: ignore 285 | 286 | nbrs_infos = [] 287 | count = 0 288 | samples_so_far = 0 289 | for sampleId,(hist, fut, nbrs, context) in enumerate(samples): 290 | num = len(nbrs) 291 | for j in range(num): 292 | hist_batch[0:len(hist[j]), samples_so_far+j, :] = torch.from_numpy(hist[j]) 293 | fut_batch[0:len(fut[j]), samples_so_far+j, :] = torch.from_numpy(fut[j]) 294 | mask_batch[0:len(fut[j]),samples_so_far+j,:] = 1 295 | samples_so_far += num 296 | 297 | nbrs_infos.append(nbrs) 298 | 299 | if self.use_context: 300 | context_batch[sampleId,:,:,:] = torch.from_numpy(context) 301 | 302 | # nbrs is a dictionary of key to list of nbr (batch_index, veh_id, grid_ind) 303 | for batch_ind, list_of_nbr in nbrs.items(): 304 | for batch_id, vehid, grid_ind in list_of_nbr: 305 | if batch_id >= 0: 306 | nbr_hist = hist[batch_id] 307 | nbrs_batch[0:len(nbr_hist),count,:] = torch.from_numpy( nbr_hist ) 308 | pos[0] = grid_ind % self.grid_size[0] 309 | pos[1] = grid_ind // self.grid_size[0] 310 | nbr_inds_batch[batch_ind,pos[1],pos[0],:] = torch.ones(self.enc_size*self.enc_fac).byte() 311 | count+=1 312 | 313 | return (hist_batch, nbrs_batch, nbr_inds_batch, fut_batch, mask_batch, context_batch, nbrs_infos) 314 | -------------------------------------------------------------------------------- /multiple_futures_prediction/model_ngsim.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any 7 | import torch 8 | import torch.nn as nn 9 | from multiple_futures_prediction.my_utils import * 10 | 11 | # Multiple Futures Prediction Network 12 | class mfpNet(nn.Module): 13 | def __init__(self, args: Dict) -> None: 14 | super(mfpNet, self).__init__() #type: ignore 15 | self.use_cuda = args['use_cuda'] 16 | self.encoder_size = args['encoder_size'] 17 | self.decoder_size = args['decoder_size'] 18 | self.out_length = args['fut_len_orig_hz']//args['subsampling'] 19 | 20 | self.dyn_embedding_size = args['dyn_embedding_size'] 21 | self.input_embedding_size = args['input_embedding_size'] 22 | 23 | self.nbr_atten_embedding_size = args['nbr_atten_embedding_size'] 24 | self.st_enc_hist_size = self.nbr_atten_embedding_size 25 | self.st_enc_pos_size = args['dec_nbr_enc_size'] 26 | self.use_gru = args['use_gru'] 27 | self.bi_direc = args['bi_direc'] 28 | self.use_context = args['use_context'] 29 | self.modes = args['modes'] 30 | self.use_forcing = args['use_forcing'] # 1: Teacher forcing. 2:classmates forcing. 31 | 32 | self.hidden_fac = 2 if args['use_gru'] else 1 33 | self.bi_direc_fac = 2 if args['bi_direc'] else 1 34 | self.dec_fac = 2 if args['bi_direc'] else 1 35 | 36 | self.init_rbf_state_enc( in_dim=self.encoder_size*self.hidden_fac ) 37 | self.posi_enc_dim = self.st_enc_pos_size 38 | self.posi_enc_ego_dim = 2 39 | 40 | # Input embedding layer 41 | self.ip_emb = torch.nn.Linear(2,self.input_embedding_size) #type: ignore 42 | 43 | # Encoding RNN. 44 | if not self.use_gru: 45 | self.enc_lstm = torch.nn.LSTM(self.input_embedding_size,self.encoder_size,1) # type: ignore 46 | else: 47 | self.num_layers=2 48 | self.enc_lstm = torch.nn.GRU(self.input_embedding_size,self.encoder_size, # type: ignore 49 | num_layers=self.num_layers, bidirectional=False) 50 | 51 | # Dynamics embeddings. 52 | self.dyn_emb = torch.nn.Linear(self.encoder_size*self.hidden_fac, self.dyn_embedding_size) #type: ignore 53 | 54 | context_feat_size = 64 if self.use_context else 0 55 | self.dec_lstm = [] 56 | self.op = [] 57 | for k in range(self.modes): 58 | if not self.use_gru: 59 | self.dec_lstm.append( torch.nn.LSTM(self.nbr_atten_embedding_size + self.dyn_embedding_size + #type: ignore 60 | context_feat_size+self.posi_enc_dim+self.posi_enc_ego_dim, self.decoder_size) ) 61 | else: 62 | self.num_layers=2 63 | self.dec_lstm.append( torch.nn.GRU(self.nbr_atten_embedding_size + self.dyn_embedding_size + context_feat_size+self.posi_enc_dim+self.posi_enc_ego_dim, # type: ignore 64 | self.decoder_size, num_layers=self.num_layers, bidirectional=self.bi_direc )) 65 | 66 | self.op.append( torch.nn.Linear(self.decoder_size*self.dec_fac, 5) ) #type: ignore 67 | 68 | self.op[k] = self.op[k] 69 | self.dec_lstm[k] = self.dec_lstm[k] 70 | 71 | self.dec_lstm = torch.nn.ModuleList(self.dec_lstm) # type: ignore 72 | self.op = torch.nn.ModuleList(self.op ) # type: ignore 73 | 74 | self.op_modes = torch.nn.Linear(self.nbr_atten_embedding_size + self.dyn_embedding_size + context_feat_size, self.modes) #type: ignore 75 | 76 | # Nonlinear activations. 77 | self.leaky_relu = torch.nn.LeakyReLU(0.1) #type: ignore 78 | self.relu = torch.nn.ReLU() #type: ignore 79 | self.softmax = torch.nn.Softmax(dim=1) #type: ignore 80 | 81 | if self.use_context: 82 | self.context_conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=2) #type: ignore 83 | self.context_conv2 = torch.nn.Conv2d(16, 16, kernel_size=3, stride=2) #type: ignore 84 | self.context_maxpool = torch.nn.MaxPool2d(kernel_size=(4,2)) #type: ignore 85 | self.context_conv3 = torch.nn.Conv2d(16, 16, kernel_size=3, stride=2) #type: ignore 86 | self.context_fc = torch.nn.Linear(16*20*3, context_feat_size) #type: ignore 87 | 88 | def init_rbf_state_enc(self, in_dim: int ) -> None: 89 | """Initialize the dynamic attentional RBF encoder. 90 | Args: 91 | in_dim is the input dim of the observation. 92 | """ 93 | self.sec_in_dim = in_dim 94 | self.extra_pos_dim = 2 95 | 96 | self.sec_in_pos_dim = 2 97 | self.sec_key_dim = 8 98 | self.sec_key_hidden_dim = 32 99 | 100 | self.sec_hidden_dim = 32 101 | self.scale = 1.0 102 | self.slot_key_scale = 1.0 103 | self.num_slots = 8 104 | self.slot_keys = [] 105 | 106 | # Network for computing the 'key' 107 | self.sec_key_net = torch.nn.Sequential( #type: ignore 108 | torch.nn.Linear(self.sec_in_dim+self.extra_pos_dim, self.sec_key_hidden_dim), 109 | torch.nn.ReLU(), 110 | torch.nn.Linear(self.sec_key_hidden_dim, self.sec_key_dim) 111 | ) 112 | 113 | for ss in range(self.num_slots): 114 | self.slot_keys.append( torch.nn.Parameter( self.slot_key_scale*torch.randn( self.sec_key_dim, 1, dtype=torch.float32) ) ) #type: ignore 115 | self.slot_keys = torch.nn.ParameterList( self.slot_keys ) # type: ignore 116 | 117 | # Network for encoding a scene-level contextual feature. 118 | self.sec_hist_net = torch.nn.Sequential( #type: ignore 119 | torch.nn.Linear(self.sec_in_dim*self.num_slots, self.sec_hidden_dim), 120 | torch.nn.ReLU(), 121 | torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim), 122 | torch.nn.ReLU(), 123 | torch.nn.Linear(self.sec_hidden_dim, self.st_enc_hist_size) 124 | ) 125 | 126 | # Encoder position of other's into a feature network, input should be normalized to ref_pos. 127 | self.sec_pos_net = torch.nn.Sequential( #type: ignore 128 | torch.nn.Linear(self.sec_in_pos_dim*self.num_slots, self.sec_hidden_dim), 129 | torch.nn.ReLU(), 130 | torch.nn.Linear(self.sec_hidden_dim, self.sec_hidden_dim), 131 | torch.nn.ReLU(), 132 | torch.nn.Linear(self.sec_hidden_dim, self.st_enc_pos_size) 133 | ) 134 | 135 | def rbf_state_enc_get_attens(self, nbrs_enc: torch.Tensor, ref_pos: torch.Tensor, nbrs_info_this: List ) -> List[torch.Tensor]: 136 | """Computing the attention over other agents. 137 | Args: 138 | nbrs_info_this is a list of list of (nbr_batch_ind, nbr_id, nbr_ctx_ind) 139 | Returns: 140 | attention weights over the neighbors. 141 | """ 142 | assert len(nbrs_info_this) == ref_pos.shape[0] 143 | if self.extra_pos_dim > 0: 144 | pos_enc = torch.zeros(nbrs_enc.shape[0],2, device=nbrs_enc.device) 145 | counter = 0 146 | for n in range(len(nbrs_info_this)): 147 | for nbr in nbrs_info_this[n]: 148 | pos_enc[counter,:] = ref_pos[nbr[0],:] - ref_pos[n,:] 149 | counter += 1 150 | Key = self.sec_key_net( torch.cat( (nbrs_enc,pos_enc),dim=1) ) 151 | # e.g. num_agents by self.sec_key_dim 152 | else: 153 | Key = self.sec_key_net( nbrs_enc ) # e.g. num_agents by self.sec_key_dim 154 | 155 | attens0 = [] 156 | for slot in self.slot_keys: 157 | attens0.append( torch.exp( -self.scale*(Key-torch.t(slot)).norm(dim=1)) ) 158 | 159 | Atten = torch.stack(attens0, dim=0) # e.g. num_keys x num_agents 160 | attens = [] 161 | counter = 0 162 | for n in range(len(nbrs_info_this)): 163 | list_of_nbrs = nbrs_info_this[n] 164 | counter2 = counter+len(list_of_nbrs) 165 | attens.append( Atten[:, counter:counter2 ] ) 166 | counter = counter2 167 | return attens 168 | 169 | def rbf_state_enc_hist_fwd(self, attens: List, nbrs_enc: torch.Tensor, nbrs_info_this: List) -> torch.Tensor: 170 | """Computes dynamic state encoding. 171 | Computes dynica state encoding with precomputed attention tensor and the 172 | RNN based encoding. 173 | Args: 174 | attens is a list of [ [slots x num_neighbors]] 175 | nbrs_enc is num_agents by input_dim 176 | Returns: 177 | feature vector 178 | """ 179 | out = [] 180 | counter = 0 181 | for n in range(len(nbrs_info_this)): 182 | list_of_nbrs = nbrs_info_this[n] 183 | if len(list_of_nbrs) > 0: 184 | counter2 = counter+len(list_of_nbrs) 185 | nbr_feat = nbrs_enc[counter:counter2,:] 186 | out.append( torch.mm( attens[n], nbr_feat ) ) 187 | counter = counter2 188 | else: 189 | out.append( torch.zeros(self.num_slots, nbrs_enc.shape[1] ).to(nbrs_enc.device) ) 190 | # if no neighbors found, use all zeros. 191 | st_enc = torch.stack(out, dim=0).view(len(out),-1) # num_agents by slots*enc dim 192 | return self.sec_hist_net(st_enc) 193 | 194 | def rbf_state_enc_pos_fwd(self, attens: List, ref_pos: torch.Tensor, fut_t: torch.Tensor, flatten_inds: torch.Tensor, chunks: List) -> torch.Tensor: 195 | """Computes the features from dynamic attention for interactive rollouts. 196 | Args: 197 | attens is a list of [ [slots x num_neighbors]] 198 | ref_pos should be (num_agents by 2) 199 | Returns: 200 | feature vector 201 | """ 202 | fut = fut_t + ref_pos #convert to 'global' frame 203 | nbr_feat = torch.index_select( fut, 0, flatten_inds) 204 | splits = torch.split(nbr_feat, chunks, dim=0) #type: ignore 205 | out = [] 206 | for n, nbr_feat in enumerate(splits): 207 | out.append( torch.mm( attens[n], nbr_feat - ref_pos[n,:] ) ) 208 | pos_enc = torch.stack(out, dim=0).view(len(attens),-1) # num_agents by slots*enc dim 209 | return self.sec_pos_net(pos_enc) 210 | 211 | 212 | def forward_mfp(self, hist:torch.Tensor, nbrs:torch.Tensor, masks:torch.Tensor, context:Any, 213 | nbrs_info:List, fut:torch.Tensor, bStepByStep:bool, 214 | use_forcing:Optional[Union[None,int]]=None) -> Tuple[List[torch.Tensor], Any]: 215 | """Forward propagation function for the MFP 216 | 217 | Computes dynamic state encoding with precomputed attention tensor and the 218 | RNN based encoding. 219 | Args: 220 | hist: Trajectory history. 221 | nbrs: Neighbors. 222 | masks: Neighbors mask. 223 | context: contextual information in image form (if used). 224 | nbrs_info: information as to which other agents are neighbors. 225 | fut: Future Trajectory. 226 | bStepByStep: During rollout, interactive or independent. 227 | use_forcing: Teacher-forcing or classmate forcing. 228 | 229 | Returns: 230 | fut_pred: a list of predictions, one for each mode. 231 | modes_pred: prediction over latent modes. 232 | """ 233 | use_forcing = self.use_forcing if use_forcing==None else use_forcing 234 | 235 | # Normalize to reference position. 236 | ref_pos = hist[-1,:,:] 237 | hist = hist - ref_pos.view(1,-1,2) 238 | 239 | # Encode history trajectories. 240 | if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU): 241 | _, hist_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(hist))) 242 | else: 243 | _,(hist_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(hist))) #hist torch.Size([16, 128, 2]) 244 | 245 | if self.use_gru: 246 | hist_enc = hist_enc.permute(1,0,2).contiguous() 247 | hist_enc = self.leaky_relu(self.dyn_emb( hist_enc.view(hist_enc.shape[0], -1) )) 248 | else: 249 | hist_enc = self.leaky_relu(self.dyn_emb(hist_enc.view(hist_enc.shape[1],hist_enc.shape[2]))) #torch.Size([128, 32]) 250 | 251 | num_nbrs = sum([len(nbs) for nb_id, nbs in nbrs_info[0].items() ]) 252 | if num_nbrs > 0: 253 | nbrs_ref_pos = nbrs[-1,:,:] 254 | nbrs = nbrs - nbrs_ref_pos.view(1,-1,2) # normalize 255 | 256 | # Forward pass for all neighbors. 257 | if isinstance(self.enc_lstm, torch.nn.modules.rnn.GRU): 258 | _, nbrs_enc = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs))) 259 | else: 260 | _, (nbrs_enc,_) = self.enc_lstm(self.leaky_relu(self.ip_emb(nbrs))) 261 | 262 | if self.use_gru: 263 | nbrs_enc = nbrs_enc.permute(1,0,2).contiguous() 264 | nbrs_enc = nbrs_enc.view(nbrs_enc.shape[0], -1) 265 | else: 266 | nbrs_enc = nbrs_enc.view(nbrs_enc.shape[1], nbrs_enc.shape[2]) 267 | 268 | attens = self.rbf_state_enc_get_attens(nbrs_enc, ref_pos, nbrs_info[0]) 269 | nbr_atten_enc = self.rbf_state_enc_hist_fwd(attens, nbrs_enc, nbrs_info[0]) 270 | 271 | else: # if have no neighbors 272 | attens = None # type: ignore 273 | nbr_atten_enc = torch.zeros( 1, self.nbr_atten_embedding_size, dtype=torch.float32, device=masks.device ) 274 | 275 | if self.use_context: #context encoding 276 | context_enc = self.relu(self.context_conv( context )) 277 | context_enc = self.context_maxpool( self.context_conv2( context_enc )) 278 | context_enc = self.relu(self.context_conv3(context_enc)) 279 | context_enc = self.context_fc( context_enc.view( context_enc.shape[0], -1) ) 280 | 281 | enc = torch.cat((nbr_atten_enc, hist_enc, context_enc),1) 282 | else: 283 | enc = torch.cat((nbr_atten_enc, hist_enc),1) 284 | # e.g. nbr_atten_enc: [num_agents by 80], hist_enc: [num_agents by 32], enc would be [num_agents by 112] 285 | 286 | ###################################################################################################### 287 | modes_pred = None if self.modes==1 else self.softmax(self.op_modes(enc)) 288 | fut_pred = self.decode(enc, attens, nbrs_info[0], ref_pos, fut, bStepByStep, use_forcing) 289 | return fut_pred, modes_pred 290 | 291 | def decode(self, enc: torch.Tensor, attens:List, nbrs_info_this:List, ref_pos:torch.Tensor, fut:torch.Tensor, bStepByStep:bool, use_forcing:Any ) -> List[torch.Tensor]: 292 | """Decode the future trajectory using RNNs. 293 | 294 | Given computed feature vector, decode the future with multimodes, using 295 | dynamic attention and either interactive or non-interactive rollouts. 296 | Args: 297 | enc: encoded features, one per agent. 298 | attens: attentional weights, list of objs, each with dimenstion of [8 x 4] (e.g.) 299 | nbrs_info_this: information on who are the neighbors 300 | ref_pos: the current postion (reference position) of the agents. 301 | fut: future trajectory (only useful for teacher or classmate forcing) 302 | bStepByStep: interactive or non-interactive rollout 303 | use_forcing: 0: None. 1: Teacher-forcing. 2: classmate forcing. 304 | 305 | Returns: 306 | fut_pred: a list of predictions, one for each mode. 307 | modes_pred: prediction over latent modes. 308 | """ 309 | if not bStepByStep: # Non-interactive rollouts 310 | enc = enc.repeat(self.out_length, 1, 1) 311 | pos_enc = torch.zeros( self.out_length, enc.shape[1], self.posi_enc_dim+self.posi_enc_ego_dim, device=enc.device ) 312 | enc2 = torch.cat( (enc, pos_enc), dim=2) 313 | fut_preds = [] 314 | for k in range(self.modes): 315 | h_dec, _ = self.dec_lstm[k](enc2) 316 | h_dec = h_dec.permute(1, 0, 2) 317 | fut_pred = self.op[k](h_dec) 318 | fut_pred = fut_pred.permute(1, 0, 2) #torch.Size([nSteps, num_agents, 5]) 319 | 320 | fut_pred = Gaussian2d(fut_pred) 321 | fut_preds.append(fut_pred) 322 | return fut_preds 323 | else: 324 | batch_sz = enc.shape[0] 325 | inds = [] 326 | chunks = [] 327 | for n in range(len(nbrs_info_this)): 328 | chunks.append( len(nbrs_info_this[n]) ) 329 | for nbr in nbrs_info_this[n]: 330 | inds.append(nbr[0]) 331 | flat_index = torch.LongTensor(inds).to(ref_pos.device) # type: ignore 332 | 333 | fut_preds = [] 334 | for k in range(self.modes): 335 | direc = 2 if self.bi_direc else 1 336 | hidden = torch.zeros(self.num_layers*direc, batch_sz, self.decoder_size).to(fut.device) 337 | preds: List[torch.Tensor] = [] 338 | for t in range(self.out_length): 339 | if t == 0: # Intial timestep. 340 | if use_forcing == 0: 341 | pred_fut_t = torch.zeros_like(fut[t,:,:]) 342 | ego_fut_t = pred_fut_t 343 | elif use_forcing == 1: 344 | pred_fut_t = fut[t,:,:] 345 | ego_fut_t = pred_fut_t 346 | else: 347 | pred_fut_t = fut[t,:,:] 348 | ego_fut_t = torch.zeros_like(pred_fut_t) 349 | else: 350 | if use_forcing == 0: 351 | pred_fut_t = preds[-1][:,:,:2].squeeze() 352 | ego_fut_t = pred_fut_t 353 | elif use_forcing == 1: 354 | pred_fut_t = fut[t,:,:] 355 | ego_fut_t = pred_fut_t 356 | else: 357 | pred_fut_t = fut[t,:,:] 358 | ego_fut_t = preds[-1][:,:,:2] 359 | 360 | if attens == None: 361 | pos_enc = torch.zeros(batch_sz, self.posi_enc_dim, device=enc.device ) 362 | else: 363 | pos_enc = self.rbf_state_enc_pos_fwd(attens, ref_pos, pred_fut_t, flat_index, chunks ) 364 | 365 | enc_large = torch.cat( ( enc.view(1,enc.shape[0],enc.shape[1]), 366 | pos_enc.view(1,batch_sz, self.posi_enc_dim), 367 | ego_fut_t.view(1, batch_sz, self.posi_enc_ego_dim ) ), dim=2 ) 368 | 369 | out, hidden = self.dec_lstm[k]( enc_large, hidden) 370 | pred = Gaussian2d(self.op[k](out)) 371 | preds.append( pred ) 372 | fut_pred_k = torch.cat(preds,dim=0) 373 | fut_preds.append(fut_pred_k) 374 | return fut_preds 375 | -------------------------------------------------------------------------------- /multiple_futures_prediction/my_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Set, Dict, Tuple, Optional, Union 7 | from scipy import special 8 | import numpy as np 9 | import torch 10 | import json 11 | import os 12 | 13 | 14 | """Helper functions for loss or normalizations. 15 | """ 16 | 17 | def Gaussian2d(x: torch.Tensor) -> torch.Tensor : 18 | """Computes the parameters of a bivariate 2D Gaussian.""" 19 | x_mean = x[:,:,0] 20 | y_mean = x[:,:,1] 21 | sigma_x = torch.exp(x[:,:,2]) 22 | sigma_y = torch.exp(x[:,:,3]) 23 | rho = torch.tanh(x[:,:,4]) 24 | return torch.stack([x_mean, y_mean, sigma_x, sigma_y, rho], dim=2) 25 | 26 | def nll_loss( pred: torch.Tensor, data: torch.Tensor, mask:torch.Tensor ) -> torch.Tensor : 27 | """NLL averages across steps, samples, and dimensions(x,y).""" 28 | x_mean = pred[:,:,0] 29 | y_mean = pred[:,:,1] 30 | x_sigma = pred[:,:,2] 31 | y_sigma = pred[:,:,3] 32 | rho = pred[:,:,4] 33 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore 34 | x = data[:,:, 0]; y = data[:,:, 1] 35 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2) 36 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr) 37 | 38 | results = results*mask[:,:,0] 39 | assert torch.sum(mask) > 0.0 40 | return torch.sum(results)/torch.sum(mask[:,:,0]) 41 | 42 | def nll_loss_per_sample( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor : 43 | """NLL averages across steps and dimensions, but not samples (agents).""" 44 | x_mean = pred[:,:,0] 45 | y_mean = pred[:,:,1] 46 | x_sigma = pred[:,:,2] 47 | y_sigma = pred[:,:,3] 48 | rho = pred[:,:,4] 49 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore 50 | x = data[:,:, 0]; y = data[:,:, 1] 51 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2) 52 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr) 53 | results = results*mask[:,:,0] # nSteps by nBatch 54 | return torch.sum(results, dim=0)/torch.sum(mask[:,:,0], dim=0) 55 | 56 | def nll_loss_test( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 57 | """NLL for testing cases, returns a vector over future timesteps.""" 58 | x_mean = pred[:,:,0] 59 | y_mean = pred[:,:,1] 60 | x_sigma = pred[:,:,2] 61 | y_sigma = pred[:,:,3] 62 | rho = pred[:,:,4] 63 | ohr = torch.pow(1-torch.pow(rho,2),-0.5) # type: ignore 64 | x = data[:,:, 0]; y = data[:,:, 1] 65 | results = torch.pow(ohr, 2)*(torch.pow(x_sigma, 2)*torch.pow(x-x_mean, 2) + torch.pow(y_sigma, 2)*torch.pow(y-y_mean, 2) 66 | -2*rho*torch.pow(x_sigma, 1)*torch.pow(y_sigma, 1)*(x-x_mean)*(y-y_mean)) - torch.log(x_sigma*y_sigma*ohr) 67 | results = results*mask[:,:,0] # nSteps by nBatch 68 | assert torch.sum(mask) > 0.0 69 | counts = torch.sum(mask[:, :, 0], dim=1) 70 | return torch.sum(results, dim=1), counts 71 | 72 | def mse_loss( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: 73 | """Mean squared error loss.""" 74 | x_mean = pred[:,:,0] 75 | y_mean = pred[:,:,1] 76 | x = data[:,:, 0] 77 | y = data[:,:, 1] 78 | results = torch.pow(x-x_mean, 2) + torch.pow(y-y_mean, 2) 79 | results = results*mask[:,:,0] 80 | return torch.sum(results)/torch.sum(mask[:,:,0]) 81 | 82 | def mse_loss_test( pred: torch.Tensor, data: torch.Tensor, mask: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: 83 | """Mean squared error loss for test time.""" 84 | x_mean = pred[:,:,0] 85 | y_mean = pred[:,:,1] 86 | x = data[:,:,0] 87 | y = data[:,:,1] 88 | results = torch.pow(x-x_mean, 2) + torch.pow(y-y_mean, 2) 89 | results = results*mask[:,:,0] 90 | counts = torch.sum(mask[:,:,0],dim=1) 91 | lossVal = torch.sum(results,dim=1) 92 | return lossVal, counts 93 | 94 | def logsumexp(inputs: torch.Tensor, dim: Optional[int] =None, keepdim: Optional[bool] =False) -> torch.Tensor: 95 | if dim is None: 96 | inputs = inputs.view(-1) 97 | dim = 0 98 | s, _ = torch.max(inputs, dim=dim, keepdim=True) 99 | outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() 100 | if not keepdim: 101 | outputs = outputs.squeeze(dim) 102 | return outputs 103 | 104 | def nll_loss_test_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, y_mean: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor] : 105 | """NLL loss multimodes for test time.""" 106 | modes = len(pred) 107 | nSteps, batch_sz, dim = pred[0].shape 108 | total = torch.zeros(mask.shape[0],mask.shape[1], modes).to(y_mean.device) 109 | count = 0 110 | for k in range(modes): 111 | wts = modes_pred[:,k] 112 | wts = wts.repeat(nSteps,1) 113 | y_pred = pred[k] 114 | if y_mean is not None: 115 | x_pred_mean = y_pred[:, :, 0]+y_mean[:,0].view(-1,1) 116 | y_pred_mean = y_pred[:, :, 1]+y_mean[:,1].view(-1,1) 117 | else: 118 | x_pred_mean = y_pred[:, :, 0] 119 | y_pred_mean = y_pred[:, :, 1] 120 | x_sigma = y_pred[:, :, 2] 121 | y_sigma = y_pred[:, :, 3] 122 | rho = y_pred[:, :, 4] 123 | ohr = torch.pow(1 - torch.pow(rho, 2), -0.5) # type: ignore 124 | x = data[:, :, 0] 125 | y = data[:, :, 1] 126 | out = -(torch.pow(ohr, 2) * (torch.pow(x_sigma, 2) * torch.pow(x - x_pred_mean, 2) + torch.pow(y_sigma, 2) * torch.pow(y - y_pred_mean,2) 127 | -2 * rho * torch.pow(x_sigma, 1) * torch.pow(y_sigma, 1) * (x - x_pred_mean) * (y - y_pred_mean)) - torch.log(x_sigma * y_sigma * ohr)) 128 | total[:, :, count] = out + torch.log(wts) 129 | count += 1 130 | total = -logsumexp(total,dim = 2) 131 | total = total * mask[:,:,0] 132 | lossVal = torch.sum(total,dim=1) 133 | counts = torch.sum(mask[:,:,0],dim=1) 134 | return lossVal, counts 135 | 136 | def nll_loss_multimodes(pred: List[torch.Tensor], data: torch.Tensor, mask: torch.Tensor, modes_pred: torch.Tensor, noise: Optional[float]=0.0 ) -> float: 137 | """NLL loss multimodes for training. 138 | Args: 139 | pred is a list (with N modes) of predictions 140 | data is ground truth 141 | noise is optional 142 | """ 143 | modes = len(pred) 144 | nSteps, batch_sz, dim = pred[0].shape 145 | log_lik = np.zeros( (batch_sz, modes) ) 146 | with torch.no_grad(): 147 | for kk in range(modes): 148 | nll = nll_loss_per_sample(pred[kk], data, mask) 149 | log_lik[:,kk] = -nll.cpu().numpy() 150 | 151 | priors = modes_pred.detach().cpu().numpy() 152 | 153 | log_posterior_unnorm = log_lik + np.log(priors).reshape((-1, modes)) #[TotalObjs, net.modes] 154 | log_posterior_unnorm += np.random.randn( *log_posterior_unnorm.shape)*noise 155 | log_posterior = log_posterior_unnorm - special.logsumexp( log_posterior_unnorm, axis=1 ).reshape((batch_sz, 1)) 156 | post_pr = np.exp(log_posterior) #[TotalObjs, net.modes] 157 | 158 | post_pr = torch.tensor(post_pr).float().to(data.device) 159 | loss = 0.0 160 | for kk in range(modes): 161 | nll_k = nll_loss_per_sample(pred[kk], data, mask)*post_pr[:,kk] 162 | loss += nll_k.sum()/float(batch_sz) 163 | 164 | kl_loss = torch.nn.KLDivLoss(reduction='batchmean') #type: ignore 165 | loss += kl_loss( torch.log(modes_pred), post_pr) 166 | return loss 167 | 168 | ################################################################################ 169 | def load_json_file(json_filename: str) -> dict: 170 | with open(json_filename) as json_file: 171 | json_dictionary = json.load(json_file) 172 | return json_dictionary 173 | 174 | def write_json_file(json_filename: str, json_dict: dict, pretty: Optional[bool]=False) -> None: 175 | with open(os.path.expanduser(json_filename), 'w') as outfile: 176 | if pretty: 177 | json.dump(json_dict, outfile, sort_keys=True, indent = 2) 178 | else: 179 | json.dump(json_dict, outfile, sort_keys=True,) 180 | 181 | def pi(obj: Union[torch.Tensor, np.ndarray]) -> None: 182 | """ Prints out some info.""" 183 | if isinstance(obj, torch.Tensor): 184 | print(str(obj.shape), end=' ') 185 | print(str(obj.device), end=' ') 186 | print( 'min:', float(obj.min() ), end=' ') 187 | print( 'max:', float(obj.max() ), end=' ') 188 | print( 'std:', float(obj.std() ), end=' ') 189 | print(str(obj.dtype) ) 190 | elif isinstance(obj, np.ndarray): 191 | print(str(obj.shape), end=' ') 192 | print( 'min:', float(obj.min() ), end=' ') 193 | print( 'max:', float(obj.max() ), end=' ') 194 | print( 'std:', float(obj.std() ), end=' ') 195 | print(str(obj.dtype) ) 196 | 197 | def compute_angles(x_mean: torch.Tensor, num_steps:int=3) -> torch.Tensor: 198 | """Compute the 2d angle of trajectories. 199 | Args: 200 | x_mean is [nSteps, nObjs, dim] 201 | """ 202 | nSteps, nObjs, dim = x_mean.shape 203 | thetas = np.zeros( (nObjs, num_steps)) 204 | for k in range(num_steps): 205 | for o in range(nObjs): 206 | diff = x_mean[k+1,o,:] - x_mean[k,o,:] 207 | thetas[o,k] = np.arctan2(diff[1], diff[0]) 208 | return thetas.mean(axis=1) 209 | 210 | def rotate_to(data: np.ndarray, theta0: np.ndarray, x0: np.ndarray) -> np.ndarray: 211 | """Rotate data about location x0 with theta0 in radians. 212 | Args: 213 | data is [nSteps, dim] or [nSteps, nObjs, dim] 214 | """ 215 | rot = np.array( [ [np.cos(theta0), np.sin(theta0) ], 216 | [ -np.sin(theta0), np.cos(theta0)] ] ) 217 | if len(data.shape) == 2: 218 | return np.dot( data - x0, rot.T) 219 | else: 220 | nSteps, nObjs, dim = data.shape 221 | return np.dot( data.reshape((-1,dim))-x0, rot.T).reshape((nSteps, nObjs, dim)) 222 | 223 | def rotate_to_inv(data: np.ndarray, theta0: np.ndarray, x0: np.ndarray) -> np.ndarray: 224 | """Inverse rotate data about location x0 with theta0 in radians. 225 | Args: 226 | data is [nSteps, dim] or [nSteps, nObjs, dim] 227 | """ 228 | rot = np.array( [ [np.cos(theta0), -np.sin(theta0) ], 229 | [ np.sin(theta0), np.cos(theta0)] ] ) 230 | if len(data.shape) == 2: 231 | return np.dot( data, rot.T) + x0 232 | else: 233 | nSteps, nObjs, dim = data.shape 234 | return (np.dot( data.reshape((-1,dim)), rot.T)+x0).reshape((nSteps, nObjs, dim)) 235 | 236 | -------------------------------------------------------------------------------- /multiple_futures_prediction/ngsim_data/.gitignore: -------------------------------------------------------------------------------- 1 | *.mat 2 | -------------------------------------------------------------------------------- /multiple_futures_prediction/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-multiple-futures-prediction/021c872fcc3529ae698e16ab565f0f5fa23857ab/multiple_futures_prediction/py.typed -------------------------------------------------------------------------------- /multiple_futures_prediction/train_ngsim.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019-2020 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from typing import List, Set, Dict, Tuple, Optional, Union, Any 7 | import numpy as np 8 | np.set_printoptions(suppress=1) 9 | import time 10 | import math 11 | import glob 12 | from attrdict import AttrDict 13 | import gin 14 | import torch 15 | from torch.utils.data import DataLoader 16 | from multiple_futures_prediction.dataset_ngsim import * 17 | from multiple_futures_prediction.model_ngsim import mfpNet 18 | from multiple_futures_prediction.my_utils import * 19 | 20 | def eval(metric: str, net: torch.nn.Module, params: AttrDict, data_loader: DataLoader, bStepByStep: bool, 21 | use_forcing: int, y_mean: np.ndarray, num_batches: int, dataset_name: str) -> torch.Tensor: 22 | """Evaluation function for validation and test data. 23 | 24 | Given a MFP network, data loader, evaulate either NLL or RMSE error. 25 | """ 26 | print('eval ', dataset_name) 27 | num = params.fut_len_orig_hz//params.subsampling 28 | lossVals = torch.zeros(num) 29 | counts = torch.zeros(num) 30 | 31 | for i, data in enumerate(data_loader): 32 | if i >= num_batches: 33 | break 34 | hist, nbrs, mask, fut, mask, context, nbrs_info = data 35 | if params.use_cuda: 36 | hist = hist.cuda() 37 | nbrs = nbrs.cuda() 38 | mask = mask.cuda() 39 | fut = fut.cuda() 40 | mask = mask.cuda() 41 | if context is not None: 42 | context = context.cuda() 43 | 44 | if metric == 'nll': 45 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing) 46 | if params.modes == 1: 47 | if params.remove_y_mean: 48 | fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device) 49 | l, c = nll_loss_test(fut_preds[0], fut, mask) 50 | else: 51 | l, c = nll_loss_test_multimodes(fut_preds, fut, mask, modes_pred, y_mean.to(fut.device) ) 52 | else: # RMSE error 53 | assert params.modes == 1 54 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep, use_forcing=use_forcing) 55 | if params.modes == 1: 56 | if params.remove_y_mean: 57 | fut_preds[0][:,:,:2] += y_mean.unsqueeze(1).to(fut.device) 58 | l, c = mse_loss_test(fut_preds[0], fut, mask) 59 | 60 | lossVals += l.detach().cpu() 61 | counts += c.detach().cpu() 62 | 63 | if metric == 'nll': 64 | err = lossVals / counts 65 | print(lossVals / counts) 66 | else: 67 | err = torch.pow(lossVals / counts,0.5)*0.3048 68 | print( err ) # Calculate RMSE and convert from feet to meters 69 | return err 70 | 71 | def get_mean( train_data_loader: DataLoader, batches: Optional[int]=200 ) -> np.ndarray: 72 | """Compute the means over some samples from the training data.""" 73 | yy = [] 74 | counters = None 75 | for i, data in enumerate(train_data_loader): 76 | if i > batches: # type: ignore 77 | break 78 | hist, nbrs, _, fut, fut_mask, _, _ = data 79 | target = fut.cpu().numpy() 80 | valid = fut_mask.cpu().numpy().sum(axis=1) 81 | 82 | if counters is None: 83 | counters = np.zeros_like( valid ) 84 | counters += valid 85 | 86 | isinvalid = (fut_mask == 0) 87 | target[isinvalid] = 0 88 | yy.append( target ) 89 | 90 | Y = np.concatenate(yy, axis=1) 91 | y_mean= np.divide( np.sum(Y,axis=1), counters) 92 | return y_mean 93 | 94 | def setup_logger(root_dir: str, SCENARIO_NAME: str ) -> Tuple[Any, Any]: 95 | """Setup the data logger for logging.""" 96 | import glob 97 | from subprocess import call 98 | import time 99 | import datetime 100 | import os 101 | 102 | timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y.%m.%d_%H.%M.%S') 103 | logging_dir = root_dir+"%s_%s/"%(SCENARIO_NAME, timestamp) 104 | if not os.path.isdir(logging_dir): 105 | os.makedirs(logging_dir) 106 | os.makedirs(logging_dir+'/checkpoints') 107 | print ("! " + logging_dir + " CREATED!") 108 | 109 | logger_file = open(logging_dir+'/log', 'w') 110 | return logger_file, logging_dir 111 | 112 | ################################################################################ 113 | @gin.configurable 114 | class Params(object): 115 | def __init__(self, log:bool=False, # save checkpoints? 116 | modes:int=2, # how many latent modes 117 | use_cuda:bool=True, 118 | encoder_size:int=16, # encoder latent layer size 119 | decoder_size:int=16, # decoder latent layer size 120 | subsampling:int=2, # factor subsample in time 121 | hist_len_orig_hz:int=30, # number of original history samples 122 | fut_len_orig_hz:int=50, # number of original future samples 123 | dyn_embedding_size:int=32, # dynamic embedding size 124 | input_embedding_size:int=32, # input embedding size 125 | dec_nbr_enc_size:int=8, # decoder neighbors encode size 126 | nbr_atten_embedding_size:int=80, # neighborhood attention embedding size 127 | seed:int=1234, 128 | remove_y_mean:bool=False, # normalize by remove mean of the future trajectory 129 | use_gru:bool=True, # GRUs instead of LSTMs 130 | bi_direc:bool=False, # bidrectional 131 | self_norm:bool=False, # normalize with respect to the current time 132 | data_aug:bool=False, # data augment 133 | use_context:bool=False, # use contextual image as additional input 134 | nll:bool=True, # negative log-liklihood loss 135 | use_forcing:int=0, # teacher forcing 136 | iter_per_err:int=100, # iterations to display errors 137 | iter_per_eval:int=1000, # iterations to eval on validation set 138 | pre_train_num_updates:int=200000, # how many iterations for pretraining 139 | updates_div_by_10:int=100000, # at what iteration to divide the learning rate by 10.0 140 | nbr_search_depth:int=10, # how deep do we search for neighbors 141 | lr_init:float=0.001, # initial learning rate 142 | min_lr:float=0.00005, # minimal learning rate 143 | iters_per_save:int=1500 ) -> None : 144 | self.params = AttrDict(locals()) 145 | def __call__(self) -> Any: 146 | return self.params 147 | ################################################################################ 148 | 149 | def train( params: AttrDict ) -> Any : 150 | """Main training function.""" 151 | torch.manual_seed( params.seed ) #type: ignore 152 | np.random.seed( params.seed ) 153 | 154 | ############################ 155 | batch_size = 1 156 | data_hz = 10 157 | ns_between_samples = (1.0/data_hz)*1e9 158 | d_s = params.subsampling 159 | t_h = params.hist_len_orig_hz 160 | t_f = params.fut_len_orig_hz 161 | NUM_WORKERS = 1 162 | 163 | DATA_PATH = 'multiple_futures_prediction/' 164 | 165 | # Loading the dataset. 166 | train_set = NgsimDataset(DATA_PATH + 'ngsim_data/TrainSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm, 167 | params.data_aug, params.use_context, params.nbr_search_depth) 168 | val_set = NgsimDataset(DATA_PATH + 'ngsim_data/ValSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm, 169 | params.data_aug, params.use_context, params.nbr_search_depth) 170 | test_set = NgsimDataset(DATA_PATH + 'ngsim_data/TestSet.mat', t_h, t_f, d_s, params.encoder_size, params.use_gru, params.self_norm, 171 | params.data_aug, params.use_context, params.nbr_search_depth) 172 | train_data_loader = DataLoader(train_set,batch_size=batch_size, shuffle=1, num_workers=NUM_WORKERS, collate_fn=train_set.collate_fn,drop_last=True) # type: ignore 173 | val_data_loader = DataLoader(val_set,batch_size=batch_size,shuffle=0, num_workers=NUM_WORKERS, collate_fn=val_set.collate_fn,drop_last=True) #type: ignore 174 | test_data_loader = DataLoader(test_set,batch_size=batch_size, shuffle=0, num_workers=NUM_WORKERS,collate_fn=test_set.collate_fn,drop_last=False) #type: ignore 175 | 176 | # Compute or load existing mean over future trajectories. 177 | if os.path.exists(DATA_PATH+'ngsim_data/y_mean.pkl'): 178 | y_mean = pickle.load( open(DATA_PATH+'ngsim_data/y_mean.pkl', 'rb')) 179 | else: 180 | y_mean = get_mean(train_data_loader) 181 | pickle.dump( y_mean, open(DATA_PATH+'ngsim_data/y_mean.pkl', 'wb') ) 182 | 183 | # Initialize network 184 | net = mfpNet( params ) 185 | if params.use_cuda: 186 | net = net.cuda() #type: ignore 187 | 188 | net.y_mean = y_mean 189 | y_mean = torch.tensor(net.y_mean) 190 | 191 | if params.log: 192 | logger_file, logging_dir = setup_logger(DATA_PATH+"./checkpts/", 'NGSIM' ) 193 | 194 | train_loss: List = [] 195 | val_loss: List = [] 196 | 197 | MODE='Pre' # For efficiency, we first pre-train w/o interactive rollouts. 198 | num_updates = 0 199 | optimizer = None 200 | 201 | for epoch_num in range(20): 202 | if MODE == 'EndPre': 203 | MODE = 'Train' 204 | print('Training with interactive rollouts.') 205 | bStepByStep = True 206 | else: 207 | print('Pre-training without interactive rollouts.') 208 | bStepByStep = False 209 | 210 | # Average losses. 211 | avg_tr_loss = 0. 212 | avg_tr_time = 0. 213 | loss_counter = 0.0 214 | 215 | for i, data in enumerate(train_data_loader): 216 | if num_updates > params.pre_train_num_updates and MODE == 'Pre': 217 | MODE = 'EndPre' 218 | break 219 | 220 | lr_fac = np.power(0.1, num_updates // params.updates_div_by_10 ) 221 | lr = max( params.min_lr, params.lr_init*lr_fac) 222 | if optimizer is None: 223 | optimizer = torch.optim.Adam(net.parameters(), lr=lr) #type: ignore 224 | elif lr != optimizer.defaults['lr']: 225 | optimizer = torch.optim.Adam(net.parameters(), lr=lr) 226 | 227 | st_time = time.time() 228 | hist, nbrs, mask, fut, mask, context, nbrs_info = data 229 | 230 | if params.remove_y_mean: 231 | fut = fut-y_mean.unsqueeze(1) 232 | 233 | if params.use_cuda: 234 | hist = hist.cuda() 235 | nbrs = nbrs.cuda() 236 | mask = mask.cuda() 237 | fut = fut.cuda() 238 | mask = mask.cuda() 239 | if context is not None: 240 | context = context.cuda() 241 | 242 | # Forward pass. 243 | fut_preds, modes_pred = net.forward_mfp(hist, nbrs, mask, context, nbrs_info, fut, bStepByStep) 244 | if params.modes == 1: 245 | l = nll_loss(fut_preds[0], fut, mask) 246 | else: 247 | l = nll_loss_multimodes(fut_preds, fut, mask, modes_pred) # type: ignore 248 | 249 | # Backprop. 250 | optimizer.zero_grad() 251 | l.backward() 252 | torch.nn.utils.clip_grad_norm_(net.parameters(), 10) #type: ignore 253 | optimizer.step() 254 | num_updates += 1 255 | 256 | batch_time = time.time()-st_time 257 | avg_tr_loss += l.item() 258 | avg_tr_time += batch_time 259 | 260 | effective_batch_sz = float(hist.shape[1]) 261 | if num_updates % params.iter_per_err == params.iter_per_err-1: 262 | print("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:", 263 | format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr) 264 | train_loss.append(avg_tr_loss/100) 265 | 266 | if params.log: 267 | msg_str_ = ("Epoch no:",epoch_num,"update:",num_updates, "| Avg train loss:", 268 | format(avg_tr_loss/100,'0.4f'), " learning_rate:%.5f"%lr) 269 | msg_str = str([str(ss) for ss in msg_str_]) 270 | logger_file.write(msg_str+'\n') 271 | logger_file.flush() 272 | 273 | avg_tr_loss = 0. 274 | if num_updates % params.iter_per_eval == params.iter_per_eval-1: 275 | print("Starting eval") 276 | val_nll_err = eval( 'nll', net, params, val_data_loader, bStepByStep, 277 | use_forcing=params.use_forcing, y_mean=y_mean, 278 | num_batches=500, dataset_name='val_dl nll') 279 | 280 | if params.log: 281 | logger_file.write('val nll: ' + str(val_nll_err)+'\n') 282 | logger_file.flush() 283 | 284 | # Save weights. 285 | if params.log and num_updates % params.iters_per_save == params.iters_per_save-1: 286 | msg_str = '\nSaving state, update iter:%d %s'%(num_updates, logging_dir) 287 | print(msg_str) 288 | logger_file.write( msg_str ); logger_file.flush() 289 | torch.save(net.state_dict(), logging_dir + '/checkpoints/ngsim_%06d'%num_updates + '.pth') #type: ignore 290 | 291 | 292 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2020 Apple Inc. All rights reserved. 3 | # 4 | 5 | [mypy] 6 | 7 | ignore_missing_imports = True 8 | # You may need to copy, paste and uncomment the following snippet for libraries that do not 9 | # support mypy yet. 10 | 11 | #[mypy-.*] 12 | #ignore_missing_imports = True 13 | 14 | # These are the settings for your own code 15 | [mypy-multiple_futures_prediction.*] 16 | # Disallow calls from functions with type annotation to functions with no type annotations 17 | disallow_untyped_calls = True 18 | # Disallow defs with no or incomplete type annotations 19 | disallow_untyped_defs = True 20 | # Type-check inside functions with no type annotations 21 | check_untyped_defs = True 22 | # Warns about uneeded ignore comments 23 | warn_unused_ignores = True 24 | 25 | ignore_missing_imports = True 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2020 Apple Inc. All rights reserved. 3 | # 4 | 5 | # Tell pip to use flit to build this pacakge 6 | [build-system] 7 | requires = ["flit"] 8 | build-backend = "flit.buildapi" 9 | 10 | [tool.flit.metadata] 11 | module = "multiple_futures_prediction" 12 | author = "Charlie Tang" 13 | author-email = "yichuan_tang@apple.com" 14 | 15 | license = "Apple Software" 16 | requires-python = ">=3.6" 17 | description-file="README.md" 18 | 19 | # List here all your dependencies 20 | requires = [ 21 | "attrdict", 22 | "gin_config", 23 | "torch", 24 | "scipy", 25 | "numpy", 26 | "opencv-python" 27 | ] 28 | 29 | [tool.flit.metadata.requires-extra] 30 | # Packages required for testing 31 | test = ["mypy"] 32 | 33 | [tool.flit.scripts] 34 | # Register your scripts here (feel free to rename the sample script) 35 | train_ngsim = "multiple_futures_prediction.cmd.train_ngsim_cmd:main" 36 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict==2.0.1 2 | gin_config==0.2.1 3 | torch==1.1.0 4 | scipy==1.3.1 5 | numpy==1.17.2 6 | opencv-python 7 | . 8 | --------------------------------------------------------------------------------