├── .gitignore ├── LICENSE ├── README.md ├── policy_network └── neighborhood_v4_ddqn │ ├── eval.py │ ├── models.py │ ├── modules.py │ ├── train_tr5.py │ ├── utils.py │ └── visualize_episode.py └── road_interactions_environment ├── __init__.py ├── gym_road_interactions ├── __init__.py ├── core.py ├── envs │ ├── __init__.py │ ├── maps │ │ ├── neighborhood_v0_intersection_id_dict.pkl │ │ ├── neighborhood_v0_map_constants.pkl │ │ └── neighborhood_v0_map_lane_segments.pkl │ └── neighborhood_v4 │ │ ├── __init__.py │ │ ├── neighborhood_env_v4.py │ │ ├── neighborhood_env_v4_agents.py │ │ └── neighborhood_env_v4_utils.py ├── utils.py └── viz_utils.py ├── neighborhood_v4_collision_set_gen.py ├── neighborhood_v4_interaction_set_creation.ipynb └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIDAS: Multi-agent Interaction-aware Decision-making with Adaptive Strategies for Urban Autonomous Navigation 2 | Xiaoyi Chen, Pratik Chaudhari 3 | 4 | GRASP Lab, University of Pennsylvania 5 | 6 | ArXiV: https://arxiv.org/abs/2008.07081 7 | 8 | ## Prepare the environment 9 | Install Argoverse following the instructions here: https://github.com/argoai/argoverse-api 10 | 11 | Install `ffmpeg` 12 | 13 | ## Create collision and interaction sets 14 | 1. To create collision sets, change `na` on line 6 to be the number of agents in the environment, and change `date` on line 7 of `road_interactions_environment/neighhood_v4_collision_set_gen.py`. Then run `python road_interactions_environment/neighhood_v4_collision_set_gen.py`. 15 | 2. To create interaction sets, follow the steps in `road_interactions_environment/neighhood_v4_interaction_set_creation.ipynb`. 16 | 17 | ## Train the model 18 | In `policy_network/neighborhood_v4_ddqn/train_tr5.py`: 19 | 1. Change the filepaths on lines 237-243 to point to your generated collision sets, interaction set and evaluation set. 20 | 2. Change the environment and training hyperparameters from line 48 to 158 for your training purposes. The default values are for MIDAS. In order to run MLP, DeepSet, SocialAttention with the same hyperparameters, simply change the value of `value_net` on line 119 to `vanilla`, `deep_set` or `social_attention`. 21 | 3. Run `python policy_network/neighborhood_v4_ddqn/train_tr5.py`. Arguments: 22 | ``` 23 | --date Training date 24 | --code ID of your experiment. Eg. c0-0 25 | --seed Experiment seed. Any integer between 0 and 65535. 26 | ``` 27 | 28 | ## Visualize an episode with a model checkpoint 29 | In `policy_network/neighborhood_v4_ddqn/visualize_episode.py`: 30 | 1. Update the variables on lines 199-218 depending on the date, checkpoint ID and filepath, dataset filepath and the ids of the episodes that you want to visualize. 31 | 2. Run `python policy_network/neighborhood_v4_ddqn/visualize_episode.py` 32 | 33 | # Code References 34 | Argoverse https://github.com/argoai/argoverse-api 35 | 36 | Set Transformer https://github.com/juho-lee/set_transformer 37 | 38 | # License 39 | This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. To view a copy of this license, visit http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 40 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb 3 | import glob 4 | import math 5 | import argparse 6 | import time 7 | from datetime import datetime 8 | import io, sys, os 9 | import gc 10 | import _pickle as pickle 11 | import base64 12 | import argparse 13 | 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | 17 | import torch 18 | from torchsummary import summary 19 | import torch.optim as optim 20 | import torch.nn as nn 21 | from torch.autograd import Variable 22 | import torch.nn.functional as F 23 | from torch.utils import data 24 | 25 | from gym import logger as gymlogger 26 | from gym.wrappers import Monitor 27 | import gym 28 | 29 | import gym_road_interactions 30 | from neighborhood_v4_ddqn.models import * 31 | from neighborhood_v4_ddqn.utils import * 32 | 33 | # === Set up environment === 34 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 35 | logger = logging.getLogger(__name__) 36 | 37 | def eval(agent, env, train_configs, env_configs, ep_ids_to_skip, device, log_name, 38 | episode_lengths=False, collision_ep_agents_init=None, 39 | collision_folder_name=None, during_training=False): 40 | 41 | ni_ep_len_dict = {1: 200, 2: 200, 3: 200} 42 | total_reward = 0 43 | total_episode_lengths = 0 44 | total_dist_driven = 0 45 | episode_score = 0 46 | episode_length = 0 47 | collision_cnt = 0 48 | success_cnt = 0 # success means ego reaches the end before timeout 49 | avg_velocity = 0 50 | 51 | # how many saved episodes you wanna use 52 | num_saved_episodes = len(collision_ep_agents_init) 53 | episodes_run = 0 54 | 55 | for episode in range(1, num_saved_episodes + 1): 56 | episodes_run += 1 57 | # decide whether to reset using saved episodes or start fresh 58 | # The caller should instantiate the env with the correct env_config. 59 | # num_other_agents is different across episodes 60 | if episode <= num_saved_episodes and (collision_ep_agents_init is not None): # use collision_ep_agents_init 61 | ep_idx = episode - 1 62 | saved_env_config = collision_ep_agents_init[ep_idx]['env_config'] 63 | loaded_agents = collision_ep_agents_init[ep_idx]['agents_init'] 64 | ep_id = collision_ep_agents_init[ep_idx]['ep_id'] 65 | 66 | if ep_id in ep_ids_to_skip: 67 | continue 68 | 69 | # adjust ni, max_na, max_ep_len based on saved episode 70 | env.env_config_['ego_num_intersections_in_path'] = saved_env_config['ego_num_intersections_in_path'] 71 | env.env_config_['max_num_other_agents'] = saved_env_config['max_num_other_agents'] 72 | env.env_config_['max_episode_timesteps'] = ni_ep_len_dict[saved_env_config['ego_num_intersections_in_path']] 73 | train_configs['max_episode_timesteps'] = env.env_config_['max_episode_timesteps'] 74 | env.env_config_['agent_action_noise'] = saved_env_config['agent_action_noise'] 75 | log(log_name, 'Updated env_config: ' + str(env.env_config_)) 76 | 77 | # the saved env config for each episode should have these fields coincide with the env config during eval 78 | for key in ['agent_stochastic_stop', 'agent_shuffle_ids', 79 | 'expanded_lane_set_depth', 'c1', 'ego_expand_path_depth', 80 | 'single_intersection_type','ego_velocity_coeff','agent_velocity_coeff']: # 'ego_expand_path_depth' and 'c1' not tested for 0313 eval sets 81 | if ((key not in env.env_config_) and (key in saved_env_config)) or \ 82 | ((key in env.env_config_) and (key not in saved_env_config)) or \ 83 | (env.env_config_[key] != saved_env_config[key]): 84 | raise Exception(f'eval env_config key mismatch {key}. saved: {saved_env_config[key]}. env: {env.env_config_[key]}') 85 | 86 | # add this check to avoid the issue of incorrect save env config 87 | if len(loaded_agents) != (saved_env_config['num_other_agents'] + 1): 88 | env.env_config_['num_other_agents'] = len(loaded_agents) - 1 89 | else: 90 | env.env_config_['num_other_agents'] = saved_env_config['num_other_agents'] 91 | parametric_state, ego_b1 = env.reset(use_saved_agents=loaded_agents) 92 | 93 | # debug 94 | num_other_agents = env.env_config_['num_other_agents'] 95 | max_num_other_agents = env.env_config_['max_num_other_agents'] 96 | 97 | record_str = f"[Eval Episode {episode} | episodes_run={episodes_run} | ep_id={ep_id}] num other agents: {num_other_agents}. max_num_other_agents: {max_num_other_agents}" 98 | log(log_name, record_str) 99 | if not during_training: 100 | print(record_str) 101 | else: 102 | parametric_state, ego_b1 = env.reset() 103 | log(log_name, f"[Eval episode {episode}] Fresh reset") 104 | 105 | done = False 106 | episode_length = 0 107 | episode_score = 0 108 | 109 | while (done == False) and (episode_length < env.env_config_['max_episode_timesteps']): 110 | # select action 111 | if train_configs['num_future_states'] > 0: 112 | parametric_state_till_now = truncate_state_till_now(parametric_state, env_configs) 113 | else: 114 | parametric_state_till_now = parametric_state 115 | parametric_state_ts = torch.from_numpy(parametric_state_till_now).unsqueeze(0).float().to(device) 116 | action = agent.select_action(parametric_state_ts, ego_b1, 0, test=True) 117 | next_state, reward, done, info = env.step(action) # step processed action 118 | action_str = "selected action: " + str(action) 119 | log(log_name, f'[Eval episode {episode} ts={episode_length}] reward={reward:.1f} | {action_str}') 120 | 121 | parametric_state = next_state 122 | 123 | episode_score += reward 124 | episode_length += 1 125 | 126 | if (info[1] == True): 127 | collision_cnt += 1 128 | log(log_name, f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | collision") 129 | if not during_training: 130 | print(f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | collision") 131 | if collision_folder_name is not None: 132 | if not os.path.exists(f'collision/{collision_folder_name}'): 133 | os.mkdir(f'collision/{collision_folder_name}') 134 | if not os.path.exists(f'collision/{collision_folder_name}/{episode}'): 135 | os.mkdir(f'collision/{collision_folder_name}/{episode}') 136 | env.render(f'collision/{collision_folder_name}/{episode}') 137 | elif done == True: 138 | success_cnt += 1 139 | log(log_name, f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | success") 140 | if not during_training: 141 | print(f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | success") 142 | else: 143 | log(log_name, f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | timeout") 144 | if not during_training: 145 | print(f"[Eval episode {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | timeout") 146 | 147 | total_reward += episode_score 148 | avg_velocity += info[0] 149 | total_episode_lengths += episode_length 150 | total_dist_driven += info[2] 151 | 152 | avg_score = total_reward / episodes_run 153 | avg_velocity = avg_velocity / episodes_run 154 | collision_pct = collision_cnt / episodes_run * 100 155 | success_pct = success_cnt / episodes_run * 100 156 | avg_episode_lengths = total_episode_lengths / episodes_run 157 | if collision_cnt == 0: 158 | km_per_collision = total_dist_driven / 1000.0 159 | else: 160 | km_per_collision = (total_dist_driven / 1000.0) / collision_cnt 161 | 162 | if episode_lengths: 163 | return avg_score, collision_pct, success_pct, avg_velocity, km_per_collision, avg_episode_lengths 164 | else: 165 | return avg_score, collision_pct, success_pct, avg_velocity, km_per_collision 166 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb 3 | import glob 4 | import math 5 | import time 6 | import random 7 | from datetime import datetime 8 | import io, sys, os, copy 9 | import base64 10 | from collections import deque 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | import torch 16 | from torchsummary import summary 17 | import torch.optim as optim 18 | import torch.nn as nn 19 | from torch.autograd import Variable 20 | import torch.nn.functional as F 21 | from torch.utils import data 22 | import torch.optim as optim 23 | from torch.distributions.categorical import Categorical 24 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 25 | 26 | from gym import logger as gymlogger 27 | from gym.wrappers import Monitor 28 | import gym 29 | import gym_road_interactions 30 | 31 | from neighborhood_v4_ddqn.utils import ReplayBuffer, log 32 | from neighborhood_v4_ddqn.modules import * 33 | 34 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 35 | logger = logging.getLogger(__name__) 36 | 37 | class DQN(nn.Module): 38 | """ 39 | Network that has a backbone MLP and a head for value estimation for each action 40 | """ 41 | def __init__(self, train_configs): 42 | super(DQN, self).__init__() 43 | self.train_configs = train_configs 44 | 45 | self.s_encoder = nn.Sequential( 46 | nn.Linear(self.train_configs['state_dim'], 128), 47 | nn.ReLU(), 48 | nn.Linear(128,128) 49 | ) 50 | 51 | self.b1_fc = nn.Sequential( 52 | nn.Linear(1,64), 53 | nn.ReLU(), 54 | nn.Linear(64,128) 55 | ) 56 | 57 | self.int_layer = nn.Sequential( 58 | nn.Linear(128,128), 59 | nn.ReLU(), 60 | nn.Linear(128,128) 61 | ) 62 | 63 | # Value estimation 64 | self.v_fc = nn.Linear(128, self.train_configs['action_dim']) 65 | 66 | def forward(self, s, b1): 67 | h = self.s_encoder(s).view(s.size(0),-1) # bsize*128 68 | z_b1 = self.b1_fc(b1) # bsize*128 69 | z = self.int_layer(h + z_b1) # bsize*128 70 | # value 71 | v = self.v_fc(z) 72 | return v 73 | 74 | class TwinDQN(nn.Module): 75 | """ 76 | Network that has a backbone MLP and 2 heads for value estimations 77 | """ 78 | def __init__(self, train_configs, input_size=None): 79 | super(TwinDQN, self).__init__() 80 | self.train_configs = train_configs 81 | 82 | if input_size is None: 83 | self.input_size = self.train_configs['agent_total_state_dim'] * (self.train_configs['max_num_other_agents_in_range'] + 1) 84 | else: 85 | self.input_size = input_size 86 | 87 | self.enc = nn.Sequential( 88 | nn.Linear(self.input_size, 128), 89 | nn.ReLU(), 90 | nn.Linear(128,128) 91 | ) 92 | 93 | self.b1_fc = nn.Sequential( 94 | nn.Linear(1,64), 95 | nn.ReLU(), 96 | nn.Linear(64,128) 97 | ) 98 | 99 | self.v_fc1 = nn.Sequential( 100 | nn.Linear(128,128), 101 | nn.ReLU(), 102 | nn.Linear(128, self.train_configs['action_dim']) 103 | ) 104 | 105 | self.v_fc2 = nn.Sequential( 106 | nn.Linear(128,128), 107 | nn.ReLU(), 108 | nn.Linear(128, self.train_configs['action_dim']) 109 | ) 110 | 111 | def forward(self, o, b1): 112 | z_o = self.enc(o).view(o.size(0),-1) # bsize*128 113 | z_b1 = self.b1_fc(b1) # bsize*128 114 | q1 = self.v_fc1(z_o + z_b1) 115 | q2 = self.v_fc2(z_o + z_b1) 116 | return q1, q2 117 | 118 | def Q1(self, o, b1): 119 | z_o = self.enc(o).view(o.size(0),-1) # bsize*128 120 | z_b1 = self.b1_fc(b1) # bsize*128 121 | q1 = self.v_fc1(z_o + z_b1) 122 | return q1 123 | 124 | def encode(self, o, b1): 125 | z_o = self.enc(o).view(o.size(0),-1) # bsize*128 126 | z_b1 = self.b1_fc(b1) # bsize*128 127 | return (z_o + z_b1) # bsize*128 128 | 129 | def heads(self, h): 130 | # h: bsize*128 131 | q1 = self.v_fc1(h) 132 | q2 = self.v_fc2(h) 133 | return q1, q2 134 | 135 | def head_Q1(self, h): 136 | q1 = self.v_fc1(h) 137 | return q1 138 | 139 | class DeepSet(nn.Module): 140 | """ 141 | Network that uses MLP to encode every agent state and combine them using mean or max pooling 142 | and a head for value estimation for each action 143 | """ 144 | def __init__(self, train_configs): 145 | super(DeepSet, self).__init__() 146 | self.train_configs = train_configs 147 | 148 | # encode each agent state vector 149 | self.s_encoder = nn.Sequential( 150 | nn.Linear(self.train_configs['agent_total_state_dim'], 64), 151 | nn.ReLU(), 152 | nn.Linear(64,64), 153 | nn.ReLU(), 154 | nn.Linear(64,128) 155 | ) 156 | 157 | self.b1_fc = nn.Sequential( 158 | nn.Linear(1,64), 159 | nn.ReLU(), 160 | nn.Linear(64,128) 161 | ) 162 | 163 | self.int_layer = nn.Sequential( 164 | nn.Linear(128,128), 165 | nn.ReLU(), 166 | nn.Linear(128,128) 167 | ) 168 | 169 | # value estimation using the mean or max-pooled state vector 170 | self.v_fc1 = nn.Sequential( 171 | nn.Linear(128,128), 172 | nn.ReLU(), 173 | nn.Linear(128, self.train_configs['action_dim']) 174 | ) 175 | 176 | def forward(self, s, b1): 177 | # s: bsize * state_dim 178 | s = s.view(s.size(0), -1, self.train_configs['agent_total_state_dim']) # bsize * max_num_agents * agent_state_dim 179 | z = self.s_encoder(s) # bsize * max_num_agents * 64 180 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * 1 * 128 181 | z_0 = z[:,0,:].unsqueeze(1) + z_b1 182 | z = torch.cat([z_0, z[:,1:,:]], dim=1) # bsize * max_num_agents * 128 183 | z = self.int_layer(z) # bsize * max_num_agents * 128 184 | if self.train_configs['pooling'] == 'mean': 185 | z = z.mean(-2) # bsize * 128 186 | elif self.train_configs['pooling'] == 'max': 187 | z = z.max(-2)[0] # bsize * 128 188 | # value 189 | q1 = self.v_fc1(z) 190 | return q1 191 | 192 | class TwinDeepSet(nn.Module): 193 | """ 194 | Network that uses MLP to encode every agent state and combine them using mean or max pooling 195 | and a head for value estimation for each action 196 | """ 197 | def __init__(self, train_configs): 198 | super(TwinDeepSet, self).__init__() 199 | self.train_configs = train_configs 200 | 201 | # encode each agent state vector 202 | self.s_encoder = nn.Sequential( 203 | nn.Linear(self.train_configs['agent_state_dim'], 64), 204 | nn.ReLU(), 205 | nn.Linear(64,64), 206 | nn.ReLU(), 207 | nn.Linear(64,128) 208 | ) 209 | 210 | self.b1_fc = nn.Sequential( 211 | nn.Linear(1,64), 212 | nn.ReLU(), 213 | nn.Linear(64,128) 214 | ) 215 | 216 | self.int_layer = nn.Sequential( 217 | nn.Linear(128,128), 218 | nn.ReLU(), 219 | nn.Linear(128,128) 220 | ) 221 | 222 | # value estimation using the mean or max-pooled state vector 223 | self.v_fc1 = nn.Sequential( 224 | nn.Linear(128,128), 225 | nn.ReLU(), 226 | nn.Linear(128, self.train_configs['action_dim']) 227 | ) 228 | 229 | self.v_fc2 = nn.Sequential( 230 | nn.Linear(128,128), 231 | nn.ReLU(), 232 | nn.Linear(128, self.train_configs['action_dim']) 233 | ) 234 | 235 | def forward(self, s, b1): 236 | # s: bsize * state_dim 237 | s = s.view(s.size(0), -1, self.train_configs['agent_state_dim']) # bsize * max_num_agents * agent_state_dim 238 | z = self.s_encoder(s) # bsize * max_num_agents * 64 239 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * 1 * 128 240 | z_0 = z[:,0,:].unsqueeze(1) + z_b1 241 | z = torch.cat([z_0, z[:,1:,:]], dim=1) # bsize * max_num_agents * 128 242 | z = self.int_layer(z) # bsize * max_num_agents * 128 243 | if self.train_configs['pooling'] == 'mean': 244 | z = z.mean(-2) # bsize * 64 245 | elif self.train_configs['pooling'] == 'max': 246 | z = z.max(-2)[0] # bsize * 64 247 | # value 248 | q1 = self.v_fc1(z) 249 | q2 = self.v_fc1(z) 250 | return q1, q2 251 | 252 | def Q1(self, s, b1): 253 | # s: bsize * state_dim 254 | s = s.view(s.size(0), -1, self.train_configs['agent_state_dim']) # bsize * max_num_agents * agent_state_dim 255 | z = self.s_encoder(s) # bsize * max_num_agents * 64 256 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * 1 * 128 257 | z_0 = z[:,0,:].unsqueeze(1) + z_b1 258 | z = torch.cat([z_0, z[:,1:,:]], dim=1) # bsize * max_num_agents * 128 259 | z = self.int_layer(z) # bsize * max_num_agents * 128 260 | if self.train_configs['pooling'] == 'mean': 261 | z = z.mean(-2) # bsize * 128 262 | elif self.train_configs['pooling'] == 'max': 263 | z = z.max(-2)[0] # bsize * 128 264 | # value 265 | q1 = self.v_fc1(z) 266 | return q1 267 | 268 | class SetTransformer(nn.Module): 269 | def __init__(self, train_configs, input_size=None): 270 | super(SetTransformer, self).__init__() 271 | self.train_configs = train_configs 272 | ln = train_configs['layer_norm'] 273 | dim_hidden = 128 274 | self.dim_hidden = dim_hidden 275 | num_heads = 4 276 | num_inds = 32 277 | 278 | if input_size is None: 279 | self.input_size = self.train_configs['agent_total_state_dim'] 280 | else: 281 | self.input_size = input_size 282 | 283 | self.b1_fc = nn.Sequential( 284 | nn.Linear(1,64), 285 | nn.ReLU(), 286 | nn.Linear(64,128), 287 | nn.ReLU(), 288 | nn.Linear(128,128) 289 | ) 290 | 291 | if ('layers' in list(self.train_configs.keys())) and (self.train_configs['layers'] == 1): 292 | self.enc = nn.Sequential( 293 | ISAB(dim_in=self.input_size, dim_out=dim_hidden, 294 | num_heads=num_heads, num_inds=num_inds, ln=ln)) 295 | self.dec = nn.Sequential( 296 | PMA(dim=dim_hidden, num_heads=num_heads, 297 | num_seeds=self.train_configs['action_dim'], ln=ln), 298 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 299 | nn.Linear(dim_hidden, 1)) 300 | if ('model' in list(self.train_configs.keys())) and (self.train_configs['model'] == 'TwinDDQN'): 301 | self.dec2 = nn.Sequential( 302 | PMA(dim=dim_hidden, num_heads=num_heads, 303 | num_seeds=self.train_configs['action_dim'], ln=ln), 304 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 305 | nn.Linear(dim_hidden, 1)) 306 | else: 307 | self.enc = nn.Sequential( 308 | ISAB(dim_in=self.input_size, dim_out=dim_hidden, 309 | num_heads=num_heads, num_inds=num_inds, ln=ln), 310 | ISAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, num_inds=num_inds, ln=ln)) 311 | self.dec = nn.Sequential( 312 | PMA(dim=dim_hidden, num_heads=num_heads, 313 | num_seeds=self.train_configs['action_dim'], ln=ln), 314 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 315 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 316 | nn.Linear(dim_hidden, 1)) 317 | if ('model' in list(self.train_configs.keys())) and (self.train_configs['model'] == 'TwinDDQN'): 318 | self.dec2 = nn.Sequential( 319 | PMA(dim=dim_hidden, num_heads=num_heads, 320 | num_seeds=self.train_configs['action_dim'], ln=ln), 321 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 322 | SAB(dim_in=dim_hidden, dim_out=dim_hidden, num_heads=num_heads, ln=ln), 323 | nn.Linear(dim_hidden, 1)) 324 | 325 | def forward(self, s, b1): 326 | # s: bsize * state_dim 327 | # b1: bsize * 1 328 | s = s.view(s.size(0), -1, self.input_size) 329 | z = self.enc(s) 330 | z_b1 = self.b1_fc(b1).unsqueeze(1) 331 | z_0 = z[:,0,:].unsqueeze(1) + z_b1 332 | z = torch.cat([z_0, z[:,1:,:]], dim=1) # bsize * max_num_agents * agent_state_dim 333 | q1 = self.dec(z).view(s.size(0), self.train_configs['action_dim']) # bsize * 2 334 | if ('model' in list(self.train_configs.keys())) and (self.train_configs['model'] == 'TwinDDQN'): 335 | q2 = self.dec2(z).view(s.size(0), self.train_configs['action_dim']) # bsize * 2 336 | return q1, q2 337 | else: 338 | return q1 339 | 340 | def Q1(self, s, b1): 341 | # s: bsize * state_dim 342 | # b1: bsize * 1 343 | s = s.view(s.size(0), -1, self.input_size) 344 | z = self.enc(s) 345 | z_b1 = self.b1_fc(b1).unsqueeze(1) 346 | z_0 = z[:,0,:].unsqueeze(1) + z_b1 347 | z = torch.cat([z_0, z[:,1:,:]], dim=1) # bsize * max_num_agents * agent_state_dim 348 | q1 = self.dec(z).view(s.size(0), self.train_configs['action_dim']) # bsize * 2 349 | return q1 350 | 351 | class SocialAttention(nn.Module): 352 | def __init__(self, train_configs): 353 | super(SocialAttention, self).__init__() 354 | self.train_configs = train_configs 355 | dim_hidden = 64 356 | num_heads = 2 357 | 358 | self.enc = nn.Sequential( 359 | nn.Linear(self.train_configs['agent_state_dim'], dim_hidden), 360 | nn.ReLU(), 361 | nn.Linear(dim_hidden, dim_hidden), 362 | nn.ReLU(), 363 | nn.Linear(dim_hidden, dim_hidden)) 364 | 365 | self.b1_fc = nn.Sequential( 366 | nn.Linear(1,dim_hidden), 367 | nn.ReLU(), 368 | nn.Linear(dim_hidden,dim_hidden), 369 | nn.ReLU(), 370 | nn.Linear(dim_hidden,dim_hidden) 371 | ) 372 | 373 | self.int_layer = nn.Sequential( 374 | nn.Linear(dim_hidden,dim_hidden), 375 | nn.ReLU(), 376 | nn.Linear(dim_hidden,dim_hidden) 377 | ) 378 | 379 | self.attention_module = EgoAttention(dim_hidden, num_heads) 380 | 381 | self.dec = nn.Sequential( 382 | nn.Linear(dim_hidden, dim_hidden), 383 | nn.ReLU(), 384 | nn.Linear(dim_hidden, self.train_configs['action_dim'])) 385 | 386 | def forward(self, s, b1): 387 | # s: bsize * state_dim 388 | s = s.view(s.size(0), -1, self.train_configs['agent_state_dim']) 389 | ego_s = s[:,0,:].view(s.size(0),1,self.train_configs['agent_state_dim']) 390 | 391 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * dim_hidden 392 | 393 | h_ego_s = self.enc(ego_s) 394 | z_ego_s = h_ego_s + z_b1 395 | z_ego_s = self.int_layer(z_ego_s) 396 | 397 | h_s = self.enc(s) 398 | z_0 = h_s[:,0,:].unsqueeze(1) + z_b1 399 | z_s = torch.cat([z_0, h_s[:,1:,:]], dim=1) # bsize * max_num_agents * dim_hidden 400 | z_s = self.int_layer(z_s) 401 | 402 | rst = self.attention_module(z_ego_s, z_s) 403 | rst = self.dec(rst) 404 | 405 | return rst.view(s.size(0), self.train_configs['action_dim']) # bsize * 2 406 | 407 | class TwinSocialAttention(nn.Module): 408 | def __init__(self, train_configs): 409 | super(TwinSocialAttention, self).__init__() 410 | self.train_configs = train_configs 411 | dim_hidden = 64 412 | num_heads = 2 413 | 414 | self.enc = nn.Sequential( 415 | nn.Linear(self.train_configs['agent_total_state_dim'], dim_hidden), 416 | nn.ReLU(), 417 | nn.Linear(dim_hidden, dim_hidden), 418 | nn.ReLU(), 419 | nn.Linear(dim_hidden, dim_hidden)) 420 | 421 | self.b1_fc = nn.Sequential( 422 | nn.Linear(1,dim_hidden), 423 | nn.ReLU(), 424 | nn.Linear(dim_hidden,dim_hidden), 425 | nn.ReLU(), 426 | nn.Linear(dim_hidden,dim_hidden) 427 | ) 428 | 429 | self.int_layer = nn.Sequential( 430 | nn.Linear(dim_hidden,dim_hidden), 431 | nn.ReLU(), 432 | nn.Linear(dim_hidden,dim_hidden) 433 | ) 434 | 435 | self.attention_module = EgoAttention(dim_hidden, num_heads) 436 | 437 | self.dec1 = nn.Sequential( 438 | nn.Linear(dim_hidden, dim_hidden), 439 | nn.ReLU(), 440 | nn.Linear(dim_hidden, self.train_configs['action_dim'])) 441 | self.dec2 = nn.Sequential( 442 | nn.Linear(dim_hidden, dim_hidden), 443 | nn.ReLU(), 444 | nn.Linear(dim_hidden, self.train_configs['action_dim'])) 445 | 446 | def forward(self, s, b1): 447 | # s: bsize * state_dim 448 | s = s.view(s.size(0), -1, self.train_configs['agent_total_state_dim']) 449 | ego_s = s[:,0,:].view(s.size(0),1,self.train_configs['agent_total_state_dim']) 450 | 451 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * dim_hidden 452 | 453 | h_ego_s = self.enc(ego_s) 454 | z_ego_s = h_ego_s + z_b1 455 | z_ego_s = self.int_layer(z_ego_s) 456 | 457 | h_s = self.enc(s) 458 | z_0 = h_s[:,0,:].unsqueeze(1) + z_b1 459 | z_s = torch.cat([z_0, h_s[:,1:,:]], dim=1) # bsize * max_num_agents * dim_hidden 460 | z_s = self.int_layer(z_s) 461 | 462 | rst = self.attention_module(z_ego_s, z_s) 463 | 464 | q1 = self.dec1(rst).view(s.size(0), self.train_configs['action_dim']) 465 | q2 = self.dec2(rst).view(s.size(0), self.train_configs['action_dim']) 466 | 467 | return q1, q2 468 | 469 | def Q1(self, s, b1): 470 | # s: bsize * state_dim 471 | s = s.view(s.size(0), -1, self.train_configs['agent_total_state_dim']) 472 | ego_s = s[:,0,:].view(s.size(0),1,self.train_configs['agent_total_state_dim']) 473 | 474 | z_b1 = self.b1_fc(b1).unsqueeze(1) # bsize * dim_hidden 475 | 476 | h_ego_s = self.enc(ego_s) 477 | z_ego_s = h_ego_s + z_b1 478 | z_ego_s = self.int_layer(z_ego_s) 479 | 480 | h_s = self.enc(s) 481 | z_0 = h_s[:,0,:].unsqueeze(1) + z_b1 482 | z_s = torch.cat([z_0, h_s[:,1:,:]], dim=1) # bsize * max_num_agents * dim_hidden 483 | z_s = self.int_layer(z_s) 484 | 485 | rst = self.attention_module(z_ego_s, z_s) 486 | 487 | q1 = self.dec1(rst).view(s.size(0), self.train_configs['action_dim']) 488 | 489 | return q1 490 | 491 | # TwinDDQN agent 492 | class TwinDDQNAgent(object): 493 | def __init__(self, train_configs, device, log_name): 494 | self.steps_done = 0 # not concerned with exploration steps 495 | self.train_configs = train_configs 496 | self.replay_buffer = ReplayBuffer(int(train_configs['max_buffer_size'])) 497 | self.log_name = log_name 498 | 499 | self.device = device 500 | if 'use_lstm' in list(self.train_configs.keys()) and self.train_configs['use_lstm']: 501 | log(self.log_name, f"Using LSTMValueNet") 502 | self.value_net = LSTMValueNet(train_configs, self.device).to(self.device) 503 | self.value_net_target = LSTMValueNet(train_configs, self.device).to(self.device) 504 | else: 505 | if 'value_net' in list(self.train_configs.keys()): 506 | if self.train_configs['value_net'] == 'vanilla': 507 | self.value_net = TwinDQN(train_configs).to(self.device) 508 | self.value_net_target = TwinDQN(train_configs).to(self.device) 509 | elif self.train_configs['value_net'] == 'deep_set': 510 | self.value_net = TwinDeepSet(train_configs).to(self.device) 511 | self.value_net_target = TwinDeepSet(train_configs).to(self.device) 512 | elif self.train_configs['value_net'] == 'deeper_deep_set': 513 | raise Exception('Not implemented') 514 | elif self.train_configs['value_net'] == 'set_transformer': 515 | self.value_net = SetTransformer(train_configs).to(self.device) 516 | self.value_net_target = SetTransformer(train_configs).to(self.device) 517 | elif self.train_configs['value_net'] == 'social_attention': 518 | self.value_net = TwinSocialAttention(train_configs).to(self.device) 519 | self.value_net_target = TwinSocialAttention(train_configs).to(self.device) 520 | else: 521 | value_net_type = self.train_configs['value_net'] 522 | raise Exception(f'Invalid value_net type: {value_net_type}') 523 | else: 524 | self.value_net = TwinDQN(train_configs).to(self.device) 525 | self.value_net_target = TwinDQN(train_configs).to(self.device) 526 | 527 | self.value_net_target.load_state_dict(self.value_net.state_dict()) 528 | self.value_net_target.eval() 529 | self.value_net_optim = optim.Adam(self.value_net.parameters(), lr=self.train_configs['lrt']) 530 | 531 | def select_action(self, parametric_state, b1, total_timesteps, test=False): 532 | if test or (total_timesteps > self.train_configs['exploration_timesteps']): 533 | # parametric_state: 1*(state_dim), tensor 534 | parametric_state = parametric_state.to(self.device) 535 | b1 = torch.tensor([[b1]]).to(self.device) # (1,1) 536 | 537 | sample = np.random.random() 538 | if test: 539 | eps_threshold = self.train_configs['eps_end'] 540 | else: 541 | eps_threshold = self.train_configs['eps_end'] + (self.train_configs['eps_start'] - self.train_configs['eps_end']) * \ 542 | math.exp(-1. * total_timesteps / self.train_configs['eps_decay']) 543 | if sample > eps_threshold: 544 | with torch.no_grad(): 545 | Q1, Q2 = self.value_net(parametric_state, b1) 546 | action = (Q1 + Q2).max(1)[1].item() 547 | 548 | return action 549 | else: 550 | return np.random.randint(2) 551 | else: 552 | return np.random.randint(2) 553 | 554 | def __dropout_future_states(self, s, s_next): 555 | sample = np.random.uniform() 556 | if sample <= self.train_configs['future_state_dropout']: 557 | log(self.log_name, f"Dropout s, s_next") 558 | s = copy.deepcopy(s) 559 | s_next = copy.deepcopy(s_next) 560 | 561 | s = np.reshape(s, (-1, self.train_configs['num_ts_in_state'], self.train_configs['agent_state_dim'])) 562 | s[:,-(self.train_configs['num_future_states']):,:] = 0 563 | s = np.reshape(s, (-1, self.train_configs['state_dim'])) 564 | s_next = np.reshape(s_next, (-1, self.train_configs['num_ts_in_state'], self.train_configs['agent_state_dim'])) 565 | s_next[:,-(self.train_configs['num_future_states']):,:] = 0 566 | s_next = np.reshape(s_next, (-1, self.train_configs['state_dim'])) 567 | return s, s_next 568 | 569 | def train(self): 570 | # perform one model update 571 | # sample replay buffer 572 | # batch_size*state_dim, batch_size*state_dim, 573 | # batch_size*action_dim, batch_size*1, batch_size*1 574 | # tensor 575 | s,s_next,a,r,d,b1 = self.replay_buffer.sample(self.train_configs['batch_size']) 576 | if self.train_configs['num_future_states'] > 0: 577 | s,s_next = self.__dropout_future_states(s,s_next) 578 | 579 | s = torch.from_numpy(s).float().to(self.device) 580 | s_next = torch.from_numpy(s_next).float().to(self.device) 581 | a = torch.from_numpy(a).to(self.device).view(-1,1) 582 | r = torch.from_numpy(r).float().to(self.device).view(-1,1) 583 | not_d = torch.from_numpy(1-d).float().to(self.device).view(-1,1) 584 | b1 = torch.from_numpy(b1).float().to(self.device).view(-1,1) 585 | 586 | # calculate Q(s_t,a_t) 587 | # pdb.set_trace() 588 | curr_Q1_values, curr_Q2_values = self.value_net(s, b1) 589 | curr_Q1 = curr_Q1_values.gather(1,a) 590 | curr_Q2 = curr_Q2_values.gather(1,a) 591 | 592 | next_Q1_values, next_Q2_values = self.value_net(s_next, b1) 593 | next_Q1_values_target, next_Q2_values_target = self.value_net_target(s_next, b1) 594 | 595 | # calculate Q(s_{t+1}, a') 596 | actions_1 = next_Q1_values.max(1)[1].long().view(-1,1) 597 | actions_2 = next_Q2_values.max(1)[1].long().view(-1,1) 598 | next_Q = torch.min(next_Q1_values_target.gather(1, actions_2), next_Q2_values_target.gather(1, actions_1)) 599 | # Compute the target of the current Q values 600 | expected_Q = r + (self.train_configs['gamma'] * not_d * next_Q) 601 | # Compute loss 602 | loss = F.mse_loss(curr_Q1, expected_Q.detach()) + F.mse_loss(curr_Q2, expected_Q.detach()) 603 | 604 | # Optimize 605 | self.value_net.zero_grad() 606 | loss.backward() 607 | if self.train_configs['grad_clamp'] == True: 608 | for param in self.value_net.parameters(): 609 | param.grad.data.clamp_(-1, 1) 610 | self.value_net_optim.step() 611 | 612 | # Delayed policy updates 613 | self.steps_done += 1 614 | if ((self.steps_done + 1) % self.train_configs['target_update'] == 0): 615 | if self.log_name is not None: 616 | log(self.log_name, f"[TwinDDQNAgent] Updating Target Network at steps_done = {self.steps_done}") 617 | for param, target_param in zip(self.value_net.parameters(), self.value_net_target.parameters()): 618 | target_param.data.copy_(self.train_configs['tau'] * param.data + (1 - self.train_configs['tau']) * target_param.data) 619 | 620 | return loss.item() 621 | 622 | def reduce_lrt(self, new_lrt : float): 623 | for param_group in self.value_net_optim.param_groups: 624 | param_group['lr'] = new_lrt 625 | 626 | def save(self, scores, file_path, info=None, env_configs=None, 627 | train_configs=None, episode_lengths=None, eval_scores=None, 628 | collision_pcts=None, avg_velocities=None, train_collisions=None, 629 | eval_timeout_pcts=None, train_timeouts=None, eval_km_per_collision=None): # optional error info to save 630 | torch.save({'value_net': self.value_net.state_dict(), 631 | 'scores': scores, 632 | 'info': info, 633 | 'env_configs': env_configs, 634 | 'train_configs': train_configs, 635 | 'eval_scores': eval_scores, 636 | 'collision_pcts': collision_pcts, 637 | 'avg_velocities': avg_velocities, 638 | 'episode_lengths': episode_lengths, 639 | 'train_collisions': train_collisions, 640 | 'eval_timeout_pcts': eval_timeout_pcts, 641 | 'train_timeouts': train_timeouts, 642 | 'eval_km_per_collision': eval_km_per_collision}, file_path) 643 | 644 | def load(self, file_path): 645 | checkpoint = torch.load(file_path) 646 | self.value_net.load_state_dict(checkpoint['value_net']) 647 | self.value_net_target.load_state_dict(checkpoint['value_net']) 648 | return checkpoint 649 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/modules.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/juho-lee/set_transformer 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | import pdb 8 | 9 | class EgoAttention(nn.Module): 10 | def __init__(self, dim_hidden, num_heads): 11 | super().__init__() 12 | self.dim_hidden = dim_hidden 13 | self.num_heads = num_heads 14 | 15 | self.fc_q = nn.Linear(dim_hidden, dim_hidden, bias=False) 16 | self.fc_k = nn.Linear(dim_hidden, dim_hidden, bias=False) 17 | self.fc_v = nn.Linear(dim_hidden, dim_hidden, bias=False) 18 | self.fc_o = nn.Linear(dim_hidden, dim_hidden) 19 | 20 | def forward(self, ego, all_agents, mask=None): 21 | Q = self.fc_q(ego) 22 | K, V = self.fc_k(all_agents), self.fc_v(all_agents) 23 | 24 | dim_split = self.dim_hidden // self.num_heads 25 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 26 | K_ = torch.cat(K.split(dim_split, 2), 0) 27 | V_ = torch.cat(V.split(dim_split, 2), 0) 28 | 29 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_hidden), 2) # attention matrix 30 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 31 | return O 32 | 33 | class MAB(nn.Module): 34 | def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): 35 | super(MAB, self).__init__() 36 | self.dim_V = dim_V 37 | self.num_heads = num_heads 38 | self.fc_q = nn.Linear(dim_Q, dim_V) 39 | self.fc_k = nn.Linear(dim_K, dim_V) 40 | self.fc_v = nn.Linear(dim_K, dim_V) 41 | if ln: 42 | self.ln0 = nn.LayerNorm(dim_V) 43 | self.ln1 = nn.LayerNorm(dim_V) 44 | self.fc_o = nn.Linear(dim_V, dim_V) 45 | 46 | def forward(self, Q, K): 47 | Q = self.fc_q(Q) 48 | K, V = self.fc_k(K), self.fc_v(K) 49 | 50 | dim_split = self.dim_V // self.num_heads 51 | Q_ = torch.cat(Q.split(dim_split, 2), 0) 52 | K_ = torch.cat(K.split(dim_split, 2), 0) 53 | V_ = torch.cat(V.split(dim_split, 2), 0) 54 | 55 | A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) 56 | O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) 57 | O = O if getattr(self, 'ln0', None) is None else self.ln0(O) 58 | O = O + F.relu(self.fc_o(O)) 59 | O = O if getattr(self, 'ln1', None) is None else self.ln1(O) 60 | return O 61 | 62 | class SAB(nn.Module): 63 | def __init__(self, dim_in, dim_out, num_heads, ln=False): 64 | super(SAB, self).__init__() 65 | self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln) 66 | 67 | def forward(self, X): 68 | return self.mab(X, X) 69 | 70 | class ISAB(nn.Module): 71 | def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False): 72 | super(ISAB, self).__init__() 73 | self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) 74 | nn.init.xavier_uniform_(self.I) 75 | self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) 76 | self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) 77 | 78 | def forward(self, X): 79 | H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) 80 | return self.mab1(X, H) 81 | 82 | class PMA(nn.Module): 83 | def __init__(self, dim, num_heads, num_seeds, ln=False): 84 | super(PMA, self).__init__() 85 | self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) # 1 * k * d 86 | nn.init.xavier_uniform_(self.S) 87 | self.mab = MAB(dim, dim, dim, num_heads, ln=ln) 88 | 89 | def forward(self, X): 90 | return self.mab(self.S.repeat(X.size(0), 1, 1), X) 91 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/train_tr5.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb 3 | import glob 4 | import math 5 | import time 6 | import pickle 7 | from datetime import datetime 8 | import io, sys, os, copy 9 | import base64 10 | import argparse 11 | 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | import torch 16 | from torchsummary import summary 17 | import torch.optim as optim 18 | import torch.nn as nn 19 | from torch.autograd import Variable 20 | import torch.nn.functional as F 21 | from torch.utils import data 22 | 23 | from gym import logger as gymlogger 24 | from gym.wrappers import Monitor 25 | import gym 26 | import gym_road_interactions 27 | 28 | from neighborhood_v4_ddqn.models import * 29 | from neighborhood_v4_ddqn.utils import * 30 | from neighborhood_v4_ddqn.eval import eval 31 | 32 | # === Set up environment === 33 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 34 | logger = logging.getLogger(__name__) 35 | 36 | # === Loading function === 37 | def load_obj(name): 38 | with open(name + '.pkl', 'rb') as f: 39 | return pickle.load(f) 40 | 41 | # === Hyperparameters === 42 | # env configs 43 | ni_ep_len_dict = {1: 200, 2: 200, 3: 200} 44 | # experiment_date = '0414' 45 | ni = 1 46 | RL1_model_path = '' 47 | 48 | env_configs = {# parametric state 49 | 'use_global_frame': False, # whether to use global frame for the state vector 50 | 'normalize': True, # whether to normalize the state vector 51 | 'include_ttc': True, 52 | 'include_future_waypoints': 10, 53 | 'use_future_horizon_positions': False, # if false, just use the 10 future waypoints (each 10 waypoints apart) 54 | 'num_history_states': 10, # number of history states included in parametric state. 55 | 'num_future_states': 10, # number of future states included in parametric state. 56 | 'time_gap': 1, # time gap between two states. 57 | # total number of states included is: num_history_states + num_future_states + 1 58 | # t - time_gap * n_h, t - time_gap * (n_h-1), ..., t, t + time_gap, ..., t + time_gap * n_f 59 | 'stalemate_horizon': 5, # number of past ts (including current) that we track to determine stalemate 60 | 'include_polygon_dist': True, # if true, include sigmoid(polygon distance between ego and agent) 61 | 'include_agents_within_range': 10.0, # the range within which agents will be included. 62 | 'agent_state_dim': 6+4+10*2+1, # state dim at a single ts. related to 'include_ttc', 'include_polygon_dist', 'include_future_waypoints'' 63 | # num agents (change training env_config with this) 64 | 'num_other_agents': 0, # this will change 65 | 'max_num_other_agents': 25, # could only be 25, 40, 60 (one of the upper limits) 66 | 'max_num_other_agents_in_range': 25, # >=6. max number of other agents in range. Must be <= max_num_other_agents. default 25. 67 | # agent behavior 68 | 'agent_stochastic_stop': False, # (check) whether the other agent can choose to stop for ego with a certain probability 69 | 'agent_shuffle_ids': True, # (check) if yes, the id of other agents will be shuffled during reset 70 | 'rl_agent_configs' : [('',0.0)], 71 | 'all_selfplay': False, # if true, all agents will be changed to the most recent rl model, whatever the episode mix is 72 | # path 73 | 'ego_num_intersections_in_path': ni, # (check) 74 | 'ego_expand_path_depth': 2, # (check) the number of extending lanes from center intersection in ego path 75 | 'expanded_lane_set_depth': 1, # (check) 76 | 'single_intersection_type': 'mix', # (check) ['t-intersection', 'roundabout', 'mix'] 77 | 'c1': 2, # (check) 0: left/right turn at t4, 1: all possible depth-1 paths at t4, 2: all possible depth-1 paths at random intersection 78 | # training settings 79 | 'gamma': 0.99, 80 | 'max_episode_timesteps': ni_ep_len_dict[ni], 81 | # NOTE: if you change velocity_coeff, you have to change the whole dataset 82 | # (interaction set, collision set, etc) bc they are based on b1 and locations are 83 | # based on v_desired 84 | 'ego_velocity_coeff': (2.7, 8.3), # (check) (w, b). v_desired = w * b1 + b # experiments change this one 85 | 'agent_velocity_coeff': (2.7, 8.3), # (check) (w, b). v_desired = w * b1 + b # this is fixed! 86 | # reward = w * b1 + b # Be careful with signs! 87 | 'reward': 'default_fad', # [default, default_ttc, simplified, default_fad] 88 | 'time_penalty_coeff': (-1./20., -3./20.), 89 | 'velocity_reward_coeff': (0.5, 1.5), 90 | 'collision_penalty_coeff': (-5.0, -45.0), 91 | 'fad_penalty_coeff': 1.0, 92 | 'timeout_penalty_coeff': (-5.0, -20.0), 93 | 'stalemate_penalty_coeff': (-0.5, -1.5), 94 | # ego config 95 | 'use_default_ego': False, 96 | 'no_ego': False, # TODO: remember to turn this off if you want ego! 97 | # action noises 98 | 'ego_action_noise': 0.0, # lambda of the poisson noise applied to ego action 99 | 'agent_action_noise': 0.0, # lambda of the poisson noise applied to agent action 100 | # added on 0414 101 | 'ego_baseline': 5, # if None, run oracle. If >=0, choose the corresponding baseline setup (note: baseline 4 = oracle with ttc_break_tie=random) 102 | 'ttc_break_tie': 'id', 103 | 'agent_baseline': 5, 104 | 'stalemate_breaker': True} # 'agg_level=0' or 'b1' 105 | # the dimension of an agent across all ts 106 | env_configs['num_ts_in_state'] = env_configs['num_history_states'] + env_configs['num_future_states'] + 1 107 | env_configs['agent_total_state_dim'] = env_configs['agent_state_dim'] * env_configs['num_ts_in_state'] 108 | 109 | # training configs 110 | train_configs = {# model 111 | 'model': 'TwinDDQN', # [TwinDDQN, DDQN] # TODO check TwinDDQN Version!!!! 112 | 'gamma': env_configs['gamma'], 113 | 'target_update': 100, # number of policy updates until target update 114 | 'max_buffer_size': 200000, # max size for buffer that saves state-action-reward transitions 115 | 'batch_size': 128, 116 | 'lrt': 2e-5, 117 | 'tau': 0.2, 118 | 'exploration_timesteps': 0, 119 | 'value_net': 'set_transformer', # [vanilla, deep_set, set_transformer, social_attention] 120 | 'future_state_dropout': 0.7, # probability of dropping out future states during training 121 | 'use_lstm': True, # if true, will use LSTMValueNet along with the configured value_net 122 | # deep set 123 | 'pooling': 'mean', # ['mean', 'max'] 124 | # set transformer 125 | 'layer_norm' : True, # whether to use layer_norm in set transformer 126 | 'layers' : 2, 127 | # 'train_every_timesteps': 4, 128 | # epsilon decay for epsilon greedy 129 | 'eps_start': 1.0, 130 | 'eps_end': 0.01, 131 | 'eps_decay': 500, 132 | # training 133 | 'reward_threshold': 1000, # set impossible 134 | 'max_timesteps': 200000, 135 | 'train_every_episodes': 1, # if -1: train at every timestep. if n >= 1: train after every n timesteps 136 | 'save_every_timesteps': 5000, 137 | 'eval_every_episodes': 50, 138 | 'log_every_timesteps': 100, 139 | 'record_every_episodes': 30, 140 | 'seed': 0, 141 | 'grad_clamp': False, 142 | 'moving_avg_window': 100, 143 | 'replay_collision_episode': 0, # if 0, start every new episode fresh 144 | 'replay_episode': 0, # if 0, start every new episode fresh 145 | 'collision_episode_ratio': 0.25, # the ratio of saved collision episodes used in both training and testing. if set to 0, then don't use saved episodes 146 | 'interaction_episode_ratio': 0.5, # the ratio of saved interaction episodes used in both training and testing. if set to 0, then don't use interaction episodes 147 | 'buffer_agent_states': False, # if True, include agent-centric states in replay buffer 148 | # env 149 | 'agent_state_dim': env_configs['agent_state_dim'], 150 | 'agent_total_state_dim': env_configs['agent_total_state_dim'], # state_dim of each agent 151 | 'max_num_other_agents_in_range': env_configs['max_num_other_agents_in_range'], 152 | 'state_dim': env_configs['agent_total_state_dim'] * (env_configs['max_num_other_agents_in_range']+1), # total state_dim 153 | 'num_ts_in_state': env_configs['num_ts_in_state'], 154 | 'num_history_states': env_configs['num_history_states'], 155 | 'num_future_states': env_configs['num_future_states'], 156 | 'action_dim': 2, 157 | 'env_action_dim': 1, 158 | 'max_episode_timesteps': ni_ep_len_dict[ni]} 159 | 160 | # parse arguments 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument('-seed', type=int, default=0) 163 | parser.add_argument('-code', type=str, default='cXX') 164 | parser.add_argument('-date', type=str, default='0000') 165 | opt = parser.parse_args() 166 | train_configs['seed'] = opt.seed 167 | experiment_date = opt.date 168 | 169 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 170 | print(device) 171 | torch.manual_seed(train_configs['seed']) 172 | np.random.seed(train_configs['seed']) 173 | 174 | model_dir = './checkpoints' 175 | if not os.path.exists(model_dir): 176 | os.mkdir(model_dir) 177 | plot_dir = './learning_curves' 178 | if not os.path.exists(plot_dir): 179 | os.mkdir(plot_dir) 180 | log_dir = './logs' 181 | if not os.path.exists(log_dir): 182 | os.mkdir(log_dir) 183 | 184 | experiment_name = f'{experiment_date}_neighborhoodv4_ddqn_{opt.code}_seed={opt.seed}' 185 | RL2_model_path = f'./checkpoints/{experiment_name}_most-recent.pt' 186 | log_name = f'{log_dir}/{experiment_name}.log' 187 | if os.path.exists(log_name): 188 | os.remove(log_name) 189 | 190 | log(log_name, str(env_configs)) 191 | log(log_name, str(train_configs)) 192 | 193 | # === Training Prep === 194 | env = gym.make('Neighborhood-v4') 195 | env.set_env_config(env_configs) 196 | env.set_train_config_and_device(train_configs, device) 197 | env.log_name_ = log_name 198 | agent = TwinDDQNAgent(train_configs, device, log_name) 199 | 200 | ## for plot 201 | # train 202 | scores = [] 203 | moving_averages = [] 204 | avg_velocities = [] 205 | avg_velocities_moving_averages = [] 206 | episode_lengths = [] 207 | episode_lengths_moving_averages = [] 208 | train_collisions = [] 209 | train_collision_pcts_moving_averages = [] 210 | train_timeouts = [] 211 | train_timeout_pcts_moving_averages = [] 212 | # eval 213 | eval_scores = [] 214 | collision_pcts = [] 215 | eval_timeout_pcts = [] 216 | eval_avg_velocities = [] 217 | eval_km_per_collision = [] 218 | running_score = -float('Inf') 219 | total_timesteps = 0 220 | prev_save_at_timestep = 0 221 | prev_eval_at_timestep = 0 222 | episode_timesteps = 0 223 | episode_collided = False 224 | episode_replay_counter = 0 225 | reduced_lrt = False # only reduce lrt once 226 | prev_highest_running_score = -float('Inf') 227 | # format for eval pfmc: [eval_score, collision_pct] 228 | prev_best_eval_pfmc_1 = [] 229 | prev_best_eval_pfmc_2 = [] 230 | prev_best_eval_pfmc_3 = [] 231 | prev_best_eval_pfmc_4 = [] 232 | episode = 1 233 | 234 | # load all collision episode agent inits 235 | # TODO: make sure the env_config is the same as our env_config! 236 | w,b = env_configs['ego_velocity_coeff'] 237 | collision_set_path_list = [f'collision_sets/0321_collision-set_train_ego-vel-coeff={w},{b}_ni=1_na=5', 238 | f'collision_sets/0321_collision-set_train_ego-vel-coeff={w},{b}_ni=1_na=10', 239 | f'collision_sets/0321_collision-set_train_ego-vel-coeff={w},{b}_ni=1_na=15', 240 | f'collision_sets/0321_collision-set_train_ego-vel-coeff={w},{b}_ni=1_na=20', 241 | f'collision_sets/0321_collision-set_train_ego-vel-coeff={w},{b}_ni=1_na=25'] 242 | interaction_set_path = f'train_sets/0601_tr5_training_set_interaction_ego-vel-coeff={w},{b}' 243 | eval_set_path = f'eval_sets/0601_tr5_eval_set_ego-vel-coeff={w},{b}' 244 | collision_ep_agents_init = [] 245 | for collision_set_path in collision_set_path_list: 246 | curr_collision_set = load_obj(collision_set_path) 247 | if isinstance(curr_collision_set, dict): 248 | saved_env_config = curr_collision_set['env_config'] 249 | curr_collision_set = curr_collision_set['agents_init'] 250 | log(log_name, 'saved env_config: ' + str(saved_env_config)) 251 | for key in ['agent_stochastic_stop', 252 | 'agent_shuffle_ids', 'ego_num_intersections_in_path', 253 | 'expanded_lane_set_depth', 'c1', 'ego_expand_path_depth', 254 | 'single_intersection_type','ego_velocity_coeff','agent_velocity_coeff']: 255 | if ((key not in env_configs) and (key in saved_env_config)) or \ 256 | ((key in env_configs) and (key not in saved_env_config)) or \ 257 | (env_configs[key] != saved_env_config[key]): 258 | raise Exception(f'env_config key mismatch: {key}') 259 | collision_ep_agents_init.append(curr_collision_set) 260 | 261 | interaction_set = load_obj(interaction_set_path) # it's a list 262 | eval_set = load_obj(eval_set_path) # it's a list 263 | 264 | # CHANGE ALL THOSE IF YOU WANT TO LOAD CHECKPOINT 265 | # Example: 266 | # load_model = True 267 | # checkpoint_path = './ppo_checkpoint_episode=3000_ts=10000.pt' 268 | # total_timesteps = 10000 269 | # episode = 3000 270 | load_model = False 271 | checkpoint_path = './checkpoints/XX.pt' 272 | 273 | if load_model: 274 | total_timesteps = 0 275 | episode = 0 276 | prev_save_at_timestep = total_timesteps 277 | prev_eval_at_timestep = total_timesteps 278 | 279 | has_problems = False 280 | log(log_name, f"Loading model from {checkpoint_path}") 281 | print(f"Loading model from {checkpoint_path}") 282 | # we will treat everything as new - just the value_net itself is from previous checkpoint 283 | checkpoint = agent.load(checkpoint_path) 284 | checkpoint_env_configs = checkpoint['env_configs'] 285 | 286 | for key in env_configs.keys(): 287 | if key in ['rl_agent_configs', 'num_other_agents']: 288 | continue 289 | if key not in checkpoint_env_configs.keys(): 290 | print(f"key {key} not in checkpoint_env_configs. current: {key}:{env_configs[key]}") 291 | has_problems = True 292 | continue 293 | if (checkpoint_env_configs[key] != env_configs[key]): 294 | print(f"mismatch in env_config: {key} : saved: {checkpoint_env_configs[key]}, current: {env_configs[key]}") 295 | has_problems = True 296 | continue 297 | 298 | checkpoint_train_configs = checkpoint['train_configs'] 299 | for key in train_configs.keys(): 300 | if key in ['state_dim','seed']: # TODO 301 | continue 302 | if key not in checkpoint_train_configs.keys(): 303 | print(f"key {key} not in checkpoint_train_configs. current: {key}:{train_configs[key]}") 304 | has_problems = True 305 | continue 306 | if (checkpoint_train_configs[key] != train_configs[key]): 307 | print(f"mismatch in train_configs: {key} : saved: {checkpoint_train_configs[key]}, current: {train_configs[key]}") 308 | has_problems = True 309 | continue 310 | 311 | if has_problems: 312 | raise Exception("Problem in loading model") 313 | 314 | scores = checkpoint['scores'] 315 | eval_scores = checkpoint['eval_scores'] 316 | collision_pcts = checkpoint['collision_pcts'] 317 | avg_velocities = checkpoint['avg_velocities'] 318 | episode_lengths = checkpoint['episode_lengths'] 319 | train_collisions = checkpoint['train_collisions'] 320 | eval_timeout_pcts = checkpoint['eval_timeout_pcts'] 321 | train_timeouts = checkpoint['train_timeouts'] 322 | eval_km_per_collision = checkpoint['eval_km_per_collision'] 323 | 324 | running_score = moving_average(scores, train_configs['moving_avg_window']) 325 | 326 | agent.save(checkpoint, RL2_model_path) 327 | 328 | else: 329 | # save init of RL2_model 330 | agent.save([], RL2_model_path) 331 | 332 | # timing 333 | time_start = time.time() 334 | 335 | def plot_curves(): 336 | # plot 337 | plt.plot(np.arange(len(moving_averages)) + train_configs['moving_avg_window'], moving_averages, color='b', label='mov avg scores') 338 | plt.plot(np.arange(len(train_collision_pcts_moving_averages)) + train_configs['moving_avg_window'], train_collision_pcts_moving_averages, color='r', label='mov avg collision pct') 339 | plt.plot(np.arange(len(train_timeout_pcts_moving_averages)) + train_configs['moving_avg_window'], train_timeout_pcts_moving_averages, color='m', label='mov avg timeout pct') 340 | plt.title(experiment_name + ' ts=' + str(total_timesteps) + ' training curve') 341 | plt.xlabel('Episodes') 342 | plt.ylabel('Reward') 343 | plt.legend() 344 | plt.savefig( 345 | f"{plot_dir}/{experiment_name}_train.png", 346 | dpi=400, 347 | ) 348 | plt.close() 349 | 350 | if train_configs['eval_every_episodes'] > 0: 351 | plt.plot(train_configs['moving_avg_window'] + np.arange(1, len(eval_scores)+1) * train_configs['eval_every_episodes'], eval_scores, color='b', label='eval score') 352 | plt.plot(train_configs['moving_avg_window'] + np.arange(1, len(collision_pcts)+1) * train_configs['eval_every_episodes'], collision_pcts, color='r', label='collision rate') 353 | plt.plot(train_configs['moving_avg_window'] + np.arange(1, len(collision_pcts)+1) * train_configs['eval_every_episodes'], eval_timeout_pcts, color='m', label='timeout rate') 354 | plt.plot(train_configs['moving_avg_window'] + np.arange(1, len(collision_pcts)+1) * train_configs['eval_every_episodes'], eval_km_per_collision, color='g', label='km per collision') 355 | 356 | plt.title(experiment_name + ' ts=' + str(total_timesteps) + 'eval curve') 357 | plt.xlabel('Episodes') 358 | plt.ylabel('Eval Scores / Collision Rate (%) / Km per Collision') 359 | plt.legend() 360 | plt.savefig( 361 | f"{plot_dir}/{experiment_name}_eval.png", 362 | dpi=400, 363 | ) 364 | plt.close() 365 | 366 | plt.plot(np.arange(len(moving_averages)) + train_configs['moving_avg_window'], avg_velocities_moving_averages, color='b', label='train mov avg') 367 | plt.plot(train_configs['moving_avg_window'] + np.arange(1, len(eval_avg_velocities)+1) * train_configs['eval_every_episodes'], eval_avg_velocities, color='g', label='eval') 368 | plt.title(experiment_name + ' ts=' + str(total_timesteps) + 'Avg ego vel') 369 | plt.xlabel('Episodes') 370 | plt.ylabel('Avg ego velocity') 371 | plt.legend() 372 | plt.savefig( 373 | f"{plot_dir}/{experiment_name}_avg_velocities.png", 374 | dpi=400, 375 | ) 376 | plt.close() 377 | 378 | plt.plot(np.arange(len(moving_averages)) + train_configs['moving_avg_window'], episode_lengths_moving_averages, color='g') 379 | plt.title(experiment_name + ' ts=' + str(total_timesteps) + 'ep lengths (mov avg)') 380 | plt.xlabel('Episodes') 381 | plt.ylabel('Episode lengths') 382 | plt.legend() 383 | plt.savefig( 384 | f"{plot_dir}/{experiment_name}_ep_lengths.png", 385 | dpi=400, 386 | ) 387 | plt.close() 388 | 389 | def change_to_self_play(loaded_agents): 390 | updated_loaded_agents = copy.deepcopy(loaded_agents) 391 | for agent_id, agent in updated_loaded_agents.items(): 392 | if int(agent_id) == 0: 393 | continue 394 | # this essentially changes this agent to an rl agent 395 | agent.rl_model_path_ = RL2_model_path 396 | agent.train_config_ = train_configs 397 | agent.device_ = device 398 | return updated_loaded_agents 399 | 400 | def append_action_and_reward(action, reward, hist_actions, hist_rewards): 401 | num_a_r_saved = env_configs['num_future_states'] * env_configs['time_gap']+1 402 | if len(hist_actions) == num_a_r_saved: 403 | hist_actions = hist_actions[1:] 404 | hist_rewards = hist_rewards[1:] 405 | hist_actions.append(action) 406 | hist_rewards.append(reward) 407 | return hist_actions, hist_rewards 408 | 409 | 410 | print("=== Training ===") 411 | 412 | while total_timesteps < train_configs['max_timesteps']: 413 | log(log_name, f"[Episode {episode}] Starting...") 414 | # check replay config 415 | if train_configs['replay_collision_episode'] > 0 and train_configs['replay_episode'] > 0: 416 | raise Exception('Only one of replay_collision_episode and replay_episode can be > 0!') 417 | # decide whether to use a saved collision episode or to use a randomly generated episode 418 | ep_sample = np.random.random() 419 | if (ep_sample < train_configs['collision_episode_ratio']): # use collision_ep_agents_init 420 | # pick a set 421 | random_set_idx = np.random.randint(len(collision_ep_agents_init)) 422 | curr_set = collision_ep_agents_init[random_set_idx] 423 | # pick an episode from that set 424 | random_ep_idx = np.random.randint(len(curr_set)) 425 | env.env_config_['num_other_agents'] = len(curr_set[random_ep_idx])-1 426 | log(log_name, f"[Episode {episode}] Using saved collision episode {random_ep_idx} from {collision_set_path_list[random_set_idx]}") 427 | 428 | loaded_agents = curr_set[random_ep_idx] 429 | if 'all_selfplay' in env_configs.keys() and env_configs['all_selfplay'] == True: 430 | updated_loaded_agents = change_to_self_play(loaded_agents) 431 | parametric_state, ego_b1 = env.reset(use_saved_agents=updated_loaded_agents) 432 | else: 433 | parametric_state, ego_b1 = env.reset(use_saved_agents=loaded_agents) 434 | 435 | elif (ep_sample < train_configs['collision_episode_ratio'] + train_configs['interaction_episode_ratio']): # use interaction_set 436 | # pick an episode from the set 437 | random_ep_idx = np.random.randint(len(interaction_set)) 438 | # interaction set is saved as a list of dicts, each with its own env_config 439 | saved_env_config = interaction_set[random_ep_idx]['env_config'] 440 | loaded_agents = interaction_set[random_ep_idx]['agents_init'] 441 | 442 | # the saved env config for each episode should have these fields coincide with the env config during eval 443 | for key in ['agent_stochastic_stop', 444 | 'agent_shuffle_ids', 'ego_num_intersections_in_path', 445 | 'expanded_lane_set_depth', 'c1', 'ego_expand_path_depth', 446 | 'single_intersection_type','ego_velocity_coeff','agent_velocity_coeff']: 447 | if ((key not in env.env_config_) and (key in saved_env_config)) or \ 448 | ((key in env.env_config_) and (key not in saved_env_config)) or \ 449 | (env.env_config_[key] != saved_env_config[key]): 450 | raise Exception(f'env_config key mismatch {key}. saved: {saved_env_config[key]}. env: {env.env_config_[key]}') 451 | 452 | # only change num_other_agents! 453 | env.env_config_['num_other_agents'] = saved_env_config['num_other_agents'] 454 | 455 | if 'all_selfplay' in env_configs.keys() and env_configs['all_selfplay'] == True: 456 | updated_loaded_agents = change_to_self_play(loaded_agents) 457 | parametric_state, ego_b1 = env.reset(use_saved_agents=updated_loaded_agents) 458 | else: 459 | parametric_state, ego_b1 = env.reset(use_saved_agents=loaded_agents) 460 | 461 | log(log_name, f"[Episode {episode}] Using saved interaction episode {random_ep_idx}") 462 | else: # randomly generate or consult replay configs 463 | # reset depending on replay config 464 | if train_configs['replay_collision_episode'] > 0: 465 | if ((episode_replay_counter == 0) and (episode_collided == True)) or \ 466 | ((episode_replay_counter > 0) and (episode_replay_counter < train_configs['replay_collision_episode'])): 467 | episode_replay_counter += 1 468 | parametric_state, ego_b1 = env.reset(use_prev_episode=True) 469 | log(log_name, f"[Episode {episode}] Using previous episode for the {episode_replay_counter}-th time") 470 | log(log_name, str(parametric_state)) 471 | else: 472 | episode_replay_counter = 0 473 | parametric_state, ego_b1 = env.reset() 474 | elif train_configs['replay_episode'] > 0: 475 | if (episode > 1 and episode_replay_counter < train_configs['replay_episode']): 476 | episode_replay_counter += 1 477 | parametric_state, ego_b1 = env.reset(use_prev_episode=True) 478 | log(log_name, f"[Episode {episode}] Using previous episode for the {episode_replay_counter}-th time") 479 | log(log_name, str(parametric_state)) 480 | else: 481 | episode_replay_counter = 0 482 | parametric_state, ego_b1 = env.reset() 483 | else: 484 | # start fresh with random na 485 | num_other_agents = np.random.randint(env_configs['max_num_other_agents']+1) 486 | env_configs['num_other_agents'] = num_other_agents 487 | log(log_name, f"[Episode {episode}] Start fresh. num_other_agents={num_other_agents}") 488 | 489 | if 'all_selfplay' in env_configs.keys() and env_configs['all_selfplay'] == True: 490 | if os.path.exists(RL2_model_path): 491 | rl_agent_configs = [(RL2_model_path,1.0)] 492 | env_configs['rl_agent_configs'] = rl_agent_configs 493 | log(log_name, "Updated env_config: " + str(env_configs)) 494 | else: 495 | raise Exception(f"{RL2_model_path} doesn't exist") 496 | 497 | env.set_env_config(env_configs) 498 | parametric_state, ego_b1 = env.reset() 499 | 500 | hist_parametric_state = parametric_state 501 | hist_actions = [] 502 | hist_rewards = [] 503 | hist_agent_states,_,_,_,_ = env.generate_agent_states() 504 | 505 | episode_timesteps = 0 506 | score = 0 507 | episode_collided = False 508 | done = False 509 | 510 | while (done == False) and (episode_timesteps < train_configs['max_episode_timesteps']): 511 | total_timesteps += 1 512 | episode_timesteps += 1 513 | 514 | # select action 515 | if env_configs['num_future_states'] > 0: 516 | parametric_state_till_now = truncate_state_till_now(parametric_state, env_configs) 517 | else: 518 | parametric_state_till_now = parametric_state 519 | parametric_state_ts = torch.from_numpy(parametric_state_till_now).unsqueeze(0).float().to(device) 520 | ## for printing 521 | action = int(agent.select_action(parametric_state_ts, ego_b1, total_timesteps)) 522 | # take action in env: 523 | next_state, reward, done, info = env.step(action) # step processed action 524 | action_str = "selected action: " + str(action) 525 | log(log_name, f'[Episode {episode} ts={episode_timesteps}] reward={reward:.1f} | {action_str}') 526 | 527 | # process next state 528 | parametric_state = next_state 529 | 530 | # save 531 | hist_actions, hist_rewards = append_action_and_reward(action, reward, hist_actions, hist_rewards) 532 | if env_configs['num_future_states'] == 0: 533 | agent.replay_buffer.add(hist_parametric_state, parametric_state, 534 | hist_actions[0], hist_rewards[0], float(done), info[4]) # info[4] is ego_b1 535 | else: 536 | if len(hist_actions) == (env_configs['num_future_states'] * env_configs['time_gap'] + 1): 537 | agent.replay_buffer.add(hist_parametric_state, parametric_state, 538 | hist_actions[0], hist_rewards[0], 0.0, info[4]) # info[4] is ego_b1 539 | hist_parametric_state = parametric_state 540 | 541 | # if buffer_agent_states 542 | if train_configs['buffer_agent_states']: 543 | agent_states, agent_rewards, agent_actions, agent_dones, agent_b1s = env.generate_agent_states() 544 | for i in range(len(agent_states)): 545 | agent.replay_buffer.add(hist_agent_states[i], agent_states[i], agent_actions[i], agent_rewards[i], float(agent_dones[i]), agent_b1s[i]) 546 | hist_agent_states = agent_states 547 | 548 | # update agent if train_every_episodes == -1 (train at every ts) 549 | if train_configs['train_every_episodes'] == -1: 550 | if (total_timesteps >= train_configs['exploration_timesteps']) and \ 551 | (agent.replay_buffer.size() > train_configs['batch_size']): 552 | loss = agent.train() 553 | agent.save(scores, RL2_model_path, train_configs=train_configs, env_configs=env_configs, \ 554 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 555 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 556 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 557 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 558 | if ((total_timesteps + 1) % train_configs['log_every_timesteps'] == 0): 559 | log(log_name, f'[Total timesteps {total_timesteps}] At every timestamp, Update and save agent to {RL2_model_path}') 560 | 561 | score += reward 562 | 563 | if (done == True): 564 | episode_collided = info[1] 565 | 566 | # when an episode ends, repeat the last state t_future times until 1 entry with done=1 is saved 567 | saved_done = 0.0 568 | num_repeat_steps = env_configs['num_future_states'] * env_configs['time_gap'] 569 | for t in range(1, num_repeat_steps+1): 570 | # take action in env: 571 | next_state, reward, done, info = env.step(0) 572 | log(log_name, f'[Episode {episode} ts={episode_timesteps}] repeat {t}/{num_repeat_steps} after done. reward={reward:.1f}') 573 | 574 | # process next state 575 | parametric_state = next_state 576 | # save 577 | hist_actions, hist_rewards = append_action_and_reward(0, reward, hist_actions, hist_rewards) 578 | if t == num_repeat_steps: 579 | saved_done = 1.0 580 | agent.replay_buffer.add(hist_parametric_state, parametric_state, 581 | hist_actions[0], hist_rewards[0], saved_done, info[4]) # info[4] is ego_b1 582 | hist_parametric_state = parametric_state 583 | 584 | episode += 1 585 | scores.append(score) 586 | avg_velocities.append(info[0]) 587 | if (info[1] == True): 588 | train_collisions.append(100) 589 | else: 590 | train_collisions.append(0) 591 | if (done == False): 592 | train_timeouts.append(100) 593 | else: 594 | train_timeouts.append(0) 595 | episode_lengths.append(episode_timesteps) 596 | running_score = moving_average(scores, train_configs['moving_avg_window']) 597 | 598 | # update agent if train_every_episodes > 0 (train after complete episodes) 599 | if train_configs['train_every_episodes'] > 0 and (episode % train_configs['train_every_episodes'] == 0): 600 | if (total_timesteps >= train_configs['exploration_timesteps']) and \ 601 | (agent.replay_buffer.size() > train_configs['batch_size']): 602 | num_train_iterations = sum(episode_lengths[(-train_configs['train_every_episodes']):]) 603 | for i in range(1, num_train_iterations+1): 604 | loss = agent.train() 605 | dt = (int)(time.time() - time_start) 606 | if i % 10 == 0: 607 | log(log_name, f'[Episode {episode} | Total timesteps = {total_timesteps}] Updating agent {i} / {num_train_iterations} | Time: {dt//3600:02}:{dt%3600//60:02}:{dt%60:02}') 608 | log(log_name, f'[Episode {episode} | Total timesteps = {total_timesteps}] Saving most recent model to {RL2_model_path} ...') 609 | agent.save(scores, RL2_model_path, train_configs=train_configs, env_configs=env_configs, \ 610 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 611 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 612 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 613 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 614 | 615 | 616 | # append running score and save highest running score model 617 | if (total_timesteps >= train_configs['exploration_timesteps']) and (len(scores) > train_configs['moving_avg_window']): 618 | # trackers of moving averages 619 | moving_averages.append(running_score) 620 | avg_velocities_moving_averages.append(moving_average(avg_velocities, train_configs['moving_avg_window'])) 621 | episode_lengths_moving_averages.append(moving_average(episode_lengths, train_configs['moving_avg_window'])) 622 | train_collision_pcts_moving_averages.append(moving_average(train_collisions, train_configs['moving_avg_window'])) 623 | train_timeout_pcts_moving_averages.append(moving_average(train_timeouts, train_configs['moving_avg_window'])) 624 | 625 | # Adjust learning rate 626 | if (len(train_collision_pcts_moving_averages) > 50): 627 | max_mov_avg_collision_pct = max(train_collision_pcts_moving_averages[-50:]) 628 | if (max_mov_avg_collision_pct < 10) and (reduced_lrt == False): 629 | target_lrt = 1e-5 # keep min at 1e-5 630 | log(log_name, f'[Total timesteps {total_timesteps}] Reducing lrt with max last 50 mov avg collision: {max_mov_avg_collision_pct}. Reducing to {target_lrt}') 631 | agent.reduce_lrt(target_lrt) 632 | reduced_lrt = True 633 | 634 | # Eval model at every several ts 635 | if train_configs['eval_every_episodes'] > 0: 636 | if (episode % train_configs['eval_every_episodes'] == 0): 637 | log(log_name, f'[Total timesteps {total_timesteps}] Evaluating model ...') 638 | agent.value_net.eval() 639 | # == Eval == 640 | # save current random state and set np random state to a new one 641 | log(log_name, 'generating eval seed...') 642 | eval_seed = np.random.randint(65536) 643 | log(log_name, str(eval_seed)) 644 | while (eval_seed == train_configs['seed']): 645 | eval_seed = np.random.randint(65536) 646 | log(log_name, str(eval_seed)) 647 | log(log_name, f'eval seed: {eval_seed}') 648 | train_random_state = np.random.get_state() 649 | np.random.seed(eval_seed) 650 | # create eval env 651 | eval_env = gym.make('Neighborhood-v4') 652 | copy_eval_env_configs = copy.deepcopy(env_configs) # because eval will change env_configs, we pass in a copy 653 | eval_env.set_env_config(copy_eval_env_configs) 654 | eval_env.log_name_ = log_name 655 | eval_env.max_v_value_ = max_v_value 656 | eval_env.max_rel_v_value_ = max_rel_v_value 657 | # call eval 658 | eval_score, collision_pct, success_pct, eval_avg_velocity, km_per_collision = eval(agent, eval_env, train_configs, env_configs, [], device, log_name, collision_ep_agents_init=eval_set, during_training=True) 659 | timeout_pct = 100 - success_pct - collision_pct 660 | log(log_name, f'[Eval] avg score = {eval_score:.2f} | collision = {collision_pct}% | success = {success_pct}% | timeout = {timeout_pct}% | km_per_collision = {km_per_collision} | avg_velocity = {eval_avg_velocity}') 661 | # set np random state back to the train seed 662 | np.random.set_state(train_random_state) 663 | agent.value_net.train() 664 | 665 | # == Eval trackers == 666 | eval_scores.append(eval_score) 667 | collision_pcts.append(collision_pct) 668 | eval_timeout_pcts.append(timeout_pct) 669 | eval_avg_velocities.append(eval_avg_velocity) 670 | eval_km_per_collision.append(km_per_collision) 671 | 672 | # == Eval saving == 673 | # eval_pfmc_1 674 | if (len(prev_best_eval_pfmc_1) == 0) or \ 675 | ((eval_score > prev_best_eval_pfmc_1[0]) and (collision_pct <= prev_best_eval_pfmc_1[1])): 676 | prev_best_eval_pfmc_1 = [eval_score, collision_pct] 677 | log(log_name, f"Best eval pfmc 1: [{eval_score:.2f}, {collision_pct:.1f}%] at {timestring}") 678 | filename = f'{model_dir}/{experiment_name}_best-eval-pfmc-1.pt' 679 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 680 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 681 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 682 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 683 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 684 | # eval_pfmc_2 685 | if (len(prev_best_eval_pfmc_2) == 0) or \ 686 | (eval_score - collision_pct > prev_best_eval_pfmc_2[0] - prev_best_eval_pfmc_2[1]): 687 | prev_best_eval_pfmc_2 = [eval_score, collision_pct] 688 | log(log_name, f"Best eval pfmc 2: [{eval_score:.2f}, {collision_pct:.1f}%] with difference {eval_score - collision_pct} at {timestring}") 689 | filename = f'{model_dir}/{experiment_name}_best-eval-pfmc-2.pt' 690 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 691 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 692 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 693 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 694 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 695 | # eval_pfmc_3 696 | if (len(prev_best_eval_pfmc_3) == 0) or \ 697 | (eval_score > prev_best_eval_pfmc_3[0]): 698 | prev_best_eval_pfmc_3 = [eval_score, collision_pct] 699 | log(log_name, f"Best eval pfmc 3: [{eval_score:.2f}, {collision_pct:.1f}%] at {timestring}") 700 | filename = f'{model_dir}/{experiment_name}_best-eval-pfmc-3.pt' 701 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 702 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 703 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 704 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 705 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 706 | 707 | # eval_pfmc_4 708 | if (len(prev_best_eval_pfmc_4) == 0) or \ 709 | (success_pct > prev_best_eval_pfmc_4[1]): 710 | prev_best_eval_pfmc_4 = [eval_score, success_pct] 711 | log(log_name, f"Best eval pfmc 4: [{eval_score:.2f}, success_pct={success_pct}] at {timestring}") 712 | filename = f'{model_dir}/{experiment_name}_best-eval-pfmc-4.pt' 713 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 714 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 715 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 716 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 717 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 718 | 719 | # save learning curves 720 | if (episode % train_configs['record_every_episodes'] == 0): 721 | plot_curves() 722 | 723 | # if avg reward > 1000 then save and stop training: 724 | if (running_score > prev_highest_running_score): 725 | # logger.info("##### Success! #####") 726 | timestring = (datetime.now()).strftime("%m%d-%H:%M") 727 | log(log_name, f"Highest training running score {running_score} at {timestring}") 728 | filename = f'{model_dir}/{experiment_name}_highest-running-score.pt' 729 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 730 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 731 | episode_lengths=episode_lengths, avg_velocities=avg_velocities,\ 732 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 733 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 734 | prev_highest_running_score = running_score 735 | 736 | if running_score >= train_configs['reward_threshold']: 737 | log(log_name, f"##### Success! ##### running_score={running_score}") 738 | break 739 | 740 | dt = (int)(time.time() - time_start) 741 | log(log_name, "[Episode {}] [Total timesteps {}] length: {} | score: {:.1f} | Avg. last {} scores: {:.3f} | Time: {:02}:{:02}:{:02}\n"\ 742 | .format(episode, total_timesteps, episode_timesteps, score, train_configs['moving_avg_window'], running_score, dt//3600, dt%3600//60, dt%60)) 743 | 744 | # Save episode 745 | if total_timesteps > train_configs['exploration_timesteps'] and \ 746 | (total_timesteps - prev_save_at_timestep >= train_configs['save_every_timesteps']): 747 | log(log_name, f'[Total timesteps {total_timesteps}] Saving model ...') 748 | timestring = (datetime.now()).strftime("%m%d-%H:%M") 749 | filename = f'{model_dir}/{experiment_name}_{timestring}_ts={total_timesteps}_running-score={running_score}.pt' 750 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 751 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 752 | episode_lengths=episode_lengths, avg_velocities=avg_velocities, \ 753 | train_collisions=train_collisions, eval_timeout_pcts=eval_timeout_pcts, \ 754 | train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 755 | prev_save_at_timestep += train_configs['save_every_timesteps'] 756 | 757 | # Finishing training 758 | log(log_name, f'[Episode {episode}][Total timesteps {total_timesteps}] Done Training. Saving final model') 759 | filename = f'{model_dir}/{experiment_name}_ep={episode}_final_ts={total_timesteps}.pt' 760 | agent.save(scores, filename, train_configs=train_configs, env_configs=env_configs, \ 761 | eval_scores=eval_scores, collision_pcts=collision_pcts, \ 762 | episode_lengths=episode_lengths, avg_velocities=avg_velocities, train_collisions=train_collisions, \ 763 | eval_timeout_pcts=eval_timeout_pcts, train_timeouts=train_timeouts, eval_km_per_collision=eval_km_per_collision) 764 | 765 | # plot 766 | plot_curves() 767 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb 3 | import glob 4 | import math 5 | import time 6 | from datetime import datetime 7 | import io, sys, os, copy 8 | import base64 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | 13 | import torch 14 | from torchsummary import summary 15 | import torch.optim as optim 16 | import torch.nn as nn 17 | from torch.autograd import Variable 18 | import torch.nn.functional as F 19 | from torch.utils import data 20 | 21 | from gym import logger as gymlogger 22 | from gym.wrappers import Monitor 23 | import gym 24 | import gym_road_interactions 25 | 26 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 27 | logger = logging.getLogger(__name__) 28 | 29 | # Simple replay buffer that saves a list of tuples of (state, next_state, action, reward, done) (tensors) 30 | class ReplayBuffer(object): 31 | def __init__(self, max_size=1e6): 32 | self.storage = [] 33 | self.max_size = max_size 34 | self.ptr = 0 35 | self.cnt = 0 36 | 37 | def add(self, state, new_state, action, reward, done_bool, b1): 38 | # state: (state_dim,) 39 | # new_state: (state_dim,) 40 | # action: (action_dim,) 41 | # reward: scalar 42 | # done_bool: scalar 43 | # b1: scalar 44 | # numpy arrays 45 | 46 | data = (state, new_state, action, reward, done_bool, b1) 47 | if len(self.storage) == self.max_size: 48 | self.storage[int(self.ptr)] = data 49 | self.ptr = (self.ptr + 1) % self.max_size 50 | else: 51 | self.storage.append(data) 52 | self.cnt += 1 53 | 54 | def sample(self, batch_size): 55 | # output: 56 | # batch_size*state_dim, batch_size*state_dim, 57 | # batch_size*action_dim, batch_size*1, batch_size*1 58 | # np array 59 | ind = np.random.randint(0, len(self.storage), size=batch_size) 60 | s,ns,u,r,d,b1 = [],[],[],[],[],[] 61 | for i in ind: 62 | S, NS, U, R, D, B1 = self.storage[i] 63 | s.append(S) 64 | ns.append(NS) 65 | u.append(U) 66 | r.append(R) 67 | d.append(D) 68 | b1.append(B1) 69 | return np.array(s), np.array(ns), np.array(u), np.array(r).reshape(-1, 1), np.array(d).reshape(-1, 1), np.array(b1).reshape(-1, 1) 70 | 71 | def size(self): 72 | return self.cnt 73 | 74 | # helper for calculating moving average of last 10 scores 75 | def moving_average(scores, window): 76 | if len(scores) < window: 77 | return sum(scores) / len(scores) 78 | else: 79 | return sum(scores[(-window):]) / float(window) 80 | 81 | # Logging function 82 | def log(fname, s): 83 | # if not os.path.isdir(os.path.dirname(fname)): 84 | # os.system(f'mkdir -p {os.path.dirname(fname)}') 85 | f = open(fname, 'a') 86 | f.write(f'{str(datetime.now())}: {s}\n') 87 | f.close() 88 | 89 | -------------------------------------------------------------------------------- /policy_network/neighborhood_v4_ddqn/visualize_episode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pdb, copy 3 | import glob 4 | import math 5 | import argparse 6 | import time 7 | from datetime import datetime 8 | import io, sys, os, pickle 9 | import base64 10 | import argparse 11 | import shutil 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | 16 | import torch 17 | from torchsummary import summary 18 | import torch.optim as optim 19 | import torch.nn as nn 20 | from torch.autograd import Variable 21 | import torch.nn.functional as F 22 | from torch.utils import data 23 | 24 | from gym import logger as gymlogger 25 | from gym.wrappers import Monitor 26 | import gym 27 | import gym_road_interactions 28 | from gym_road_interactions.viz_utils import write_nonsequential_idx_video 29 | 30 | from neighborhood_v4_ddqn.models import * 31 | from neighborhood_v4_ddqn.utils import * 32 | 33 | def save_obj(obj, name): 34 | with open(name + '.pkl', 'wb') as f: 35 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 36 | 37 | def load_obj(name): 38 | with open(name + '.pkl', 'rb') as f: 39 | return pickle.load(f) 40 | 41 | # === Set up environment === 42 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 43 | logger = logging.getLogger(__name__) 44 | 45 | # make this part independent of the checkpoint seeds 46 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 47 | print(device) 48 | eval_seed = np.random.randint(65536) 49 | torch.manual_seed(eval_seed) 50 | np.random.seed(eval_seed) 51 | 52 | def model_test_run(agent, env, ep_ids, set_stats, 53 | train_configs, env_configs, device, log_name, model_viz_path, 54 | curr_ego_b1 = None, # when set to None, use the ego_b1 provided in the set_stats. Otherwise, use provided curr_ego_b1 for the whole set 55 | ego_action_noise = None # when set to None, use the ego_action_noise provided in train_configs. Otherwise, use provided ego_action_noise for the whole set 56 | ): 57 | ni_ep_len_dict = {1: 200, 2:200, 3:200} 58 | 59 | total_reward = 0 60 | total_dist_driven = 0 61 | # total_normalized_brake = 0 62 | total_time_to_finish = 0 63 | collision_cnt = 0 64 | success_cnt = 0 # success means ego reaches the end before timeout 65 | 66 | episode = 0 67 | episode_score = 0 68 | episode_length = 0 69 | 70 | num_episodes = len(ep_ids) 71 | log(log_name, f'number of test run episodes: {num_episodes}') 72 | 73 | # reset ego_action_noise if not None 74 | if ego_action_noise is not None: 75 | print(f"Reset env ego_action_noise to {ego_action_noise}") 76 | env.env_config_['ego_action_noise'] = ego_action_noise 77 | 78 | for ep_id in ep_ids: 79 | episode += 1 80 | 81 | episode_path = f'{model_viz_path}/{ep_id}' 82 | 83 | if os.path.exists(episode_path): 84 | shutil.rmtree(episode_path) 85 | os.mkdir(episode_path) 86 | logger.info(f'visualizing episode {ep_id} at {episode_path}') 87 | 88 | # else: 89 | # ONLY CHANGE num_other_agents in env.env_config_! The caller should instantiate the env with the correct env_config. 90 | # Only num_other_agents is different across episodes 91 | if isinstance(set_stats, dict): 92 | saved_env_config = set_stats[ep_id]['env_config'] 93 | loaded_agents = set_stats[ep_id]['agents_init'] 94 | else: 95 | for ep_entry in set_stats: 96 | if ep_entry['ep_id'] == ep_id: 97 | saved_env_config = ep_entry['env_config'] 98 | loaded_agents = ep_entry['agents_init'] 99 | break 100 | 101 | # the saved env config for each episode should have these fields coincide with the env config during eval 102 | # TODO 'ttc_break_tie','stalemate_breaker', 'use_default_ego' 103 | for key in ['agent_stochastic_stop', 104 | 'agent_shuffle_ids', 'single_intersection_type', 105 | 'expanded_lane_set_depth', 'ego_velocity_coeff', 'agent_velocity_coeff', 106 | 'ego_expand_path_depth', 'c1']: 107 | if ((key not in env.env_config_) and (key in saved_env_config)) or \ 108 | ((key in env.env_config_) and (key not in saved_env_config)) or \ 109 | (env.env_config_[key] != saved_env_config[key]): 110 | raise Exception(f'eval env_config key mismatch {key}. saved: {saved_env_config[key]}. env: {env.env_config_[key]}') 111 | 112 | # change num_other_agents! 113 | # add this check to avoid the issue of incorrect save env config 114 | if len(loaded_agents) != (saved_env_config['num_other_agents'] + 1): 115 | env.env_config_['num_other_agents'] = len(loaded_agents) - 1 116 | else: 117 | env.env_config_['num_other_agents'] = saved_env_config['num_other_agents'] 118 | # change num intersections in path and max num other agents according to the recorded episode 119 | print('Model ni=' + str(env.env_config_['ego_num_intersections_in_path']) + ' max_na=' + str(env.env_config_['max_num_other_agents'])) 120 | env.env_config_['ego_num_intersections_in_path'] = saved_env_config['ego_num_intersections_in_path'] 121 | env.env_config_['max_num_other_agents'] = saved_env_config['max_num_other_agents'] 122 | env.env_config_['max_episode_timesteps'] = ni_ep_len_dict[saved_env_config['ego_num_intersections_in_path']] 123 | train_configs['max_episode_timesteps'] = ni_ep_len_dict[saved_env_config['ego_num_intersections_in_path']] 124 | print('Updated to saved ni=' + str(env.env_config_['ego_num_intersections_in_path']) + ' max_na=' + str(env.env_config_['max_num_other_agents'])) 125 | parametric_state, ego_b1 = env.reset(use_saved_agents=loaded_agents) 126 | env.render(episode_path) 127 | # reset ego_b1 if set within range 128 | if curr_ego_b1 is not None and curr_ego_b1 <= 1 and curr_ego_b1 >= -1: 129 | print(f"Reset env ego_b1 to {curr_ego_b1}") 130 | env.reset_ego_b1(curr_ego_b1) 131 | 132 | # debug 133 | num_other_agents = env.env_config_['num_other_agents'] 134 | max_num_other_agents = env.env_config_['max_num_other_agents'] 135 | log(log_name, f"[Viz Ep {episode}] Using ep_id={ep_id}. num other agents: {num_other_agents}. max_num_other_agents: {max_num_other_agents}") 136 | print(f"[Viz Ep {episode}] Using ep_id={ep_id}. num other agents: {num_other_agents}. max_num_other_agents: {max_num_other_agents}") 137 | 138 | done = False 139 | episode_length = 0 140 | episode_score = 0 141 | 142 | while (done == False) and (episode_length < train_configs['max_episode_timesteps']): 143 | # select action 144 | if not env.env_config_['use_default_ego']: 145 | if train_configs['num_future_states'] > 0: 146 | parametric_state_till_now = truncate_state_till_now(parametric_state, env_configs) 147 | else: 148 | parametric_state_till_now = parametric_state 149 | parametric_state_ts = torch.from_numpy(parametric_state_till_now).unsqueeze(0).float().to(device) # 1*(state_dim) 150 | action = agent.select_action(parametric_state_ts, ego_b1, 0, test=True) 151 | else: 152 | action = 1 # will be overriden in env anyways 153 | next_state, reward, done, info = env.step(action) # step processed action 154 | action_str = "selected action: " + str(action) 155 | log(log_name, f'[Viz Ep {episode} ts={episode_length}] reward={reward:.1f} | {action_str}') 156 | 157 | parametric_state = next_state # !!! 158 | 159 | episode_score += reward 160 | episode_length += 1 161 | 162 | env.render(episode_path) 163 | 164 | if (info[1] == True): 165 | collision_cnt += 1 166 | end_status = 'collision' 167 | elif done == True: 168 | success_cnt += 1 169 | end_status = 'success' 170 | else: 171 | end_status = 'timeout' 172 | 173 | log(log_name, f"[Viz Ep {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | {end_status}") 174 | print(f"[Viz Ep {episode} len={episode_length}] episode score={episode_score:.1f} | episode_dist_driven = {info[2]} | {end_status}") 175 | 176 | total_reward += episode_score 177 | total_dist_driven += info[2] 178 | total_time_to_finish += episode_length 179 | 180 | # make a video 181 | fps = 5 182 | img_wildcard = f"{episode_path}/*.png" 183 | output_fpath = f"{episode_path}/len={episode_length:.2f}_score={episode_score:.2f}_{end_status}.mp4" 184 | cmd = write_nonsequential_idx_video(img_wildcard, output_fpath, fps, True) 185 | 186 | # trackers 187 | avg_score = total_reward / num_episodes 188 | avg_time_to_finish = total_time_to_finish / float(num_episodes) 189 | # avg_normalized_brake = total_normalized_brake / num_episodes 190 | collision_pct = collision_cnt / num_episodes * 100 191 | success_pct = success_cnt / num_episodes * 100 192 | if collision_cnt == 0: 193 | km_per_collision = total_dist_driven / 1000.0 194 | else: 195 | km_per_collision = (total_dist_driven / 1000.0) / collision_cnt 196 | 197 | return avg_score, avg_time_to_finish, collision_pct, success_pct, km_per_collision 198 | 199 | # ======== INPUT ======== 200 | seeds = [0] 201 | model_date = '0717' 202 | seed_date_hyperparam_combos = [(0, model_date)] 203 | gen_date = '0724' 204 | code = 'c4-1' 205 | lrt = '2e-05' 206 | curr_ego_b1s = [None] # [-1,-0.5,0.0,0.5,1] 207 | ego_action_noise = None 208 | set_stats_name = 'testing_set_stats' 209 | model_suffix = 'best-eval-pfmc-2' 210 | model_suffix_short = 'bep2' 211 | 212 | # TODO: Specify the ep_ids # This specifies which episodes you want to run 213 | ep_ids = [] 214 | 215 | # TODO: Specify the path of the dataset where the ep_ids are from 216 | set_stats_path = f'./datasets/testing_set' 217 | 218 | # ===================== 219 | ni_ep_len_dict = {1: 200, 2:200, 3:200} 220 | ni = 1 221 | max_na = 25 222 | 223 | filename_suffix = '' 224 | 225 | if __name__ == '__main__': 226 | parser = argparse.ArgumentParser() 227 | parser.add_argument('-name', type=str, default=code + "-" + model_suffix_short) 228 | opt = parser.parse_args() 229 | name = opt.name 230 | 231 | all_seeds_avg_scores = 0 232 | all_seeds_avg_time_to_finish = 0 233 | all_seeds_avg_collision_pct = 0 234 | all_seeds_avg_success_pct = 0 235 | all_seeds_avg_km_per_collision = 0 236 | 237 | time_start = time.time() 238 | 239 | for curr_ego_b1 in curr_ego_b1s: 240 | for combo in seed_date_hyperparam_combos: 241 | seed, date = combo 242 | # === Load checkpoint === 243 | experiment_name = f'{date}_neighborhoodv4_ddqn_{code}' 244 | model_name = f'{experiment_name}_seed={seed}_{model_suffix}' 245 | checkpoint_path = f'./checkpoints/{model_name}.pt' 246 | log_name = f'./viz_logs/{gen_date}_viz_model={name}_{set_stats_name}' 247 | 248 | if curr_ego_b1 is not None: 249 | log_name += f'_ego-b1={curr_ego_b1}' 250 | if ego_action_noise is not None: 251 | log_name += f'_ego-action-noise={ego_action_noise}' 252 | if filename_suffix is not None: 253 | log_name += filename_suffix 254 | log_name += '.log' 255 | 256 | checkpoint = torch.load(checkpoint_path) 257 | train_configs = checkpoint['train_configs'] 258 | env_configs = checkpoint['env_configs'] 259 | print(checkpoint_path) 260 | print(train_configs) 261 | print(env_configs) 262 | 263 | # use v3 or v4 c0 checkpoints in v4 c1 or later env 264 | if 'num_history_states' not in env_configs.keys(): 265 | env_configs['num_history_states'] = 0 266 | if 'num_future_states' not in env_configs.keys(): 267 | env_configs['num_future_states'] = 0 268 | train_configs['num_future_states'] = 0 269 | if 'time_gap' not in env_configs.keys(): 270 | env_configs['time_gap'] = 0 271 | if 'stalemate_horizon' not in env_configs.keys(): 272 | env_configs['stalemate_horizon'] = 4 # make it larger than 1 cuz otherwise every state is considered stalemate 273 | if 'agent_total_state_dim' not in env_configs.keys(): 274 | env_configs['num_ts_in_state'] = env_configs['num_history_states'] + env_configs['num_future_states'] + 1 275 | env_configs['agent_total_state_dim'] = env_configs['agent_state_dim'] * env_configs['num_ts_in_state'] 276 | train_configs['agent_total_state_dim'] = env_configs['agent_total_state_dim'] 277 | train_configs['agent_state_dim'] = env_configs['agent_state_dim'] 278 | train_configs['state_dim'] = env_configs['agent_total_state_dim'] * (env_configs['max_num_other_agents']+1) 279 | train_configs['num_ts_in_state'] = env_configs['num_ts_in_state'] 280 | 281 | print('Setting test env diff. from training... ') 282 | env_configs['use_default_ego'] = use_default_ego 283 | env_configs['stalemate_breaker'] = True 284 | env_configs['ttc_break_tie'] = 'id' 285 | print(env_configs) 286 | 287 | # model_viz_path 288 | model_viz_path = f'./visualization/{model_name}_{set_stats_name}' 289 | if curr_ego_b1 is not None: 290 | model_viz_path += f'_ego-b1={curr_ego_b1}' 291 | if ego_action_noise is not None: 292 | model_viz_path += f'_ego-action-noise={ego_action_noise}' 293 | 294 | # make folder 295 | if not os.path.exists(model_viz_path): 296 | os.mkdir(model_viz_path) 297 | 298 | # set_stats 299 | set_stats = load_obj(set_stats_path) 300 | 301 | # === Prep Env === 302 | # Remember to set the env_config of env before passing it to model_test_run()! 303 | env = gym.make('Neighborhood-v4') 304 | env.set_env_config(env_configs) 305 | env.set_train_config_and_device(train_configs, device) 306 | env.log_name_ = log_name 307 | if not use_default_ego: 308 | if train_configs['model'] == 'TwinDDQN': 309 | agent = TwinDDQNAgent(train_configs, device, log_name) 310 | else: 311 | agent = DDQNAgent(train_configs, device) 312 | agent.value_net.eval() 313 | agent.load(checkpoint_path) 314 | else: 315 | agent = None 316 | 317 | # === eval === 318 | model_test_run_results = model_test_run(agent, env, ep_ids, set_stats, 319 | train_configs, env_configs, device, log_name, model_viz_path, 320 | curr_ego_b1=curr_ego_b1, ego_action_noise=ego_action_noise) 321 | avg_score, avg_time_to_finish, collision_pct, success_pct, km_per_collision = model_test_run_results 322 | 323 | # === tracker === 324 | all_seeds_avg_scores += avg_score 325 | all_seeds_avg_time_to_finish += avg_time_to_finish 326 | all_seeds_avg_collision_pct += collision_pct 327 | all_seeds_avg_success_pct += success_pct 328 | all_seeds_avg_km_per_collision += km_per_collision 329 | 330 | dt = (int)(time.time() - time_start) 331 | print(f'[Seed {seed}] Model test run: avg score = {avg_score:.3f} | ttf = {avg_time_to_finish} | collision = {collision_pct}% | success = {success_pct}% | | km_per_collision = {km_per_collision}') 332 | print("Time: {:02}:{:02}:{:02}".format(dt//3600, dt%3600//60, dt%60)) 333 | 334 | all_seeds_avg_scores /= len(seeds) 335 | all_seeds_avg_time_to_finish /= len(seeds) 336 | all_seeds_avg_collision_pct /= len(seeds) 337 | all_seeds_avg_success_pct /= len(seeds) 338 | all_seeds_avg_km_per_collision /= len(seeds) 339 | 340 | print(f'Model test run over seeds: avg score = {all_seeds_avg_scores:.3f} | ttf = {all_seeds_avg_time_to_finish} | collision = {all_seeds_avg_collision_pct}% | success = {all_seeds_avg_success_pct}% | km_per_collision = {all_seeds_avg_km_per_collision}') 341 | -------------------------------------------------------------------------------- /road_interactions_environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sherrychen1120/MIDAS/d5f5c2de00630454db7a2bc8c59fd030358f4dfa/road_interactions_environment/__init__.py -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | register( 4 | id='Neighborhood-v4', 5 | entry_point='gym_road_interactions.envs.neighborhood_v4:NeighborhoodEnvV4', 6 | ) 7 | -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/core.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import math 3 | import copy 4 | import numpy as np 5 | from typing import List, Optional, Sequence 6 | 7 | # b 8 | class AgentType(Enum): 9 | ego_vehicle = 0 10 | other_vehicle = 1 11 | pedestrian = 2 12 | cyclist = 3 13 | motorcycle = 4 14 | on_road_obstacle = 5 15 | other_mover = 6 16 | 17 | # c 18 | class MapCategory(Enum): 19 | driveway = 0 20 | intersection = 1 21 | crosswalk = 2 22 | sidewalk = 3 23 | 24 | class Position: 25 | def __init__(self, x: float, y: float, heading: float): 26 | """ 27 | Represents a position vector p = [x, y, heading] 28 | Args: 29 | x, y: city coordinates, in meters 30 | heading: heading angle w.r.t. city coordinates, in radians 31 | """ 32 | self.x_ = x 33 | self.y_ = y 34 | self.heading_ = ((heading + math.pi) % (2 * math.pi) - math.pi) 35 | self.world_to_position_se3_ = None 36 | 37 | def calculate_distance(self, other_position: 'Position') -> float: 38 | """ 39 | Calculates the distance between self and other position 40 | """ 41 | dist = math.sqrt((other_position.x_ - self.x_)**2 + \ 42 | (other_position.y_ - self.y_)**2) 43 | return dist 44 | 45 | class ObservableState: 46 | def __init__(self, position: Position, velocity: float, yaw_rate: float, turn_signal: int = 0, stop_light: int = 0): 47 | """ 48 | Represents an observable state vector s = [p, v] 49 | Args: 50 | position: position [x, y, heading] 51 | velocity: longitudinal velocity, in m/s 52 | yaw_rate: change in yaw, in rad/s 53 | turn_signal: 0 none, -1 left, 1 right 54 | stop_light: 0 off, 1 on 55 | """ 56 | self.position_ = position 57 | self.velocity_ = velocity 58 | self.yaw_rate_ = yaw_rate 59 | self.turn_signal_ = turn_signal 60 | self.stop_light_ = stop_light 61 | 62 | class Observation: 63 | def __init__(self, observable_state: ObservableState, 64 | agent_type: AgentType, lane_type: str = None): 65 | """ 66 | Represents an observation vector o = [s, b] 67 | Args: 68 | observable_state: the observable state of the corresponding agent 69 | agent_type: the type of the corresponding agent 70 | """ 71 | self.observable_state_ = observable_state 72 | self.agent_type_ = agent_type 73 | # This optional string contains info about the lane segment property 74 | # (eg. roundabout, straight_lane, transition) 75 | self.lane_type_ = lane_type 76 | 77 | class MapRange: 78 | def __init__(self, x_min: float, x_max: float, 79 | y_min: float, y_max: float, city_name: str): 80 | """ 81 | Represents the range of map of concern in a world 82 | Args: 83 | x_min, x_max, y_min, y_max: city coordinates of the map boundary 84 | city_name: either 'MIA' for Miami or 'PIT' for Pittsburgh 85 | """ 86 | self.x_min_ = x_min 87 | self.x_max_ = x_max 88 | self.y_min_ = y_min 89 | self.y_max_ = y_max 90 | self.city_name_ = city_name 91 | 92 | class LaneSegment: 93 | def __init__( 94 | self, 95 | id: int, # in neighborhood-v0 it's str 96 | centerline: np.ndarray, 97 | length: float, 98 | priority: Optional[int] = None, 99 | priority2: Optional[int] = None, 100 | curve_center: Optional[np.ndarray] = None, 101 | curve_radius: Optional[float] = None, 102 | turn_direction: Optional[str] = None, 103 | lane_heading: Optional[float] = None, 104 | is_intersection: Optional[bool] = False, 105 | intersection_id: Optional[str] = None, 106 | has_traffic_control: Optional[bool] = None, 107 | l_neighbor_id: Optional[int] = None, 108 | r_neighbor_id: Optional[int] = None, 109 | predecessors: Optional[Sequence[int]] = [], 110 | successors: Optional[Sequence[int]] = [], 111 | ) -> None: 112 | """Initialize the lane segment. 113 | 114 | Args: 115 | id: Unique lane ID that serves as identifier for this "Way" 116 | centerline: The coordinates of the lane segment's center line. 117 | length: length of this lane 118 | priority: integer indicating lane priority. Larger is more important 119 | curve_center: circle center of the circle which this curve is a part of 120 | curve_radius: radius of the circle which this curve is a part of 121 | has_traffic_control: T/F 122 | turn_direction: 'right', 'left', or None 123 | lane_heading: heading of the direction of a straight lane in radian (eg. S->N lane is pi/2) 124 | is_intersection: Whether or not this lane segment is an intersection - if yes, the intersection_id is filled in 125 | intersection_id: intersection_id 126 | l_neighbor_id: Unique ID for left neighbor 127 | r_neighbor_id: Unique ID for right neighbor 128 | predecessors: The IDs of the lane segments that come after this one 129 | successors: The IDs of the lane segments that come before this one. 130 | 131 | """ 132 | self.id = id 133 | self.centerline = centerline 134 | self.length = length 135 | self.priority = priority 136 | self.priority2 = priority2 137 | self.curve_center = curve_center 138 | self.curve_radius = curve_radius 139 | self.has_traffic_control = has_traffic_control 140 | self.turn_direction = turn_direction 141 | self.lane_heading = lane_heading 142 | self.is_intersection = is_intersection 143 | self.intersection_id = intersection_id 144 | self.l_neighbor_id = l_neighbor_id 145 | self.r_neighbor_id = r_neighbor_id 146 | self.predecessors = predecessors 147 | self.successors = successors 148 | 149 | class Agent: 150 | def __init__(self, 151 | id: str, 152 | observable_state: ObservableState, 153 | agent_type: AgentType, 154 | goal: ObservableState, 155 | width: float, 156 | length: float, 157 | height: float, 158 | observations: list = None, 159 | map_category: MapCategory = None, 160 | use_saved_path: bool = False, 161 | path: list = None, 162 | closest_lane_id_order: int = -1, 163 | closest_waypoint_idx: int = -1): 164 | """ 165 | Represents an agent 166 | Args: 167 | id: unique id representing this agent 168 | observable_state: observable state of the agent itself 169 | agent_type: agent type 170 | goal: observable state of the goal 171 | width, length: lateral and longitudinal size in meters 172 | observations: list of Observations, observations within a certain range. Default empty list 173 | map_category: map category of the current location of the agent. Default None. 174 | use_saved_trajectory: whether to update this agent using saved trajectory. 175 | path: list of np.ndarray, each represents a lane centerline 176 | closest_lane_id_order: the order of the currently closest lane_id in path. Should be updated if path is set. 177 | """ 178 | self.id_ = id 179 | self.observable_state_ = observable_state 180 | self.agent_type_ = agent_type 181 | self.goal_ = goal 182 | self.closest_lane_id_order_ = closest_lane_id_order 183 | self.closest_waypoint_idx_ = closest_waypoint_idx 184 | if observations is not None: 185 | self.observations_ = observations 186 | else: 187 | self.observations_ = [] 188 | self.map_category_ = map_category 189 | # radius in which the agent observes others 190 | self.observation_radius_ = 20.0 191 | # size 192 | self.width_ = width # lateral 193 | self.length_ = length # longitudinal 194 | self.height_ = height # vertical 195 | # use saved trajectory? 196 | self.use_saved_path_ = use_saved_path 197 | if path is not None: 198 | self.path_ = path 199 | else: 200 | self.path_ = [] 201 | # bounding box in world frame 202 | self.bbox_world_frame_ = None 203 | 204 | def set_path(self, path: List[str]) -> None: 205 | """ 206 | Save path 207 | Args: 208 | path: List[str] 209 | """ 210 | if not (self.use_saved_path_): 211 | raise Exception('Cannot set path of an agent that doesn\'t use saved path') 212 | self.path_ = path 213 | 214 | def set_goal(self, goal: ObservableState) -> None: 215 | self.goal_ = goal 216 | 217 | -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sherrychen1120/MIDAS/d5f5c2de00630454db7a2bc8c59fd030358f4dfa/road_interactions_environment/gym_road_interactions/envs/__init__.py -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_intersection_id_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sherrychen1120/MIDAS/d5f5c2de00630454db7a2bc8c59fd030358f4dfa/road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_intersection_id_dict.pkl -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_map_constants.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sherrychen1120/MIDAS/d5f5c2de00630454db7a2bc8c59fd030358f4dfa/road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_map_constants.pkl -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_map_lane_segments.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sherrychen1120/MIDAS/d5f5c2de00630454db7a2bc8c59fd030358f4dfa/road_interactions_environment/gym_road_interactions/envs/maps/neighborhood_v0_map_lane_segments.pkl -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/neighborhood_v4/__init__.py: -------------------------------------------------------------------------------- 1 | from gym_road_interactions.envs.neighborhood_v4.neighborhood_env_v4 import NeighborhoodEnvV4 2 | -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/neighborhood_v4/neighborhood_env_v4_agents.py: -------------------------------------------------------------------------------- 1 | # Closed neighborhood environment with a roundabout, 4 t-intersections 2 | 3 | # Python 4 | import pdb, copy, os, sys 5 | import pickle 6 | import numpy as np 7 | import math 8 | import matplotlib.pyplot as plt 9 | from pathlib import Path 10 | import matplotlib 11 | import logging 12 | from queue import Queue 13 | from typing import Any, Dict, Tuple, List 14 | import cv2 15 | import torch 16 | # Argoverse 17 | from argoverse.utils.se3 import SE3 18 | # gym 19 | import gym 20 | from gym import error, spaces, utils 21 | from gym.utils import seeding 22 | # utils 23 | from gym_road_interactions.utils import create_bbox_world_frame, conditional_log, wrap_to_pi 24 | from gym_road_interactions.viz_utils import visualize_agent 25 | from gym_road_interactions.core import AgentType, Position, Agent, ObservableState, Observation, LaneSegment 26 | from shapely.geometry import Point, Polygon 27 | from .neighborhood_env_v4_utils import calculate_remaining_lane_distance, calculate_traversed_lane_distance, calculate_time_to_collision, calculate_poly_distance, get_two_degree_successors 28 | 29 | from neighborhood_v4_ddqn.models import DDQNAgent, TwinDDQNAgent 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | DT = 0.1 # 10Hz 34 | FAD_COEFF = 0.3 35 | TTC_HORIZON = 20 36 | TTC_VEL_PROD = 9.2 37 | DONE_THRES = 0.5 # distance threshold for completing the task 38 | 39 | # interface for Neighborhood-v4 agents that creates bbox, v_desired, implements drive_along_path and get_lane_priority and calculate_ttc 40 | class NeighborhoodV4AgentInterface(Agent): 41 | def __init__(self, 42 | id: str, 43 | observable_state: ObservableState, 44 | agent_type: AgentType, 45 | goal: ObservableState, 46 | curr_lane_id_order: int, 47 | curr_waypoint_idx: int, 48 | v_desired: float, 49 | default_ttc: float, 50 | lane_segments: List[LaneSegment], 51 | path: list = []): # list of lane_ids 52 | 53 | super(NeighborhoodV4AgentInterface, self).__init__( 54 | id=id, 55 | observable_state=observable_state, 56 | agent_type=agent_type, 57 | goal=goal, 58 | width=2.0, # circle with r = sqrt(2) 59 | length=2.0, 60 | height=1.7, 61 | use_saved_path=True, 62 | path=path, 63 | closest_lane_id_order=curr_lane_id_order, 64 | closest_waypoint_idx=curr_waypoint_idx) 65 | 66 | self.bbox_world_frame_ = create_bbox_world_frame(self) # this function is in gym_road_interactions.utils 67 | self.v_desired_ = v_desired 68 | 69 | self.radius_ = math.sqrt(2) 70 | 71 | # things related to ttc 72 | self.default_ttc_ = default_ttc 73 | self.ttc_thres_ = TTC_VEL_PROD / self.v_desired_ 74 | self.past_ttc_ = [default_ttc, default_ttc, default_ttc] 75 | self.min_ttc_agent_id_ = '' 76 | self.ttc_horizon_ = TTC_HORIZON 77 | self.fad_coeff_ = FAD_COEFF 78 | self.future_horizon_positions_ = None 79 | 80 | # lane segments 81 | self.lane_segments_ = lane_segments 82 | 83 | # path_length 84 | self.path_length_ = self.__calculate_path_length(path) 85 | self.dist_driven_ = 0 86 | 87 | self.done_thres_ = DONE_THRES 88 | 89 | def calculate_ttc(self, agents, dt, ttc_dp) -> None: 90 | # ttc_dp is a table filled with ttc assuming current speed 91 | # calculate current-speed ttc with all other agents 92 | min_ttc = np.min(ttc_dp[0,int(self.id_),:]) 93 | min_ttc_agent_id = np.argmin(ttc_dp[0,int(self.id_),:]) 94 | 95 | self.past_ttc_ = self.past_ttc_[1:] 96 | self.past_ttc_.append(min_ttc) 97 | self.min_ttc_agent_id_ = min_ttc_agent_id 98 | 99 | # DEBUG 100 | # conditional_log(self.log_name_, logger, 101 | # f'agent {self.id_} min ttc agent is {min_ttc_agent_id}', 'info') 102 | 103 | def drive_along_path(self, dt: float, assume_vel: int = 0) -> Tuple[Position, int]: 104 | """ 105 | Returns the end position of this agent if drive along path for dt at current velocity 106 | assume_vel: 1 (assume at v_desired_), 0 (use current velocity), -1 (assume vel = 0) 107 | """ 108 | if assume_vel == 1: 109 | remain_dist = self.v_desired_ * dt 110 | elif assume_vel == 0: 111 | remain_dist = self.observable_state_.velocity_ * dt 112 | elif assume_vel == -1: 113 | remain_dist = 0 114 | elif assume_vel == -2: # special mode: drive backwards 115 | remain_dist = - self.v_desired_ * dt 116 | else: 117 | raise Exception(f'Invalid assume_vel value: {assume_vel}') 118 | 119 | if remain_dist == 0: # velocity = 0 120 | # print(f'agent{self.id_} dap: remain_dist <= 0: velocity: {self.observable_state_.velocity_}, remain_dist = {remain_dist}') 121 | return self.observable_state_.position_, self.closest_lane_id_order_, self.closest_waypoint_idx_ 122 | elif remain_dist > 0: # move forward 123 | curr_pos = copy.deepcopy(self.observable_state_.position_) 124 | closest_lane_id_order = self.closest_lane_id_order_ 125 | 126 | while remain_dist > 0: 127 | closest_lane_id_order, remain_dist, curr_pos = self.__drive_along_lane(closest_lane_id_order, remain_dist, curr_pos) 128 | curr_xy = np.array([[curr_pos.x_, curr_pos.y_]]) 129 | # if self.id_ == '1': 130 | # print(f'closest_lane_id_order={closest_lane_id_order}, curr_xy={curr_xy}') 131 | 132 | # find closest waypoint index on this updated lane 133 | lane_id = self.path_[closest_lane_id_order] 134 | lane_cl = self.lane_segments_[lane_id].centerline 135 | curr_xy = np.array([[curr_pos.x_, curr_pos.y_]]) 136 | curr_distances = np.linalg.norm((np.tile(curr_xy, (len(lane_cl), 1)) - lane_cl), axis=1) 137 | closest_waypoint_idx = np.argmin(curr_distances) 138 | 139 | # print(f'agent{self.id_} dap: closest_lane_id_order={closest_lane_id_order}, curr_xy={curr_xy}') 140 | return curr_pos, closest_lane_id_order, closest_waypoint_idx 141 | else: # move backward 142 | curr_pos = copy.deepcopy(self.observable_state_.position_) 143 | closest_lane_id_order = self.closest_lane_id_order_ 144 | remain_dist = - remain_dist # reverse it so it's easier to debug... just remember that we are driving backwards 145 | 146 | while remain_dist > 0: 147 | closest_lane_id_order, remain_dist, curr_pos = self.__drive_backwards_along_lane(closest_lane_id_order, remain_dist, curr_pos) 148 | curr_xy = np.array([[curr_pos.x_, curr_pos.y_]]) 149 | # if self.id_ == '1': 150 | # print(f'closest_lane_id_order={closest_lane_id_order}, curr_xy={curr_xy}') 151 | 152 | # find closest waypoint index on this updated lane 153 | lane_id = self.path_[closest_lane_id_order] 154 | lane_cl = self.lane_segments_[lane_id].centerline 155 | curr_xy = np.array([[curr_pos.x_, curr_pos.y_]]) 156 | curr_distances = np.linalg.norm((np.tile(curr_xy, (len(lane_cl), 1)) - lane_cl), axis=1) 157 | closest_waypoint_idx = np.argmin(curr_distances) 158 | 159 | # print(f'agent{self.id_} dap: closest_lane_id_order={closest_lane_id_order}, curr_xy={curr_xy}') 160 | return curr_pos, closest_lane_id_order, closest_waypoint_idx 161 | 162 | def __drive_along_lane(self, closest_lane_id_order: int, remain_dist: float, curr_pos: Position) -> Tuple[int, float, Position]: 163 | # returns (closest_lane_id_order, remain_dist, curr_pos) 164 | # 1. check for zero remain_dist 165 | if remain_dist == 0: 166 | return closest_lane_id_order, remain_dist, curr_pos 167 | # 2. move along current lane 168 | lane_id = self.path_[closest_lane_id_order] 169 | curr_lane_seg = self.lane_segments_[lane_id] 170 | remain_lane_distance = calculate_remaining_lane_distance(lane_id, curr_pos, self.lane_segments_) 171 | # print(f'remain_lane_distance: {remain_lane_distance}') 172 | # if remain lane distance is >= remain_dist, we end up on the same lane 173 | if remain_lane_distance >= remain_dist: 174 | return_pos = curr_pos 175 | # straight lanes 176 | if curr_lane_seg.lane_heading is not None: 177 | return_pos.x_ = curr_pos.x_ + remain_dist * math.cos(curr_lane_seg.lane_heading) 178 | return_pos.y_ = curr_pos.y_ + remain_dist * math.sin(curr_lane_seg.lane_heading) 179 | # curves 180 | else: 181 | r = curr_lane_seg.curve_radius 182 | drive_theta = remain_dist / r 183 | if curr_lane_seg.turn_direction == 'left': # counter-clockwise 184 | return_pos.heading_ = curr_pos.heading_ + drive_theta 185 | # theta on circle 186 | curr_theta = curr_pos.heading_ - math.pi / 2.0 187 | else: # clockwise 188 | return_pos.heading_ = curr_pos.heading_ - drive_theta 189 | # theta on circle 190 | curr_theta = curr_pos.heading_ + math.pi / 2.0 191 | return_pos.x_ = self.lane_segments_[lane_id].curve_center[0] + r * math.cos(curr_theta) 192 | return_pos.y_ = self.lane_segments_[lane_id].curve_center[1] + r * math.sin(curr_theta) 193 | return closest_lane_id_order, 0.0, return_pos 194 | else: 195 | remain_dist -= remain_lane_distance 196 | if closest_lane_id_order < len(self.path_) - 1: 197 | closest_lane_id_order += 1 198 | next_lane_id = self.path_[closest_lane_id_order] 199 | # update curr_pos to be the first waypoint of next lane segment 200 | next_point = self.lane_segments_[next_lane_id].centerline[0,:] 201 | else: 202 | next_lane_id = lane_id 203 | # we are on the last lane in path. move to the last point on the lane and we are done 204 | next_point = self.lane_segments_[lane_id].centerline[-1,:] 205 | remain_dist = 0.0 206 | 207 | curr_pos.x_, curr_pos.y_ = next_point[0], next_point[1] 208 | # correct heading if you come out of a curve (only curves have turn_direction) 209 | if curr_lane_seg.turn_direction is not None: 210 | r = curr_lane_seg.curve_radius 211 | drive_theta = remain_lane_distance / r # remain_lane_distance is the distance that we've driven on this past lane 212 | if curr_lane_seg.turn_direction == 'right': 213 | curr_pos.heading_ = curr_pos.heading_ - drive_theta 214 | elif curr_lane_seg.turn_direction == 'left': 215 | curr_pos.heading_ = curr_pos.heading_ + drive_theta 216 | 217 | curr_pos.heading_ = wrap_to_pi(curr_pos.heading_) 218 | return closest_lane_id_order, remain_dist, curr_pos 219 | 220 | # remain_dist is positive 221 | def __drive_backwards_along_lane(self, closest_lane_id_order: int, remain_dist: float, curr_pos: Position) -> Tuple[int, float, Position]: 222 | # returns (closest_lane_id_order, remain_dist, curr_pos) 223 | # 1. check for zero remain_dist 224 | if remain_dist == 0: 225 | return closest_lane_id_order, remain_dist, curr_pos 226 | # 2. move backwards along current lane 227 | lane_id = self.path_[closest_lane_id_order] 228 | curr_lane_seg = self.lane_segments_[lane_id] 229 | traversed_lane_distance = calculate_traversed_lane_distance(lane_id, curr_pos, self.lane_segments_) 230 | # if remain lane distance is >= remain_dist, we end up on the same lane 231 | if traversed_lane_distance >= remain_dist: 232 | return_pos = curr_pos 233 | # straight lanes 234 | if curr_lane_seg.lane_heading is not None: 235 | return_pos.x_ = curr_pos.x_ - remain_dist * math.cos(curr_lane_seg.lane_heading) 236 | return_pos.y_ = curr_pos.y_ - remain_dist * math.sin(curr_lane_seg.lane_heading) 237 | # curves 238 | else: 239 | r = curr_lane_seg.curve_radius 240 | drive_theta = remain_dist / r 241 | if curr_lane_seg.turn_direction == 'left': # counter-clockwise 242 | return_pos.heading_ = curr_pos.heading_ - drive_theta 243 | # theta on circle 244 | curr_theta = return_pos.heading_ - math.pi / 2.0 245 | else: # clockwise 246 | return_pos.heading_ = curr_pos.heading_ + drive_theta 247 | # theta on circle 248 | curr_theta = return_pos.heading_ + math.pi / 2.0 249 | return_pos.x_ = self.lane_segments_[lane_id].curve_center[0] + r * math.cos(curr_theta) 250 | return_pos.y_ = self.lane_segments_[lane_id].curve_center[1] + r * math.sin(curr_theta) 251 | return closest_lane_id_order, 0.0, return_pos 252 | else: 253 | remain_dist -= traversed_lane_distance 254 | if closest_lane_id_order > 0: 255 | closest_lane_id_order -= 1 256 | prev_lane_id = self.path_[closest_lane_id_order] 257 | # update curr_pos to be the last waypoint of previous lane segment 258 | prev_point = self.lane_segments_[prev_lane_id].centerline[-1,:] 259 | else: 260 | prev_lane_id = lane_id 261 | # we are on the first lane in path. move to the first point on the lane and we are done 262 | prev_point = self.lane_segments_[lane_id].centerline[0,:] 263 | remain_dist = 0.0 264 | 265 | curr_pos.x_, curr_pos.y_ = prev_point[0], prev_point[1] 266 | # correct heading if you come out of a curve (only curves have turn_direction) 267 | if curr_lane_seg.turn_direction is not None: 268 | r = curr_lane_seg.curve_radius 269 | drive_theta = traversed_lane_distance / r # traversed_lane_distance is the distance that we've driven on this past lane 270 | if curr_lane_seg.turn_direction == 'right': 271 | curr_pos.heading_ = curr_pos.heading_ + drive_theta 272 | elif curr_lane_seg.turn_direction == 'left': 273 | curr_pos.heading_ = curr_pos.heading_ - drive_theta 274 | 275 | curr_pos.heading_ = wrap_to_pi(curr_pos.heading_) 276 | return closest_lane_id_order, remain_dist, curr_pos 277 | 278 | def initialize_future_horizon_positions(self): 279 | # if v = 0 280 | curr_x = self.observable_state_.position_.x_ 281 | curr_y = self.observable_state_.position_.y_ 282 | self.future_horizon_positions_ = np.hstack((np.ones((self.ttc_horizon_,1)) * curr_x, \ 283 | np.ones((self.ttc_horizon_,1)) * curr_y)).reshape((1,self.ttc_horizon_,2)) # (1 * ttc_horizon * 2) 284 | 285 | # if v > 0 286 | if len(self.path_) > 0: # if path length > 0, give the actual v > 0 future horizon positions 287 | temp = [] 288 | for i in range(1, self.ttc_horizon_+1): # 1 to 10 289 | curr_dt = i * DT 290 | pred_pos, _, _ = self.drive_along_path(curr_dt, 1) 291 | temp.append(np.array([[pred_pos.x_, pred_pos.y_]])) 292 | temp = np.array(temp).reshape((1,self.ttc_horizon_,2)) 293 | self.future_horizon_positions_ = np.concatenate((self.future_horizon_positions_, temp), axis=0) # (2 * ttc_horizon * 2) 294 | else: # if path length == 0, we want this agent to stay here forever. just repeat v=0 fhp 295 | self.future_horizon_positions_ = np.concatenate((self.future_horizon_positions_, 296 | self.future_horizon_positions_), axis=0) # (2 * ttc_horizon * 2) 297 | 298 | # update future_horizon_positions_ forward by 1 ts 299 | def update_future_horizon_positions(self): 300 | if self.observable_state_.velocity_ > 0: # we only need to update fhw if we have moved in this current step 301 | # if v = 0 302 | curr_x = self.observable_state_.position_.x_ 303 | curr_y = self.observable_state_.position_.y_ 304 | self.future_horizon_positions_[0,:,:] = np.hstack((np.ones((self.ttc_horizon_,1)) * curr_x, \ 305 | np.ones((self.ttc_horizon_,1)) * curr_y)) 306 | # if v > 0 307 | pred_pos, _, _ = self.drive_along_path(DT * self.ttc_horizon_, 1) 308 | next_position = np.array([[pred_pos.x_, pred_pos.y_]]) 309 | self.future_horizon_positions_[1,:,:] = np.vstack((self.future_horizon_positions_[1,1:,:], next_position)) 310 | 311 | def __calculate_path_length(self, path): 312 | tmp_length = 0 313 | for i in range(len(path)): 314 | lane_id = path[i] 315 | tmp_length += self.lane_segments_[lane_id].length 316 | self.path_length_ = tmp_length 317 | 318 | def apply_action(self, action, dt): 319 | # action: 0 stop, 1 go, -1 go backwards 320 | # set velocity 321 | if action == 0: 322 | self.observable_state_.velocity_ = 0.0 323 | self.observable_state_.yaw_rate_ = 0.0 324 | else: 325 | self.observable_state_.velocity_ = self.v_desired_ 326 | if action == 1: 327 | # drive along path with speed 328 | new_pos, new_closest_lane_id_order, new_closest_point_idx = self.drive_along_path(dt) 329 | else: # action == -1 330 | new_pos, new_closest_lane_id_order, new_closest_point_idx = self.drive_along_path(dt, assume_vel=-2) 331 | self.observable_state_.yaw_rate_ = wrap_to_pi((new_pos.heading_ - self.observable_state_.position_.heading_) / dt) 332 | # logger.debug(f'[Agent {self.id_}] yaw_rate = {self.observable_state_.yaw_rate_}') 333 | self.observable_state_.position_ = new_pos 334 | self.closest_lane_id_order_ = new_closest_lane_id_order 335 | self.closest_waypoint_idx_ = new_closest_point_idx 336 | self.dist_driven_ += self.observable_state_.velocity_ * dt 337 | # logger.info(f'ego action: 1 closest_lane_id_order_={self.closest_lane_id_order_} closest_waypoint_idx_={self.closest_waypoint_idx_}') 338 | # don't forget to update bbox!! 339 | self.bbox_world_frame_ = create_bbox_world_frame(self) 340 | 341 | def fad_distance(self): 342 | return 1 + self.fad_coeff_ * self.ttc_thres_ * self.v_desired_ 343 | 344 | class NeighborhoodV4DefaultAgent(NeighborhoodV4AgentInterface): 345 | def __init__(self, 346 | id: str, 347 | observable_state: ObservableState, 348 | goal: ObservableState, 349 | curr_lane_id_order: int, 350 | curr_waypoint_idx: int, 351 | # agg_level: int, # [0: mild, 1: average, 2: aggressive] 352 | default_ttc: float, 353 | lane_segments: List[LaneSegment], 354 | path: list = [], # list of lane_ids 355 | stochastic_stop: bool = False, # whether this agent can choose to stop for ego with a certain prob 356 | log_name: str = None, 357 | rl_model_path: str = None, # if set, this agent follows a saved RL policy 358 | train_configs: dict = None, # used if rl_model_path is not None 359 | device = None, # used if rl_model_path is not None 360 | v_desired: float = None, # use this if we want to specify agent velocity 361 | b1: float = None 362 | ): 363 | self.log_name_ = log_name 364 | 365 | self.rl_agent_ = None 366 | self.v_desired_ = v_desired 367 | self.b1_ = b1 368 | 369 | # == DEPRECATED == 370 | # Note: agg_level is determined with b1 371 | # cf: b1 (driver_type) is a continuous float [-1,1] that determines v_desired_ 372 | # you can view agg_level as a discretization of b1 373 | # if (v_desired >= 9): 374 | # self.agg_level_ = 2 375 | # elif (v_desired >= 6.8): 376 | # self.agg_level_ = 1 377 | # else: 378 | # self.agg_level_ = 0 379 | # ==== 380 | if (self.b1_ >= 1./3.): 381 | self.agg_level_ = 2 382 | elif (v_desired >= -1./3.): 383 | self.agg_level_ = 1 384 | else: 385 | self.agg_level_ = 0 386 | 387 | self.stop_prob_ = - 0.5 * self.b1_ + 0.5 388 | 389 | # if rl_model_path is None: 390 | # # choose v_desired and stop_prob based on agg_level 391 | # if agg_level == 2: 392 | # self.v_desired_ = np.random.uniform(low=9.0, high=11.0) if (v_desired is None) else v_desired 393 | # elif agg_level == 1: 394 | # self.v_desired_ = np.random.uniform(low=6.8, high=8.8) if (v_desired is None) else v_desired 395 | # elif agg_level == 0: 396 | # self.v_desired_ = np.random.uniform(low=4.6, high=6.6) if (v_desired is None) else v_desired 397 | # else: 398 | # raise Exception(f'Invalid agg_level: {agg_level}') 399 | # # log 400 | # # conditional_log(self.log_name_, logger, 401 | # # f'agent {id} agg_level: {agg_level}, v_desired_: {self.v_desired_}', 402 | # # 'debug') 403 | if rl_model_path is not None: 404 | # self.v_desired_ = 5.6 405 | # initialize agent 406 | self.rl_model_path_ = rl_model_path 407 | if train_configs['model'] == 'DDQN': 408 | self.rl_agent_ = DDQNAgent(train_configs, device) 409 | elif train_configs['model'] == 'TwinDDQN': 410 | self.rl_agent_ = TwinDDQNAgent(train_configs, device) 411 | self.rl_agent_.load(self.rl_model_path_) 412 | self.rl_agent_.value_net.eval() 413 | # log 414 | # conditional_log(self.log_name_, logger, 415 | # f'agent {id} is RL agent with policy {rl_model_path}', 416 | # 'debug') 417 | 418 | 419 | super(NeighborhoodV4DefaultAgent, self).__init__( 420 | id=id, 421 | observable_state=observable_state, 422 | agent_type=AgentType.other_vehicle, 423 | goal=goal, 424 | curr_lane_id_order=curr_lane_id_order, 425 | curr_waypoint_idx=curr_waypoint_idx, 426 | v_desired = self.v_desired_, 427 | default_ttc = default_ttc, 428 | lane_segments = lane_segments, 429 | path=path) 430 | 431 | self.observable_state_.velocity_ = self.v_desired_ 432 | self.stochastic_stop_ = stochastic_stop 433 | self.on_stop_ = False # If agent is on stop, it will wait until past_ttc[2] > ttc_thres_ 434 | 435 | self.device_ = device 436 | # track the positions at future 10 ts 437 | self.initialize_future_horizon_positions() 438 | 439 | self.has_passed_goal_ = False 440 | 441 | def should_stop(self, dt, agents, ttc_dp, ttc_break_tie=None, agent_baseline=None) -> bool: 442 | # there are 3 scenarios: 443 | # 1. when following another agent, whoever that is, we will stop when dist <= 1 + 1.5 * self.ttc_thres_ * self.v_desired_ 444 | # 2. otherwise, 445 | # if min_ttc3 < default_ttc (If I go and everyone else goes): 446 | # find that agent. ttc1 is the min ttc if I stop, ttc2 is the min ttc if I go. 447 | # should_stop = (ttc1 > ttc2) [for ego: should_stop = (ttc >= ttc2) 448 | # 3. (NOT IMPLEMENTED) if the other agent is ego, and stochastic_stop = True & sample <= stop_prob, return False 449 | 450 | # look up ttc1 (if I stop), ttc2 (if I go) 451 | # if any of the 4 ttcs are within range, capture that agent 452 | # (xs, ys) 453 | ids_ttc_in_range_raw = np.array([]).astype(int) 454 | for j in range(4): 455 | # ids_ttc_in_range = np.union1d(ids_ttc_in_range, np.where(ttc_dp[j, int(self.id_), :] < self.default_ttc_)[0]) 456 | ids_ttc_in_range_raw = np.union1d(ids_ttc_in_range_raw, np.where(ttc_dp[j, int(self.id_), :] <= self.ttc_thres_)[0]) 457 | # for crossing vehicles, only consider my interaction with the head of a series of vehicles in ids_ttc_in_range 458 | ids_ttc_in_range = [] 459 | for raw_id in ids_ttc_in_range_raw: 460 | to_add = True 461 | for selected_id in ids_ttc_in_range: 462 | selected_agent_lane_id = agents[str(selected_id)].path_[agents[str(selected_id)].closest_lane_id_order_] 463 | selected_agent_waypoint_idx = agents[str(selected_id)].closest_waypoint_idx_ 464 | raw_agent_lane_id = agents[str(raw_id)].path_[agents[str(raw_id)].closest_lane_id_order_] 465 | raw_agent_waypoint_idx = agents[str(raw_id)].closest_waypoint_idx_ 466 | # if raw agent is behind selected agent, don't add 467 | if (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx < selected_agent_waypoint_idx) or \ 468 | selected_agent_lane_id in self.lane_segments_[raw_agent_lane_id].successors: 469 | to_add = False 470 | break 471 | # if selected agent is behind raw agent, remove selected agent 472 | elif (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx > selected_agent_waypoint_idx) or \ 473 | raw_agent_lane_id in self.lane_segments_[selected_agent_lane_id].successors: 474 | ids_ttc_in_range.remove(selected_id) 475 | if to_add: 476 | ids_ttc_in_range.append(raw_id) 477 | 478 | ids_should_stop_for = [] 479 | 480 | should_stop = False 481 | my_lane_id = self.path_[self.closest_lane_id_order_] 482 | my_waypoint_idx = self.closest_waypoint_idx_ 483 | my_two_degree_successors = get_two_degree_successors(self.lane_segments_, my_lane_id) 484 | for other_agent_id, other_agent in agents.items(): 485 | if (other_agent_id == self.id_): 486 | continue 487 | if len(other_agent.path_) == 0: # special case: other agent path length = 0 488 | continue 489 | # if other agent is in front of me 490 | other_agent_lane_id = other_agent.path_[other_agent.closest_lane_id_order_] 491 | other_agent_waypoint_idx = other_agent.closest_waypoint_idx_ 492 | dist = calculate_poly_distance(self, other_agent) 493 | 494 | # conditional_log(self.log_name_, logger, 495 | # f'agent {self.id_} dist to agent {other_agent_id} | dist={dist}', 'info') # DEBUG 496 | # == front-vehicle logic == 497 | if (agent_baseline not in [5]) and \ 498 | ((other_agent_lane_id == my_lane_id and my_waypoint_idx < other_agent_waypoint_idx) or \ 499 | (other_agent_lane_id in my_two_degree_successors)): 500 | if (dist <= 1 + self.fad_coeff_ * self.ttc_thres_ * self.v_desired_): # assume ttc_thres_ self stop time if the other agent suddenly stops 501 | # conditional_log(self.log_name_, logger, 502 | # f'agent {self.id_} agent_baseline={agent_baseline} | should stop for front agent {other_agent_id} | dist={dist}', 'debug') # DEBUG 503 | est_decision_ttc = dist / self.v_desired_ 504 | ids_should_stop_for.append((int(other_agent_id), est_decision_ttc)) 505 | should_stop = (should_stop or True) # we want to go through everyone for debug 506 | else: 507 | i = int(other_agent_id) 508 | if i in ids_ttc_in_range: 509 | # == back-vehicle logic == 510 | # if this ttc3 agent is behind me, don't consider whether we should stop for them 511 | if (agent_baseline not in [5]) and \ 512 | ((my_lane_id == other_agent_lane_id and other_agent_waypoint_idx < my_waypoint_idx) or \ 513 | my_lane_id in self.lane_segments_[other_agent_lane_id].successors): 514 | # conditional_log(self.log_name_, logger, f'agent {self.id_} agent_baseline={agent_baseline} | with agent {i} behind. shouldn\'t stop.', 'debug') 515 | should_stop = (should_stop or False) 516 | # == ttc logic == 517 | else: 518 | ttc3 = ttc_dp[3, int(self.id_), i] # if we both go 519 | ttc1 = ttc_dp[1, int(self.id_), i] # if I stop 520 | ttc2 = ttc_dp[2, int(self.id_), i] # if I go 521 | # max_ttc_go = max(ttc2, ttc3) 522 | if ttc1 == ttc2: 523 | # based on who's agg_level = 0 524 | if ttc_break_tie is None or ttc_break_tie == 'agg_level=0': 525 | # if I'm mild, the other is not, I should stop 526 | if self.agg_level_ == 0 and other_agent.agg_level_ != 0: 527 | criteria = True 528 | # if both are mild, break tie: larger id stops (this is to make sure agent stops for ego) 529 | elif self.agg_level_ == 0 and other_agent.agg_level_ == 0: 530 | criteria = (int(self.id_) > i) 531 | # if I'm not mild and the other is mild, I shouldn't stop 532 | elif self.agg_level_ != 0 and other_agent.agg_level_ == 0: 533 | criteria = False 534 | # if both are not mild, break tie: larger id stops 535 | elif self.agg_level_ != 0 and other_agent.agg_level_ != 0: 536 | criteria = (int(self.id_) > i) 537 | # conditional_log(self.log_name_, logger, 538 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 539 | # based on b1 540 | elif ttc_break_tie == 'b1': 541 | if self.b1_ != other_agent.b1_: 542 | criteria = (self.b1_ < other_agent.b1_) 543 | else: 544 | temp = np.random.random() 545 | if temp < 0.5: 546 | criteria = True 547 | else: 548 | criteria = False 549 | # conditional_log(self.log_name_, logger, 550 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 551 | # based on coin flip 552 | elif ttc_break_tie == 'random': 553 | ttc_sample = np.random.random() # [0,1) 554 | criteria = (ttc_sample < self.stop_prob_) 555 | # conditional_log(self.log_name_, logger, 556 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | sample={ttc_sample} | stop_prob = {self.stop_prob_} | criteria={criteria}', 'debug') 557 | else: 558 | raise Exception(f"Invalid ttc_break_tie: {ttc_break_tie}") 559 | else: 560 | criteria = (ttc1 > ttc2) 561 | # conditional_log(self.log_name_, logger, 562 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc | criteria={criteria}', 'debug') 563 | # debug 564 | if criteria: 565 | ids_should_stop_for.append((i, ttc2)) 566 | 567 | should_stop = (should_stop or criteria) # we could break loop here, but calculate all so we can debug 568 | # conditional_log(self.log_name_, logger, 569 | # f'agent {self.id_} (agg={self.agg_level_}) ttc with agent {i} (agg={other_agent.agg_level_}) | ttc3 (both go) = {ttc3} | ttc1 (I stop) = {ttc1} | ttc2 (I go) = {ttc2} | should stop: {criteria}', 'debug') 570 | 571 | # if should_stop: 572 | # conditional_log(self.log_name_, logger, 573 | # f'agent {self.id_} should stop for: {str(ids_should_stop_for)}', 'debug') 574 | 575 | return should_stop, ids_should_stop_for 576 | 577 | def step(self, dt, agents, ttc_dp, agent_states, agent_action_noise=0, ttc_break_tie=None, agent_baseline=None) -> None: 578 | should_stop, ids_should_stop_for = self.should_stop(dt, agents, ttc_dp, ttc_break_tie, agent_baseline) 579 | action = 0 580 | # if not rl agent, plan as usual 581 | if self.rl_agent_ is None: 582 | # conditional_log(self.log_name_, logger, 583 | # f'agent {self.id_} ttc: {str(self.past_ttc_)}', 'info') # DEBUG 584 | # first checking whether it's possible to switch out of on_stop_ 585 | if self.on_stop_ and (not should_stop): 586 | # conditional_log(self.log_name_, logger, 587 | # f'agent {self.id_} leaving on_stop_', 'debug') 588 | self.on_stop_ = False 589 | 590 | if not self.on_stop_: 591 | # decide whether to enter on_stop_ 592 | if should_stop: 593 | # stop 594 | # conditional_log(self.log_name_, logger, 595 | # f'agent {self.id_} entering on_stop_. action: 0', 'debug') 596 | action = 0 597 | self.on_stop_ = True 598 | else: 599 | action = 1 600 | # if rl agent, use rl agent to select action 601 | else: 602 | parametric_state = agent_states[int(self.id_) - 1] 603 | parametric_state_ts = torch.from_numpy(parametric_state).unsqueeze(0).float().to(self.device_) # 1*(state_dim) 604 | rl_action = self.rl_agent_.select_action(parametric_state_ts, 0, test=True) 605 | action = rl_action 606 | 607 | # agent_action_noise 608 | if agent_action_noise > 0: 609 | agent_action_noise_sample = np.random.random() 610 | if (agent_action_noise_sample <= agent_action_noise): 611 | action = int(1 - action) 612 | conditional_log(self.log_name_, logger, 613 | f'agent {self.id_} switch action to {action} given agent_action_noise={agent_action_noise}', 'debug') # DEBUG 614 | 615 | self.apply_action(action, dt) 616 | 617 | # update self.future_horizon_positions_ 618 | self.update_future_horizon_positions() 619 | 620 | # update whether I have passed goal 621 | # dist_to_goal = self.__compute_dist_to_observable_state(self.observable_state_, self.goal_) 622 | # if (dist_to_goal < self.done_thres_): 623 | # conditional_log(self.log_name_, logger, f'Agent {self.id_} passed goal with distance to goal: {dist_to_goal}', 'debug') 624 | # self.has_passed_goal_ = True 625 | 626 | return ids_should_stop_for 627 | 628 | def force_action(self, action, dt): 629 | # force the agent to take a certain action. Only used for break stop_for cycles or creating test set using backwards driving 630 | self.apply_action(action, dt) 631 | # update self.future_horizon_positions_ 632 | self.update_future_horizon_positions() 633 | 634 | def __compute_dist_to_observable_state(self, 635 | this_observable_state: ObservableState, 636 | other_observable_state: ObservableState) -> float: 637 | # to cover up a bug in datasets 638 | if isinstance(this_observable_state, Position): 639 | this_pos = this_observable_state 640 | else: 641 | this_pos = this_observable_state.position_ 642 | if isinstance(other_observable_state, Position): 643 | other_pos = other_observable_state 644 | else: 645 | other_pos = other_observable_state.position_ 646 | 647 | pos_dist = this_pos.calculate_distance(other_pos) 648 | # vel_dist = abs(this_observable_state.velocity_ - other_observable_state.velocity_) 649 | # yaw_rate_dist = abs(this_observable_state.yaw_rate_ - other_observable_state.yaw_rate_) 650 | return pos_dist 651 | 652 | # Wrapper around NeighborhoodV4DefaultAgent and NeighborhoodV3DefaultAgent objects that were saved in interaction 653 | # and collision sets in order to modify agent should_stop behavior 654 | class NeighborhoodV4DefaultAgentWrapper(object): 655 | def __init__(self, 656 | saved_agent: NeighborhoodV4DefaultAgent): 657 | self.saved_agent_ = copy.deepcopy(saved_agent) 658 | self.__update_data_members() 659 | 660 | def calculate_ttc(self, agents, dt, ttc_dp) -> None: 661 | rst = self.saved_agent_.calculate_ttc(agents, dt, ttc_dp) 662 | self.__update_data_members() 663 | return rst 664 | 665 | def drive_along_path(self, dt: float, assume_vel: int = 0) -> Tuple[Position, int]: 666 | rst = self.saved_agent_.drive_along_path(dt, assume_vel) 667 | self.__update_data_members() 668 | return rst 669 | 670 | def initialize_future_horizon_positions(self): 671 | self.saved_agent_.initialize_future_horizon_positions() 672 | self.__update_data_members() 673 | 674 | def update_future_horizon_positions(self): 675 | self.saved_agent_.update_future_horizon_positions() 676 | self.__update_data_members() 677 | 678 | def apply_action(self, action, dt): 679 | self.saved_agent_.apply_action(action, dt) 680 | self.__update_data_members() 681 | 682 | def fad_distance(self): 683 | return self.saved_agent_.fad_distance() 684 | 685 | def should_stop_wrapped(self, dt, agents, ttc_dp, ttc_break_tie=None, agent_baseline=None) -> bool: 686 | # look up ttc1 (if I stop), ttc2 (if I go) 687 | # if any of the 4 ttcs are within range, capture that agent 688 | # (xs, ys) 689 | ids_ttc_in_range_raw = np.array([]).astype(int) 690 | for j in range(4): 691 | # ids_ttc_in_range = np.union1d(ids_ttc_in_range, np.where(ttc_dp[j, int(self.id_), :] < self.default_ttc_)[0]) 692 | ids_ttc_in_range_raw = np.union1d(ids_ttc_in_range_raw, np.where(ttc_dp[j, int(self.id_), :] <= self.ttc_thres_)[0]) 693 | # for crossing vehicles, only consider my interaction with the head of a series of vehicles in ids_ttc_in_range 694 | ids_ttc_in_range = [] 695 | for raw_id in ids_ttc_in_range_raw: 696 | to_add = True 697 | for selected_id in ids_ttc_in_range: 698 | selected_agent_lane_id = agents[str(selected_id)].path_[agents[str(selected_id)].closest_lane_id_order_] 699 | selected_agent_waypoint_idx = agents[str(selected_id)].closest_waypoint_idx_ 700 | raw_agent_lane_id = agents[str(raw_id)].path_[agents[str(raw_id)].closest_lane_id_order_] 701 | raw_agent_waypoint_idx = agents[str(raw_id)].closest_waypoint_idx_ 702 | # if raw agent is behind selected agent, don't add 703 | if (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx < selected_agent_waypoint_idx) or \ 704 | selected_agent_lane_id in self.lane_segments_[raw_agent_lane_id].successors: 705 | to_add = False 706 | break 707 | # if selected agent is behind raw agent, remove selected agent 708 | elif (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx > selected_agent_waypoint_idx) or \ 709 | raw_agent_lane_id in self.lane_segments_[selected_agent_lane_id].successors: 710 | ids_ttc_in_range.remove(selected_id) 711 | if to_add: 712 | ids_ttc_in_range.append(raw_id) 713 | 714 | ids_should_stop_for = [] 715 | 716 | should_stop = False 717 | my_lane_id = self.path_[self.closest_lane_id_order_] 718 | my_waypoint_idx = self.closest_waypoint_idx_ 719 | my_two_degree_successors = get_two_degree_successors(self.lane_segments_, my_lane_id) 720 | for other_agent_id, other_agent in agents.items(): 721 | if (other_agent_id == self.id_): 722 | continue 723 | if len(other_agent.path_) == 0: # special case: other agent path length = 0 724 | continue 725 | 726 | # == front-vehicle logic == 727 | # if other agent is in front of me 728 | other_agent_lane_id = other_agent.path_[other_agent.closest_lane_id_order_] 729 | other_agent_waypoint_idx = other_agent.closest_waypoint_idx_ 730 | dist = calculate_poly_distance(self, other_agent) 731 | # conditional_log(self.log_name_, logger, 732 | # f'agent {self.id_} dist to agent {other_agent_id} | dist={dist}', 'info') # DEBUG 733 | if (agent_baseline not in [5]) and \ 734 | ((other_agent_lane_id == my_lane_id and my_waypoint_idx < other_agent_waypoint_idx) or \ 735 | (other_agent_lane_id in my_two_degree_successors)): 736 | if (dist <= 1 + self.fad_coeff_ * self.ttc_thres_ * self.v_desired_): # assume ttc_thres_ self stop time if the other agent suddenly stops 737 | # conditional_log(self.log_name_, logger, 738 | # f'agent {self.id_} agent_baseline={agent_baseline} | should stop for front agent {other_agent_id} | dist={dist}', 'debug') # DEBUG 739 | est_decision_ttc = dist / self.v_desired_ 740 | ids_should_stop_for.append((int(other_agent_id), est_decision_ttc)) 741 | should_stop = (should_stop or True) # we want to go through everyone for debug 742 | else: 743 | i = int(other_agent_id) 744 | if i in ids_ttc_in_range: 745 | # == back-vehicle logic == 746 | # if this ttc3 agent is behind me, don't consider whether we should stop for them 747 | if (agent_baseline not in [5]) and \ 748 | ((my_lane_id == other_agent_lane_id and other_agent_waypoint_idx < my_waypoint_idx) or \ 749 | my_lane_id in self.lane_segments_[other_agent_lane_id].successors): 750 | # conditional_log(self.log_name_, logger, f'agent {self.id_} agent_baseline={agent_baseline} | with agent {i} behind. shouldn\'t stop.', 'debug') 751 | should_stop = (should_stop or False) 752 | # == ttc logic == 753 | elif (agent_baseline not in [6]): 754 | ttc3 = ttc_dp[3, int(self.id_), i] # if we both go 755 | ttc1 = ttc_dp[1, int(self.id_), i] # if I stop 756 | ttc2 = ttc_dp[2, int(self.id_), i] # if I go 757 | # max_ttc_go = max(ttc2, ttc3) 758 | if ttc1 == ttc2: 759 | # based on who's agg_level = 0 760 | if ttc_break_tie is None or ttc_break_tie == 'agg_level=0': 761 | # if I'm mild, the other is not, I should stop 762 | if self.agg_level_ == 0 and other_agent.agg_level_ != 0: 763 | criteria = True 764 | # if both are mild, break tie: larger id stops (this is to make sure agent stops for ego) 765 | elif self.agg_level_ == 0 and other_agent.agg_level_ == 0: 766 | criteria = (int(self.id_) > i) 767 | # if I'm not mild and the other is mild, I shouldn't stop 768 | elif self.agg_level_ != 0 and other_agent.agg_level_ == 0: 769 | criteria = False 770 | # if both are not mild, break tie: larger id stops 771 | elif self.agg_level_ != 0 and other_agent.agg_level_ != 0: 772 | criteria = (int(self.id_) > i) 773 | # conditional_log(self.log_name_, logger, 774 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 775 | # based on b1 776 | elif ttc_break_tie == 'b1': 777 | if self.b1_ != other_agent.b1_: 778 | criteria = (self.b1_ < other_agent.b1_) 779 | else: 780 | temp = np.random.random() 781 | if temp < 0.5: 782 | criteria = True 783 | else: 784 | criteria = False 785 | # conditional_log(self.log_name_, logger, 786 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 787 | # based on coin flip 788 | elif ttc_break_tie == 'random': 789 | ttc_sample = np.random.random() # [0,1) 790 | criteria = (ttc_sample < self.stop_prob_) 791 | # conditional_log(self.log_name_, logger, 792 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | sample={ttc_sample} | stop_prob = {self.stop_prob_} | criteria={criteria}', 'debug') 793 | elif ttc_break_tie == 'id': 794 | # between agents: larger id go 795 | if (other_agent.id_ != '0'): 796 | criteria = (int(self.id_) < int(other_agent.id_)) 797 | # between agent and ego: larger pseudo_id go 798 | else: 799 | criteria = (int(self.id_) < int(other_agent.pseudo_id_)) 800 | log_str = f'agent {self.id_} other_agent {other_agent.id_} agent_baseline={agent_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria}' 801 | if other_agent.id_ == '0': 802 | log_str += f' | other_agent pseudo_id: {other_agent.pseudo_id_}' 803 | conditional_log(self.log_name_, logger, log_str, 'debug') 804 | else: 805 | raise Exception(f"Invalid ttc_break_tie: {ttc_break_tie}") 806 | else: 807 | criteria = (ttc1 > ttc2) 808 | # conditional_log(self.log_name_, logger, 809 | # f'agent {self.id_} agent_baseline={agent_baseline} | ttc | criteria={criteria}', 'debug') 810 | # debug 811 | if criteria: 812 | ids_should_stop_for.append((i, ttc2)) 813 | 814 | should_stop = (should_stop or criteria) # we could break loop here, but calculate all so we can debug 815 | # conditional_log(self.log_name_, logger, 816 | # f'agent {self.id_} (agg={self.agg_level_}) ttc with agent {i} (agg={other_agent.agg_level_}) | ttc3 (both go) = {ttc3} | ttc1 (I stop) = {ttc1} | ttc2 (I go) = {ttc2} | should stop: {criteria}', 'debug') 817 | 818 | # if should_stop: 819 | # conditional_log(self.log_name_, logger, 820 | # f'agent {self.id_} should stop for: {str(ids_should_stop_for)}', 'debug') 821 | 822 | self.__update_data_members() 823 | 824 | return should_stop, ids_should_stop_for 825 | 826 | def step(self, dt, agents, ttc_dp, agent_states, agent_action_noise=0, ttc_break_tie=None, agent_baseline=None) -> None: 827 | should_stop, ids_should_stop_for = self.should_stop_wrapped(dt, agents, ttc_dp, ttc_break_tie, agent_baseline) 828 | action = 0 829 | # if not rl agent, plan as usual 830 | if self.rl_agent_ is None: 831 | # conditional_log(self.log_name_, logger, 832 | # f'agent {self.id_} ttc: {str(self.past_ttc_)}', 'info') # DEBUG 833 | # first checking whether it's possible to switch out of on_stop_ 834 | if self.on_stop_ and (not should_stop): 835 | # conditional_log(self.log_name_, logger, 836 | # f'agent {self.id_} leaving on_stop_', 'debug') 837 | self.on_stop_ = False 838 | 839 | if not self.on_stop_: 840 | # decide whether to enter on_stop_ 841 | if should_stop: 842 | # stop 843 | # conditional_log(self.log_name_, logger, 844 | # f'agent {self.id_} entering on_stop_. action: 0', 'debug') 845 | action = 0 846 | self.on_stop_ = True 847 | else: 848 | action = 1 849 | # if rl agent, use rl agent to select action 850 | else: 851 | parametric_state = agent_states[int(self.id_) - 1] 852 | parametric_state_ts = torch.from_numpy(parametric_state).unsqueeze(0).float().to(self.device_) # 1*(state_dim) 853 | rl_action = self.rl_agent_.select_action(parametric_state_ts, 0, test=True) 854 | action = rl_action 855 | 856 | # agent_action_noise 857 | if agent_action_noise > 0: 858 | agent_action_noise_sample = np.random.random() 859 | if (agent_action_noise_sample <= agent_action_noise): 860 | action = int(1 - action) 861 | conditional_log(self.log_name_, logger, 862 | f'agent {self.id_} switch action to {action} given agent_action_noise={agent_action_noise}', 'debug') # DEBUG 863 | 864 | self.apply_action(action, dt) 865 | 866 | # update self.future_horizon_positions_ 867 | self.update_future_horizon_positions() 868 | 869 | # update whether I have passed goal 870 | # dist_to_goal = self.__compute_dist_to_observable_state(self.observable_state_, self.goal_) 871 | # if (dist_to_goal < self.done_thres_): 872 | # conditional_log(self.log_name_, logger, f'Agent {self.id_} passed goal with distance to goal: {dist_to_goal}', 'debug') 873 | # self.has_passed_goal_ = True 874 | 875 | self.__update_data_members() 876 | 877 | return ids_should_stop_for 878 | 879 | def force_action(self, action, dt): 880 | self.saved_agent_.force_action(action, dt) 881 | self.__update_data_members() 882 | 883 | def __compute_dist_to_observable_state(self, 884 | this_observable_state: ObservableState, 885 | other_observable_state: ObservableState) -> float: 886 | rst = self.saved_agent_.__compute_dist_to_observable_state(this_observable_state, other_observable_state) 887 | self.__update_data_members() 888 | return rst 889 | 890 | def __update_data_members(self): 891 | self.__dict__.update(self.saved_agent_.__dict__) 892 | 893 | class NeighborhoodV4EgoAgent(NeighborhoodV4AgentInterface): 894 | def __init__(self, 895 | id: str, 896 | observable_state: ObservableState, 897 | goal: ObservableState, 898 | curr_lane_id_order: int, 899 | curr_waypoint_idx: int, 900 | default_ttc: float, 901 | lane_segments: List[LaneSegment], 902 | path: list = [], # list of lane_ids 903 | log_name: str = None, 904 | v_desired: float = 5.6, 905 | b1: float = -1.0): 906 | 907 | self.v_desired_ = v_desired 908 | # driver_type 909 | self.b1_ = b1 910 | # Note: agg_level is determined with the same velocity range as default agents, {0,1,2} 911 | # cf: b1 (driver_type) is a continuous float [-1,1] that determines v_desired_ 912 | # you can view agg_level as a discretization of b1 913 | if (v_desired >= 9): 914 | self.agg_level_ = 2 915 | elif (v_desired >= 6.8): 916 | self.agg_level_ = 1 917 | else: 918 | self.agg_level_ = 0 919 | self.stop_prob_ = - 0.5 * self.b1_ + 0.5 920 | 921 | super(NeighborhoodV4EgoAgent, self).__init__( 922 | id=id, 923 | observable_state=observable_state, 924 | agent_type=AgentType.ego_vehicle, 925 | goal=goal, 926 | curr_lane_id_order=curr_lane_id_order, 927 | curr_waypoint_idx=curr_waypoint_idx, 928 | v_desired = self.v_desired_, 929 | default_ttc = default_ttc, 930 | lane_segments = lane_segments, 931 | path=path) 932 | 933 | self.observable_state_.velocity_ = self.v_desired_ 934 | self.log_name_ = log_name 935 | 936 | # track the waypoints at future 10 ts 937 | self.initialize_future_horizon_positions() 938 | 939 | def step(self, action: int, dt: float) -> None: 940 | self.apply_action(action, dt) 941 | 942 | # update self.future_horizon_positions_ 943 | self.update_future_horizon_positions() 944 | 945 | def should_stop(self, dt, agents, ttc_dp, 946 | ego_baseline=None, include_agents_within_range=None, ttc_break_tie=None) -> bool: 947 | # look up ttc1 (if I stop), ttc2 (if I go) 948 | # if any of the 4 ttcs are within range, capture that agent 949 | # (xs, ys) 950 | ids_ttc_in_range_raw = np.array([]).astype(int) 951 | for j in range(4): 952 | # ids_ttc_in_range = np.union1d(ids_ttc_in_range, np.where(ttc_dp[j, int(self.id_), :] < self.default_ttc_)[0]) 953 | ids_ttc_in_range_raw = np.union1d(ids_ttc_in_range_raw, np.where(ttc_dp[j, int(self.id_), :] <= self.ttc_thres_)[0]) 954 | # if we enforce the same observation range as the model 955 | if include_agents_within_range is not None and include_agents_within_range > 0: 956 | indices_to_delete = [] 957 | ids_to_delete = [] # for debug 958 | for i in range(len(ids_ttc_in_range_raw)): 959 | agent_id = str(ids_ttc_in_range_raw[i]) 960 | poly_dist = calculate_poly_distance(agents['0'], agents[agent_id]) 961 | if (poly_dist > include_agents_within_range): 962 | indices_to_delete.append(i) 963 | ids_to_delete.append(agent_id) 964 | ids_ttc_in_range_raw = np.delete(ids_ttc_in_range_raw, indices_to_delete) 965 | conditional_log(self.log_name_, logger, f'ttc_in_range ids not in distance range: {str(ids_to_delete)} | ids_ttc_in_range_raw={str(ids_ttc_in_range_raw)}', 'debug') 966 | # for crossing vehicles, only consider my interaction with the head of a series of vehicles in ids_ttc_in_range 967 | ids_ttc_in_range = [] 968 | for raw_id in ids_ttc_in_range_raw: 969 | to_add = True 970 | for selected_id in ids_ttc_in_range: 971 | selected_agent_lane_id = agents[str(selected_id)].path_[agents[str(selected_id)].closest_lane_id_order_] 972 | selected_agent_waypoint_idx = agents[str(selected_id)].closest_waypoint_idx_ 973 | raw_agent_lane_id = agents[str(raw_id)].path_[agents[str(raw_id)].closest_lane_id_order_] 974 | raw_agent_waypoint_idx = agents[str(raw_id)].closest_waypoint_idx_ 975 | # if raw agent is behind selected agent, don't add 976 | if (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx < selected_agent_waypoint_idx) or \ 977 | selected_agent_lane_id in self.lane_segments_[raw_agent_lane_id].successors: 978 | to_add = False 979 | break 980 | # if selected agent is behind raw agent, remove selected agent 981 | elif (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx > selected_agent_waypoint_idx) or \ 982 | raw_agent_lane_id in self.lane_segments_[selected_agent_lane_id].successors: 983 | ids_ttc_in_range.remove(selected_id) 984 | if to_add: 985 | ids_ttc_in_range.append(raw_id) 986 | 987 | ids_should_stop_for = [] 988 | 989 | should_stop = False 990 | ego_lane_id = self.path_[self.closest_lane_id_order_] 991 | ego_waypoint_idx = self.closest_waypoint_idx_ 992 | my_two_degree_successors = get_two_degree_successors(self.lane_segments_, ego_lane_id) 993 | for other_agent_id, other_agent in agents.items(): 994 | if (other_agent_id == self.id_): 995 | continue 996 | # if other agent is in front of me 997 | other_agent_lane_id = other_agent.path_[other_agent.closest_lane_id_order_] 998 | other_agent_waypoint_idx = other_agent.closest_waypoint_idx_ 999 | dist = calculate_poly_distance(self, other_agent) 1000 | # conditional_log(self.log_name_, logger, 1001 | # f'agent {self.id_} dist to agent {other_agent_id} | dist={dist}', 'info') # DEBUG 1002 | # == front-vehicle logic == 1003 | if (ego_baseline not in [1,2,5]) and \ 1004 | ((other_agent_lane_id == ego_lane_id and ego_waypoint_idx < other_agent_waypoint_idx) or \ 1005 | other_agent_lane_id in my_two_degree_successors): 1006 | if (dist <= 1 + self.fad_coeff_ * self.ttc_thres_ * self.v_desired_): # assume ttc_thres_ self stop time if the other agent suddenly stops 1007 | conditional_log(self.log_name_, logger, 1008 | f'ego should stop for front agent {other_agent_id} | dist={dist}', 'debug') # DEBUG 1009 | est_decision_ttc = dist / self.v_desired_ 1010 | ids_should_stop_for.append((int(other_agent_id), est_decision_ttc)) 1011 | should_stop = (should_stop or True) # we want to go through everyone for debug 1012 | else: 1013 | i = int(other_agent_id) 1014 | if i in ids_ttc_in_range: 1015 | # == back-vehicle logic == 1016 | # if this ttc3 agent is behind me, don't consider whether we should stop for them 1017 | if (ego_baseline not in [5]) and \ 1018 | ((ego_lane_id == other_agent_lane_id and other_agent_waypoint_idx < ego_waypoint_idx) or \ 1019 | (ego_lane_id in self.lane_segments_[other_agent_lane_id].successors)): 1020 | conditional_log(self.log_name_, logger, f'ego with agent {i} behind. shouldn\'t stop.', 'debug') 1021 | should_stop = (should_stop or False) 1022 | # == ttc logic == 1023 | elif (ego_baseline not in [6]): 1024 | ttc3 = ttc_dp[3, int(self.id_), i] # if we both go 1025 | ttc1 = ttc_dp[1, int(self.id_), i] # if I stop 1026 | ttc2 = ttc_dp[2, int(self.id_), i] # if I go 1027 | # max_ttc_go = max(ttc2, ttc3) 1028 | # criteria = (ttc1 >= max_ttc_go) 1029 | if ttc1 == ttc2: 1030 | if ego_baseline in [None, 5, 4]: 1031 | # based on who's agg_level = 0 1032 | if (ego_baseline not in [4]) and (ttc_break_tie is None or ttc_break_tie == 'agg_level=0'): 1033 | # if the other is not mild, I should stop 1034 | if other_agent.agg_level_ != 0: 1035 | criteria = True 1036 | # if the other is mild, I should go 1037 | else: 1038 | criteria = False 1039 | conditional_log(self.log_name_, logger, 1040 | f'agent {self.id_} ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 1041 | # based on b1 1042 | elif (ego_baseline not in [4]) and (ttc_break_tie == 'b1'): 1043 | conditional_log(self.log_name_, logger, f'agent {self.id_} ttc_break_tie: {ttc_break_tie}', 'debug') 1044 | if self.b1_ != other_agent.b1_: 1045 | criteria = (self.b1_ < other_agent.b1_) 1046 | else: 1047 | temp = np.random.random() 1048 | if temp < 0.5: 1049 | criteria = True 1050 | else: 1051 | criteria = False 1052 | conditional_log(self.log_name_, logger, 1053 | f'agent {self.id_} ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 1054 | # based on coin flip 1055 | elif ttc_break_tie == 'random': 1056 | ttc_sample = np.random.random() # [0,1) 1057 | criteria = (ttc_sample < self.stop_prob_) 1058 | conditional_log(self.log_name_, logger, 1059 | f'agent {self.id_} ttc_break_tie: {ttc_break_tie} | sample={ttc_sample} | stop_prob = {self.stop_prob_} | criteria={criteria}', 'debug') 1060 | else: 1061 | raise Exception(f"invalid ttc_break_tie: {ttc_break_tie}") 1062 | else: 1063 | if ego_baseline in [0,1]: 1064 | temp = np.random.random() 1065 | if temp < 0.5: 1066 | criteria = True 1067 | else: 1068 | criteria = False 1069 | else: 1070 | raise Exception(f"invalid ego_baseline: {ego_baseline}") 1071 | else: 1072 | criteria = (ttc1 > ttc2) 1073 | # debug 1074 | if criteria: 1075 | ids_should_stop_for.append((i, ttc2)) 1076 | 1077 | should_stop = (should_stop or criteria) # we could break loop here, but calculate all so we can debug 1078 | conditional_log(self.log_name_, logger, 1079 | f'agent {self.id_} (agg={self.agg_level_}) ttc with agent {i} (agg={other_agent.agg_level_}) | ttc3 (both go) = {ttc3} | ttc1 (I stop) = {ttc1} | ttc2 (I go) = {ttc2}', 'debug') 1080 | if ego_baseline is None: 1081 | conditional_log(self.log_name_, logger, 1082 | f'agent {self.id_} (agg={self.agg_level_}) with agent {i} | ttc logic | should stop: {criteria}', 'debug') 1083 | else: 1084 | conditional_log(self.log_name_, logger, 1085 | f'agent {self.id_} (agg={self.agg_level_}) with agent {i} | ttc logic | ego_baseline {ego_baseline} should_stop: {criteria}', 'debug') 1086 | 1087 | if should_stop: 1088 | conditional_log(self.log_name_, logger, 1089 | f'agent {self.id_} should stop for: {str(ids_should_stop_for)}', 'debug') 1090 | 1091 | return should_stop, ids_should_stop_for 1092 | 1093 | 1094 | class NeighborhoodV4EgoAgentWrapper(NeighborhoodV4AgentInterface): 1095 | def __init__(self, 1096 | saved_agent: NeighborhoodV4EgoAgent, 1097 | max_num_other_agents: int): 1098 | self.saved_agent_ = copy.deepcopy(saved_agent) 1099 | self.pseudo_id_ = str(np.random.randint(max_num_other_agents + 1)) 1100 | self.__update_data_members() 1101 | 1102 | def calculate_ttc(self, agents, dt, ttc_dp) -> None: 1103 | rst = self.saved_agent_.calculate_ttc(agents, dt, ttc_dp) 1104 | self.__update_data_members() 1105 | return rst 1106 | 1107 | def drive_along_path(self, dt: float, assume_vel: int = 0) -> Tuple[Position, int]: 1108 | rst = self.saved_agent_.drive_along_path(dt, assume_vel) 1109 | self.__update_data_members() 1110 | return rst 1111 | 1112 | def initialize_future_horizon_positions(self): 1113 | self.saved_agent_.initialize_future_horizon_positions() 1114 | self.__update_data_members() 1115 | 1116 | def update_future_horizon_positions(self): 1117 | self.saved_agent_.update_future_horizon_positions() 1118 | self.__update_data_members() 1119 | 1120 | def apply_action(self, action, dt): 1121 | self.saved_agent_.apply_action(action, dt) 1122 | self.__update_data_members() 1123 | 1124 | def fad_distance(self): 1125 | return self.saved_agent_.fad_distance() 1126 | 1127 | def step(self, action: int, dt: float) -> None: 1128 | self.saved_agent_.step(action, dt) 1129 | self.__update_data_members() 1130 | 1131 | def should_stop_wrapped(self, dt, agents, ttc_dp, 1132 | ego_baseline=None, include_agents_within_range=None, 1133 | ttc_break_tie=None, return_ttc_tracker=False) -> bool: 1134 | 1135 | # look up ttc1 (if I stop), ttc2 (if I go) 1136 | # if any of the 4 ttcs are within range, capture that agent 1137 | # (xs, ys) 1138 | ids_ttc_in_range_raw = np.array([]).astype(int) 1139 | for j in range(4): 1140 | # ids_ttc_in_range = np.union1d(ids_ttc_in_range, np.where(ttc_dp[j, int(self.id_), :] < self.default_ttc_)[0]) 1141 | ids_ttc_in_range_raw = np.union1d(ids_ttc_in_range_raw, np.where(ttc_dp[j, int(self.id_), :] <= self.ttc_thres_)[0]) 1142 | # if we enforce the same observation range as the model 1143 | if include_agents_within_range is not None and include_agents_within_range > 0: 1144 | indices_to_delete = [] 1145 | ids_to_delete = [] # for debug 1146 | for i in range(len(ids_ttc_in_range_raw)): 1147 | agent_id = str(ids_ttc_in_range_raw[i]) 1148 | poly_dist = calculate_poly_distance(agents['0'], agents[agent_id]) 1149 | if (poly_dist > include_agents_within_range): 1150 | indices_to_delete.append(i) 1151 | ids_to_delete.append(agent_id) 1152 | ids_ttc_in_range_raw = np.delete(ids_ttc_in_range_raw, indices_to_delete) 1153 | conditional_log(self.log_name_, logger, f'ttc_in_range ids not in distance range: {str(ids_to_delete)} | ids_ttc_in_range_raw={str(ids_ttc_in_range_raw)}', 'debug') 1154 | # for crossing vehicles, only consider my interaction with the head of a series of vehicles in ids_ttc_in_range 1155 | ids_ttc_in_range = [] 1156 | for raw_id in ids_ttc_in_range_raw: 1157 | to_add = True 1158 | for selected_id in ids_ttc_in_range: 1159 | selected_agent_lane_id = agents[str(selected_id)].path_[agents[str(selected_id)].closest_lane_id_order_] 1160 | selected_agent_waypoint_idx = agents[str(selected_id)].closest_waypoint_idx_ 1161 | raw_agent_lane_id = agents[str(raw_id)].path_[agents[str(raw_id)].closest_lane_id_order_] 1162 | raw_agent_waypoint_idx = agents[str(raw_id)].closest_waypoint_idx_ 1163 | # if raw agent is behind selected agent, don't add 1164 | if (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx < selected_agent_waypoint_idx) or \ 1165 | selected_agent_lane_id in self.lane_segments_[raw_agent_lane_id].successors: 1166 | to_add = False 1167 | break 1168 | # if selected agent is behind raw agent, remove selected agent 1169 | elif (selected_agent_lane_id == raw_agent_lane_id and raw_agent_waypoint_idx > selected_agent_waypoint_idx) or \ 1170 | raw_agent_lane_id in self.lane_segments_[selected_agent_lane_id].successors: 1171 | ids_ttc_in_range.remove(selected_id) 1172 | if to_add: 1173 | ids_ttc_in_range.append(raw_id) 1174 | 1175 | ids_should_stop_for = [] 1176 | 1177 | ttc_trackers = {'min_decision_ttc' : self.default_ttc_, # minimum decision-triggering ttc (if ttc1 > ttc2, should save ttc2) 1178 | 'min_ttc' : self.default_ttc_, # among all agents in ttc range, save their min ttc in (0,1,2,3) 1179 | 'min_ttc0' : self.default_ttc_, # among all agents in ttc range, save their min ttc0 1180 | 'min_ttc1' : self.default_ttc_, # among all agents in ttc range, save their min ttc1 1181 | 'min_ttc2' : self.default_ttc_, # among all agents in ttc range, save their min ttc2 1182 | 'min_ttc3' : self.default_ttc_ # among all agents in ttc range, save their min ttc3 1183 | } 1184 | 1185 | should_stop = False 1186 | ego_lane_id = self.path_[self.closest_lane_id_order_] 1187 | ego_waypoint_idx = self.closest_waypoint_idx_ 1188 | my_two_degree_successors = get_two_degree_successors(self.lane_segments_, ego_lane_id) 1189 | for other_agent_id, other_agent in agents.items(): 1190 | if (other_agent_id == self.id_): 1191 | continue 1192 | # if other agent is in front of me 1193 | other_agent_lane_id = other_agent.path_[other_agent.closest_lane_id_order_] 1194 | other_agent_waypoint_idx = other_agent.closest_waypoint_idx_ 1195 | dist = calculate_poly_distance(self, other_agent) 1196 | # conditional_log(self.log_name_, logger, 1197 | # f'agent {self.id_} dist to agent {other_agent_id} | dist={dist}', 'info') # DEBUG 1198 | # == front-vehicle logic == 1199 | if (ego_baseline not in [1,2,5]) and \ 1200 | ((other_agent_lane_id == ego_lane_id and ego_waypoint_idx < other_agent_waypoint_idx) or \ 1201 | other_agent_lane_id in my_two_degree_successors): 1202 | if (dist <= 1 + self.fad_coeff_ * self.ttc_thres_ * self.v_desired_): # assume ttc_thres_ self stop time if the other agent suddenly stops 1203 | conditional_log(self.log_name_, logger, 1204 | f'ego should stop for front agent {other_agent_id} | dist={dist}', 'debug') # DEBUG 1205 | est_decision_ttc = dist / self.v_desired_ 1206 | ids_should_stop_for.append((int(other_agent_id), est_decision_ttc)) 1207 | should_stop = (should_stop or True) # we want to go through everyone for debug 1208 | else: 1209 | i = int(other_agent_id) 1210 | if i in ids_ttc_in_range: 1211 | # == back-vehicle logic == 1212 | # if this ttc3 agent is behind me, don't consider whether we should stop for them 1213 | if (ego_baseline not in [5]) and \ 1214 | ((ego_lane_id == other_agent_lane_id and other_agent_waypoint_idx < ego_waypoint_idx) or \ 1215 | (ego_lane_id in self.lane_segments_[other_agent_lane_id].successors)): 1216 | conditional_log(self.log_name_, logger, f'ego with agent {i} behind. shouldn\'t stop.', 'debug') 1217 | should_stop = (should_stop or False) 1218 | # == ttc logic == 1219 | elif (ego_baseline not in [6]): 1220 | ttc3 = ttc_dp[3, int(self.id_), i] # if we both go 1221 | ttc1 = ttc_dp[1, int(self.id_), i] # if I stop 1222 | ttc2 = ttc_dp[2, int(self.id_), i] # if I go 1223 | ttc0 = ttc_dp[0, int(self.id_), i] 1224 | 1225 | # update ttc_trackers 1226 | min_ttc = min([ttc0, ttc1, ttc2, ttc3]) 1227 | ttc_trackers['min_ttc3'] = min(ttc3, ttc_trackers['min_ttc3']) 1228 | ttc_trackers['min_ttc2'] = min(ttc2, ttc_trackers['min_ttc2']) 1229 | ttc_trackers['min_ttc1'] = min(ttc1, ttc_trackers['min_ttc1']) 1230 | ttc_trackers['min_ttc0'] = min(ttc0, ttc_trackers['min_ttc0']) 1231 | ttc_trackers['min_ttc'] = min(min_ttc, ttc_trackers['min_ttc']) 1232 | 1233 | # max_ttc_go = max(ttc2, ttc3) 1234 | # criteria = (ttc1 >= max_ttc_go) 1235 | if ttc1 == ttc2: 1236 | ttc_trackers['min_decision_ttc'] = min(ttc1, ttc_trackers['min_decision_ttc']) 1237 | if ego_baseline in [None, 5, 4]: 1238 | # based on who's agg_level = 0 1239 | if (ego_baseline not in [4]) and (ttc_break_tie is None or ttc_break_tie == 'agg_level=0'): 1240 | # if the other is not mild, I should stop 1241 | if other_agent.agg_level_ != 0: 1242 | criteria = True 1243 | # if the other is mild, I should go 1244 | else: 1245 | criteria = False 1246 | conditional_log(self.log_name_, logger, 1247 | f'agent {self.id_} ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 1248 | # based on b1 1249 | elif (ego_baseline not in [4]) and (ttc_break_tie == 'b1'): 1250 | conditional_log(self.log_name_, logger, f'agent {self.id_} ttc_break_tie: {ttc_break_tie}', 'debug') 1251 | if self.b1_ != other_agent.b1_: 1252 | criteria = (self.b1_ < other_agent.b1_) 1253 | else: 1254 | temp = np.random.random() 1255 | if temp < 0.5: 1256 | criteria = True 1257 | else: 1258 | criteria = False 1259 | conditional_log(self.log_name_, logger, 1260 | f'agent {self.id_} ttc_break_tie: {ttc_break_tie} | criteria={criteria}', 'debug') 1261 | # based on coin flip 1262 | elif ttc_break_tie == 'random': 1263 | ttc_sample = np.random.random() # [0,1) 1264 | criteria = (ttc_sample < self.stop_prob_) 1265 | conditional_log(self.log_name_, logger, 1266 | f'agent {self.id_} other_agent {other_agent.id_} ttc_break_tie: {ttc_break_tie} | sample={ttc_sample} | stop_prob = {self.stop_prob_} | criteria={criteria}', 'debug') 1267 | elif ttc_break_tie == 'id': 1268 | criteria = (int(self.pseudo_id_) < int(other_agent.id_)) 1269 | log_str = f'agent {self.id_} other_agent {other_agent.id_} agent_baseline={ego_baseline} | ttc_break_tie: {ttc_break_tie} | criteria={criteria} | self.pseudo_id: {self.pseudo_id_}' 1270 | conditional_log(self.log_name_, logger, log_str, 'debug') 1271 | else: 1272 | raise Exception(f"invalid ttc_break_tie: {ttc_break_tie}") 1273 | else: 1274 | if ego_baseline in [0,1]: 1275 | temp = np.random.random() 1276 | if temp < 0.5: 1277 | criteria = True 1278 | else: 1279 | criteria = False 1280 | else: 1281 | raise Exception(f"invalid ego_baseline: {ego_baseline}") 1282 | else: 1283 | criteria = (ttc1 > ttc2) 1284 | ttc_trackers['min_decision_ttc'] = min(ttc2, ttc_trackers['min_decision_ttc']) 1285 | # debug 1286 | if criteria: 1287 | ids_should_stop_for.append((i, ttc2)) 1288 | 1289 | should_stop = (should_stop or criteria) # we could break loop here, but calculate all so we can debug 1290 | conditional_log(self.log_name_, logger, 1291 | f'agent {self.id_} (agg={self.agg_level_}) ttc with agent {i} (agg={other_agent.agg_level_}) | ttc3 (both go) = {ttc3} | ttc1 (I stop) = {ttc1} | ttc2 (I go) = {ttc2}', 'debug') 1292 | if ego_baseline is None: 1293 | conditional_log(self.log_name_, logger, 1294 | f'agent {self.id_} (agg={self.agg_level_}) with agent {i} | ttc logic | should stop: {criteria}', 'debug') 1295 | else: 1296 | conditional_log(self.log_name_, logger, 1297 | f'agent {self.id_} (agg={self.agg_level_}) with agent {i} | ttc logic | ego_baseline {ego_baseline} should_stop: {criteria}', 'debug') 1298 | 1299 | if should_stop: 1300 | conditional_log(self.log_name_, logger, 1301 | f'agent {self.id_} should stop for: {str(ids_should_stop_for)}', 'debug') 1302 | 1303 | self.__update_data_members() 1304 | 1305 | if return_ttc_tracker: 1306 | return should_stop, ids_should_stop_for, ttc_trackers 1307 | else: 1308 | return should_stop, ids_should_stop_for 1309 | 1310 | def __update_data_members(self): 1311 | self.__dict__.update(self.saved_agent_.__dict__) 1312 | -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/envs/neighborhood_v4/neighborhood_env_v4_utils.py: -------------------------------------------------------------------------------- 1 | # Closed neighborhood environment with a roundabout, 4 t-intersections 2 | # all global util functions 3 | 4 | # Python 5 | import pdb, copy, os 6 | import pickle 7 | import numpy as np 8 | import math 9 | import matplotlib.pyplot as plt 10 | from pathlib import Path 11 | import matplotlib 12 | import logging 13 | from queue import Queue 14 | from typing import Any, Dict, Tuple, List 15 | import cv2, time 16 | # Argoverse 17 | from argoverse.utils.se3 import SE3 18 | # gym 19 | import gym 20 | from gym import error, spaces, utils 21 | from gym.utils import seeding 22 | # utils 23 | from gym_road_interactions.utils import create_bbox_world_frame, wrap_to_pi 24 | from gym_road_interactions.viz_utils import visualize_agent 25 | from gym_road_interactions.core import AgentType, Position, Agent, ObservableState, Observation, LaneSegment 26 | from shapely.geometry import Point, Polygon 27 | 28 | def calculate_remaining_lane_distance(lane_id: int, curr_pos: Position, lane_segments: List[LaneSegment]) -> float: 29 | # calculates the distance between curr_pos and lane end 30 | curr_xy = np.array([curr_pos.x_, curr_pos.y_]) 31 | lane_end_xy = lane_segments[lane_id].centerline[-1,:] 32 | l = np.linalg.norm(lane_end_xy - curr_xy) # straight-line distance 33 | 34 | # straight lanes 35 | if lane_segments[lane_id].curve_center is None: 36 | remain_lane_distance = l 37 | else: 38 | r = lane_segments[lane_id].curve_radius 39 | remain_theta = math.acos((2*r*r - l**2)/(2*r*r)) 40 | remain_lane_distance = remain_theta * r 41 | 42 | return remain_lane_distance 43 | 44 | def calculate_traversed_lane_distance(lane_id: int, curr_pos: Position, lane_segments: List[LaneSegment]) -> float: 45 | # calculates the distance between curr_pos and lane start 46 | curr_xy = np.array([curr_pos.x_, curr_pos.y_]) 47 | lane_start_xy = lane_segments[lane_id].centerline[0,:] 48 | l = np.linalg.norm(lane_start_xy - curr_xy) # straight-line distance 49 | 50 | # straight lanes 51 | if lane_segments[lane_id].curve_center is None: 52 | traversed_lane_distance = l 53 | else: 54 | r = lane_segments[lane_id].curve_radius 55 | traversed_theta = math.acos((2*r*r - l**2)/(2*r*r)) 56 | traversed_lane_distance = traversed_theta * r 57 | 58 | return traversed_lane_distance 59 | 60 | def detect_future_collision(agent1 : Agent, agent2: Agent, future_pos1 : Position, future_pos2: Position) -> bool: 61 | dist = math.sqrt((future_pos1.x_ - future_pos2.x_)**2 + \ 62 | (future_pos1.y_ - future_pos2.y_)**2) 63 | return (dist < agent1.radius_ + agent2.radius_ + 1.0) 64 | 65 | def detect_collision(agent1: Agent, agent2: Agent) -> bool: 66 | dist = math.sqrt((agent1.observable_state_.position_.x_ - agent2.observable_state_.position_.x_)**2 + \ 67 | (agent1.observable_state_.position_.y_ - agent2.observable_state_.position_.y_)**2) 68 | return (dist < agent1.radius_ + agent2.radius_ + 1.0) 69 | 70 | def calculate_poly_distance(agent1: Agent, agent2: Agent) -> float: 71 | dist = math.sqrt((agent1.observable_state_.position_.x_ - agent2.observable_state_.position_.x_)**2 + \ 72 | (agent1.observable_state_.position_.y_ - agent2.observable_state_.position_.y_)**2) 73 | return dist - agent1.radius_ - agent2.radius_ 74 | 75 | def calculate_time_to_collision(agent1: Agent, agent2: Agent, default_ttc: float, 76 | dt: float, assume_vel1: int = 0, assume_vel2: int = 0): 77 | obs_window = 10 # number of timesteps to look forward # always set this to be the same as largest ttc thres 78 | time_to_collision = default_ttc 79 | 80 | for ts in range(1, obs_window+1): 81 | time_period = dt * ts 82 | pred_pos1,_,_ = agent1.drive_along_path(time_period, assume_vel=assume_vel1) 83 | pred_pos2,_,_ = agent2.drive_along_path(time_period, assume_vel=assume_vel2) 84 | if detect_future_collision(agent1, agent2, pred_pos1, pred_pos2): 85 | time_to_collision = time_period 86 | break 87 | 88 | # DEBUG 89 | # print(f'raw ttc: from agent{agent1.id_} to agent{agent2.id_}: {time_to_collision}') 90 | return time_to_collision 91 | 92 | def get_two_degree_successors(lane_segments, lane_id): 93 | first_degree_successors = np.array(lane_segments[lane_id].successors) 94 | second_degree_successors = np.array([]) 95 | for suc in first_degree_successors: 96 | second_degree_successors = np.union1d(second_degree_successors, np.array(lane_segments[suc].successors)) 97 | all_successors = np.union1d(first_degree_successors, second_degree_successors) 98 | return all_successors -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/utils.py: -------------------------------------------------------------------------------- 1 | # Uses Argoverse v1.1 : https://github.com/argoai/argoverse-api 2 | 3 | import math, sys 4 | import numpy as np 5 | import pickle 6 | from argoverse.utils.se3 import SE3 7 | from datetime import datetime 8 | from gym_road_interactions.core import AgentType, Position, Agent, ObservableState, Observation 9 | 10 | def wrap_to_pi(input_radian: float) -> float: 11 | """ 12 | Helper function to wrap input radian value to [-pi, pi) (+pi exclusive!) 13 | """ 14 | return ((input_radian + math.pi) % (2 * math.pi) - math.pi) 15 | 16 | def rotation_matrix_z(angle: float) -> np.ndarray: 17 | """ 18 | Helper function to generate a rotation matrix that rotates w.r.t z-axis 19 | Args: 20 | angle (float): angle in radian 21 | """ 22 | return np.array([[math.cos(angle), -math.sin(angle), 0], 23 | [math.sin(angle), math.cos(angle), 0], 24 | [0, 0, 1]]) 25 | 26 | def rotation_matrix_x(angle: float) -> np.ndarray: 27 | """ 28 | Helper function to generate a rotation matrix that rotates w.r.t x-axis 29 | Args: 30 | angle (float): angle in radian 31 | """ 32 | return np.array([[1, 0, 0], 33 | [0, math.cos(angle), math.sin(angle)], 34 | [0, -math.sin(angle), math.cos(angle)]]) 35 | 36 | def save_obj(obj, name): 37 | with open(name, 'wb') as f: 38 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 39 | 40 | def load_obj(name): 41 | with open(name, 'rb') as f: 42 | return pickle.load(f) 43 | 44 | def create_bbox_world_frame(agent : 'Agent', pos : 'Position'=None) -> np.ndarray: 45 | """ 46 | pos (optional): if provided, use this position instead of agent position 47 | """ 48 | if pos is None: 49 | x = agent.observable_state_.position_.x_ 50 | y = agent.observable_state_.position_.y_ 51 | heading = agent.observable_state_.position_.heading_ 52 | else: 53 | x = pos.x_ 54 | y = pos.y_ 55 | heading = pos.heading_ 56 | bbox_object_frame = np.array( 57 | [ 58 | [agent.length_ / 2.0, agent.width_ / 2.0, agent.height_ / 2.0], 59 | [agent.length_ / 2.0, -agent.width_ / 2.0, agent.height_ / 2.0], 60 | [-agent.length_ / 2.0, agent.width_ / 2.0, agent.height_ / 2.0], 61 | [-agent.length_ / 2.0, -agent.width_ / 2.0, agent.height_ / 2.0], 62 | ] 63 | ) 64 | agent_to_city_frame_se3 = SE3(rotation=rotation_matrix_z(heading), 65 | translation=np.array([x, y, 0])) 66 | bbox_in_city_frame = agent_to_city_frame_se3.transform_point_cloud(bbox_object_frame) 67 | return bbox_in_city_frame 68 | 69 | # Logging function 70 | def log(fname, s): 71 | # if not os.path.isdir(os.path.dirname(fname)): 72 | # os.system(f'mkdir -p {os.path.dirname(fname)}') 73 | f = open(fname, 'a') 74 | f.write(f'{str(datetime.now())}: {s}\n') 75 | f.close() 76 | 77 | # helper to choose whether to log to a file or logger out 78 | # logger_level: [info, debug, warning, error] 79 | def conditional_log(log_name, logger, content, logger_level='debug'): 80 | if log_name is not None: 81 | log(log_name, content) 82 | else: 83 | if logger_level == 'debug': 84 | logger.debug(content) 85 | elif logger_level == 'info': 86 | logger.info(content) 87 | else: 88 | logger.warning(f'conditional_log called with invalid logger_level={logger_level}, showing content with info') 89 | logger.info(content) 90 | 91 | def remap(v, x, y): 92 | # v: value / np array 93 | # x: original range 94 | # y: target range 95 | return y[0] + (v-x[0])*(y[1]-y[0])/(x[1]-x[0]) 96 | 97 | -------------------------------------------------------------------------------- /road_interactions_environment/gym_road_interactions/viz_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from Argoverse v1.1 : https://github.com/argoai/argoverse-api 2 | # argoverse-api/demo_usage/visualize_30hz_benchmark_data_on_map.py 3 | 4 | import math 5 | import logging 6 | import pdb 7 | from typing import Tuple 8 | 9 | import imageio 10 | # all mayavi imports MUST come before matplotlib, else Tkinter exceptions 11 | # will be thrown, e.g. "unrecognized selector sent to instance" 12 | import mayavi 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | 16 | from argoverse.map_representation.map_api import ArgoverseMap 17 | from argoverse.utils.geometry import filter_point_cloud_to_polygon, rotate_polygon_about_pt 18 | from argoverse.utils.mpl_plotting_utils import draw_lane_polygons, plot_bbox_2D 19 | from argoverse.utils.se3 import SE3 20 | from argoverse.utils.subprocess_utils import run_command 21 | 22 | from gym_road_interactions.core import Agent, AgentType 23 | from gym_road_interactions.utils import rotation_matrix_z 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | def render_map( 28 | city_name: str, 29 | ax: plt.Axes, 30 | axis: str, 31 | local_lane_polygons: np.ndarray, 32 | local_das: np.ndarray, 33 | local_lane_centerlines: list, # list of np.ndarray 34 | city_to_egovehicle_se3: SE3, 35 | avm: ArgoverseMap) -> None: 36 | """ 37 | Helper function to draw map 38 | """ 39 | if axis is not "city_axis": 40 | # rendering instead in the egovehicle reference frame 41 | for da_idx, local_da in enumerate(local_das): 42 | local_da = city_to_egovehicle_se3.inverse_transform_point_cloud(local_da) 43 | local_das[da_idx] = rotate_polygon_about_pt(local_da, city_to_egovehicle_se3.rotation, np.zeros(3)) 44 | 45 | for lane_idx, local_lane_polygon in enumerate(local_lane_polygons): 46 | local_lane_polygon = city_to_egovehicle_se3.inverse_transform_point_cloud(local_lane_polygon) 47 | local_lane_polygons[lane_idx] = rotate_polygon_about_pt( 48 | local_lane_polygon, city_to_egovehicle_se3.rotation, np.zeros(3) 49 | ) 50 | 51 | draw_lane_polygons(ax, local_lane_polygons) 52 | draw_lane_polygons(ax, local_das, color="tab:pink") 53 | 54 | for lane_cl in local_lane_centerlines: 55 | ax.plot(lane_cl[:, 0], lane_cl[:, 1], "--", color="grey", alpha=1, linewidth=1, zorder=0) 56 | 57 | def visualize_agent( 58 | ax: plt.Axes, 59 | agent: Agent, 60 | color: Tuple[float, float, float] = None 61 | ) -> None: 62 | """ 63 | Helper function to draw bounding box and arrow of an agent 64 | """ 65 | # draw center dot 66 | # TODO: change color based on agent type 67 | x = agent.observable_state_.position_.x_ 68 | y = agent.observable_state_.position_.y_ 69 | # change color based on agent type 70 | if (agent.agent_type_ == AgentType.ego_vehicle): 71 | logger.debug(f'Rendering ego vehicle at {x}, {y}') 72 | clr_lib = {AgentType.ego_vehicle : (0,1,1), # cyan 73 | AgentType.other_vehicle : (0,0,1), # blue 74 | AgentType.pedestrian : (1,0,0), # red 75 | AgentType.cyclist : (0,1,0), # green 76 | AgentType.motorcycle : (0.196, 0.8, 0.196), # limegreen 77 | AgentType.on_road_obstacle : (0,0,0), # black 78 | AgentType.other_mover : (1,0,1)} # magenta 79 | ax.scatter(x, y, 100, color=clr_lib[agent.agent_type_], marker=".", zorder=2) 80 | # draw arrow 81 | heading = agent.observable_state_.position_.heading_ 82 | scale_factor = math.log(abs(agent.observable_state_.velocity_) + 1.0) 83 | dx = scale_factor * math.cos(heading) 84 | dy = scale_factor * math.sin(heading) 85 | ax.arrow(x, y, dx, dy, color="r", width=0.2, zorder=2) 86 | if agent.id_ == '0': 87 | ax.annotate(agent.id_, (x, y), color='k', weight='bold', fontsize=7, ha='center', va='center') 88 | else: 89 | ax.annotate(agent.id_, (x, y), color='w', weight='bold', fontsize=7, ha='center', va='center') 90 | # pdb.set_trace() 91 | # draw bounding box 92 | bbox_object_frame = np.array( 93 | [ 94 | [agent.length_ / 2.0, agent.width_ / 2.0, agent.height_ / 2.0], 95 | [agent.length_ / 2.0, -agent.width_ / 2.0, agent.height_ / 2.0], 96 | [-agent.length_ / 2.0, agent.width_ / 2.0, agent.height_ / 2.0], 97 | [-agent.length_ / 2.0, -agent.width_ / 2.0, agent.height_ / 2.0], 98 | ] 99 | ) 100 | agent_to_city_frame_se3 = SE3(rotation=rotation_matrix_z(heading), 101 | translation=np.array([x, y, 0])) 102 | bbox_in_city_frame = agent_to_city_frame_se3.transform_point_cloud(bbox_object_frame) 103 | if color is not None: 104 | plot_bbox_2D(ax, bbox_in_city_frame, color) 105 | else: 106 | plot_bbox_2D(ax, bbox_in_city_frame, clr_lib[agent.agent_type_]) 107 | 108 | def write_nonsequential_idx_video(img_wildcard: str, output_fpath: str, fps: int, return_cmd: bool = False) -> None: 109 | """ 110 | Args: 111 | img_wildcard: string 112 | output_fpath: string 113 | fps: integer, frames per second 114 | 115 | Returns: 116 | None 117 | """ 118 | cmd = f"ffmpeg -r {fps} -f image2 -pattern_type glob -i \'{img_wildcard}\'" 119 | cmd += " -vcodec libx264 -profile:v main" 120 | cmd += " -level 3.1 -preset medium -crf 23 -x264-params ref=4 -acodec" 121 | cmd += f" copy -movflags +faststart -pix_fmt yuv420p -vf scale=920:-2" 122 | cmd += f" {output_fpath}" 123 | print(cmd) 124 | stdout_data, stderr_data = run_command(cmd, True) 125 | print(stdout_data) 126 | print(stderr_data) 127 | 128 | if return_cmd: 129 | return cmd 130 | -------------------------------------------------------------------------------- /road_interactions_environment/neighborhood_v4_collision_set_gen.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym_road_interactions 3 | import time 4 | import logging 5 | import sys, os 6 | import pdb 7 | import glob 8 | import copy, pickle 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from datetime import datetime 12 | import torch 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | def save_obj(obj, name): 18 | with open(name + '.pkl', 'wb') as f: 19 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 20 | 21 | def load_obj(name): 22 | with open(name + '.pkl', 'rb') as f: 23 | return pickle.load(f) 24 | 25 | # Logging function 26 | def log(fname, s): 27 | f = open(fname, 'a') 28 | f.write(f'{str(datetime.now())}: {s}\n') 29 | f.close() 30 | 31 | # format: [(model_path, ratio_of_num_other_agents)] 32 | rl_agent_configs = [('',0.0)] 33 | 34 | ni_ep_len_dict = {1: 200, 2:200, 3:200} 35 | ni = 1 36 | na = 5 37 | date = '0321' 38 | env_configs = {# parametric state 39 | 'use_global_frame': False, # whether to use global frame for the state vector 40 | 'normalize': True, # whether to normalize the state vector 41 | 'include_ttc': True, 42 | 'include_future_waypoints': 10, 43 | 'use_future_horizon_positions': False, # if false, just use the 10 future waypoints (each 10 waypoints apart) 44 | 'num_history_states': 10, # number of history states included in parametric state. 45 | 'num_future_states': 10, # number of future states included in parametric state. 46 | 'time_gap': 1, # time gap between two states. 47 | # total number of states included is: num_history_states + num_future_states + 1 48 | # t - time_gap * n_h, t - time_gap * (n_h-1), ..., t, t + time_gap, ..., t + time_gap * n_f 49 | 'stalemate_horizon': 5, # number of past ts (including current) that we track to determine stalemate 50 | 'include_polygon_dist': True, # if true, include sigmoid(polygon distance between ego and agent) 51 | 'include_agents_within_range': 10.0, # the range within which agents will be included. 52 | 'agent_state_dim': 6+4+10*2+1, # state dim at a single ts. related to 'include_ttc', 'include_polygon_dist', 'include_future_waypoints'' 53 | # num agents (change training env_config with this) 54 | 'num_other_agents': na, # this will change 55 | 'max_num_other_agents': 25, # could only be 25, 40, 60 (one of the upper limits) 56 | 'max_num_other_agents_in_range': 25, # >=6. max number of other agents in range. Must be <= max_num_other_agents. default 25. 57 | # agent behavior 58 | 'agent_stochastic_stop': False, # (check) whether the other agent can choose to stop for ego with a certain probability 59 | 'agent_shuffle_ids': True, # (check) if yes, the id of other agents will be shuffled during reset 60 | 'rl_agent_configs' : [('',0.0)], 61 | 'all_selfplay': False, # if true, all agents will be changed to the most recent rl model, whatever the episode mix is 62 | # path 63 | 'ego_num_intersections_in_path': ni, # (check) 64 | 'ego_expand_path_depth': 2, # (check) the number of extending lanes from center intersection in ego path 65 | 'expanded_lane_set_depth': 1, # (check) 66 | 'single_intersection_type': 'mix', # (check) ['t-intersection', 'roundabout', 'mix'] 67 | 'c1': 2, # (check) 0: left/right turn at t4, 1: all possible depth-1 paths at t4, 2: all possible depth-1 paths at random intersection 68 | # training settings 69 | 'gamma': 0.99, 70 | 'max_episode_timesteps': ni_ep_len_dict[ni], 71 | # NOTE: if you change velocity_coeff, you have to change the whole dataset 72 | # (interaction set, collision set, etc) bc they are based on b1 and locations are 73 | # based on v_desired 74 | 'ego_velocity_coeff': (2.7, 8.3), # (check) (w, b). v_desired = w * b1 + b # experiments change this one 75 | 'agent_velocity_coeff': (2.7, 8.3), # (check) (w, b). v_desired = w * b1 + b # this is fixed! 76 | # reward = w * b1 + b # Be careful with signs! 77 | 'reward': 'default_fad', # [default, default_ttc, simplified, default_fad] 78 | 'time_penalty_coeff': (-1./20., -3./20.), 79 | 'velocity_reward_coeff': (0.5, 1.5), 80 | 'collision_penalty_coeff': (-5.0, -45.0), 81 | 'fad_penalty_coeff': 1.0, 82 | 'timeout_penalty_coeff': (-5.0, -20.0), 83 | 'stalemate_penalty_coeff': (-0.5, -1.5), 84 | # ego config 85 | 'use_default_ego': False, 86 | 'no_ego': False, # TODO: remember to turn this off if you want ego! 87 | # action noises 88 | 'ego_action_noise': 0.0, # lambda of the poisson noise applied to ego action 89 | 'agent_action_noise': 0.0, # lambda of the poisson noise applied to agent action 90 | # added on 0414 91 | 'ego_baseline': 5, # if None, run oracle. If >=0, choose the corresponding baseline setup (note: baseline 4 = oracle with ttc_break_tie=random) 92 | 'ttc_break_tie': 'id', 93 | 'agent_baseline': 5, 94 | 'stalemate_breaker': True} # 'agg_level=0' or 'b1' 95 | # the dimension of an agent across all ts 96 | env_configs['num_ts_in_state'] = env_configs['num_history_states'] + env_configs['num_future_states'] + 1 97 | env_configs['agent_total_state_dim'] = env_configs['agent_state_dim'] * env_configs['num_ts_in_state'] 98 | # training configs 99 | train_configs = {# model 100 | 'model': 'TwinDDQN', # [TwinDDQN, DDQN] # TODO check TwinDDQN Version!!!! 101 | 'gamma': env_configs['gamma'], 102 | 'target_update': 100, # number of policy updates until target update 103 | 'max_buffer_size': 200000, # max size for buffer that saves state-action-reward transitions 104 | 'batch_size': 128, 105 | 'lrt': 2e-5, 106 | 'tau': 0.2, 107 | 'exploration_timesteps': 0, 108 | 'value_net': 'set_transformer', # [vanilla, deep_set, deeper_deep_set, set_transformer, social_attention] 109 | 'future_state_dropout': 0.7, # probability of dropping out future states during training 110 | 'use_lstm': True, # if true, will use LSTMValueNet along with the configured value_net 111 | # deep set 112 | 'pooling': 'mean', # ['mean', 'max'] 113 | # set transformer 114 | 'layer_norm' : True, # whether to use layer_norm in set transformer 115 | 'layers' : 2, 116 | # 'train_every_timesteps': 4, 117 | # epsilon decay for epsilon greedy 118 | 'eps_start': 1.0, 119 | 'eps_end': 0.01, 120 | 'eps_decay': 500, 121 | # training 122 | 'reward_threshold': 1000, # set impossible 123 | 'max_timesteps': 200000, 124 | 'train_every_episodes': 1, # if -1: train at every timestep. if n >= 1: train after every n timesteps 125 | 'save_every_timesteps': 5000, 126 | 'eval_every_episodes': 50, 127 | 'log_every_timesteps': 100, 128 | 'record_every_episodes': 30, 129 | 'seed': 0, 130 | 'grad_clamp': False, 131 | 'moving_avg_window': 100, 132 | 'replay_collision_episode': 0, # if 0, start every new episode fresh 133 | 'replay_episode': 0, # if 0, start every new episode fresh 134 | 'collision_episode_ratio': 0.25, # the ratio of saved collision episodes used in both training and testing. if set to 0, then don't use saved episodes 135 | 'interaction_episode_ratio': 0.5, # the ratio of saved interaction episodes used in both training and testing. if set to 0, then don't use interaction episodes 136 | 'buffer_agent_states': False, # if True, include agent-centric states in replay buffer 137 | # env 138 | 'agent_state_dim': env_configs['agent_state_dim'], 139 | 'agent_total_state_dim': env_configs['agent_total_state_dim'], # state_dim of each agent 140 | 'max_num_other_agents_in_range': env_configs['max_num_other_agents_in_range'], 141 | 'state_dim': env_configs['agent_total_state_dim'] * (env_configs['max_num_other_agents_in_range']+1), # total state_dim 142 | 'num_ts_in_state': env_configs['num_ts_in_state'], 143 | 'num_history_states': env_configs['num_history_states'], 144 | 'num_future_states': env_configs['num_future_states'], 145 | 'action_dim': 2, 146 | 'env_action_dim': 1, 147 | 'max_episode_timesteps': ni_ep_len_dict[ni]} 148 | 149 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 150 | print(device) 151 | env = gym.make('Neighborhood-v4') 152 | env.set_env_config(env_config) 153 | assert env_config['max_num_other_agents'] == 25 154 | env.set_train_config_and_device(train_configs, device) 155 | w,b = env_config['ego_velocity_coeff'] 156 | env.log_name_ = f'out/{date}_collision-set_train_ego-vel-coeff={w},{b}_ni={ni}_na={na}.log' 157 | dataset_path = f'../policy_network/neighborhood_v4_ddqn/collision_sets/{date}_collision-set_train_ego-vel-coeff={w},{b}_ni={ni}_na={na}' 158 | 159 | N = 200 160 | action = 1 161 | episode = 0 162 | prev_ts = 0 163 | ts = 0 164 | done = False 165 | 166 | collision = 0 167 | timeout = 0 168 | last_save_collision = 0 169 | collision_ep_agents_init = [] 170 | 171 | longest_episode = 0 172 | 173 | time_start = time.time() 174 | 175 | while collision < 200: # use this when generating dataset 176 | episode += 1 177 | total_reward = 0 178 | prev_ts = ts 179 | episode_ts = 0 180 | done = False 181 | 182 | print(f"[Episode {episode}]...") 183 | log(env.log_name_, f"[Episode {episode}]...") 184 | 185 | state = None 186 | while state is None: 187 | state = env.reset() 188 | ep_agents_init = copy.deepcopy(env.agents_) 189 | 190 | while done == False and episode_ts < ni_ep_len_dict[env_config['ego_num_intersections_in_path']]: 191 | ts += 1 192 | episode_ts += 1 193 | log(env.log_name_, f"[Step {episode_ts}]...") 194 | 195 | state, reward, done, info = env.step(action) 196 | total_reward += reward 197 | 198 | if done or episode_ts >= ni_ep_len_dict[env_config['ego_num_intersections_in_path']]: 199 | dt = (int)(time.time() - time_start) 200 | print(f'total reward={total_reward} | length={ts - prev_ts} | dist_driven: {info[2]} | Time: {dt//3600:02}:{dt%3600//60:02}:{dt%60}') 201 | longest_episode = max(longest_episode, episode_ts) 202 | if done: 203 | if info[1] == True: 204 | collision += 1 205 | print(f'collision={collision}') 206 | collision_ep_agents_init.append(ep_agents_init) 207 | else: 208 | timeout += 1 209 | print(f'timeout={timeout}') 210 | 211 | if (collision > 0) and (collision - last_save_collision > 1): 212 | print(f'saving collision_ep_agents_init at collision={collision}') 213 | save_obj({'env_config': env_config, 214 | 'agents_init': collision_ep_agents_init}, dataset_path) 215 | last_save_collision = collision 216 | 217 | env.close() -------------------------------------------------------------------------------- /road_interactions_environment/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='road_interactions', 4 | version='0.0.1', 5 | install_requires=['gym']#And any other dependencies required 6 | ) 7 | --------------------------------------------------------------------------------