├── requirements.txt
├── no_title_thumbnail.png
├── datasets
├── __pycache__
│ ├── helper.cpython-36.pyc
│ └── argoverse_pickle_loader.cpython-36.pyc
├── argoverse_lane_loader.py
├── helper.py
└── preprocess_data.py
├── scripts
├── train_utils.py
├── evaluate_network.py
└── train.py
├── README.md
└── models
├── EquiLinear.py
├── rho1_ECCO.py
├── rho_reg_ECCO.py
├── RelEquiCtsConv.py
└── EquiCtsConv.py
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.3
2 | pandas>=1.0.3
3 | torch==1.5.0
--------------------------------------------------------------------------------
/no_title_thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/no_title_thumbnail.png
--------------------------------------------------------------------------------
/datasets/__pycache__/helper.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/datasets/__pycache__/helper.cpython-36.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/argoverse_pickle_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rose-STL-Lab/ECCO/HEAD/datasets/__pycache__/argoverse_pickle_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/scripts/train_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import gc
3 |
4 |
5 | def euclidean_distance(a, b, epsilon=1e-9):
6 | return torch.sqrt(torch.sum((a - b)**2, axis=-1) + epsilon)
7 |
8 |
9 | def loss_fn(pr_pos, gt_pos, num_fluid_neighbors, car_mask):
10 | gamma = 0.5
11 | neighbor_scale = 1 / 40
12 | importance = torch.exp(-neighbor_scale * num_fluid_neighbors)
13 | dist = euclidean_distance(pr_pos, gt_pos)**gamma
14 | mask_dist = dist * car_mask
15 | batch_losses = torch.mean(importance * mask_dist, axis=-1)
16 | # print(batch_losses)
17 | return torch.sum(batch_losses)
18 |
19 | def clean_cache(device):
20 | if device == torch.device('cuda'):
21 | torch.cuda.empty_cache()
22 | if device == torch.device('cpu'):
23 | # gc.collect()
24 | pass
25 |
26 | def get_lr(optimizer):
27 | for param_group in optimizer.param_groups:
28 | return param_group['lr']
29 |
30 | def unsqueeze_n(tensor, n):
31 | for i in range(n):
32 | tensor = tensor.unsqueeze(-1)
33 | return tensor
34 |
35 |
36 | def normalize_input(tensor_dict, scale, train_window):
37 | pos_keys = (['pos' + str(i) for i in range(train_window + 1)] +
38 | ['pos_2s', 'lane', 'lane_norm'])
39 | vel_keys = (['vel' + str(i) for i in range(train_window + 1)] +
40 | ['vel_2s'])
41 | max_pos = torch.cat([torch.max(tensor_dict[pk].reshape(tensor_dict[pk].shape[0], -1), axis=-1)
42 | .values.unsqueeze(-1)
43 | for pk in pos_keys], axis=-1)
44 | max_pos = torch.max(max_pos, axis=-1).values
45 |
46 | for pk in pos_keys:
47 | tensor_dict[pk][...,:2] = (tensor_dict[pk][...,:2] - unsqueeze_n(max_pos, len(tensor_dict[pk].shape)-1)) / scale
48 |
49 | for vk in vel_keys:
50 | tensor_dict[vk] = tensor_dict[vk] / scale
51 |
52 | return tensor_dict, max_pos
53 |
54 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Trajectory Prediction using Equivariant Continuous Convolution (ECCO)
2 |
3 | This is the codebase for the ICLR 2021 paper [Trajectory Prediction using Equivariant Continuous Convolution](https://arxiv.org/abs/2010.11344), by Robin Walters, Jinxi Li and Rose Yu.
4 |
5 | 
6 |
7 | ## Installation
8 |
9 | This codebase is trained on Python 3.6.6+. For the usage of argoverse dataset, [argoverse-api](https://github.com/argoai/argoverse-api) is required. We recommand the reader to follow their guide to install the complete api and datasets. Other requirements include:
10 | - numpy>=1.18.3
11 | - pandas>=1.0.3
12 | - torch==1.5.0
13 |
14 | Dependency can be installed using the following command:
15 | ```bash
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ## Data Preparation
20 |
21 | Original data could be downloaded from [argoverse](https://www.argoverse.org/data.html). To generate the training and validation data
22 | 1. Set the path to `argoverse_forecasting` in `datasets/preprocess_data.py` scripts.
23 | 2. Run the script
24 | ```bash
25 | python preprocess_data.py
26 | ```
27 | The data will be stored in `path/to/argoverse_forecasting/train(val)/lane_data`.
28 |
29 | ## Data Download
30 |
31 | If you want to skip the data generation part, the link to preprocessed data will be provided soon.
32 |
33 | ## Model Training and Evaluation
34 |
35 | Here are commands to train the model. The evaluation will be provided after the model is trained.
36 |
37 | For
-ECCO, run the following command
38 | ```bash
39 | python train.py --dataset_path /path/to/argoverse_forecasting/ --rho1 --model_name rho_1_ecco --train --evaluation
40 | ```
41 |
42 | For
-ECCO, run the following command
43 | ```bash
44 | python train.py --dataset_path /path/to/argoverse_forecasting/ --rho-reg --model_name rho_1_ecco --train --evaluation
45 | ```
46 |
47 | For the baseline evaluation, you can refer to [Argoverse Official Baseline](https://github.com/jagjeet-singh/argoverse-forecasting). Note: the evaluation of the constant velocity is evaluated on the validation set (filtered out the scenes with car number greater than 60) with the velocity at final timestamp as the constant velocity.
48 |
49 | ## Citation
50 |
51 | If you find this repository useful in your research, please cite our paper:
52 | ```
53 | @article{Walters2021ECCO,
54 | title={Trajectory Prediction using Equivariant Continuous Convolution},
55 | author={Robin Walters and Jinxi Li and Rose Yu},
56 | journal={International Conference on Learning Representations},
57 | year={2021},
58 | }
59 |
--------------------------------------------------------------------------------
/datasets/argoverse_lane_loader.py:
--------------------------------------------------------------------------------
1 | "Functions loading the .pkl version preprocessed data"
2 | from glob import glob
3 | import pickle
4 | import os
5 | import numpy as np
6 | from argoverse.map_representation.map_api import ArgoverseMap
7 | from torch.utils.data import IterableDataset, DataLoader
8 |
9 |
10 | class ArgoverseDataset(IterableDataset):
11 | def __init__(self, data_path: str, transform=None,
12 | max_lane_nodes=650, min_lane_nodes=0, shuffle=True):
13 | super(ArgoverseDataset, self).__init__()
14 | self.data_path = data_path
15 | self.transform = transform
16 | self.pkl_list = glob(os.path.join(self.data_path, '*'))
17 | if shuffle:
18 | np.random.shuffle(self.pkl_list)
19 | else:
20 | self.pkl_list.sort()
21 | self.max_lane_nodes = max_lane_nodes
22 | self.min_lane_nodes = min_lane_nodes
23 |
24 | def __len__(self):
25 | return len(self.pkl_list)
26 |
27 | def __iter__(self):
28 | # pkl_path = self.pkl_list[idx]
29 | for pkl_path in self.pkl_list:
30 | with open(pkl_path, 'rb') as f:
31 | data = pickle.load(f)
32 | # data = {k:v[0] for k, v in data.items()}
33 | lane_mask = np.zeros(self.max_lane_nodes, dtype=np.float32)
34 | lane_mask[:len(data['lane'][0])] = 1.0
35 | data['lane_mask'] = [lane_mask]
36 |
37 | if data['lane'][0].shape[0] > self.max_lane_nodes:
38 | continue
39 |
40 | if data['lane'][0].shape[0] < self.min_lane_nodes:
41 | continue
42 |
43 | data['lane'] = [self.expand_particle(data['lane'][0], self.max_lane_nodes, 0)]
44 | data['lane_norm'] = [self.expand_particle(data['lane_norm'][0], self.max_lane_nodes, 0)]
45 |
46 | if self.transform:
47 | data = self.transform(data)
48 |
49 | yield data
50 |
51 | @classmethod
52 | def expand_particle(cls, arr, max_num, axis, value_type='int'):
53 | dummy_shape = list(arr.shape)
54 | dummy_shape[axis] = max_num - arr.shape[axis]
55 | dummy = np.zeros(dummy_shape)
56 | if value_type == 'str':
57 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
58 | return np.concatenate([arr, dummy], axis=axis)
59 |
60 |
61 | def cat_key(data, key):
62 | result = []
63 | for d in data:
64 | result = result + d[key]
65 | return result
66 |
67 |
68 | def dict_collate_func(data):
69 | keys = data[0].keys()
70 | data = {key: cat_key(data, key) for key in keys}
71 | return data
72 |
73 |
74 | def read_pkl_data(data_path: str, batch_size: int,
75 | shuffle: bool=False, repeat: bool=False, **kwargs):
76 | dataset = ArgoverseDataset(data_path=data_path, shuffle=shuffle, **kwargs)
77 | loader = DataLoader(dataset, batch_size=batch_size, collate_fn=dict_collate_func)
78 | if repeat:
79 | while True:
80 | for data in loader:
81 | yield data
82 | else:
83 | for data in loader:
84 | yield data
85 |
86 |
--------------------------------------------------------------------------------
/models/EquiLinear.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 |
8 | # rho1 --> rho1
9 | class EquiLinear(nn.Module):
10 | def __init__(self, in_features, out_features):
11 | super(EquiLinear, self).__init__()
12 | self.linear = nn.Linear(in_features, out_features, bias=False)
13 |
14 | def forward(self, field_feat):
15 | """
16 | inputs:
17 | @field_feat: [batch, num_part, in_feat, 2]
18 |
19 | output:
20 | [batch, num_part, out_feat, 2]
21 | """
22 | return self.linear(field_feat.permute(0, 1, 3, 2)).permute(0, 1, 3, 2)
23 |
24 |
25 | # Reg --> Reg
26 | class EquiLinearRegToReg(nn.Module):
27 | def __init__(self, in_features, out_features, k):
28 | super(EquiLinearRegToReg, self).__init__()
29 | self.k = k
30 | self.weights = nn.parameter.Parameter(torch.rand(in_features, out_features, k) / in_features)
31 | # print(self.weights)
32 | # kernel = self.update_kernel()
33 | # print(self.kernel)
34 |
35 | def update_kernel(self):
36 | # i or -i ??? stack -2 or -1 ??? torch.flip ???
37 | return torch.stack([torch.roll(self.weights, i, 2) for i in range(0,self.k)],-1)
38 |
39 | def forward(self, field_feat):
40 | """
41 | inputs:
42 | k: int -- number of slices of the circle for regular rep
43 | @field_feat: [batch, num_part, in_feat, k]
44 | kernel: [in_feat, out_feat, k, k]
45 |
46 | f*k(\theta) = \sum_\psi K(\psi)f(\theta - \psi)
47 |
48 | output:
49 | [batch, num_part, out_feat, k]
50 | """
51 | # x or y ???
52 | kernel = self.update_kernel()
53 | return torch.einsum('ijyx,...ix->...jy',kernel,field_feat)
54 |
55 |
56 | # Rho1 --> Reg
57 | class EquiLinearRho1ToReg(nn.Module):
58 | def __init__(self, k):
59 | super(EquiLinearRho1ToReg, self).__init__()
60 | self.k = k
61 | SinVec = torch.tensor([math.sin(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False)
62 | CosVec = torch.tensor([math.cos(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False)
63 | Rho1ToReg = torch.stack([CosVec,SinVec],1) #[k,2]
64 | self.register_buffer('Rho1ToReg', Rho1ToReg)
65 |
66 | def forward(self, field_feat):
67 | """
68 | k: int -- number of slices of the circle for regular rep
69 | inputs:
70 | @field_feat: [batch, num_part, in_feat, 2]
71 | output: [batch, num_part, in_feat, k]
72 |
73 | (a,b) --> a Sin + b Cos
74 | """
75 | return torch.einsum('yx,...x->...y',self.Rho1ToReg, field_feat)
76 |
77 |
78 | # Reg --> Rho1
79 | class EquiLinearRegToRho1(nn.Module):
80 | def __init__(self, k):
81 | super(EquiLinearRegToRho1, self).__init__()
82 | self.k = k
83 | SinVec = torch.tensor([math.sin(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False)
84 | CosVec = torch.tensor([math.cos(i * 2 * math.pi / self.k) for i in range(k)],requires_grad=False)
85 | RegToRho1 = torch.stack([CosVec,SinVec],0) #[2,k]
86 | self.register_buffer('RegToRho1', RegToRho1)
87 |
88 | def forward(self, field_feat):
89 | '''
90 | k: int -- number of slices of the circle for regular rep
91 | inputs:
92 | @field_feat: [batch, num_part, in_feat, k]
93 | output:
94 | retval: [batch, num_part, in_feat, 2]
95 |
96 | f is a function on circle divided into k parts
97 | f(i) means f(2\pi *i /k)
98 | f --> ( \sum_{i=0}^k ( f(i) cos(2 \pi i /k) , f(i) sin(2 \pi i /k) )
99 | This is a fourier transform.
100 | '''
101 | return torch.einsum('yx,...x->...y',self.RegToRho1, field_feat)
102 |
--------------------------------------------------------------------------------
/scripts/evaluate_network.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import sys
4 | import numpy as np
5 | import time
6 | import importlib
7 | import torch
8 | import pickle
9 |
10 |
11 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
12 | from train_utils import *
13 |
14 |
15 | def get_agent(pr, gt, pr_id, gt_id, agent_id, device='cpu'):
16 |
17 | pr_agent = pr[pr_id == agent_id,:]
18 | gt_agent = gt[gt_id == agent_id,:]
19 |
20 | return pr_agent, gt_agent
21 |
22 |
23 | def evaluate(model, val_dataset, train_window=3, max_iter=2500, device='cpu', start_iter=0,
24 | batch_size=32):
25 |
26 | print('evaluating.. ', end='', flush=True)
27 |
28 | count = 0
29 | prediction_gt = {}
30 | losses = []
31 | val_iter = iter(val_dataset)
32 |
33 | for i, sample in enumerate(val_dataset):
34 |
35 | if i >= max_iter:
36 | break
37 |
38 | if i < start_iter:
39 | continue
40 |
41 | pred = []
42 | gt = []
43 |
44 | if count % 1 == 0:
45 | print('{}'.format(count + 1), end=' ', flush=True)
46 |
47 | count += 1
48 |
49 | data = {}
50 | convert_keys = (['pos' + str(i) for i in range(31)] +
51 | ['vel' + str(i) for i in range(31)] +
52 | ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])
53 |
54 | for k in convert_keys:
55 | data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device)
56 |
57 |
58 | for k in ['track_id' + str(i) for i in range(31)] + ['city', 'agent_id', 'scene_idx']:
59 | data[k] = np.stack(sample[k])
60 |
61 | for k in ['car_mask', 'lane_mask']:
62 | data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1)
63 |
64 | scenes = data['scene_idx'].tolist()
65 |
66 | data['agent_id'] = data['agent_id'][:,np.newaxis]
67 |
68 | data['car_mask'] = data['car_mask'].squeeze(-1)
69 | accel = torch.zeros(1, 1, 2).to(device)
70 | data['accel'] = accel
71 |
72 | lane = data['lane']
73 | lane_normals = data['lane_norm']
74 | agent_id = data['agent_id']
75 | city = data['city']
76 |
77 | inputs = ([
78 | data['pos_2s'], data['vel_2s'],
79 | data['pos0'], data['vel0'],
80 | data['accel'], None,
81 | data['lane'], data['lane_norm'],
82 | data['car_mask'], data['lane_mask']
83 | ])
84 |
85 | pr_pos1, pr_vel1, states = model(inputs)
86 | gt_pos1 = data['pos1']
87 |
88 | l = 0.5 * loss_fn(pr_pos1, gt_pos1,
89 | torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))
90 |
91 | pr_agent, gt_agent = get_agent(pr_pos1, data['pos1'],
92 | data['track_id0'],
93 | data['track_id1'],
94 | agent_id, device)
95 | pred.append(pr_agent.unsqueeze(1).detach().cpu())
96 | gt.append(gt_agent.unsqueeze(1).detach().cpu())
97 | del pr_agent, gt_agent
98 | clean_cache(device)
99 |
100 | pos0 = data['pos0']
101 | vel0 = data['vel0']
102 | for i in range(29):
103 | pos_enc = torch.unsqueeze(pos0, 2)
104 | vel_enc = torch.unsqueeze(vel0, 2)
105 | inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None,
106 | data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])
107 | pos0, vel0 = pr_pos1, pr_vel1
108 | pr_pos1, pr_vel1, states = model(inputs, states)
109 | clean_cache(device)
110 |
111 | if i < train_window - 1:
112 | gt_pos1 = data['pos'+str(i+2)]
113 | l += 0.5 * loss_fn(pr_pos1, gt_pos1,
114 | torch.sum(data['car_mask'], dim = -2) - 1, data['car_mask'].squeeze(-1))
115 |
116 | pr_agent, gt_agent = get_agent(pr_pos1, data['pos'+str(i+2)],
117 | data['track_id0'],
118 | data['track_id'+str(i+2)],
119 | agent_id, device)
120 |
121 | pred.append(pr_agent.unsqueeze(1).detach().cpu())
122 | gt.append(gt_agent.unsqueeze(1).detach().cpu())
123 |
124 | clean_cache(device)
125 |
126 | losses.append(l)
127 |
128 | predict_result = (torch.cat(pred, axis=1), torch.cat(gt, axis=1))
129 | for idx, scene_id in enumerate(scenes):
130 | prediction_gt[scene_id] = (predict_result[0][idx], predict_result[1][idx])
131 |
132 | total_loss = 128 * torch.sum(torch.stack(losses),axis=0) / max_iter
133 |
134 | result = {}
135 | de = {}
136 |
137 | for k, v in prediction_gt.items():
138 | de[k] = torch.sqrt((v[0][:,0] - v[1][:,0])**2 +
139 | (v[0][:,1] - v[1][:,1])**2)
140 |
141 | ade = []
142 | de1s = []
143 | de2s = []
144 | de3s = []
145 | for k, v in de.items():
146 | ade.append(np.mean(v.numpy()))
147 | de1s.append(v.numpy()[10])
148 | de2s.append(v.numpy()[20])
149 | de3s.append(v.numpy()[-1])
150 |
151 | result['ADE'] = np.mean(ade)
152 | result['ADE_std'] = np.std(ade)
153 | result['DE@1s'] = np.mean(de1s)
154 | result['DE@1s_std'] = np.std(de1s)
155 | result['DE@2s'] = np.mean(de2s)
156 | result['DE@2s_std'] = np.std(de2s)
157 | result['DE@3s'] = np.mean(de3s)
158 | result['DE@3s_std'] = np.std(de3s)
159 |
160 | print(result)
161 | print('done')
162 |
163 | return total_loss, prediction_gt
164 |
165 |
166 |
167 |
168 |
169 |
--------------------------------------------------------------------------------
/datasets/helper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | A modified visualization function and get_lanes function
5 | """
6 |
7 | import argparse
8 | import os
9 | import shutil
10 | import sys
11 | from collections import defaultdict
12 | from typing import Dict, Optional
13 |
14 | import matplotlib.animation as anim
15 | import matplotlib.lines as mlines
16 | import matplotlib.pyplot as plt
17 | import numpy as np
18 | import pandas as pd
19 | import scipy.interpolate as interp
20 |
21 | from argoverse.map_representation.map_api import ArgoverseMap
22 |
23 | _ZORDER = {"AGENT": 15, "AV": 10, "OTHERS": 5}
24 |
25 |
26 | def get_batch_lane_direction(pos, city, am):
27 | if am is None:
28 | from argoverse.map_representation.map_api import ArgoverseMap
29 | am = ArgoverseMap()
30 | else:
31 | pass
32 | drct_conf = list()
33 |
34 | for ps, c in zip(pos, city):
35 | drct_conf.append(np.array([np.append(*am.get_lane_direction(p[:2], c)) for p in ps]))
36 |
37 | return drct_conf
38 |
39 |
40 | def get_lane_direction(pos, city, am):
41 | if am is None:
42 | from argoverse.map_representation.map_api import ArgoverseMap
43 | am = ArgoverseMap()
44 | else:
45 | pass
46 | drct_conf = np.array([np.append(*am.get_lane_direction(p[:2], city)) for p in pos])
47 |
48 | return drct_conf
49 |
50 |
51 | def interpolate_polyline(polyline: np.ndarray, num_points: int) -> np.ndarray:
52 | duplicates = []
53 | for i in range(1, len(polyline)):
54 | if np.allclose(polyline[i], polyline[i - 1]):
55 | duplicates.append(i)
56 | if polyline.shape[0] - len(duplicates) < 4:
57 | return polyline
58 | if duplicates:
59 | polyline = np.delete(polyline, duplicates, axis=0)
60 | tck, u = interp.splprep(polyline.T, s=0)
61 | u = np.linspace(0.0, 1.0, num_points)
62 | return np.column_stack(interp.splev(u, tck))
63 |
64 |
65 | def get_lanes(df: pd.DataFrame, city_name: str, avm: Optional[ArgoverseMap] = None) -> list:
66 |
67 | # Get API for Argo Dataset map
68 | avm = ArgoverseMap() if avm is None else avm
69 | seq_lane_bbox = avm.city_halluc_bbox_table[city_name]
70 | seq_lane_props = avm.city_lane_centerlines_dict[city_name]
71 |
72 | x_min = min(df["X"])
73 | x_max = max(df["X"])
74 | y_min = min(df["Y"])
75 | y_max = max(df["Y"])
76 |
77 | lane_centerlines = []
78 |
79 | # Get lane centerlines which lie within the range of trajectories
80 | for lane_id, lane_props in seq_lane_props.items():
81 |
82 | lane_cl = lane_props.centerline
83 |
84 | if (
85 | np.min(lane_cl[:, 0]) < x_max
86 | and np.min(lane_cl[:, 1]) < y_max
87 | and np.max(lane_cl[:, 0]) > x_min
88 | and np.max(lane_cl[:, 1]) > y_min
89 | ):
90 | lane_centerlines.append(lane_cl)
91 |
92 | return lane_centerlines
93 |
94 |
95 | def get_all_lanes(city_name: str, avm: Optional[ArgoverseMap] = None) -> list:
96 |
97 | # Get API for Argo Dataset map
98 | avm = ArgoverseMap() if avm is None else avm
99 | seq_lane_bbox = avm.city_halluc_bbox_table[city_name]
100 | seq_lane_props = avm.city_lane_centerlines_dict[city_name]
101 |
102 | lane_centerlines = [lane.centerline for lane in seq_lane_props.values()]
103 |
104 | return lane_centerlines
105 |
106 |
107 | def visualize_trajectory(
108 | df: pd.DataFrame, lane_centerlines: Optional[np.ndarray] = None, show: bool = True, smoothen: bool = False
109 | ) -> None:
110 |
111 | # Seq data
112 | # time_list = np.sort(np.unique(df["TIMESTAMP"].values))
113 | city_name = df["CITY_NAME"].values[0]
114 |
115 | lane_centerlines = get_lanes(df, city_name) if lane_centerlines is None else lane_centerlines
116 |
117 | plt.figure(0, figsize=(8, 7))
118 |
119 | x_min = min(df["X"])
120 | x_max = max(df["X"])
121 | y_min = min(df["Y"])
122 | y_max = max(df["Y"])
123 |
124 | plt.xlim(x_min, x_max)
125 | plt.ylim(y_min, y_max)
126 |
127 |
128 | for lane_cl in lane_centerlines:
129 | plt.plot(lane_cl[:, 0], lane_cl[:, 1], "--", color="grey", alpha=1, linewidth=1, zorder=0)
130 | frames = df.groupby("TRACK_ID")
131 |
132 | plt.xlabel("Map X")
133 | plt.ylabel("Map Y")
134 |
135 | color_dict = {"AGENT": "#d33e4c", "OTHERS": "#13d4f2", "AV": "#007672"}
136 | object_type_tracker: Dict[int, int] = defaultdict(int)
137 |
138 | # Plot all the tracks up till current frame
139 | for group_name, group_data in frames:
140 | object_type = group_data["OBJECT_TYPE"].values[0]
141 |
142 | cor_x = group_data["X"].values
143 | cor_y = group_data["Y"].values
144 |
145 | if smoothen:
146 | polyline = np.column_stack((cor_x, cor_y))
147 | num_points = cor_x.shape[0] * 3
148 | smooth_polyline = interpolate_polyline(polyline, num_points)
149 | cor_x = smooth_polyline[:, 0]
150 | cor_y = smooth_polyline[:, 1]
151 |
152 | plt.plot(
153 | cor_x,
154 | cor_y,
155 | "-",
156 | color=color_dict[object_type],
157 | label=object_type if not object_type_tracker[object_type] else "",
158 | alpha=1,
159 | linewidth=1,
160 | zorder=_ZORDER[object_type],
161 | )
162 |
163 | final_x = cor_x[-1]
164 | final_y = cor_y[-1]
165 |
166 | if object_type == "AGENT":
167 | marker_type = "o"
168 | marker_size = 7
169 | elif object_type == "OTHERS":
170 | marker_type = "o"
171 | marker_size = 7
172 | elif object_type == "AV":
173 | marker_type = "o"
174 | marker_size = 7
175 |
176 | plt.plot(
177 | final_x,
178 | final_y,
179 | marker_type,
180 | color=color_dict[object_type],
181 | label=object_type if not object_type_tracker[object_type] else "",
182 | alpha=1,
183 | markersize=marker_size,
184 | zorder=_ZORDER[object_type],
185 | )
186 |
187 | object_type_tracker[object_type] += 1
188 |
189 | red_star = mlines.Line2D([], [], color="red", marker="*", linestyle="None", markersize=7, label="Agent")
190 | green_circle = mlines.Line2D([], [], color="green", marker="o", linestyle="None", markersize=7, label="Others")
191 | black_triangle = mlines.Line2D([], [], color="black", marker="^", linestyle="None", markersize=7, label="AV")
192 |
193 | plt.axis("off")
194 |
195 | if show:
196 | plt.show()
197 |
198 |
199 |
200 |
201 |
202 |
--------------------------------------------------------------------------------
/models/rho1_ECCO.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import argoverse
7 |
8 | import sys
9 | import os
10 | sys.path.append(os.path.dirname(__file__))
11 |
12 | from EquiCtsConv import *
13 | from EquiLinear import *
14 |
15 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 | class ECCONetwork(nn.Module):
18 | def __init__(self,
19 | num_radii = 3,
20 | num_theta = 16,
21 | radius_scale = 40,
22 | timestep = 0.1,
23 | encoder_hidden_size = 19,
24 | layer_channels = [16, 32, 32, 32, 1]
25 | ):
26 | super(ECCONetwork, self).__init__()
27 |
28 | # init parameters
29 |
30 | self.num_radii = num_radii
31 | self.num_theta = num_theta
32 | self.radius_scale = radius_scale
33 | self.timestep = timestep
34 | self.layer_channels = layer_channels
35 |
36 | self.encoder_hidden_size = encoder_hidden_size
37 | self.in_channel = 1 + self.encoder_hidden_size
38 | self.activation = F.relu
39 | # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2))
40 | relu_shift = torch.tensor(0.2)
41 | self.register_buffer('relu_shift', relu_shift)
42 |
43 | # create continuous convolution and fully-connected layers
44 |
45 | convs = []
46 | denses = []
47 | # c_in, c_out, radius, num_radii, num_theta
48 | self.conv_fluid = EquiCtsConv2d(in_channels = self.in_channel,
49 | out_channels = self.layer_channels[0],
50 | num_radii = self.num_radii,
51 | num_theta = self.num_theta,
52 | radius = self.radius_scale)
53 |
54 | self.conv_obstacle = EquiCtsConv2d(in_channels = 1,
55 | out_channels = self.layer_channels[0],
56 | num_radii = self.num_radii,
57 | num_theta = self.num_theta,
58 | radius = self.radius_scale)
59 |
60 | self.dense_fluid = EquiLinear(self.in_channel, self.layer_channels[0])
61 |
62 | # concat conv_obstacle, conv_fluid, dense_fluid
63 | in_ch = 3 * self.layer_channels[0]
64 | for i in range(1, len(self.layer_channels)):
65 | out_ch = self.layer_channels[i]
66 | dense = EquiLinear(in_ch, out_ch)
67 | denses.append(dense)
68 | conv = EquiCtsConv2d(in_channels = in_ch,
69 | out_channels = out_ch,
70 | num_radii = self.num_radii,
71 | num_theta = self.num_theta,
72 | radius = self.radius_scale)
73 | convs.append(conv)
74 | in_ch = self.layer_channels[i]
75 |
76 | self.convs = nn.ModuleList(convs)
77 | self.denses = nn.ModuleList(denses)
78 |
79 |
80 | def update_pos_vel(self, p0, v0, a):
81 | """Apply acceleration and integrate position and velocity.
82 | Assume the particle has constant acceleration during timestep.
83 | Return particle's position and velocity after 1 unit timestep."""
84 |
85 | dt = self.timestep
86 | v1 = v0 + dt * a
87 | p1 = p0 + dt * (v0 + v1) / 2
88 | return p1, v1
89 |
90 | def apply_correction(self, p0, p1, correction):
91 | """Apply the position correction
92 | p0, p1: the position of the particle before/after basic integration. """
93 | dt = self.timestep
94 | p_corrected = p1 + correction
95 | v_corrected = (p_corrected - p0) / dt
96 | return p_corrected, v_corrected
97 |
98 | def compute_correction(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask):
99 | """Precondition: p and v were updated with accerlation"""
100 |
101 | fluid_feats = [v.unsqueeze(-2)]
102 | if not other_feats is None:
103 | fluid_feats.append(other_feats)
104 | fluid_feats = torch.cat(fluid_feats, -2)
105 |
106 | # compute the correction by accumulating the output through the network layers
107 | output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask)
108 | output_dense_fluid = self.dense_fluid(fluid_feats)
109 | output_conv_obstacle = self.conv_obstacle(box, p, box_feats.unsqueeze(-2), box_mask)
110 |
111 | feats = torch.cat((output_conv_obstacle, output_conv_fluid, output_dense_fluid), -2)
112 | # self.outputs = [feats]
113 | output = feats
114 |
115 | for conv, dense in zip(self.convs, self.denses):
116 | # pass input features to conv and fully-connected layers
117 | mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1)
118 | in_feats = output/mags * self.activation(mags - self.relu_shift)
119 | # in_feats = self.activation(output)
120 | # in_feats = output
121 | output_conv = conv(p, p, in_feats, fluid_mask)
122 | output_dense = dense(in_feats)
123 |
124 | # if last dim size of output from cur dense layer is same as last dim size of output
125 | # current output should be based off on previous output
126 | if output_dense.shape[-2] == output.shape[-2]:
127 | output = output_conv + output_dense + output
128 | else:
129 | output = output_conv + output_dense
130 | # self.outputs.append(output)
131 |
132 | # compute the number of fluid particle neighbors.
133 | # this info is used in the loss function during training.
134 | # TODO: test this block of code
135 | self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1
136 |
137 | # self.last_features = self.outputs[-2]
138 |
139 | # scale to better match the scale of the output distribution
140 | self.pos_correction = (1.0 / 128) * output
141 | return self.pos_correction
142 |
143 | def forward(self, inputs, states=None):
144 | """ inputs: 8 elems tuple
145 | p0_enc, v0_enc, p0, v0, a, feats, box, box_feats
146 | v0_enc: [batch, num_part, timestamps, 2]
147 | Computes 1 simulation timestep"""
148 | p0_enc, v0_enc, p0, v0, a, other_feats, box, box_feats, fluid_mask, box_mask = inputs
149 |
150 | if states is None:
151 | if other_feats is None:
152 | feats = v0_enc
153 | else:
154 | feats = torch.cat((other_feats, v0_enc), -2)
155 | else:
156 | if other_feats is None:
157 | feats = v0_enc
158 | feats = torch.cat((states[0][...,1:,:], feats), -2)
159 | else:
160 | feats = torch.cat((other_feats, states[0][...,1:,:], v0_enc), -2)
161 | # print(feats.shape)
162 |
163 | # a = (v0 - v0_enc[...,-1,:]) / self.timestep
164 | p1, v1 = self.update_pos_vel(p0, v0, a)
165 | pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fluid_mask, box_mask)
166 | p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction.squeeze(-2))
167 |
168 | return p_corrected, v_corrected, (feats, None)
169 |
170 |
--------------------------------------------------------------------------------
/models/rho_reg_ECCO.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 | import argoverse
7 |
8 | import sys
9 | import os
10 | sys.path.append(os.path.dirname(__file__))
11 |
12 | from EquiCtsConv import *
13 | from EquiLinear import *
14 |
15 | #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16 |
17 | class ECCONetwork(nn.Module):
18 | def __init__(self,
19 | num_radii = 3,
20 | num_theta = 16,
21 | reg_dim = 8,
22 | radius_scale = 40,
23 | timestep = 0.1,
24 | encoder_hidden_size = 19,
25 | layer_channels = [8, 16, 16, 16, 1]
26 | ):
27 | super(ECCONetwork, self).__init__()
28 |
29 | # init parameters
30 |
31 | self.num_radii = num_radii
32 | self.num_theta = num_theta
33 | self.reg_dim = reg_dim
34 | self.radius_scale = radius_scale
35 | self.timestep = timestep
36 | self.layer_channels = layer_channels
37 | self.filter_extent = np.float32(self.radius_scale * 6 *
38 | self.particle_radius)
39 |
40 | self.encoder_hidden_size = encoder_hidden_size
41 | self.in_channel = 1 + self.encoder_hidden_size
42 | self.activation = F.relu
43 | # self.relu_shift = torch.nn.parameter.Parameter(torch.tensor(0.2))
44 | relu_shift = torch.tensor(0.2)
45 | self.register_buffer('relu_shift', relu_shift)
46 |
47 | # create continuous convolution and fully-connected layers
48 |
49 | convs = []
50 | denses = []
51 | # c_in, c_out, radius, num_radii, num_theta
52 | self.conv_fluid = EquiCtsConv2dRho1ToReg(in_channels = self.in_channel,
53 | out_channels = self.layer_channels[0],
54 | num_radii = self.num_radii,
55 | num_theta = self.num_theta,
56 | radius = self.radius_scale,
57 | k = self.reg_dim)
58 |
59 | self.conv_obstacle = EquiCtsConv2dRho1ToReg(in_channels = 1,
60 | out_channels = self.layer_channels[0],
61 | num_radii = self.num_radii,
62 | num_theta = self.num_theta,
63 | radius = self.radius_scale,
64 | k = self.reg_dim)
65 |
66 | self.dense_fluid = nn.Sequential(
67 | EquiLinearRho1ToReg(self.reg_dim),
68 | EquiLinearRegToReg(self.in_channel, self.layer_channels[0], self.reg_dim)
69 | )
70 |
71 | # concat conv_obstacle, conv_fluid, dense_fluid
72 | in_ch = 3 * self.layer_channels[0]
73 | for i in range(1, len(self.layer_channels)-1):
74 | out_ch = self.layer_channels[i]
75 | dense = EquiLinearRegToReg(in_ch, out_ch, self.reg_dim)
76 | denses.append(dense)
77 | conv = EquiCtsConv2dRegToReg(in_channels = in_ch,
78 | out_channels = out_ch,
79 | num_radii = self.num_radii,
80 | num_theta = self.num_theta,
81 | radius = self.radius_scale,
82 | k = self.reg_dim)
83 | convs.append(conv)
84 | in_ch = self.layer_channels[i]
85 |
86 | out_ch = self.layer_channels[-1]
87 | dense = nn.Sequential(
88 | EquiLinearRegToReg(in_ch, out_ch, self.reg_dim),
89 | EquiLinearRegToRho1(self.reg_dim),
90 | )
91 | denses.append(dense)
92 | conv = EquiCtsConv2dRegToRho1(in_channels = in_ch,
93 | out_channels = out_ch,
94 | num_radii = self.num_radii,
95 | num_theta = self.num_theta,
96 | radius = self.radius_scale,
97 | k = self.reg_dim)
98 | convs.append(conv)
99 |
100 | self.convs = nn.ModuleList(convs)
101 | self.denses = nn.ModuleList(denses)
102 |
103 |
104 | def update_pos_vel(self, p0, v0, a):
105 | """Apply acceleration and integrate position and velocity.
106 | Assume the particle has constant acceleration during timestep.
107 | Return particle's position and velocity after 1 unit timestep."""
108 |
109 | dt = self.timestep
110 | v1 = v0 + dt * a
111 | p1 = p0 + dt * (v0 + v1) / 2
112 | return p1, v1
113 |
114 | def apply_correction(self, p0, p1, correction):
115 | """Apply the position correction
116 | p0, p1: the position of the particle before/after basic integration. """
117 | dt = self.timestep
118 | p_corrected = p1 + correction
119 | v_corrected = (p_corrected - p0) / dt
120 | return p_corrected, v_corrected
121 |
122 | def compute_correction(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask):
123 | """Precondition: p and v were updated with accerlation"""
124 |
125 | fluid_feats = [v.unsqueeze(-2)]
126 | if not other_feats is None:
127 | fluid_feats.append(other_feats)
128 | fluid_feats = torch.cat(fluid_feats, -2)
129 |
130 | # compute the correction by accumulating the output through the network layers
131 | output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask)
132 | output_dense_fluid = self.dense_fluid(fluid_feats)
133 | output_conv_obstacle = self.conv_obstacle(box, p, box_feats.unsqueeze(-2), box_mask)
134 |
135 | feats = torch.cat((output_conv_obstacle, output_conv_fluid, output_dense_fluid), -2)
136 | # self.outputs = [feats]
137 | output = feats
138 |
139 | for conv, dense in zip(self.convs, self.denses):
140 | # pass input features to conv and fully-connected layers
141 | # mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1)
142 | # in_feats = output/mags * self.activation(mags - self.relu_shift)
143 | in_feats = self.activation(output)
144 | # in_feats = output
145 | output_conv = conv(p, p, in_feats, fluid_mask)
146 | output_dense = dense(in_feats)
147 |
148 | # if last dim size of output from cur dense layer is same as last dim size of output
149 | # current output should be based off on previous output
150 | if output_dense.shape[-2] == output.shape[-2]:
151 | output = output_conv + output_dense + output
152 | else:
153 | output = output_conv + output_dense
154 | # self.outputs.append(output)
155 |
156 | # compute the number of fluid particle neighbors.
157 | # this info is used in the loss function during training.
158 | # TODO: test this block of code
159 | self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1
160 |
161 | # self.last_features = self.outputs[-2]
162 |
163 | # scale to better match the scale of the output distribution
164 | self.pos_correction = (1.0 / 128) * output
165 | return self.pos_correction
166 |
167 | def forward(self, inputs, states=None):
168 | """ inputs: 8 elems tuple
169 | p0_enc, v0_enc, p0, v0, a, feats, box, box_feats
170 | v0_enc: [batch, num_part, timestamps, 2]
171 | Computes 1 simulation timestep"""
172 | p0_enc, v0_enc, p0, v0, a, other_feats, box, box_feats, fluid_mask, box_mask = inputs
173 |
174 | if states is None:
175 | if other_feats is None:
176 | feats = v0_enc
177 | else:
178 | feats = torch.cat((other_feats, v0_enc), -2)
179 | else:
180 | if other_feats is None:
181 | feats = v0_enc
182 | feats = torch.cat((states[0][...,1:,:], feats), -2)
183 | else:
184 | feats = torch.cat((other_feats, states[0][...,1:,:], v0_enc), -2)
185 | # print(feats.shape)
186 |
187 | # a = (v0 - v0_enc[...,-1,:]) / self.timestep
188 | p1, v1 = self.update_pos_vel(p0, v0, a)
189 | pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fluid_mask, box_mask)
190 | p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction.squeeze(-2))
191 |
192 | return p_corrected, v_corrected, (feats, None)
193 |
194 |
195 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import sys
4 | import numpy as np
5 | sys.path.append('..')
6 | from collections import namedtuple
7 | import time
8 | import pickle
9 | import argparse
10 | from evaluate_network import evaluate
11 | from argoverse.map_representation.map_api import ArgoverseMap
12 | from datasets.argoverse_lane_loader import read_pkl_data
13 | from train_utils import *
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.optim as optim
18 | import torch.nn.functional as F
19 |
20 |
21 | parser = argparse.ArgumentParser(description="Training setting and hyperparameters")
22 | parser.add_argument('--cuda_visible_devices', default='0,1,2,3')
23 | parser.add_argument('--dataset_path', default='/path/to/argoverse_forecasting/',
24 | help='path to dataset folder, which contains train and val folders')
25 | parser.add_argument('--train_window', default=4, type=int, help='how many timestamps to iterate in training')
26 | parser.add_argument('--batch_divide', default=1, type=int,
27 | help='divide one batch into several packs, and train them iterativelly.')
28 | parser.add_argument('--epochs', default=70, type=int)
29 | parser.add_argument('--batches_per_epoch', default=600, type=int,
30 | help='determine the number of batches to train in one epoch')
31 | parser.add_argument('--base_lr', default=0.001, type=float)
32 | parser.add_argument('--batch_size', default=16, type=int)
33 | parser.add_argument('--model_name', default='ecco_trained_model', type=str)
34 | parser.add_argument('--val_batches', default=50, type=int,
35 | help='the number of batches of data to split as validation set')
36 | parser.add_argument('--val_batch_size', default=32, type=int)
37 | parser.add_argument('--train', default=False, action='store_true')
38 | parser.add_argument('--evaluation', default=False, action='store_true')
39 |
40 | feature_parser = parser.add_mutually_exclusive_group(required=False)
41 | feature_parser.add_argument('--rho1', dest='representation', action='store_false')
42 | feature_parser.add_argument('--rho-reg', dest='representation', action='store_true')
43 | parser.set_defaults(representation=True)
44 |
45 | args = parser.parse_args()
46 |
47 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_visible_devices
48 |
49 | model_name = args.model_name
50 |
51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52 |
53 | val_path = os.path.join(args.dataset_path, 'val', 'lane_data')
54 | train_path = os.path.join(args.dataset_path, 'train', 'lane_data')
55 |
56 | def create_model():
57 | if args.representation:
58 | from models.rho_reg_ECCO import ECCONetwork
59 | """Returns an instance of the network for training and evaluation"""
60 | model = model = ECCONetwork(radius_scale = 40,
61 | layer_channels = [8, 16, 8, 8, 1],
62 | encoder_hidden_size=18)
63 | else:
64 | from models.rho1_ECCO import ECCONetwork
65 | """Returns an instance of the network for training and evaluation"""
66 | model = ECCONetwork(radius_scale = 40, encoder_hidden_size=18,
67 | layer_channels = [16, 32, 32, 32, 1],
68 | num_radii = 3)
69 | return model
70 |
71 | class MyDataParallel(torch.nn.DataParallel):
72 | """
73 | Allow nn.DataParallel to call model's attributes.
74 | """
75 | def __getattr__(self, name):
76 | try:
77 | return super().__getattr__(name)
78 | except AttributeError:
79 | return getattr(self.module, name)
80 |
81 | def train():
82 | am = ArgoverseMap()
83 |
84 | val_dataset = read_pkl_data(val_path, batch_size=args.val_batch_size, shuffle=True, repeat=False, max_lane_nodes=700)
85 |
86 | dataset = read_pkl_data(train_path, batch_size=args.batch_size // args.batch_divide,
87 | repeat=True, shuffle=True, max_lane_nodes=900)
88 |
89 | data_iter = iter(dataset)
90 |
91 | model = create_model().to(device)
92 | # model_ = torch.load(model_name + '.pth')
93 | # model = model_
94 | model = MyDataParallel(model)
95 | optimizer = torch.optim.Adam(model.parameters(), args.base_lr,betas=(0.9, 0.999), weight_decay=4e-4)
96 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.95)
97 |
98 | def train_one_batch(model, batch, train_window=2):
99 |
100 | batch_size = args.batch_size
101 |
102 | inputs = ([
103 | batch['pos_2s'], batch['vel_2s'],
104 | batch['pos0'], batch['vel0'],
105 | batch['accel'], None,
106 | batch['lane'], batch['lane_norm'],
107 | batch['car_mask'], batch['lane_mask']
108 | ])
109 |
110 | # print_inputs_shape(inputs)
111 | # print(batch['pos0'])
112 | pr_pos1, pr_vel1, states = model(inputs)
113 | gt_pos1 = batch['pos1']
114 | # print(pr_pos1)
115 |
116 | # losses = 0.5 * loss_fn(pr_pos1, gt_pos1, model.num_fluid_neighbors.unsqueeze(-1), batch['car_mask'])
117 | losses = 0.5 * loss_fn(pr_pos1, gt_pos1, torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1))
118 | del gt_pos1
119 |
120 | pos0 = batch['pos0']
121 | vel0 = batch['vel0']
122 | for i in range(train_window-1):
123 | pos_enc = torch.unsqueeze(pos0, 2)
124 | vel_enc = torch.unsqueeze(vel0, 2)
125 | inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, batch['accel'], None,
126 | batch['lane'],
127 | batch['lane_norm'],batch['car_mask'], batch['lane_mask'])
128 | pos0, vel0 = pr_pos1, pr_vel1
129 | # del pos_enc, vel_enc
130 |
131 | pr_pos1, pr_vel1, states = model(inputs, states)
132 | gt_pos1 = batch['pos'+str(i+2)]
133 |
134 | losses += 0.5 * loss_fn(pr_pos1, gt_pos1,
135 | torch.sum(batch['car_mask'], dim = -2) - 1, batch['car_mask'].squeeze(-1))
136 |
137 |
138 | total_loss = 128 * torch.sum(losses,axis=0) / batch_size
139 |
140 | return total_loss
141 |
142 | epochs = args.epochs
143 | batches_per_epoch = args.batches_per_epoch # batchs_per_epoch. Dataset is too large to run whole data.
144 | data_load_times = [] # Per batch
145 | train_losses = []
146 | valid_losses = []
147 | valid_metrics_list = []
148 | min_loss = None
149 |
150 | for i in range(epochs):
151 | epoch_start_time = time.time()
152 |
153 | model.train()
154 | epoch_train_loss = 0
155 | sub_idx = 0
156 |
157 | print("training ... epoch " + str(i + 1), end='')
158 | for batch_itr in range(batches_per_epoch * args.batch_divide):
159 |
160 | data_fetch_start = time.time()
161 | batch = next(data_iter)
162 |
163 | if sub_idx == 0:
164 | optimizer.zero_grad()
165 | if (batch_itr // args.batch_divide) % 25 == 0:
166 | print("... batch " + str((batch_itr // args.batch_divide) + 1), end='', flush=True)
167 | sub_idx += 1
168 |
169 | batch_size = len(batch['pos0'])
170 |
171 | batch_tensor = {}
172 | convert_keys = (['pos' + str(i) for i in range(args.train_window + 1)] +
173 | ['vel' + str(i) for i in range(args.train_window + 1)] +
174 | ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])
175 |
176 | for k in convert_keys:
177 | batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device)
178 |
179 | for k in ['car_mask', 'lane_mask']:
180 | batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device).unsqueeze(-1)
181 |
182 | for k in ['track_id' + str(i) for i in range(31)] + ['city']:
183 | batch_tensor[k] = batch[k]
184 |
185 | batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)
186 | accel = torch.zeros(batch_size, 1, 2).to(device)
187 | batch_tensor['accel'] = accel
188 | del batch
189 |
190 | data_fetch_latency = time.time() - data_fetch_start
191 | data_load_times.append(data_fetch_latency)
192 |
193 | current_loss = train_one_batch(model, batch_tensor, train_window=args.train_window)
194 |
195 | if sub_idx < args.batch_divide:
196 | current_loss.backward(retain_graph=True)
197 | else:
198 | current_loss.backward()
199 | optimizer.step()
200 | sub_idx = 0
201 | del batch_tensor
202 |
203 | epoch_train_loss += float(current_loss)
204 | del current_loss
205 | clean_cache(device)
206 |
207 | if batch_itr == batches_per_epoch - 1:
208 | print("... DONE", flush=True)
209 |
210 | train_losses.append(epoch_train_loss)
211 |
212 | model.eval()
213 | with torch.no_grad():
214 | valid_total_loss, valid_metrics = evaluate(model.module, val_dataset,
215 | train_window=args.train_window,
216 | max_iter=args.val_batches,
217 | device=device,
218 | batch_size=args.val_batch_size)
219 |
220 | valid_losses.append(float(valid_total_loss))
221 | valid_metrics_list.append(valid_metrics)
222 |
223 | if min_loss is None:
224 | min_loss = valid_losses[-1]
225 |
226 | if valid_losses[-1] < min_loss:
227 | print('update weights')
228 | min_loss = valid_losses[-1]
229 | best_model = model
230 | torch.save(model.module, model_name + ".pth")
231 |
232 | epoch_end_time = time.time()
233 |
234 | print('epoch: {}, train loss: {}, val loss: {}, epoch time: {}, lr: {}, {}'.format(
235 | i + 1, train_losses[-1], valid_losses[-1],
236 | round((epoch_end_time - epoch_start_time) / 60, 5),
237 | format(get_lr(optimizer), "5.2e"), model_name
238 | ))
239 |
240 | scheduler.step()
241 |
242 |
243 | def evaluation():
244 | am = ArgoverseMap()
245 |
246 | val_dataset = read_pkl_data(val_path, batch_size=args.val_batch_size, shuffle=False, repeat=False)
247 |
248 | trained_model = torch.load(model_name + '.pth')
249 | trained_model.eval()
250 |
251 | with torch.no_grad():
252 | valid_total_loss, valid_metrics = evaluate(trained_model, val_dataset,
253 | train_window=args.train_window, max_iter=len(val_dataset),
254 | device=device, start_iter=args.val_batches, use_lane=args.use_lane,
255 | batch_size=args.val_batch_size)
256 |
257 | with open('results/{}_predictions.pickle'.format(model_name), 'wb') as f:
258 | pickle.dump(valid_metrics, f)
259 |
260 |
261 | if __name__ == '__main__':
262 | if args.train:
263 | train()
264 |
265 | if args.evaluation:
266 | evaluation()
267 |
268 |
269 |
--------------------------------------------------------------------------------
/datasets/preprocess_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import numpy as np
4 | import pandas as pd
5 | import sys
6 | sys.path.append('..')
7 | from datasets.helper import get_lane_direction
8 | # from tensorpack import dataflow
9 | import time
10 | import gc
11 | import pickle
12 | import helper
13 | import time
14 | import glob
15 | from argoverse.map_representation.map_api import ArgoverseMap
16 | from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
17 |
18 | dataset_path = '/path/to/dataset/argoverse_forecasting/'
19 |
20 | val_path = os.path.join(dataset_path, 'val', 'data')
21 | train_path = os.path.join(dataset_path, 'train', 'data')
22 |
23 | class ArgoverseTest(object):
24 | """
25 | Data flow for argoverse dataset
26 | """
27 |
28 | def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False,
29 | max_car_num: int = 50, freq: int = 10, use_interpolate: bool = False,
30 | use_lane: bool = False, use_mask: bool = True):
31 | if not os.path.exists(file_path):
32 | raise Exception("Path does not exist.")
33 |
34 | self.afl = ArgoverseForecastingLoader(file_path)
35 | self.shuffle = shuffle
36 | self.random_rotation = random_rotation
37 | self.max_car_num = max_car_num
38 | self.freq = freq
39 | self.use_interpolate = use_interpolate
40 | self.am = ArgoverseMap()
41 | self.use_mask = use_mask
42 | self.file_path = file_path
43 |
44 |
45 | def get_feat(self, scene):
46 |
47 | data, city = self.afl[scene].seq_df, self.afl[scene].city
48 |
49 | lane = np.array([[0., 0.]], dtype=np.float32)
50 | lane_drct = np.array([[0., 0.]], dtype=np.float32)
51 |
52 |
53 | tstmps = data.TIMESTAMP.unique()
54 | tstmps.sort()
55 |
56 | data = self._filter_imcomplete_data(data, tstmps, 50)
57 |
58 | data = self._calc_vel(data, self.freq)
59 |
60 | agent = data[data['OBJECT_TYPE'] == 'AGENT']['TRACK_ID'].values[0]
61 |
62 | car_mask = np.zeros((self.max_car_num, 1), dtype=np.float32)
63 | car_mask[:len(data.TRACK_ID.unique())] = 1.0
64 |
65 | feat_dict = {'city': city,
66 | 'lane': lane,
67 | 'lane_norm': lane_drct,
68 | 'scene_idx': scene,
69 | 'agent_id': agent,
70 | 'car_mask': car_mask}
71 |
72 | pos_enc = [subdf[['X', 'Y']].values[np.newaxis,:]
73 | for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]
74 | pos_enc = np.concatenate(pos_enc, axis=0)
75 | # pos_enc = self._expand_dim(pos_enc)
76 | feat_dict['pos_2s'] = self._expand_particle(pos_enc, self.max_car_num, 0)
77 |
78 | vel_enc = [subdf[['vel_x', 'vel_y']].values[np.newaxis,:]
79 | for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]
80 | vel_enc = np.concatenate(vel_enc, axis=0)
81 | # vel_enc = self._expand_dim(vel_enc)
82 | feat_dict['vel_2s'] = self._expand_particle(vel_enc, self.max_car_num, 0)
83 |
84 | pos = data[data['TIMESTAMP'] == tstmps[19]][['X', 'Y']].values
85 | # pos = self._expand_dim(pos)
86 | feat_dict['pos0'] = self._expand_particle(pos, self.max_car_num, 0)
87 | vel = data[data['TIMESTAMP'] == tstmps[19]][['vel_x', 'vel_y']].values
88 | # vel = self._expand_dim(vel)
89 | feat_dict['vel0'] = self._expand_particle(vel, self.max_car_num, 0)
90 | track_id = data[data['TIMESTAMP'] == tstmps[19]]['TRACK_ID'].values
91 | feat_dict['track_id0'] = self._expand_particle(track_id, self.max_car_num, 0, 'str')
92 | feat_dict['frame_id0'] = 0
93 |
94 | for t in range(31):
95 | pos = data[data['TIMESTAMP'] == tstmps[19 + t]][['X', 'Y']].values
96 | # pos = self._expand_dim(pos)
97 | feat_dict['pos' + str(t)] = self._expand_particle(pos, self.max_car_num, 0)
98 | vel = data[data['TIMESTAMP'] == tstmps[19 + t]][['vel_x', 'vel_y']].values
99 | # vel = self._expand_dim(vel)
100 | feat_dict['vel' + str(t)] = self._expand_particle(vel, self.max_car_num, 0)
101 | track_id = data[data['TIMESTAMP'] == tstmps[19 + t]]['TRACK_ID'].values
102 | feat_dict['track_id' + str(t)] = self._expand_particle(track_id, self.max_car_num, 0, 'str')
103 | feat_dict['frame_id' + str(t)] = t
104 |
105 | return feat_dict
106 |
107 | def __len__(self):
108 | return len(glob.glob(os.path.join(self.file_path, '*')))
109 |
110 | @classmethod
111 | def _expand_df(cls, data, city_name):
112 | timestps = data['TIMESTAMP'].unique().tolist()
113 | ids = data['TRACK_ID'].unique().tolist()
114 | df = pd.DataFrame({'TIMESTAMP': timestps * len(ids)}).sort_values('TIMESTAMP')
115 | df['TRACK_ID'] = ids * len(timestps)
116 | df['CITY_NAME'] = city_name
117 | return pd.merge(data, df, on=['TIMESTAMP', 'TRACK_ID'], how='right')
118 |
119 |
120 | @classmethod
121 | def __calc_vel_generator(cls, df, freq=10):
122 | for idx, subdf in df.groupby('TRACK_ID'):
123 | sub_df = subdf.copy().sort_values('TIMESTAMP')
124 | sub_df[['vel_x', 'vel_y']] = sub_df[['X', 'Y']].diff() * freq
125 | yield sub_df.iloc[1:, :]
126 |
127 | @classmethod
128 | def _calc_vel(cls, df, freq=10):
129 | return pd.concat(cls.__calc_vel_generator(df, freq=freq), axis=0)
130 |
131 | @classmethod
132 | def _expand_dim(cls, ndarr, dtype=np.float32):
133 | return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)
134 |
135 | @classmethod
136 | def _linear_interpolate_generator(cls, data, col=['X', 'Y']):
137 | for idx, df in data.groupby('TRACK_ID'):
138 | sub_df = df.copy().sort_values('TIMESTAMP')
139 | sub_df[col] = sub_df[col].interpolate(limit_direction='both')
140 | yield sub_df.ffill().bfill()
141 |
142 | @classmethod
143 | def _linear_interpolate(cls, data, col=['X', 'Y']):
144 | return pd.concat(cls._linear_interpolate_generator(data, col), axis=0)
145 |
146 | @classmethod
147 | def _filter_imcomplete_data(cls, data, tstmps, window=20):
148 | complete_id = list()
149 | for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'):
150 | if len(subdf) == window:
151 | complete_id.append(idx)
152 | return data[data['TRACK_ID'].isin(complete_id)]
153 |
154 | @classmethod
155 | def _expand_particle(cls, arr, max_num, axis, value_type='int'):
156 | dummy_shape = list(arr.shape)
157 | dummy_shape[axis] = max_num - arr.shape[axis]
158 | dummy = np.zeros(dummy_shape)
159 | if value_type == 'str':
160 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
161 | return np.concatenate([arr, dummy], axis=axis)
162 |
163 |
164 | class process_utils(object):
165 |
166 | @classmethod
167 | def expand_dim(cls, ndarr, dtype=np.float32):
168 | return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)
169 |
170 | @classmethod
171 | def expand_particle(cls, arr, max_num, axis, value_type='int'):
172 | dummy_shape = list(arr.shape)
173 | dummy_shape[axis] = max_num - arr.shape[axis]
174 | dummy = np.zeros(dummy_shape)
175 | if value_type == 'str':
176 | dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
177 | return np.concatenate([arr, dummy], axis=axis)
178 |
179 |
180 | def get_max_min(datas):
181 | mask = datas['car_mask']
182 | slicer = mask[0].astype(bool).flatten()
183 | pos_keys = ['pos0'] + ['pos_2s']
184 | max_x = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,0]
185 | .reshape(np.stack(datas[pk]).shape[0], -1),
186 | axis=-1)[...,np.newaxis]
187 | for pk in pos_keys], axis=-1)
188 | min_x = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,0]
189 | .reshape(np.stack(datas[pk]).shape[0], -1),
190 | axis=-1)[...,np.newaxis]
191 | for pk in pos_keys], axis=-1)
192 | max_y = np.concatenate([np.max(np.stack(datas[pk])[0,slicer,...,1]
193 | .reshape(np.stack(datas[pk]).shape[0], -1),
194 | axis=-1)[...,np.newaxis]
195 | for pk in pos_keys], axis=-1)
196 | min_y = np.concatenate([np.min(np.stack(datas[pk])[0,slicer,...,1]
197 | .reshape(np.stack(datas[pk]).shape[0], -1),
198 | axis=-1)[...,np.newaxis]
199 | for pk in pos_keys], axis=-1)
200 | max_x = np.max(max_x, axis=-1) + 10
201 | max_y = np.max(max_y, axis=-1) + 10
202 | min_x = np.max(min_x, axis=-1) - 10
203 | min_y = np.max(min_y, axis=-1) - 10
204 | return min_x, max_x, min_y, max_y
205 |
206 |
207 | def process_func(putil, datas, am):
208 |
209 | city = datas['city'][0]
210 | x_min, x_max, y_min, y_max = get_max_min(datas)
211 |
212 | seq_lane_props = am.city_lane_centerlines_dict[city]
213 |
214 | lane_centerlines = []
215 | lane_directions = []
216 |
217 | # Get lane centerlines which lie within the range of trajectories
218 | for lane_id, lane_props in seq_lane_props.items():
219 |
220 | lane_cl = lane_props.centerline
221 |
222 | if (
223 | np.min(lane_cl[:, 0]) < x_max
224 | and np.min(lane_cl[:, 1]) < y_max
225 | and np.max(lane_cl[:, 0]) > x_min
226 | and np.max(lane_cl[:, 1]) > y_min
227 | ):
228 | lane_centerlines.append(lane_cl[1:])
229 | lane_drct = np.diff(lane_cl, axis=0)
230 | lane_directions.append(lane_drct)
231 | if len(lane_centerlines) > 0:
232 | lane = np.concatenate(lane_centerlines, axis=0)
233 | # lane = putil.expand_dim(lane)
234 | lane_drct = np.concatenate(lane_directions, axis=0)
235 | # lane_drct = putil.expand_dim(lane_drct)[...,:3]
236 |
237 | datas['lane'] = [lane]
238 | datas['lane_norm'] = [lane_drct]
239 | return datas
240 | else:
241 | return datas
242 |
243 |
244 | if __name__ == '__main__':
245 | am = ArgoverseMap()
246 | putil = process_utils()
247 |
248 | afl_train = ArgoverseForecastingLoader(os.path.join(dataset_path, 'train', 'data'))
249 | afl_val = ArgoverseForecastingLoader(os.path.join(dataset_path, 'val', 'data'))
250 | at_train = ArgoverseTest(os.path.join(dataset_path, 'train', 'data'), max_car_num=60)
251 | at_val = ArgoverseTest(os.path.join(dataset_path, 'val', 'data'), max_car_num=60)
252 |
253 |
254 | print("++++++++++++++++++++ START TRAIN ++++++++++++++++++++")
255 | train_num = len(afl_train)
256 | batch_start = time.time()
257 | os.mkdir(os.path.join(dataset_path, 'train/lane_data'))
258 | for i, scene in enumerate(range(train_num)):
259 | if i % 1000 == 0:
260 | batch_end = time.time()
261 | print("SAVED ============= {} / {} ....... {}".format(i, train_num, batch_end - batch_start))
262 | batch_start = time.time()
263 |
264 | data = {k:[v] for k, v in at_train.get_feat(scene).items()}
265 | datas = process_func(putil, data, am)
266 | with open(os.path.join(dataset_path, 'train/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:
267 | pickle.dump(datas, f)
268 |
269 | print("++++++++++++++++++++ START VAL ++++++++++++++++++++")
270 | val_num = len(afl_val)
271 | batch_start = time.time()
272 | os.mkdir(os.path.join(dataset_path, 'val/lane_data'))
273 | for i, scene in enumerate(range(val_num)):
274 | if i % 1000 == 0:
275 | batch_end = time.time()
276 | print("SAVED ============= {} / {} ....... {}".format(i, val_num, batch_end - batch_start))
277 | batch_start = time.time()
278 |
279 | data = {k:[v] for k, v in at_val.get_feat(scene).items()}
280 | datas = process_func(putil, data, am)
281 | with open(os.path.join(dataset_path, 'val/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:
282 | pickle.dump(datas, f)
283 |
284 |
285 |
--------------------------------------------------------------------------------
/models/RelEquiCtsConv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 |
7 | from EquiCtsConv import *
8 |
9 |
10 | class RelEquiCtsConv2d(EquiCtsConv2d):
11 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, matrix_dim=2,
12 | use_attention=True, normalize_attention=True):
13 | super(RelEquiCtsConv2d, self).__init__(in_channels, out_channels, radius, num_radii, num_theta,
14 | matrix_dim, use_attention, normalize_attention)
15 |
16 | def ContinuousConv(
17 | self, field, center, field_feat,
18 | field_mask, ctr_feat
19 | ):
20 | """
21 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
22 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim]
23 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim]
24 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2]
25 | @ctr_feat: [batch, 1, feat_dim]
26 | @field_mask: [batch, num_n, 1]
27 | """
28 | kernel = self.computeKernel()
29 |
30 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
31 | # relative_field: [batch, num_m, num_n, pos_dim]
32 |
33 |
34 | polar_field = self.PolarCoords(relative_field)
35 | # polar_field: [batch, num_m, num_n, pos_dim]
36 |
37 | kernel_on_field = self.InterpolateKernel(kernel, polar_field)
38 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2]
39 |
40 | if self.use_attention:
41 | # print(relative_field.shape)
42 | # print(field_mask.unsqueeze(1).shape)
43 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1)
44 | # attention: [batch, num_m, num_n, 1]
45 |
46 | if self.normalize_attention:
47 | psi = torch.sum(attention, axis=2).squeeze(-1)
48 | psi[psi == 0.] = 1
49 | psi = psi.unsqueeze(-1).unsqueeze(-1)
50 | else:
51 | psi = 1.0
52 | else:
53 | attention = torch.ones(*relative_field.shape[0:3],1)
54 |
55 | if self.normalize_attention:
56 | psi = torch.sum(attention, axis=2).squeeze(-1)
57 | psi[psi == 0.] = 1
58 | psi = psi.unsqueeze(-1).unsqueeze(-1)
59 | else:
60 | psi = 1.0
61 |
62 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2)
63 | attention_field_feat = field_feat * attention.unsqueeze(-1)
64 | # attention_field_feat: [batch, num_m, num_n, c_in, 2]
65 |
66 | # print(kernel_on_field.shape, attention_field_feat.shape)
67 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat)
68 | # out: [batch, num_m, c_out, 2]
69 |
70 | return out / psi
71 |
72 | def forward(
73 | self, field, center, field_feat,
74 | field_mask, ctr_feat, normalize_attention=False
75 | ):
76 | out = self.ContinuousConv(
77 | field, center, field_feat, field_mask,
78 | ctr_feat
79 | )
80 | return out
81 |
82 | class RelEquiCtsConv2dRegToRho1(EquiCtsConv2dRegToRho1):
83 |
84 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
85 | use_attention=True, normalize_attention=True):
86 | super(RelEquiCtsConv2dRegToRho1, self).__init__(in_channels, out_channels, radius, num_radii, num_theta,
87 | k, matrix_dim, use_attention, normalize_attention)
88 |
89 | def ContinuousConv(
90 | self, field, center, field_feat,
91 | field_mask, ctr_feat
92 | ):
93 | """
94 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
95 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim]
96 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim]
97 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2]
98 | @ctr_feat: [batch, 1, feat_dim]
99 | @field_mask: [batch, num_n, 1]
100 | """
101 | kernel = self.computeKernel()
102 |
103 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
104 | # relative_field: [batch, num_m, num_n, pos_dim]
105 |
106 |
107 | polar_field = self.PolarCoords(relative_field)
108 | # polar_field: [batch, num_m, num_n, pos_dim]
109 |
110 | kernel_on_field = self.InterpolateKernel(kernel, polar_field)
111 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2]
112 |
113 | if self.use_attention:
114 | # print(relative_field.shape)
115 | # print(field_mask.unsqueeze(1).shape)
116 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1)
117 | # attention: [batch, num_m, num_n, 1]
118 |
119 | if self.normalize_attention:
120 | psi = torch.sum(attention, axis=2).squeeze(-1)
121 | psi[psi == 0.] = 1
122 | psi = psi.unsqueeze(-1).unsqueeze(-1)
123 | else:
124 | psi = 1.0
125 | else:
126 | attention = torch.ones(*relative_field.shape[0:3],1)
127 |
128 | if self.normalize_attention:
129 | psi = torch.sum(attention, axis=2).squeeze(-1)
130 | psi[psi == 0.] = 1
131 | psi = psi.unsqueeze(-1).unsqueeze(-1)
132 | else:
133 | psi = 1.0
134 |
135 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2)
136 |
137 | attention_field_feat = field_feat*attention.unsqueeze(-1)
138 | # attention_field_feat: [batch, num_m, num_n, c_in, 2]
139 |
140 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat)
141 | # out: [batch, num_m, c_out, 2]
142 |
143 | return out / psi
144 |
145 | def forward(
146 | self, field, center, field_feat,
147 | field_mask, ctr_feat=None
148 | ):
149 | out = self.ContinuousConv(
150 | field, center, field_feat, field_mask,
151 | ctr_feat
152 | )
153 | return out
154 |
155 | class RelEquiCtsConv2dRho1ToReg(EquiCtsConv2dRho1ToReg):
156 |
157 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
158 | use_attention=True, normalize_attention=True):
159 | super(RelEquiCtsConv2dRho1ToReg, self).__init__(in_channels, out_channels, radius, num_radii, num_theta,
160 | k, matrix_dim, use_attention, normalize_attentionv)
161 |
162 |
163 | def ContinuousConv(
164 | self, field, center, field_feat,
165 | field_mask, ctr_feat
166 | ):
167 | """
168 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
169 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim]
170 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim]
171 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2]
172 | @ctr_feat: [batch, 1, feat_dim]
173 | @field_mask: [batch, num_n, 1]
174 | """
175 | kernel = self.computeKernel()
176 |
177 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
178 | # relative_field: [batch, num_m, num_n, pos_dim]
179 |
180 |
181 | polar_field = self.PolarCoords(relative_field)
182 | # polar_field: [batch, num_m, num_n, pos_dim]
183 |
184 | kernel_on_field = self.InterpolateKernel(kernel, polar_field)
185 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2]
186 |
187 | if self.use_attention:
188 | # print(relative_field.shape)
189 | # print(field_mask.unsqueeze(1).shape)
190 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1)
191 | # attention: [batch, num_m, num_n, 1]
192 |
193 | if self.normalize_attention:
194 | psi = torch.sum(attention, axis=2).squeeze(-1)
195 | psi[psi == 0.] = 1
196 | psi = psi.unsqueeze(-1).unsqueeze(-1)
197 | else:
198 | psi = 1.0
199 | else:
200 | attention = torch.ones(*relative_field.shape[0:3],1)
201 |
202 | if self.normalize_attention:
203 | psi = torch.sum(attention, axis=2).squeeze(-1)
204 | psi[psi == 0.] = 1
205 | psi = psi.unsqueeze(-1).unsqueeze(-1)
206 | else:
207 | psi = 1.0
208 |
209 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2)
210 |
211 | attention_field_feat = field_feat*attention.unsqueeze(-1)
212 | # attention_field_feat: [batch, num_m, num_n, c_in, 2]
213 |
214 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat)
215 | # out: [batch, num_m, c_out, 2]
216 |
217 | return out / psi
218 |
219 | def forward(
220 | self, field, center, field_feat,
221 | field_mask, ctr_feat=None
222 | ):
223 | out = self.ContinuousConv(
224 | field, center, field_feat, field_mask,
225 | ctr_feat
226 | )
227 | return out
228 |
229 | class RelEquiCtsConv2dRegToReg(EquiCtsConv2dRegToReg):
230 |
231 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
232 | use_attention=True, normalize_attention=True):
233 | super(RelEquiCtsConv2dRegToReg, self).__init__(in_channels, out_channels, radius, num_radii, num_theta,
234 | k, matrix_dim, use_attention, normalize_attention)
235 | self.kernel = self.computeKernel()
236 |
237 |
238 | def ContinuousConv(
239 | self, field, center, field_feat,
240 | field_mask, ctr_feat
241 | ):
242 | """
243 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
244 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim]
245 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim]
246 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2]
247 | @ctr_feat: [batch, 1, feat_dim]
248 | @field_mask: [batch, num_n, 1]
249 | """
250 | kernel = self.computeKernel()
251 |
252 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
253 | # relative_field: [batch, num_m, num_n, pos_dim]
254 |
255 |
256 | polar_field = self.PolarCoords(relative_field)
257 | # polar_field: [batch, num_m, num_n, pos_dim]
258 |
259 | kernel_on_field = self.InterpolateKernel(kernel, polar_field)
260 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2]
261 |
262 | if self.use_attention:
263 | # print(relative_field.shape)
264 | # print(field_mask.unsqueeze(1).shape)
265 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1)
266 | # attention: [batch, num_m, num_n, 1]
267 |
268 | if self.normalize_attention:
269 | psi = torch.sum(attention, axis=2).squeeze(-1)
270 | psi[psi == 0.] = 1
271 | psi = psi.unsqueeze(-1).unsqueeze(-1)
272 | else:
273 | psi = 1.0
274 | else:
275 | attention = torch.ones(*relative_field.shape[0:3],1)
276 |
277 | if self.normalize_attention:
278 | psi = torch.sum(attention, axis=2).squeeze(-1)
279 | psi[psi == 0.] = 1
280 | psi = psi.unsqueeze(-1).unsqueeze(-1)
281 | else:
282 | psi = 1.0
283 |
284 | field_feat = field_feat.unsqueeze(1) - ctr_feat.unsqueeze(2)
285 |
286 | attention_field_feat = field_feat*attention.unsqueeze(-1)
287 | # attention_field_feat: [batch, num_m, num_n, c_in, 2]
288 |
289 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat)
290 | # out: [batch, num_m, c_out, 2]
291 |
292 | return out / psi
293 |
294 | def forward(
295 | self, field, center, field_feat,
296 | field_mask, ctr_feat=None
297 | ):
298 | out = self.ContinuousConv(
299 | field, center, field_feat, field_mask,
300 | ctr_feat
301 | )
302 | return out
303 |
--------------------------------------------------------------------------------
/models/EquiCtsConv.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 | from abc import ABCMeta, abstractmethod
7 | from EquiLinear import *
8 |
9 |
10 | class EquiCtsConvBase(nn.Module, metaclass=ABCMeta):
11 | def __init__(self):
12 | super(EquiCtsConvBase, self).__init__()
13 |
14 | @abstractmethod
15 | def computeKernel(self):
16 | pass
17 |
18 | def GetAttention(self, relative_field):
19 | r = torch.sum(relative_field ** 2, axis=-1)
20 | return torch.relu((1 - r) ** 3).unsqueeze(-1)
21 |
22 | @classmethod
23 | def RotMat(cls, theta):
24 | m = torch.tensor([
25 | [torch.cos(theta), -torch.sin(theta)],
26 | [torch.sin(theta), torch.cos(theta)]
27 | ], requires_grad=False)
28 | return m
29 |
30 | @classmethod
31 | def Rho1RotMat(cls, theta):
32 | m = torch.tensor([
33 | [torch.cos(theta), -torch.sin(theta)],
34 | [torch.sin(theta), torch.cos(theta)]
35 | ], requires_grad=False)
36 | return m
37 |
38 | @classmethod
39 | def RegRotMat(cls, theta, k):
40 | slice_angle = 2 * math.pi / k
41 | index_shift = theta / slice_angle
42 | i = np.floor(index_shift).astype(np.int)
43 | # divide weights between i and i+1 first_col = [ 0 0 0 ... 0 w_i w_{i+1} 0 0 0 ... 0 0 0]
44 | first_col = torch.zeros(k)
45 |
46 | offset = (theta - slice_angle * i) / slice_angle
47 | w_i = 1 - offset
48 | w_ip = offset
49 | first_col[np.mod(i,8)], first_col[np.mod(i+1,8)] = w_i, w_ip
50 |
51 | m = torch.stack([torch.roll(first_col, i, 0) for i in range(k)], -1)
52 | #like a permuation matrix which sends i -> i + \theta / ( 2 \pi / k )
53 | #Note if k = num_theta, then it is a permutation matrix
54 | return m
55 |
56 | def PolarCoords(self, vec, epsilon = 1e-9):
57 | # vec: [batch, num_m, num_n, pos_dim]
58 | # Convert to Polar
59 | r = torch.sqrt(vec[...,0] **2 + vec[...,1] **2 + epsilon)
60 |
61 | cond_nonzero = ~((vec[...,0] == 0.) & (vec[...,1] == 0.))
62 |
63 | theta = torch.zeros(vec[...,0].shape, device=self.outer_weights.device)
64 | theta[cond_nonzero] = torch.atan2(vec[...,1][cond_nonzero], vec[...,0][cond_nonzero])
65 |
66 | out = [r, theta]
67 | out = torch.stack(out, -1)
68 | return out
69 |
70 | def InterpolateKernel(self, kernel, pos):
71 | """
72 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2] -> [batch, C=c_out*c_in*4, r, theta]
73 | @pos: [batch, num_m, num_n, 2] -> [batch, num_m, num_n, 2]
74 |
75 | return out: [batch, C=c_out*c_in*4, num_m, num_n] -> [batch, num_m, num_n, c_out, c_in, 2, 2]
76 | """
77 | # kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
78 | kernels = kernel.permute(0, 1, 4, 5, 2, 3)
79 | # kernels: [c_out, c_in=feat_dim, 2, 2, r, theta]
80 |
81 | kernels = kernels.reshape(-1, *kernels.shape[4:]).unsqueeze(0)
82 | # kernels: [1, c_out*c_in*2*2, r, theta]
83 |
84 | kernels = kernels.expand((pos.shape[0], *kernels.shape[1:]))
85 | # kernels: [batch_size, c_out*c_in*2*2, r, theta]
86 | #[N, C, H, W]
87 |
88 |
89 | # Copy first and last column to wrap thetas.
90 | padded_kernels = torch.cat([
91 | kernels[..., -1].unsqueeze(-1),
92 | kernels,
93 | kernels[..., 0].unsqueeze(-1)
94 | ],dim = -1)
95 | padded_kernels = padded_kernels.permute(0,1,3,2)
96 | # padded_kernels: [batch, C=c_out*c_in*4, theta+2, r]
97 |
98 |
99 | grid = pos
100 | # adjust radii [0,1] -> [-1,1]
101 | grid[...,0] = 2*grid[...,0] - 1
102 | # adjust angles [-pi,pi] -> [-1,1]
103 | grid[...,1] *= 1/math.pi
104 | # shrink thetas slightly to account for padding
105 | grid[...,1] *= self.num_theta / (self.num_theta + 2)
106 | # grid [batch, num_m, num_n, 2]
107 | # [N, H_out, W_out, 2]
108 |
109 | # print("grid",grid)
110 | # print("padded_kernels_shape [batch_size, c_out*c_in*2*2, theta+2, r]:",padded_kernels.shape)
111 | #print("kernels",padded_kernels)
112 |
113 | out = F.grid_sample(padded_kernels, grid, padding_mode='zeros',
114 | mode='bilinear', align_corners=False) #bilinear
115 | # out: [batch, C=c_out*c_in*4, num_m, num_n]
116 | # [N, C, H_out, W_out]
117 |
118 | out = out.permute(0, 2, 3, 1)
119 | # out: [batch, num_m, num_n, C=c_out*c_in*4]
120 | out = out.reshape(*pos.shape[:-1], *kernel.shape[0:2], *kernel.shape[-2:])
121 | # out: [batch, num_m, num_n, c_out, c_in, 2, 2]
122 | return out
123 |
124 | def ContinuousConv(
125 | self, field, center, field_feat,
126 | field_mask, ctr_feat=None
127 | ):
128 | """
129 | @kernel: [c_out, c_in=feat_dim, r, theta, 2, 2]
130 | @field: [batch, num_n, pos_dim=2] -> [batch, 1, num_n, pos_dim]
131 | @center: [batch, num_m, pos_dim=2] -> [batch, num_m, 1, pos_dim]
132 | @field_feat: [batch, num_n, c_in=feat_dim, 2] -> [batch, 1, num_n, c_in, 2]
133 | @ctr_feat: [batch, 1, feat_dim]
134 | @field_mask: [batch, num_n, 1]
135 | """
136 | kernel = self.computeKernel()
137 |
138 | relative_field = (field.unsqueeze(1) - center.unsqueeze(2)) / self.radius
139 | # relative_field: [batch, num_m, num_n, pos_dim]
140 |
141 |
142 | polar_field = self.PolarCoords(relative_field)
143 | # polar_field: [batch, num_m, num_n, pos_dim]
144 |
145 | kernel_on_field = self.InterpolateKernel(kernel, polar_field)
146 | # kernel_on_field: [batch, num_m, num_n, c_out, c_in, 2, 2]
147 |
148 | if self.use_attention:
149 | # print(relative_field.shape)
150 | # print(field_mask.unsqueeze(1).shape)
151 | attention = self.GetAttention(relative_field) * field_mask.unsqueeze(1)
152 | # attention: [batch, num_m, num_n, 1]
153 |
154 | if self.normalize_attention:
155 | psi = torch.sum(attention, axis=2).squeeze(-1)
156 | psi[psi == 0.] = 1
157 | psi = psi.unsqueeze(-1).unsqueeze(-1)
158 | else:
159 | psi = 1.0
160 | else:
161 | attention = torch.ones(*relative_field.shape[0:3],1)
162 |
163 | if self.normalize_attention:
164 | psi = torch.sum(attention, axis=2).squeeze(-1)
165 | psi[psi == 0.] = 1
166 | psi = psi.unsqueeze(-1).unsqueeze(-1)
167 | else:
168 | psi = 1.0
169 |
170 | attention_field_feat = field_feat.unsqueeze(1)*attention.unsqueeze(-1)
171 | # attention_field_feat: [batch, num_m, num_n, c_in, 2]
172 |
173 | out = torch.einsum('bmnoiyx,bmnix->bmoy', kernel_on_field, attention_field_feat)
174 | # out: [batch, num_m, c_out, 2]
175 |
176 | return out / psi
177 |
178 | def forward(
179 | self, field, center, field_feat,
180 | field_mask, ctr_feat=None
181 | ):
182 | out = self.ContinuousConv(
183 | field, center, field_feat, field_mask,
184 | ctr_feat
185 | )
186 | return out
187 |
188 |
189 | class EquiCtsConv2d(EquiCtsConvBase):
190 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, matrix_dim=2,
191 | use_attention=True, normalize_attention=True):
192 | super(EquiCtsConv2d, self).__init__()
193 | self.num_theta = num_theta
194 | self.num_radii = num_radii
195 |
196 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim)
197 | self.register_buffer('kernel_basis_outer', kernel_basis_outer)
198 | self.register_buffer('kernel_bullseye', kernel_bullseye)
199 |
200 | outer_weights = torch.rand(in_channels, out_channels, num_radii, matrix_dim, matrix_dim)
201 | outer_weights -= 0.5
202 | k = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float))
203 | outer_weights *= 1 * k
204 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights)
205 |
206 | bullseye_weights = torch.rand(in_channels, out_channels)
207 | bullseye_weights -= 0.5
208 | bullseye_weights *= 1 * k
209 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights)
210 |
211 | self.radius = radius
212 |
213 | self.use_attention = use_attention
214 | self.normalize_attention = normalize_attention
215 |
216 | def computeKernel(self):
217 | # print("[r, d, d, r, theta, d, d] ",self.kernel_basis.shape)
218 | # print("[c_in,c_out, r, d , d]", self.weights.shape)
219 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) +
220 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights))
221 | return kernel
222 |
223 | def GenerateKernelBasis(self, r, theta, matrix_dim=2):
224 | """
225 | output: KB : [r+1, d, d, r+1, theta, d, d]
226 | KB_bullseye :
227 | """
228 | d = matrix_dim
229 | KB_outer = torch.zeros(r, d, d, r+1, theta, d, d, requires_grad=False)
230 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d)
231 |
232 | for i in range(d):
233 | for j in range(d):
234 | for r1 in range(0, r):
235 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d)
236 |
237 | return KB_outer, K_bullseye
238 |
239 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2):
240 | """
241 | output: K: [r, theta, d, d]
242 | """
243 | d = matrix_dim
244 | K = torch.zeros(r, theta, d, d, requires_grad=False)
245 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d)
246 | return K
247 |
248 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2):
249 | # d = matrix_dim
250 | # 0 <= i,j <= d-1
251 | # C = kernelcolumn: [theta, d, d]
252 | # C[0,:,:] = 0
253 | # C[0,i,j] = 1
254 | # for k in range(1,theta):
255 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta)
256 | # # K[g v] = g K[v] g^{-1}
257 | d = matrix_dim
258 | C = torch.zeros(theta, d, d, requires_grad=False)
259 | C[0,i,j] = 1
260 | # TODO: rho 1 -> rho n
261 | for k in range(1, theta):
262 | theta_i = torch.tensor(k*2*math.pi/theta)
263 | C[k] = self.RotMat(theta_i).matmul(C[0]).matmul(self.RotMat(-theta_i))
264 | return C
265 |
266 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2):
267 | """
268 | output: K: [r, theta, d, d]
269 | """
270 | d = matrix_dim
271 | K = torch.zeros(r, theta, d, d, requires_grad=False)
272 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d)
273 | return K
274 |
275 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2):
276 | d = matrix_dim
277 | C = torch.zeros(theta, d, d, requires_grad=False)
278 | C[:,0,0] = 1
279 | C[:,1,1] = 1
280 | return C
281 |
282 |
283 | class EquiCtsConv2dRegToRho1(EquiCtsConvBase):
284 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
285 | use_attention=True, normalize_attention=True):
286 | super(EquiCtsConv2dRegToRho1, self).__init__()
287 | self.num_theta = num_theta
288 | self.num_radii = num_radii
289 | self.k = k
290 |
291 | RegToRho1Mat = EquiLinearRegToRho1(k).RegToRho1 #[2,k]
292 | self.register_buffer('RegToRho1Mat', RegToRho1Mat)
293 |
294 |
295 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim)
296 | self.register_buffer('kernel_basis_outer', kernel_basis_outer)
297 | self.register_buffer('kernel_bullseye', kernel_bullseye)
298 |
299 | outer_weights = torch.rand(in_channels, out_channels, num_radii, matrix_dim, k)
300 | outer_weights -= 0.5
301 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float))
302 | outer_weights *= 1 * scale_norm
303 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights)
304 |
305 | bullseye_weights = torch.rand(in_channels, out_channels)
306 | bullseye_weights -= 0.5
307 | bullseye_weights *= 1 * scale_norm
308 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights)
309 |
310 | self.radius = radius
311 |
312 | self.use_attention = use_attention
313 | self.normalize_attention = normalize_attention
314 |
315 | def computeKernel(self):
316 | # print("[r, d, d, r, theta, d, k] ",self.kernel_basis.shape)
317 | # print("[c_in,c_out, r, d, k]", self.weights.shape)
318 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) +
319 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights))
320 | return kernel
321 |
322 | def GenerateKernelBasis(self, r, theta, matrix_dim=2):
323 | """
324 | output: KB : [r+1, d, k, r+1, theta, d, k]
325 | KB_bullseye :
326 | """
327 | d = matrix_dim
328 | k = self.k
329 |
330 | KB_outer = torch.zeros(r, d, k, r+1, theta, d, k, requires_grad=False)
331 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d)
332 |
333 | for i in range(d):
334 | for j in range(k):
335 | for r1 in range(0, r):
336 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d)
337 |
338 | return KB_outer, K_bullseye
339 |
340 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2):
341 | """
342 | output: K: [r, theta, d, k]
343 | """
344 | d = matrix_dim
345 | k = self.k
346 |
347 | K = torch.zeros(r, theta, d, k, requires_grad=False)
348 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d)
349 | return K
350 |
351 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2):
352 | # d = matrix_dim
353 | # 0 <= i,j <= d-1
354 | # C = kernelcolumn: [theta, d, d]
355 | # C[0,:,:] = 0
356 | # C[0,i,j] = 1
357 | # for k in range(1,theta):
358 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta)
359 | # # K[g v] = g K[v] g^{-1}
360 | d = matrix_dim
361 | k = self.k
362 |
363 | C = torch.zeros(theta, d, k, requires_grad=False)
364 | C[0,i,j] = 1
365 | # TODO: rho 1 -> rho n
366 | for ind in range(1, theta):
367 | theta_i = torch.tensor(ind*2*math.pi/theta)
368 | C[ind] = self.Rho1RotMat(theta_i).matmul(C[0]).matmul(self.RegRotMat(-theta_i.numpy(), k))
369 | return C
370 |
371 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2):
372 | """
373 | output: K: [r, theta, d, k]
374 | """
375 | d = matrix_dim
376 | k = self.k
377 |
378 | K = torch.zeros(r, theta, d, k, requires_grad=False)
379 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d)
380 | return K
381 |
382 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2):
383 | d = matrix_dim
384 | k = self.k
385 |
386 | C = torch.zeros(theta, d, k, requires_grad=False)
387 | C[:] = self.RegToRho1Mat
388 | return C
389 |
390 |
391 | class EquiCtsConv2dRho1ToReg(EquiCtsConvBase):
392 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
393 | use_attention=True, normalize_attention=True):
394 | super(EquiCtsConv2dRho1ToReg, self).__init__()
395 | self.num_theta = num_theta
396 | self.num_radii = num_radii
397 | self.k = k
398 |
399 | Rho1ToRegMat = EquiLinearRho1ToReg(k).Rho1ToReg #[k,2]
400 | self.register_buffer('Rho1ToRegMat', Rho1ToRegMat)
401 |
402 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim)
403 | self.register_buffer('kernel_basis_outer', kernel_basis_outer)
404 | self.register_buffer('kernel_bullseye', kernel_bullseye)
405 |
406 | outer_weights = torch.rand(in_channels, out_channels, num_radii, k, matrix_dim)
407 | outer_weights -= 0.5
408 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float))
409 | outer_weights *= 1 * scale_norm
410 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights)
411 |
412 | bullseye_weights = torch.rand(in_channels, out_channels)
413 | bullseye_weights -= 0.5
414 | bullseye_weights *= 1 * scale_norm
415 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights)
416 |
417 | self.radius = radius
418 |
419 | self.use_attention = use_attention
420 | self.normalize_attention = normalize_attention
421 |
422 | def computeKernel(self):
423 | # print("[r, d, d, r, theta, k, d] ",self.kernel_basis.shape)
424 | # print("[c_in,c_out, r, k, d]", self.weights.shape)
425 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) +
426 | torch.einsum('rtij,xy->yxrtij',self.kernel_bullseye,self.bullseye_weights))
427 | return kernel
428 |
429 | def GenerateKernelBasis(self, r, theta, matrix_dim=2):
430 | """
431 | output: KB : [r+1, k, d, r+1, theta, k, d]
432 | KB_bullseye :
433 | """
434 | d = matrix_dim
435 | k = self.k
436 |
437 | KB_outer = torch.zeros(r, k, d, r+1, theta, k, d, requires_grad=False)
438 | K_bullseye = self.GenerateKernelBullseyeElement(r+1, theta, d)
439 |
440 | for i in range(k):
441 | for j in range(d):
442 | for r1 in range(0, r):
443 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d)
444 |
445 | return KB_outer, K_bullseye
446 |
447 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2):
448 | """
449 | output: K: [r, theta, d, k]
450 | """
451 | d = matrix_dim
452 | k = self.k
453 |
454 | K = torch.zeros(r, theta, k, d, requires_grad=False)
455 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d)
456 | return K
457 |
458 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2):
459 | # d = matrix_dim
460 | # 0 <= i,j <= d-1
461 | # C = kernelcolumn: [theta, d, d]
462 | # C[0,:,:] = 0
463 | # C[0,i,j] = 1
464 | # for k in range(1,theta):
465 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta)
466 | # # K[g v] = g K[v] g^{-1}
467 | d = matrix_dim
468 | k = self.k
469 |
470 | C = torch.zeros(theta, k, d, requires_grad=False)
471 | C[0,i,j] = 1
472 | # TODO: rho 1 -> rho n
473 | for ind in range(1, theta):
474 | theta_i = torch.tensor(ind*2*math.pi/theta)
475 | C[ind] = self.RegRotMat(theta_i.numpy(), k).matmul(C[0]).matmul(self.Rho1RotMat(-theta_i))
476 | return C
477 |
478 | def GenerateKernelBullseyeElement(self, r, theta, matrix_dim=2):
479 | """
480 | output: K: [r, theta, d, k]
481 | """
482 | d = matrix_dim
483 | k = self.k
484 |
485 | K = torch.zeros(r, theta, k, d, requires_grad=False)
486 | K[0] = self.GenerateKernelBullseyeElementColumn(theta, d)
487 | return K
488 |
489 | def GenerateKernelBullseyeElementColumn(self, theta, matrix_dim=2):
490 | d = matrix_dim
491 | k = self.k
492 |
493 | C = torch.zeros(theta, k, d, requires_grad=False)
494 | C[:] = self.Rho1ToRegMat
495 | return C
496 |
497 |
498 | class EquiCtsConv2dRegToReg(EquiCtsConvBase):
499 | def __init__(self, in_channels, out_channels, radius, num_radii, num_theta, k, matrix_dim=2,
500 | use_attention=True, normalize_attention=True):
501 | super(EquiCtsConv2dRegToReg, self).__init__()
502 | self.num_theta = num_theta
503 | self.num_radii = num_radii
504 | self.k = k
505 |
506 |
507 | kernel_basis_outer, kernel_bullseye = self.GenerateKernelBasis(num_radii, num_theta, matrix_dim)
508 | self.register_buffer('kernel_basis_outer', kernel_basis_outer)
509 | self.register_buffer('kernel_bullseye', kernel_bullseye)
510 |
511 | outer_weights = torch.rand(in_channels, out_channels, num_radii, k, k)
512 | outer_weights -= 0.5
513 | scale_norm = 1 / torch.sqrt(torch.tensor(in_channels, dtype=torch.float))
514 | outer_weights *= 1 * scale_norm
515 | self.outer_weights = torch.nn.parameter.Parameter(outer_weights)
516 |
517 |
518 |
519 | bullseye_weights = torch.rand(in_channels, out_channels, k)
520 | bullseye_weights -= 0.5
521 | bullseye_weights *= 1 * scale_norm
522 | self.bullseye_weights = torch.nn.parameter.Parameter(bullseye_weights)
523 |
524 | self.radius = radius
525 |
526 | self.use_attention = use_attention
527 | self.normalize_attention = normalize_attention
528 |
529 | def computeKernel(self):
530 | # print("[r, d, d, r, theta, d, k] ",self.kernel_basis.shape)
531 | # print("[c_in,c_out, r, d, k]", self.weights.shape)
532 | kernel = (torch.einsum('pabrtij,xypab->yxrtij',self.kernel_basis_outer, self.outer_weights) +
533 | torch.einsum('lrtij,xyl->yxrtij',self.kernel_bullseye,self.bullseye_weights))
534 | return kernel
535 |
536 | def GenerateKernelBasis(self, r, theta, matrix_dim=2):
537 | """
538 | output: KB : [r+1, k, k, r+1, theta, k, k]
539 | KB_bullseye :
540 | """
541 | d = matrix_dim
542 | k = self.k
543 |
544 | KB_outer = torch.zeros(r, k, k, r+1, theta, k, k, requires_grad=False)
545 | K_bullseye = self.GenerateKernelBullseyeBasis(r+1, theta, d)
546 |
547 | for i in range(k):
548 | for j in range(k):
549 | for r1 in range(0, r):
550 | KB_outer[r1,i,j] = self.GenerateKernelBasisElement(r+1, theta, i, j, r1+1, d)
551 |
552 | return KB_outer, K_bullseye
553 |
554 | def GenerateKernelBasisElement(self, r, theta, i, j, r1, matrix_dim=2):
555 | """
556 | output: K: [r, theta, k, k]
557 | """
558 | d = matrix_dim
559 | k = self.k
560 |
561 | K = torch.zeros(r, theta, k, k, requires_grad=False)
562 | K[r1] = self.GenerateKernelBasisElementColumn(theta, i, j, d)
563 | return K
564 |
565 | def GenerateKernelBasisElementColumn(self, theta, i, j, matrix_dim=2):
566 | # d = matrix_dim
567 | # 0 <= i,j <= d-1
568 | # C = kernelcolumn: [theta, d, d]
569 | # C[0,:,:] = 0
570 | # C[0,i,j] = 1
571 | # for k in range(1,theta):
572 | # C[k] = RotMat(k*2*pi/theta) * C[0] * RotMat(-k*2*pi/theta)
573 | # # K[g v] = g K[v] g^{-1}
574 | d = matrix_dim
575 | k = self.k
576 |
577 | C = torch.zeros(theta, k, k, requires_grad=False)
578 | C[0,i,j] = 1
579 | # TODO: rho 1 -> rho n
580 | for ind in range(1, theta):
581 | theta_i = torch.tensor(ind*2*math.pi/theta)
582 | C[ind] = self.RegRotMat(theta_i.numpy(), k).matmul(C[0]).matmul(self.RegRotMat(-theta_i.numpy(), k))
583 | return C
584 |
585 | def GenerateKernelBullseyeBasis(self, r, theta, matrix_dim=2):
586 | """
587 | output: K: [k, r, theta, k, k]
588 | """
589 | d = matrix_dim
590 | k = self.k
591 |
592 | K = torch.zeros(k, r, theta, k, k, requires_grad=False)
593 | for l in range(k):
594 | K[l,0] = self.GenerateKernelBullseyeElementColumn(theta, l, d)
595 | return K
596 |
597 | def GenerateKernelBullseyeElementColumn(self, theta, l, matrix_dim=2):
598 | d = matrix_dim
599 | k = self.k
600 |
601 | first_col = torch.zeros(k)
602 | first_col[l] = 1.
603 | C = torch.zeros(theta, k, k, requires_grad=False)
604 | C[:] = torch.stack([torch.roll(first_col, i, 0) for i in range(0,self.k)],-1)
605 | return C
606 |
--------------------------------------------------------------------------------