├── .gitignore ├── INSTALL.md ├── LICENSE ├── PythonAPI └── agents │ ├── __init__.py │ ├── navigation │ ├── Town01.png │ ├── Town02.png │ ├── __init__.py │ ├── agent.py │ ├── basic_agent.py │ ├── controller.py │ ├── global_route_planner.py │ ├── global_route_planner_dao.py │ ├── local_planner.py │ ├── roaming_agent.py │ └── test_global_route_planner.py │ └── tools │ ├── __init__.py │ └── misc.py ├── README.md ├── benchmark ├── __init__.py ├── base_suite.py ├── carla100 │ ├── 084 │ │ ├── nocrash_Town01.txt │ │ └── nocrash_Town02.txt │ └── 096 │ │ ├── nocrash_Town01.txt │ │ └── nocrash_Town02.txt ├── corl2017 │ ├── 084 │ │ ├── full_Town01.txt │ │ ├── full_Town02.txt │ │ ├── straight_Town01.txt │ │ ├── straight_Town02.txt │ │ ├── turn_Town01.txt │ │ └── turn_Town02.txt │ └── 096 │ │ ├── full_Town01.txt │ │ ├── full_Town02.txt │ │ ├── straight_Town01.txt │ │ ├── straight_Town02.txt │ │ ├── turn_Town01.txt │ │ └── turn_Town02.txt ├── goal_suite.py └── run_benchmark.py ├── benchmark_agent.py ├── bird_view ├── augmenter.py ├── models │ ├── __init__.py │ ├── agent.py │ ├── baseline.py │ ├── birdview.py │ ├── common.py │ ├── controller.py │ ├── factory.py │ ├── image.py │ ├── resnet.py │ └── roaming.py ├── scripts │ ├── parse_runs.py │ └── tune_pid.py └── utils │ ├── __init__.py │ ├── bz_utils │ ├── .gitignore │ ├── __init__.py │ ├── gif_maker.py │ ├── optimizer.py │ ├── plotter.py │ ├── saver.py │ ├── test.py │ └── video_maker.py │ ├── carla_utils.py │ ├── datasets │ ├── __init__.py │ ├── birdview_lmdb.py │ └── image_lmdb.py │ ├── image_utils.py │ ├── logger.py │ ├── map_utils.py │ ├── no_rendering_mode.py │ └── train_utils.py ├── data_collector.py ├── environment.yml ├── figs ├── birdview.png ├── birdview_loss.png ├── fig1.png └── image_phase1.png ├── misc ├── ImportMaps.sh ├── automatic_control.py ├── controller.ipynb ├── dynamic_weather.py ├── find_traffic_violations.py ├── light_town1.txt ├── light_town2.txt ├── manual_control.py ├── no_rendering_mode.py ├── spawn_npc.py ├── synchronous_mode.py ├── tutorial.py └── vehicle_gallery.py ├── quick_start.sh ├── training ├── phase2_utils.py ├── train_birdview.py ├── train_image_phase0.py ├── train_image_phase1.py └── train_image_phase2.py └── view_benchmark_results.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw* 2 | **/__pycache__ 3 | HDMaps/* 4 | bird_view/logs/* 5 | 6 | bird_view/notebooks/* 7 | *.out 8 | CarlaUE4/* 9 | Engine/* 10 | ExportedMaps/* 11 | CarlaUE4.sh 12 | CHANGELOG 13 | Dockerfile 14 | LICENSE 15 | README 16 | python_api.md 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | bin/ 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | *.egg-info/ 36 | .installed.cfg 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | .tox/ 44 | .coverage 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | 49 | # Translations 50 | *.mo 51 | 52 | # Mr Developer 53 | .mr.developer.cfg 54 | .project 55 | .pydevproject 56 | 57 | # Rope 58 | .ropeproject 59 | 60 | # Django stuff: 61 | *.log 62 | *.pot 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Setup 2 | 3 | ## Install CARLA 4 | - Download the [released 0.9.6 binary](http://carla-assets-internal.s3.amazonaws.com/Releases/Linux/CARLA_0.9.6.tar.gz) and use our compiled [.egg file](http://www.cs.utexas.edu/~dchen/lbc_release/egg/carla-0.9.6-py3.5-linux-x86_64.egg), if you are using Python 2.7 or 3.5. You still need to download the updated Navmesh. 5 | - Alternatively, you can compile carla from source. Clone CARLA 0.9.6 with our pedestrian fix at: https://github.com/dianchen96/carla/tree/0.9.6-lbc. Follow the instructions to compile and download the assets. 6 | 7 | ## Install our custom Navmesh 8 | Download the modified Navmesh for Town1 and Town2: 9 | 10 | `Town01`: http://www.cs.utexas.edu/~dchen/lbc_release/navmesh/Town01.bin 11 | 12 | `Town02`: http://www.cs.utexas.edu/~dchen/lbc_release/navmesh/Town02.bin 13 | 14 | ## Setup LBC 15 | - Clone this repo and replace all the files inside the CARLA folder 16 | - Install the dependencies or `conda install -f environment.yml`. 17 | - (Optionally) Download the model checkpoints specified in [README](..), or train the models. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Dian Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PythonAPI/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/PythonAPI/agents/__init__.py -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/Town01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/PythonAPI/agents/navigation/Town01.png -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/Town02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/PythonAPI/agents/navigation/Town02.png -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/PythonAPI/agents/navigation/__init__.py -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/basic_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018 Intel Labs. 4 | # authors: German Ros (german.ros@intel.com) 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """ This module implements an agent that roams around a track following random 10 | waypoints and avoiding other vehicles. 11 | The agent also responds to traffic lights. """ 12 | 13 | 14 | import carla 15 | from agents.navigation.agent import Agent, AgentState 16 | from agents.navigation.local_planner import LocalPlanner 17 | from agents.navigation.global_route_planner import GlobalRoutePlanner 18 | from agents.navigation.global_route_planner_dao import GlobalRoutePlannerDAO 19 | 20 | class BasicAgent(Agent): 21 | """ 22 | BasicAgent implements a basic agent that navigates scenes to reach a given 23 | target destination. This agent respects traffic lights and other vehicles. 24 | """ 25 | 26 | def __init__(self, vehicle, target_speed=20): 27 | """ 28 | 29 | :param vehicle: actor to apply to local planner logic onto 30 | """ 31 | super(BasicAgent, self).__init__(vehicle) 32 | 33 | self._proximity_threshold = 10.0 # meters 34 | self._state = AgentState.NAVIGATING 35 | args_lateral_dict = { 36 | 'K_P': 1, 37 | 'K_D': 0.02, 38 | 'K_I': 0, 39 | 'dt': 1.0/20.0} 40 | self._local_planner = LocalPlanner( 41 | self._vehicle, opt_dict={'target_speed' : target_speed, 42 | 'lateral_control_dict':args_lateral_dict}) 43 | self._hop_resolution = 2.0 44 | self._path_seperation_hop = 2 45 | self._path_seperation_threshold = 0.5 46 | self._target_speed = target_speed 47 | self._grp = None 48 | 49 | def set_destination(self, location): 50 | """ 51 | This method creates a list of waypoints from agent's position to destination location 52 | based on the route returned by the global router 53 | """ 54 | 55 | start_waypoint = self._map.get_waypoint(self._vehicle.get_location()) 56 | end_waypoint = self._map.get_waypoint( 57 | carla.Location(location[0], location[1], location[2])) 58 | 59 | route_trace = self._trace_route(start_waypoint, end_waypoint) 60 | assert route_trace 61 | 62 | self._local_planner.set_global_plan(route_trace) 63 | 64 | def _trace_route(self, start_waypoint, end_waypoint): 65 | """ 66 | This method sets up a global router and returns the optimal route 67 | from start_waypoint to end_waypoint 68 | """ 69 | 70 | # Setting up global router 71 | if self._grp is None: 72 | dao = GlobalRoutePlannerDAO(self._vehicle.get_world().get_map(), self._hop_resolution) 73 | grp = GlobalRoutePlanner(dao) 74 | grp.setup() 75 | self._grp = grp 76 | 77 | # Obtain route plan 78 | route = self._grp.trace_route( 79 | start_waypoint.transform.location, 80 | end_waypoint.transform.location) 81 | 82 | return route 83 | 84 | def run_step(self, debug=False): 85 | """ 86 | Execute one step of navigation. 87 | :return: carla.VehicleControl 88 | """ 89 | 90 | # is there an obstacle in front of us? 91 | hazard_detected = False 92 | 93 | # retrieve relevant elements for safe navigation, i.e.: traffic lights 94 | # and other vehicles 95 | actor_list = self._world.get_actors() 96 | vehicle_list = actor_list.filter("*vehicle*") 97 | lights_list = actor_list.filter("*traffic_light*") 98 | # import pdb; pdb.set_trace() 99 | print (actor_list) 100 | 101 | # check possible obstacles 102 | vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) 103 | if vehicle_state: 104 | if debug: 105 | print('!!! VEHICLE BLOCKING AHEAD [{}])'.format(vehicle.id)) 106 | 107 | self._state = AgentState.BLOCKED_BY_VEHICLE 108 | hazard_detected = True 109 | 110 | # check for the state of the traffic lights 111 | light_state, traffic_light = self._is_light_red(lights_list) 112 | if light_state: 113 | if debug: 114 | print('=== RED LIGHT AHEAD [{}])'.format(traffic_light.id)) 115 | 116 | self._state = AgentState.BLOCKED_RED_LIGHT 117 | hazard_detected = True 118 | 119 | if hazard_detected: 120 | control = self.emergency_stop() 121 | else: 122 | self._state = AgentState.NAVIGATING 123 | # standard local planner behavior 124 | control = self._local_planner.run_step(debug=debug) 125 | 126 | return control 127 | -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/controller.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018 Intel Labs. 4 | # authors: German Ros (german.ros@intel.com) 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """ This module contains PID controllers to perform lateral and longitudinal control. """ 10 | 11 | from collections import deque 12 | import math 13 | 14 | import numpy as np 15 | 16 | import carla 17 | from agents.tools.misc import get_speed 18 | 19 | 20 | class VehiclePIDController(): 21 | """ 22 | VehiclePIDController is the combination of two PID controllers (lateral and longitudinal) to perform the 23 | low level control a vehicle from client side 24 | """ 25 | 26 | def __init__(self, vehicle, args_lateral=None, args_longitudinal=None): 27 | """ 28 | :param vehicle: actor to apply to local planner logic onto 29 | :param args_lateral: dictionary of arguments to set the lateral PID controller using the following semantics: 30 | K_P -- Proportional term 31 | K_D -- Differential term 32 | K_I -- Integral term 33 | :param args_longitudinal: dictionary of arguments to set the longitudinal PID controller using the following 34 | semantics: 35 | K_P -- Proportional term 36 | K_D -- Differential term 37 | K_I -- Integral term 38 | """ 39 | if not args_lateral: 40 | args_lateral = {'K_P': 1.0, 'K_D': 0.0, 'K_I': 0.0} 41 | if not args_longitudinal: 42 | args_longitudinal = {'K_P': 1.0, 'K_D': 0.0, 'K_I': 0.0} 43 | 44 | self._vehicle = vehicle 45 | self._world = self._vehicle.get_world() 46 | self._lon_controller = PIDLongitudinalController(self._vehicle, **args_longitudinal) 47 | self._lat_controller = PIDLateralController(self._vehicle, **args_lateral) 48 | 49 | def run_step(self, target_speed, waypoint): 50 | """ 51 | Execute one step of control invoking both lateral and longitudinal PID controllers to reach a target waypoint 52 | at a given target_speed. 53 | 54 | :param target_speed: desired vehicle speed 55 | :param waypoint: target location encoded as a waypoint 56 | :return: distance (in meters) to the waypoint 57 | """ 58 | throttle = self._lon_controller.run_step(target_speed) 59 | steering = self._lat_controller.run_step(waypoint) 60 | 61 | control = carla.VehicleControl() 62 | control.steer = steering 63 | control.throttle = throttle 64 | control.brake = 0.0 65 | control.hand_brake = False 66 | control.manual_gear_shift = False 67 | 68 | return control 69 | 70 | 71 | class PIDLongitudinalController(): 72 | """ 73 | PIDLongitudinalController implements longitudinal control using a PID. 74 | """ 75 | 76 | def __init__(self, vehicle, K_P=1.0, K_D=0.0, K_I=0.0, dt=0.1): 77 | """ 78 | :param vehicle: actor to apply to local planner logic onto 79 | :param K_P: Proportional term 80 | :param K_D: Differential term 81 | :param K_I: Integral term 82 | :param dt: time differential in seconds 83 | """ 84 | self._vehicle = vehicle 85 | self._K_P = K_P 86 | self._K_D = K_D 87 | self._K_I = K_I 88 | self._dt = dt 89 | self._e_buffer = deque(maxlen=30) 90 | 91 | def run_step(self, target_speed, debug=False): 92 | """ 93 | Execute one step of longitudinal control to reach a given target speed. 94 | 95 | :param target_speed: target speed in Km/h 96 | :return: throttle control in the range [0, 1] 97 | """ 98 | current_speed = get_speed(self._vehicle) 99 | 100 | if debug: 101 | print('Current speed = {}'.format(current_speed)) 102 | 103 | return self._pid_control(target_speed, current_speed) 104 | 105 | def _pid_control(self, target_speed, current_speed): 106 | """ 107 | Estimate the throttle of the vehicle based on the PID equations 108 | 109 | :param target_speed: target speed in Km/h 110 | :param current_speed: current speed of the vehicle in Km/h 111 | :return: throttle control in the range [0, 1] 112 | """ 113 | _e = (target_speed - current_speed) 114 | self._e_buffer.append(_e) 115 | 116 | if len(self._e_buffer) >= 2: 117 | _de = (self._e_buffer[-1] - self._e_buffer[-2]) / self._dt 118 | _ie = sum(self._e_buffer) * self._dt 119 | else: 120 | _de = 0.0 121 | _ie = 0.0 122 | 123 | return np.clip((self._K_P * _e) + (self._K_D * _de / self._dt) + (self._K_I * _ie * self._dt), 0.0, 1.0) 124 | 125 | 126 | class PIDLateralController(): 127 | """ 128 | PIDLateralController implements lateral control using a PID. 129 | """ 130 | 131 | def __init__(self, vehicle, K_P=1.0, K_D=0.0, K_I=0.0, dt=0.1): 132 | """ 133 | :param vehicle: actor to apply to local planner logic onto 134 | :param K_P: Proportional term 135 | :param K_D: Differential term 136 | :param K_I: Integral term 137 | :param dt: time differential in seconds 138 | """ 139 | self._vehicle = vehicle 140 | self._K_P = K_P 141 | self._K_D = K_D 142 | self._K_I = K_I 143 | self._dt = dt 144 | self._e_buffer = deque(maxlen=10) 145 | 146 | def run_step(self, waypoint): 147 | """ 148 | Execute one step of lateral control to steer the vehicle towards a certain waypoin. 149 | 150 | :param waypoint: target waypoint 151 | :return: steering control in the range [-1, 1] where: 152 | -1 represent maximum steering to left 153 | +1 maximum steering to right 154 | """ 155 | return self._pid_control(waypoint, self._vehicle.get_transform()) 156 | 157 | def _pid_control(self, waypoint, vehicle_transform): 158 | """ 159 | Estimate the steering angle of the vehicle based on the PID equations 160 | 161 | :param waypoint: target waypoint 162 | :param vehicle_transform: current transform of the vehicle 163 | :return: steering control in the range [-1, 1] 164 | """ 165 | v_begin = vehicle_transform.location 166 | v_end = v_begin + carla.Location(x=math.cos(math.radians(vehicle_transform.rotation.yaw)), 167 | y=math.sin(math.radians(vehicle_transform.rotation.yaw))) 168 | 169 | v_vec = np.array([v_end.x - v_begin.x, v_end.y - v_begin.y, 0.0]) 170 | w_vec = np.array([waypoint.transform.location.x - 171 | v_begin.x, waypoint.transform.location.y - 172 | v_begin.y, 0.0]) 173 | _dot = math.acos(np.clip(np.dot(w_vec, v_vec) / 174 | (np.linalg.norm(w_vec) * np.linalg.norm(v_vec)), -1.0, 1.0)) 175 | 176 | _cross = np.cross(v_vec, w_vec) 177 | if _cross[2] < 0: 178 | _dot *= -1.0 179 | 180 | self._e_buffer.append(_dot) 181 | if len(self._e_buffer) >= 2: 182 | _de = (self._e_buffer[-1] - self._e_buffer[-2]) / self._dt 183 | _ie = sum(self._e_buffer) * self._dt 184 | else: 185 | _de = 0.0 186 | _ie = 0.0 187 | 188 | return np.clip((self._K_P * _dot) + (self._K_D * _de / 189 | self._dt) + (self._K_I * _ie * self._dt), -1.0, 1.0) 190 | -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/global_route_planner_dao.py: -------------------------------------------------------------------------------- 1 | # This work is licensed under the terms of the MIT license. 2 | # For a copy, see . 3 | 4 | """ 5 | This module provides implementation for GlobalRoutePlannerDAO 6 | """ 7 | 8 | import numpy as np 9 | 10 | 11 | class GlobalRoutePlannerDAO(object): 12 | """ 13 | This class is the data access layer for fetching data 14 | from the carla server instance for GlobalRoutePlanner 15 | """ 16 | 17 | def __init__(self, wmap, sampling_resolution=1): 18 | """get_topology 19 | Constructor 20 | 21 | wmap : carl world map object 22 | """ 23 | self._sampling_resolution = sampling_resolution 24 | self._wmap = wmap 25 | 26 | def get_topology(self): 27 | """ 28 | Accessor for topology. 29 | This function retrieves topology from the server as a list of 30 | road segments as pairs of waypoint objects, and processes the 31 | topology into a list of dictionary objects. 32 | 33 | return: list of dictionary objects with the following attributes 34 | entry - waypoint of entry point of road segment 35 | entryxyz- (x,y,z) of entry point of road segment 36 | exit - waypoint of exit point of road segment 37 | exitxyz - (x,y,z) of exit point of road segment 38 | path - list of waypoints separated by 1m from entry 39 | to exit 40 | """ 41 | topology = [] 42 | # Retrieving waypoints to construct a detailed topology 43 | for segment in self._wmap.get_topology(): 44 | wp1, wp2 = segment[0], segment[1] 45 | l1, l2 = wp1.transform.location, wp2.transform.location 46 | # Rounding off to avoid floating point imprecision 47 | x1, y1, z1, x2, y2, z2 = np.round([l1.x, l1.y, l1.z, l2.x, l2.y, l2.z], 0) 48 | wp1.transform.location, wp2.transform.location = l1, l2 49 | seg_dict = dict() 50 | seg_dict['entry'], seg_dict['exit'] = wp1, wp2 51 | seg_dict['entryxyz'], seg_dict['exitxyz'] = (x1, y1, z1), (x2, y2, z2) 52 | seg_dict['path'] = [] 53 | endloc = wp2.transform.location 54 | if wp1.transform.location.distance(endloc) > self._sampling_resolution: 55 | w = wp1.next(self._sampling_resolution)[0] 56 | while w.transform.location.distance(endloc) > self._sampling_resolution: 57 | seg_dict['path'].append(w) 58 | w = w.next(self._sampling_resolution)[0] 59 | else: 60 | seg_dict['path'].append(wp1.next(self._sampling_resolution/2.0)[0]) 61 | topology.append(seg_dict) 62 | return topology 63 | 64 | def get_waypoint(self, location): 65 | """ 66 | The method returns waypoint at given location 67 | """ 68 | waypoint = self._wmap.get_waypoint(location) 69 | return waypoint 70 | 71 | def get_resolution(self): 72 | """ Accessor for self._sampling_resolution """ 73 | return self._sampling_resolution -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/roaming_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018 Intel Labs. 4 | # authors: German Ros (german.ros@intel.com) 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """ This module implements an agent that roams around a track following random waypoints and avoiding other vehicles. 10 | The agent also responds to traffic lights. """ 11 | 12 | from agents.navigation.agent import Agent, AgentState 13 | from agents.navigation.local_planner import LocalPlanner 14 | 15 | 16 | class RoamingAgent(Agent): 17 | """ 18 | RoamingAgent implements a basic agent that navigates scenes making random 19 | choices when facing an intersection. 20 | 21 | This agent respects traffic lights and other vehicles. 22 | """ 23 | 24 | def __init__(self, vehicle): 25 | """ 26 | 27 | :param vehicle: actor to apply to local planner logic onto 28 | """ 29 | super(RoamingAgent, self).__init__(vehicle) 30 | 31 | self._proximity_threshold = 10.0 # meters 32 | self._state = AgentState.NAVIGATING 33 | self._local_planner = LocalPlanner(self._vehicle) 34 | 35 | def run_step(self, debug=False): 36 | """ 37 | Execute one step of navigation. 38 | :return: carla.VehicleControl 39 | """ 40 | 41 | # is there an obstacle in front of us? 42 | hazard_detected = False 43 | 44 | # retrieve relevant elements for safe navigation, i.e.: traffic lights 45 | # and other vehicles 46 | actor_list = self._world.get_actors() 47 | vehicle_list = actor_list.filter("*vehicle*") 48 | lights_list = actor_list.filter("*traffic_light*") 49 | 50 | # check possible obstacles 51 | vehicle_state, vehicle = self._is_vehicle_hazard(vehicle_list) 52 | if vehicle_state: 53 | if debug: 54 | print('!!! VEHICLE BLOCKING AHEAD [{}])'.format(vehicle.id)) 55 | 56 | self._state = AgentState.BLOCKED_BY_VEHICLE 57 | hazard_detected = True 58 | 59 | # check for the state of the traffic lights 60 | light_state, traffic_light = self._is_light_red(lights_list) 61 | if light_state: 62 | if debug: 63 | print('=== RED LIGHT AHEAD [{}])'.format(traffic_light.id)) 64 | 65 | self._state = AgentState.BLOCKED_RED_LIGHT 66 | hazard_detected = True 67 | 68 | if hazard_detected: 69 | control = self.emergency_stop() 70 | else: 71 | self._state = AgentState.NAVIGATING 72 | # standard local planner behavior 73 | control = self._local_planner.run_step() 74 | 75 | return control 76 | -------------------------------------------------------------------------------- /PythonAPI/agents/navigation/test_global_route_planner.py: -------------------------------------------------------------------------------- 1 | import math 2 | import unittest 3 | 4 | import carla 5 | 6 | from global_route_planner import GlobalRoutePlanner 7 | from global_route_planner import NavEnum 8 | from global_route_planner_dao import GlobalRoutePlannerDAO 9 | 10 | 11 | class Test_GlobalRoutePlanner(unittest.TestCase): 12 | """ 13 | Test class for GlobalRoutePlanner class 14 | """ 15 | 16 | def setUp(self): 17 | # == Utilities test instance without DAO == # 18 | self.simple_grp = GlobalRoutePlanner(None) 19 | 20 | # == Integration test instance == # 21 | client = carla.Client('localhost', 2000) 22 | world = client.get_world() 23 | integ_dao = GlobalRoutePlannerDAO(world.get_map()) 24 | self.integ_grp = GlobalRoutePlanner(integ_dao) 25 | self.integ_grp.setup() 26 | pass 27 | 28 | def tearDown(self): 29 | self.simple_grp = None 30 | self.dao_grp = None 31 | self.integ_grp = None 32 | pass 33 | 34 | def test_plan_route(self): 35 | """ 36 | Test for GlobalROutePlanner.plan_route() 37 | Run this test with carla server running Town03 38 | """ 39 | plan = self.integ_grp.plan_route((-60, -5), (-77.65, 72.72)) 40 | self.assertEqual( 41 | plan, [NavEnum.START, NavEnum.LEFT, NavEnum.LEFT, 42 | NavEnum.GO_STRAIGHT, NavEnum.LEFT, NavEnum.STOP]) 43 | 44 | def test_path_search(self): 45 | """ 46 | Test for GlobalRoutePlanner.path_search() 47 | Run this test with carla server running Town03 48 | """ 49 | self.integ_grp.path_search((191.947, -5.602), (78.730, -50.091)) 50 | self.assertEqual( 51 | self.integ_grp.path_search((196.947, -5.602), (78.730, -50.091)), 52 | [256, 157, 158, 117, 118, 59, 55, 230]) 53 | 54 | def test_localise(self): 55 | """ 56 | Test for GlobalRoutePlanner.localise() 57 | Run this test with carla server running Town03 58 | """ 59 | x, y = (200, -250) 60 | segment = self.integ_grp.localise(x, y) 61 | self.assertEqual(self.integ_grp._id_map[segment['entry']], 5) 62 | self.assertEqual(self.integ_grp._id_map[segment['exit']], 225) 63 | 64 | def test_unit_vector(self): 65 | """ 66 | Test for GlobalROutePlanner.unit_vector() 67 | """ 68 | vector = self.simple_grp.unit_vector((1, 1), (2, 2)) 69 | self.assertAlmostEquals(vector[0], 1 / math.sqrt(2)) 70 | self.assertAlmostEquals(vector[1], 1 / math.sqrt(2)) 71 | 72 | def test_dot(self): 73 | """ 74 | Test for GlobalROutePlanner.test_dot() 75 | """ 76 | self.assertAlmostEqual(self.simple_grp.dot((1, 0), (0, 1)), 0) 77 | self.assertAlmostEqual(self.simple_grp.dot((1, 0), (1, 0)), 1) 78 | 79 | 80 | def suite(): 81 | """ 82 | Gathering all tests 83 | """ 84 | 85 | suite = unittest.TestSuite() 86 | suite.addTest(Test_GlobalRoutePlanner('test_unit_vector')) 87 | suite.addTest(Test_GlobalRoutePlanner('test_dot')) 88 | suite.addTest(Test_GlobalRoutePlanner('test_localise')) 89 | suite.addTest(Test_GlobalRoutePlanner('test_path_search')) 90 | suite.addTest(Test_GlobalRoutePlanner('test_plan_route')) 91 | 92 | return suite 93 | 94 | 95 | if __name__ == '__main__': 96 | """ 97 | Running test suite 98 | """ 99 | mySuit = suite() 100 | runner = unittest.TextTestRunner() 101 | runner.run(mySuit) 102 | -------------------------------------------------------------------------------- /PythonAPI/agents/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/PythonAPI/agents/tools/__init__.py -------------------------------------------------------------------------------- /PythonAPI/agents/tools/misc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2018 Intel Labs. 4 | # authors: German Ros (german.ros@intel.com) 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """ Module with auxiliary functions. """ 10 | 11 | import math 12 | 13 | import numpy as np 14 | 15 | import carla 16 | 17 | 18 | def draw_waypoints(world, waypoints, z=0.5): 19 | """ 20 | Draw a list of waypoints at a certain height given in z. 21 | 22 | :param world: carla.world object 23 | :param waypoints: list or iterable container with the waypoints to draw 24 | :param z: height in meters 25 | :return: 26 | """ 27 | for w in waypoints: 28 | t = w.transform 29 | begin = t.location + carla.Location(z=z) 30 | angle = math.radians(t.rotation.yaw) 31 | end = begin + carla.Location(x=math.cos(angle), y=math.sin(angle)) 32 | world.debug.draw_arrow(begin, end, arrow_size=0.3, life_time=1.0) 33 | 34 | 35 | def get_speed(vehicle): 36 | """ 37 | Compute speed of a vehicle in Kmh 38 | :param vehicle: the vehicle for which speed is calculated 39 | :return: speed as a float in Kmh 40 | """ 41 | vel = vehicle.get_velocity() 42 | return 3.6 * math.sqrt(vel.x ** 2 + vel.y ** 2 + vel.z ** 2) 43 | 44 | 45 | def compute_yaw_difference(yaw1, yaw2): 46 | u = np.array([ 47 | math.cos(math.radians(yaw1)), 48 | math.sin(math.radians(yaw1)), 49 | ]) 50 | 51 | v = np.array([ 52 | math.cos(math.radians(yaw2)), 53 | math.sin(math.radians(yaw2)), 54 | ]) 55 | 56 | 57 | angle = math.degrees(math.acos(np.clip(np.dot(u, v), -1, 1))) 58 | 59 | return angle 60 | 61 | 62 | def is_within_distance_ahead(target_location, current_location, orientation, max_distance, degree=60): 63 | """ 64 | Check if a target object is within a certain distance in front of a reference object. 65 | 66 | :param target_location: location of the target object 67 | :param current_location: location of the reference object 68 | :param orientation: orientation of the reference object 69 | :param max_distance: maximum allowed distance 70 | :return: True if target object is within max_distance ahead of the reference object 71 | """ 72 | u = np.array([ 73 | target_location.x - current_location.x, 74 | target_location.y - current_location.y]) 75 | distance = np.linalg.norm(u) 76 | 77 | if distance > max_distance: 78 | return False 79 | 80 | v = np.array([ 81 | math.cos(math.radians(orientation)), 82 | math.sin(math.radians(orientation))]) 83 | 84 | angle = math.degrees(math.acos(np.dot(u, v) / distance)) 85 | 86 | return angle < degree 87 | 88 | 89 | def compute_magnitude_angle(target_location, current_location, orientation): 90 | """ 91 | Compute relative angle and distance between a target_location and a current_location 92 | 93 | :param target_location: location of the target object 94 | :param current_location: location of the reference object 95 | :param orientation: orientation of the reference object 96 | :return: a tuple composed by the distance to the object and the angle between both objects 97 | """ 98 | target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y]) 99 | norm_target = np.linalg.norm(target_vector) 100 | 101 | forward_vector = np.array([math.cos(math.radians(orientation)), math.sin(math.radians(orientation))]) 102 | d_angle = math.degrees(math.acos(np.dot(forward_vector, target_vector) / norm_target)) 103 | 104 | return (norm_target, d_angle) 105 | 106 | 107 | def distance_vehicle(waypoint, vehicle_transform): 108 | loc = vehicle_transform.location 109 | dx = waypoint.transform.location.x - loc.x 110 | dy = waypoint.transform.location.y - loc.y 111 | 112 | return math.sqrt(dx * dx + dy * dy) 113 | 114 | def vector(location_1, location_2): 115 | """ 116 | Returns the unit vector from location_1 to location_2 117 | location_1, location_2 : carla.Location objects 118 | """ 119 | x = location_2.x - location_1.x 120 | y = location_2.y - location_1.y 121 | z = location_2.z - location_1.z 122 | norm = np.linalg.norm([x, y, z]) 123 | 124 | return [x/norm, y/norm, z/norm] 125 | -------------------------------------------------------------------------------- /benchmark/__init__.py: -------------------------------------------------------------------------------- 1 | from .goal_suite import PointGoalSuite 2 | 3 | 4 | VERSION = '096' 5 | 6 | WEATHER_1 = [1, 3, 6, 8] 7 | WEATHER_2 = [4, 14] 8 | WEATHER_3 = [10, 14] 9 | WEATHER_4 = [1, 8, 14] 10 | 11 | _suites = dict() 12 | 13 | 14 | def _add(suite_name, *args, **kwargs): 15 | assert suite_name not in _suites, '%s is already registered!' % suite_name 16 | 17 | town = None 18 | 19 | if 'Town01' in suite_name: 20 | town = 'Town01' 21 | elif 'Town02' in suite_name: 22 | town = 'Town02' 23 | else: 24 | raise Exception('No town specified: %s.' % suite_name) 25 | 26 | benchmark = 'carla100' if 'NoCrash' in suite_name else 'corl2017' 27 | suite = None 28 | 29 | if 'Turn' in suite_name: 30 | suite = 'turn' 31 | elif 'Straight' in suite_name: 32 | suite = 'straight' 33 | elif 'Full' in suite_name: 34 | suite = 'full' 35 | elif 'NoCrash' in suite_name: 36 | suite = 'nocrash' 37 | else: 38 | raise Exception('No suite specified: %s.' % suite_name) 39 | 40 | kwargs['town'] = town 41 | kwargs['poses_txt'] = '%s/%s/%s_%s.txt' % (benchmark, VERSION, suite, town) 42 | kwargs['col_is_failure'] = 'NoCrash' in suite_name 43 | 44 | _suites[suite_name] = (args, kwargs) 45 | 46 | 47 | ## ============= Register Suites ============ ## 48 | # _add('DebugTown01-v0', DebugSuite, n_vehicles=10, viz_camera=True) 49 | # _add('FullTown01-v0', n_vehicles=0, viz_camera=True) 50 | # _add('FullTown02-v0', n_vehicles=0, viz_camera=True) 51 | 52 | # data collection town; no respawn to prevent missing frames 53 | _add('FullTown01-v0', n_vehicles=0, weathers=WEATHER_1, respawn_peds=False) 54 | # Train town, train weathers. 55 | _add('FullTown01-v1', n_vehicles=0, weathers=WEATHER_1) 56 | _add('StraightTown01-v1', n_vehicles=0, weathers=WEATHER_1) 57 | _add('TurnTown01-v1', n_vehicles=0, weathers=WEATHER_1) 58 | 59 | # Train town, test weathers. 60 | _add('FullTown01-v2', n_vehicles=0, weathers=WEATHER_2) 61 | _add('StraightTown01-v2', n_vehicles=0, weathers=WEATHER_2) 62 | _add('TurnTown01-v2', n_vehicles=0, weathers=WEATHER_2) 63 | 64 | # Train town, more vehicles 65 | _add('FullTown01-v3', n_vehicles=20, n_pedestrians=50, weathers=WEATHER_1) 66 | _add('FullTown01-v4', n_vehicles=20, n_pedestrians=50, weathers=WEATHER_2) 67 | # No ped versions 68 | _add('FullTown01-v3-np', n_vehicles=20, n_pedestrians=0, weathers=WEATHER_1) 69 | _add('FullTown01-v4-np', n_vehicles=20, n_pedestrians=0, weathers=WEATHER_2) 70 | 71 | # Test town, train weathers. 72 | _add('FullTown02-v1', n_vehicles=0, weathers=WEATHER_1) 73 | _add('StraightTown02-v1', n_vehicles=0, weathers=WEATHER_1) 74 | _add('TurnTown02-v1', n_vehicles=0, weathers=WEATHER_1) 75 | 76 | # Test town, test weathers. 77 | _add('FullTown02-v2', n_vehicles=0, weathers=WEATHER_2) 78 | _add('StraightTown02-v2', n_vehicles=0, weathers=WEATHER_2) 79 | _add('TurnTown02-v2', n_vehicles=0, weathers=WEATHER_2) 80 | 81 | # Test town, more vehicles. 82 | _add('FullTown02-v3', n_vehicles=15, n_pedestrians=50, weathers=WEATHER_1) 83 | _add('FullTown02-v4', n_vehicles=15, n_pedestrians=50, weathers=WEATHER_2) 84 | # No ped versions 85 | _add('FullTown02-v3-np', n_vehicles=15, n_pedestrians=0, weathers=WEATHER_1) 86 | _add('FullTown02-v4-np', n_vehicles=15, n_pedestrians=0, weathers=WEATHER_2) 87 | 88 | _add('NoCrashTown01-v1', n_vehicles=0, disable_two_wheels=True, weathers=WEATHER_1) 89 | _add('NoCrashTown01-v2', n_vehicles=0, disable_two_wheels=True, weathers=WEATHER_3) 90 | _add('NoCrashTown01-v3', n_vehicles=20, disable_two_wheels=True, n_pedestrians=50, weathers=WEATHER_1) 91 | _add('NoCrashTown01-v4', n_vehicles=20, disable_two_wheels=True, n_pedestrians=50, weathers=WEATHER_3) 92 | _add('NoCrashTown01-v5', n_vehicles=100, disable_two_wheels=True, n_pedestrians=250, weathers=WEATHER_1) 93 | _add('NoCrashTown01-v6', n_vehicles=100, disable_two_wheels=True, n_pedestrians=250, weathers=WEATHER_3) 94 | # No ped versions 95 | _add('NoCrashTown01-v3-np', n_vehicles=20, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_1) 96 | _add('NoCrashTown01-v4-np', n_vehicles=20, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_3) 97 | _add('NoCrashTown01-v5-np', n_vehicles=100, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_1) 98 | _add('NoCrashTown01-v6-np', n_vehicles=100, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_3) 99 | 100 | _add('NoCrashTown02-v1', n_vehicles=0, disable_two_wheels=True, weathers=WEATHER_1) 101 | _add('NoCrashTown02-v2', n_vehicles=0, disable_two_wheels=True, weathers=WEATHER_3) 102 | _add('NoCrashTown02-v3', n_vehicles=15, disable_two_wheels=True, n_pedestrians=50, weathers=WEATHER_1) 103 | _add('NoCrashTown02-v4', n_vehicles=15, disable_two_wheels=True, n_pedestrians=50, weathers=WEATHER_3) 104 | _add('NoCrashTown02-v5', n_vehicles=70, disable_two_wheels=True, n_pedestrians=150, weathers=WEATHER_1) 105 | _add('NoCrashTown02-v6', n_vehicles=70, disable_two_wheels=True, n_pedestrians=150, weathers=WEATHER_3) 106 | # No ped versions 107 | _add('NoCrashTown02-v3-np', n_vehicles=15, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_1) 108 | _add('NoCrashTown02-v4-np', n_vehicles=15, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_3) 109 | _add('NoCrashTown02-v5-np', n_vehicles=70, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_1) 110 | _add('NoCrashTown02-v6-np', n_vehicles=70, disable_two_wheels=True, n_pedestrians=0, weathers=WEATHER_3) 111 | 112 | # Demo 113 | _add('NoCrashTown01-v7', n_vehicles=100, n_pedestrians=250, weathers=WEATHER_1) 114 | _add('NoCrashTown01-v8', n_vehicles=100, n_pedestrians=250, weathers=WEATHER_2) 115 | _add('NoCrashTown02-v7', n_vehicles=70, n_pedestrians=150, weathers=WEATHER_1) 116 | _add('NoCrashTown02-v8', n_vehicles=70, n_pedestrians=150, weathers=WEATHER_2) 117 | 118 | 119 | # Weather primes. 120 | _add('FullTown01-v5', n_vehicles=0, weathers=WEATHER_4) 121 | _add('FullTown01-v6', n_vehicles=20, weathers=WEATHER_4) 122 | _add('StraightTown01-v3', n_vehicles=0, weathers=WEATHER_4) 123 | _add('TurnTown01-v3', n_vehicles=0, weathers=WEATHER_4) 124 | 125 | _add('FullTown02-v5', n_vehicles=0, weathers=WEATHER_4) 126 | _add('FullTown02-v6', n_vehicles=15, weathers=WEATHER_4) 127 | _add('StraightTown02-v3', n_vehicles=0, weathers=WEATHER_4) 128 | _add('TurnTown02-v3', n_vehicles=0, weathers=WEATHER_4) 129 | 130 | # Random 131 | _add('NoCrashTown01_noweather_empty', weathers=[1], n_vehicles=0) 132 | _add('NoCrashTown01_noweather_regular', weathers=[1], n_vehicles=20, n_pedestrians=50) 133 | _add('NoCrashTown01_noweather_dense', weathers=[1], n_vehicles=100, n_pedestrians=250) 134 | 135 | _add('NoCrashTown02_noweather_empty', weathers=[1], n_vehicles=0) 136 | _add('NoCrashTown02_noweather_regular', weathers=[1], n_vehicles=15, n_pedestrians=50) 137 | _add('NoCrashTown02_noweather_dense', weathers=[1], n_vehicles=70, n_pedestrians=200) 138 | 139 | _add('StraightTown01-noweather', n_vehicles=0, weathers=[1]) 140 | _add('TurnTown01-noweather', n_vehicles=0, weathers=[1]) 141 | _add('FullTown01-noweather-nav', n_vehicles=0, weathers=[1]) 142 | _add('FullTown01-noweather', n_vehicles=20, weathers=[1]) 143 | 144 | _add('StraightTown02-noweather', n_vehicles=0, weathers=[1]) 145 | _add('TurnTown02-noweather', n_vehicles=0, weathers=[1]) 146 | _add('FullTown02-noweather-nav', n_vehicles=0, weathers=[1]) 147 | _add('FullTown02-noweather', n_vehicles=15, weathers=[1]) 148 | 149 | 150 | _aliases = { 151 | 'town1': [ 152 | 'FullTown01-v1', 'FullTown01-v2', 'FullTown01-v3', 'FullTown01-v4', 153 | 'StraightTown01-v1', 'StraightTown01-v2', 154 | 'TurnTown01-v1', 'TurnTown01-v2'], 155 | 'town2': [ 156 | 'FullTown02-v1', 'FullTown02-v2', 'FullTown02-v3', 'FullTown02-v4', 157 | 'StraightTown02-v1', 'StraightTown02-v2', 158 | 'TurnTown02-v1', 'TurnTown02-v2'], 159 | 'town1p': [ 160 | 'FullTown01-v5', 'FullTown01-v6', 161 | 'StraightTown01-v3', 'TurnTown01-v3', 162 | 'FullTown01-v5', 'FullTown01-v6', 163 | ], 164 | 'town2p': [ 165 | 'FullTown02-v5', 'FullTown02-v6', 166 | 'StraightTown02-v3', 'TurnTown02-v3', 167 | 'FullTown02-v5', 'FullTown02-v6', 168 | ], 169 | 'ntown1p': [ 170 | 'NoCrashTown01-v7', 'NoCrashTown01-v8', 'NoCrashTown01-v9', 171 | ], 172 | 173 | 'ntown2p': [ 174 | 'NoCrashTown02-v7', 'NoCrashTown02-v8', 'NoCrashTown02-v9', 175 | ], 176 | 'empty': [ 177 | 'NoCrashTown01-v1', 'NoCrashTown01-v2', 178 | 'NoCrashTown02-v1', 'NoCrashTown02-v2', 179 | ], 180 | 'regular': [ 181 | 'NoCrashTown01-v3', 'NoCrashTown01-v4', 182 | 'NoCrashTown02-v3', 'NoCrashTown02-v4', 183 | ], 184 | 'regular-np': [ 185 | 'NoCrashTown01-v3-np', 'NoCrashTown01-v4-np', 186 | 'NoCrashTown02-v3-np', 'NoCrashTown02-v4-np', 187 | ], 188 | 'dense': [ 189 | 'NoCrashTown01-v5', 'NoCrashTown01-v6', 190 | 'NoCrashTown02-v5', 'NoCrashTown02-v6', 191 | ], 192 | 'dense-np': [ 193 | 'NoCrashTown01-v5-np', 'NoCrashTown01-v6-np', 194 | 'NoCrashTown02-v5-np', 'NoCrashTown02-v6-np', 195 | ] 196 | } 197 | 198 | _aliases['all'] = _aliases['town1'] + _aliases['town2'] 199 | 200 | ALL_SUITES = list(_suites.keys()) + list(_aliases.keys()) 201 | 202 | 203 | def make_suite(suite_name, port=2000, big_cam=False, planner='new', client=None): 204 | assert suite_name in _suites, '%s is not registered!'%suite_name 205 | 206 | args, kwargs = _suites[suite_name] 207 | kwargs['port'] = port 208 | kwargs['big_cam'] = big_cam 209 | kwargs['planner'] = planner 210 | kwargs['client'] = client 211 | 212 | return PointGoalSuite(*args, **kwargs) 213 | 214 | 215 | def get_suites(suite_name): 216 | if suite_name.lower() in _aliases: 217 | return _aliases[suite_name] 218 | 219 | return [suite_name] 220 | -------------------------------------------------------------------------------- /benchmark/base_suite.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | from bird_view.utils import carla_utils as cu 4 | 5 | 6 | class BaseSuite(cu.CarlaWrapper): 7 | def __init__(self, weathers=[0], n_vehicles=0, n_pedestrians=0, disable_two_wheels=False, **kwargs): 8 | super().__init__(**kwargs) 9 | 10 | self._weathers = weathers 11 | self.n_vehicles = n_vehicles 12 | self.n_pedestrians = n_pedestrians 13 | self.disable_two_wheels = disable_two_wheels 14 | 15 | def get_spawn_point(self, pose_num): 16 | return self._spawn_points[pose_num] 17 | 18 | def is_failure(self): 19 | raise NotImplementedError 20 | 21 | def is_success(self): 22 | raise NotImplementedError 23 | 24 | @property 25 | def pose_tasks(self): 26 | raise NotImplementedError 27 | 28 | @property 29 | def weathers(self): 30 | return self._weathers 31 | 32 | @property 33 | def all_tasks(self): 34 | for (start, target), weather in product(self.pose_tasks, self.weathers): 35 | run_name = 's%d_t%d_w%d' % (start, target, weather) 36 | 37 | yield weather, (start, target), run_name 38 | -------------------------------------------------------------------------------- /benchmark/carla100/084/nocrash_Town01.txt: -------------------------------------------------------------------------------- 1 | 105 29 2 | 27 130 3 | 102 87 4 | 132 27 5 | 25 44 6 | 4 64 7 | 34 67 8 | 54 30 9 | 140 134 10 | 105 9 11 | 148 129 12 | 65 18 13 | 21 16 14 | 147 97 15 | 134 49 16 | 30 41 17 | 81 89 18 | 69 45 19 | 102 95 20 | 18 145 21 | 111 64 22 | 79 45 23 | 84 69 24 | 73 31 25 | 37 81 26 | -------------------------------------------------------------------------------- /benchmark/carla100/084/nocrash_Town02.txt: -------------------------------------------------------------------------------- 1 | 19 66 2 | 79 14 3 | 19 57 4 | 39 53 5 | 60 26 6 | 53 76 7 | 42 13 8 | 31 71 9 | 59 35 10 | 47 16 11 | 10 61 12 | 66 3 13 | 20 79 14 | 14 56 15 | 26 69 16 | 79 19 17 | 2 29 18 | 16 14 19 | 5 57 20 | 77 68 21 | 70 73 22 | 46 67 23 | 34 77 24 | 61 49 25 | 21 12 26 | -------------------------------------------------------------------------------- /benchmark/carla100/096/nocrash_Town01.txt: -------------------------------------------------------------------------------- 1 | 79 227 2 | 105 21 3 | 129 88 4 | 19 105 5 | 231 212 6 | 252 192 7 | 222 120 8 | 202 226 9 | 11 17 10 | 79 247 11 | 3 177 12 | 191 114 13 | 235 240 14 | 4 54 15 | 17 207 16 | 223 212 17 | 154 66 18 | 187 123 19 | 129 56 20 | 114 6 21 | 40 192 22 | 176 123 23 | 121 187 24 | 238 225 25 | 219 154 26 | -------------------------------------------------------------------------------- /benchmark/carla100/096/nocrash_Town02.txt: -------------------------------------------------------------------------------- 1 | 66 19 2 | 6 71 3 | 66 28 4 | 46 32 5 | 25 59 6 | 32 9 7 | 43 72 8 | 54 14 9 | 26 50 10 | 38 69 11 | 75 24 12 | 19 82 13 | 65 6 14 | 71 29 15 | 59 16 16 | 6 66 17 | 83 56 18 | 69 71 19 | 82 28 20 | 8 17 21 | 19 12 22 | 39 18 23 | 51 8 24 | 24 36 25 | 64 73 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/full_Town01.txt: -------------------------------------------------------------------------------- 1 | 105 29 2 | 27 130 3 | 102 87 4 | 132 27 5 | 24 44 6 | 96 26 7 | 34 67 8 | 28 1 9 | 140 134 10 | 105 9 11 | 148 129 12 | 65 18 13 | 21 16 14 | 147 97 15 | 42 51 16 | 30 41 17 | 18 107 18 | 69 45 19 | 102 95 20 | 18 145 21 | 111 64 22 | 79 45 23 | 84 69 24 | 73 31 25 | 37 81 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/full_Town02.txt: -------------------------------------------------------------------------------- 1 | 19 66 2 | 79 14 3 | 19 57 4 | 23 1 5 | 53 76 6 | 42 13 7 | 31 71 8 | 33 5 9 | 54 30 10 | 10 61 11 | 66 3 12 | 27 12 13 | 79 19 14 | 2 29 15 | 16 14 16 | 5 57 17 | 70 73 18 | 46 67 19 | 57 50 20 | 61 49 21 | 21 12 22 | 51 81 23 | 77 68 24 | 56 65 25 | 43 54 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/straight_Town01.txt: -------------------------------------------------------------------------------- 1 | 36 40 2 | 39 35 3 | 110 114 4 | 7 3 5 | 0 4 6 | 68 50 7 | 61 59 8 | 47 64 9 | 147 90 10 | 33 87 11 | 26 19 12 | 80 76 13 | 45 49 14 | 55 44 15 | 29 107 16 | 95 104 17 | 84 34 18 | 53 67 19 | 22 17 20 | 91 148 21 | 20 107 22 | 78 70 23 | 95 102 24 | 68 44 25 | 45 69 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/straight_Town02.txt: -------------------------------------------------------------------------------- 1 | 38 34 2 | 4 2 3 | 12 10 4 | 62 55 5 | 43 47 6 | 64 66 7 | 78 76 8 | 59 57 9 | 61 18 10 | 35 39 11 | 12 8 12 | 0 18 13 | 75 68 14 | 54 60 15 | 45 49 16 | 46 42 17 | 53 46 18 | 80 29 19 | 65 63 20 | 0 81 21 | 54 63 22 | 51 42 23 | 16 19 24 | 17 26 25 | 77 68 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/turn_Town01.txt: -------------------------------------------------------------------------------- 1 | 138 17 2 | 47 16 3 | 26 9 4 | 42 49 5 | 140 124 6 | 85 98 7 | 65 133 8 | 137 51 9 | 76 66 10 | 46 39 11 | 40 60 12 | 0 29 13 | 4 129 14 | 121 140 15 | 2 129 16 | 78 44 17 | 68 85 18 | 41 102 19 | 95 70 20 | 68 129 21 | 84 69 22 | 47 79 23 | 110 15 24 | 130 17 25 | 0 17 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/084/turn_Town02.txt: -------------------------------------------------------------------------------- 1 | 37 76 2 | 8 24 3 | 60 69 4 | 38 10 5 | 21 1 6 | 58 71 7 | 74 32 8 | 44 0 9 | 71 16 10 | 14 24 11 | 34 11 12 | 43 14 13 | 75 16 14 | 80 21 15 | 3 23 16 | 75 59 17 | 50 47 18 | 11 19 19 | 77 34 20 | 79 25 21 | 40 63 22 | 58 76 23 | 79 55 24 | 16 61 25 | 27 11 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/096/full_Town01.txt: -------------------------------------------------------------------------------- 1 | 79 227 2 | 105 21 3 | 129 88 4 | 19 105 5 | 104 212 6 | 84 230 7 | 222 120 8 | 228 255 9 | 11 17 10 | 79 247 11 | 3 177 12 | 191 240 13 | 235 240 14 | 4 54 15 | 214 205 16 | 56 215 17 | 114 44 18 | 187 123 19 | 129 56 20 | 114 6 21 | 40 192 22 | 176 123 23 | 121 187 24 | 238 225 25 | 219 154 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/096/full_Town02.txt: -------------------------------------------------------------------------------- 1 | 66 19 2 | 6 71 3 | 66 28 4 | 62 89 5 | 32 9 6 | 43 72 7 | 54 14 8 | 52 82 9 | 31 55 10 | 75 24 11 | 19 82 12 | 58 73 13 | 6 66 14 | 83 56 15 | 69 71 16 | 82 28 17 | 19 12 18 | 39 18 19 | 28 35 20 | 24 36 21 | 64 73 22 | 34 4 23 | 8 17 24 | 29 20 25 | 42 31 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/096/straight_Town01.txt: -------------------------------------------------------------------------------- 1 | 220 216 2 | 217 72 3 | 41 37 4 | 249 253 5 | 256 252 6 | 188 206 7 | 195 197 8 | 209 192 9 | 4 61 10 | 223 88 11 | 230 237 12 | 165 180 13 | 123 207 14 | 201 212 15 | 227 44 16 | 56 47 17 | 121 222 18 | 203 120 19 | 234 239 20 | 60 3 21 | 103 44 22 | 178 186 23 | 56 129 24 | 188 212 25 | 123 187 -------------------------------------------------------------------------------- /benchmark/corl2017/096/straight_Town02.txt: -------------------------------------------------------------------------------- 1 | 47 51 2 | 81 83 3 | 73 75 4 | 23 30 5 | 42 38 6 | 21 19 7 | 7 9 8 | 26 28 9 | 24 67 10 | 50 46 11 | 73 77 12 | 100 67 13 | 10 17 14 | 31 25 15 | 40 36 16 | 39 43 17 | 32 39 18 | 5 56 19 | 20 22 20 | 100 4 21 | 31 22 22 | 34 43 23 | 69 66 24 | 68 59 25 | 8 17 -------------------------------------------------------------------------------- /benchmark/corl2017/096/turn_Town01.txt: -------------------------------------------------------------------------------- 1 | 13 239 2 | 209 240 3 | 230 247 4 | 214 207 5 | 11 27 6 | 110 53 7 | 191 18 8 | 14 205 9 | 180 190 10 | 210 217 11 | 216 196 12 | 256 227 13 | 252 177 14 | 30 11 15 | 254 177 16 | 178 212 17 | 188 110 18 | 215 129 19 | 56 186 20 | 188 177 21 | 121 187 22 | 209 176 23 | 41 234 24 | 21 239 25 | 256 239 26 | -------------------------------------------------------------------------------- /benchmark/corl2017/096/turn_Town02.txt: -------------------------------------------------------------------------------- 1 | 48 9 2 | 77 61 3 | 25 16 4 | 47 75 5 | 64 89 6 | 27 14 7 | 11 53 8 | 41 100 9 | 14 69 10 | 71 61 11 | 51 74 12 | 42 71 13 | 10 69 14 | 5 64 15 | 82 62 16 | 10 26 17 | 35 38 18 | 74 66 19 | 8 51 20 | 6 60 21 | 45 22 22 | 27 9 23 | 6 30 24 | 69 24 25 | 58 74 26 | -------------------------------------------------------------------------------- /benchmark/goal_suite.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import queue 4 | 5 | import numpy as np 6 | 7 | import carla 8 | 9 | from agents.navigation.local_planner import RoadOption, LocalPlannerNew, LocalPlannerOld 10 | 11 | from .base_suite import BaseSuite 12 | 13 | 14 | def from_file(poses_txt): 15 | pairs_file = Path(__file__).parent / poses_txt 16 | pairs = pairs_file.read_text().strip().split('\n') 17 | pairs = [(int(x[0]), int(x[1])) for x in map(lambda y: y.split(), pairs)] 18 | 19 | return pairs 20 | 21 | 22 | class PointGoalSuite(BaseSuite): 23 | def __init__( 24 | self, success_dist=5.0, col_is_failure=False, 25 | viz_camera=False, planner='new', poses_txt='', **kwargs): 26 | super().__init__(**kwargs) 27 | 28 | self.success_dist = success_dist 29 | self.col_is_failure = col_is_failure 30 | self.planner = planner 31 | self.poses = from_file(poses_txt) 32 | 33 | self.command = RoadOption.LANEFOLLOW 34 | 35 | self.timestamp_active = 0 36 | self._timeout = float('inf') 37 | 38 | self.viz_camera = viz_camera 39 | self._viz_queue = None 40 | 41 | def init(self, target=1, **kwargs): 42 | self._target_pose = self._map.get_spawn_points()[target] 43 | 44 | super().init(**kwargs) 45 | 46 | def ready(self): 47 | # print (self.planner) 48 | if self.planner == 'new': 49 | self._local_planner = LocalPlannerNew(self._player, 2.5, 9.0, 1.5) 50 | else: 51 | self._local_planner = LocalPlannerOld(self._player) 52 | 53 | self._local_planner.set_route(self._start_pose.location, self._target_pose.location) 54 | self._timeout = self._local_planner.calculate_timeout() 55 | 56 | return super().ready() 57 | 58 | def tick(self): 59 | result = super().tick() 60 | 61 | self._local_planner.run_step() 62 | self.command = self._local_planner.checkpoint[1] 63 | self.node = self._local_planner.checkpoint[0].transform.location 64 | self._next = self._local_planner.target[0].transform.location 65 | 66 | return result 67 | 68 | def get_observations(self): 69 | result = dict() 70 | result.update(super().get_observations()) 71 | result['command'] = int(self.command) 72 | result['node'] = np.array([self.node.x, self.node.y]) 73 | result['next'] = np.array([self._next.x, self._next.y]) 74 | 75 | return result 76 | 77 | @property 78 | def weathers(self): 79 | return self._weathers 80 | 81 | @property 82 | def pose_tasks(self): 83 | return self.poses 84 | 85 | def clean_up(self): 86 | super().clean_up() 87 | 88 | self.timestamp_active = 0 89 | self._timeout = float('inf') 90 | self._local_planner = None 91 | 92 | # Clean-up cameras 93 | if self._viz_queue: 94 | with self._viz_queue.mutex: 95 | self._viz_queue.queue.clear() 96 | 97 | def is_failure(self): 98 | if self.timestamp_active >= self._timeout or self._tick >= 10000: 99 | return True 100 | elif self.col_is_failure and self.collided: 101 | return True 102 | 103 | return False 104 | 105 | def is_success(self): 106 | location = self._player.get_location() 107 | distance = location.distance(self._target_pose.location) 108 | 109 | return distance <= self.success_dist 110 | 111 | def apply_control(self, control): 112 | result = super().apply_control(control) 113 | 114 | # is_light_red = self._is_light_red(agent) 115 | # 116 | # if is_light_red: 117 | # self.timestamp_active -= 1 118 | 119 | self.timestamp_active += 1 120 | 121 | # Return diagnostics 122 | location = self._player.get_location() 123 | orientation = self._player.get_transform().get_forward_vector() 124 | velocity = self._player.get_velocity() 125 | speed = np.linalg.norm([velocity.x, velocity.y, velocity.z]) 126 | 127 | info = { 128 | 'x': location.x, 129 | 'y': location.y, 130 | 'z': location.z, 131 | 'ori_x': orientation.x, 132 | 'ori_y': orientation.y, 133 | 'speed': speed, 134 | 'collided': self.collided, 135 | 'invaded': self.invaded, 136 | 'distance_to_goal': self._local_planner.distance_to_goal, 137 | 'viz_img': self._viz_queue.get() if self.viz_camera else None 138 | } 139 | 140 | info.update(result) 141 | 142 | return info 143 | 144 | def _is_light_red(self, agent): 145 | lights_list = self._world.get_actors().filter('*traffic_light*') 146 | is_light_red, _ = agent._is_light_red(lights_list) 147 | 148 | return is_light_red 149 | 150 | def _setup_sensors(self): 151 | super()._setup_sensors() 152 | 153 | if self.viz_camera: 154 | viz_camera_bp = self._blueprints.find('sensor.camera.rgb') 155 | viz_camera = self._world.spawn_actor( 156 | viz_camera_bp, 157 | carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), 158 | attach_to=self._player) 159 | viz_camera_bp.set_attribute('image_size_x', '640') 160 | viz_camera_bp.set_attribute('image_size_y', '480') 161 | 162 | # Set camera queues 163 | self._viz_queue = queue.Queue() 164 | viz_camera.listen(self._viz_queue.put) 165 | 166 | self._actor_dict['sensor'].append(viz_camera) 167 | 168 | def render_world(self): 169 | import matplotlib.pyplot as plt 170 | 171 | from matplotlib.patches import Circle 172 | 173 | plt.clf() 174 | plt.tight_layout() 175 | plt.axis('off') 176 | 177 | fig, ax = plt.subplots(1, 1) 178 | ax.get_xaxis().set_visible(False) 179 | ax.get_yaxis().set_visible(False) 180 | 181 | world = super().render_world() 182 | world[np.all(world == (255, 255, 255), axis=-1)] = [100, 100, 100] 183 | world[np.all(world == (0, 0, 0), axis=-1)] = [255, 255, 255] 184 | 185 | ax.imshow(world) 186 | 187 | prev_command = -1 188 | 189 | for i, (node, command) in enumerate(self._local_planner._route): 190 | command = int(command) 191 | pixel_x, pixel_y = self.world_to_pixel(node.transform.location) 192 | 193 | if command != prev_command and prev_command != -1: 194 | _command = {1: 'L', 2: 'R', 3: 'S', 4: 'F'}.get(command, '???') 195 | ax.text(pixel_x, pixel_y, _command, fontsize=8, color='black') 196 | ax.add_patch(Circle((pixel_x, pixel_y), 5, color='black')) 197 | elif i == 0 or i == len(self._local_planner._route)-1: 198 | text = 'start' if i == 0 else 'end' 199 | ax.text(pixel_x, pixel_y, text, fontsize=8, color='blue') 200 | ax.add_patch(Circle((pixel_x, pixel_y), 5, color='blue')) 201 | elif i % (len(self._local_planner._route) // 10) == 0: 202 | ax.add_patch(Circle((pixel_x, pixel_y), 3, color='red')) 203 | 204 | prev_command = int(command) 205 | 206 | fig.canvas.draw() 207 | 208 | w, h = fig.canvas.get_width_height() 209 | 210 | return np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(h, w, 3) 211 | -------------------------------------------------------------------------------- /benchmark/run_benchmark.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import tqdm 6 | import time 7 | 8 | import bird_view.utils.bz_utils as bzu 9 | import bird_view.utils.carla_utils as cu 10 | 11 | from bird_view.models.common import crop_birdview 12 | 13 | 14 | def _paint(observations, control, diagnostic, debug, env, show=False): 15 | import cv2 16 | import numpy as np 17 | 18 | 19 | WHITE = (255, 255, 255) 20 | RED = (255, 0, 0) 21 | CROP_SIZE = 192 22 | X = 176 23 | Y = 192 // 2 24 | R = 2 25 | 26 | birdview = cu.visualize_birdview(observations['birdview']) 27 | birdview = crop_birdview(birdview) 28 | 29 | if 'big_cam' in observations: 30 | canvas = np.uint8(observations['big_cam']).copy() 31 | rgb = np.uint8(observations['rgb']).copy() 32 | else: 33 | canvas = np.uint8(observations['rgb']).copy() 34 | 35 | def _stick_together(a, b, axis=1): 36 | 37 | if axis == 1: 38 | h = min(a.shape[0], b.shape[0]) 39 | 40 | r1 = h / a.shape[0] 41 | r2 = h / b.shape[0] 42 | 43 | a = cv2.resize(a, (int(r1 * a.shape[1]), int(r1 * a.shape[0]))) 44 | b = cv2.resize(b, (int(r2 * b.shape[1]), int(r2 * b.shape[0]))) 45 | 46 | return np.concatenate([a, b], 1) 47 | 48 | else: 49 | h = min(a.shape[1], b.shape[1]) 50 | 51 | r1 = h / a.shape[1] 52 | r2 = h / b.shape[1] 53 | 54 | a = cv2.resize(a, (int(r1 * a.shape[1]), int(r1 * a.shape[0]))) 55 | b = cv2.resize(b, (int(r2 * b.shape[1]), int(r2 * b.shape[0]))) 56 | 57 | return np.concatenate([a, b], 0) 58 | 59 | def _write(text, i, j, canvas=canvas, fontsize=0.4): 60 | rows = [x * (canvas.shape[0] // 10) for x in range(10+1)] 61 | cols = [x * (canvas.shape[1] // 9) for x in range(9+1)] 62 | cv2.putText( 63 | canvas, text, (cols[j], rows[i]), 64 | cv2.FONT_HERSHEY_SIMPLEX, fontsize, WHITE, 1) 65 | 66 | _command = { 67 | 1: 'LEFT', 68 | 2: 'RIGHT', 69 | 3: 'STRAIGHT', 70 | 4: 'FOLLOW', 71 | }.get(observations['command'], '???') 72 | 73 | if 'big_cam' in observations: 74 | fontsize = 0.8 75 | else: 76 | fontsize = 0.4 77 | 78 | _write('Command: ' + _command, 1, 0, fontsize=fontsize) 79 | _write('Velocity: %.1f' % np.linalg.norm(observations['velocity']), 2, 0, fontsize=fontsize) 80 | 81 | _write('Steer: %.2f' % control.steer, 4, 0, fontsize=fontsize) 82 | _write('Throttle: %.2f' % control.throttle, 5, 0, fontsize=fontsize) 83 | _write('Brake: %.1f' % control.brake, 6, 0, fontsize=fontsize) 84 | 85 | _write('Collided: %s' % diagnostic['collided'], 1, 6, fontsize=fontsize) 86 | _write('Invaded: %s' % diagnostic['invaded'], 2, 6, fontsize=fontsize) 87 | _write('Lights Ran: %d/%d' % (env.traffic_tracker.total_lights_ran, env.traffic_tracker.total_lights), 3, 6, fontsize=fontsize) 88 | _write('Goal: %.1f' % diagnostic['distance_to_goal'], 4, 6, fontsize=fontsize) 89 | 90 | _write('Time: %d' % env._tick, 5, 6, fontsize=fontsize) 91 | _write('FPS: %.2f' % (env._tick / (diagnostic['wall'])), 6, 6, fontsize=fontsize) 92 | 93 | for x, y in debug.get('locations', []): 94 | x = int(X - x / 2.0 * CROP_SIZE) 95 | y = int(Y + y / 2.0 * CROP_SIZE) 96 | 97 | S = R // 2 98 | birdview[x-S:x+S+1,y-S:y+S+1] = RED 99 | 100 | for x, y in debug.get('locations_world', []): 101 | x = int(X - x * 4) 102 | y = int(Y + y * 4) 103 | 104 | S = R // 2 105 | birdview[x-S:x+S+1,y-S:y+S+1] = RED 106 | 107 | for x, y in debug.get('locations_birdview', []): 108 | S = R // 2 109 | birdview[x-S:x+S+1,y-S:y+S+1] = RED 110 | 111 | for x, y in debug.get('locations_pixel', []): 112 | S = R // 2 113 | if 'big_cam' in observations: 114 | rgb[y-S:y+S+1,x-S:x+S+1] = RED 115 | else: 116 | canvas[y-S:y+S+1,x-S:x+S+1] = RED 117 | 118 | for x, y in debug.get('curve', []): 119 | x = int(X - x * 4) 120 | y = int(Y + y * 4) 121 | 122 | try: 123 | birdview[x,y] = [155, 0, 155] 124 | except: 125 | pass 126 | 127 | if 'target' in debug: 128 | x, y = debug['target'][:2] 129 | x = int(X - x * 4) 130 | y = int(Y + y * 4) 131 | birdview[x-R:x+R+1,y-R:y+R+1] = [0, 155, 155] 132 | 133 | ox, oy = observations['orientation'] 134 | rot = np.array([ 135 | [ox, oy], 136 | [-oy, ox]]) 137 | u = observations['node'] - observations['position'][:2] 138 | v = observations['next'] - observations['position'][:2] 139 | u = rot.dot(u) 140 | x, y = u 141 | x = int(X - x * 4) 142 | y = int(Y + y * 4) 143 | v = rot.dot(v) 144 | x, y = v 145 | x = int(X - x * 4) 146 | y = int(Y + y * 4) 147 | 148 | if 'big_cam' in observations: 149 | _write('Network input/output', 1, 0, canvas=rgb) 150 | _write('Projected output', 1, 0, canvas=birdview) 151 | full = _stick_together(rgb, birdview) 152 | else: 153 | full = _stick_together(canvas, birdview) 154 | 155 | if 'image' in debug: 156 | full = _stick_together(full, cu.visualize_predicted_birdview(debug['image'], 0.01)) 157 | 158 | if 'big_cam' in observations: 159 | full = _stick_together(canvas, full, axis=0) 160 | 161 | if show: 162 | bzu.show_image('canvas', full) 163 | bzu.add_to_video(full) 164 | 165 | 166 | def run_single(env, weather, start, target, agent_maker, seed, autopilot, show=False): 167 | # HACK: deterministic vehicle spawns. 168 | env.seed = seed 169 | env.init(start=start, target=target, weather=cu.PRESET_WEATHERS[weather]) 170 | 171 | if not autopilot: 172 | agent = agent_maker() 173 | else: 174 | agent = agent_maker(env._player, resolution=1, threshold=7.5) 175 | agent.set_route(env._start_pose.location, env._target_pose.location) 176 | 177 | diagnostics = list() 178 | result = { 179 | 'weather': weather, 180 | 'start': start, 'target': target, 181 | 'success': None, 't': None, 182 | 'total_lights_ran': None, 183 | 'total_lights': None, 184 | 'collided': None, 185 | } 186 | 187 | while env.tick(): 188 | observations = env.get_observations() 189 | control = agent.run_step(observations) 190 | diagnostic = env.apply_control(control) 191 | 192 | _paint(observations, control, diagnostic, agent.debug, env, show=show) 193 | 194 | diagnostic.pop('viz_img') 195 | diagnostics.append(diagnostic) 196 | 197 | if env.is_failure() or env.is_success(): 198 | result['success'] = env.is_success() 199 | result['total_lights_ran'] = env.traffic_tracker.total_lights_ran 200 | result['total_lights'] = env.traffic_tracker.total_lights 201 | result['collided'] = env.collided 202 | result['t'] = env._tick 203 | break 204 | 205 | return result, diagnostics 206 | 207 | 208 | def run_benchmark(agent_maker, env, benchmark_dir, seed, autopilot, resume, max_run=5, show=False): 209 | """ 210 | benchmark_dir must be an instance of pathlib.Path 211 | """ 212 | summary_csv = benchmark_dir / 'summary.csv' 213 | diagnostics_dir = benchmark_dir / 'diagnostics' 214 | diagnostics_dir.mkdir(parents=True, exist_ok=True) 215 | 216 | summary = list() 217 | total = len(list(env.all_tasks)) 218 | 219 | if summary_csv.exists() and resume: 220 | summary = pd.read_csv(summary_csv) 221 | else: 222 | summary = pd.DataFrame() 223 | 224 | num_run = 0 225 | 226 | for weather, (start, target), run_name in tqdm.tqdm(env.all_tasks, total=total): 227 | if resume and len(summary) > 0 and ((summary['start'] == start) \ 228 | & (summary['target'] == target) \ 229 | & (summary['weather'] == weather)).any(): 230 | print (weather, start, target) 231 | continue 232 | 233 | 234 | diagnostics_csv = str(diagnostics_dir / ('%s.csv' % run_name)) 235 | 236 | bzu.init_video(save_dir=str(benchmark_dir / 'videos'), save_path=run_name) 237 | 238 | result, diagnostics = run_single(env, weather, start, target, agent_maker, seed, autopilot, show=show) 239 | 240 | summary = summary.append(result, ignore_index=True) 241 | 242 | # Do this every timestep just in case. 243 | pd.DataFrame(summary).to_csv(summary_csv, index=False) 244 | pd.DataFrame(diagnostics).to_csv(diagnostics_csv, index=False) 245 | 246 | num_run += 1 247 | 248 | if num_run >= max_run: 249 | break 250 | -------------------------------------------------------------------------------- /benchmark_agent.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | from pathlib import Path 5 | 6 | from benchmark import make_suite, get_suites, ALL_SUITES 7 | from benchmark.run_benchmark import run_benchmark 8 | 9 | import bird_view.utils.bz_utils as bzu 10 | 11 | 12 | def _agent_factory_hack(model_path, config, autopilot): 13 | """ 14 | These imports before carla.Client() cause seg faults... 15 | """ 16 | from bird_view.models.roaming import RoamingAgentMine 17 | 18 | if autopilot: 19 | return RoamingAgentMine 20 | 21 | import torch 22 | 23 | from bird_view.models import baseline 24 | from bird_view.models import birdview 25 | from bird_view.models import image 26 | 27 | model_args = config['model_args'] 28 | model_name = model_args['model'] 29 | model_to_class = { 30 | 'birdview_dian': (birdview.BirdViewPolicyModelSS, birdview.BirdViewAgent), 31 | 'image_ss': (image.ImagePolicyModelSS, image.ImageAgent), 32 | } 33 | 34 | model_class, agent_class = model_to_class[model_name] 35 | 36 | model = model_class(**config['model_args']) 37 | model.load_state_dict(torch.load(str(model_path))) 38 | model.eval() 39 | 40 | agent_args = config.get('agent_args', dict()) 41 | agent_args['model'] = model 42 | 43 | return lambda: agent_class(**agent_args) 44 | 45 | 46 | def run(model_path, port, suite, big_cam, seed, autopilot, resume, max_run=10, show=False): 47 | log_dir = model_path.parent 48 | config = bzu.load_json(str(log_dir / 'config.json')) 49 | 50 | total_time = 0.0 51 | 52 | for suite_name in get_suites(suite): 53 | tick = time.time() 54 | 55 | benchmark_dir = log_dir / 'benchmark' / model_path.stem / ('%s_seed%d' % (suite_name, seed)) 56 | benchmark_dir.mkdir(parents=True, exist_ok=True) 57 | 58 | with make_suite(suite_name, port=port, big_cam=big_cam) as env: 59 | agent_maker = _agent_factory_hack(model_path, config, autopilot) 60 | 61 | run_benchmark(agent_maker, env, benchmark_dir, seed, autopilot, resume, max_run=max_run, show=show) 62 | 63 | elapsed = time.time() - tick 64 | total_time += elapsed 65 | 66 | print('%s: %.3f hours.' % (suite_name, elapsed / 3600)) 67 | 68 | print('Total time: %.3f hours.' % (total_time / 3600)) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--model-path', required=True) 74 | parser.add_argument('--port', type=int, default=2000) 75 | parser.add_argument('--suite', choices=ALL_SUITES, default='town1') 76 | parser.add_argument('--big_cam', action='store_true') 77 | parser.add_argument('--seed', type=int, default=2019) 78 | parser.add_argument('--autopilot', action='store_true', default=False) 79 | parser.add_argument('--show', action='store_true', default=False) 80 | parser.add_argument('--resume', action='store_true') 81 | parser.add_argument('--max-run', type=int, default=3) 82 | 83 | args = parser.parse_args() 84 | 85 | run(Path(args.model_path), args.port, args.suite, args.big_cam, args.seed, args.autopilot, args.resume, max_run=args.max_run, show=args.show) 86 | -------------------------------------------------------------------------------- /bird_view/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/bird_view/models/__init__.py -------------------------------------------------------------------------------- /bird_view/models/agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | 5 | import carla 6 | 7 | 8 | class Agent(object): 9 | def __init__(self, model=None, **kwargs): 10 | assert model is not None 11 | 12 | if len(kwargs) > 0: 13 | print('Unused kwargs: %s' % kwargs) 14 | 15 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | self.transform = transforms.ToTensor() 17 | 18 | self.one_hot = torch.FloatTensor(torch.eye(4)) 19 | 20 | self.model = model.to(self.device) 21 | self.model.eval() 22 | 23 | self.debug = dict() 24 | 25 | def postprocess(self, steer, throttle, brake): 26 | control = carla.VehicleControl() 27 | control.steer = np.clip(steer, -1.0, 1.0) 28 | control.throttle = np.clip(throttle, 0.0, 1.0) 29 | control.brake = np.clip(brake, 0.0, 1.0) 30 | control.manual_gear_shift = False 31 | 32 | return control 33 | -------------------------------------------------------------------------------- /bird_view/models/baseline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchvision import transforms 6 | 7 | import carla 8 | 9 | from .resnet import get_resnet 10 | from .common import select_branch, Normalize 11 | from .agent import Agent 12 | 13 | 14 | def BaselineBranch(p): 15 | return nn.Sequential( 16 | nn.Linear(512, 256), 17 | nn.ReLU(True), 18 | nn.Dropout(p), 19 | 20 | nn.Linear(256, 256), 21 | nn.ReLU(True), 22 | nn.Dropout(p), 23 | 24 | nn.Linear(256, 3)) 25 | 26 | 27 | class Baseline(nn.Module): 28 | def __init__(self, backbone='resnet18', dropout=0.5, **kwargs): 29 | super().__init__() 30 | 31 | conv, c = get_resnet(backbone, input_channel=3) 32 | 33 | self.conv = conv 34 | self.c = c 35 | self.global_avg_pool = nn.AvgPool2d((40, 96)) 36 | 37 | self.rgb_transform = Normalize( 38 | mean=[0.31, 0.33, 0.36], 39 | std=[0.18, 0.18, 0.19], 40 | ) 41 | 42 | self.speed_encoder = nn.Sequential( 43 | nn.Linear(1, 128), 44 | nn.ReLU(True), 45 | nn.Dropout(p=dropout), 46 | 47 | nn.Linear(128, 128), 48 | nn.ReLU(True), 49 | nn.Dropout(p=dropout), 50 | 51 | nn.Linear(128, 128), 52 | nn.ReLU(True), 53 | nn.Dropout(p=dropout), 54 | ) 55 | 56 | self.joint = nn.Sequential( 57 | nn.Linear(c+128, 512), 58 | nn.ReLU(True), 59 | nn.Dropout(p=dropout), 60 | ) 61 | 62 | self.speed = nn.Sequential( 63 | nn.Linear(512, 256), 64 | nn.ReLU(True), 65 | nn.Dropout(p=dropout), 66 | 67 | nn.Linear(256, 256), 68 | nn.ReLU(True), 69 | nn.Dropout(p=dropout), 70 | 71 | nn.Linear(256, 1), 72 | ) 73 | 74 | self.branches = nn.ModuleList([BaselineBranch(p=dropout) for i in range(4)]) 75 | 76 | def forward(self, image, velocity, command): 77 | h = self.conv(self.rgb_transform(image)) 78 | h = self.global_avg_pool(h).view(-1, self.c) 79 | v = self.speed_encoder(velocity[...,None]) 80 | 81 | h = torch.cat([h, v], dim=1) 82 | h = self.joint(h) 83 | 84 | branch_outputs = [control(h) for control in self.branches] 85 | branch_outputs = torch.stack(branch_outputs, dim=1) 86 | 87 | control = select_branch(branch_outputs, command) 88 | speed = self.speed(h) 89 | 90 | return control, speed 91 | 92 | 93 | class BaselineAgent(Agent): 94 | def run_step(self, observations): 95 | rgb = observations['rgb'].copy() 96 | speed = np.linalg.norm(observations['velocity']) 97 | command = self.one_hot[int(observations['command']) - 1] 98 | 99 | with torch.no_grad(): 100 | _rgb = (self.transform(rgb)[None]).to(self.device) 101 | _speed = torch.FloatTensor([speed]).to(self.device) 102 | _command = one_hot(torch.FloatTensor([command])).to(self.device) 103 | 104 | _control, _ = self.model(_rgb, _speed, _command) 105 | steer, throttle, brake = map(float, _control.cpu().numpy().squeeze()) 106 | 107 | if not hasattr(self, 'hack'): 108 | self.hack = 0 109 | 110 | if self.hack < 20: 111 | speed = 2 112 | throttle = 0.5 113 | brake = 0 114 | 115 | self.hack += 1 116 | 117 | control = carla.VehicleControl() 118 | control.steer = steer 119 | control.throttle = throttle 120 | control.brake = brake 121 | 122 | return control 123 | -------------------------------------------------------------------------------- /bird_view/models/birdview.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from . import common 8 | from .agent import Agent 9 | from .controller import PIDController, CustomController 10 | from .controller import ls_circle 11 | 12 | 13 | STEPS = 5 14 | SPEED_STEPS = 3 15 | COMMANDS = 4 16 | DT = 0.1 17 | CROP_SIZE = 192 18 | PIXELS_PER_METER = 5 19 | 20 | 21 | def regression_base(): 22 | return nn.Sequential( 23 | nn.ConvTranspose2d(640,256,4,2,1,0), 24 | nn.BatchNorm2d(256), 25 | nn.ReLU(True), 26 | nn.ConvTranspose2d(256,128,4,2,1,0), 27 | nn.BatchNorm2d(128), 28 | nn.ReLU(True), 29 | nn.ConvTranspose2d(128,64,4,2,1,0), 30 | nn.BatchNorm2d(64), 31 | nn.ReLU(True)) 32 | 33 | 34 | def spatial_softmax_base(): 35 | return nn.Sequential( 36 | nn.BatchNorm2d(640), 37 | nn.ConvTranspose2d(640,256,3,2,1,1), 38 | nn.ReLU(True), 39 | nn.BatchNorm2d(256), 40 | nn.ConvTranspose2d(256,128,3,2,1,1), 41 | nn.ReLU(True), 42 | nn.BatchNorm2d(128), 43 | nn.ConvTranspose2d(128,64,3,2,1,1), 44 | nn.ReLU(True)) 45 | 46 | 47 | class BirdViewPolicyModelSS(common.ResnetBase): 48 | def __init__(self, backbone='resnet18', input_channel=7, n_step=5, all_branch=False, **kwargs): 49 | super().__init__(backbone=backbone, input_channel=input_channel, bias_first=False) 50 | 51 | self.deconv = spatial_softmax_base() 52 | self.location_pred = nn.ModuleList([ 53 | nn.Sequential( 54 | nn.BatchNorm2d(64), 55 | nn.Conv2d(64,STEPS,1,1,0), 56 | common.SpatialSoftmax(48,48,STEPS)) for i in range(COMMANDS) 57 | ]) 58 | 59 | self.all_branch = all_branch 60 | 61 | def forward(self, bird_view, velocity, command): 62 | h = self.conv(bird_view) 63 | b, c, kh, kw = h.size() 64 | 65 | # Late fusion for velocity 66 | velocity = velocity[...,None,None,None].repeat((1,128,kh,kw)) 67 | 68 | h = torch.cat((h, velocity), dim=1) 69 | h = self.deconv(h) 70 | 71 | location_preds = [location_pred(h) for location_pred in self.location_pred] 72 | location_preds = torch.stack(location_preds, dim=1) 73 | 74 | location_pred = common.select_branch(location_preds, command) 75 | 76 | if self.all_branch: 77 | return location_pred, location_preds 78 | 79 | return location_pred 80 | 81 | 82 | class BirdViewAgent(Agent): 83 | def __init__(self, steer_points=None, pid=None, gap=5, **kwargs): 84 | super().__init__(**kwargs) 85 | 86 | self.speed_control = PIDController(K_P=1.0, K_I=0.1, K_D=2.5) 87 | 88 | if steer_points is None: 89 | steer_points = {"1": 3, "2": 2, "3": 2, "4": 2} 90 | 91 | if pid is None: 92 | pid = { 93 | "1" : {"Kp": 1.0, "Ki": 0.1, "Kd":0}, # Left 94 | "2" : {"Kp": 1.0, "Ki": 0.1, "Kd":0}, # Right 95 | "3" : {"Kp": 0.8, "Ki": 0.1, "Kd":0}, # Straight 96 | "4" : {"Kp": 0.8, "Ki": 0.1, "Kd":0}, # Follow 97 | } 98 | 99 | self.turn_control = CustomController(pid) 100 | self.steer_points = steer_points 101 | 102 | self.gap = gap 103 | 104 | def run_step(self, observations, teaching=False): 105 | birdview = common.crop_birdview(observations['birdview'], dx=-10) 106 | speed = np.linalg.norm(observations['velocity']) 107 | command = self.one_hot[int(observations['command']) - 1] 108 | 109 | with torch.no_grad(): 110 | _birdview = self.transform(birdview).to(self.device).unsqueeze(0) 111 | _speed = torch.FloatTensor([speed]).to(self.device) 112 | _command = command.to(self.device).unsqueeze(0) 113 | 114 | if self.model.all_branch: 115 | _locations, _ = self.model(_birdview, _speed, _command) 116 | else: 117 | _locations = self.model(_birdview, _speed, _command) 118 | _locations = _locations.squeeze().detach().cpu().numpy() 119 | 120 | _map_locations = _locations 121 | # Pixel coordinates. 122 | _locations = (_locations + 1) / 2 * CROP_SIZE 123 | 124 | targets = list() 125 | 126 | for i in range(STEPS): 127 | pixel_dx, pixel_dy = _locations[i] 128 | pixel_dx = pixel_dx - CROP_SIZE / 2 129 | pixel_dy = CROP_SIZE - pixel_dy 130 | 131 | angle = np.arctan2(pixel_dx, pixel_dy) 132 | dist = np.linalg.norm([pixel_dx, pixel_dy]) / PIXELS_PER_METER 133 | 134 | targets.append([dist * np.cos(angle), dist * np.sin(angle)]) 135 | 136 | target_speed = 0.0 137 | 138 | for i in range(1, SPEED_STEPS): 139 | pixel_dx, pixel_dy = _locations[i] 140 | prev_dx, prev_dy = _locations[i-1] 141 | 142 | dx = pixel_dx - prev_dx 143 | dy = pixel_dy - prev_dy 144 | delta = np.linalg.norm([dx, dy]) 145 | 146 | target_speed += delta / (PIXELS_PER_METER * self.gap * DT) / (SPEED_STEPS-1) 147 | 148 | _cmd = int(observations['command']) 149 | n = self.steer_points.get(str(_cmd), 1) 150 | targets = np.concatenate([[[0, 0]], targets], 0) 151 | c, r = ls_circle(targets) 152 | closest = common.project_point_to_circle(targets[n], c, r) 153 | 154 | v = [1.0, 0.0, 0.0] 155 | w = [closest[0], closest[1], 0.0] 156 | alpha = common.signed_angle(v, w) 157 | steer = self.turn_control.run_step(alpha, _cmd) 158 | throttle = self.speed_control.step(target_speed - speed) 159 | brake = 0.0 160 | 161 | if target_speed < 1.0: 162 | steer = 0.0 163 | throttle = 0.0 164 | brake = 1.0 165 | 166 | self.debug['locations_birdview'] = _locations[:,::-1].astype(int) 167 | self.debug['target'] = closest 168 | self.debug['target_speed'] = target_speed 169 | 170 | control = self.postprocess(steer, throttle, brake) 171 | if teaching: 172 | return control, _map_locations 173 | else: 174 | return control 175 | -------------------------------------------------------------------------------- /bird_view/models/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import torchvision.transforms as transforms 10 | 11 | from .resnet import get_resnet 12 | 13 | 14 | CROP_SIZE = 192 15 | MAP_SIZE = 320 16 | 17 | 18 | def crop_birdview(birdview, dx=0, dy=0): 19 | x = 260 - CROP_SIZE // 2 + dx 20 | y = MAP_SIZE // 2 + dy 21 | 22 | birdview = birdview[ 23 | x-CROP_SIZE//2:x+CROP_SIZE//2, 24 | y-CROP_SIZE//2:y+CROP_SIZE//2] 25 | 26 | return birdview 27 | 28 | 29 | def select_branch(branches, one_hot): 30 | shape = branches.size() 31 | 32 | for i, s in enumerate(shape[2:]): 33 | one_hot = torch.stack([one_hot for _ in range(s)], dim=i+2) 34 | 35 | return torch.sum(one_hot * branches, dim=1) 36 | 37 | 38 | def signed_angle(u, v): 39 | theta = math.acos(np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))) 40 | 41 | if np.cross(u, v)[2] < 0: 42 | theta *= -1.0 43 | 44 | return theta 45 | 46 | 47 | def project_point_to_circle(point, c, r): 48 | direction = point - c 49 | closest = c + (direction / np.linalg.norm(direction)) * r 50 | 51 | return closest 52 | 53 | 54 | def make_arc(points, c, r): 55 | point_min = project_point_to_circle(points[0], c, r) 56 | point_max = project_point_to_circle(points[-1], c, r) 57 | 58 | theta_min = np.arctan2(point_min[1], point_min[0]) 59 | theta_max = np.arctan2(point_max[1], point_max[0]) 60 | 61 | # Probably a bug here. 62 | theta = np.linspace(theta_min, theta_max, 100) 63 | x1 = r * np.cos(theta) + c[0] 64 | x2 = r * np.sin(theta) + c[1] 65 | 66 | return np.stack([x1, x2], 1) 67 | 68 | 69 | class ResnetBase(nn.Module): 70 | def __init__(self, backbone, input_channel=3, bias_first=True, pretrained=False): 71 | super().__init__() 72 | 73 | 74 | conv, c = get_resnet( 75 | backbone, input_channel=input_channel, 76 | bias_first=bias_first, pretrained=pretrained) 77 | 78 | self.conv = conv 79 | self.c = c 80 | 81 | self.backbone = backbone 82 | self.input_channel = input_channel 83 | self.bias_first = bias_first 84 | 85 | 86 | class Normalize(nn.Module): 87 | def __init__(self, mean, std): 88 | super().__init__() 89 | 90 | self.mean = nn.Parameter(torch.FloatTensor(mean).reshape(1, 3, 1, 1), requires_grad=False) 91 | self.std = nn.Parameter(torch.FloatTensor(std).reshape(1, 3, 1, 1), requires_grad=False) 92 | 93 | def cuda(self): 94 | self.mean = self.mean.cuda() 95 | self.std = self.std.cuda() 96 | 97 | def forward(self, x): 98 | return (x - self.mean) / self.std 99 | 100 | 101 | class NormalizeV2(nn.Module): 102 | def __init__(self, mean, std): 103 | super().__init__() 104 | 105 | self.mean = torch.FloatTensor(mean).reshape(1, 3, 1, 1).cuda() 106 | self.std = torch.FloatTensor(std).reshape(1, 3, 1, 1).cuda() 107 | 108 | def forward(self, x): 109 | return (x - self.mean) / self.std 110 | 111 | 112 | class SpatialSoftmax(nn.Module): 113 | # Source: https://gist.github.com/jeasinema/1cba9b40451236ba2cfb507687e08834 114 | def __init__(self, height, width, channel, temperature=None, data_format='NCHW'): 115 | super().__init__() 116 | 117 | self.data_format = data_format 118 | self.height = height 119 | self.width = width 120 | self.channel = channel 121 | 122 | if temperature: 123 | self.temperature = Parameter(torch.ones(1)*temperature) 124 | else: 125 | self.temperature = 1. 126 | 127 | pos_x, pos_y = np.meshgrid( 128 | np.linspace(-1., 1., self.height), 129 | np.linspace(-1., 1., self.width) 130 | ) 131 | pos_x = torch.from_numpy(pos_x.reshape(self.height*self.width)).float() 132 | pos_y = torch.from_numpy(pos_y.reshape(self.height*self.width)).float() 133 | self.register_buffer('pos_x', pos_x) 134 | self.register_buffer('pos_y', pos_y) 135 | 136 | def forward(self, feature): 137 | # Output: 138 | # (N, C*2) x_0 y_0 ... 139 | 140 | if self.data_format == 'NHWC': 141 | feature = feature.transpose(1, 3).tranpose(2, 3).view(-1, self.height*self.width) 142 | else: 143 | feature = feature.view(-1, self.height*self.width) 144 | 145 | weight = F.softmax(feature/self.temperature, dim=-1) 146 | expected_x = torch.sum(torch.autograd.Variable(self.pos_x)*weight, dim=1, keepdim=True) 147 | expected_y = torch.sum(torch.autograd.Variable(self.pos_y)*weight, dim=1, keepdim=True) 148 | expected_xy = torch.cat([expected_x, expected_y], 1) 149 | # feature_keypoints = expected_xy.view(-1, self.channel*2) 150 | feature_keypoints = expected_xy.view(-1, self.channel, 2) 151 | 152 | return feature_keypoints 153 | 154 | 155 | class SpatialSoftmaxBZ(torch.nn.Module): 156 | """ 157 | IMPORTANT: 158 | i in [0, 1], where 0 is at the bottom, 1 is at the top 159 | j in [-1, 1] 160 | """ 161 | def __init__(self, height, width): 162 | super().__init__() 163 | 164 | self.height = height 165 | self.width = width 166 | 167 | pos_x, pos_y = np.meshgrid( 168 | np.linspace(-1.0, 1.0, self.height), 169 | np.linspace(-1.0, 1.0, self.width) 170 | ) 171 | 172 | self.pos_x = torch.from_numpy(pos_x).reshape(-1).float() 173 | self.pos_x = torch.nn.Parameter(self.pos_x, requires_grad=False) 174 | 175 | self.pos_y = torch.from_numpy(pos_y).reshape(-1).float() 176 | self.pos_y = torch.nn.Parameter(self.pos_y, requires_grad=False) 177 | 178 | def forward(self, feature): 179 | flattened = feature.view(feature.shape[0], feature.shape[1], -1) 180 | softmax = F.softmax(flattened, dim=-1) 181 | 182 | # This is not a bug. 183 | expected_x = torch.sum(self.pos_y * softmax, dim=-1) 184 | expected_x = (-expected_x + 1) / 2.0 185 | expected_y = torch.sum(self.pos_x * softmax, dim=-1) 186 | 187 | expected_xy = torch.stack([expected_x, expected_y], dim=2) 188 | 189 | return expected_xy 190 | 191 | 192 | # tmp = SpatialSoftmax(48, 48) 193 | # 194 | # check = [(47, 0), (47, 24), (47, 47), (0, 24)] 195 | # 196 | # for i, j in check: 197 | # feature = np.zeros((48, 48)) 198 | # feature[i,j] = 100 199 | # feature = torch.FloatTensor(feature).unsqueeze(0).unsqueeze(0) 200 | # 201 | # print(i, j, tmp(feature)) 202 | -------------------------------------------------------------------------------- /bird_view/models/controller.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | 3 | import numpy as np 4 | 5 | from scipy.special import comb 6 | from scipy import interpolate 7 | 8 | def ls_circle(points): 9 | ''' 10 | Input: Nx2 points 11 | Output: cx, cy, r 12 | ''' 13 | xs = points[:,0] 14 | ys = points[:,1] 15 | 16 | us = xs - np.mean(xs) 17 | vs = ys - np.mean(ys) 18 | 19 | Suu = np.sum(us**2) 20 | Suv = np.sum(us*vs) 21 | Svv = np.sum(vs**2) 22 | Suuu = np.sum(us**3) 23 | Suvv = np.sum(us*vs*vs) 24 | Svvv = np.sum(vs**3) 25 | Svuu = np.sum(vs*us*us) 26 | 27 | A = np.array([ 28 | [Suu, Suv], 29 | [Suv, Svv] 30 | ]) 31 | 32 | b = np.array([1/2.*Suuu+1/2.*Suvv, 1/2.*Svvv+1/2.*Svuu]) 33 | 34 | cx, cy = np.linalg.solve(A, b) 35 | r = np.sqrt(cx*cx+cy*cy+(Suu+Svv)/len(xs)) 36 | 37 | cx += np.mean(xs) 38 | cy += np.mean(ys) 39 | 40 | return np.array([cx, cy]), r 41 | 42 | 43 | class PIDController(object): 44 | def __init__(self, K_P=1.0, K_I=0.0, K_D=0.0, fps=10, n=30, **kwargs): 45 | self._K_P = K_P 46 | self._K_I = K_I 47 | self._K_D = K_D 48 | 49 | self._dt = 1.0 / fps 50 | self._n = n 51 | self._window = deque(maxlen=self._n) 52 | 53 | def step(self, error): 54 | self._window.append(error) 55 | 56 | if len(self._window) >= 2: 57 | integral = sum(self._window) * self._dt 58 | derivative = (self._window[-1] - self._window[-2]) / self._dt 59 | else: 60 | integral = 0.0 61 | derivative = 0.0 62 | 63 | control = 0.0 64 | control += self._K_P * error 65 | control += self._K_I * integral 66 | control += self._K_D * derivative 67 | 68 | return control 69 | 70 | 71 | class CustomController(): 72 | def __init__(self, controller_args, k=0.5, n=2, wheelbase=2.89, dt=0.1): 73 | self._wheelbase = wheelbase 74 | self._k = k 75 | 76 | self._n = n 77 | self._t = 0 78 | 79 | self._dt = dt 80 | self._controller_args = controller_args 81 | 82 | self._e_buffer = deque(maxlen=10) 83 | 84 | 85 | def run_step(self, alpha, cmd): 86 | self._e_buffer.append(alpha) 87 | 88 | if len(self._e_buffer) >= 2: 89 | _de = (self._e_buffer[-1] - self._e_buffer[-2]) / self._dt 90 | _ie = sum(self._e_buffer) * self._dt 91 | else: 92 | _de = 0.0 93 | _ie = 0.0 94 | 95 | Kp = self._controller_args[str(cmd)]["Kp"] 96 | Ki = self._controller_args[str(cmd)]["Ki"] 97 | Kd = self._controller_args[str(cmd)]["Kd"] 98 | 99 | return (Kp * alpha) + (Kd * _de) + (Ki * _ie) 100 | -------------------------------------------------------------------------------- /bird_view/models/factory.py: -------------------------------------------------------------------------------- 1 | def get_model(backbone, n_step=5, ss_loss=True): 2 | if ss_loss: 3 | return BirdViewPolicyModelSS(backbone, n_step=n_step) 4 | else: 5 | return BirdViewPolicyModel(backbone, n_step=n_step) 6 | 7 | -------------------------------------------------------------------------------- /bird_view/models/image.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from . import common 9 | from .agent import Agent 10 | from .controller import CustomController, PIDController 11 | from .controller import ls_circle 12 | 13 | 14 | CROP_SIZE = 192 15 | STEPS = 5 16 | COMMANDS = 4 17 | DT = 0.1 18 | CROP_SIZE = 192 19 | PIXELS_PER_METER = 5 20 | 21 | 22 | class ImagePolicyModelSS(common.ResnetBase): 23 | def __init__(self, backbone, warp=False, pretrained=False, all_branch=False, **kwargs): 24 | super().__init__(backbone, pretrained=pretrained, input_channel=3, bias_first=False) 25 | 26 | self.c = { 27 | 'resnet18': 512, 28 | 'resnet34': 512, 29 | 'resnet50': 2048 30 | }[backbone] 31 | self.warp = warp 32 | self.rgb_transform = common.NormalizeV2( 33 | mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225] 35 | ) 36 | 37 | self.deconv = nn.Sequential( 38 | nn.BatchNorm2d(self.c + 128), 39 | nn.ConvTranspose2d(self.c + 128,256,3,2,1,1), 40 | nn.ReLU(True), 41 | nn.BatchNorm2d(256), 42 | nn.ConvTranspose2d(256,128,3,2,1,1), 43 | nn.ReLU(True), 44 | nn.BatchNorm2d(128), 45 | nn.ConvTranspose2d(128,64,3,2,1,1), 46 | nn.ReLU(True), 47 | ) 48 | 49 | if warp: 50 | ow,oh = 48,48 51 | else: 52 | ow,oh = 96,40 53 | 54 | self.location_pred = nn.ModuleList([ 55 | nn.Sequential( 56 | nn.BatchNorm2d(64), 57 | nn.Conv2d(64,STEPS,1,1,0), 58 | common.SpatialSoftmax(ow,oh,STEPS), 59 | ) for i in range(4) 60 | ]) 61 | 62 | self.all_branch = all_branch 63 | 64 | def forward(self, image, velocity, command): 65 | if self.warp: 66 | warped_image = tgm.warp_perspective(image, self.M, dsize=(192, 192)) 67 | resized_image = resize_images(image) 68 | image = torch.cat([warped_image, resized_image], 1) 69 | 70 | 71 | image = self.rgb_transform(image) 72 | 73 | h = self.conv(image) 74 | b, c, kh, kw = h.size() 75 | 76 | # Late fusion for velocity 77 | velocity = velocity[...,None,None,None].repeat((1,128,kh,kw)) 78 | 79 | h = torch.cat((h, velocity), dim=1) 80 | h = self.deconv(h) 81 | 82 | location_preds = [location_pred(h) for location_pred in self.location_pred] 83 | location_preds = torch.stack(location_preds, dim=1) 84 | location_pred = common.select_branch(location_preds, command) 85 | 86 | if self.all_branch: 87 | return location_pred, location_preds 88 | 89 | return location_pred 90 | 91 | 92 | 93 | class ImageAgent(Agent): 94 | def __init__(self, steer_points=None, pid=None, gap=5, camera_args={'x':384,'h':160,'fov':90,'world_y':1.4,'fixed_offset':4.0}, **kwargs): 95 | super().__init__(**kwargs) 96 | 97 | self.fixed_offset = float(camera_args['fixed_offset']) 98 | print ("Offset: ", self.fixed_offset) 99 | w = float(camera_args['w']) 100 | h = float(camera_args['h']) 101 | self.img_size = np.array([w,h]) 102 | self.gap = gap 103 | 104 | if steer_points is None: 105 | steer_points = {"1": 4, "2": 3, "3": 2, "4": 2} 106 | 107 | if pid is None: 108 | pid = { 109 | "1" : {"Kp": 0.5, "Ki": 0.20, "Kd":0.0}, # Left 110 | "2" : {"Kp": 0.7, "Ki": 0.10, "Kd":0.0}, # Right 111 | "3" : {"Kp": 1.0, "Ki": 0.10, "Kd":0.0}, # Straight 112 | "4" : {"Kp": 1.0, "Ki": 0.50, "Kd":0.0}, # Follow 113 | } 114 | 115 | self.steer_points = steer_points 116 | self.turn_control = CustomController(pid) 117 | self.speed_control = PIDController(K_P=.8, K_I=.08, K_D=0.) 118 | 119 | self.engine_brake_threshold = 2.0 120 | self.brake_threshold = 2.0 121 | 122 | self.last_brake = -1 123 | 124 | def run_step(self, observations, teaching=False): 125 | rgb = observations['rgb'].copy() 126 | speed = np.linalg.norm(observations['velocity']) 127 | _cmd = int(observations['command']) 128 | command = self.one_hot[int(observations['command']) - 1] 129 | 130 | with torch.no_grad(): 131 | _rgb = self.transform(rgb).to(self.device).unsqueeze(0) 132 | _speed = torch.FloatTensor([speed]).to(self.device) 133 | _command = command.to(self.device).unsqueeze(0) 134 | if self.model.all_branch: 135 | model_pred, _ = self.model(_rgb, _speed, _command) 136 | else: 137 | model_pred = self.model(_rgb, _speed, _command) 138 | 139 | model_pred = model_pred.squeeze().detach().cpu().numpy() 140 | 141 | pixel_pred = model_pred 142 | 143 | # Project back to world coordinate 144 | model_pred = (model_pred+1)*self.img_size/2 145 | 146 | world_pred = self.unproject(model_pred) 147 | 148 | targets = [(0, 0)] 149 | 150 | for i in range(STEPS): 151 | pixel_dx, pixel_dy = world_pred[i] 152 | angle = np.arctan2(pixel_dx, pixel_dy) 153 | dist = np.linalg.norm([pixel_dx, pixel_dy]) 154 | 155 | targets.append([dist * np.cos(angle), dist * np.sin(angle)]) 156 | 157 | targets = np.array(targets) 158 | 159 | target_speed = np.linalg.norm(targets[:-1] - targets[1:], axis=1).mean() / (self.gap * DT) 160 | 161 | c, r = ls_circle(targets) 162 | n = self.steer_points.get(str(_cmd), 1) 163 | closest = common.project_point_to_circle(targets[n], c, r) 164 | 165 | acceleration = target_speed - speed 166 | 167 | v = [1.0, 0.0, 0.0] 168 | w = [closest[0], closest[1], 0.0] 169 | alpha = common.signed_angle(v, w) 170 | 171 | steer = self.turn_control.run_step(alpha, _cmd) 172 | throttle = self.speed_control.step(acceleration) 173 | brake = 0.0 174 | 175 | # Slow or stop. 176 | 177 | if target_speed <= self.engine_brake_threshold: 178 | steer = 0.0 179 | throttle = 0.0 180 | 181 | if target_speed <= self.brake_threshold: 182 | brake = 1.0 183 | 184 | self.debug = { 185 | # 'curve': curve, 186 | 'target_speed': target_speed, 187 | 'target': closest, 188 | 'locations_world': targets, 189 | 'locations_pixel': model_pred.astype(int), 190 | } 191 | 192 | control = self.postprocess(steer, throttle, brake) 193 | if teaching: 194 | return control, pixel_pred 195 | else: 196 | return control 197 | 198 | def unproject(self, output, world_y=1.4, fov=90): 199 | 200 | cx, cy = self.img_size / 2 201 | 202 | w, h = self.img_size 203 | 204 | f = w /(2 * np.tan(fov * np.pi / 360)) 205 | 206 | xt = (output[...,0:1] - cx) / f 207 | yt = (output[...,1:2] - cy) / f 208 | 209 | world_z = world_y / yt 210 | world_x = world_z * xt 211 | 212 | world_output = np.stack([world_x, world_z],axis=-1) 213 | 214 | if self.fixed_offset: 215 | world_output[...,1] -= self.fixed_offset 216 | 217 | world_output = world_output.squeeze() 218 | 219 | return world_output 220 | -------------------------------------------------------------------------------- /bird_view/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | model_urls = { 7 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 8 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 11 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 12 | } 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | identity = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = conv1x1(inplanes, planes) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = conv3x3(planes, planes, stride) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = conv1x1(planes, planes * self.expansion) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | identity = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | identity = self.downsample(x) 88 | 89 | out += identity 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | def __init__( 97 | self, block, layers, 98 | input_channel=7, num_classes=1000, zero_init_residual=False, bias_first=True): 99 | super(ResNet, self).__init__() 100 | 101 | self.inplanes = 64 102 | self.conv1 = nn.Conv2d( 103 | input_channel, 64, kernel_size=7, stride=2, padding=3, bias=bias_first) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 111 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 112 | self.fc = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 117 | elif isinstance(m, nn.BatchNorm2d): 118 | nn.init.constant_(m.weight, 1) 119 | nn.init.constant_(m.bias, 0) 120 | 121 | # Zero-initialize the last BN in each residual branch, 122 | # so that the residual branch starts with zeros, and each residual block behaves 123 | # like an identity. 124 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 125 | if zero_init_residual: 126 | for m in self.modules(): 127 | if isinstance(m, Bottleneck): 128 | nn.init.constant_(m.bn3.weight, 0) 129 | elif isinstance(m, BasicBlock): 130 | nn.init.constant_(m.bn2.weight, 0) 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | conv1x1(self.inplanes, planes * block.expansion, stride), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for _ in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | x = self.conv1(x) 150 | x = self.bn1(x) 151 | x = self.relu(x) 152 | x = self.maxpool(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | x = self.layer4(x) 158 | 159 | return x 160 | 161 | 162 | model_funcs = { 163 | 'resnet18': (BasicBlock, [2, 2, 2, 2], -1), 164 | 'resnet34': (BasicBlock, [3, 4, 6, 3], 512), 165 | 'resnet50': (Bottleneck, [3, 4, 6, 3], -1), 166 | 'resnet101': (Bottleneck, [3, 4, 23, 3], -1), 167 | 'resnet152': (Bottleneck, [3, 8, 36, 3], -1), 168 | } 169 | 170 | 171 | def get_resnet(model_name='resnet18', pretrained=False, **kwargs): 172 | block, layers, c_out = model_funcs[model_name] 173 | model = ResNet(block, layers, **kwargs) 174 | 175 | if pretrained and kwargs.get('input_channel', 3) == 3: 176 | url = model_urls[model_name] 177 | print ("Loading ResNet weights from : %s" % url) 178 | model.load_state_dict(model_zoo.load_url(url)) 179 | 180 | return model, c_out 181 | -------------------------------------------------------------------------------- /bird_view/models/roaming.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from agents.navigation.agent import Agent 4 | from agents.navigation.local_planner import LocalPlannerNew 5 | 6 | from .controller import PIDController 7 | 8 | import carla 9 | 10 | 11 | TURNING_PID = { 12 | 'K_P': 1.5, 13 | 'K_I': 0.5, 14 | 'K_D': 0.0, 15 | 'fps': 10 16 | } 17 | 18 | 19 | class RoamingAgentMine(Agent): 20 | def __init__(self, vehicle, resolution, threshold_before, threshold_after): 21 | super().__init__(vehicle) 22 | 23 | self._proximity_threshold = 9.5 24 | self.speed_control = PIDController(K_P=1.0) 25 | self.turn_control = PIDController(**TURNING_PID) 26 | 27 | self._local_planner = LocalPlannerNew(self._vehicle, resolution, threshold_before, threshold_after) 28 | self.set_route = self._local_planner.set_route 29 | 30 | self.debug = dict() 31 | 32 | def run_step(self, inputs=None, debug=False, debug_info=None): 33 | self._local_planner.run_step() 34 | 35 | ox = self._vehicle.get_transform().get_forward_vector().x 36 | oy = self._vehicle.get_transform().get_forward_vector().y 37 | rot = np.array([ 38 | [ox, oy], 39 | [-oy, ox]]) 40 | 41 | target = self._local_planner.target[0].transform.location 42 | target = np.array([target.x, target.y]) 43 | pos = self._vehicle.get_location() 44 | pos = np.array([pos.x, pos.y]) 45 | diff = rot.dot(target - pos) 46 | 47 | speed = self._vehicle.get_velocity() 48 | speed = np.linalg.norm([speed.x, speed.y]) 49 | 50 | u = np.array([diff[0], diff[1], 0.0]) 51 | v = np.array([1.0, 0.0, 0.0]) 52 | theta = np.arccos(np.dot(u, v) / np.linalg.norm(u)) 53 | theta = theta if np.cross(u, v)[2] < 0 else -theta 54 | steer = self.turn_control.step(theta) 55 | 56 | target_speed = 6.0 57 | 58 | if int(self._local_planner.target[1]) not in [3, 4]: 59 | target_speed *= 0.75 60 | 61 | delta = target_speed - speed 62 | throttle = self.speed_control.step(delta) 63 | 64 | control = carla.VehicleControl() 65 | control.steer = np.clip(steer, -1.0, 1.0) 66 | control.throttle = np.clip(throttle, 0.0, 1.0) 67 | control.brake = 0.0 68 | control.manual_gear_shift = False 69 | 70 | self.vehicle = self._vehicle.get_location() 71 | self.road_option = self._local_planner.target[1] 72 | 73 | actor_list = self._world.get_actors() 74 | vehicle_list = actor_list.filter('*vehicle*') 75 | lights_list = actor_list.filter('*traffic_light*') 76 | walkers_list = actor_list.filter('*walker*') 77 | 78 | blocking_vehicle, vehicle = self._is_vehicle_hazard(vehicle_list) 79 | blocking_light, traffic_light = self._is_light_red(lights_list) 80 | blocking_walker, walker = self._is_walker_hazard(walkers_list) 81 | hazard_detected = blocking_vehicle or blocking_light or blocking_walker 82 | 83 | if blocking_vehicle: 84 | self.waypoint = vehicle.get_location() 85 | elif traffic_light: 86 | self.waypoint = traffic_light.get_location() 87 | elif blocking_walker: 88 | self.waypoint = walker.get_location() 89 | else: 90 | self.waypoint = self._local_planner.target[0].transform.location 91 | 92 | if hazard_detected: 93 | control = self.emergency_stop() 94 | control.manual_gear_shift = False 95 | 96 | return control 97 | 98 | self.debug['target'] = (self.waypoint.x, self.waypoint.y) 99 | 100 | return control 101 | 102 | # x = self.scale * self._pixels_per_meter * (location.x - self._world_offset[0]) 103 | # y = self.scale * self._pixels_per_meter * (location.y - self._world_offset[1]) 104 | # return [int(x - offset[0]), int(y - offset[1])] 105 | -------------------------------------------------------------------------------- /bird_view/scripts/parse_runs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | 7 | 8 | log_dir = sys.argv[1] 9 | 10 | for model_name in Path(log_dir).glob('*'): 11 | print(model_name.stem) 12 | 13 | for run_path in sorted(model_name.glob('*/*.csv')): 14 | run_name = run_path.parent.stem 15 | csv = pd.read_csv(run_path) 16 | 17 | print(run_name, '%.4f' % csv['success'].mean(), len(csv)) 18 | 19 | print() 20 | -------------------------------------------------------------------------------- /bird_view/scripts/tune_pid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tqdm 3 | import carla 4 | 5 | from agents.navigation.roaming_agent import RoamingAgent 6 | 7 | from bird_view.utils import carla_utils as cu 8 | from bird_view.utils import bz_utils as bzu 9 | 10 | 11 | TOWN = 'Town01' 12 | TRAIN = [(25,29), (28,24), (99,103), (144,148), (151,147)] 13 | VAL = [(57,39), (50,48), (36,53), (136,79), (22,76)] 14 | PORT = 3000 15 | 16 | 17 | def world_loop(opts_dict): 18 | params = { 19 | 'spawn': 15, 20 | 'weather': 'clear_noon', 21 | 'n_vehicles': 0 22 | } 23 | 24 | with cu.CarlaWrapper(TOWN, cu.VEHICLE_NAME, PORT) as env: 25 | env.init(**params) 26 | agent = RoamingAgent(env._player, False, opts_dict) 27 | 28 | # Hack: fill up controller experience. 29 | for _ in range(30): 30 | env.tick() 31 | env.apply_control(agent.run_step()[0]) 32 | 33 | for _ in tqdm.tqdm(range(125)): 34 | env.tick() 35 | 36 | observations = env.get_observations() 37 | inputs = cu.get_inputs(observations) 38 | 39 | debug = dict() 40 | control, command = agent.run_step(inputs, debug_info=debug) 41 | env.apply_control(control) 42 | 43 | observations.update({'control': control, 'command': command}) 44 | 45 | processed = cu.process(observations) 46 | 47 | yield debug 48 | 49 | bzu.show_image('rgb', processed['rgb']) 50 | bzu.show_image('birdview', cu.visualize_birdview(processed['birdview'])) 51 | 52 | 53 | def main(): 54 | import matplotlib.pyplot as plt; plt.ion() 55 | 56 | np.random.seed(0) 57 | 58 | for _ in tqdm.tqdm(range(10000), desc='Trials'): 59 | desired = list() 60 | current = list() 61 | output = list() 62 | e = list() 63 | 64 | K_P = np.random.uniform(0.5, 2.0) 65 | K_I = np.random.uniform(0.0, 2.0) 66 | K_D = np.random.uniform(0.0, 0.05) 67 | 68 | # Best so far. 69 | # K_P = 1.0 70 | # K_I = 0.5 71 | # K_D = 0.0 72 | 73 | opts_dict = { 74 | 'lateral_control_dict': { 75 | 'K_P': K_P, 76 | 'K_I': K_I, 77 | 'K_D': K_D, 78 | 'dt': 0.1 79 | } 80 | } 81 | 82 | for debug in world_loop(opts_dict): 83 | for x in [desired, current, output]: 84 | if len(x) > 500: 85 | x.pop(0) 86 | 87 | desired.append(debug['desired']) 88 | current.append(debug['current']) 89 | output.append(debug['output']) 90 | e.append(debug['e'] ** 2) 91 | 92 | name = '%.1f_%.3f_%.3f_%.3f' % (sum(e), K_P, K_I, K_D) 93 | 94 | plt.cla() 95 | plt.plot(list(range(len(desired))), desired, 'b-') 96 | plt.plot(list(range(len(current))), current, 'r-') 97 | plt.plot(list(range(len(output))), output, 'c-') 98 | plt.savefig('/home/bradyzho/hd_data/images/%s.png' % name) 99 | 100 | 101 | if __name__ == '__main__': 102 | main() 103 | -------------------------------------------------------------------------------- /bird_view/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/bird_view/utils/__init__.py -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.sw* 3 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from collections import defaultdict 4 | 5 | from . import video_maker 6 | from . import gif_maker 7 | from . import saver 8 | 9 | show_image = video_maker.show 10 | 11 | init_video = video_maker.init 12 | add_to_video = video_maker.add 13 | 14 | add_to_gif = gif_maker.add 15 | save_gif = gif_maker.save 16 | clear_gif = gif_maker.clear 17 | 18 | dictlist = lambda: defaultdict(list) 19 | 20 | log = saver.Experiment() 21 | 22 | 23 | 24 | def load_json(path): 25 | with open(path, 'r') as f: 26 | data = json.load(f) 27 | 28 | return data 29 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/gif_maker.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import imageio 4 | 5 | 6 | DEFAULT_DIR = str(Path.home().joinpath('debug')) 7 | DEFAULT_PATH = 'test.gif' 8 | 9 | 10 | class Dummy(object): 11 | images = dict() 12 | 13 | @classmethod 14 | def add(cls, key, image): 15 | if key not in cls.images: 16 | cls.images[key] = list() 17 | 18 | cls.images[key].append(image.copy()) 19 | 20 | @classmethod 21 | def save(cls, key, save_dir=None, save_path=None, duration=0.1): 22 | save_dir = Path(save_dir or DEFAULT_DIR).resolve() 23 | save_path = save_path or DEFAULT_PATH 24 | 25 | save_dir.mkdir(exist_ok=True, parents=True) 26 | 27 | imageio.mimsave( 28 | str(save_dir.joinpath(save_path)), cls.images[key], 29 | 'GIF', duration=duration) 30 | 31 | cls.clear(key) 32 | 33 | @classmethod 34 | def clear(cls, key=None): 35 | if key in cls.images: 36 | cls.images.pop(key) 37 | else: 38 | cls.images.clear() 39 | 40 | 41 | add = Dummy.add 42 | save = Dummy.save 43 | clear = Dummy.clear 44 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/plotter.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | print(123) 5 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/saver.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | 4 | from pathlib import Path 5 | from collections import OrderedDict 6 | 7 | from loguru import logger 8 | from tensorboardX import SummaryWriter 9 | 10 | import numpy as np 11 | import torch 12 | import torchvision.utils as tv_utils 13 | 14 | 15 | def _preprocess_image(x): 16 | """ 17 | Takes - 18 | list of (h, w, 3) 19 | tensor of (n, h, 3) 20 | """ 21 | if isinstance(x, list): 22 | x = np.stack(x, 0).transpose(0, 3, 1, 2) 23 | 24 | x = torch.Tensor(x) 25 | 26 | if x.requires_grad: 27 | x = x.detach() 28 | 29 | if x.dim() == 3: 30 | x = x.unsqueeze(1) 31 | 32 | # x = torch.nn.functional.interpolate(x, 128, mode='nearest') 33 | x = tv_utils.make_grid(x, padding=2, normalize=True, nrow=4) 34 | x = x.cpu().numpy() 35 | 36 | return x 37 | 38 | 39 | def _format(**kwargs): 40 | result = list() 41 | 42 | for k, v in kwargs.items(): 43 | if isinstance(v, float) or isinstance(v, np.float32): 44 | result.append('%s: %.2f' % (k, v)) 45 | else: 46 | result.append('%s: %s' % (k, v)) 47 | 48 | return '\t'.join(result) 49 | 50 | 51 | class Experiment(object): 52 | def init(self, log_dir): 53 | """ 54 | This MUST be called. 55 | """ 56 | self._log = logger 57 | self.epoch = 0 58 | self.scalars = OrderedDict() 59 | 60 | self.log_dir = Path(log_dir).resolve() 61 | self.log_dir.mkdir(parents=True, exist_ok=True) 62 | 63 | for i in self._log._handlers: 64 | self._log.remove(i) 65 | 66 | self._writer_train = SummaryWriter(str(self.log_dir / 'train')) 67 | self._writer_val = SummaryWriter(str(self.log_dir / 'val')) 68 | self._log.add( 69 | str(self.log_dir / 'log.txt'), 70 | format='{time:MM/DD/YY HH:mm:ss} {level}\t{message}') 71 | 72 | # Functions. 73 | self.debug = self._log.debug 74 | self.info = lambda **kwargs: self._log.info(_format(**kwargs)) 75 | 76 | def load_config(self, model_path): 77 | log_dir = Path(model_path).parent 78 | 79 | with open(str(log_dir / 'config.json'), 'r') as f: 80 | return json.load(f) 81 | 82 | def save_config(self, config_dict): 83 | def _process(x): 84 | for key, val in x.items(): 85 | if isinstance(val, dict): 86 | _process(val) 87 | elif not isinstance(val, float) and not isinstance(val, int): 88 | x[key] = str(val) 89 | 90 | config = copy.deepcopy(config_dict) 91 | 92 | _process(config) 93 | 94 | with open(str(self.log_dir / 'config.json'), 'w+') as f: 95 | json.dump(config, f, indent=4, sort_keys=True) 96 | 97 | def scalar(self, is_train=True, **kwargs): 98 | for k, v in sorted(kwargs.items()): 99 | key = (is_train, k) 100 | 101 | if key not in self.scalars: 102 | self.scalars[key] = list() 103 | 104 | self.scalars[key].append(v) 105 | 106 | def image(self, is_train=True, **kwargs): 107 | writer = self._writer_train if is_train else self._writer_val 108 | 109 | for k, v in sorted(kwargs.items()): 110 | writer.add_image(k, _preprocess_image(v), self.epoch) 111 | 112 | def end_epoch(self, net=None): 113 | for (is_train, k), v in self.scalars.items(): 114 | info = OrderedDict() 115 | info['%s_%s' % ('train' if is_train else 'val', k)] = np.mean(v) 116 | info['std'] = np.std(v, dtype=np.float32) 117 | info['min'] = np.min(v) 118 | info['max'] = np.max(v) 119 | info['n'] = len(v) 120 | 121 | self.info(**info) 122 | 123 | if is_train: 124 | self._writer_train.add_scalar(k, np.mean(v), self.epoch) 125 | else: 126 | self._writer_val.add_scalar(k, np.mean(v), self.epoch) 127 | 128 | self.scalars.clear() 129 | 130 | if net is not None: 131 | if self.epoch % 10 == 0: 132 | torch.save(net.state_dict(), str(self.log_dir / ('model_%03d.t7' % self.epoch))) 133 | 134 | torch.save(net.state_dict(), str(self.log_dir / 'latest.t7')) 135 | 136 | self.epoch += 1 137 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/test.py: -------------------------------------------------------------------------------- 1 | import bz_utils.video_maker as video_maker 2 | 3 | 4 | import time 5 | import numpy as np 6 | 7 | 8 | tmp = np.zeros((256, 128, 3), dtype=np.uint8) 9 | video_maker.init() 10 | 11 | 12 | for i in range(256): 13 | tmp[:,:,0] += 1 14 | video_maker.add(tmp) 15 | 16 | if i == 100: 17 | break 18 | -------------------------------------------------------------------------------- /bird_view/utils/bz_utils/video_maker.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import cv2 5 | 6 | 7 | DEFAULT_DIR = str(Path.home().joinpath('debug')) 8 | DEFAULT_PATH = 'video' 9 | 10 | 11 | def _create_writer(video_path, height, width, fps=20): 12 | return cv2.VideoWriter( 13 | '%s.avi' % video_path, cv2.VideoWriter_fourcc(*'XVID'), fps, (width, height)) 14 | 15 | 16 | def show(name, image): 17 | if image.ndim == 3: 18 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 19 | 20 | cv2.imshow(name, image) 21 | cv2.waitKey(1) 22 | 23 | 24 | class Dummy(object): 25 | video = None 26 | video_path = None 27 | 28 | @classmethod 29 | def init(cls, save_dir=None, save_path=None): 30 | if cls.video is not None: 31 | cls.video.release() 32 | 33 | save_dir = Path(save_dir or DEFAULT_DIR) 34 | save_dir.mkdir(exist_ok=True, parents=True) 35 | save_path = save_path or DEFAULT_PATH 36 | 37 | cls.video = None 38 | cls.video_path = str(save_dir.joinpath(save_path)) 39 | 40 | cv2.destroyAllWindows() 41 | 42 | 43 | @classmethod 44 | def add(cls, image): 45 | if image.ndim == 3: 46 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 47 | 48 | if cls.video is None: 49 | cls.video = _create_writer(cls.video_path, image.shape[0], image.shape[1]) 50 | 51 | cls.video.write(image) 52 | 53 | 54 | init = Dummy.init 55 | add = Dummy.add 56 | -------------------------------------------------------------------------------- /bird_view/utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .birdview_lmdb import BirdViewDataset, load_birdview_data 2 | from .image_lmdb import ImageDataset, load_image_data 3 | # from .birdview_lmdb -------------------------------------------------------------------------------- /bird_view/utils/datasets/birdview_lmdb.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import lmdb 5 | import os 6 | import glob 7 | import numpy as np 8 | import cv2 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | from utils.image_utils import draw_msra_gaussian, gaussian_radius 13 | from utils.carla_utils import visualize_birdview 14 | 15 | import math 16 | import random 17 | 18 | PIXEL_OFFSET = 10 19 | 20 | 21 | def world_to_pixel( 22 | x,y,ox,oy,ori_ox, ori_oy, 23 | pixels_per_meter=5, offset=(-80,160), size=320, angle_jitter=15): 24 | pixel_dx, pixel_dy = (x-ox)*pixels_per_meter, (y-oy)*pixels_per_meter 25 | 26 | pixel_x = pixel_dx*ori_ox+pixel_dy*ori_oy 27 | pixel_y = -pixel_dx*ori_oy+pixel_dy*ori_ox 28 | 29 | pixel_x = 320-pixel_x 30 | 31 | return np.array([pixel_x, pixel_y]) + offset 32 | 33 | 34 | class BirdViewDataset(Dataset): 35 | def __init__( 36 | self, dataset_path, 37 | img_size=320, crop_size=192, gap=5, n_step=5, 38 | crop_x_jitter=5, crop_y_jitter=5, angle_jitter=5, 39 | down_ratio=4, gaussian_radius=1.0, max_frames=None): 40 | 41 | # These typically don't change. 42 | self.img_size = img_size 43 | self.crop_size = crop_size 44 | self.down_ratio = down_ratio 45 | self.gap = gap 46 | self.n_step = n_step 47 | 48 | self.max_frames = max_frames 49 | 50 | self.crop_x_jitter = crop_x_jitter 51 | self.crop_y_jitter = crop_y_jitter 52 | self.angle_jitter = angle_jitter 53 | 54 | self.gaussian_radius = gaussian_radius 55 | 56 | self._name_map = {} 57 | self.file_map = {} 58 | self.idx_map = {} 59 | 60 | self.bird_view_transform = transforms.ToTensor() 61 | 62 | n_episodes = 0 63 | 64 | for full_path in sorted(glob.glob('%s/**' % dataset_path), reverse=True): 65 | txn = lmdb.open( 66 | full_path, 67 | max_readers=1, readonly=True, 68 | lock=False, readahead=False, meminit=False).begin(write=False) 69 | 70 | n = int(txn.get('len'.encode())) - self.gap * self.n_step 71 | offset = len(self._name_map) 72 | 73 | for i in range(n): 74 | if max_frames and len(self) >= max_frames: 75 | break 76 | 77 | self._name_map[offset+i] = full_path 78 | self.file_map[offset+i] = txn 79 | self.idx_map[offset+i] = i 80 | 81 | n_episodes += 1 82 | 83 | if max_frames and len(self) >= max_frames: 84 | break 85 | 86 | print('%s: %d frames, %d episodes.' % (dataset_path, len(self), n_episodes)) 87 | 88 | def __len__(self): 89 | return len(self.file_map) 90 | 91 | def __getitem__(self, idx): 92 | lmdb_txn = self.file_map[idx] 93 | index = self.idx_map[idx] 94 | 95 | bird_view = np.frombuffer(lmdb_txn.get(('birdview_%04d'%index).encode()), np.uint8).reshape(320,320,7) 96 | measurement = np.frombuffer(lmdb_txn.get(('measurements_%04d'%index).encode()), np.float32) 97 | rgb_image = None 98 | 99 | ox, oy, oz, ori_ox, ori_oy, vx, vy, vz, ax, ay, az, cmd, steer, throttle, brake, manual, gear = measurement 100 | speed = np.linalg.norm([vx,vy,vz]) 101 | 102 | oangle = np.arctan2(ori_oy, ori_ox) 103 | delta_angle = np.random.randint(-self.angle_jitter,self.angle_jitter+1) 104 | dx = np.random.randint(-self.crop_x_jitter,self.crop_x_jitter+1) 105 | dy = np.random.randint(0,self.crop_y_jitter+1) - PIXEL_OFFSET 106 | 107 | o_camx = ox + ori_ox*2 108 | o_camy = oy + ori_oy*2 109 | 110 | pixel_ox = 160 111 | pixel_oy = 260 112 | 113 | bird_view = cv2.warpAffine( 114 | bird_view, 115 | cv2.getRotationMatrix2D((pixel_ox,pixel_oy), delta_angle, 1.0), 116 | bird_view.shape[1::-1], flags=cv2.INTER_LINEAR) 117 | 118 | # random cropping 119 | center_x, center_y = 160, 260-self.crop_size//2 120 | bird_view = bird_view[ 121 | dy+center_y-self.crop_size//2:dy+center_y+self.crop_size//2, 122 | dx+center_x-self.crop_size//2:dx+center_x+self.crop_size//2] 123 | 124 | angle = np.arctan2(ori_oy, ori_ox) + np.deg2rad(delta_angle) 125 | ori_ox, ori_oy = np.cos(angle), np.sin(angle) 126 | 127 | locations = [] 128 | orientations = [] 129 | 130 | for dt in range(self.gap, self.gap*(self.n_step+1), self.gap): 131 | lmdb_txn = self.file_map[idx] 132 | index =self.idx_map[idx]+dt 133 | 134 | f_measurement = np.frombuffer(lmdb_txn.get(("measurements_%04d"%index).encode()), np.float32) 135 | x, y, z, ori_x, ori_y = f_measurement[:5] 136 | 137 | pixel_y, pixel_x = world_to_pixel(x,y,ox,oy,ori_ox,ori_oy,size=self.img_size) 138 | pixel_x = pixel_x - (self.img_size-self.crop_size)//2 139 | pixel_y = self.crop_size - (self.img_size-pixel_y)+70 140 | 141 | pixel_x -= dx 142 | pixel_y -= dy 143 | 144 | angle = np.arctan2(ori_y, ori_x) - np.arctan2(ori_oy, ori_ox) 145 | ori_dx, ori_dy = np.cos(angle), np.sin(angle) 146 | 147 | locations.append([pixel_x, pixel_y]) 148 | orientations.append([ori_dx, ori_dy]) 149 | 150 | bird_view = self.bird_view_transform(bird_view) 151 | 152 | # Create mask 153 | output_size = self.crop_size // self.down_ratio 154 | heatmap_mask = np.zeros((self.n_step, output_size, output_size), dtype=np.float32) 155 | regression_offset = np.zeros((self.n_step,2), np.float32) 156 | indices = np.zeros((self.n_step), dtype=np.int64) 157 | 158 | for i, (pixel_x, pixel_y) in enumerate(locations): 159 | center = np.array( 160 | [pixel_x / self.down_ratio, pixel_y / self.down_ratio], 161 | dtype=np.float32) 162 | center = np.clip(center, 0, output_size-1) 163 | center_int = np.rint(center) 164 | 165 | draw_msra_gaussian(heatmap_mask[i], center_int, self.gaussian_radius) 166 | regression_offset[i] = center - center_int 167 | indices[i] = center_int[1] * output_size + center_int[0] 168 | 169 | return bird_view, np.array(locations), cmd, speed 170 | 171 | 172 | 173 | class BiasedBirdViewDataset(BirdViewDataset): 174 | def __init__(self, dataset_path, left_ratio=0.25, right_ratio=0.25, straight_ratio=0.25, **kwargs): 175 | super().__init__(dataset_path, **kwargs) 176 | 177 | print ("Doing biased: %.2f,%.2f,%.2f"%(left_ratio, right_ratio, straight_ratio)) 178 | 179 | self._choices = [1,2,3,4] 180 | self._weights = [left_ratio,right_ratio,straight_ratio,1-left_ratio-right_ratio-straight_ratio] 181 | # Separately save data on different cmd 182 | self.cmd_map = { i : set([]) for i in range(1,5)} 183 | 184 | for idx in range(len(self.file_map)): 185 | lmdb_txn = self.file_map[idx] 186 | index = self.idx_map[idx] 187 | 188 | measurement = np.frombuffer(lmdb_txn.get(('measurements_%04d'%index).encode()), np.float32) 189 | ox, oy, oz, ori_ox, ori_oy, vx, vy, vz, ax, ay, az, cmd, steer, throttle, brake, manual, gear = measurement 190 | speed = np.linalg.norm([vx,vy,vz]) 191 | 192 | if cmd != 4 and speed > 1.0: 193 | self.cmd_map[cmd].add(idx) 194 | else: 195 | self.cmd_map[4].add(idx) 196 | 197 | for cmd, nums in self.cmd_map.items(): 198 | print (cmd, len(nums)) 199 | 200 | def __getitem__(self, idx): 201 | cmd = np.random.choice(self._choices, p=self._weights) 202 | [_idx] = random.sample(self.cmd_map[cmd], 1) 203 | return super(BiasedBirdViewDataset, self).__getitem__(_idx) 204 | 205 | 206 | 207 | def load_birdview_data( 208 | dataset_dir, 209 | batch_size=32, num_workers=0, shuffle=True, 210 | crop_x_jitter=0, crop_y_jitter=0, angle_jitter=0, n_step=5, gap=5, 211 | max_frames=None, cmd_biased=False): 212 | if cmd_biased: 213 | dataset_cls = BiasedBirdViewDataset 214 | else: 215 | dataset_cls = BirdViewDataset 216 | 217 | dataset = dataset_cls( 218 | dataset_path, 219 | crop_x_jitter=crop_x_jitter, 220 | crop_y_jitter=crop_y_jitter, 221 | angle_jitter=angle_jitter, 222 | n_step=n_step, 223 | gap=gap, 224 | data_ratio=data_ratio, 225 | ) 226 | 227 | return DataLoader( 228 | dataset, 229 | batch_size=batch_size, num_workers=num_workers, 230 | shuffle=shuffle, drop_last=True, pin_memory=True) 231 | 232 | 233 | class Wrap(Dataset): 234 | def __init__(self, data, batch_size, samples): 235 | self.data = data 236 | self.batch_size = batch_size 237 | self.samples = samples 238 | 239 | def __len__(self): 240 | return self.batch_size * self.samples 241 | 242 | def __getitem__(self, i): 243 | return self.data[np.random.randint(len(self.data))] 244 | 245 | 246 | def _dataloader(data, batch_size, num_workers): 247 | return DataLoader( 248 | data, batch_size=batch_size, num_workers=num_workers, 249 | shuffle=True, drop_last=True, pin_memory=True) 250 | 251 | 252 | def get_birdview( 253 | dataset_dir, 254 | batch_size=32, num_workers=8, shuffle=True, 255 | crop_x_jitter=0, crop_y_jitter=0, angle_jitter=0, n_step=5, gap=5, 256 | max_frames=None, cmd_biased=False): 257 | 258 | def make_dataset(dir_name, is_train): 259 | _dataset_dir = str(Path(dataset_dir) / dir_name) 260 | _samples = 1000 if is_train else 10 261 | _crop_x_jitter = crop_x_jitter if is_train else 0 262 | _crop_y_jitter = crop_y_jitter if is_train else 0 263 | _angle_jitter = angle_jitter if is_train else 0 264 | _max_frames = max_frames if is_train else None 265 | _num_workers = num_workers if is_train else 0 266 | 267 | if is_train and cmd_biased: 268 | dataset_cls = BiasedBirdViewDataset 269 | else: 270 | dataset_cls = BirdViewDataset 271 | 272 | data = dataset_cls( 273 | _dataset_dir, gap=gap, n_step=n_step, 274 | crop_x_jitter=_crop_x_jitter, crop_y_jitter=_crop_y_jitter, 275 | angle_jitter=_angle_jitter, 276 | max_frames=_max_frames) 277 | data = Wrap(data, batch_size, _samples) 278 | data = _dataloader(data, batch_size, _num_workers) 279 | 280 | return data 281 | 282 | train = make_dataset('train', True) 283 | val = make_dataset('val', False) 284 | 285 | return train, val 286 | -------------------------------------------------------------------------------- /bird_view/utils/datasets/image_lmdb.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import lmdb 5 | import os 6 | import glob 7 | import numpy as np 8 | import cv2 9 | 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | 13 | import math 14 | import random 15 | 16 | import augmenter 17 | 18 | PIXEL_OFFSET = 10 19 | PIXELS_PER_METER = 5 20 | 21 | def world_to_pixel(x,y,ox,oy,ori_ox, ori_oy, offset=(-80,160), size=320, angle_jitter=15): 22 | pixel_dx, pixel_dy = (x-ox)*PIXELS_PER_METER, (y-oy)*PIXELS_PER_METER 23 | 24 | pixel_x = pixel_dx*ori_ox+pixel_dy*ori_oy 25 | pixel_y = -pixel_dx*ori_oy+pixel_dy*ori_ox 26 | 27 | pixel_x = 320-pixel_x 28 | 29 | return np.array([pixel_x, pixel_y]) + offset 30 | 31 | 32 | def project_to_image(pixel_x, pixel_y, tran=[0.,0.,0.], rot=[0.,0.,0.], fov=90, w=384, h=160, camera_world_z=1.4, crop_size=192): 33 | # Apply fixed offset tp pixel_y 34 | pixel_y -= 2*PIXELS_PER_METER 35 | 36 | pixel_y = crop_size - pixel_y 37 | pixel_x = pixel_x - crop_size/2 38 | 39 | world_x = pixel_x / PIXELS_PER_METER 40 | world_y = pixel_y / PIXELS_PER_METER 41 | 42 | xyz = np.zeros((1,3)) 43 | xyz[0,0] = world_x 44 | xyz[0,1] = camera_world_z 45 | xyz[0,2] = world_y 46 | 47 | f = w /(2 * np.tan(fov * np.pi / 360)) 48 | A = np.array([ 49 | [f, 0., w/2], 50 | [0, f, h/2], 51 | [0., 0., 1.] 52 | ]) 53 | image_xy, _ = cv2.projectPoints(xyz, np.array(tran), np.array(rot), A, None) 54 | image_xy[...,0] = np.clip(image_xy[...,0], 0, w) 55 | image_xy[...,1] = np.clip(image_xy[...,1], 0, h) 56 | 57 | return image_xy[0,0] 58 | 59 | class ImageDataset(Dataset): 60 | def __init__(self, 61 | dataset_path, 62 | rgb_shape=(160,384,3), 63 | img_size=320, 64 | crop_size=192, 65 | gap=5, 66 | n_step=5, 67 | gaussian_radius=1., 68 | down_ratio=4, 69 | # rgb_mean=[0.29813555, 0.31239682, 0.33620676], 70 | # rgb_std=[0.0668446, 0.06680295, 0.07329721], 71 | augment_strategy=None, 72 | batch_read_number=819200, 73 | batch_aug=1, 74 | ): 75 | self._name_map = {} 76 | 77 | self.file_map = {} 78 | self.idx_map = {} 79 | 80 | self.bird_view_transform = transforms.ToTensor() 81 | self.rgb_transform = transforms.ToTensor() 82 | 83 | self.rgb_shape = rgb_shape 84 | self.img_size = img_size 85 | self.crop_size = crop_size 86 | 87 | self.gap = gap 88 | self.n_step = n_step 89 | self.down_ratio = down_ratio 90 | self.batch_aug = batch_aug 91 | 92 | self.gaussian_radius = gaussian_radius 93 | 94 | print ("augment with ", augment_strategy) 95 | if augment_strategy is not None and augment_strategy != 'None': 96 | self.augmenter = getattr(augmenter, augment_strategy) 97 | else: 98 | self.augmenter = None 99 | 100 | count = 0 101 | for full_path in glob.glob('%s/**'%dataset_path): 102 | # hdf5_file = h5py.File(full_path, 'r', libver='latest', swmr=True) 103 | lmdb_file = lmdb.open(full_path, 104 | max_readers=1, 105 | readonly=True, 106 | lock=False, 107 | readahead=False, 108 | meminit=False 109 | ) 110 | 111 | txn = lmdb_file.begin(write=False) 112 | 113 | N = int(txn.get('len'.encode())) - self.gap*self.n_step 114 | 115 | for _ in range(N): 116 | self._name_map[_+count] = full_path 117 | self.file_map[_+count] = txn 118 | self.idx_map[_+count] = _ 119 | 120 | count += N 121 | 122 | print ("Finished loading %s. Length: %d"%(dataset_path, count)) 123 | self.batch_read_number = batch_read_number 124 | 125 | def __len__(self): 126 | return len(self.file_map) 127 | 128 | def __getitem__(self, idx): 129 | 130 | lmdb_txn = self.file_map[idx] 131 | index = self.idx_map[idx] 132 | 133 | bird_view = np.frombuffer(lmdb_txn.get(('birdview_%04d'%index).encode()), np.uint8).reshape(320,320,7) 134 | measurement = np.frombuffer(lmdb_txn.get(('measurements_%04d'%index).encode()), np.float32) 135 | rgb_image = np.fromstring(lmdb_txn.get(('rgb_%04d'%index).encode()), np.uint8).reshape(160,384,3) 136 | 137 | if self.augmenter: 138 | rgb_images = [self.augmenter(self.batch_read_number).augment_image(rgb_image) for i in range(self.batch_aug)] 139 | else: 140 | rgb_images = [rgb_image for i in range(self.batch_aug)] 141 | 142 | if self.batch_aug == 1: 143 | rgb_images = rgb_images[0] 144 | 145 | ox, oy, oz, ori_ox, ori_oy, vx, vy, vz, ax, ay, az, cmd, steer, throttle, brake, manual, gear = measurement 146 | speed = np.linalg.norm([vx,vy,vz]) 147 | 148 | oangle = np.arctan2(ori_oy, ori_ox) 149 | delta_angle = 0 150 | dx = 0 151 | dy = -PIXEL_OFFSET 152 | 153 | pixel_ox = 160 154 | pixel_oy = 260 155 | 156 | rot_mat = cv2.getRotationMatrix2D((pixel_ox,pixel_oy), delta_angle, 1.0) 157 | bird_view = cv2.warpAffine(bird_view, rot_mat, bird_view.shape[1::-1], flags=cv2.INTER_LINEAR) 158 | 159 | # random cropping 160 | center_x, center_y = 160, 260-self.crop_size//2 161 | 162 | 163 | bird_view = bird_view[dy+center_y-self.crop_size//2:dy+center_y+self.crop_size//2,dx+center_x-self.crop_size//2:dx+center_x+self.crop_size//2] 164 | 165 | angle = np.arctan2(ori_oy, ori_ox) + np.deg2rad(delta_angle) 166 | ori_ox, ori_oy = np.cos(angle), np.sin(angle) 167 | 168 | locations = [] 169 | 170 | for dt in range(self.gap, self.gap*(self.n_step+1), self.gap): 171 | 172 | lmdb_txn = self.file_map[idx] 173 | index =self.idx_map[idx]+dt 174 | 175 | f_measurement = np.frombuffer(lmdb_txn.get(("measurements_%04d"%index).encode()), np.float32) 176 | x, y, z, ori_x, ori_y = f_measurement[:5] 177 | 178 | pixel_y, pixel_x = world_to_pixel(x,y,ox,oy,ori_ox,ori_oy,size=self.img_size) 179 | pixel_x = pixel_x - (self.img_size-self.crop_size)//2 180 | pixel_y = self.crop_size - (self.img_size-pixel_y)+70 181 | 182 | pixel_x -= dx 183 | pixel_y -= dy 184 | 185 | # Coordinate transform 186 | 187 | locations.append([pixel_x, pixel_y]) 188 | 189 | if self.batch_aug == 1: 190 | rgb_images = self.rgb_transform(rgb_images) 191 | else: 192 | # if len() 193 | # import pdb; pdb.set_trace() 194 | rgb_images = torch.stack([self.rgb_transform(img) for img in rgb_images]) 195 | bird_view = self.bird_view_transform(bird_view) 196 | 197 | # Create mask 198 | # output_h = self.rgb_shape[0] // self.down_ratio 199 | # output_w = self.rgb_shape[1] // self.down_ratio 200 | # heatmap_mask = np.zeros((self.n_step, output_h, output_w), dtype=np.float32) 201 | # regression_offset = np.zeros((self.n_step,2), np.float32) 202 | # indices = np.zeros((self.n_step), dtype=np.int64) 203 | 204 | # image_locations = [] 205 | 206 | # for i, (pixel_x, pixel_y) in enumerate(locations): 207 | # image_pixel_x, image_pixel_y = project_to_image(pixel_x, pixel_y) 208 | 209 | # image_locations.append([image_pixel_x, image_pixel_y]) 210 | 211 | # center = np.array([image_pixel_x / self.down_ratio, image_pixel_y / self.down_ratio], dtype=np.float32) 212 | # center = np.clip(center, (0,0), (output_w-1, output_h-1)) 213 | 214 | # center_int = np.rint(center) 215 | 216 | # # draw_msra_gaussian(heatmap_mask[i], center_int, self.gaussian_radius) 217 | # regression_offset[i] = center - center_int 218 | # indices[i] = center_int[1] * output_w + center_int[0] 219 | 220 | self.batch_read_number += 1 221 | 222 | return rgb_images, bird_view, np.array(locations), cmd, speed 223 | 224 | 225 | def load_image_data(dataset_path, 226 | batch_size=32, 227 | num_workers=8, 228 | shuffle=True, 229 | n_step=5, 230 | gap=10, 231 | augment=None, 232 | **kwargs 233 | # rgb_mean=[0.29813555, 0.31239682, 0.33620676], 234 | # rgb_std=[0.0668446, 0.06680295, 0.07329721], 235 | ): 236 | 237 | dataset = ImageDataset( 238 | dataset_path, 239 | n_step=n_step, 240 | gap=gap, 241 | augment_strategy=augment, 242 | **kwargs, 243 | # rgb_mean=rgb_mean, 244 | # rgb_std=rgb_std, 245 | ) 246 | 247 | return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle, drop_last=True, pin_memory=True) 248 | 249 | 250 | class Wrap(Dataset): 251 | def __init__(self, data, batch_size, samples): 252 | self.data = data 253 | self.batch_size = batch_size 254 | self.samples = samples 255 | 256 | def __len__(self): 257 | return self.batch_size * self.samples 258 | 259 | def __getitem__(self, i): 260 | return self.data[np.random.randint(len(self.data))] 261 | 262 | 263 | def _dataloader(data, batch_size, num_workers): 264 | return DataLoader( 265 | data, batch_size=batch_size, num_workers=num_workers, 266 | shuffle=True, drop_last=True, pin_memory=True) 267 | 268 | 269 | def get_image( 270 | dataset_dir, 271 | batch_size=32, num_workers=0, shuffle=True, augment=None, 272 | n_step=5, gap=5, batch_aug=1): 273 | 274 | # import pdb; pdb.set_trace() 275 | 276 | def make_dataset(dir_name, is_train): 277 | _dataset_dir = str(Path(dataset_dir) / dir_name) 278 | _samples = 1000 if is_train else 10 279 | _num_workers = num_workers if is_train else 0 280 | _batch_aug = batch_aug if is_train else 1 281 | _augment = augment if is_train else None 282 | 283 | data = ImageDataset( 284 | _dataset_dir, gap=gap, n_step=n_step, augment_strategy=_augment, batch_aug=_batch_aug) 285 | data = Wrap(data, batch_size, _samples) 286 | data = _dataloader(data, batch_size, _num_workers) 287 | 288 | return data 289 | 290 | train = make_dataset('train', True) 291 | val = make_dataset('val', False) 292 | 293 | return train, val 294 | 295 | 296 | if __name__ == '__main__': 297 | batch_size = 256 298 | import tqdm 299 | dataset = ImageDataset('/raid0/dian/carla_0.9.6_data/train') 300 | loader = _dataloader(dataset, batch_size=batch_size, num_workers=16) 301 | mean = [] 302 | for rgb_img, bird_view, locations, cmd, speed in tqdm.tqdm(loader): 303 | mean.append(rgb_img.mean(dim=(0,2,3)).numpy()) 304 | 305 | print ("Mean: ", np.mean(mean, axis=0)) 306 | print ("Std: ", np.std(mean, axis=0)*np.sqrt(batch_size)) 307 | -------------------------------------------------------------------------------- /bird_view/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def draw_msra_gaussian(heatmap, center, sigma): 5 | tmp_size = sigma * 3 6 | mu_x = int(center[0] + 0.5) 7 | mu_y = int(center[1] + 0.5) 8 | w, h = heatmap.shape[0], heatmap.shape[1] 9 | ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] 10 | br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] 11 | if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0: 12 | return heatmap 13 | size = 2 * tmp_size + 1 14 | x = np.arange(0, size, 1, np.float32) 15 | y = x[:, np.newaxis] 16 | x0 = y0 = size // 2 17 | g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) 18 | g_x = max(0, -ul[0]), min(br[0], h) - ul[0] 19 | g_y = max(0, -ul[1]), min(br[1], w) - ul[1] 20 | img_x = max(0, ul[0]), min(br[0], h) 21 | img_y = max(0, ul[1]), min(br[1], w) 22 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum( 23 | heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]], 24 | g[g_y[0]:g_y[1], g_x[0]:g_x[1]]) 25 | return heatmap 26 | 27 | def gaussian_radius(det_size, min_overlap=0.7): 28 | height, width = det_size 29 | 30 | a1 = 1 31 | b1 = (height + width) 32 | c1 = width * height * (1 - min_overlap) / (1 + min_overlap) 33 | sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1) 34 | r1 = (b1 + sq1) / 2 35 | 36 | a2 = 4 37 | b2 = 2 * (height + width) 38 | c2 = (1 - min_overlap) * width * height 39 | sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2) 40 | r2 = (b2 + sq2) / 2 41 | 42 | a3 = 4 * min_overlap 43 | b3 = -2 * min_overlap * (height + width) 44 | c3 = (min_overlap - 1) * width * height 45 | sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3) 46 | r3 = (b3 + sq3) / 2 47 | return min(r1, r2, r3) 48 | -------------------------------------------------------------------------------- /bird_view/utils/logger.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from loguru import logger 4 | from tensorboardX import SummaryWriter 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision.utils as tv_utils 9 | 10 | 11 | def _preprocess_image(x): 12 | if isinstance(x, torch.Tensor): 13 | if x.requires_grad: 14 | x = x.detach() 15 | 16 | if x.dim() == 3: 17 | if x.shape[0] == 3: 18 | x = x.unsqueeze(0) 19 | else: 20 | x = x.unsqueeze(1) 21 | 22 | # x = torch.nn.functional.interpolate(x, 128, mode='nearest') 23 | x = tv_utils.make_grid(x, padding=16, normalize=True, nrow=4) 24 | x = x.cpu().numpy() 25 | 26 | return x 27 | 28 | 29 | def _format(**kwargs): 30 | result = list() 31 | 32 | for k, v in kwargs.items(): 33 | if isinstance(v, float): 34 | result.append('%s: %.2f' % (k, v)) 35 | else: 36 | result.append('%s: %s' % (k, v)) 37 | 38 | return '\t'.join(result) 39 | 40 | 41 | class Wrapper(object): 42 | def __init__(self, log): 43 | self.epoch = 0 44 | self._log = log 45 | self._writer = None 46 | self.scalars = OrderedDict() 47 | 48 | self.info = lambda **kwargs: self._log.info(_format(**kwargs)) 49 | self.debug = self._log.debug 50 | 51 | def init(self, log_path): 52 | for i in self._log._handlers: 53 | self._log.remove(i) 54 | 55 | self._writer = SummaryWriter(log_path) 56 | self._log.add( 57 | '%s/log.txt' % log_path, 58 | format='{time:MM/DD/YY HH:mm:ss} {level}\t{message}') 59 | 60 | def scalar(self, **kwargs): 61 | for k, v in sorted(kwargs.items()): 62 | if k not in self.scalars: 63 | self.scalars[k] = list() 64 | 65 | self.scalars[k].append(v) 66 | 67 | def image(self, **kwargs): 68 | for k, v in sorted(kwargs.items()): 69 | self._writer.add_image(k, _preprocess_image(v), self.epoch) 70 | 71 | def end_epoch(self): 72 | for k, v in self.scalars.items(): 73 | info = OrderedDict() 74 | info[k] = np.mean(v) 75 | info['std'] = float(np.std(v, dtype=np.float32)) 76 | info['min'] = np.min(v) 77 | info['max'] = np.max(v) 78 | info['n'] = len(v) 79 | 80 | self.info(**info) 81 | self._writer.add_scalar(k, np.mean(v), self.epoch) 82 | 83 | self.epoch = self.epoch + 1 84 | self.scalars = OrderedDict() 85 | 86 | self.info(epoch=self.epoch) 87 | 88 | 89 | log = Wrapper(logger) 90 | -------------------------------------------------------------------------------- /bird_view/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class SummaryWriter: 4 | def __init__(self, *args, **kwargs): 5 | print("tensorboardX not found. You need to install it to use the SummaryWriter.") 6 | print("try: pip3 install tensorboardX") 7 | raise ImportError 8 | 9 | class UnNormalize(object): 10 | def __init__(self, mean=[0.2929, 0.3123, 0.3292], std=[0.0762, 0.0726, 0.0801]): 11 | self.mean = mean 12 | self.std = std 13 | 14 | def __call__(self, tensor): 15 | """ 16 | Args: 17 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 18 | Returns: 19 | Tensor: Normalized image. 20 | """ 21 | new_tensor = tensor.new(*tensor.size()) 22 | new_tensor[:, 0, :, :] = tensor[:, 0, :, :] * self.std[0] + self.mean[0] 23 | new_tensor[:, 1, :, :] = tensor[:, 1, :, :] * self.std[1] + self.mean[1] 24 | new_tensor[:, 2, :, :] = tensor[:, 2, :, :] * self.std[2] + self.mean[2] 25 | 26 | return new_tensor 27 | 28 | try: 29 | from tensorboardX import SummaryWriter 30 | except ImportError: 31 | pass 32 | 33 | def one_hot(x, num_digits=4, start=1): 34 | N = x.size()[0] 35 | x = x.long()[:,None]-start 36 | x = torch.clamp(x, 0, num_digits-1) 37 | y = torch.FloatTensor(N, num_digits) 38 | y.zero_() 39 | y.scatter_(1, x, 1) 40 | return y 41 | 42 | def viz_image_pred(rgb, pred_locations, gt_locations, dot_radius=2, unnormalizer=None): 43 | if unnormalizer: 44 | rgb_viz = unnormalizer(rgb.clone()) 45 | else: 46 | rgb_viz = rgb.clone() 47 | for i, step_locations in enumerate(gt_locations.int()): 48 | for x, y in step_locations: 49 | rgb_viz[i,0,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 50 | rgb_viz[i,1,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 51 | rgb_viz[i,2,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 1 52 | 53 | for i, step_locations in enumerate(pred_locations.int()): 54 | for x, y in step_locations: 55 | rgb_viz[i,0,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 1 56 | rgb_viz[i,1,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 57 | rgb_viz[i,2,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 58 | 59 | return rgb_viz 60 | 61 | def viz_birdview_pred(bird_view_viz, pred_locations, gt_locations, dot_radius=2): 62 | for i, step_locations in enumerate(gt_locations.int()): 63 | for x, y in step_locations: 64 | bird_view_viz[i,0,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 65 | bird_view_viz[i,1,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 66 | bird_view_viz[i,2,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 1 67 | 68 | for i, step_locations in enumerate(pred_locations.int()): 69 | for x, y in step_locations: 70 | bird_view_viz[i,0,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 1 71 | bird_view_viz[i,1,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 72 | bird_view_viz[i,2,y-dot_radius:y+dot_radius+1,x-dot_radius:x+dot_radius+1] = 0 73 | 74 | return bird_view_viz -------------------------------------------------------------------------------- /data_collector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Stores tuples of (birdview, measurements, rgb). 3 | 4 | Run from top level directory. 5 | Sample usage - 6 | 7 | python3 bird_view/data_collector.py \ 8 | --dataset_path $PWD/data \ 9 | --frame_skip 10 \ 10 | --frames_per_episode 1000 \ 11 | --n_episodes 100 \ 12 | --port 3000 \ 13 | --n_vehicles 0 \ 14 | --n_pedestrians 0 15 | """ 16 | import argparse 17 | 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | import tqdm 22 | import lmdb 23 | 24 | import carla 25 | 26 | from benchmark import make_suite 27 | from bird_view.utils import carla_utils as cu 28 | from bird_view.utils import bz_utils as bu 29 | 30 | from bird_view.models.common import crop_birdview 31 | from bird_view.models.controller import PIDController 32 | from bird_view.models.roaming import RoamingAgentMine 33 | 34 | 35 | def _debug(observations, agent_debug): 36 | import cv2 37 | 38 | processed = cu.process(observations) 39 | 40 | control = observations['control'] 41 | control = [control.steer, control.throttle, control.brake] 42 | control = ' '.join(str('%.2f' % x).rjust(5, ' ') for x in control) 43 | real_control = observations['real_control'] 44 | real_control = [real_control.steer, real_control.throttle, real_control.brake] 45 | real_control = ' '.join(str('%.2f' % x).rjust(5, ' ') for x in real_control) 46 | 47 | canvas = np.uint8(observations['rgb']).copy() 48 | rows = [x * (canvas.shape[0] // 10) for x in range(10+1)] 49 | cols = [x * (canvas.shape[1] // 10) for x in range(10+1)] 50 | 51 | WHITE = (255, 255, 255) 52 | CROP_SIZE = 192 53 | X = 176 54 | Y = 192 // 2 55 | R = 2 56 | 57 | def _write(text, i, j): 58 | cv2.putText( 59 | canvas, text, (cols[j], rows[i]), 60 | cv2.FONT_HERSHEY_SIMPLEX, 0.4, WHITE, 1) 61 | 62 | _command = { 63 | 1: 'LEFT', 64 | 2: 'RIGHT', 65 | 3: 'STRAIGHT', 66 | 4: 'FOLLOW', 67 | }.get(int(observations['command']), '???') 68 | 69 | _write('Command: ' + _command, 1, 0) 70 | _write('Velocity: %.1f' % np.linalg.norm(observations['velocity']), 2, 0) 71 | _write('Real: %s' % control, -5, 0) 72 | _write('Control: %s' % control, -4, 0) 73 | 74 | r = 2 75 | birdview = cu.visualize_birdview(crop_birdview(processed['birdview'])) 76 | 77 | def _dot(x, y, color): 78 | x = int(x) 79 | y = int(y) 80 | birdview[176-r-x:176+r+1-x,96-r+y:96+r+1+y] = color 81 | 82 | _dot(0, 0, [255, 255, 255]) 83 | 84 | ox, oy = observations['orientation'] 85 | R = np.array([ 86 | [ox, oy], 87 | [-oy, ox]]) 88 | 89 | u = np.array(agent_debug['waypoint']) - np.array(agent_debug['vehicle']) 90 | u = R.dot(u[:2]) 91 | u = u * 4 92 | 93 | _dot(u[0], u[1], [255, 255, 255]) 94 | 95 | def _stick_together(a, b): 96 | h = min(a.shape[0], b.shape[0]) 97 | 98 | r1 = h / a.shape[0] 99 | r2 = h / b.shape[0] 100 | 101 | a = cv2.resize(a, (int(r1 * a.shape[1]), int(r1 * a.shape[0]))) 102 | b = cv2.resize(b, (int(r2 * b.shape[1]), int(r2 * b.shape[0]))) 103 | 104 | return np.concatenate([a, b], 1) 105 | 106 | full = _stick_together(canvas, birdview) 107 | 108 | bu.show_image('full', full) 109 | 110 | 111 | 112 | class NoisyAgent(RoamingAgentMine): 113 | """ 114 | Each parameter is in units of frames. 115 | State can be "drive" or "noise". 116 | """ 117 | def __init__(self, env, noise=None): 118 | super().__init__(env._player, resolution=1, threshold_before=7.5, threshold_after=5.) 119 | 120 | # self.params = {'drive': (100, 'noise'), 'noise': (10, 'drive')} 121 | self.params = {'drive': (100, 'drive')} 122 | 123 | self.steps = 0 124 | self.state = 'drive' 125 | self.noise_steer = 0 126 | self.last_throttle = 0 127 | self.noise_func = noise if noise else lambda: np.random.uniform(-0.25, 0.25) 128 | 129 | self.speed_control = PIDController(K_P=0.5, K_I=0.5/20, K_D=0.1) 130 | self.turn_control = PIDController(K_P=0.75, K_I=1.0/20, K_D=0.0) 131 | 132 | def run_step(self, observations): 133 | self.steps += 1 134 | 135 | last_status = self.state 136 | num_steps, next_state = self.params[self.state] 137 | real_control = super().run_step(observations) 138 | real_control.throttle *= max((1.0 - abs(real_control.steer)), 0.25) 139 | 140 | control = carla.VehicleControl() 141 | control.manual_gear_shift = False 142 | 143 | if self.state == 'noise': 144 | control.steer = self.noise_steer 145 | control.throttle = self.last_throttle 146 | else: 147 | control.steer = real_control.steer 148 | control.throttle = real_control.throttle 149 | control.brake = real_control.brake 150 | 151 | if self.steps == num_steps: 152 | self.steps = 0 153 | self.state = next_state 154 | self.noise_steer = self.noise_func() 155 | self.last_throttle = control.throttle 156 | 157 | self.debug = { 158 | 'waypoint': (self.waypoint.x, self.waypoint.y, self.waypoint.z), 159 | 'vehicle': (self.vehicle.x, self.vehicle.y, self.vehicle.z) 160 | } 161 | 162 | return control, self.road_option, last_status, real_control 163 | 164 | 165 | def get_episode(env, params): 166 | data = list() 167 | progress = tqdm.tqdm(range(params.frames_per_episode), desc='Frame') 168 | start, target = env.pose_tasks[np.random.randint(len(env.pose_tasks))] 169 | env_params = { 170 | 'weather': np.random.choice(list(cu.TRAIN_WEATHERS.keys())), 171 | 'start': start, 172 | 'target': target, 173 | 'n_pedestrians': params.n_pedestrians, 174 | 'n_vehicles': params.n_vehicles, 175 | } 176 | 177 | env.init(**env_params) 178 | env.success_dist = 5.0 179 | 180 | agent = NoisyAgent(env) 181 | agent.set_route(env._start_pose.location, env._target_pose.location) 182 | 183 | # Real loop. 184 | while len(data) < params.frames_per_episode and not env.is_success() and not env.collided: 185 | for _ in range(params.frame_skip): 186 | env.tick() 187 | 188 | observations = env.get_observations() 189 | control, command, last_status, real_control = agent.run_step(observations) 190 | agent_debug = agent.debug 191 | env.apply_control(control) 192 | 193 | observations['command'] = command 194 | observations['control'] = control 195 | observations['real_control'] = real_control 196 | 197 | if not params.nodisplay: 198 | _debug(observations, agent_debug) 199 | 200 | observations['control'] = real_control 201 | processed = cu.process(observations) 202 | 203 | data.append(processed) 204 | 205 | progress.update(1) 206 | 207 | progress.close() 208 | 209 | if (not env.is_success() and not env.collided) or len(data) < 500: 210 | return None 211 | 212 | return data 213 | 214 | 215 | def main(params): 216 | 217 | save_dir = Path(params.dataset_path) 218 | save_dir.mkdir(parents=True, exist_ok=True) 219 | 220 | total = 0 221 | 222 | for i in tqdm.tqdm(range(params.n_episodes), desc='Episode'): 223 | with make_suite('FullTown01-v1', port=params.port, planner=params.planner) as env: 224 | filepath = save_dir.joinpath('%03d' % i) 225 | 226 | if filepath.exists(): 227 | continue 228 | 229 | data = None 230 | 231 | while data is None: 232 | data = get_episode(env, params) 233 | 234 | lmdb_env = lmdb.open(str(filepath), map_size=int(1e10)) 235 | n = len(data) 236 | 237 | with lmdb_env.begin(write=True) as txn: 238 | txn.put('len'.encode(), str(n).encode()) 239 | 240 | for i, x in enumerate(data): 241 | txn.put( 242 | ('rgb_%04d' % i).encode(), 243 | np.ascontiguousarray(x['rgb']).astype(np.uint8)) 244 | txn.put( 245 | ('birdview_%04d' % i).encode(), 246 | np.ascontiguousarray(x['birdview']).astype(np.uint8)) 247 | txn.put( 248 | ('measurements_%04d' % i).encode(), 249 | np.ascontiguousarray(x['measurements']).astype(np.float32)) 250 | txn.put( 251 | ('control_%04d' % i).encode(), 252 | np.ascontiguousarray(x['control']).astype(np.float32)) 253 | 254 | total += len(data) 255 | 256 | print('Total frames: %d' % total) 257 | 258 | 259 | if __name__ == '__main__': 260 | parser = argparse.ArgumentParser() 261 | parser.add_argument('--planner', type=str, choices=['old', 'new'], default='new') 262 | parser.add_argument('--dataset_path', required=True) 263 | parser.add_argument('--n_vehicles', type=int, default=100) 264 | parser.add_argument('--n_pedestrians', type=int, default=250) 265 | parser.add_argument('--n_episodes', type=int, default=50) 266 | parser.add_argument('--frames_per_episode', type=int, default=4000) 267 | parser.add_argument('--frame_skip', type=int, default=1) 268 | parser.add_argument('--nodisplay', action='store_true', default=False) 269 | parser.add_argument('--port', type=int, default=2000) 270 | 271 | params = parser.parse_args() 272 | 273 | main(params) 274 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: carla 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - blas=1.0=mkl 9 | - intel-openmp=2019.1=144 10 | - libgfortran-ng=7.3.0=hdf63c60_0 11 | - libpng=1.6.36=hbc83047_0 12 | - pandas=0.23.4=py35h04863e7_0 13 | - astor=0.7.1=py_0 14 | - c-ares=1.15.0=h14c3975_1001 15 | - cloudpickle=0.8.0=py_0 16 | - cycler=0.10.0=py_1 17 | - dask-core=1.1.3=py_0 18 | - dbus>=1.13.0 19 | - decorator=4.3.2=py_0 20 | - expat=2.2.5=hf484d3e_1002 21 | - fontconfig=2.13.1=h2176d3f_1000 22 | - freetype>=2.9.1 23 | - gast=0.2.2=py_0 24 | - gettext=0.19.8.1=h9745a5d_1001 25 | - glew=2.0.0=hf484d3e_1002 26 | - glib=2.56.2=had28632_1001 27 | - gst-plugins-base>=1.12.5 28 | - gstreamer>=1.12.5 29 | - icu=58.2=hf484d3e_1000 30 | - imageio=2.3.0=py_1 31 | - jpeg=9c=h14c3975_1001 32 | - keras-applications=1.0.7=py_1 33 | - keras-preprocessing=1.0.9=py_1 34 | - kiwisolver=1.0.1=py35h2d50403_2 35 | - libffi=3.2.1=he1b5a44_1006 36 | - libglu=9.0.0=hf484d3e_1000 37 | - libgpuarray=0.7.6=h14c3975_1003 38 | - libiconv=1.15=h14c3975_1004 39 | - libtiff>=4.0.10 40 | - libuuid=2.32.1=h14c3975_1000 41 | - libxcb=1.13=h14c3975_1002 42 | - libxml2=2.9.8=h143f9aa_1005 43 | - mako=1.0.10=py_0 44 | - markdown=3.1.1=py_0 45 | - matplotlib>=2.2.3 46 | - ncurses=6.1=hf484d3e_1002 47 | - networkx=2.2=py_1 48 | - olefile=0.46=py_0 49 | - pcre>=8.41 50 | - pillow>=5.3.0 51 | - pthread-stubs=0.4=h14c3975_1001 52 | - pyparsing=2.3.1=py_0 53 | - pyqt>=5.6.0 54 | - python-dateutil=2.8.0=py_0 55 | - pytz=2018.9=py_0 56 | - pywavelets>=1.0.1 57 | - qt>=5.6.2 58 | - readline=7.0=hf8c457e_1001 59 | - scikit-image>=0.14.0 60 | - sip>=4.18.1 61 | - sqlite=3.28.0=h8b20d00_0 62 | - termcolor=1.1.0=py_2 63 | - tk=8.6.9=hed695b0_1002 64 | - toolz=0.9.0=py_1 65 | - tornado=5.1.1=py35h470a237_0 66 | - tqdm=4.32.1=py_0 67 | - werkzeug=0.15.4=py_0 68 | - xorg-kbproto=1.0.7=h14c3975_1002 69 | - xorg-libx11=1.6.7=h14c3975_1000 70 | - xorg-libxau=1.0.9=h14c3975_0 71 | - xorg-libxdmcp=1.1.2=h14c3975_1007 72 | - xorg-libxext=1.3.4=h516909a_0 73 | - xorg-xextproto=7.3.0=h14c3975_1002 74 | - xorg-xproto=7.0.31=h14c3975_1007 75 | - xz=5.2.4=h14c3975_1001 76 | - yaml=0.1.7=h14c3975_1001 77 | - zlib=1.2.11=h14c3975_1004 78 | - _libgcc_mutex=0.1=main 79 | - _tflow_select=2.1.0=gpu 80 | - absl-py=0.4.1=py35_0 81 | - backcall=0.1.0=py35_0 82 | - binutils_impl_linux-64=2.31.1=h6176602_1 83 | - binutils_linux-64=2.31.1=h6176602_7 84 | - ca-certificates=2019.5.15=0 85 | - certifi=2018.8.24=py35_1 86 | - cffi=1.11.5=py35he75722e_1 87 | - cudatoolkit=8.0=3 88 | - cudnn=7.1.3=cuda8.0_0 89 | - cupti=8.0.61=0 90 | - fastcache=1.0.2=py35h14c3975_2 91 | - gcc_impl_linux-64=7.3.0=habb00fd_1 92 | - gcc_linux-64=7.3.0=h553295d_7 93 | - gmp=6.1.2=h6c8ec71_1 94 | - gmpy2=2.0.8=py35hd0a1c9a_2 95 | - grpcio=1.12.1=py35hdbcaa40_0 96 | - gxx_impl_linux-64=7.3.0=hdf63c60_1 97 | - gxx_linux-64=7.3.0=h553295d_7 98 | - h5py=2.8.0=py35h989c5e5_3 99 | - hdf5=1.10.2=hba1933b_1 100 | - ipykernel=4.10.0=py35_0 101 | - ipython=6.5.0=py35_0 102 | - ipython_genutils=0.2.0=py35hc9e07d0_0 103 | - jedi=0.12.1=py35_0 104 | - jupyter_client=5.2.3=py35_0 105 | - jupyter_core=4.4.0=py35_0 106 | - libedit=3.1.20181209=hc058e9b_0 107 | - libgcc-ng=9.1.0=hdf63c60_0 108 | - libprotobuf=3.6.0=hdbcaa40_0 109 | - libsodium=1.0.16=h1bed415_0 110 | - libstdcxx-ng=8.2.0=hdf63c60_1 111 | - markupsafe=1.0=py35h14c3975_1 112 | - mkl=2018.0.3=1 113 | - mkl-service=1.1.2=py35h90e4bf4_5 114 | - mkl_fft=1.0.6=py35h7dd41cf_0 115 | - mkl_random=1.0.1=py35h4414c95_1 116 | - mpc=1.0.3=hec55b23_5 117 | - mpfr=3.1.5=h11a74b3_2 118 | - mpmath=1.0.0=py35_2 119 | - ninja=1.8.2=py35h6bb024c_1 120 | - numpy=1.15.2=py35h1d66e8a_0 121 | - numpy-base=1.15.2=py35h81de0dd_0 122 | - openssl=1.0.2s=h7b6447c_0 123 | - parso=0.3.1=py35_0 124 | - pexpect=4.6.0=py35_0 125 | - pickleshare=0.7.4=py35hd57304d_0 126 | - pip=10.0.1=py35_0 127 | - prompt_toolkit=1.0.15=py35hc09de7a_0 128 | - protobuf>=3.6.0 129 | - ptyprocess=0.6.0=py35_0 130 | - pycparser=2.19=py35_0 131 | - pygments=2.2.0=py35h0f41973_0 132 | - pygpu=0.7.6=py35h3010b51_0 133 | - python=3.5.6=hc3d631a_0 134 | - pyyaml=3.13=py35h14c3975_0 135 | - pyzmq>=17.1.2 136 | - scipy=1.1.0=py35hfa4b5c9_1 137 | - setuptools=40.2.0=py35_0 138 | - simplegeneric=0.8.1=py35_2 139 | - six=1.11.0=py35_1 140 | - sympy=1.2=py35_0 141 | - tensorboard=1.10.0=py35hf484d3e_0 142 | - tensorflow=1.10.0=gpu_py35ha6119f3_0 143 | - tensorflow-base=1.10.0=gpu_py35h3435052_0 144 | - tensorflow-gpu=1.10.0=hf154084_0 145 | - theano=1.0.2=py35h6bb024c_0 146 | - traitlets=4.3.2=py35ha522a97_0 147 | - wcwidth=0.1.7=py35hcd08066_0 148 | - wheel=0.31.1=py35_0 149 | - zeromq=4.2.5=hf484d3e_1 150 | - cuda80=1.0=h205658b_0 151 | - pytorch=1.0.0=py3.5_cuda8.0.61_cudnn7.1.2_1 152 | - torchvision=0.2.1=py_2 153 | - pip: 154 | - atari-py==0.1.7 155 | - atomicwrites==1.3.0 156 | - attrs==19.1.0 157 | - bcolz==1.2.1 158 | - bleach==1.5.0 159 | - chardet==3.0.4 160 | - click==7.0 161 | - colorama==0.4.1 162 | - cvxpy==1.0.22 163 | - cython==0.29.6 164 | - dask==1.1.3 165 | - dill==0.2.9 166 | - dlib==19.17.0 167 | - ecos==2.0.7.post1 168 | - enum34==1.1.6 169 | - filelock==3.0.10 170 | - flatbuffers==1.10 171 | - funcsigs==1.0.2 172 | - future==0.17.1 173 | - futures==3.1.1 174 | - glfw==1.7.1 175 | - gym==0.12.1 176 | - html5lib==0.9999999 177 | - idna==2.8 178 | - imgaug==0.2.8 179 | - keras==2.2.2 180 | - lmdb==0.94 181 | - lockfile==0.12.2 182 | - loguru==0.3.0 183 | - mock==2.0.0 184 | - more-itertools==7.0.0 185 | - multiprocess==0.70.7 186 | - opencv-python==4.0.0.21 187 | - osqp==0.5.0 188 | - pathlib2==2.3.3 189 | - pathos==0.2.3 190 | - pbr==5.1.3 191 | - pluggy==0.9.0 192 | - pox==0.2.5 193 | - ppft==1.6.4.9 194 | - py==1.8.0 195 | - pygame==1.9.4 196 | - pyglet==1.3.2 197 | - pyopengl==3.1.0 198 | - pytest==4.4.1 199 | - ray==0.6.6 200 | - redis==3.2.1 201 | - requests==2.21.0 202 | - scs==2.1.0 203 | - shapely==1.6.4.post2 204 | - tensorboardx==1.6 205 | - tensorflow-tensorboard==0.4.0 206 | - terminaltables==3.1.0 207 | - torch==1.0.0 208 | - typing==3.6.6 209 | - urllib3==1.24.1 210 | -------------------------------------------------------------------------------- /figs/birdview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/figs/birdview.png -------------------------------------------------------------------------------- /figs/birdview_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/figs/birdview_loss.png -------------------------------------------------------------------------------- /figs/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/figs/fig1.png -------------------------------------------------------------------------------- /figs/image_phase1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dotchen/LearningByCheating/4145d33f74c9a8f27061a0f94840f3e458ecc60e/figs/image_phase1.png -------------------------------------------------------------------------------- /misc/ImportMaps.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # ============================================================================== 4 | # -- Parse arguments ----------------------------------------------------------- 5 | # ============================================================================== 6 | 7 | DOC_STRING="Unpack and copy over CarlaUE4's Exported Maps" 8 | 9 | USAGE_STRING="Usage: $0 [-h|--help] [-d|--dir] " 10 | 11 | OUTPUT_DIRECTORY="" 12 | 13 | OPTS=`getopt -o h,d:: --long help,dir:: -n 'parse-options' -- "$@"` 14 | 15 | if [ $? != 0 ] ; then echo "$USAGE_STRING" ; exit 2; fi 16 | 17 | eval set -- "$OPTS" 18 | 19 | while true; do 20 | case "$1" in 21 | --dir ) 22 | OUTPUT_DIRECTORY="$2" 23 | shift ;; 24 | -h | --help ) 25 | echo "$DOC_STRING" 26 | echo "$USAGE_STRING" 27 | exit 1 28 | ;; 29 | * ) 30 | break ;; 31 | esac 32 | done 33 | 34 | #Tar.gz the stuff 35 | for filepath in `find ExportedMaps/ -type f -name "*.tar.gz"`; do 36 | tar --keep-newer-files -xvf ${filepath} 37 | done 38 | 39 | -------------------------------------------------------------------------------- /misc/dynamic_weather.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017 Computer Vision Center (CVC) at the Universitat Autonoma de 4 | # Barcelona (UAB). 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """ 10 | CARLA Dynamic Weather: 11 | 12 | Connect to a CARLA Simulator instance and control the weather. Change Sun 13 | position smoothly with time and generate storms occasionally. 14 | """ 15 | 16 | import glob 17 | import os 18 | import sys 19 | 20 | try: 21 | sys.path.append(glob.glob('**/*%d.%d-%s.egg' % ( 22 | sys.version_info.major, 23 | sys.version_info.minor, 24 | 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) 25 | except IndexError: 26 | pass 27 | 28 | import carla 29 | 30 | import argparse 31 | import math 32 | 33 | 34 | def clamp(value, minimum=0.0, maximum=100.0): 35 | return max(minimum, min(value, maximum)) 36 | 37 | 38 | class Sun(object): 39 | def __init__(self, azimuth, altitude): 40 | self.azimuth = azimuth 41 | self.altitude = altitude 42 | self._t = 0.0 43 | 44 | def tick(self, delta_seconds): 45 | self._t += 0.008 * delta_seconds 46 | self._t %= 2.0 * math.pi 47 | self.azimuth += 0.25 * delta_seconds 48 | self.azimuth %= 360.0 49 | self.altitude = 35.0 * (math.sin(self._t) + 1.0) 50 | 51 | def __str__(self): 52 | return 'Sun(%.2f, %.2f)' % (self.azimuth, self.altitude) 53 | 54 | 55 | class Storm(object): 56 | def __init__(self, precipitation): 57 | self._t = precipitation if precipitation > 0.0 else -50.0 58 | self._increasing = True 59 | self.clouds = 0.0 60 | self.rain = 0.0 61 | self.puddles = 0.0 62 | self.wind = 0.0 63 | 64 | def tick(self, delta_seconds): 65 | delta = (1.3 if self._increasing else -1.3) * delta_seconds 66 | self._t = clamp(delta + self._t, -250.0, 100.0) 67 | self.clouds = clamp(self._t + 40.0, 0.0, 90.0) 68 | self.rain = clamp(self._t, 0.0, 80.0) 69 | delay = -10.0 if self._increasing else 90.0 70 | self.puddles = clamp(self._t + delay, 0.0, 75.0) 71 | self.wind = clamp(self._t - delay, 0.0, 80.0) 72 | if self._t == -250.0: 73 | self._increasing = True 74 | if self._t == 100.0: 75 | self._increasing = False 76 | 77 | def __str__(self): 78 | return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind) 79 | 80 | 81 | class Weather(object): 82 | def __init__(self, weather): 83 | self.weather = weather 84 | self._sun = Sun(weather.sun_azimuth_angle, weather.sun_altitude_angle) 85 | self._storm = Storm(weather.precipitation) 86 | 87 | def tick(self, delta_seconds): 88 | self._sun.tick(delta_seconds) 89 | self._storm.tick(delta_seconds) 90 | self.weather.cloudyness = self._storm.clouds 91 | self.weather.precipitation = self._storm.rain 92 | self.weather.precipitation_deposits = self._storm.puddles 93 | self.weather.wind_intensity = self._storm.wind 94 | self.weather.sun_azimuth_angle = self._sun.azimuth 95 | self.weather.sun_altitude_angle = self._sun.altitude 96 | 97 | def __str__(self): 98 | return '%s %s' % (self._sun, self._storm) 99 | 100 | 101 | def main(): 102 | argparser = argparse.ArgumentParser( 103 | description=__doc__) 104 | argparser.add_argument( 105 | '--host', 106 | metavar='H', 107 | default='127.0.0.1', 108 | help='IP of the host server (default: 127.0.0.1)') 109 | argparser.add_argument( 110 | '-p', '--port', 111 | metavar='P', 112 | default=2000, 113 | type=int, 114 | help='TCP port to listen to (default: 2000)') 115 | argparser.add_argument( 116 | '-s', '--speed', 117 | metavar='FACTOR', 118 | default=1.0, 119 | type=float, 120 | help='rate at which the weather changes (default: 1.0)') 121 | args = argparser.parse_args() 122 | 123 | speed_factor = args.speed 124 | update_freq = 0.1 / speed_factor 125 | 126 | client = carla.Client(args.host, args.port) 127 | client.set_timeout(2.0) 128 | world = client.get_world() 129 | 130 | weather = Weather(world.get_weather()) 131 | 132 | elapsed_time = 0.0 133 | 134 | while True: 135 | timestamp = world.wait_for_tick(seconds=30.0) 136 | elapsed_time += timestamp.delta_seconds 137 | if elapsed_time > update_freq: 138 | weather.tick(speed_factor * elapsed_time) 139 | world.set_weather(weather.weather) 140 | sys.stdout.write('\r' + str(weather) + 12 * ' ') 141 | sys.stdout.flush() 142 | elapsed_time = 0.0 143 | 144 | 145 | if __name__ == '__main__': 146 | 147 | main() 148 | -------------------------------------------------------------------------------- /misc/find_traffic_violations.py: -------------------------------------------------------------------------------- 1 | """ 2 | usage - 3 | 4 | python3 parse.py --run_dir . --threshold 1.0 --town 1 5 | 6 | or 7 | 8 | python3 parse.py --run_dir . --threshold 1.0 --debug 9 | """ 10 | import argparse 11 | 12 | from pathlib import Path 13 | 14 | import pandas as pd 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import tqdm 18 | 19 | 20 | class Vector2(object): 21 | def __init__(self, x, y): 22 | self.x = x 23 | self.y = y 24 | 25 | def __truediv__(self, c): 26 | return Vector2(self.x / c, self.y / c) 27 | 28 | def __add__(self, v): 29 | return Vector2(self.x + v.x, self.y + v.y) 30 | 31 | def __sub__(self, v): 32 | return Vector2(self.x - v.x, self.y - v.y) 33 | 34 | def dot(self, v): 35 | return self.x * v.x + self.y * v.y 36 | 37 | def cross(self, v): 38 | return self.x * v.y - self.y * v.x 39 | 40 | def norm(self): 41 | return np.sqrt(self.x * self.x + self.y * self.y) 42 | 43 | def normalize(self): 44 | return self / (self.norm() + 1e-8) 45 | 46 | 47 | def get_collision(p1, p2, lines): 48 | """ 49 | line 1: p + t r 50 | line 2: q + u s 51 | """ 52 | p = p1 53 | r = p2 - p1 54 | 55 | for a, b in lines: 56 | q = a 57 | s = b - a 58 | 59 | r_cross_s = r.cross(s) 60 | q_minus_p = q - p 61 | 62 | if abs(r_cross_s) < 1e-3: 63 | continue 64 | 65 | t = q_minus_p.cross(s) / r_cross_s 66 | u = q_minus_p.cross(r) / r_cross_s 67 | 68 | if t >= 0.0 and t <= 1.0 and u >= 0.0 and u <= 1.0: 69 | return True 70 | 71 | return False 72 | 73 | 74 | def parse(df, lights): 75 | n = len(df) 76 | t = np.array(list(range(n))) 77 | traveled = 0.0 78 | 79 | broken_t = list() 80 | broken = list() 81 | 82 | for i in range(1, n): 83 | a = Vector2(df['x'][i-1], df['y'][i-1]) 84 | b = Vector2(df['x'][i], df['y'][i]) 85 | traveled += (a - b).norm() 86 | 87 | if not df['is_light_red'][i-1]: 88 | continue 89 | 90 | if get_collision(a, b, lights): 91 | broken_t.append(i) 92 | broken.append(df['is_light_red'][i-1]) 93 | 94 | if args.debug: 95 | plt.plot(t, is_light_red) 96 | plt.plot(t, speed) 97 | plt.plot(broken_t, broken, 'r.') 98 | plt.show() 99 | 100 | return broken, traveled 101 | 102 | 103 | def get_town(town): 104 | lights = Path('light_town%s.txt' % town).read_text().strip().split('\n') 105 | lights = [tuple(map(float, x.split())) for x in lights] 106 | lights = np.array(lights) 107 | 108 | alpha = 10.0 109 | lines = list() 110 | 111 | for x, y in lights: 112 | for dx, dy in [(1, 0), (0, 1), (-1, 0), (0, -1)]: 113 | nx = x + alpha * dx 114 | ny = y + alpha * dy 115 | 116 | lines.append((Vector2(x, y), Vector2(nx, ny))) 117 | 118 | return lines 119 | 120 | 121 | def main(run_dir): 122 | result = list() 123 | 124 | for path in sorted(run_dir.glob('*/summary.csv')): 125 | summary_csv = pd.read_csv(str(path)) 126 | total = len(summary_csv) 127 | lights = get_town(1 if 'Town01' in str(path) else 2) 128 | 129 | n_infractions = list() 130 | distances = list() 131 | 132 | for _, row in summary_csv.iterrows(): 133 | weather = row['weather'] 134 | start = row['start'] 135 | target = row['target'] 136 | run_csv = 'w%s_s%s_t%s.csv' % (weather, start, target) 137 | 138 | diag = pd.read_csv(str(path.parent / 'diagnostics' / run_csv)) 139 | 140 | crosses, dist = parse(diag, lights) 141 | 142 | n_infractions.append(sum(crosses)) 143 | distances.append(dist) 144 | 145 | print(path, total) 146 | print('%s infractions total.' % np.sum(n_infractions)) 147 | print('%s total dist' % np.sum(distances)) 148 | 149 | result.append({ 150 | 'suite': path.parent.stem, 151 | 'infractions': sum(n_infractions), 152 | 'per_10km': sum(n_infractions) / (np.sum(distances) / 10000), 153 | 'distances': sum(distances)}) 154 | 155 | pd.DataFrame(result).to_csv('%s/lights.csv' % path.parent.parent, index=False) 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument('--run_dir', required=True, type=str) 161 | parser.add_argument('--debug', action='store_true', default=False) 162 | 163 | args = parser.parse_args() 164 | 165 | main(Path(args.run_dir)) 166 | -------------------------------------------------------------------------------- /misc/light_town1.txt: -------------------------------------------------------------------------------- 1 | 94.990 70.410 2 | 102.720 52.850 3 | 85.730 45.780 4 | 94.990 144.380 5 | 102.720 126.820 6 | 85.730 119.750 7 | 94.990 209.920 8 | 102.720 192.550 9 | 85.730 185.290 10 | 106.120 324.080 11 | 85.630 316.350 12 | 75.710 333.340 13 | 349.600 324.080 14 | 332.040 316.350 15 | 324.970 333.340 16 | 168.870 53.030 17 | 151.310 45.300 18 | 144.240 62.290 19 | 331.990 44.880 20 | 321.760 62.440 21 | 341.360 12.710 22 | 348.430 -4.480 23 | 341.250 69.510 24 | 331.990 118.930 25 | 321.680 136.132 26 | 341.250 143.560 27 | 331.990 184.470 28 | 323.410 202.030 29 | 341.250 209.100 30 | 143.120 4.830 31 | 160.680 12.560 32 | 167.750 -4.430 33 | 77.480 4.830 34 | 95.040 12.560 35 | 102.110 -4.430 36 | 323.800 4.980 37 | -------------------------------------------------------------------------------- /misc/light_town2.txt: -------------------------------------------------------------------------------- 1 | 147.110 234.222 2 | 129.550 226.560 3 | 122.480 243.817 4 | 121.370 194.090 5 | 138.930 201.820 6 | 146.000 184.830 7 | 30.960 194.290 8 | 48.520 203.280 9 | 58.050 185.030 10 | 186.950 176.770 11 | 179.220 194.330 12 | 196.210 201.400 13 | 186.950 226.180 14 | 179.220 243.740 15 | 196.210 250.810 16 | 6.910 184.960 17 | -10.080 177.890 18 | 48.530 251.760 19 | 56.260 234.200 20 | 39.270 227.130 21 | 56.690 299.989 22 | 39.130 289.530 23 | 32.060 309.410 24 | -0.820 202.520 25 | -------------------------------------------------------------------------------- /misc/spawn_npc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017 Computer Vision Center (CVC) at the Universitat Autonoma de 4 | # Barcelona (UAB). 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | """Spawn NPCs into the simulation""" 10 | 11 | import glob 12 | import os 13 | import sys 14 | 15 | try: 16 | sys.path.append(glob.glob('**/*%d.%d-%s.egg' % ( 17 | sys.version_info.major, 18 | sys.version_info.minor, 19 | 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) 20 | except IndexError: 21 | pass 22 | 23 | import carla 24 | 25 | import argparse 26 | import random 27 | import time 28 | 29 | 30 | def main(): 31 | argparser = argparse.ArgumentParser( 32 | description=__doc__) 33 | argparser.add_argument( 34 | '--host', 35 | metavar='H', 36 | default='127.0.0.1', 37 | help='IP of the host server (default: 127.0.0.1)') 38 | argparser.add_argument( 39 | '-p', '--port', 40 | metavar='P', 41 | default=2000, 42 | type=int, 43 | help='TCP port to listen to (default: 2000)') 44 | argparser.add_argument( 45 | '-n', '--number-of-vehicles', 46 | metavar='N', 47 | default=10, 48 | type=int, 49 | help='number of vehicles (default: 10)') 50 | argparser.add_argument( 51 | '-d', '--delay', 52 | metavar='D', 53 | default=2.0, 54 | type=float, 55 | help='delay in seconds between spawns (default: 2.0)') 56 | argparser.add_argument( 57 | '--safe', 58 | action='store_true', 59 | help='avoid spawning vehicles prone to accidents') 60 | args = argparser.parse_args() 61 | 62 | actor_list = [] 63 | client = carla.Client(args.host, args.port) 64 | client.set_timeout(2.0) 65 | 66 | try: 67 | 68 | world = client.get_world() 69 | blueprints = world.get_blueprint_library().filter('vehicle.*') 70 | 71 | if args.safe: 72 | blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4] 73 | blueprints = [x for x in blueprints if not x.id.endswith('isetta')] 74 | 75 | def try_spawn_random_vehicle_at(transform): 76 | blueprint = random.choice(blueprints) 77 | if blueprint.has_attribute('color'): 78 | color = random.choice(blueprint.get_attribute('color').recommended_values) 79 | blueprint.set_attribute('color', color) 80 | blueprint.set_attribute('role_name', 'autopilot') 81 | vehicle = world.try_spawn_actor(blueprint, transform) 82 | if vehicle is not None: 83 | actor_list.append(vehicle) 84 | vehicle.set_autopilot() 85 | print('spawned %r at %s' % (vehicle.type_id, transform.location)) 86 | return True 87 | return False 88 | 89 | # @todo Needs to be converted to list to be shuffled. 90 | spawn_points = list(world.get_map().get_spawn_points()) 91 | random.shuffle(spawn_points) 92 | 93 | print('found %d spawn points.' % len(spawn_points)) 94 | 95 | count = args.number_of_vehicles 96 | 97 | for spawn_point in spawn_points: 98 | if try_spawn_random_vehicle_at(spawn_point): 99 | count -= 1 100 | if count <= 0: 101 | break 102 | 103 | while count > 0: 104 | time.sleep(args.delay) 105 | if try_spawn_random_vehicle_at(random.choice(spawn_points)): 106 | count -= 1 107 | 108 | print('spawned %d vehicles, press Ctrl+C to exit.' % args.number_of_vehicles) 109 | 110 | while True: 111 | time.sleep(10) 112 | 113 | finally: 114 | 115 | print('\ndestroying %d actors' % len(actor_list)) 116 | client.apply_batch([carla.command.DestroyActor(x.id) for x in actor_list]) 117 | 118 | 119 | if __name__ == '__main__': 120 | 121 | try: 122 | main() 123 | except KeyboardInterrupt: 124 | pass 125 | finally: 126 | print('\ndone.') 127 | -------------------------------------------------------------------------------- /misc/synchronous_mode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017 Computer Vision Center (CVC) at the Universitat Autonoma de 4 | # Barcelona (UAB). 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | import glob 10 | import os 11 | import sys 12 | 13 | try: 14 | sys.path.append(glob.glob('**/*%d.%d-%s.egg' % ( 15 | sys.version_info.major, 16 | sys.version_info.minor, 17 | 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) 18 | except IndexError: 19 | pass 20 | 21 | import carla 22 | 23 | import logging 24 | import random 25 | 26 | try: 27 | import pygame 28 | except ImportError: 29 | raise RuntimeError('cannot import pygame, make sure pygame package is installed') 30 | 31 | try: 32 | import numpy as np 33 | except ImportError: 34 | raise RuntimeError('cannot import numpy, make sure numpy package is installed') 35 | 36 | try: 37 | import queue 38 | except ImportError: 39 | import Queue as queue 40 | 41 | 42 | def draw_image(surface, image): 43 | array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8")) 44 | array = np.reshape(array, (image.height, image.width, 4)) 45 | array = array[:, :, :3] 46 | array = array[:, :, ::-1] 47 | image_surface = pygame.surfarray.make_surface(array.swapaxes(0, 1)) 48 | surface.blit(image_surface, (0, 0)) 49 | 50 | 51 | def get_font(): 52 | fonts = [x for x in pygame.font.get_fonts()] 53 | default_font = 'ubuntumono' 54 | font = default_font if default_font in fonts else fonts[0] 55 | font = pygame.font.match_font(font) 56 | return pygame.font.Font(font, 14) 57 | 58 | 59 | def should_quit(): 60 | for event in pygame.event.get(): 61 | if event.type == pygame.QUIT: 62 | return True 63 | elif event.type == pygame.KEYUP: 64 | if event.key == pygame.K_ESCAPE: 65 | return True 66 | return False 67 | 68 | 69 | def main(): 70 | actor_list = [] 71 | pygame.init() 72 | 73 | client = carla.Client('localhost', 2000) 74 | client.set_timeout(2.0) 75 | 76 | world = client.get_world() 77 | 78 | print('enabling synchronous mode.') 79 | settings = world.get_settings() 80 | settings.synchronous_mode = True 81 | world.apply_settings(settings) 82 | 83 | try: 84 | m = world.get_map() 85 | start_pose = random.choice(m.get_spawn_points()) 86 | waypoint = m.get_waypoint(start_pose.location) 87 | 88 | blueprint_library = world.get_blueprint_library() 89 | 90 | vehicle = world.spawn_actor( 91 | random.choice(blueprint_library.filter('vehicle.*')), 92 | start_pose) 93 | actor_list.append(vehicle) 94 | vehicle.set_simulate_physics(False) 95 | 96 | camera = world.spawn_actor( 97 | blueprint_library.find('sensor.camera.rgb'), 98 | carla.Transform(carla.Location(x=-5.5, z=2.8), carla.Rotation(pitch=-15)), 99 | attach_to=vehicle) 100 | actor_list.append(camera) 101 | 102 | # Make sync queue for sensor data. 103 | image_queue = queue.Queue() 104 | camera.listen(image_queue.put) 105 | 106 | frame = None 107 | 108 | # display = pygame.display.set_mode( 109 | # (800, 600), 110 | # pygame.HWSURFACE | pygame.DOUBLEBUF) 111 | font = get_font() 112 | 113 | clock = pygame.time.Clock() 114 | 115 | while True: 116 | # if should_quit(): 117 | # return 118 | 119 | clock.tick() 120 | world.tick() 121 | ts = world.wait_for_tick() 122 | 123 | if frame is not None: 124 | if ts.frame_count != frame + 1: 125 | print('frame skip!') 126 | 127 | frame = ts.frame_count 128 | 129 | while True: 130 | image = image_queue.get() 131 | if image.frame_number == ts.frame_count: 132 | break 133 | print ( 134 | 'wrong image time-stampstamp: frame=%d, image.frame=%d', 135 | ts.frame_count, 136 | image.frame_number) 137 | 138 | waypoint = random.choice(waypoint.next(2)) 139 | vehicle.set_transform(waypoint.transform) 140 | 141 | # draw_image(display, image) 142 | 143 | text_surface = font.render('% 5d FPS' % clock.get_fps(), True, (255, 255, 255)) 144 | # display.blit(text_surface, (8, 10)) 145 | 146 | # pygame.display.flip() 147 | 148 | finally: 149 | print('\ndisabling synchronous mode.') 150 | settings = world.get_settings() 151 | settings.synchronous_mode = False 152 | world.apply_settings(settings) 153 | 154 | print('destroying actors.') 155 | for actor in actor_list: 156 | actor.destroy() 157 | 158 | pygame.quit() 159 | print('done.') 160 | 161 | 162 | if __name__ == '__main__': 163 | 164 | main() 165 | -------------------------------------------------------------------------------- /misc/tutorial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017 Computer Vision Center (CVC) at the Universitat Autonoma de 4 | # Barcelona (UAB). 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | import glob 10 | import os 11 | import sys 12 | 13 | try: 14 | sys.path.append(glob.glob('**/*%d.%d-%s.egg' % ( 15 | sys.version_info.major, 16 | sys.version_info.minor, 17 | 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) 18 | except IndexError: 19 | pass 20 | 21 | import carla 22 | 23 | import random 24 | import time 25 | 26 | 27 | def main(): 28 | actor_list = [] 29 | 30 | # In this tutorial script, we are going to add a vehicle to the simulation 31 | # and let it drive in autopilot. We will also create a camera attached to 32 | # that vehicle, and save all the images generated by the camera to disk. 33 | 34 | try: 35 | # First of all, we need to create the client that will send the requests 36 | # to the simulator. Here we'll assume the simulator is accepting 37 | # requests in the localhost at port 2000. 38 | client = carla.Client('localhost', 2000) 39 | client.set_timeout(2.0) 40 | 41 | # Once we have a client we can retrieve the world that is currently 42 | # running. 43 | world = client.get_world() 44 | 45 | # The world contains the list blueprints that we can use for adding new 46 | # actors into the simulation. 47 | blueprint_library = world.get_blueprint_library() 48 | 49 | # Now let's filter all the blueprints of type 'vehicle' and choose one 50 | # at random. 51 | bp = random.choice(blueprint_library.filter('vehicle')) 52 | 53 | # A blueprint contains the list of attributes that define a vehicle 54 | # instance, we can read them and modify some of them. For instance, 55 | # let's randomize its color. 56 | color = random.choice(bp.get_attribute('color').recommended_values) 57 | bp.set_attribute('color', color) 58 | 59 | # Now we need to give an initial transform to the vehicle. We choose a 60 | # random transform from the list of recommended spawn points of the map. 61 | transform = random.choice(world.get_map().get_spawn_points()) 62 | 63 | # So let's tell the world to spawn the vehicle. 64 | vehicle = world.spawn_actor(bp, transform) 65 | 66 | # It is important to note that the actors we create won't be destroyed 67 | # unless we call their "destroy" function. If we fail to call "destroy" 68 | # they will stay in the simulation even after we quit the Python script. 69 | # For that reason, we are storing all the actors we create so we can 70 | # destroy them afterwards. 71 | actor_list.append(vehicle) 72 | print('created %s' % vehicle.type_id) 73 | 74 | # Let's put the vehicle to drive around. 75 | vehicle.set_autopilot(True) 76 | 77 | # Let's add now a "depth" camera attached to the vehicle. Note that the 78 | # transform we give here is now relative to the vehicle. 79 | camera_bp = blueprint_library.find('sensor.camera.depth') 80 | camera_transform = carla.Transform(carla.Location(x=1.5, z=2.4)) 81 | camera = world.spawn_actor(camera_bp, camera_transform, attach_to=vehicle) 82 | actor_list.append(camera) 83 | print('created %s' % camera.type_id) 84 | 85 | # Now we register the function that will be called each time the sensor 86 | # receives an image. In this example we are saving the image to disk 87 | # converting the pixels to gray-scale. 88 | cc = carla.ColorConverter.LogarithmicDepth 89 | camera.listen(lambda image: image.save_to_disk('_out/%06d.png' % image.frame_number, cc)) 90 | 91 | # Oh wait, I don't like the location we gave to the vehicle, I'm going 92 | # to move it a bit forward. 93 | location = vehicle.get_location() 94 | location.x += 40 95 | vehicle.set_location(location) 96 | print('moved vehicle to %s' % location) 97 | 98 | # But the city now is probably quite empty, let's add a few more 99 | # vehicles. 100 | transform.location += carla.Location(x=40, y=-3.2) 101 | transform.rotation.yaw = -180.0 102 | for x in range(0, 10): 103 | transform.location.x += 8.0 104 | 105 | bp = random.choice(blueprint_library.filter('vehicle')) 106 | 107 | # This time we are using try_spawn_actor. If the spot is already 108 | # occupied by another object, the function will return None. 109 | npc = world.try_spawn_actor(bp, transform) 110 | if npc is not None: 111 | actor_list.append(npc) 112 | npc.set_autopilot() 113 | print('created %s' % npc.type_id) 114 | 115 | time.sleep(5) 116 | 117 | finally: 118 | 119 | print('destroying actors') 120 | for actor in actor_list: 121 | actor.destroy() 122 | print('done.') 123 | 124 | 125 | if __name__ == '__main__': 126 | 127 | main() 128 | -------------------------------------------------------------------------------- /misc/vehicle_gallery.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017 Computer Vision Center (CVC) at the Universitat Autonoma de 4 | # Barcelona (UAB). 5 | # 6 | # This work is licensed under the terms of the MIT license. 7 | # For a copy, see . 8 | 9 | import glob 10 | import os 11 | import sys 12 | 13 | try: 14 | sys.path.append(glob.glob('**/*%d.%d-%s.egg' % ( 15 | sys.version_info.major, 16 | sys.version_info.minor, 17 | 'win-amd64' if os.name == 'nt' else 'linux-x86_64'))[0]) 18 | except IndexError: 19 | pass 20 | 21 | import carla 22 | 23 | import math 24 | import random 25 | import time 26 | 27 | 28 | def get_transform(vehicle_location, angle, d=6.4): 29 | a = math.radians(angle) 30 | location = carla.Location(d * math.cos(a), d * math.sin(a), 2.0) + vehicle_location 31 | return carla.Transform(location, carla.Rotation(yaw=180 + angle, pitch=-15)) 32 | 33 | 34 | def main(): 35 | client = carla.Client('localhost', 2000) 36 | client.set_timeout(2.0) 37 | world = client.get_world() 38 | spectator = world.get_spectator() 39 | vehicle_blueprints = world.get_blueprint_library().filter('vehicle') 40 | 41 | location = random.choice(world.get_map().get_spawn_points()).location 42 | 43 | for blueprint in vehicle_blueprints: 44 | transform = carla.Transform(location, carla.Rotation(yaw=-45.0)) 45 | vehicle = world.spawn_actor(blueprint, transform) 46 | 47 | try: 48 | 49 | print(vehicle.type_id) 50 | 51 | angle = 0 52 | while angle < 356: 53 | timestamp = world.wait_for_tick() 54 | angle += timestamp.delta_seconds * 60.0 55 | spectator.set_transform(get_transform(vehicle.get_location(), angle - 90)) 56 | 57 | finally: 58 | 59 | vehicle.destroy() 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | main() 65 | -------------------------------------------------------------------------------- /quick_start.sh: -------------------------------------------------------------------------------- 1 | # Download CARLA 0.9.6 2 | # wget http://carla-assets-internal.s3.amazonaws.com/Releases/Linux/CARLA_0.9.6.tar.gz 3 | mkdir carla_lbc 4 | tar -xvzf CARLA_0.9.6.tar.gz -C carla_lbc 5 | cd carla_lbc 6 | 7 | # Download LBC 8 | git init 9 | git remote add origin https://github.com/dianchen96/LearningByCheating.git 10 | # rename the LICENSE file to avoid conflicts during the pull 11 | mv LICENSE CARLA_LICENSE 12 | git pull origin release-0.9.6 13 | wget http://www.cs.utexas.edu/~dchen/lbc_release/navmesh/Town01.bin 14 | wget http://www.cs.utexas.edu/~dchen/lbc_release/navmesh/Town02.bin 15 | mv Town*.bin CarlaUE4/Content/Carla/Maps/Nav/ 16 | 17 | # Create conda environment 18 | conda env create -f environment.yml 19 | conda activate carla 20 | 21 | # Install carla client 22 | cd PythonAPI/carla/dist 23 | rm carla-0.9.6-py3.5-linux-x86_64.egg 24 | wget http://www.cs.utexas.edu/~dchen/lbc_release/egg/carla-0.9.6-py3.5-linux-x86_64.egg 25 | easy_install carla-0.9.6-py3.5-linux-x86_64.egg 26 | 27 | # Download model checkpoints 28 | cd ../../.. 29 | mkdir -p ckpts/image 30 | cd ckpts/image 31 | wget http://www.cs.utexas.edu/~dchen/lbc_release/ckpts/image/model-10.th 32 | wget http://www.cs.utexas.edu/~dchen/lbc_release/ckpts/image/config.json 33 | cd ../.. 34 | mkdir -p ckpts/priveleged 35 | cd ckpts/priveleged 36 | wget http://www.cs.utexas.edu/~dchen/lbc_release/ckpts/privileged/model-128.th 37 | wget http://www.cs.utexas.edu/~dchen/lbc_release/ckpts/privileged/config.json 38 | cd ../.. 39 | -------------------------------------------------------------------------------- /training/phase2_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import augmenter 5 | from torchvision import transforms 6 | import torchvision.transforms.functional as TF 7 | 8 | import sys 9 | import glob 10 | try: 11 | sys.path.append(glob.glob('../PythonAPI')[0]) 12 | sys.path.append(glob.glob('../bird_view')[0]) 13 | except IndexError as e: 14 | pass 15 | 16 | import utils.carla_utils as cu 17 | from models.image import ImagePolicyModelSS 18 | from models.birdview import BirdViewPolicyModelSS 19 | 20 | CROP_SIZE = 192 21 | PIXELS_PER_METER = 5 22 | 23 | 24 | def repeat(a, repeats, dim=0): 25 | """ 26 | Substitute for numpy's repeat function. Taken from https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/2 27 | torch.repeat([1,2,3], 2) --> [1, 2, 3, 1, 2, 3] 28 | np.repeat([1,2,3], repeats=2, axis=0) --> [1, 1, 2, 2, 3, 3] 29 | 30 | :param a: tensor 31 | :param repeats: number of repeats 32 | :param dim: dimension where to repeat 33 | :return: tensor with repitions 34 | """ 35 | 36 | init_dim = a.size(dim) 37 | repeat_idx = [1] * a.dim() 38 | repeat_idx[dim] = repeats 39 | a = a.repeat(*(repeat_idx)) 40 | if a.is_cuda: # use cuda-device if input was on cuda device already 41 | order_index = torch.cuda.LongTensor( 42 | torch.cat([init_dim * torch.arange(repeats, device=a.device) + i for i in range(init_dim)])) 43 | else: 44 | order_index = torch.LongTensor( 45 | torch.cat([init_dim * torch.arange(repeats) + i for i in range(init_dim)])) 46 | 47 | return torch.index_select(a, dim, order_index) 48 | 49 | 50 | def get_weight(learner_points, teacher_points): 51 | decay = torch.FloatTensor([0.7**i for i in range(5)]).to(learner_points.device) 52 | xy_bias = torch.FloatTensor([0.7,0.3]).to(learner_points.device) 53 | loss_weight = torch.mean((torch.abs(learner_points - teacher_points)*xy_bias).sum(dim=-1)*decay, dim=-1) 54 | x_weight = torch.max( 55 | torch.mean(teacher_points[...,0],dim=-1), 56 | torch.mean(teacher_points[...,0]*-1.4,dim=-1), 57 | ) 58 | 59 | return loss_weight 60 | 61 | def weighted_random_choice(weights): 62 | t = np.cumsum(weights) 63 | s = np.sum(weights) 64 | return np.searchsorted(t, random.uniform(0,s)) 65 | 66 | def get_optimizer(parameters, lr=1e-4): 67 | optimizer = torch.optim.Adam(parameters, lr=1e-4) 68 | return optimizer 69 | 70 | def load_image_model(backbone, ckpt, device='cuda'): 71 | net = ImagePolicyModelSS( 72 | backbone, 73 | all_branch=True 74 | ).to(device) 75 | 76 | net.load_state_dict(torch.load(ckpt)) 77 | return net 78 | 79 | def _log_visuals(rgb_image, birdview, speed, command, loss, pred_locations, _pred_locations, _teac_locations, size=16): 80 | import cv2 81 | import numpy as np 82 | import utils.carla_utils as cu 83 | 84 | WHITE = [255, 255, 255] 85 | BLUE = [0, 0, 255] 86 | RED = [255, 0, 0] 87 | _numpy = lambda x: x.detach().cpu().numpy().copy() 88 | 89 | images = list() 90 | 91 | for i in range(min(birdview.shape[0], size)): 92 | loss_i = loss[i].sum() 93 | canvas = np.uint8(_numpy(birdview[i]).transpose(1, 2, 0) * 255).copy() 94 | canvas = cu.visualize_birdview(canvas) 95 | rgb = np.uint8(_numpy(rgb_image[i]).transpose(1, 2, 0) * 255).copy() 96 | rows = [x * (canvas.shape[0] // 10) for x in range(10+1)] 97 | cols = [x * (canvas.shape[1] // 10) for x in range(10+1)] 98 | 99 | def _write(text, i, j): 100 | cv2.putText( 101 | canvas, text, (cols[j], rows[i]), 102 | cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255,255,255), 1) 103 | 104 | def _dot(_canvas, i, j, color, radius=2): 105 | x, y = int(j), int(i) 106 | _canvas[x-radius:x+radius+1, y-radius:y+radius+1] = color 107 | 108 | def _stick_together(a, b): 109 | h = min(a.shape[0], b.shape[0]) 110 | 111 | r1 = h / a.shape[0] 112 | r2 = h / b.shape[0] 113 | 114 | a = cv2.resize(a, (int(r1 * a.shape[1]), int(r1 * a.shape[0]))) 115 | b = cv2.resize(b, (int(r2 * b.shape[1]), int(r2 * b.shape[0]))) 116 | 117 | return np.concatenate([a, b], 1) 118 | 119 | _command = { 120 | 1: 'LEFT', 2: 'RIGHT', 121 | 3: 'STRAIGHT', 4: 'FOLLOW'}.get(torch.argmax(command[i]).item()+1, '???') 122 | 123 | _dot(canvas, 0, 0, WHITE) 124 | 125 | for x, y in (_teac_locations[i] + 1) * (0.5 * CROP_SIZE): _dot(canvas, x, y, BLUE) 126 | for x, y in _pred_locations[i]: _dot(rgb, x, y, RED) 127 | for x, y in pred_locations[i]: _dot(canvas, x, y, RED) 128 | 129 | _write('Command: %s' % _command, 1, 0) 130 | _write('Loss: %.2f' % loss[i].item(), 2, 0) 131 | 132 | 133 | images.append((loss[i].item(), _stick_together(rgb, canvas))) 134 | 135 | return [x[1] for x in sorted(images, reverse=True, key=lambda x: x[0])] 136 | 137 | def load_birdview_model(backbone, ckpt, device='cuda'): 138 | teacher_net = BirdViewPolicyModelSS(backbone, all_branch=True).to(device) 139 | teacher_net.load_state_dict(torch.load(ckpt)) 140 | 141 | return teacher_net 142 | 143 | class CoordConverter(): 144 | def __init__(self, w=384, h=160, fov=90, world_y=1.4, fixed_offset=4.0, device='cuda'): 145 | self._img_size = torch.FloatTensor([w,h]).to(device) 146 | 147 | self._fov = fov 148 | self._world_y = world_y 149 | self._fixed_offset = fixed_offset 150 | print ("Fixed offset", fixed_offset) 151 | 152 | def __call__(self, camera_locations): 153 | if isinstance(camera_locations, torch.Tensor): 154 | camera_locations = (camera_locations + 1) * self._img_size/2 155 | else: 156 | camera_locations = (camera_locations + 1) * self._img_size.cpu().numpy()/2 157 | 158 | w, h = self._img_size 159 | w = int(w) 160 | h = int(h) 161 | 162 | cx, cy = w/2, h/2 163 | 164 | f = w /(2 * np.tan(self._fov * np.pi / 360)) 165 | 166 | xt = (camera_locations[...,0] - cx) / f 167 | yt = (camera_locations[...,1] - cy) / f 168 | 169 | world_z = self._world_y / yt 170 | world_x = world_z * xt 171 | 172 | if isinstance(camera_locations, torch.Tensor): 173 | map_output = torch.stack([world_x, world_z],dim=-1) 174 | else: 175 | map_output = np.stack([world_x,world_z],axis=-1) 176 | 177 | map_output *= PIXELS_PER_METER 178 | map_output[...,1] = CROP_SIZE - map_output[...,1] 179 | map_output[...,0] += CROP_SIZE/2 180 | map_output[...,1] += self._fixed_offset*PIXELS_PER_METER 181 | 182 | return map_output 183 | 184 | class LocationLoss(torch.nn.Module): 185 | def forward(self, pred_locations, teac_locations): 186 | pred_locations = pred_locations/(0.5*CROP_SIZE) - 1 187 | 188 | return torch.mean(torch.abs(pred_locations - teac_locations), dim=(1,2,3)) 189 | 190 | class ReplayBuffer(torch.utils.data.Dataset): 191 | def __init__(self, buffer_limit=100000, augment=None, sampling=True, aug_fix_iter=1000000, batch_aug=4): 192 | self.buffer_limit = buffer_limit 193 | self._data = [] 194 | self._weights = [] 195 | self.rgb_transform = transforms.ToTensor() 196 | 197 | self.birdview_transform = transforms.Compose([ 198 | transforms.ToTensor(), 199 | ]) 200 | 201 | if augment and augment != 'None': 202 | self.augmenter = getattr(augmenter, augment) 203 | else: 204 | self.augmenter = None 205 | 206 | self.normalized = False 207 | self._sampling = sampling 208 | self.aug_fix_iter = aug_fix_iter 209 | self.batch_aug = batch_aug 210 | 211 | def __len__(self): 212 | return len(self._data) 213 | 214 | def __getitem__(self, _idx): 215 | if self._sampling and self.normalized: 216 | while True: 217 | idx = weighted_random_choice(self._weights) 218 | if idx < len(self._data): 219 | break 220 | print ("waaat") 221 | else: 222 | idx = _idx 223 | 224 | rgb_img, cmd, speed, target, birdview_img = self._data[idx] 225 | if self.augmenter: 226 | rgb_imgs = [self.augmenter(self.aug_fix_iter).augment_image(rgb_img) for i in range(self.batch_aug)] 227 | else: 228 | rgb_imgs = [rgb_img for i in range(self.batch_aug)] 229 | 230 | rgb_imgs = [self.rgb_transform(img) for img in rgb_imgs] 231 | if self.batch_aug == 1: 232 | rgb_imgs = rgb_imgs[0] 233 | else: 234 | rgb_imgs = torch.stack(rgb_imgs) 235 | 236 | birdview_img = self.birdview_transform(birdview_img) 237 | 238 | return idx, rgb_imgs, cmd, speed, target, birdview_img 239 | 240 | def update_weights(self, idxes, losses): 241 | idxes = idxes.numpy() 242 | losses = losses.detach().cpu().numpy() 243 | for idx, loss in zip(idxes, losses): 244 | if idx > len(self._data): 245 | continue 246 | 247 | self._new_weights[idx] = loss 248 | 249 | def init_new_weights(self): 250 | self._new_weights = self._weights.copy() 251 | 252 | def normalize_weights(self): 253 | self._weights = self._new_weights 254 | self.normalized = True 255 | 256 | def add_data(self, rgb_img, cmd, speed, target, birdview_img, weight): 257 | self.normalized = False 258 | self._data.append((rgb_img, cmd, speed, target, birdview_img)) 259 | self._weights.append(weight) 260 | 261 | if len(self._data) > self.buffer_limit: 262 | # Pop the one with lowest loss 263 | idx = np.argsort(self._weights)[0] 264 | self._data.pop(idx) 265 | self._weights.pop(idx) 266 | 267 | 268 | def remove_data(self, idx): 269 | self._weights.pop(idx) 270 | self._data.pop(idx) 271 | 272 | def get_highest_k(self, k): 273 | top_idxes = np.argsort(self._weights)[-k:] 274 | rgb_images = [] 275 | bird_views = [] 276 | targets = [] 277 | cmds = [] 278 | speeds = [] 279 | 280 | for idx in top_idxes: 281 | if idx < len(self._data): 282 | rgb_img, cmd, speed, target, birdview_img = self._data[idx] 283 | rgb_images.append(TF.to_tensor(np.ascontiguousarray(rgb_img))) 284 | bird_views.append(TF.to_tensor(birdview_img)) 285 | cmds.append(cmd) 286 | speeds.append(speed) 287 | targets.append(target) 288 | 289 | return torch.stack(rgb_images), torch.stack(bird_views), torch.FloatTensor(cmds), torch.FloatTensor(speeds), torch.FloatTensor(targets) 290 | -------------------------------------------------------------------------------- /training/train_birdview.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import tqdm 9 | 10 | import glob 11 | import os 12 | import sys 13 | 14 | try: 15 | sys.path.append(glob.glob('../PythonAPI')[0]) 16 | sys.path.append(glob.glob('../bird_view')[0]) 17 | except IndexError as e: 18 | pass 19 | 20 | import utils.bz_utils as bzu 21 | 22 | from models.birdview import BirdViewPolicyModelSS 23 | from train_util import one_hot 24 | from utils.datasets.birdview_lmdb import get_birdview as load_data 25 | 26 | 27 | # Maybe experiment with this eventually... 28 | BACKBONE = 'resnet18' 29 | GAP = 5 30 | N_STEP = 5 31 | SAVE_EPOCHS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 384, 512, 768, 1000] 32 | 33 | class LocationLoss(torch.nn.Module): 34 | def __init__(self, w=192, h=192, choice='l2'): 35 | super(LocationLoss, self).__init__() 36 | 37 | # IMPORTANT(bradyz): loss per sample. 38 | if choice == 'l1': 39 | self.loss = lambda a, b: torch.mean(torch.abs(a - b), dim=(1,2)) 40 | elif choice == 'l2': 41 | self.loss = torch.nn.MSELoss() 42 | else: 43 | raise NotImplemented("Unknown loss: %s"%choice) 44 | 45 | self.img_size = torch.FloatTensor([w,h]).cuda() 46 | 47 | def forward(self, pred_location, gt_location): 48 | ''' 49 | Note that ground-truth location is [0,img_size] 50 | and pred_location is [-1,1] 51 | ''' 52 | gt_location = gt_location / (0.5 * self.img_size) - 1.0 53 | 54 | return self.loss(pred_location, gt_location) 55 | 56 | 57 | def _log_visuals(birdview, speed, command, loss, locations, _locations, size=16): 58 | import cv2 59 | import numpy as np 60 | import utils.carla_utils as cu 61 | 62 | WHITE = [255, 255, 255] 63 | BLUE = [0, 0, 255] 64 | RED = [255, 0, 0] 65 | _numpy = lambda x: x.detach().cpu().numpy().copy() 66 | 67 | images = list() 68 | 69 | for i in range(min(birdview.shape[0], size)): 70 | loss_i = loss[i].sum() 71 | canvas = np.uint8(_numpy(birdview[i]).transpose(1, 2, 0) * 255).copy() 72 | canvas = cu.visualize_birdview(canvas) 73 | rows = [x * (canvas.shape[0] // 10) for x in range(10+1)] 74 | cols = [x * (canvas.shape[1] // 10) for x in range(10+1)] 75 | 76 | def _write(text, i, j): 77 | cv2.putText( 78 | canvas, text, (cols[j], rows[i]), 79 | cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255,255,255), 1) 80 | 81 | def _dot(i, j, color, radius=2): 82 | x, y = int(j), int(i) 83 | canvas[x-radius:x+radius+1, y-radius:y+radius+1] = color 84 | 85 | _command = { 86 | 1: 'LEFT', 2: 'RIGHT', 87 | 3: 'STRAIGHT', 4: 'FOLLOW'}.get(torch.argmax(command[i]).item()+1, '???') 88 | 89 | _dot(0, 0, WHITE) 90 | 91 | for x, y in locations[i]: _dot(x, y, BLUE) 92 | for x, y in (_locations[i] + 1) * (0.5 * 192): _dot(x, y, RED) 93 | 94 | _write('Command: %s' % _command, 1, 0) 95 | _write('Loss: %.2f' % loss[i].item(), 2, 0) 96 | 97 | images.append((loss[i].item(), canvas)) 98 | 99 | return [x[1] for x in sorted(images, reverse=True, key=lambda x: x[0])] 100 | 101 | 102 | def train_or_eval(criterion, net, data, optim, is_train, config, is_first_epoch): 103 | if is_train: 104 | desc = 'Train' 105 | net.train() 106 | else: 107 | desc = 'Val' 108 | net.eval() 109 | 110 | total = 10 if is_first_epoch else len(data) 111 | iterator_tqdm = tqdm.tqdm(data, desc=desc, total=total) 112 | iterator = enumerate(iterator_tqdm) 113 | 114 | tick = time.time() 115 | 116 | for i, (birdview, location, command, speed) in iterator: 117 | birdview = birdview.to(config['device']) 118 | command = one_hot(command).to(config['device']) 119 | speed = speed.to(config['device']) 120 | location = location.float().to(config['device']) 121 | 122 | pred_location = net(birdview, speed, command) 123 | loss = criterion(pred_location, location) 124 | loss_mean = loss.mean() 125 | 126 | if is_train and not is_first_epoch: 127 | optim.zero_grad() 128 | loss_mean.backward() 129 | optim.step() 130 | 131 | should_log = False 132 | should_log |= i % config['log_iterations'] == 0 133 | should_log |= not is_train 134 | should_log |= is_first_epoch 135 | 136 | if should_log: 137 | metrics = dict() 138 | metrics['loss'] = loss_mean.item() 139 | 140 | images = _log_visuals( 141 | birdview, speed, command, loss, 142 | location, pred_location) 143 | 144 | bzu.log.scalar(is_train=is_train, loss_mean=loss_mean.item()) 145 | bzu.log.image(is_train=is_train, birdview=images) 146 | 147 | bzu.log.scalar(is_train=is_train, fps=1.0/(time.time() - tick)) 148 | 149 | tick = time.time() 150 | 151 | if is_first_epoch and i == 10: 152 | iterator_tqdm.close() 153 | break 154 | 155 | 156 | def train(config): 157 | bzu.log.init(config['log_dir']) 158 | bzu.log.save_config(config) 159 | 160 | data_train, data_val = load_data(**config['data_args']) 161 | criterion = LocationLoss(w=192, h=192, choice='l1') 162 | net = BirdViewPolicyModelSS(config['model_args']['backbone']).to(config['device']) 163 | 164 | if config['resume']: 165 | log_dir = Path(config['log_dir']) 166 | checkpoints = list(log_dir.glob('model-*.th')) 167 | checkpoint = str(checkpoints[-1]) 168 | print ("load %s"%checkpoint) 169 | net.load_state_dict(torch.load(checkpoint)) 170 | 171 | optim = torch.optim.Adam(net.parameters(), lr=config['optimizer_args']['lr']) 172 | 173 | for epoch in tqdm.tqdm(range(config['max_epoch']+1), desc='Epoch'): 174 | train_or_eval(criterion, net, data_train, optim, True, config, epoch == 0) 175 | train_or_eval(criterion, net, data_val, None, False, config, epoch == 0) 176 | 177 | if epoch in SAVE_EPOCHS: 178 | torch.save( 179 | net.state_dict(), 180 | str(Path(config['log_dir']) / ('model-%d.th' % epoch))) 181 | 182 | bzu.log.end_epoch() 183 | 184 | 185 | if __name__ == '__main__': 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--log_dir', required=True) 188 | parser.add_argument('--log_iterations', default=1000) 189 | parser.add_argument('--max_epoch', default=1000) 190 | 191 | # Dataset. 192 | parser.add_argument('--dataset_dir', default='/raid0/dian/carla_0.9.6_data') 193 | parser.add_argument('--batch_size', type=int, default=256) 194 | parser.add_argument('--x_jitter', type=int, default=5) 195 | parser.add_argument('--y_jitter', type=int, default=0) 196 | parser.add_argument('--angle_jitter', type=int, default=5) 197 | parser.add_argument('--gap', type=int, default=5) 198 | parser.add_argument('--max_frames', type=int, default=None) 199 | parser.add_argument('--cmd-biased', action='store_true', default=False) 200 | parser.add_argument('--resume', action='store_true') 201 | 202 | # Optimizer. 203 | parser.add_argument('--lr', type=float, default=1e-4) 204 | 205 | parsed = parser.parse_args() 206 | 207 | config = { 208 | 'log_dir': parsed.log_dir, 209 | 'resume': parsed.resume, 210 | 'log_iterations': parsed.log_iterations, 211 | 'max_epoch': parsed.max_epoch, 212 | 'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 213 | 'optimizer_args': {'lr': parsed.lr}, 214 | 'data_args': { 215 | 'dataset_dir': parsed.dataset_dir, 216 | 'batch_size': parsed.batch_size, 217 | 'n_step': N_STEP, 218 | 'gap': GAP, 219 | 'crop_x_jitter': parsed.x_jitter, 220 | 'crop_y_jitter': parsed.y_jitter, 221 | 'angle_jitter': parsed.angle_jitter, 222 | 'max_frames': parsed.max_frames, 223 | 'cmd_biased': parsed.cmd_biased, 224 | }, 225 | 'model_args': { 226 | 'model': 'birdview_dian', 227 | 'input_channel': 7, 228 | 'backbone': BACKBONE, 229 | }, 230 | } 231 | 232 | train(config) 233 | -------------------------------------------------------------------------------- /view_benchmark_results.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import pandas as pd 4 | from terminaltables import DoubleTable 5 | from pathlib import Path 6 | 7 | 8 | def main(path_name): 9 | 10 | performance = dict() 11 | 12 | path = Path(path_name) 13 | for summary_path in path.glob('*/summary.csv'): 14 | name = summary_path.parent.name 15 | match = re.search('^(?P.*Town.*-v[0-9]+.*)_seed(?P[0-9]+)', name) 16 | suite_name = match.group('suite_name') 17 | seed = match.group('seed') 18 | 19 | summary = pd.read_csv(summary_path) 20 | 21 | if suite_name not in performance: 22 | performance[suite_name] = dict() 23 | 24 | performance[suite_name][seed] = (summary['success'].sum(), len(summary)) 25 | 26 | table_data = [] 27 | for suite_name, seeds in performance.items(): 28 | 29 | successes, totals = np.array(list(zip(*seeds.values()))) 30 | rates = successes / totals * 100 31 | 32 | if len(seeds) > 1: 33 | table_data.append([suite_name, "%.1f ± %.1f"%(np.mean(rates), np.std(rates, ddof=1)), "%d/%d"%(sum(successes),sum(totals)), ','.join(sorted(seeds.keys()))]) 34 | else: 35 | table_data.append([suite_name, "%d"%np.mean(rates), "%d/%d"%(sum(successes),sum(totals)), ','.join(sorted(seeds.keys()))]) 36 | 37 | table_data = sorted(table_data, key=lambda row: row[0]) 38 | table_data = [('Suite Name', 'Success Rate', 'Total', 'Seeds')] + table_data 39 | table = DoubleTable(table_data, "Performance of %s"%path.name) 40 | print(table.table) 41 | 42 | 43 | 44 | if __name__ == '__main__': 45 | import argparse 46 | 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument('path', help='path of benchmark folder') 50 | 51 | args = parser.parse_args() 52 | main(args.path) --------------------------------------------------------------------------------