├── LICENSE ├── eval.py ├── model ├── obs_encoder.py ├── traj_ogm.py ├── ConvRNN.py ├── traj_cluster.py ├── policy_network.py ├── transformer.py ├── utils.py ├── full_model.py └── traj_decoder.py ├── README.md └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kguo-cs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | 5 | from model.full_model import Model 6 | 7 | # Initialize device: 8 | device = torch.device( "cuda:0") 9 | 10 | dataset="sdd" 11 | 12 | if dataset=="ind": 13 | horizon = 30 14 | fut_len = 30 15 | grid_extent = 25 16 | nei_dim=0 17 | type="test" 18 | 19 | from data.IND.inD import inD as DS 20 | else: 21 | horizon = 20 22 | fut_len = 12 23 | grid_extent = 20 24 | nei_dim=2 25 | type="sddtest" 26 | 27 | from data.SDD.sdd import sdd as DS 28 | 29 | 30 | net = Model(horizon, fut_len,nei_dim,grid_extent).float().to(device) 31 | 32 | 33 | if dataset=="ind": 34 | checkpoint = torch.load("./pretrained/indend.tar",map_location='cuda:0') 35 | elif dataset=="trajnet": 36 | checkpoint = torch.load("./pretrained/trajnetend.tar",map_location='cuda:0') 37 | else: 38 | checkpoint = torch.load("./pretrained/sddend.tar",map_location='cuda:0') 39 | 40 | test_set = DS(dataset,horizon=horizon, fut_len=fut_len, type="test", grid_extent=grid_extent) 41 | 42 | 43 | 44 | test_dl = DataLoader(test_set, 45 | batch_size=16, 46 | shuffle=True, 47 | num_workers=8 48 | ) 49 | 50 | 51 | net.load_state_dict(checkpoint['model_state_dict']) 52 | temp=checkpoint["temp"] 53 | 54 | net.eval() 55 | 56 | Minade = 0 57 | Minfde = 0 58 | Offroad = 0 59 | Offroad_count = 0 60 | val_batch_count = 0 61 | 62 | 63 | for epoch in range(10): 64 | 65 | with torch.no_grad(): 66 | # Load batch 67 | for k, data_val in enumerate(test_dl): 68 | 69 | min_ade,min_fde,off_road,off_road_count,count=net(data_val,temp=temp,type=type,device=device,num_samples=1000) 70 | 71 | Minade += min_ade.item()*count 72 | Minfde += min_fde.item()*count 73 | Offroad += off_road.item() 74 | Offroad_count += off_road_count.item() 75 | val_batch_count += count 76 | 77 | print("Epoch no:", epoch, 78 | "| temp", format(temp, '0.5f'), 79 | "| ade", format(Minade / val_batch_count, '0.3f'), 80 | "| fde", format(Minfde / val_batch_count, '0.3f'), 81 | "| offroad_rate", format(1-Offroad / Offroad_count, '0.3f')) -------------------------------------------------------------------------------- /model/obs_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torchvision.models as mdl 4 | import numpy as np 5 | 6 | class ObsEncoder(nn.Module): 7 | 8 | def __init__(self,nei_dim=2,hist_dim=2,scene_dim=32,motion_dim=64,grid_dim=25): 9 | 10 | super(ObsEncoder, self).__init__() 11 | 12 | self.nei_dim=nei_dim 13 | self.grid_dim=grid_dim 14 | 15 | coordinate = np.zeros((2, grid_dim, grid_dim)) 16 | centers = np.linspace(-1 + 1/ grid_dim , 1 - 1 / grid_dim , grid_dim) 17 | 18 | coordinate[0] = centers.reshape(-1, 1).repeat(grid_dim, axis=1).transpose() 19 | coordinate[1] = centers.reshape(-1, 1).repeat(grid_dim, axis=1) 20 | 21 | self.coordinate=torch.from_numpy(coordinate).float()[None] 22 | 23 | resnet34 = mdl.resnet34(pretrained=False) 24 | 25 | self.scene_enc=nn.Sequential(resnet34.conv1, resnet34.bn1, resnet34.relu, resnet34.maxpool, resnet34.layer1,nn.Conv2d(64, scene_dim, (2, 2), (2, 2)),nn.LeakyReLU(0.1)) 26 | 27 | self.hist_enc = nn.GRU(hist_dim, motion_dim, batch_first=True) 28 | 29 | if nei_dim!=0: 30 | 31 | self.nei_enc=nn.Sequential(nn.Conv2d(nei_dim,2,kernel_size=(5,5)),nn.MaxPool2d(2),nn.LeakyReLU(0.1),nn.Conv2d(2,2,kernel_size=(5,5)) ) 32 | 33 | self.hist_emb=nn.Sequential(nn.Linear(hist_dim+6*6*2, motion_dim),nn.LeakyReLU(0.1)) 34 | 35 | self.histrot_enc =nn.GRU(motion_dim, motion_dim,batch_first=True)# 36 | 37 | def forward(self,hist, neighbors, img,r_mat,type,device): 38 | 39 | scene_feats = self.scene_enc(img) 40 | 41 | motion_feats =self.hist_enc(hist)[1][0] 42 | 43 | motion_grid=torch.cat([motion_feats[:,:,None,None].repeat(1,1,self.grid_dim,self.grid_dim),self.coordinate.to(device).repeat(len(hist), 1, 1, 1)],dim=1) 44 | 45 | if type == "dist": 46 | scene_feats = scene_feats.detach() 47 | motion_feats = motion_feats.detach() 48 | motion_grid=motion_grid.detach() 49 | 50 | if self.nei_dim!=0: 51 | hist_rot = torch.einsum('nab,ntb->nta', r_mat, hist) 52 | 53 | nei_feats=self.nei_enc(neighbors.view(-1,self.nei_dim,self.grid_dim,self.grid_dim)).view(len(hist),hist.shape[1], -1) 54 | 55 | hist_feats=self.hist_emb(torch.cat([nei_feats,hist_rot],dim=-1)) 56 | 57 | motion_feats = self.histrot_enc(hist_feats)[1][0] 58 | 59 | return scene_feats, motion_feats,motion_grid 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /model/traj_ogm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .ConvRNN import ConvLSTM 4 | 5 | class OGMDecoder(torch.nn.Module): 6 | def __init__(self,fut_len,motion_dim=32,scene_dim=32,ogm_dim=32,n_layers=2,filter_size=5,grid_dim=25): 7 | super(OGMDecoder, self).__init__() 8 | 9 | self.fut_len=fut_len 10 | self.ogm_dim=ogm_dim 11 | self.conv1= nn.Sequential(nn.Conv2d( motion_dim+2, ogm_dim,1),nn.LeakyReLU(0.1)) 12 | self.conv2= nn.Sequential(nn.Conv2d( motion_dim+2, ogm_dim, 1),nn.LeakyReLU(0.1)) 13 | 14 | self.convlstm = ConvLSTM( input_dim=scene_dim, hidden_dims=[ogm_dim,ogm_dim], n_layers=n_layers,kernel_size=(3, 3)) 15 | 16 | self.x0=torch.zeros([grid_dim*grid_dim]) 17 | self.x0=nn.Parameter(self.x0) 18 | 19 | self.softmax2d = nn.Softmax2d() 20 | self.pad=filter_size//2 21 | self.grid_dim=grid_dim 22 | self.filter_size=filter_size 23 | self.unfold = nn.Unfold(kernel_size=(filter_size, filter_size), padding=0) 24 | self.Padding = nn.ConstantPad2d(filter_size//2, -10000) 25 | self.mask = self.unfold(self.Padding(torch.zeros(1, 1, grid_dim, grid_dim))).reshape(1, -1, grid_dim, grid_dim) 26 | self.flod= nn.Fold(output_size=(self.pad*2+grid_dim, self.pad*2+grid_dim), kernel_size=(filter_size, filter_size)) 27 | 28 | self.output_Conv = nn.Conv2d(ogm_dim,filter_size*filter_size,1) 29 | 30 | 31 | def forward(self, f_s,H,type,device): 32 | 33 | h=self.conv1(H) 34 | c=self.conv2(H) 35 | 36 | h_outputs = [] 37 | 38 | for t in range(self.fut_len): 39 | 40 | h = self.convlstm(f_s, first_timestep=(t == 0),h=h,c=c) #32,8,64,64 41 | 42 | h_outputs.append(h) 43 | 44 | h_outputs = torch.stack(h_outputs, 1).view(-1,self.ogm_dim,self.grid_dim,self.grid_dim) 45 | 46 | if type=="end" or type=="cluster": 47 | return None,h_outputs 48 | 49 | outputs = [] 50 | 51 | x = torch.softmax(self.x0, dim=0).view(1, 1, self.grid_dim, self.grid_dim).repeat(len(f_s), 1, 1, 1) 52 | 53 | weight_raw=self.output_Conv(h_outputs) #n*op_len,3*3,25,25 54 | 55 | weight=self.softmax2d(weight_raw+self.mask.to(device)).view(-1,self.fut_len,self.filter_size*self.filter_size,self.grid_dim,self.grid_dim) #8,9,64,64+self.mask #weight = weight_raw / weight_raw.sum(1, keepdim=True).clamp(min=1e-7) 56 | 57 | for t in range(self.fut_len): 58 | 59 | filter_x=weight[:,t]*x#8,9,64,64 60 | 61 | x=self.flod(filter_x.reshape(-1,self.filter_size*self.filter_size,self.grid_dim*self.grid_dim))[:,:,self.pad:-self.pad,self.pad:-self.pad] 62 | 63 | outputs.append(x) 64 | 65 | outputs = torch.stack(outputs, 1).view(-1,1,self.grid_dim,self.grid_dim) 66 | 67 | return outputs,h_outputs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repo contains the official implementation of our paper: "End-to-End Trajectory Distribution Prediction Based on Occupancy Grid Maps". 2 | Ke Guo, Wenxi Liu, Jia Pan. 3 | 4 | **CVPR 2022** 5 | [paper](http://arxiv.org/abs/2203.16910) 6 | 7 | # Installation 8 | 9 | ### Environment 10 | * Python >= 3.7 11 | * PyTorch == 1.8.0 12 | 13 | 14 | ### Data and pretrained model 15 | Please download the pretrained model and data from onedrive([https://connecthkuhk-my.sharepoint.com/:u:/g/personal/u3006612_connect_hku_hk/EXqC6hjGTphKh8TkjrwtByEB3FFZ_dpCu0Rs6N7CTG2gag?e=5q4Knz](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/u3006612_connect_hku_hk/Ei8gZNibG4lBhtJGGkmAGggB1aKgCXg2sYpxViE7PFqkwQ?e=YbB6c4)). Extract the zip file into the main folder. 16 | 17 | ### Data Preprocessing 18 | 19 | Here is the detail of data preprocessing. You can skip it by using the data from google drive. 20 | 21 | * SDD (Trajnet split) 22 | 23 | 1. Download the Trajnet split data from [Y-Net](https://github.com/HarshayuGirase/Human-Path-Prediction/tree/master/ynet). Put the data under [data/SDD](data/SDD) 24 | 25 | 2. Run [script](process_trajnet.py) to process the downloaded "train_trajnet.pkl" and "test_trajnet.pkl": 26 | ``` 27 | python data/SDD/process_trajnet.py 28 | ``` 29 | 30 | 31 | * SDD(P2T split) 32 | 1. Download the P2T split data from [P2T](https://github.com/nachiket92/P2T/tree/main/data/sdd). Put the data under [data/SDD](data/SDD) 33 | 34 | 2. Run [script](process_p2t.py) to process the downloaded "SDDtrain.mat", "SDDval.mat" and "SDDtest.mat": 35 | ``` 36 | python data/SDD/process_p2t.py 37 | ``` 38 | 39 | 40 | * inD 41 | 42 | 1. Obtain the processed inD data from [Y-Net](https://github.com/HarshayuGirase/Human-Path-Prediction/tree/master/ynet). Put the data under [data/SDD](data/IND) 43 | 44 | 2. Run [script](process_trajnet.py) to process the downloaded "inD_train.pickle" and "inD_test.pickle": 45 | ``` 46 | python data/SDD/process_inD.py 47 | ``` 48 | 49 | ### Training 50 | 51 | 52 | Training the model for Trajnet: 53 | 54 | ``` 55 | python train.py --dataset "trajnet" 56 | ``` 57 | For SDD(p2t split) or inD, the "trajnet" need to be replaced by "sdd" or "ind". 58 | 59 | ### Evaluation 60 | 61 | Evaluating on Trajnet dataset: 62 | 63 | ``` 64 | python eval.py --dataset "trajnet" 65 | ``` 66 | For SDD(p2t split) or inD, the "trajnet" need to be replaced by "sdd" or "ind". 67 | 68 | ## Citation 69 | 70 | ``` 71 | @inproceedings{guo2022end, 72 | title={End-to-End Trajectory Distribution Prediction Based on Occupancy Grid Maps}, 73 | author={Ke, Guo and Wenxi, Liu and Jia, Pan}, 74 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 75 | year={2022} 76 | } 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /model/ConvRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvLSTM_Cell(nn.Module): 5 | def __init__(self, input_dim, hidden_dim, kernel_size, bias=True): 6 | super(ConvLSTM_Cell, self).__init__() 7 | 8 | self.input_dim = input_dim 9 | self.hidden_dim = hidden_dim 10 | self.kernel_size = kernel_size 11 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 12 | self.bias = bias 13 | # img_size=64 14 | 15 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 16 | out_channels=4 * self.hidden_dim, 17 | kernel_size=self.kernel_size, 18 | padding=self.padding, bias=self.bias) 19 | 20 | # we implement LSTM that process only one timestep 21 | def forward(self, x, hidden): # x [batch, hidden_dim, width, height] 22 | h_cur, c_cur = hidden 23 | 24 | # x,h_cur=self.CEBlock(x,h_cur) 25 | 26 | combined = torch.cat([x, h_cur], dim=1) # concatenate along channel axis 27 | combined_conv = self.conv(combined) 28 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 29 | i = torch.sigmoid(cc_i) 30 | f = torch.sigmoid(cc_f) 31 | o = torch.sigmoid(cc_o) 32 | g = torch.tanh(cc_g) 33 | 34 | c_next = f * c_cur + i * g 35 | h_next = o * torch.tanh(c_next) 36 | 37 | # h_next, c_next, self.attentions = self.SEBlock(h_next, c_next) 38 | return h_next, c_next 39 | 40 | class ConvLSTM(nn.Module): 41 | def __init__(self, input_dim, hidden_dims, n_layers, kernel_size): 42 | super(ConvLSTM, self).__init__() 43 | self.input_dim = input_dim 44 | self.hidden_dims = hidden_dims 45 | self.n_layers = n_layers 46 | self.kernel_size = kernel_size 47 | self.H, self.C = [], [] 48 | 49 | cell_list = [] 50 | for i in range(0, self.n_layers): 51 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dims[i - 1] 52 | cell_list.append(ConvLSTM_Cell(input_dim=cur_input_dim, 53 | hidden_dim=self.hidden_dims[i], 54 | kernel_size=self.kernel_size)) 55 | 56 | self.cell_list = nn.ModuleList(cell_list) 57 | 58 | def forward(self, input_,first_timestep=True, h=None,c=None): # input_ [batch_size, 1, channels, width, height] 59 | #batch_size = input_.data.size()[0] 60 | if first_timestep==True: 61 | self.H=[h] 62 | self.C=[c] 63 | for i in range(self.n_layers-1): 64 | self.H.append(torch.zeros_like(h)) 65 | self.C.append(torch.zeros_like(h)) 66 | 67 | 68 | for j, cell in enumerate(self.cell_list): 69 | if j == 0: # bottom layer 70 | self.H[j], self.C[j] = cell(input_, (self.H[j], self.C[j])) 71 | else: 72 | self.H[j], self.C[j] = cell(self.H[j - 1], (self.H[j], self.C[j])) 73 | 74 | return self.H[-1] # (hidden, output) 75 | -------------------------------------------------------------------------------- /model/traj_cluster.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .transformer import MultiHeadedAttention,PositionwiseFeedForward,DecoderLayer,Decoder,EncoderLayer,Encoder 4 | import copy 5 | 6 | 7 | 8 | 9 | class TrajCluster(torch.nn.Module): 10 | 11 | def __init__(self,fut_len=25,num_cluster=20,motion_dim=32,dropout=0.1): 12 | super(TrajCluster, self).__init__() 13 | 14 | h=8 15 | 16 | d_model=64 17 | 18 | d_ff=128 19 | 20 | N=3 21 | 22 | self.fut_len=fut_len 23 | 24 | self.num_cluster = num_cluster 25 | 26 | attn = MultiHeadedAttention(h, d_model,dropout) 27 | 28 | #spatial_attn=MultiHeadedAttention_spatial(h,d_model,dropout) 29 | 30 | ff = PositionwiseFeedForward(d_model, d_ff, dropout) 31 | 32 | c = copy.deepcopy 33 | 34 | self.temporal_embed = nn.Sequential(nn.Linear(fut_len*2,d_model),nn.ReLU()) 35 | 36 | #self.temporal_encoder=STEncoder(STEncoderLayer(d_model, c(attn),c(spatial_attn),c(ff), c(ff), dropout), N) 37 | self.temporal_encoder=Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N) 38 | 39 | #self.dest_decoder = Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), 1) 40 | 41 | self.decoder = Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N) 42 | 43 | 44 | self.dmodel=d_model 45 | 46 | self.generator= nn.Linear(d_model,fut_len*2) 47 | 48 | self.tgt_embed= nn.Parameter(torch.randn([num_cluster,d_model])) 49 | 50 | self.hist_emb1=nn.Sequential(nn.Linear(motion_dim, d_model),nn.LeakyReLU(0.1)) 51 | # self.hist_emb2=nn.Sequential(nn.Linear(32, d_model-2),nn.LeakyReLU(0.1)) 52 | 53 | #self.dest_genetator=nn.Sequential(nn.Linear(d_model, 2)) 54 | 55 | # self.dest_decoder=nn.Sequential(nn.Linear(d_model, 2)) 56 | 57 | 58 | #self.encoder_dest =nn.Sequential(nn.Linear(2, 32),nn.LeakyReLU(0.1)) 59 | 60 | def forward(self,traj,hist_feats): 61 | 62 | #num_samples=traj.shape[1] 63 | 64 | #hist_feats = hist_feats[:, None].repeat(1, num_samples, 1) 65 | 66 | traj=traj.reshape(len(traj),traj.shape[1],-1) 67 | 68 | traj_vec = traj#torch.cat([hist_feats, traj], dim=-1) 69 | 70 | rel_embedding = self.temporal_embed(traj_vec) # n,6,d_model 71 | 72 | # rel_s = (batch_abs[:obs_len, :, None] - batch_abs[:obs_len, None]).permute(1,2,0,3) # batch_abs : 20,263,2 73 | # 74 | # edge_mask = nei_list[:obs_len].bool() & seq_list[:obs_len, :, None].bool() 75 | # 76 | # edge_mask=edge_mask.permute(1,2,0) #a,b,t 77 | # 78 | # edge_list=torch.where(edge_mask==1) 79 | # 80 | # spatial_embedding = self.spatial_embed(rel_s[edge_list]) 81 | 82 | memory = self.temporal_encoder(rel_embedding, None) # num,7,d_Model 83 | 84 | #tgt_embedding =torch.einsum("rab,na->nrb",self.tgt_embed,hist_feats) 85 | 86 | hist_feats=self.hist_emb1(hist_feats) 87 | 88 | tgt_embedding=self.tgt_embed.repeat(len(traj), 1, 1)+hist_feats[:,None] 89 | 90 | tgt_embedding=self.decoder(tgt_embedding, memory, None, None) 91 | 92 | #tgt_embedding=self.decoder(dest_features, memory, None, None) 93 | 94 | #dest = self.dest_genetator(dest_features) 95 | 96 | output = self.generator(tgt_embedding) 97 | 98 | # output=torch.cat([inter,dest],dim=-1) 99 | 100 | 101 | return output.view(-1,self.num_cluster,self.fut_len,2) -------------------------------------------------------------------------------- /model/policy_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | from .ConvRNN import ConvLSTM 5 | 6 | class PolicyNet(nn.Module): 7 | 8 | def __init__(self,horizon,scene_dim=32,motion_dim=32,grid_dim=25,ns_dim=32): 9 | 10 | super(PolicyNet, self).__init__() 11 | 12 | self.horizon=horizon 13 | self.grid_dim=grid_dim 14 | self.ns_dim=ns_dim 15 | 16 | self.action= torch.tensor([[0, 2], [2, 0], [0, -2], [-2, 0]])/grid_dim 17 | self.transition = torch.tensor([[[[0, 0, 0], 18 | [0, 0, 0], 19 | [0, 1.0, 0]]], 20 | 21 | [[[0, 0, 0], 22 | [0, 0, 1], 23 | [0, 0, 0]]], 24 | 25 | [[[0, 1, 0], 26 | [0, 0, 0], 27 | [0, 0, 0]]], 28 | 29 | [[[0, 0, 0], 30 | [1, 0, 0], 31 | [0, 0, 0]]], 32 | 33 | [[[0, 0, 0], 34 | [0, 0, 0], 35 | [0, 0, 0]]]]) 36 | 37 | self.conv_h= nn.Sequential(nn.Conv2d( motion_dim+2, ns_dim, 1),nn.LeakyReLU(0.1)) 38 | self.conv_c= nn.Sequential(nn.Conv2d( motion_dim+2, ns_dim, 1),nn.LeakyReLU(0.1)) 39 | self.convlstm = ConvLSTM(input_dim=scene_dim, hidden_dims=[ns_dim], n_layers=1,kernel_size=(3, 3)) 40 | self.conv_r = nn.Conv2d(ns_dim, 5, 1) 41 | 42 | def sample_policy(self,pi,waypts_e, num_samples,temp,device): 43 | 44 | state=waypts_e[:,:1].repeat(1,num_samples,1) 45 | 46 | waypts=[state] 47 | 48 | waypts_length = self.horizon + torch.zeros([len(pi) * num_samples]).int() 49 | 50 | for t in range(self.horizon - 1): 51 | 52 | policy_sample = F.grid_sample(pi[:, t, :4], grid=state[:, :, None],align_corners=False).permute(0, 2, 3, 1).reshape( -1, 4) # policy: N,4,25,25 state : N, num_samples,1,2 -> N,4,numsamples,1 53 | # #input {N,C,H_in,W_in} grid {N,H_out,W_out,2} => {N,C,H_out,W_out},padding_mode="border" 54 | policy_sample = torch.clamp_min_(policy_sample, min=1e-10) 55 | 56 | move_prob = torch.sum(policy_sample, dim=1) 57 | 58 | end = (move_prob < torch.rand_like(move_prob)) 59 | 60 | waypts_length[end]=torch.clamp_max_(waypts_length[end], t + 1) 61 | 62 | prob = policy_sample / move_prob[:, None] 63 | 64 | if temp==0: 65 | value_indexes = torch.distributions.Categorical(prob).sample()[:,None] 66 | soft_samples_gumble = torch.zeros_like(prob).scatter_(1, value_indexes, 1) 67 | else: 68 | gumble_samples = prob.log() - torch.log(1e-10 - torch.log(torch.rand_like(prob) + 1e-10)) 69 | soft_samples_gumble = F.softmax(gumble_samples / temp, dim=1) 70 | 71 | gumbel_action_mean = torch.matmul(soft_samples_gumble, self.action.to(device)) 72 | 73 | state = state + gumbel_action_mean.view(-1, num_samples, 2) 74 | 75 | waypts.append(state) 76 | 77 | waypts=torch.stack(waypts,dim=1) 78 | 79 | return waypts, waypts_length 80 | 81 | def forward(self,scene_feats, motion_grid,device): 82 | 83 | ns_feats=[] 84 | h=self.conv_h(motion_grid) 85 | c=self.conv_c(motion_grid) 86 | 87 | for t in range(self.horizon): 88 | h = self.convlstm(scene_feats, first_timestep=(t == 0),h=h,c=c) #32,8,64,64 89 | ns_feats.append(h) 90 | 91 | ns_feats=torch.stack(ns_feats,dim=1) 92 | 93 | r_n=self.conv_r(ns_feats.view(-1,self.ns_dim,self.grid_dim,self.grid_dim)).view(-1,self.horizon,5,self.grid_dim,self.grid_dim) 94 | 95 | v = torch.zeros_like(r_n[:, 0, :1]) 96 | 97 | pi = torch.zeros_like(r_n) # 98 | 99 | for k in range(self.horizon - 1, -1, -1): 100 | v_pad = F.pad(v, pad=(1, 1, 1, 1), mode='constant', value=-1000) # # v_pad = F.pad(v, pad=(1, 1, 1, 1), mode='replicate') 101 | 102 | q = r_n[:, k] + F.conv2d(v_pad, self.transition.to(device), stride=1) 103 | 104 | v = torch.logsumexp(q, dim=1, keepdim=True) 105 | 106 | pi[:, k] = torch.exp(q - v) 107 | 108 | return pi,ns_feats#[:,:-1] 109 | -------------------------------------------------------------------------------- /model/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math, copy, time 5 | 6 | 7 | def clones(module, N): 8 | "Produce N identical layers." 9 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 10 | 11 | class Encoder(nn.Module): 12 | "Core encoder is a stack of N layers" 13 | 14 | def __init__(self, layer, N): 15 | super(Encoder, self).__init__() 16 | self.layers = clones(layer, N) 17 | self.norm = torch.nn.LayerNorm(layer.size) 18 | 19 | def forward(self, x, mask): 20 | "Pass the input (and mask) through each layer in turn." 21 | for layer in self.layers: 22 | x = layer(x, mask) 23 | 24 | return self.norm(x) 25 | 26 | class SublayerConnection(nn.Module): 27 | """ 28 | A residual connection followed by a layer norm. 29 | Note for code simplicity the norm is first as opposed to last. 30 | """ 31 | def __init__(self, size, dropout): 32 | super(SublayerConnection, self).__init__() 33 | self.norm = torch.nn.LayerNorm(size) 34 | self.dropout = nn.Dropout(dropout) 35 | 36 | def forward(self, x, sublayer,e=None): 37 | "Apply residual connection to any sublayer with the same size." 38 | 39 | return x + self.dropout(sublayer(self.norm(x))) 40 | 41 | 42 | def attention(query, key, value, mask=None, dropout=None): 43 | "Compute 'Scaled Dot Product Attention'" 44 | d_k = query.size(-1) 45 | #scores = torch.sum(query[:,:,:,None]+key[:,:,None],dim=-1)/ math.sqrt(d_k) 46 | 47 | # torch.matmul(query, key.transpose(-2, -1)) #n,h,6,d_k *n,h,d_k,6 scores:n,h,6,6 48 | 49 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)#n,h,6,d_k *n,h,d_k,6 scores:n,h,6,6 50 | if mask is not None:#n,1,1,t 51 | scores = scores.masked_fill(mask == 0, -1e9) 52 | #scores = scores.masked_fill(mask == 0, 0) 53 | # p_attn = scores / (torch.sum(mask, dim=-1, keepdim=True) + 1e-9) 54 | # else: 55 | # p_attn = scores / (scores.shape[-1]) 56 | p_attn = F.softmax(scores, dim = -1) 57 | if dropout is not None: 58 | p_attn = dropout(p_attn) 59 | 60 | return torch.matmul(p_attn, value)#, p_attn#n,h,6,6 n,h,6,k 61 | 62 | 63 | class MultiHeadedAttention(nn.Module): 64 | def __init__(self, h, d_model, dropout=0.1): 65 | "Take in model size and number of heads." 66 | super(MultiHeadedAttention, self).__init__() 67 | assert d_model % h == 0 68 | # We assume d_v always equals d_k 69 | self.d_k = d_model // h 70 | self.h = h 71 | self.linears = clones(nn.Linear(d_model, d_model), 4) 72 | self.attn = None 73 | self.dropout = nn.Dropout(p=dropout) 74 | 75 | def forward(self, query, key, value, mask=None): 76 | "Implements Figure 2" 77 | if mask is not None: 78 | # Same mask applied to all h heads. 79 | mask = mask.unsqueeze(1) 80 | nbatches = query.size(0) 81 | 82 | # 1) Do all the linear projections in batch from d_model => h x d_k 83 | query, key, value = [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 84 | for l, x in zip(self.linears, (query, key, value))]#n,h,6,d_k 85 | 86 | # 2) Apply attention on all the projected vectors in batch. 87 | x = attention(query, key, value, mask=mask, dropout=self.dropout)#n,h,6,d_k, self.attn 88 | 89 | # 3) "Concat" using a view and apply a final linear. 90 | x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) 91 | return self.linears[-1](x) 92 | 93 | class PositionwiseFeedForward(nn.Module): 94 | "Implements FFN equation." 95 | def __init__(self, d_model, d_ff, dropout=0.1): 96 | super(PositionwiseFeedForward, self).__init__() 97 | self.w_1 = nn.Linear(d_model, d_ff) 98 | self.w_2 = nn.Linear(d_ff, d_model) 99 | self.dropout = nn.Dropout(dropout) 100 | 101 | def forward(self, x): 102 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 103 | 104 | class Decoder(nn.Module): 105 | "Generic N layer decoder with masking." 106 | 107 | def __init__(self, layer, N): 108 | super(Decoder, self).__init__() 109 | self.layers = clones(layer, N) 110 | self.norm = torch.nn.LayerNorm(layer.size) 111 | 112 | def forward(self, x, memory, src_mask, tgt_mask): 113 | for layer in self.layers: 114 | x = layer(x, memory, src_mask, tgt_mask) 115 | return self.norm(x) 116 | 117 | 118 | class DecoderLayer(nn.Module): 119 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 120 | 121 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 122 | super(DecoderLayer, self).__init__() 123 | self.size = size 124 | self.self_attn = self_attn 125 | self.src_attn = src_attn 126 | self.feed_forward = feed_forward 127 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 128 | 129 | def forward(self, x, memory, src_mask, tgt_mask): 130 | "Follow Figure 1 (right) for connections." 131 | m = memory 132 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 133 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 134 | return self.sublayer[2](x, self.feed_forward) 135 | 136 | class EncoderLayer(nn.Module): 137 | "Encoder is made up of self-attn and feed forward (defined below)" 138 | def __init__(self, size, self_attn, feed_forward, dropout): 139 | super(EncoderLayer, self).__init__() 140 | self.self_attn = self_attn 141 | self.feed_forward = feed_forward 142 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 143 | self.size = size 144 | 145 | def forward(self, x, mask,e=None): 146 | 147 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 148 | 149 | return self.sublayer[1](x, self.feed_forward) 150 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def gaussian_2d( mu1mu2s1s2rho, x1x2): 6 | 7 | x1, x2 = x1x2[:, 0], x1x2[:, 1] 8 | mu1, mu2, s1, s2, rho = ( 9 | mu1mu2s1s2rho[:, 0], 10 | mu1mu2s1s2rho[:, 1], 11 | mu1mu2s1s2rho[:, 2], 12 | mu1mu2s1s2rho[:, 3], 13 | mu1mu2s1s2rho[:, 4], 14 | ) 15 | 16 | norm1 = x1 - mu1 17 | norm2 = x2 - mu2 18 | #print(torch.min(s1),torch.min(s2),torch.max(torch.abs(rho))) 19 | 20 | s1s2 = s1 * s2 21 | 22 | z = (norm1 / s1) ** 2 + (norm2 / s2) ** 2 - 2 * rho * norm1 * norm2 / s1s2 23 | 24 | neg_rho = 1 - rho ** 2 25 | 26 | 27 | #log_prob=-z/(2*neg_rho)-torch.log(s1s2)-1/2*torch.log(neg_rho) 28 | 29 | #ent= 1/2*torch.log(neg_rho)+torch.log(s1s2)#+(1+torch.log(2*np.pi)) 30 | ent= 1/2*torch.log(neg_rho)+torch.log(s1s2)+np.log(2*np.pi) 31 | 32 | neg_log_prob_ent=z/(2*neg_rho)#neg_(log_prob+ent) 33 | 34 | neg_log_prob=neg_log_prob_ent+ent 35 | 36 | #print(torch.max(neg_log_prob)) 37 | # numerator = torch.exp(-z / (2 * neg_rho)) 38 | # denominator = 2 * np.pi * s1s2 * torch.sqrt(neg_rho) 39 | # 40 | # neg_log_prob1=-torch.log(numerator/denominator) 41 | 42 | return neg_log_prob 43 | 44 | import scipy.io as scp 45 | 46 | 47 | or_lbls = scp.loadmat('./data/SDD/img_lbls.mat') 48 | img_lbls = or_lbls['img_lbls'] 49 | 50 | 51 | def offroad_rate(y_pred, ref_pos, ds_ids, y_gt,scale=1, all_timestamps=False): 52 | """ 53 | Computes offroad rate for Stanford drone dataset 54 | 55 | Inputs 56 | y_pred, y_gt, all_timestamps: Similar to minADE_k and minFDE_k functions 57 | img_lbls: path/obstacle labels, binary images from SDD 58 | ref_pos: global co-ordinates of agent location at the time of prediction, for each instance in the batch 59 | dsIds: scene Ids for each instance in the batch 60 | 61 | Output 62 | offroad rate for batch 63 | """ 64 | 65 | # Transform to global co-ordinates 66 | y_gt_global = (y_gt+ref_pos[:,None])/scale[:,None,None]# N,12,2 67 | y_pred_global = (y_pred+ref_pos[:,None,None])/scale[:,None,None,None]# N, 20,12,2 68 | # Compute offroad rate 69 | num_path = torch.zeros(y_pred.shape[2]) 70 | counts = torch.zeros(y_pred.shape[2]) 71 | 72 | N,s,op_len,k=y_pred.shape 73 | 74 | for k in range(N): 75 | lbl_img = img_lbls[0][ds_ids[k]] 76 | 77 | for m in range(op_len): 78 | row_gt = int(y_gt_global[k, m, 1].item()) 79 | col_gt = int(y_gt_global[k, m, 0].item()) 80 | if lbl_img[row_gt, col_gt]: 81 | 82 | for n in range(s): 83 | counts[m] += 1 84 | # If predicted location is on a path and within image boundaries: 85 | row = int(y_pred_global[k, n, m, 1].item()) 86 | col = int(y_pred_global[k, n, m, 0].item()) 87 | 88 | if -1nta', r_mat, fut) 99 | 100 | if type=="dist": 101 | expert_prob = self.net_t(motion_feats, scene_feats, ns_feats, waypts_e[:, :, None], waypt_lengths_e, 102 | r_mat, 103 | omg_feats, fut_rot,device) 104 | 105 | policy_l = - pi[bc_targets].log().sum() 106 | 107 | traj_l = gaussian_2d(expert_prob.reshape(-1, 5), fut_rot.reshape(-1, 2)).sum() 108 | 109 | traj_back = torch.einsum('nab,nsta->ntsb', r_mat, traj_generated).reshape(-1, 1, num_samples, 2) 110 | 111 | ogms_rce = -F.grid_sample(ogms, grid=traj_back / self.grid_extent, padding_mode="border", 112 | align_corners=False).log().sum() / num_samples 113 | 114 | loss = (policy_l + traj_l + ogms_rce * beta) / n_batch 115 | 116 | else: 117 | 118 | if type=="cluster": 119 | traj_generated=traj_generated.detach() 120 | motion_feats=motion_feats.detach() 121 | 122 | traj_clustered = self.net_c(traj_generated, motion_feats)## n,20,12,2 123 | 124 | loss =min_ade = min_ade_k(traj_clustered, fut_rot, scale)# 125 | 126 | # if type=="multi_task": 127 | # 128 | # loss=(ogms_ce+policy_l + traj_l + ogms_rce * beta) / n_batch+min_ade 129 | 130 | if type=="test": 131 | 132 | min_fde=min_fde_k(traj_clustered,fut_rot,scale) 133 | 134 | return min_ade,min_fde,torch.tensor(0.0),torch.tensor(1.0),n_batch 135 | 136 | elif type=="sddtest": 137 | 138 | min_fde=min_fde_k(traj_clustered,fut_rot,scale) 139 | 140 | ref_pos=ref_pos.float().to(device) 141 | 142 | traj_clustered_back=torch.einsum('nab,nsta->nstb', r_mat, traj_clustered) 143 | 144 | offroad,offroad_sum=offroad_rate(traj_clustered_back,ref_pos,ds_id,fut,scale) 145 | 146 | return min_ade,min_fde,offroad,offroad_sum,n_batch 147 | 148 | # elif type=="ns_test": 149 | # return traj_clustered,ref_pos,ds_id 150 | 151 | return loss, policy_l, traj_l, ogms_rce, ogms_ce, min_ade, n_batch 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /model/traj_decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | from .transformer import attention 5 | 6 | class Hidden2Normal(torch.nn.Module): 7 | def __init__(self, hidden_dim): 8 | super(Hidden2Normal, self).__init__() 9 | self.linear = torch.nn.Linear(hidden_dim, 5) 10 | 11 | def forward(self, hidden_state): 12 | normal = self.linear(hidden_state) 13 | 14 | normal[..., 2] =torch.exp(normal[..., 2]) 15 | normal[..., 3] =torch.exp(normal[..., 3]) 16 | normal[..., 4] = torch.tanh(normal[..., 4]) 17 | 18 | return normal 19 | 20 | class TrajGenerator(nn.Module): 21 | 22 | def __init__(self,horizon=25,fut_len=25,grid_extent=20,motion_dim=32,scene_dim=32,d_model=64,ns_dim=32,ogm_dim=32,head=4): 23 | 24 | super(TrajGenerator, self).__init__() 25 | 26 | self.grid_dim=25 27 | self.grid_extent=grid_extent 28 | self.head=head 29 | self.ns_dim=ns_dim 30 | self.ogm_dim=ogm_dim 31 | self.d_model = d_model 32 | self.scene_dim = scene_dim 33 | self.fut_len=fut_len 34 | self.horizon =horizon 35 | 36 | self.hist_emb = nn.Sequential(nn.Linear(motion_dim, d_model), nn.LeakyReLU(0.1)) 37 | 38 | self.waypt_emb = nn.Sequential(nn.Linear(scene_dim+ns_dim+2, d_model),nn.LeakyReLU(0.1)) 39 | 40 | self.waypt_enc_gru = nn.GRU(d_model, d_model,batch_first=True) 41 | 42 | self.waypt_att_emb = nn.Sequential(nn.Linear(d_model+scene_dim+ogm_dim+2, d_model),nn.LeakyReLU(0.1)) 43 | 44 | self.d_k = d_model // head 45 | self.linear_k =nn.Linear(d_model, d_model) 46 | self.linear_q =nn.Linear(d_model, d_model) 47 | self.linear_v =nn.Linear(d_model, d_model) 48 | 49 | self.dec_gru = nn.GRUCell(d_model, d_model) 50 | 51 | self.op_traj =Hidden2Normal(d_model) 52 | 53 | self.dropout = nn.Dropout(p=0.1) 54 | 55 | def forward(self,motion_feats,scene_feats,ns_feats,waypts,waypt_lengths,r_mat,omg_feats, fut_rot,device): 56 | 57 | num_samples=waypts.shape[2] 58 | 59 | local_ns = F.grid_sample(ns_feats.view(-1,self.ns_dim,self.grid_dim,self.grid_dim),waypts.view(-1,num_samples,1,2),padding_mode="border", align_corners=False) 60 | 61 | local_ns = local_ns.reshape(-1,self.horizon,self.ns_dim,num_samples).permute(0,3,1,2).reshape(-1, self.horizon, self.ns_dim) 62 | 63 | local_scene = F.grid_sample(scene_feats, grid=waypts,padding_mode="border", align_corners=False).permute(0, 3, 2,1).reshape(-1, self.horizon, self.scene_dim) 64 | 65 | waypts_rot = torch.einsum('nab,ntsb->nsta', r_mat, waypts).reshape(-1, self.horizon, 2) 66 | 67 | h_feats=self.hist_emb(motion_feats) 68 | 69 | h=h_feats.repeat_interleave(num_samples,dim=0) 70 | 71 | # Encode waypoints: 72 | waypts_cat = torch.cat((waypts_rot, local_scene,local_ns), dim=-1) 73 | waypts_feats_all = self.waypt_emb(waypts_cat) 74 | 75 | emb_packed = nn.utils.rnn.pack_padded_sequence(waypts_feats_all, waypt_lengths,enforce_sorted=False, batch_first=True) 76 | h_waypt_packed, _ = self.waypt_enc_gru(emb_packed) 77 | h_waypt, _ = nn.utils.rnn.pad_packed_sequence(h_waypt_packed, batch_first=True) 78 | 79 | nbatches = h_waypt.shape[0] 80 | 81 | traj = [] 82 | 83 | mask = torch.zeros_like(h_waypt[:,:,0]) 84 | 85 | for i in range(nbatches): 86 | mask[i][:waypt_lengths[i]]=1 87 | 88 | mask=mask[:,None,None] 89 | 90 | pos_rot=waypts_rot[:, 0] 91 | 92 | if fut_rot is None: 93 | 94 | omg_feats = omg_feats.view(-1, self.fut_len, self.ogm_dim, self.grid_dim, self.grid_dim) 95 | else: 96 | fut_prev_rot=torch.cat([pos_rot[:,None],fut_rot[:,:-1]],dim=1) 97 | 98 | fut_prev = torch.einsum('nab,ntsa->ntsb', r_mat, fut_prev_rot[:,:,None])/self.grid_extent 99 | 100 | local_scene_all = F.grid_sample(scene_feats, grid=fut_prev,padding_mode="border" , align_corners=False).permute(0,3, 2, 1).reshape(-1,self.fut_len,self.scene_dim) 101 | 102 | local_omg_all=F.grid_sample(omg_feats, grid=fut_prev.reshape(-1,1,1,2) ,padding_mode="border", align_corners=False).reshape(-1,self.fut_len,self.ogm_dim) 103 | 104 | local_feats_all=torch.cat([local_scene_all,local_omg_all,fut_prev_rot],dim=-1) 105 | 106 | key=self.linear_k(h_waypt).view(nbatches, -1, self.head, self.d_k).transpose(1, 2) 107 | 108 | value=self.linear_v(h_waypt).view(nbatches, -1, self.head, self.d_k).transpose(1, 2) 109 | 110 | for t in range(self.fut_len): 111 | 112 | query=self.linear_q(h).view(nbatches, self.head, 1, self.d_k) 113 | 114 | ip1=attention(query,key,value, mask=mask).view(-1,self.head*self.d_k) 115 | 116 | if fut_rot is None: 117 | pos = torch.einsum('nab,ntsa->ntsb', r_mat, pos_rot.view(-1, 1, num_samples, 2))/self.grid_extent 118 | 119 | scene_omg_feats=torch.cat([scene_feats,omg_feats[:,t]],dim=1) 120 | local_feats = F.grid_sample(scene_omg_feats, grid=pos,padding_mode="border", align_corners=False)[:, :, 0].permute(0, 2,1).reshape( -1, scene_omg_feats.shape[1]) 121 | ip2 = torch.cat([ip1, local_feats, pos_rot], dim=-1) 122 | else: 123 | ip2 = torch.cat([ip1,local_feats_all[:,t]],dim=-1) 124 | 125 | ip=self.waypt_att_emb(ip2) 126 | 127 | h = self.dec_gru(ip, h) 128 | 129 | fut_v=self.op_traj(h) 130 | 131 | if fut_rot is None: 132 | 133 | mu=fut_v[:,:2] 134 | sigmax=fut_v[:,2] 135 | sigmay=fut_v[:,3] 136 | rho=fut_v[:,4] 137 | 138 | a=torch.sqrt(1+rho) 139 | b=torch.sqrt(1-rho) 140 | 141 | A=torch.zeros([nbatches,2,2]).to(device) 142 | 143 | A[:,0,0]=(a+b)*sigmax 144 | A[:,0,1]=(a-b)*sigmax 145 | A[:,1,0]=(a-b)*sigmay 146 | A[:,1,1]=(a+b)*sigmay 147 | 148 | z=torch.randn_like(mu) 149 | 150 | pos_rot=torch.einsum('nab,nb->na',A/2,z) + mu+pos_rot 151 | 152 | else: 153 | pos_rot=fut_v 154 | pos_rot[:,:2]=fut_v[:,:2]+fut_prev_rot[:,t] 155 | 156 | traj.append(pos_rot) 157 | 158 | traj = torch.stack(traj,dim=1) 159 | 160 | return traj 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | --------------------------------------------------------------------------------