├── requirements.txt ├── no_title_thumbnail.png ├── datasets ├── __pycache__ │ ├── helper.cpython-36.pyc │ └── argoverse_pickle_loader.cpython-36.pyc ├── argoverse_lane_loader.py ├── helper.py └── preprocess_data.py ├── scripts ├── train_utils.py ├── evaluate_network.py └── train.py ├── README.md └── models ├── EquiLinear.py ├── rho1_ECCO.py ├── rho_reg_ECCO.py ├── RelEquiCtsConv.py └── EquiCtsConv.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.3 2 | pandas>=1.0.3 3 | torch==1.5.0 -------------------------------------------------------------------------------- /no_title_thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/no_title_thumbnail.png -------------------------------------------------------------------------------- /datasets/__pycache__/helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/datasets/__pycache__/helper.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/argoverse_pickle_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/datasets/__pycache__/argoverse_pickle_loader.cpython-36.pyc -------------------------------------------------------------------------------- /scripts/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import gc 3 | 4 | 5 | def euclidean_distance(a, b, epsilon=1e-9): 6 | return torch.sqrt(torch.sum((a - b)**2, axis=-1) + epsilon) 7 | 8 | 9 | def loss_fn(pr_pos, gt_pos, num_fluid_neighbors, car_mask): 10 | gamma = 0.5 11 | neighbor_scale = 1 / 40 12 | importance = torch.exp(-neighbor_scale * num_fluid_neighbors) 13 | dist = euclidean_distance(pr_pos, gt_pos)**gamma 14 | mask_dist = dist * car_mask 15 | batch_losses = torch.mean(importance * mask_dist, axis=-1) 16 | # print(batch_losses) 17 | return torch.sum(batch_losses) 18 | 19 | def clean_cache(device): 20 | if device == torch.device('cuda'): 21 | torch.cuda.empty_cache() 22 | if device == torch.device('cpu'): 23 | # gc.collect() 24 | pass 25 | 26 | def get_lr(optimizer): 27 | for param_group in optimizer.param_groups: 28 | return param_group['lr'] 29 | 30 | def unsqueeze_n(tensor, n): 31 | for i in range(n): 32 | tensor = tensor.unsqueeze(-1) 33 | return tensor 34 | 35 | 36 | def normalize_input(tensor_dict, scale, train_window): 37 | pos_keys = (['pos' + str(i) for i in range(train_window + 1)] + 38 | ['pos_2s', 'lane', 'lane_norm']) 39 | vel_keys = (['vel' + str(i) for i in range(train_window + 1)] + 40 | ['vel_2s']) 41 | max_pos = torch.cat([torch.max(tensor_dict[pk].reshape(tensor_dict[pk].shape[0], -1), axis=-1) 42 | .values.unsqueeze(-1) 43 | for pk in pos_keys], axis=-1) 44 | max_pos = torch.max(max_pos, axis=-1).values 45 | 46 | for pk in pos_keys: 47 | tensor_dict[pk][...,:2] = (tensor_dict[pk][...,:2] - unsqueeze_n(max_pos, len(tensor_dict[pk].shape)-1)) / scale 48 | 49 | for vk in vel_keys: 50 | tensor_dict[vk] = tensor_dict[vk] / scale 51 | 52 | return tensor_dict, max_pos 53 | 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trajectory Prediction using Equivariant Continuous Convolution (ECCO) 2 | 3 | This is the codebase for the ICLR 2021 paper [Trajectory Prediction using Equivariant Continuous Convolution](https://arxiv.org/abs/2010.11344), by Robin Walters, Jinxi Li and Rose Yu. 4 | 5 | ![Thumbnail](no_title_thumbnail.png) 6 | 7 | ## Installation 8 | 9 | This codebase is trained on Python 3.6.6+. For the usage of argoverse dataset, [argoverse-api](https://github.com/argoai/argoverse-api) is required. We recommand the reader to follow their guide to install the complete api and datasets. Other requirements include: 10 | - numpy>=1.18.3 11 | - pandas>=1.0.3 12 | - torch==1.5.0 13 | 14 | Dependency can be installed using the following command: 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Data Preparation 20 | 21 | Original data could be downloaded from [argoverse](https://www.argoverse.org/data.html). To generate the training and validation data 22 | 1. Set the path to `argoverse_forecasting` in `datasets/preprocess_data.py` scripts. 23 | 2. Run the script 24 | ```bash 25 | python preprocess_data.py 26 | ``` 27 | The data will be stored in `path/to/argoverse_forecasting/train(val)/lane_data`. 28 | 29 | ## Data Download 30 | 31 | If you want to skip the data generation part, the link to preprocessed data will be provided soon. 32 | 33 | ## Model Training and Evaluation 34 | 35 | Here are commands to train the model. The evaluation will be provided after the model is trained. 36 | 37 | For -ECCO, run the following command 38 | ```bash 39 | python train.py --dataset_path /path/to/argoverse_forecasting/ --rho1 --model_name rho_1_ecco --train --evaluation 40 | ``` 41 | 42 | For -ECCO, run the following command 43 | ```bash 44 | python train.py --dataset_path /path/to/argoverse_forecasting/ --rho-reg --model_name rho_1_ecco --train --evaluation 45 | ``` 46 | 47 | For the baseline evaluation, you can refer to [Argoverse Official Baseline](https://github.com/jagjeet-singh/argoverse-forecasting). Note: the evaluation of the constant velocity is evaluated on the validation set (filtered out the scenes with car number greater than 60) with the velocity at final timestamp as the constant velocity. 48 | 49 | ## Citation 50 | 51 | If you find this repository useful in your research, please cite our paper: 52 | ``` 53 | @article{Walters2021ECCO, 54 | title={Trajectory Prediction using Equivariant Continuous Convolution}, 55 | author={Robin Walters and Jinxi Li and Rose Yu}, 56 | journal={International Conference on Learning Representations}, 57 | year={2021}, 58 | } 59 | -------------------------------------------------------------------------------- /datasets/argoverse_lane_loader.py: -------------------------------------------------------------------------------- 1 | "Functions loading the .pkl version preprocessed data" 2 | from glob import glob 3 | import pickle 4 | import os 5 | import numpy as np 6 | from argoverse.map_representation.map_api import ArgoverseMap 7 | from torch.utils.data import IterableDataset, DataLoader 8 | 9 | 10 | class ArgoverseDataset(IterableDataset): 11 | def __init__(self, data_path: str, transform=None, 12 | max_lane_nodes=650, min_lane_nodes=0, shuffle=True): 13 | super(ArgoverseDataset, self).__init__() 14 | self.data_path = data_path 15 | self.transform = transform 16 | self.pkl_list = glob(os.path.join(self.data_path, '*')) 17 | if shuffle: 18 | np.random.shuffle(self.pkl_list) 19 | else: 20 | self.pkl_list.sort() 21 | self.max_lane_nodes = max_lane_nodes 22 | self.min_lane_nodes = min_lane_nodes 23 | 24 | def __len__(self): 25 | return len(self.pkl_list) 26 | 27 | def __iter__(self): 28 | # pkl_path = self.pkl_list[idx] 29 | for pkl_path in self.pkl_list: 30 | with open(pkl_path, 'rb') as f: 31 | data = pickle.load(f) 32 | # data = {k:v[0] for k, v in data.items()} 33 | lane_mask = np.zeros(self.max_lane_nodes, dtype=np.float32) 34 | lane_mask[:len(data['lane'][0])] = 1.0 35 | data['lane_mask'] = [lane_mask] 36 | 37 | if data['lane'][0].shape[0] > self.max_lane_nodes: 38 | continue 39 | 40 | if data['lane'][0].shape[0] < self.min_lane_nodes: 41 | continue 42 | 43 | data['lane'] = [self.expand_particle(data['lane'][0], self.max_lane_nodes, 0)] 44 | data['lane_norm'] = [self.expand_particle(data['lane_norm'][0], self.max_lane_nodes, 0)] 45 | 46 | if self.transform: 47 | data = self.transform(data) 48 | 49 | yield data 50 | 51 | @classmethod 52 | def expand_particle(cls, arr, max_num, axis, value_type='int'): 53 | dummy_shape = list(arr.shape) 54 | dummy_shape[axis] = max_num - arr.shape[axis] 55 | dummy = np.zeros(dummy_shape) 56 | if value_type == 'str': 57 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape) 58 | return np.concatenate([arr, dummy], axis=axis) 59 | 60 | 61 | def cat_key(data, key): 62 | result = [] 63 | for d in data: 64 | result = result + d[key] 65 | return result 66 | 67 | 68 | def dict_collate_func(data): 69 | keys = data[0].keys() 70 | data = {key: cat_key(data, key) for key in keys} 71 | return data 72 | 73 | 74 | def read_pkl_data(data_path: str, batch_size: int, 75 | shuffle: bool=False, repeat: bool=False, **kwargs): 76 | dataset = ArgoverseDataset(data_path=data_path, shuffle=shuffle, **kwargs) 77 | loader = DataLoader(dataset, batch_size=batch_size, collate_fn=dict_collate_func) 78 | if repeat: 79 | while True: 80 | for data in loader: 81 | yield data 82 | else: 83 | for data in loader: 84 | yield data 85 | 86 | -------------------------------------------------------------------------------- /models/EquiLinear.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | 8 | # rho1 --> rho1 9 | class EquiLinear(nn.Module): 10 | def __init__(self, in_features, out_features): 11 | super(EquiLinear, self).__init__() 12 | self.linear = nn.Linear(in_features, out_features, bias=False) 13 | 14 | def forward(self, field_feat): 15 | """ 16 | inputs: 17 | @field_feat: [batch, num_part, in_feat, 2] 18 | 19 | output: 20 | [batch, num_part, out_feat, 2] 21 | """ 22 | return self.linear(field_feat.permute(0, 1, 3, 2)).permute(0, 1, 3, 2) 23 | 24 | 25 | # Reg --> Reg 26 | class EquiLinearRegToReg(nn.Module): 27 | def __init__(self, in_features, out_features, k): 28 | super(EquiLinearRegToReg, self).__init__() 29 | self.k = k 30 | self.weights = nn.parameter.Parameter(torch.rand(in_features, out_features, k) / in_features) 31 | # print(self.weights) 32 | # kernel = self.update_kernel() 33 | # print(self.kernel) 34 | 35 | def update_kernel(self): 36 | # i or -i ??? stack -2 or -1 ??? torch.flip ??? 37 | return torch.stack([torch.roll(self.weights, i, 2) for i in range(0,self.k)],-1) 38 | 39 | def forward(self, field_feat): 40 | """ 41 | inputs: 42 | k: int -- number of slices of the circle for regular rep 43 | @field_feat: [batch, num_part, in_feat, k] 44 | kernel: [in_feat, out_feat, k, k] 45 | 46 | f*k(\theta) = \sum_\psi K(\psi)f(\theta - \psi) 47 | 48 | output: 49 | [batch, num_part, out_feat, k] 50 | """ 51 | # x or y ??? 52 | kernel = self.update_kernel() 53 | return torch.einsum('ijyx,...ix->...jy',kernel,field_feat) 54 | 55 | 56 | # Rho1 --> Reg 57 | class EquiLinearRho1ToReg(nn.Module): 58 | def __init__(self, k): 59 | super(EquiLinearRho1ToReg, self).__init__() 60 | self.k = k 61 | SinVec = torch.tensor([math.sin(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False) 62 | CosVec = torch.tensor([math.cos(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False) 63 | Rho1ToReg = torch.stack([CosVec,SinVec],1) #[k,2] 64 | self.register_buffer('Rho1ToReg', Rho1ToReg) 65 | 66 | def forward(self, field_feat): 67 | """ 68 | k: int -- number of slices of the circle for regular rep 69 | inputs: 70 | @field_feat: [batch, num_part, in_feat, 2] 71 | output: [batch, num_part, in_feat, k] 72 | 73 | (a,b) --> a Sin + b Cos 74 | """ 75 | return torch.einsum('yx,...x->...y',self.Rho1ToReg, field_feat) 76 | 77 | 78 | # Reg --> Rho1 79 | class EquiLinearRegToRho1(nn.Module): 80 | def __init__(self, k): 81 | super(EquiLinearRegToRho1, self).__init__() 82 | self.k = k 83 | SinVec = torch.tensor([math.sin(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False) 84 | CosVec = torch.tensor([math.cos(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False) 85 | RegToRho1 = torch.stack([CosVec,SinVec],0) #[2,k] 86 | self.register_buffer('RegToRho1', RegToRho1) 87 | 88 | def forward(self, field_feat): 89 | ''' 90 | k: int -- number of slices of the circle for regular rep 91 | inputs: 92 | @field_feat: [batch, num_part, in_feat, k] 93 | output: 94 | retval: [batch, num_part, in_feat, 2] 95 | 96 | f is a function on circle divided into k parts 97 | f(i) means f(2\pi *i /k) 98 | f --> ( \sum_{i=0}^k ( f(i) cos(2 \pi i /k) , f(i) sin(2 \pi i /k) ) 99 | This is a fourier transform. 100 | ''' 101 | return torch.einsum('yx,...x->...y',self.RegToRho1, field_feat) 102 | -------------------------------------------------------------------------------- /scripts/evaluate_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import numpy as np 5 | import time 6 | import importlib 7 | import torch 8 | import pickle 9 | 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 12 | from train_utils import * 13 | 14 | 15 | def get_agent(pr, gt, pr_id, gt_id, agent_id, device='cpu'): 16 | 17 | pr_agent = pr[pr_id == agent_id,:] 18 | gt_agent = gt[gt_id == agent_id,:] 19 | 20 | return pr_agent, gt_agent 21 | 22 | 23 | def evaluate(model, val_dataset, train_window=3, max_iter=2500, device='cpu', start_iter=0, 24 | batch_size=32): 25 | 26 | print('evaluating.. ', end='', flush=True) 27 | 28 | count = 0 29 | prediction_gt = {} 30 | losses = [] 31 | val_iter = iter(val_dataset) 32 | 33 | for i, sample in enumerate(val_dataset): 34 | 35 | if i >= max_iter: 36 | break 37 | 38 | if i < start_iter: 39 | continue 40 | 41 | pred = [] 42 | gt = [] 43 | 44 | if count % 1 == 0: 45 | print('{}'.format(count + 1), end=' ', flush=True) 46 | 47 | count += 1 48 | 49 | data = {} 50 | convert_keys = (['pos' + str(i) for i in range(31)] + 51 | ['vel' + str(i) for i in range(31)] + 52 | ['pos_2s', 'vel_2s', 'lane', 'lane_norm']) 53 | 54 | for k in convert_keys: 55 | data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device) 56 | 57 | 58 | for k in ['track_id' + str(i) for i in range(31)] + ['city', 'agent_id', 'scene_idx']: 59 | data[k] = np.stack(sample[k]) 60 | 61 | for k in ['car_mask', 'lane_mask']: 62 | data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1) 63 | 64 | scenes = data['scene_idx'].tolist() 65 | 66 | data['agent_id'] = data['agent_id'][:,np.newaxis] 67 | 68 | data['car_mask'] = data['car_mask'].squeeze(-1) 69 | accel = torch.zeros(1, 1, 2).to(device) 70 | data['accel'] = accel 71 | 72 | lane = data['lane'] 73 | lane_normals = data['lane_norm'] 74 | agent_id = data['agent_id'] 75 | city = data['city'] 76 | 77 | inputs = ([ 78 | data['pos_2s'], data['vel_2s'], 79 | data['pos0'], data['vel0'], 80 | data['accel'], None, 81 | data['lane'], data['lane_norm'], 82 | data['car_mask'], data['lane_mask'] 83 | ]) 84 | 85 | pr_pos1, pr_vel1, states = model(inputs) 86 | gt_pos1 = data['pos1'] 87 | 88 | l = 0.5 * loss_fn(pr_pos1, gt_pos1, 89 | torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1)) 90 | 91 | pr_agent, gt_agent = get_agent(pr_pos1, data['pos1'], 92 | data['track_id0'], 93 | data['track_id1'], 94 | agent_id, device) 95 | pred.append(pr_agent.unsqueeze(1).detach().cpu()) 96 | gt.append(gt_agent.unsqueeze(1).detach().cpu()) 97 | del pr_agent, gt_agent 98 | clean_cache(device) 99 | 100 | pos0 = data['pos0'] 101 | vel0 = data['vel0'] 102 | for i in range(29): 103 | pos_enc = torch.unsqueeze(pos0, 2) 104 | vel_enc = torch.unsqueeze(vel0, 2) 105 | inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, 106 | data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask']) 107 | pos0, vel0 = pr_pos1, pr_vel1 108 | pr_pos1, pr_vel1, states = model(inputs, states) 109 | clean_cache(device) 110 | 111 | if i < train_window - 1: 112 | gt_pos1 = data['pos'+str(i+2)] 113 | l += 0.5 * loss_fn(pr_pos1, gt_pos1, 114 | torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1)) 115 | 116 | pr_agent, gt_agent = get_agent(pr_pos1, data['pos'+str(i+2)], 117 | data['track_id0'], 118 | data['track_id'+str(i+2)], 119 | agent_id, device) 120 | 121 | pred.append(pr_agent.unsqueeze(1).detach().cpu()) 122 | gt.append(gt_agent.unsqueeze(1).detach().cpu()) 123 | 124 | clean_cache(device) 125 | 126 | losses.append(l) 127 | 128 | predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1)) 129 | for idx, scene_id in enumerate(scenes): 130 | prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx]) 131 | 132 | total_loss = 128 * torch.sum(torch.stack(losses),axis=0) / max_iter 133 | 134 | result = {} 135 | de = {} 136 | 137 | for k, v in prediction_gt.items(): 138 | de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 + 139 | (v[0][:,1] - v[1][:,1])**2) 140 | 141 | ade = [] 142 | de1s = [] 143 | de2s = [] 144 | de3s = [] 145 | for k, v in de.items(): 146 | ade.append(np.mean(v.numpy())) 147 | de1s.append(v.numpy()[10]) 148 | de2s.append(v.numpy()[20]) 149 | de3s.append(v.numpy()[-1]) 150 | 151 | result['ADE'] = np.mean(ade) 152 | result['ADE_std'] = np.std(ade) 153 | result['DE@1s'] = np.mean(de1s) 154 | result['DE@1s_std'] = np.std(de1s) 155 | result['DE@2s'] = np.mean(de2s) 156 | result['DE@2s_std'] = np.std(de2s) 157 | result['DE@3s'] = np.mean(de3s) 158 | result['DE@3s_std'] = np.std(de3s) 159 | 160 | print(result) 161 | print('done') 162 | 163 | return total_loss, prediction_gt 164 | 165 | 166 | 167 | 168 | 169 | -------------------------------------------------------------------------------- /datasets/helper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | A modified visualization function and get_lanes function 5 | """ 6 | 7 | import argparse 8 | import os 9 | import shutil 10 | import sys 11 | from collections import defaultdict 12 | from typing import Dict, Optional 13 | 14 | import matplotlib.animation as anim 15 | import matplotlib.lines as mlines 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import pandas as pd 19 | import scipy.interpolate as interp 20 | 21 | from argoverse.map_representation.map_api import ArgoverseMap 22 | 23 | _ZORDER = {"AGENT": 15, "AV": 10, "OTHERS": 5} 24 | 25 | 26 | def get_batch_lane_direction(pos, city, am): 27 | if am is None: 28 | from argoverse.map_representation.map_api import ArgoverseMap 29 | am = ArgoverseMap() 30 | else: 31 | pass 32 | drct_conf = list() 33 | 34 | for ps, c in zip(pos, city): 35 | drct_conf.append(np.array([np.append(*am.get_lane_direction(p[:2], c)) for p in ps])) 36 | 37 | return drct_conf 38 | 39 | 40 | def get_lane_direction(pos, city, am): 41 | if am is None: 42 | from argoverse.map_representation.map_api import ArgoverseMap 43 | am = ArgoverseMap() 44 | else: 45 | pass 46 | drct_conf = np.array([np.append(*am.get_lane_direction(p[:2], city)) for p in pos]) 47 | 48 | return drct_conf 49 | 50 | 51 | def interpolate_polyline(polyline: np.ndarray, num_points: int) -> np.ndarray: 52 | duplicates = [] 53 | for i in range(1, len(polyline)): 54 | if np.allclose(polyline[i], polyline[i - 1]): 55 | duplicates.append(i) 56 | if polyline.shape[0] - len(duplicates) < 4: 57 | return polyline 58 | if duplicates: 59 | polyline = np.delete(polyline, duplicates, axis=0) 60 | tck, u = interp.splprep(polyline.T, s=0) 61 | u = np.linspace(0.0, 1.0, num_points) 62 | return np.column_stack(interp.splev(u, tck)) 63 | 64 | 65 | def get_lanes(df: pd.DataFrame, city_name: str, avm: Optional[ArgoverseMap] = None) -> list: 66 | 67 | # Get API for Argo Dataset map 68 | avm = ArgoverseMap() if avm is None else avm 69 | seq_lane_bbox = avm.city_halluc_bbox_table[city_name] 70 | seq_lane_props = avm.city_lane_centerlines_dict[city_name] 71 | 72 | x_min = min(df["X"]) 73 | x_max = max(df["X"]) 74 | y_min = min(df["Y"]) 75 | y_max = max(df["Y"]) 76 | 77 | lane_centerlines = [] 78 | 79 | # Get lane centerlines which lie within the range of trajectories 80 | for lane_id, lane_props in seq_lane_props.items(): 81 | 82 | lane_cl = lane_props.centerline 83 | 84 | if ( 85 | np.min(lane_cl[:, 0]) < x_max 86 | and np.min(lane_cl[:, 1]) < y_max 87 | and np.max(lane_cl[:, 0]) > x_min 88 | and np.max(lane_cl[:, 1]) > y_min 89 | ): 90 | lane_centerlines.append(lane_cl) 91 | 92 | return lane_centerlines 93 | 94 | 95 | def get_all_lanes(city_name: str, avm: Optional[ArgoverseMap] = None) -> list: 96 | 97 | # Get API for Argo Dataset map 98 | avm = ArgoverseMap() if avm is None else avm 99 | seq_lane_bbox = avm.city_halluc_bbox_table[city_name] 100 | seq_lane_props = avm.city_lane_centerlines_dict[city_name] 101 | 102 | lane_centerlines = [lane.centerline for lane in seq_lane_props.values()] 103 | 104 | return lane_centerlines 105 | 106 | 107 | def visualize_trajectory( 108 | df: pd.DataFrame, lane_centerlines: Optional[np.ndarray] = None, show: bool = True, smoothen: bool = False 109 | ) -> None: 110 | 111 | # Seq data 112 | # time_list = np.sort(np.unique(df["TIMESTAMP"].values)) 113 | city_name = df["CITY_NAME"].values[0] 114 | 115 | lane_centerlines = get_lanes(df, city_name) if lane_centerlines is None else lane_centerlines 116 | 117 | plt.figure(0, figsize=(8, 7)) 118 | 119 | x_min = min(df["X"]) 120 | x_max = max(df["X"]) 121 | y_min = min(df["Y"]) 122 | y_max = max(df["Y"]) 123 | 124 | plt.xlim(x_min, x_max) 125 | plt.ylim(y_min, y_max) 126 | 127 | 128 | for lane_cl in lane_centerlines: 129 | plt.plot(lane_cl[:, 0], lane_cl[:, 1], "--", color="grey", alpha=1, linewidth=1, zorder=0) 130 | frames = df.groupby("TRACK_ID") 131 | 132 | plt.xlabel("Map X") 133 | plt.ylabel("Map Y") 134 | 135 | color_dict = {"AGENT": "#d33e4c", "OTHERS": "#13d4f2", "AV": "#007672"} 136 | object_type_tracker: Dict[int, int] = defaultdict(int) 137 | 138 | # Plot all the tracks up till current frame 139 | for group_name, group_data in frames: 140 | object_type = group_data["OBJECT_TYPE"].values[0] 141 | 142 | cor_x = group_data["X"].values 143 | cor_y = group_data["Y"].values 144 | 145 | if smoothen: 146 | polyline = np.column_stack((cor_x, cor_y)) 147 | num_points = cor_x.shape[0] * 3 148 | smooth_polyline = interpolate_polyline(polyline, num_points) 149 | cor_x = smooth_polyline[:, 0] 150 | cor_y = smooth_polyline[:, 1] 151 | 152 | plt.plot( 153 | cor_x, 154 | cor_y, 155 | "-", 156 | color=color_dict[object_type], 157 | label=object_type if not object_type_tracker[object_type] else "", 158 | alpha=1, 159 | linewidth=1, 160 | zorder=_ZORDER[object_type], 161 | ) 162 | 163 | final_x = cor_x[-1] 164 | final_y = cor_y[-1] 165 | 166 | if object_type == "AGENT": 167 | marker_type = "o" 168 | marker_size = 7 169 | elif object_type == "OTHERS": 170 | marker_type = "o" 171 | marker_size = 7 172 | elif object_type == "AV": 173 | marker_type = "o" 174 | marker_size = 7 175 | 176 | plt.plot( 177 | final_x, 178 | final_y, 179 | marker_type, 180 | color=color_dict[object_type], 181 | label=object_type if not object_type_tracker[object_type] else "", 182 | alpha=1, 183 | markersize=marker_size, 184 | zorder=_ZORDER[object_type], 185 | ) 186 | 187 | object_type_tracker[object_type] += 1 188 | 189 | red_star = mlines.Line2D([], [], color="red", marker="*", linestyle="None", markersize=7, label="Agent") 190 | green_circle = mlines.Line2D([], [], color="green", marker="o", linestyle="None", markersize=7, label="Others") 191 | black_triangle = mlines.Line2D([], [], color="black", marker="^", linestyle="None", markersize=7, label="AV") 192 | 193 | plt.axis("off") 194 | 195 | if show: 196 | plt.show() 197 | 198 | 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /models/rho1_ECCO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import argoverse 7 | 8 | import sys 9 | import os 10 | sys.path.append(os.path.dirname(__file__)) 11 | 12 | from EquiCtsConv import * 13 | from EquiLinear import * 14 | 15 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | class ECCONetwork(nn.Module): 18 | def __init__(self, 19 | num_radii = 3, 20 | num_theta = 16, 21 | radius_scale = 40, 22 | timestep = 0.1, 23 | encoder_hidden_size = 19, 24 | layer_channels = [16, 32, 32, 32, 1] 25 | ): 26 | super(ECCONetwork, self).__init__() 27 | 28 | # init parameters 29 | 30 | self.num_radii = num_radii 31 | self.num_theta = num_theta 32 | self.radius_scale = radius_scale 33 | self.timestep = timestep 34 | self.layer_channels = layer_channels 35 | 36 | self.encoder_hidden_size = encoder_hidden_size 37 | self.in_channel = 1 + self.encoder_hidden_size 38 | self.activation = F.relu 39 | # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2)) 40 | relu_shift = torch.tensor(0.2) 41 | self.register_buffer('relu_shift', relu_shift) 42 | 43 | # create continuous convolution and fully-connected layers 44 | 45 | convs = [] 46 | denses = [] 47 | # c_in, c_out, radius, num_radii, num_theta 48 | self.conv_fluid = EquiCtsConv2d(in_channels = self.in_channel, 49 | out_channels = self.layer_channels[0], 50 | num_radii = self.num_radii, 51 | num_theta = self.num_theta, 52 | radius = self.radius_scale) 53 | 54 | self.conv_obstacle = EquiCtsConv2d(in_channels = 1, 55 | out_channels = self.layer_channels[0], 56 | num_radii = self.num_radii, 57 | num_theta = self.num_theta, 58 | radius = self.radius_scale) 59 | 60 | self.dense_fluid = EquiLinear(self.in_channel, self.layer_channels[0]) 61 | 62 | # concat conv_obstacle, conv_fluid, dense_fluid 63 | in_ch = 3 * self.layer_channels[0] 64 | for i in range(1, len(self.layer_channels)): 65 | out_ch = self.layer_channels[i] 66 | dense = EquiLinear(in_ch, out_ch) 67 | denses.append(dense) 68 | conv = EquiCtsConv2d(in_channels = in_ch, 69 | out_channels = out_ch, 70 | num_radii = self.num_radii, 71 | num_theta = self.num_theta, 72 | radius = self.radius_scale) 73 | convs.append(conv) 74 | in_ch = self.layer_channels[i] 75 | 76 | self.convs = nn.ModuleList(convs) 77 | self.denses = nn.ModuleList(denses) 78 | 79 | 80 | def update_pos_vel(self, p0, v0, a): 81 | """Apply acceleration and integrate position and velocity. 82 | Assume the particle has constant acceleration during timestep. 83 | Return particle's position and velocity after 1 unit timestep.""" 84 | 85 | dt = self.timestep 86 | v1 = v0 + dt * a 87 | p1 = p0 + dt * (v0 + v1) / 2 88 | return p1, v1 89 | 90 | def apply_correction(self, p0, p1, correction): 91 | """Apply the position correction 92 | p0, p1: the position of the particle before/after basic integration. """ 93 | dt = self.timestep 94 | p_corrected = p1 + correction 95 | v_corrected = (p_corrected - p0) / dt 96 | return p_corrected, v_corrected 97 | 98 | def compute_correction(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask): 99 | """Precondition: p and v were updated with accerlation""" 100 | 101 | fluid_feats = [v.unsqueeze(-2)] 102 | if not other_feats is None: 103 | fluid_feats.append(other_feats) 104 | fluid_feats = torch.cat(fluid_feats, -2) 105 | 106 | # compute the correction by accumulating the output through the network layers 107 | output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask) 108 | output_dense_fluid = self.dense_fluid(fluid_feats) 109 | output_conv_obstacle = self.conv_obstacle(box, p, box_feats.unsqueeze(-2), box_mask) 110 | 111 | feats = torch.cat((output_conv_obstacle, output_conv_fluid, output_dense_fluid), -2) 112 | # self.outputs = [feats] 113 | output = feats 114 | 115 | for conv, dense in zip(self.convs, self.denses): 116 | # pass input features to conv and fully-connected layers 117 | mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1) 118 | in_feats = output/mags * self.activation(mags - self.relu_shift) 119 | # in_feats = self.activation(output) 120 | # in_feats = output 121 | output_conv = conv(p, p, in_feats, fluid_mask) 122 | output_dense = dense(in_feats) 123 | 124 | # if last dim size of output from cur dense layer is same as last dim size of output 125 | # current output should be based off on previous output 126 | if output_dense.shape[-2] == output.shape[-2]: 127 | output = output_conv + output_dense + output 128 | else: 129 | output = output_conv + output_dense 130 | # self.outputs.append(output) 131 | 132 | # compute the number of fluid particle neighbors. 133 | # this info is used in the loss function during training. 134 | # TODO: test this block of code 135 | self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1 136 | 137 | # self.last_features = self.outputs[-2] 138 | 139 | # scale to better match the scale of the output distribution 140 | self.pos_correction = (1.0 / 128) * output 141 | return self.pos_correction 142 | 143 | def forward(self, inputs, states=None): 144 | """ inputs: 8 elems tuple 145 | p0_enc, v0_enc, p0, v0, a, feats, box, box_feats 146 | v0_enc: [batch, num_part, timestamps, 2] 147 | Computes 1 simulation timestep""" 148 | p0_enc, v0_enc, p0, v0, a, other_feats, box, box_feats, fluid_mask, box_mask = inputs 149 | 150 | if states is None: 151 | if other_feats is None: 152 | feats = v0_enc 153 | else: 154 | feats = torch.cat((other_feats, v0_enc), -2) 155 | else: 156 | if other_feats is None: 157 | feats = v0_enc 158 | feats = torch.cat((states[0][...,1:,:], feats), -2) 159 | else: 160 | feats = torch.cat((other_feats, states[0][...,1:,:], v0_enc), -2) 161 | # print(feats.shape) 162 | 163 | # a = (v0 - v0_enc[...,-1,:]) / self.timestep 164 | p1, v1 = self.update_pos_vel(p0, v0, a) 165 | pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fluid_mask, box_mask) 166 | p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction.squeeze(-2)) 167 | 168 | return p_corrected, v_corrected, (feats, None) 169 | 170 | -------------------------------------------------------------------------------- /models/rho_reg_ECCO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import argoverse 7 | 8 | import sys 9 | import os 10 | sys.path.append(os.path.dirname(__file__)) 11 | 12 | from EquiCtsConv import * 13 | from EquiLinear import * 14 | 15 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 16 | 17 | class ECCONetwork(nn.Module): 18 | def __init__(self, 19 | num_radii = 3, 20 | num_theta = 16, 21 | reg_dim = 8, 22 | radius_scale = 40, 23 | timestep = 0.1, 24 | encoder_hidden_size = 19, 25 | layer_channels = [8, 16, 16, 16, 1] 26 | ): 27 | super(ECCONetwork, self).__init__() 28 | 29 | # init parameters 30 | 31 | self.num_radii = num_radii 32 | self.num_theta = num_theta 33 | self.reg_dim = reg_dim 34 | self.radius_scale = radius_scale 35 | self.timestep = timestep 36 | self.layer_channels = layer_channels 37 | self.filter_extent = np.float32(self.radius_scale * 6 * 38 | self.particle_radius) 39 | 40 | self.encoder_hidden_size = encoder_hidden_size 41 | self.in_channel = 1 + self.encoder_hidden_size 42 | self.activation = F.relu 43 | # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2)) 44 | relu_shift = torch.tensor(0.2) 45 | self.register_buffer('relu_shift', relu_shift) 46 | 47 | # create continuous convolution and fully-connected layers 48 | 49 | convs = [] 50 | denses = [] 51 | # c_in, c_out, radius, num_radii, num_theta 52 | self.conv_fluid = EquiCtsConv2dRho1ToReg(in_channels = self.in_channel, 53 | out_channels = self.layer_channels[0], 54 | num_radii = self.num_radii, 55 | num_theta = self.num_theta, 56 | radius = self.radius_scale, 57 | k = self.reg_dim) 58 | 59 | self.conv_obstacle = EquiCtsConv2dRho1ToReg(in_channels = 1, 60 | out_channels = self.layer_channels[0], 61 | num_radii = self.num_radii, 62 | num_theta = self.num_theta, 63 | radius = self.radius_scale, 64 | k = self.reg_dim) 65 | 66 | self.dense_fluid = nn.Sequential( 67 | EquiLinearRho1ToReg(self.reg_dim), 68 | EquiLinearRegToReg(self.in_channel, self.layer_channels[0], self.reg_dim) 69 | ) 70 | 71 | # concat conv_obstacle, conv_fluid, dense_fluid 72 | in_ch = 3 * self.layer_channels[0] 73 | for i in range(1, len(self.layer_channels)-1): 74 | out_ch = self.layer_channels[i] 75 | dense = EquiLinearRegToReg(in_ch, out_ch, self.reg_dim) 76 | denses.append(dense) 77 | conv = EquiCtsConv2dRegToReg(in_channels = in_ch, 78 | out_channels = out_ch, 79 | num_radii = self.num_radii, 80 | num_theta = self.num_theta, 81 | radius = self.radius_scale, 82 | k = self.reg_dim) 83 | convs.append(conv) 84 | in_ch = self.layer_channels[i] 85 | 86 | out_ch = self.layer_channels[-1] 87 | dense = nn.Sequential( 88 | EquiLinearRegToReg(in_ch, out_ch, self.reg_dim), 89 | EquiLinearRegToRho1(self.reg_dim), 90 | ) 91 | denses.append(dense) 92 | conv = EquiCtsConv2dRegToRho1(in_channels = in_ch, 93 | out_channels = out_ch, 94 | num_radii = self.num_radii, 95 | num_theta = self.num_theta, 96 | radius = self.radius_scale, 97 | k = self.reg_dim) 98 | convs.append(conv) 99 | 100 | self.convs = nn.ModuleList(convs) 101 | self.denses = nn.ModuleList(denses) 102 | 103 | 104 | def update_pos_vel(self, p0, v0, a): 105 | """Apply acceleration and integrate position and velocity. 106 | Assume the particle has constant acceleration during timestep. 107 | Return particle's position and velocity after 1 unit timestep.""" 108 | 109 | dt = self.timestep 110 | v1 = v0 + dt * a 111 | p1 = p0 + dt * (v0 + v1) / 2 112 | return p1, v1 113 | 114 | def apply_correction(self, p0, p1, correction): 115 | """Apply the position correction 116 | p0, p1: the position of the particle before/after basic integration. """ 117 | dt = self.timestep 118 | p_corrected = p1 + correction 119 | v_corrected = (p_corrected - p0) / dt 120 | return p_corrected, v_corrected 121 | 122 | def compute_correction(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask): 123 | """Precondition: p and v were updated with accerlation""" 124 | 125 | fluid_feats = [v.unsqueeze(-2)] 126 | if not other_feats is None: 127 | fluid_feats.append(other_feats) 128 | fluid_feats = torch.cat(fluid_feats, -2) 129 | 130 | # compute the correction by accumulating the output through the network layers 131 | output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask) 132 | output_dense_fluid = self.dense_fluid(fluid_feats) 133 | output_conv_obstacle = self.conv_obstacle(box, p, box_feats.unsqueeze(-2), box_mask) 134 | 135 | feats = torch.cat((output_conv_obstacle, output_conv_fluid, output_dense_fluid), -2) 136 | # self.outputs = [feats] 137 | output = feats 138 | 139 | for conv, dense in zip(self.convs, self.denses): 140 | # pass input features to conv and fully-connected layers 141 | # mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1) 142 | # in_feats = output/mags * self.activation(mags - self.relu_shift) 143 | in_feats = self.activation(output) 144 | # in_feats = output 145 | output_conv = conv(p, p, in_feats, fluid_mask) 146 | output_dense = dense(in_feats) 147 | 148 | # if last dim size of output from cur dense layer is same as last dim size of output 149 | # current output should be based off on previous output 150 | if output_dense.shape[-2] == output.shape[-2]: 151 | output = output_conv + output_dense + output 152 | else: 153 | output = output_conv + output_dense 154 | # self.outputs.append(output) 155 | 156 | # compute the number of fluid particle neighbors. 157 | # this info is used in the loss function during training. 158 | # TODO: test this block of code 159 | self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1 160 | 161 | # self.last_features = self.outputs[-2] 162 | 163 | # scale to better match the scale of the output distribution 164 | self.pos_correction = (1.0 / 128) * output 165 | return self.pos_correction 166 | 167 | def forward(self, inputs, states=None): 168 | """ inputs: 8 elems tuple 169 | p0_enc, v0_enc, p0, v0, a, feats, box, box_feats 170 | v0_enc: [batch, num_part, timestamps, 2] 171 | Computes 1 simulation timestep""" 172 | p0_enc, v0_enc, p0, v0, a, other_feats, box, box_feats, fluid_mask, box_mask = inputs 173 | 174 | if states is None: 175 | if other_feats is None: 176 | feats = v0_enc 177 | else: 178 | feats = torch.cat((other_feats, v0_enc), -2) 179 | else: 180 | if other_feats is None: 181 | feats = v0_enc 182 | feats = torch.cat((states[0][...,1:,:], feats), -2) 183 | else: 184 | feats = torch.cat((other_feats, states[0][...,1:,:], v0_enc), -2) 185 | # print(feats.shape) 186 | 187 | # a = (v0 - v0_enc[...,-1,:]) / self.timestep 188 | p1, v1 = self.update_pos_vel(p0, v0, a) 189 | pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fluid_mask, box_mask) 190 | p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction.squeeze(-2)) 191 | 192 | return p_corrected, v_corrected, (feats, None) 193 | 194 | 195 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import sys 4 | import numpy as np 5 | sys.path.append('..') 6 | from collections import namedtuple 7 | import time 8 | import pickle 9 | import argparse 10 | from evaluate_network import evaluate 11 | from argoverse.map_representation.map_api import ArgoverseMap 12 | from datasets.argoverse_lane_loader import read_pkl_data 13 | from train_utils import * 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | import torch.nn.functional as F 19 | 20 | 21 | parser = argparse.ArgumentParser(description="Training setting and hyperparameters") 22 | parser.add_argument('--cuda_visible_devices', default='0,1,2,3') 23 | parser.add_argument('--dataset_path', default='/path/to/argoverse_forecasting/', 24 | help='path to dataset folder, which contains train and val folders') 25 | parser.add_argument('--train_window', default=4, type=int, help='how many timestamps to iterate in training') 26 | parser.add_argument('--batch_divide', default=1, type=int, 27 | help='divide one batch into several packs, and train them iterativelly.') 28 | parser.add_argument('--epochs', default=70, type=int) 29 | parser.add_argument('--batches_per_epoch', default=600, type=int, 30 | help='determine the number of batches to train in one epoch') 31 | parser.add_argument('--base_lr', default=0.001, type=float) 32 | parser.add_argument('--batch_size', default=16, type=int) 33 | parser.add_argument('--model_name', default='ecco_trained_model', type=str) 34 | parser.add_argument('--val_batches', default=50, type=int, 35 | help='the number of batches of data to split as validation set') 36 | parser.add_argument('--val_batch_size', default=32, type=int) 37 | parser.add_argument('--train', default=False, action='store_true') 38 | parser.add_argument('--evaluation', default=False, action='store_true') 39 | 40 | feature_parser = parser.add_mutually_exclusive_group(required=False) 41 | feature_parser.add_argument('--rho1', dest='representation', action='store_false') 42 | feature_parser.add_argument('--rho-reg', dest='representation', action='store_true') 43 | parser.set_defaults(representation=True) 44 | 45 | args = parser.parse_args() 46 | 47 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices 48 | 49 | model_name = args.model_name 50 | 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | 53 | val_path = os.path.join(args.dataset_path, 'val', 'lane_data') 54 | train_path = os.path.join(args.dataset_path, 'train', 'lane_data') 55 | 56 | def create_model(): 57 | if args.representation: 58 | from models.rho_reg_ECCO import ECCONetwork 59 | """Returns an instance of the network for training and evaluation""" 60 | model = model = ECCONetwork(radius_scale = 40, 61 | layer_channels = [8, 16, 8, 8, 1], 62 | encoder_hidden_size=18) 63 | else: 64 | from models.rho1_ECCO import ECCONetwork 65 | """Returns an instance of the network for training and evaluation""" 66 | model = ECCONetwork(radius_scale = 40, encoder_hidden_size=18, 67 | layer_channels = [16, 32, 32, 32, 1], 68 | num_radii = 3) 69 | return model 70 | 71 | class MyDataParallel(torch.nn.DataParallel): 72 | """ 73 | Allow nn.DataParallel to call model's attributes. 74 | """ 75 | def __getattr__(self, name): 76 | try: 77 | return super().__getattr__(name) 78 | except AttributeError: 79 | return getattr(self.module, name) 80 | 81 | def train(): 82 | am = ArgoverseMap() 83 | 84 | val_dataset = read_pkl_data(val_path, batch_size=args.val_batch_size, shuffle=True, repeat=False, max_lane_nodes=700) 85 | 86 | dataset = read_pkl_data(train_path, batch_size=args.batch_size // args.batch_divide, 87 | repeat=True, shuffle=True, max_lane_nodes=900) 88 | 89 | data_iter = iter(dataset) 90 | 91 | model = create_model().to(device) 92 | # model_ = torch.load(model_name + '.pth') 93 | # model = model_ 94 | model = MyDataParallel(model) 95 | optimizer = torch.optim.Adam(model.parameters(), args.base_lr,betas=(0.9, 0.999), weight_decay=4e-4) 96 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.95) 97 | 98 | def train_one_batch(model, batch, train_window=2): 99 | 100 | batch_size = args.batch_size 101 | 102 | inputs = ([ 103 | batch['pos_2s'], batch['vel_2s'], 104 | batch['pos0'], batch['vel0'], 105 | batch['accel'], None, 106 | batch['lane'], batch['lane_norm'], 107 | batch['car_mask'], batch['lane_mask'] 108 | ]) 109 | 110 | # print_inputs_shape(inputs) 111 | # print(batch['pos0']) 112 | pr_pos1, pr_vel1, states = model(inputs) 113 | gt_pos1 = batch['pos1'] 114 | # print(pr_pos1) 115 | 116 | # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask']) 117 | losses = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1)) 118 | del gt_pos1 119 | 120 | pos0 = batch['pos0'] 121 | vel0 = batch['vel0'] 122 | for i in range(train_window-1): 123 | pos_enc = torch.unsqueeze(pos0, 2) 124 | vel_enc = torch.unsqueeze(vel0, 2) 125 | inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, batch['accel'], None, 126 | batch['lane'], 127 | batch['lane_norm'],batch['car_mask'], batch['lane_mask']) 128 | pos0, vel0 = pr_pos1, pr_vel1 129 | # del pos_enc, vel_enc 130 | 131 | pr_pos1, pr_vel1, states = model(inputs, states) 132 | gt_pos1 = batch['pos'+str(i+2)] 133 | 134 | losses += 0.5 * loss_fn(pr_pos1, gt_pos1, 135 | torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1)) 136 | 137 | 138 | total_loss = 128 * torch.sum(losses,axis=0) / batch_size 139 | 140 | return total_loss 141 | 142 | epochs = args.epochs 143 | batches_per_epoch = args.batches_per_epoch # batchs_per_epoch. Dataset is too large to run whole data. 144 | data_load_times = [] # Per batch 145 | train_losses = [] 146 | valid_losses = [] 147 | valid_metrics_list = [] 148 | min_loss = None 149 | 150 | for i in range(epochs): 151 | epoch_start_time = time.time() 152 | 153 | model.train() 154 | epoch_train_loss = 0 155 | sub_idx = 0 156 | 157 | print("training ... epoch " + str(i + 1), end='') 158 | for batch_itr in range(batches_per_epoch * args.batch_divide): 159 | 160 | data_fetch_start = time.time() 161 | batch = next(data_iter) 162 | 163 | if sub_idx == 0: 164 | optimizer.zero_grad() 165 | if (batch_itr // args.batch_divide) % 25 == 0: 166 | print("... batch " + str((batch_itr // args.batch_divide) + 1), end='', flush=True) 167 | sub_idx += 1 168 | 169 | batch_size = len(batch['pos0']) 170 | 171 | batch_tensor = {} 172 | convert_keys = (['pos' + str(i) for i in range(args.train_window + 1)] + 173 | ['vel' + str(i) for i in range(args.train_window + 1)] + 174 | ['pos_2s', 'vel_2s', 'lane', 'lane_norm']) 175 | 176 | for k in convert_keys: 177 | batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device) 178 | 179 | for k in ['car_mask', 'lane_mask']: 180 | batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device).unsqueeze(-1) 181 | 182 | for k in ['track_id' + str(i) for i in range(31)] + ['city']: 183 | batch_tensor[k] = batch[k] 184 | 185 | batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1) 186 | accel = torch.zeros(batch_size, 1, 2).to(device) 187 | batch_tensor['accel'] = accel 188 | del batch 189 | 190 | data_fetch_latency = time.time() - data_fetch_start 191 | data_load_times.append(data_fetch_latency) 192 | 193 | current_loss = train_one_batch(model, batch_tensor, train_window=args.train_window) 194 | 195 | if sub_idx < args.batch_divide: 196 | current_loss.backward(retain_graph=True) 197 | else: 198 | current_loss.backward() 199 | optimizer.step() 200 | sub_idx = 0 201 | del batch_tensor 202 | 203 | epoch_train_loss += float(current_loss) 204 | del current_loss 205 | clean_cache(device) 206 | 207 | if batch_itr == batches_per_epoch - 1: 208 | print("... DONE", flush=True) 209 | 210 | train_losses.append(epoch_train_loss) 211 | 212 | model.eval() 213 | with torch.no_grad(): 214 | valid_total_loss, valid_metrics = evaluate(model.module, val_dataset, 215 | train_window=args.train_window, 216 | max_iter=args.val_batches, 217 | device=device, 218 | batch_size=args.val_batch_size) 219 | 220 | valid_losses.append(float(valid_total_loss)) 221 | valid_metrics_list.append(valid_metrics) 222 | 223 | if min_loss is None: 224 | min_loss = valid_losses[-1] 225 | 226 | if valid_losses[-1] < min_loss: 227 | print('update weights') 228 | min_loss = valid_losses[-1] 229 | best_model = model 230 | torch.save(model.module, model_name + ".pth") 231 | 232 | epoch_end_time = time.time() 233 | 234 | print('epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'.format( 235 | i + 1, train_losses[-1], valid_losses[-1], 236 | round((epoch_end_time - epoch_start_time) / 60, 5), 237 | format(get_lr(optimizer), "5.2e"), model_name 238 | )) 239 | 240 | scheduler.step() 241 | 242 | 243 | def evaluation(): 244 | am = ArgoverseMap() 245 | 246 | val_dataset = read_pkl_data(val_path, batch_size=args.val_batch_size, shuffle=False, repeat=False) 247 | 248 | trained_model = torch.load(model_name + '.pth') 249 | trained_model.eval() 250 | 251 | with torch.no_grad(): 252 | valid_total_loss, valid_metrics = evaluate(trained_model, val_dataset, 253 | train_window=args.train_window, max_iter=len(val_dataset), 254 | device=device, start_iter=args.val_batches, use_lane=args.use_lane, 255 | batch_size=args.val_batch_size) 256 | 257 | with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f: 258 | pickle.dump(valid_metrics, f) 259 | 260 | 261 | if __name__ == '__main__': 262 | if args.train: 263 | train() 264 | 265 | if args.evaluation: 266 | evaluation() 267 | 268 | 269 | -------------------------------------------------------------------------------- /datasets/preprocess_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | import sys 6 | sys.path.append('..') 7 | from datasets.helper import get_lane_direction 8 | # from tensorpack import dataflow 9 | import time 10 | import gc 11 | import pickle 12 | import helper 13 | import time 14 | import glob 15 | from argoverse.map_representation.map_api import ArgoverseMap 16 | from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader 17 | 18 | dataset_path = '/path/to/dataset/argoverse_forecasting/' 19 | 20 | val_path = os.path.join(dataset_path, 'val', 'data') 21 | train_path = os.path.join(dataset_path, 'train', 'data') 22 | 23 | class ArgoverseTest(object): 24 | """ 25 | Data flow for argoverse dataset 26 | """ 27 | 28 | def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False, 29 | max_car_num: int = 50, freq: int = 10, use_interpolate: bool = False, 30 | use_lane: bool = False, use_mask: bool = True): 31 | if not os.path.exists(file_path): 32 | raise Exception("Path does not exist.") 33 | 34 | self.afl = ArgoverseForecastingLoader(file_path) 35 | self.shuffle = shuffle 36 | self.random_rotation = random_rotation 37 | self.max_car_num = max_car_num 38 | self.freq = freq 39 | self.use_interpolate = use_interpolate 40 | self.am = ArgoverseMap() 41 | self.use_mask = use_mask 42 | self.file_path = file_path 43 | 44 | 45 | def get_feat(self, scene): 46 | 47 | data, city = self.afl[scene].seq_df, self.afl[scene].city 48 | 49 | lane = np.array([[0., 0.]], dtype=np.float32) 50 | lane_drct = np.array([[0., 0.]], dtype=np.float32) 51 | 52 | 53 | tstmps = data.TIMESTAMP.unique() 54 | tstmps.sort() 55 | 56 | data = self._filter_imcomplete_data(data, tstmps, 50) 57 | 58 | data = self._calc_vel(data, self.freq) 59 | 60 | agent = data[data['OBJECT_TYPE'] == 'AGENT']['TRACK_ID'].values[0] 61 | 62 | car_mask = np.zeros((self.max_car_num, 1), dtype=np.float32) 63 | car_mask[:len(data.TRACK_ID.unique())] = 1.0 64 | 65 | feat_dict = {'city': city, 66 | 'lane': lane, 67 | 'lane_norm': lane_drct, 68 | 'scene_idx': scene, 69 | 'agent_id': agent, 70 | 'car_mask': car_mask} 71 | 72 | pos_enc = [subdf[['X', 'Y']].values[np.newaxis,:] 73 | for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')] 74 | pos_enc = np.concatenate(pos_enc, axis=0) 75 | # pos_enc = self._expand_dim(pos_enc) 76 | feat_dict['pos_2s'] = self._expand_particle(pos_enc, self.max_car_num, 0) 77 | 78 | vel_enc = [subdf[['vel_x', 'vel_y']].values[np.newaxis,:] 79 | for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')] 80 | vel_enc = np.concatenate(vel_enc, axis=0) 81 | # vel_enc = self._expand_dim(vel_enc) 82 | feat_dict['vel_2s'] = self._expand_particle(vel_enc, self.max_car_num, 0) 83 | 84 | pos = data[data['TIMESTAMP'] == tstmps[19]][['X', 'Y']].values 85 | # pos = self._expand_dim(pos) 86 | feat_dict['pos0'] = self._expand_particle(pos, self.max_car_num, 0) 87 | vel = data[data['TIMESTAMP'] == tstmps[19]][['vel_x', 'vel_y']].values 88 | # vel = self._expand_dim(vel) 89 | feat_dict['vel0'] = self._expand_particle(vel, self.max_car_num, 0) 90 | track_id = data[data['TIMESTAMP'] == tstmps[19]]['TRACK_ID'].values 91 | feat_dict['track_id0'] = self._expand_particle(track_id, self.max_car_num, 0, 'str') 92 | feat_dict['frame_id0'] = 0 93 | 94 | for t in range(31): 95 | pos = data[data['TIMESTAMP'] == tstmps[19 + t]][['X', 'Y']].values 96 | # pos = self._expand_dim(pos) 97 | feat_dict['pos' + str(t)] = self._expand_particle(pos, self.max_car_num, 0) 98 | vel = data[data['TIMESTAMP'] == tstmps[19 + t]][['vel_x', 'vel_y']].values 99 | # vel = self._expand_dim(vel) 100 | feat_dict['vel' + str(t)] = self._expand_particle(vel, self.max_car_num, 0) 101 | track_id = data[data['TIMESTAMP'] == tstmps[19 + t]]['TRACK_ID'].values 102 | feat_dict['track_id' + str(t)] = self._expand_particle(track_id, self.max_car_num, 0, 'str') 103 | feat_dict['frame_id' + str(t)] = t 104 | 105 | return feat_dict 106 | 107 | def __len__(self): 108 | return len(glob.glob(os.path.join(self.file_path, '*'))) 109 | 110 | @classmethod 111 | def _expand_df(cls, data, city_name): 112 | timestps = data['TIMESTAMP'].unique().tolist() 113 | ids = data['TRACK_ID'].unique().tolist() 114 | df = pd.DataFrame({'TIMESTAMP': timestps * len(ids)}).sort_values('TIMESTAMP') 115 | df['TRACK_ID'] = ids * len(timestps) 116 | df['CITY_NAME'] = city_name 117 | return pd.merge(data, df, on=['TIMESTAMP', 'TRACK_ID'], how='right') 118 | 119 | 120 | @classmethod 121 | def __calc_vel_generator(cls, df, freq=10): 122 | for idx, subdf in df.groupby('TRACK_ID'): 123 | sub_df = subdf.copy().sort_values('TIMESTAMP') 124 | sub_df[['vel_x', 'vel_y']] = sub_df[['X', 'Y']].diff() * freq 125 | yield sub_df.iloc[1:, :] 126 | 127 | @classmethod 128 | def _calc_vel(cls, df, freq=10): 129 | return pd.concat(cls.__calc_vel_generator(df, freq=freq), axis=0) 130 | 131 | @classmethod 132 | def _expand_dim(cls, ndarr, dtype=np.float32): 133 | return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype) 134 | 135 | @classmethod 136 | def _linear_interpolate_generator(cls, data, col=['X', 'Y']): 137 | for idx, df in data.groupby('TRACK_ID'): 138 | sub_df = df.copy().sort_values('TIMESTAMP') 139 | sub_df[col] = sub_df[col].interpolate(limit_direction='both') 140 | yield sub_df.ffill().bfill() 141 | 142 | @classmethod 143 | def _linear_interpolate(cls, data, col=['X', 'Y']): 144 | return pd.concat(cls._linear_interpolate_generator(data, col), axis=0) 145 | 146 | @classmethod 147 | def _filter_imcomplete_data(cls, data, tstmps, window=20): 148 | complete_id = list() 149 | for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'): 150 | if len(subdf) == window: 151 | complete_id.append(idx) 152 | return data[data['TRACK_ID'].isin(complete_id)] 153 | 154 | @classmethod 155 | def _expand_particle(cls, arr, max_num, axis, value_type='int'): 156 | dummy_shape = list(arr.shape) 157 | dummy_shape[axis] = max_num - arr.shape[axis] 158 | dummy = np.zeros(dummy_shape) 159 | if value_type == 'str': 160 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape) 161 | return np.concatenate([arr, dummy], axis=axis) 162 | 163 | 164 | class process_utils(object): 165 | 166 | @classmethod 167 | def expand_dim(cls, ndarr, dtype=np.float32): 168 | return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype) 169 | 170 | @classmethod 171 | def expand_particle(cls, arr, max_num, axis, value_type='int'): 172 | dummy_shape = list(arr.shape) 173 | dummy_shape[axis] = max_num - arr.shape[axis] 174 | dummy = np.zeros(dummy_shape) 175 | if value_type == 'str': 176 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape) 177 | return np.concatenate([arr, dummy], axis=axis) 178 | 179 | 180 | def get_max_min(datas): 181 | mask = datas['car_mask'] 182 | slicer = mask[0].astype(bool).flatten() 183 | pos_keys = ['pos0'] + ['pos_2s'] 184 | max_x = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,0] 185 | .reshape(np.stack(datas[pk]).shape[0], -1), 186 | axis=-1)[...,np.newaxis] 187 | for pk in pos_keys], axis=-1) 188 | min_x = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,0] 189 | .reshape(np.stack(datas[pk]).shape[0], -1), 190 | axis=-1)[...,np.newaxis] 191 | for pk in pos_keys], axis=-1) 192 | max_y = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,1] 193 | .reshape(np.stack(datas[pk]).shape[0], -1), 194 | axis=-1)[...,np.newaxis] 195 | for pk in pos_keys], axis=-1) 196 | min_y = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,1] 197 | .reshape(np.stack(datas[pk]).shape[0], -1), 198 | axis=-1)[...,np.newaxis] 199 | for pk in pos_keys], axis=-1) 200 | max_x = np.max(max_x, axis=-1) + 10 201 | max_y = np.max(max_y, axis=-1) + 10 202 | min_x = np.max(min_x, axis=-1) - 10 203 | min_y = np.max(min_y, axis=-1) - 10 204 | return min_x, max_x, min_y, max_y 205 | 206 | 207 | def process_func(putil, datas, am): 208 | 209 | city = datas['city'][0] 210 | x_min, x_max, y_min, y_max = get_max_min(datas) 211 | 212 | seq_lane_props = am.city_lane_centerlines_dict[city] 213 | 214 | lane_centerlines = [] 215 | lane_directions = [] 216 | 217 | # Get lane centerlines which lie within the range of trajectories 218 | for lane_id, lane_props in seq_lane_props.items(): 219 | 220 | lane_cl = lane_props.centerline 221 | 222 | if ( 223 | np.min(lane_cl[:, 0]) < x_max 224 | and np.min(lane_cl[:, 1]) < y_max 225 | and np.max(lane_cl[:, 0]) > x_min 226 | and np.max(lane_cl[:, 1]) > y_min 227 | ): 228 | lane_centerlines.append(lane_cl[1:]) 229 | lane_drct = np.diff(lane_cl, axis=0) 230 | lane_directions.append(lane_drct) 231 | if len(lane_centerlines) > 0: 232 | lane = np.concatenate(lane_centerlines, axis=0) 233 | # lane = putil.expand_dim(lane) 234 | lane_drct = np.concatenate(lane_directions, axis=0) 235 | # lane_drct = putil.expand_dim(lane_drct)[...,:3] 236 | 237 | datas['lane'] = [lane] 238 | datas['lane_norm'] = [lane_drct] 239 | return datas 240 | else: 241 | return datas 242 | 243 | 244 | if __name__ == '__main__': 245 | am = ArgoverseMap() 246 | putil = process_utils() 247 | 248 | afl_train = ArgoverseForecastingLoader(os.path.join(dataset_path, 'train', 'data')) 249 | afl_val = ArgoverseForecastingLoader(os.path.join(dataset_path, 'val', 'data')) 250 | at_train = ArgoverseTest(os.path.join(dataset_path, 'train', 'data'), max_car_num=60) 251 | at_val = ArgoverseTest(os.path.join(dataset_path, 'val', 'data'), max_car_num=60) 252 | 253 | 254 | print("++++++++++++++++++++ START TRAIN ++++++++++++++++++++") 255 | train_num = len(afl_train) 256 | batch_start = time.time() 257 | os.mkdir(os.path.join(dataset_path, 'train/lane_data')) 258 | for i, scene in enumerate(range(train_num)): 259 | if i % 1000 == 0: 260 | batch_end = time.time() 261 | print("SAVED ============= {} / {} ....... {}".format(i, train_num, batch_end - batch_start)) 262 | batch_start = time.time() 263 | 264 | data = {k:[v] for k, v in at_train.get_feat(scene).items()} 265 | datas = process_func(putil, data, am) 266 | with open(os.path.join(dataset_path, 'train/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f: 267 | pickle.dump(datas, f) 268 | 269 | print("++++++++++++++++++++ START VAL ++++++++++++++++++++") 270 | val_num = len(afl_val) 271 | batch_start = time.time() 272 | os.mkdir(os.path.join(dataset_path, 'val/lane_data')) 273 | for i, scene in enumerate(range(val_num)): 274 | if i % 1000 == 0: 275 | batch_end = time.time() 276 | print("SAVED ============= {} / {} ....... {}".format(i, val_num, batch_end - batch_start)) 277 | batch_start = time.time() 278 | 279 | data = {k:[v] for k, v in at_val.get_feat(scene).items()} 280 | datas = process_func(putil, data, am) 281 | with open(os.path.join(dataset_path, 'val/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f: 282 | pickle.dump(datas, f) 283 | 284 | 285 | -------------------------------------------------------------------------------- /models/RelEquiCtsConv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | 7 | from EquiCtsConv import * 8 | 9 | 10 | class RelEquiCtsConv2d(EquiCtsConv2d): 11 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, matrix_dim=2, 12 | use_attention=True, normalize_attention=True): 13 | super(RelEquiCtsConv2d, self).__init__(in_channels, out_channels, radius, num_radii, num_theta, 14 | matrix_dim, use_attention, normalize_attention) 15 | 16 | def ContinuousConv( 17 | self, field, center, field_feat, 18 | field_mask, ctr_feat 19 | ): 20 | """ 21 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 22 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim] 23 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim] 24 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2] 25 | @ctr_feat: [batch, 1, feat_dim] 26 | @field_mask: [batch, num_n, 1] 27 | """ 28 | kernel = self.computeKernel() 29 | 30 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius 31 | # relative_field: [batch, num_m, num_n, pos_dim] 32 | 33 | 34 | polar_field = self.PolarCoords(relative_field) 35 | # polar_field: [batch, num_m, num_n, pos_dim] 36 | 37 | kernel_on_field = self.InterpolateKernel(kernel, polar_field) 38 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2] 39 | 40 | if self.use_attention: 41 | # print(relative_field.shape) 42 | # print(field_mask.unsqueeze(1).shape) 43 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1) 44 | # attention: [batch, num_m, num_n, 1] 45 | 46 | if self.normalize_attention: 47 | psi = torch.sum(attention, axis=2).squeeze(-1) 48 | psi[psi == 0.] = 1 49 | psi = psi.unsqueeze(-1).unsqueeze(-1) 50 | else: 51 | psi = 1.0 52 | else: 53 | attention = torch.ones(*relative_field.shape[0:3],1) 54 | 55 | if self.normalize_attention: 56 | psi = torch.sum(attention, axis=2).squeeze(-1) 57 | psi[psi == 0.] = 1 58 | psi = psi.unsqueeze(-1).unsqueeze(-1) 59 | else: 60 | psi = 1.0 61 | 62 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2) 63 | attention_field_feat = field_feat * attention.unsqueeze(-1) 64 | # attention_field_feat: [batch, num_m, num_n, c_in, 2] 65 | 66 | # print(kernel_on_field.shape, attention_field_feat.shape) 67 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat) 68 | # out: [batch, num_m, c_out, 2] 69 | 70 | return out / psi 71 | 72 | def forward( 73 | self, field, center, field_feat, 74 | field_mask, ctr_feat, normalize_attention=False 75 | ): 76 | out = self.ContinuousConv( 77 | field, center, field_feat, field_mask, 78 | ctr_feat 79 | ) 80 | return out 81 | 82 | class RelEquiCtsConv2dRegToRho1(EquiCtsConv2dRegToRho1): 83 | 84 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 85 | use_attention=True, normalize_attention=True): 86 | super(RelEquiCtsConv2dRegToRho1, self).__init__(in_channels, out_channels, radius, num_radii, num_theta, 87 | k, matrix_dim, use_attention, normalize_attention) 88 | 89 | def ContinuousConv( 90 | self, field, center, field_feat, 91 | field_mask, ctr_feat 92 | ): 93 | """ 94 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 95 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim] 96 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim] 97 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2] 98 | @ctr_feat: [batch, 1, feat_dim] 99 | @field_mask: [batch, num_n, 1] 100 | """ 101 | kernel = self.computeKernel() 102 | 103 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius 104 | # relative_field: [batch, num_m, num_n, pos_dim] 105 | 106 | 107 | polar_field = self.PolarCoords(relative_field) 108 | # polar_field: [batch, num_m, num_n, pos_dim] 109 | 110 | kernel_on_field = self.InterpolateKernel(kernel, polar_field) 111 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2] 112 | 113 | if self.use_attention: 114 | # print(relative_field.shape) 115 | # print(field_mask.unsqueeze(1).shape) 116 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1) 117 | # attention: [batch, num_m, num_n, 1] 118 | 119 | if self.normalize_attention: 120 | psi = torch.sum(attention, axis=2).squeeze(-1) 121 | psi[psi == 0.] = 1 122 | psi = psi.unsqueeze(-1).unsqueeze(-1) 123 | else: 124 | psi = 1.0 125 | else: 126 | attention = torch.ones(*relative_field.shape[0:3],1) 127 | 128 | if self.normalize_attention: 129 | psi = torch.sum(attention, axis=2).squeeze(-1) 130 | psi[psi == 0.] = 1 131 | psi = psi.unsqueeze(-1).unsqueeze(-1) 132 | else: 133 | psi = 1.0 134 | 135 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2) 136 | 137 | attention_field_feat = field_feat*attention.unsqueeze(-1) 138 | # attention_field_feat: [batch, num_m, num_n, c_in, 2] 139 | 140 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat) 141 | # out: [batch, num_m, c_out, 2] 142 | 143 | return out / psi 144 | 145 | def forward( 146 | self, field, center, field_feat, 147 | field_mask, ctr_feat=None 148 | ): 149 | out = self.ContinuousConv( 150 | field, center, field_feat, field_mask, 151 | ctr_feat 152 | ) 153 | return out 154 | 155 | class RelEquiCtsConv2dRho1ToReg(EquiCtsConv2dRho1ToReg): 156 | 157 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 158 | use_attention=True, normalize_attention=True): 159 | super(RelEquiCtsConv2dRho1ToReg, self).__init__(in_channels, out_channels, radius, num_radii, num_theta, 160 | k, matrix_dim, use_attention, normalize_attentionv) 161 | 162 | 163 | def ContinuousConv( 164 | self, field, center, field_feat, 165 | field_mask, ctr_feat 166 | ): 167 | """ 168 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 169 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim] 170 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim] 171 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2] 172 | @ctr_feat: [batch, 1, feat_dim] 173 | @field_mask: [batch, num_n, 1] 174 | """ 175 | kernel = self.computeKernel() 176 | 177 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius 178 | # relative_field: [batch, num_m, num_n, pos_dim] 179 | 180 | 181 | polar_field = self.PolarCoords(relative_field) 182 | # polar_field: [batch, num_m, num_n, pos_dim] 183 | 184 | kernel_on_field = self.InterpolateKernel(kernel, polar_field) 185 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2] 186 | 187 | if self.use_attention: 188 | # print(relative_field.shape) 189 | # print(field_mask.unsqueeze(1).shape) 190 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1) 191 | # attention: [batch, num_m, num_n, 1] 192 | 193 | if self.normalize_attention: 194 | psi = torch.sum(attention, axis=2).squeeze(-1) 195 | psi[psi == 0.] = 1 196 | psi = psi.unsqueeze(-1).unsqueeze(-1) 197 | else: 198 | psi = 1.0 199 | else: 200 | attention = torch.ones(*relative_field.shape[0:3],1) 201 | 202 | if self.normalize_attention: 203 | psi = torch.sum(attention, axis=2).squeeze(-1) 204 | psi[psi == 0.] = 1 205 | psi = psi.unsqueeze(-1).unsqueeze(-1) 206 | else: 207 | psi = 1.0 208 | 209 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2) 210 | 211 | attention_field_feat = field_feat*attention.unsqueeze(-1) 212 | # attention_field_feat: [batch, num_m, num_n, c_in, 2] 213 | 214 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat) 215 | # out: [batch, num_m, c_out, 2] 216 | 217 | return out / psi 218 | 219 | def forward( 220 | self, field, center, field_feat, 221 | field_mask, ctr_feat=None 222 | ): 223 | out = self.ContinuousConv( 224 | field, center, field_feat, field_mask, 225 | ctr_feat 226 | ) 227 | return out 228 | 229 | class RelEquiCtsConv2dRegToReg(EquiCtsConv2dRegToReg): 230 | 231 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 232 | use_attention=True, normalize_attention=True): 233 | super(RelEquiCtsConv2dRegToReg, self).__init__(in_channels, out_channels, radius, num_radii, num_theta, 234 | k, matrix_dim, use_attention, normalize_attention) 235 | self.kernel = self.computeKernel() 236 | 237 | 238 | def ContinuousConv( 239 | self, field, center, field_feat, 240 | field_mask, ctr_feat 241 | ): 242 | """ 243 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 244 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim] 245 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim] 246 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2] 247 | @ctr_feat: [batch, 1, feat_dim] 248 | @field_mask: [batch, num_n, 1] 249 | """ 250 | kernel = self.computeKernel() 251 | 252 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius 253 | # relative_field: [batch, num_m, num_n, pos_dim] 254 | 255 | 256 | polar_field = self.PolarCoords(relative_field) 257 | # polar_field: [batch, num_m, num_n, pos_dim] 258 | 259 | kernel_on_field = self.InterpolateKernel(kernel, polar_field) 260 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2] 261 | 262 | if self.use_attention: 263 | # print(relative_field.shape) 264 | # print(field_mask.unsqueeze(1).shape) 265 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1) 266 | # attention: [batch, num_m, num_n, 1] 267 | 268 | if self.normalize_attention: 269 | psi = torch.sum(attention, axis=2).squeeze(-1) 270 | psi[psi == 0.] = 1 271 | psi = psi.unsqueeze(-1).unsqueeze(-1) 272 | else: 273 | psi = 1.0 274 | else: 275 | attention = torch.ones(*relative_field.shape[0:3],1) 276 | 277 | if self.normalize_attention: 278 | psi = torch.sum(attention, axis=2).squeeze(-1) 279 | psi[psi == 0.] = 1 280 | psi = psi.unsqueeze(-1).unsqueeze(-1) 281 | else: 282 | psi = 1.0 283 | 284 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2) 285 | 286 | attention_field_feat = field_feat*attention.unsqueeze(-1) 287 | # attention_field_feat: [batch, num_m, num_n, c_in, 2] 288 | 289 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat) 290 | # out: [batch, num_m, c_out, 2] 291 | 292 | return out / psi 293 | 294 | def forward( 295 | self, field, center, field_feat, 296 | field_mask, ctr_feat=None 297 | ): 298 | out = self.ContinuousConv( 299 | field, center, field_feat, field_mask, 300 | ctr_feat 301 | ) 302 | return out 303 | -------------------------------------------------------------------------------- /models/EquiCtsConv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math 6 | from abc import ABCMeta, abstractmethod 7 | from EquiLinear import * 8 | 9 | 10 | class EquiCtsConvBase(nn.Module, metaclass=ABCMeta): 11 | def __init__(self): 12 | super(EquiCtsConvBase, self).__init__() 13 | 14 | @abstractmethod 15 | def computeKernel(self): 16 | pass 17 | 18 | def GetAttention(self, relative_field): 19 | r = torch.sum(relative_field ** 2, axis=-1) 20 | return torch.relu((1 - r) ** 3).unsqueeze(-1) 21 | 22 | @classmethod 23 | def RotMat(cls, theta): 24 | m = torch.tensor([ 25 | [torch.cos(theta), -torch.sin(theta)], 26 | [torch.sin(theta), torch.cos(theta)] 27 | ], requires_grad=False) 28 | return m 29 | 30 | @classmethod 31 | def Rho1RotMat(cls, theta): 32 | m = torch.tensor([ 33 | [torch.cos(theta), -torch.sin(theta)], 34 | [torch.sin(theta), torch.cos(theta)] 35 | ], requires_grad=False) 36 | return m 37 | 38 | @classmethod 39 | def RegRotMat(cls, theta, k): 40 | slice_angle = 2 * math.pi / k 41 | index_shift = theta / slice_angle 42 | i = np.floor(index_shift).astype(np.int) 43 | # divide weights between i and i+1 first_col = [ 0 0 0 ... 0 w_i w_{i+1} 0 0 0 ... 0 0 0] 44 | first_col = torch.zeros(k) 45 | 46 | offset = (theta - slice_angle * i) / slice_angle 47 | w_i = 1 - offset 48 | w_ip = offset 49 | first_col[np.mod(i,8)], first_col[np.mod(i+1,8)] = w_i, w_ip 50 | 51 | m = torch.stack([torch.roll(first_col, i, 0) for i in range(k)], -1) 52 | #like a permuation matrix which sends i -> i + \theta / ( 2 \pi / k ) 53 | #Note if k = num_theta, then it is a permutation matrix 54 | return m 55 | 56 | def PolarCoords(self, vec, epsilon = 1e-9): 57 | # vec: [batch, num_m, num_n, pos_dim] 58 | # Convert to Polar 59 | r = torch.sqrt(vec[...,0] **2 + vec[...,1] **2 + epsilon) 60 | 61 | cond_nonzero = ~((vec[...,0] == 0.) & (vec[...,1] == 0.)) 62 | 63 | theta = torch.zeros(vec[...,0].shape, device=self.outer_weights.device) 64 | theta[cond_nonzero] = torch.atan2(vec[...,1][cond_nonzero], vec[...,0][cond_nonzero]) 65 | 66 | out = [r, theta] 67 | out = torch.stack(out, -1) 68 | return out 69 | 70 | def InterpolateKernel(self, kernel, pos): 71 | """ 72 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] -> [batch, C=c_out*c_in*4, r, theta] 73 | @pos: [batch, num_m, num_n, 2] -> [batch, num_m, num_n, 2] 74 | 75 | return out: [batch, C=c_out*c_in*4, num_m, num_n] -> [batch, num_m, num_n, c_out, c_in, 2, 2] 76 | """ 77 | # kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 78 | kernels = kernel.permute(0, 1, 4, 5, 2, 3) 79 | # kernels: [c_out, c_in=feat_dim, 2, 2, r, theta] 80 | 81 | kernels = kernels.reshape(-1, *kernels.shape[4:]).unsqueeze(0) 82 | # kernels: [1, c_out*c_in*2*2, r, theta] 83 | 84 | kernels = kernels.expand((pos.shape[0], *kernels.shape[1:])) 85 | # kernels: [batch_size, c_out*c_in*2*2, r, theta] 86 | #[N, C, H, W] 87 | 88 | 89 | # Copy first and last column to wrap thetas. 90 | padded_kernels = torch.cat([ 91 | kernels[..., -1].unsqueeze(-1), 92 | kernels, 93 | kernels[..., 0].unsqueeze(-1) 94 | ],dim = -1) 95 | padded_kernels = padded_kernels.permute(0,1,3,2) 96 | # padded_kernels: [batch, C=c_out*c_in*4, theta+2, r] 97 | 98 | 99 | grid = pos 100 | # adjust radii [0,1] -> [-1,1] 101 | grid[...,0] = 2*grid[...,0] - 1 102 | # adjust angles [-pi,pi] -> [-1,1] 103 | grid[...,1] *= 1/math.pi 104 | # shrink thetas slightly to account for padding 105 | grid[...,1] *= self.num_theta / (self.num_theta + 2) 106 | # grid [batch, num_m, num_n, 2] 107 | # [N, H_out, W_out, 2] 108 | 109 | # print("grid",grid) 110 | # print("padded_kernels_shape [batch_size, c_out*c_in*2*2, theta+2, r]:",padded_kernels.shape) 111 | #print("kernels",padded_kernels) 112 | 113 | out = F.grid_sample(padded_kernels, grid, padding_mode='zeros', 114 | mode='bilinear', align_corners=False) #bilinear 115 | # out: [batch, C=c_out*c_in*4, num_m, num_n] 116 | # [N, C, H_out, W_out] 117 | 118 | out = out.permute(0, 2, 3, 1) 119 | # out: [batch, num_m, num_n, C=c_out*c_in*4] 120 | out = out.reshape(*pos.shape[:-1], *kernel.shape[0:2], *kernel.shape[-2:]) 121 | # out: [batch, num_m, num_n, c_out, c_in, 2, 2] 122 | return out 123 | 124 | def ContinuousConv( 125 | self, field, center, field_feat, 126 | field_mask, ctr_feat=None 127 | ): 128 | """ 129 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] 130 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim] 131 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim] 132 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2] 133 | @ctr_feat: [batch, 1, feat_dim] 134 | @field_mask: [batch, num_n, 1] 135 | """ 136 | kernel = self.computeKernel() 137 | 138 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius 139 | # relative_field: [batch, num_m, num_n, pos_dim] 140 | 141 | 142 | polar_field = self.PolarCoords(relative_field) 143 | # polar_field: [batch, num_m, num_n, pos_dim] 144 | 145 | kernel_on_field = self.InterpolateKernel(kernel, polar_field) 146 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2] 147 | 148 | if self.use_attention: 149 | # print(relative_field.shape) 150 | # print(field_mask.unsqueeze(1).shape) 151 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1) 152 | # attention: [batch, num_m, num_n, 1] 153 | 154 | if self.normalize_attention: 155 | psi = torch.sum(attention, axis=2).squeeze(-1) 156 | psi[psi == 0.] = 1 157 | psi = psi.unsqueeze(-1).unsqueeze(-1) 158 | else: 159 | psi = 1.0 160 | else: 161 | attention = torch.ones(*relative_field.shape[0:3],1) 162 | 163 | if self.normalize_attention: 164 | psi = torch.sum(attention, axis=2).squeeze(-1) 165 | psi[psi == 0.] = 1 166 | psi = psi.unsqueeze(-1).unsqueeze(-1) 167 | else: 168 | psi = 1.0 169 | 170 | attention_field_feat = field_feat.unsqueeze(1)*attention.unsqueeze(-1) 171 | # attention_field_feat: [batch, num_m, num_n, c_in, 2] 172 | 173 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat) 174 | # out: [batch, num_m, c_out, 2] 175 | 176 | return out / psi 177 | 178 | def forward( 179 | self, field, center, field_feat, 180 | field_mask, ctr_feat=None 181 | ): 182 | out = self.ContinuousConv( 183 | field, center, field_feat, field_mask, 184 | ctr_feat 185 | ) 186 | return out 187 | 188 | 189 | class EquiCtsConv2d(EquiCtsConvBase): 190 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, matrix_dim=2, 191 | use_attention=True, normalize_attention=True): 192 | super(EquiCtsConv2d, self).__init__() 193 | self.num_theta = num_theta 194 | self.num_radii = num_radii 195 | 196 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim) 197 | self.register_buffer('kernel_basis_outer', kernel_basis_outer) 198 | self.register_buffer('kernel_bullseye', kernel_bullseye) 199 | 200 | outer_weights = torch.rand(in_channels, out_channels, num_radii, matrix_dim, matrix_dim) 201 | outer_weights -= 0.5 202 | k = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float)) 203 | outer_weights *= 1 * k 204 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights) 205 | 206 | bullseye_weights = torch.rand(in_channels, out_channels) 207 | bullseye_weights -= 0.5 208 | bullseye_weights *= 1 * k 209 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights) 210 | 211 | self.radius = radius 212 | 213 | self.use_attention = use_attention 214 | self.normalize_attention = normalize_attention 215 | 216 | def computeKernel(self): 217 | # print("[r, d, d, r, theta, d, d] ",self.kernel_basis.shape) 218 | # print("[c_in,c_out, r, d , d]", self.weights.shape) 219 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) + 220 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights)) 221 | return kernel 222 | 223 | def GenerateKernelBasis(self, r, theta, matrix_dim=2): 224 | """ 225 | output: KB : [r+1, d, d, r+1, theta, d, d] 226 | KB_bullseye : 227 | """ 228 | d = matrix_dim 229 | KB_outer = torch.zeros(r, d, d, r+1, theta, d, d, requires_grad=False) 230 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d) 231 | 232 | for i in range(d): 233 | for j in range(d): 234 | for r1 in range(0, r): 235 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d) 236 | 237 | return KB_outer, K_bullseye 238 | 239 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2): 240 | """ 241 | output: K: [r, theta, d, d] 242 | """ 243 | d = matrix_dim 244 | K = torch.zeros(r, theta, d, d, requires_grad=False) 245 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d) 246 | return K 247 | 248 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2): 249 | # d = matrix_dim 250 | # 0 <= i,j <= d-1 251 | # C = kernelcolumn: [theta, d, d] 252 | # C[0,:,:] = 0 253 | # C[0,i,j] = 1 254 | # for k in range(1,theta): 255 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta) 256 | # # K[g v] = g K[v] g^{-1} 257 | d = matrix_dim 258 | C = torch.zeros(theta, d, d, requires_grad=False) 259 | C[0,i,j] = 1 260 | # TODO: rho 1 -> rho n 261 | for k in range(1, theta): 262 | theta_i = torch.tensor(k*2*math.pi/theta) 263 | C[k] = self.RotMat(theta_i).matmul(C[0]).matmul(self.RotMat(-theta_i)) 264 | return C 265 | 266 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2): 267 | """ 268 | output: K: [r, theta, d, d] 269 | """ 270 | d = matrix_dim 271 | K = torch.zeros(r, theta, d, d, requires_grad=False) 272 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d) 273 | return K 274 | 275 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2): 276 | d = matrix_dim 277 | C = torch.zeros(theta, d, d, requires_grad=False) 278 | C[:,0,0] = 1 279 | C[:,1,1] = 1 280 | return C 281 | 282 | 283 | class EquiCtsConv2dRegToRho1(EquiCtsConvBase): 284 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 285 | use_attention=True, normalize_attention=True): 286 | super(EquiCtsConv2dRegToRho1, self).__init__() 287 | self.num_theta = num_theta 288 | self.num_radii = num_radii 289 | self.k = k 290 | 291 | RegToRho1Mat = EquiLinearRegToRho1(k).RegToRho1 #[2,k] 292 | self.register_buffer('RegToRho1Mat', RegToRho1Mat) 293 | 294 | 295 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim) 296 | self.register_buffer('kernel_basis_outer', kernel_basis_outer) 297 | self.register_buffer('kernel_bullseye', kernel_bullseye) 298 | 299 | outer_weights = torch.rand(in_channels, out_channels, num_radii, matrix_dim, k) 300 | outer_weights -= 0.5 301 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float)) 302 | outer_weights *= 1 * scale_norm 303 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights) 304 | 305 | bullseye_weights = torch.rand(in_channels, out_channels) 306 | bullseye_weights -= 0.5 307 | bullseye_weights *= 1 * scale_norm 308 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights) 309 | 310 | self.radius = radius 311 | 312 | self.use_attention = use_attention 313 | self.normalize_attention = normalize_attention 314 | 315 | def computeKernel(self): 316 | # print("[r, d, d, r, theta, d, k] ",self.kernel_basis.shape) 317 | # print("[c_in,c_out, r, d, k]", self.weights.shape) 318 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) + 319 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights)) 320 | return kernel 321 | 322 | def GenerateKernelBasis(self, r, theta, matrix_dim=2): 323 | """ 324 | output: KB : [r+1, d, k, r+1, theta, d, k] 325 | KB_bullseye : 326 | """ 327 | d = matrix_dim 328 | k = self.k 329 | 330 | KB_outer = torch.zeros(r, d, k, r+1, theta, d, k, requires_grad=False) 331 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d) 332 | 333 | for i in range(d): 334 | for j in range(k): 335 | for r1 in range(0, r): 336 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d) 337 | 338 | return KB_outer, K_bullseye 339 | 340 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2): 341 | """ 342 | output: K: [r, theta, d, k] 343 | """ 344 | d = matrix_dim 345 | k = self.k 346 | 347 | K = torch.zeros(r, theta, d, k, requires_grad=False) 348 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d) 349 | return K 350 | 351 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2): 352 | # d = matrix_dim 353 | # 0 <= i,j <= d-1 354 | # C = kernelcolumn: [theta, d, d] 355 | # C[0,:,:] = 0 356 | # C[0,i,j] = 1 357 | # for k in range(1,theta): 358 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta) 359 | # # K[g v] = g K[v] g^{-1} 360 | d = matrix_dim 361 | k = self.k 362 | 363 | C = torch.zeros(theta, d, k, requires_grad=False) 364 | C[0,i,j] = 1 365 | # TODO: rho 1 -> rho n 366 | for ind in range(1, theta): 367 | theta_i = torch.tensor(ind*2*math.pi/theta) 368 | C[ind] = self.Rho1RotMat(theta_i).matmul(C[0]).matmul(self.RegRotMat(-theta_i.numpy(), k)) 369 | return C 370 | 371 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2): 372 | """ 373 | output: K: [r, theta, d, k] 374 | """ 375 | d = matrix_dim 376 | k = self.k 377 | 378 | K = torch.zeros(r, theta, d, k, requires_grad=False) 379 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d) 380 | return K 381 | 382 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2): 383 | d = matrix_dim 384 | k = self.k 385 | 386 | C = torch.zeros(theta, d, k, requires_grad=False) 387 | C[:] = self.RegToRho1Mat 388 | return C 389 | 390 | 391 | class EquiCtsConv2dRho1ToReg(EquiCtsConvBase): 392 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 393 | use_attention=True, normalize_attention=True): 394 | super(EquiCtsConv2dRho1ToReg, self).__init__() 395 | self.num_theta = num_theta 396 | self.num_radii = num_radii 397 | self.k = k 398 | 399 | Rho1ToRegMat = EquiLinearRho1ToReg(k).Rho1ToReg #[k,2] 400 | self.register_buffer('Rho1ToRegMat', Rho1ToRegMat) 401 | 402 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim) 403 | self.register_buffer('kernel_basis_outer', kernel_basis_outer) 404 | self.register_buffer('kernel_bullseye', kernel_bullseye) 405 | 406 | outer_weights = torch.rand(in_channels, out_channels, num_radii, k, matrix_dim) 407 | outer_weights -= 0.5 408 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float)) 409 | outer_weights *= 1 * scale_norm 410 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights) 411 | 412 | bullseye_weights = torch.rand(in_channels, out_channels) 413 | bullseye_weights -= 0.5 414 | bullseye_weights *= 1 * scale_norm 415 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights) 416 | 417 | self.radius = radius 418 | 419 | self.use_attention = use_attention 420 | self.normalize_attention = normalize_attention 421 | 422 | def computeKernel(self): 423 | # print("[r, d, d, r, theta, k, d] ",self.kernel_basis.shape) 424 | # print("[c_in,c_out, r, k, d]", self.weights.shape) 425 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) + 426 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights)) 427 | return kernel 428 | 429 | def GenerateKernelBasis(self, r, theta, matrix_dim=2): 430 | """ 431 | output: KB : [r+1, k, d, r+1, theta, k, d] 432 | KB_bullseye : 433 | """ 434 | d = matrix_dim 435 | k = self.k 436 | 437 | KB_outer = torch.zeros(r, k, d, r+1, theta, k, d, requires_grad=False) 438 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d) 439 | 440 | for i in range(k): 441 | for j in range(d): 442 | for r1 in range(0, r): 443 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d) 444 | 445 | return KB_outer, K_bullseye 446 | 447 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2): 448 | """ 449 | output: K: [r, theta, d, k] 450 | """ 451 | d = matrix_dim 452 | k = self.k 453 | 454 | K = torch.zeros(r, theta, k, d, requires_grad=False) 455 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d) 456 | return K 457 | 458 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2): 459 | # d = matrix_dim 460 | # 0 <= i,j <= d-1 461 | # C = kernelcolumn: [theta, d, d] 462 | # C[0,:,:] = 0 463 | # C[0,i,j] = 1 464 | # for k in range(1,theta): 465 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta) 466 | # # K[g v] = g K[v] g^{-1} 467 | d = matrix_dim 468 | k = self.k 469 | 470 | C = torch.zeros(theta, k, d, requires_grad=False) 471 | C[0,i,j] = 1 472 | # TODO: rho 1 -> rho n 473 | for ind in range(1, theta): 474 | theta_i = torch.tensor(ind*2*math.pi/theta) 475 | C[ind] = self.RegRotMat(theta_i.numpy(), k).matmul(C[0]).matmul(self.Rho1RotMat(-theta_i)) 476 | return C 477 | 478 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2): 479 | """ 480 | output: K: [r, theta, d, k] 481 | """ 482 | d = matrix_dim 483 | k = self.k 484 | 485 | K = torch.zeros(r, theta, k, d, requires_grad=False) 486 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d) 487 | return K 488 | 489 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2): 490 | d = matrix_dim 491 | k = self.k 492 | 493 | C = torch.zeros(theta, k, d, requires_grad=False) 494 | C[:] = self.Rho1ToRegMat 495 | return C 496 | 497 | 498 | class EquiCtsConv2dRegToReg(EquiCtsConvBase): 499 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2, 500 | use_attention=True, normalize_attention=True): 501 | super(EquiCtsConv2dRegToReg, self).__init__() 502 | self.num_theta = num_theta 503 | self.num_radii = num_radii 504 | self.k = k 505 | 506 | 507 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim) 508 | self.register_buffer('kernel_basis_outer', kernel_basis_outer) 509 | self.register_buffer('kernel_bullseye', kernel_bullseye) 510 | 511 | outer_weights = torch.rand(in_channels, out_channels, num_radii, k, k) 512 | outer_weights -= 0.5 513 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float)) 514 | outer_weights *= 1 * scale_norm 515 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights) 516 | 517 | 518 | 519 | bullseye_weights = torch.rand(in_channels, out_channels, k) 520 | bullseye_weights -= 0.5 521 | bullseye_weights *= 1 * scale_norm 522 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights) 523 | 524 | self.radius = radius 525 | 526 | self.use_attention = use_attention 527 | self.normalize_attention = normalize_attention 528 | 529 | def computeKernel(self): 530 | # print("[r, d, d, r, theta, d, k] ",self.kernel_basis.shape) 531 | # print("[c_in,c_out, r, d, k]", self.weights.shape) 532 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) + 533 | torch.einsum('lrtij,xyl->yxrtij',self.kernel_bullseye,self.bullseye_weights)) 534 | return kernel 535 | 536 | def GenerateKernelBasis(self, r, theta, matrix_dim=2): 537 | """ 538 | output: KB : [r+1, k, k, r+1, theta, k, k] 539 | KB_bullseye : 540 | """ 541 | d = matrix_dim 542 | k = self.k 543 | 544 | KB_outer = torch.zeros(r, k, k, r+1, theta, k, k, requires_grad=False) 545 | K_bullseye = self.GenerateKernelBullseyeBasis(r+1, theta, d) 546 | 547 | for i in range(k): 548 | for j in range(k): 549 | for r1 in range(0, r): 550 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d) 551 | 552 | return KB_outer, K_bullseye 553 | 554 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2): 555 | """ 556 | output: K: [r, theta, k, k] 557 | """ 558 | d = matrix_dim 559 | k = self.k 560 | 561 | K = torch.zeros(r, theta, k, k, requires_grad=False) 562 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d) 563 | return K 564 | 565 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2): 566 | # d = matrix_dim 567 | # 0 <= i,j <= d-1 568 | # C = kernelcolumn: [theta, d, d] 569 | # C[0,:,:] = 0 570 | # C[0,i,j] = 1 571 | # for k in range(1,theta): 572 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta) 573 | # # K[g v] = g K[v] g^{-1} 574 | d = matrix_dim 575 | k = self.k 576 | 577 | C = torch.zeros(theta, k, k, requires_grad=False) 578 | C[0,i,j] = 1 579 | # TODO: rho 1 -> rho n 580 | for ind in range(1, theta): 581 | theta_i = torch.tensor(ind*2*math.pi/theta) 582 | C[ind] = self.RegRotMat(theta_i.numpy(), k).matmul(C[0]).matmul(self.RegRotMat(-theta_i.numpy(), k)) 583 | return C 584 | 585 | def GenerateKernelBullseyeBasis(self, r, theta, matrix_dim=2): 586 | """ 587 | output: K: [k, r, theta, k, k] 588 | """ 589 | d = matrix_dim 590 | k = self.k 591 | 592 | K = torch.zeros(k, r, theta, k, k, requires_grad=False) 593 | for l in range(k): 594 | K[l,0] = self.GenerateKernelBullseyeElementColumn(theta, l, d) 595 | return K 596 | 597 | def GenerateKernelBullseyeElementColumn(self, theta, l, matrix_dim=2): 598 | d = matrix_dim 599 | k = self.k 600 | 601 | first_col = torch.zeros(k) 602 | first_col[l] = 1. 603 | C = torch.zeros(theta, k, k, requires_grad=False) 604 | C[:] = torch.stack([torch.roll(first_col, i, 0) for i in range(0,self.k)],-1) 605 | return C 606 | --------------------------------------------------------------------------------