├── Auto_Driving_Highway ├── agent_train.py ├── analysis_obs.py ├── ask_llm.py ├── baseClass.py ├── customTools.py ├── pre_prompt.py ├── requirements.txt ├── scenario.py ├── test_DQN.py ├── test_chat.py └── train_logger.py ├── LICENSE ├── README.md └── Results ├── Casestudy.png ├── Framework.png ├── Prompt_CaseStudy.jpg ├── aggressive.gif ├── conservative.gif └── prompt_Structure.png /Auto_Driving_Highway/agent_train.py: -------------------------------------------------------------------------------- 1 | import highway_env 2 | import numpy as np 3 | import gymnasium as gym 4 | from gym import spaces 5 | from stable_baselines3 import DQN 6 | from stable_baselines3.common.vec_env import DummyVecEnv # Import DummyVecEnv 7 | import random 8 | from scenario import Scenario 9 | from customTools import ( 10 | getAvailableActions, 11 | getAvailableLanes, 12 | getLaneInvolvedCar, 13 | isChangeLaneConflictWithCar, 14 | isAccelerationConflictWithCar, 15 | isKeepSpeedConflictWithCar, 16 | isDecelerationSafe, 17 | isActionSafe 18 | ) 19 | from analysis_obs import available_action, get_available_lanes, get_involved_cars, extract_lanes_info, extract_lane_and_car_ids, assess_lane_change_safety, check_safety_in_current_lane, format_training_info 20 | import ask_llm 21 | 22 | ACTIONS_ALL = { 23 | 0: 'LANE_LEFT', 24 | 1: 'IDLE', 25 | 2: 'LANE_RIGHT', 26 | 3: 'FASTER', 27 | 4: 'SLOWER' 28 | } 29 | 30 | class MyHighwayEnv(gym.Env): 31 | def __init__(self, vehicleCount=15): 32 | super(MyHighwayEnv, self).__init__() 33 | # base setting 34 | self.vehicleCount = vehicleCount 35 | # environment setting 36 | self.config = { 37 | "observation": { 38 | "type": "Kinematics", 39 | "features": ["presence", "x", "y", "vx", "vy"], 40 | "absolute": True, 41 | "normalize": False, 42 | "vehicles_count": vehicleCount, 43 | "see_behind": True, 44 | }, 45 | "action": { 46 | "type": "DiscreteMetaAction", 47 | "target_speeds": np.linspace(0, 32, 9), 48 | }, 49 | "duration": 40, 50 | "vehicles_density": 2, 51 | "show_trajectories": True, 52 | "render_agent": True, 53 | } 54 | self.env = gym.make("highway-v0") 55 | self.env.configure(self.config) 56 | self.action_space = self.env.action_space 57 | self.observation_space = self.env.observation_space 58 | # self.observation_space = spaces.Box( 59 | # low=-np.inf, high=np.inf, shape=(vehicleCount, 5), dtype=np.float32 60 | # ) 61 | 62 | 63 | def step(self, action): 64 | # Step the wrapped environment and capture all returned values 65 | obs, reward, done, truncated, info = self.env.step(action) 66 | custom_reward = self.calculate_custom_reward(action) 67 | return obs, custom_reward, done, truncated, info 68 | def set_llm_suggested_action(self, action): 69 | self.llm_suggested_action = action 70 | def calculate_custom_reward(self, action): 71 | if action == self.llm_suggested_action: 72 | return 1 # Reward for matching action 73 | else: 74 | return 0 75 | 76 | def reset(self, **kwargs): 77 | obs = self.env.reset(**kwargs) 78 | return obs # Make sure to return the observation 79 | 80 | def get_available_actions(self): 81 | """Get the list of available actions from the underlying Highway environment.""" 82 | if hasattr(self.env, 'get_available_actions'): 83 | return self.env.get_available_actions() 84 | else: 85 | raise NotImplementedError( 86 | "The method get_available_actions is not implemented in the underlying environment.") 87 | def main(): 88 | env = MyHighwayEnv(vehicleCount=5) 89 | observation = env.reset() 90 | print("Initial Observation:", observation) 91 | print("Observation space:", env.observation_space) 92 | # print("Action space:", env.action_space) 93 | 94 | # Wrap the environment in a DummyVecEnv for SB3 95 | env = DummyVecEnv([lambda: env]) # Add this line 96 | available_actions = env.envs[0].get_available_actions() 97 | model = DQN( 98 | "MlpPolicy", 99 | env, 100 | verbose=0, 101 | train_freq=2, 102 | learning_starts=20, 103 | exploration_fraction=0.5, 104 | learning_rate=0.0001, 105 | ) 106 | # Initialize scenario and tools 107 | sce = Scenario(vehicleCount=5) 108 | toolModels = [ 109 | getAvailableActions(env.envs[0]), 110 | getAvailableLanes(sce), 111 | getLaneInvolvedCar(sce), 112 | isChangeLaneConflictWithCar(sce), 113 | isAccelerationConflictWithCar(sce), 114 | isKeepSpeedConflictWithCar(sce), 115 | isDecelerationSafe(sce), 116 | # isActionSafe() 117 | ] 118 | frame = 0 119 | for _ in range(10): 120 | obs = env.reset() 121 | done = False 122 | while not done: 123 | sce.updateVehicles(obs, frame) 124 | # Observation translation 125 | msg0 = available_action(toolModels) 126 | msg1 = get_available_lanes(toolModels) 127 | msg2 = get_involved_cars((toolModels)) 128 | msg1_info = next(iter(msg1.values())) 129 | lanes_info = extract_lanes_info(msg1_info) 130 | 131 | lane_car_ids = extract_lane_and_car_ids(lanes_info, msg2) 132 | safety_assessment = assess_lane_change_safety(toolModels, lane_car_ids) 133 | safety_msg = check_safety_in_current_lane(toolModels, lane_car_ids) 134 | formatted_info = format_training_info(msg0, msg1, msg2, lanes_info, lane_car_ids, safety_assessment, 135 | safety_msg) 136 | 137 | action, _ = model.predict(obs) 138 | action_id = int(action[0]) 139 | action_name = ACTIONS_ALL.get(action_id, "Unknown Action") 140 | print(f"DQN action: {action_name}") 141 | 142 | llm_response = ask_llm.send_to_chatgpt(action, formatted_info, sce) 143 | decision_content = llm_response.content 144 | print(llm_response) 145 | llm_suggested_action = extract_decision(decision_content) 146 | print(f"llm action: {llm_suggested_action}") 147 | 148 | env.env_method('set_llm_suggested_action', llm_suggested_action) 149 | # print(f"Action: {action}") 150 | # print(f"Observation: {next_obs}") 151 | 152 | obs, custom_reward, done, info = env.step(action) 153 | print(f"Reward: {custom_reward}\n") 154 | frame += 1 155 | model.learn(total_timesteps=10, reset_num_timesteps=False) 156 | 157 | obs = env.reset() 158 | for step in range(1000): 159 | action, _states = model.predict(obs, deterministic=True) 160 | obs, rewards, dones, info = env.step(action) 161 | 162 | print(f"Reward: {rewards}\n") 163 | 164 | env.close() 165 | 166 | # utils.py 167 | def extract_decision(response_content): 168 | try: 169 | start = response_content.find('"decision": {') + len('"decision": {') 170 | end = response_content.find('}', start) 171 | decision = response_content[start:end].strip('"') 172 | return decision 173 | except Exception as e: 174 | print(f"Error in extracting decision: {e}") 175 | return None 176 | 177 | 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/analysis_obs.py: -------------------------------------------------------------------------------- 1 | from customTools import ( 2 | getAvailableActions, 3 | getAvailableLanes, 4 | getLaneInvolvedCar, 5 | isChangeLaneConflictWithCar, 6 | isAccelerationConflictWithCar, 7 | isKeepSpeedConflictWithCar, 8 | isDecelerationSafe, 9 | isActionSafe 10 | ) 11 | def available_action(toolModels): 12 | 13 | available_action_tool = next((tool for tool in toolModels if isinstance(tool, getAvailableActions)), None) 14 | # Use tools to analyze the situation 15 | available_action = {} 16 | ego_vehicle_id = 'ego' 17 | available_lanes_analysis = available_action_tool.inference(ego_vehicle_id) 18 | available_action[available_action_tool] = available_lanes_analysis 19 | 20 | return available_action 21 | 22 | def get_available_lanes(toolModels): 23 | 24 | available_lanes_tool = next((tool for tool in toolModels if isinstance(tool, getAvailableLanes)), None) 25 | # Use tools to analyze the situation 26 | situation_analysis = {} 27 | ego_vehicle_id = 'ego' 28 | available_lanes_analysis = available_lanes_tool.inference(ego_vehicle_id) 29 | situation_analysis[available_lanes_tool] = available_lanes_analysis 30 | 31 | return situation_analysis 32 | 33 | def get_involved_cars(toolModels): 34 | lane_cars_info = {} 35 | lane_involved_car_tool = next((tool for tool in toolModels if isinstance(tool, getLaneInvolvedCar)), None) 36 | lane_ids=['lane_0', 'lane_1', 'lane_2'] 37 | for lane_id in lane_ids: 38 | cars_in_lane_info = lane_involved_car_tool.inference(lane_id) 39 | lane_cars_info[lane_id] = cars_in_lane_info 40 | 41 | return lane_cars_info 42 | 43 | # function get info from available lanes 44 | def extract_lanes_info(available_lanes_info): 45 | lanes = { 46 | 'current': None, 47 | 'left': None, 48 | 'right': None 49 | } 50 | 51 | parts = available_lanes_info.split(". ") 52 | for part in parts: 53 | if "is the current lane" in part: 54 | lanes['current'] = part.split("`")[1] # Extract the current lane 55 | elif "to the left of the current lane" in part: 56 | lanes['left'] = part.split("`")[1] # Extract the left adjacent lane 57 | elif "to the right of the current lane" in part: 58 | lanes['right'] = part.split("`")[1] # Extract the right adjacent lane 59 | 60 | return lanes 61 | 62 | # These two functions get info from get_involved_cars 63 | 64 | def extract_car_id_from_info(lane_info): 65 | # Extracts the car ID from the lane information string 66 | if "is driving" in lane_info: 67 | parts = lane_info.split() 68 | car_id_index = parts.index("is") - 1 69 | return parts[car_id_index] 70 | return None 71 | 72 | def extract_lane_and_car_ids(lanes_info, lane_cars_info): 73 | lane_car_ids = { 74 | 'current_lane': {'lane_id': None, 'car_id': None}, 75 | 'left_lane': {'lane_id': None, 'car_id': None}, 76 | 'right_lane': {'lane_id': None, 'car_id': None} 77 | } 78 | current_lane_id = lanes_info['current'] 79 | left_lane_id = lanes_info['left'] 80 | right_lane_id = lanes_info['right'] 81 | 82 | # Extract car ID for the left adjacent lane, if it exists 83 | if current_lane_id and current_lane_id in lane_cars_info: 84 | current_lane_info = lane_cars_info[current_lane_id] 85 | current_car_id = extract_car_id_from_info(current_lane_info) 86 | lane_car_ids['current_lane'] = {'lane_id': current_lane_id, 'car_id': current_car_id} 87 | 88 | # Extract car ID for the left adjacent lane, if it exists 89 | if left_lane_id and left_lane_id in lane_cars_info: 90 | left_lane_info = lane_cars_info[left_lane_id] 91 | left_car_id = extract_car_id_from_info(left_lane_info) 92 | lane_car_ids['left_lane'] = {'lane_id': left_lane_id, 'car_id': left_car_id} 93 | 94 | # Extract car ID for the right adjacent lane, if it exists 95 | if right_lane_id and right_lane_id in lane_cars_info: 96 | right_lane_info = lane_cars_info[right_lane_id] 97 | right_car_id = extract_car_id_from_info(right_lane_info) 98 | lane_car_ids['right_lane'] = {'lane_id': right_lane_id, 'car_id': right_car_id} 99 | 100 | return lane_car_ids 101 | 102 | # F 103 | def assess_lane_change_safety(toolModels, lane_car_ids): 104 | lane_change_tool = next((tool for tool in toolModels if isinstance(tool, isChangeLaneConflictWithCar)), None) 105 | safety_assessment = { 106 | 'left_lane_change_safe': True, 107 | 'right_lane_change_safe': True 108 | } 109 | 110 | # Check if changing to the left lane is safe 111 | if lane_car_ids['left_lane']['lane_id'] and lane_car_ids['left_lane']['car_id']: 112 | left_lane_id = lane_car_ids['left_lane']['lane_id'] 113 | left_car_id = lane_car_ids['left_lane']['car_id'] 114 | input_str = f"{left_lane_id},{left_car_id}" 115 | left_lane_safety = lane_change_tool.inference(input_str) 116 | safety_assessment['left_lane_change_safe'] = 'safe' in left_lane_safety 117 | else: 118 | # If no car is in the left lane, consider it safe to change 119 | safety_assessment['left_lane_change_safe'] = True 120 | 121 | # Check if changing to the right lane is safe 122 | if lane_car_ids['right_lane']['lane_id'] and lane_car_ids['right_lane']['car_id']: 123 | right_lane_id = lane_car_ids['right_lane']['lane_id'] 124 | right_car_id = lane_car_ids['right_lane']['car_id'] 125 | input_str = f"{right_lane_id},{right_car_id}" 126 | right_lane_safety = lane_change_tool.inference(input_str) 127 | safety_assessment['right_lane_change_safe'] = 'safe' in right_lane_safety 128 | else: 129 | # If no car is in the right lane, consider it safe to change 130 | safety_assessment['right_lane_change_safe'] = True 131 | 132 | return safety_assessment 133 | 134 | 135 | def check_safety_in_current_lane(toolModels, lane_and_car_ids): 136 | safety_analysis = { 137 | 'acceleration_conflict': None, 138 | 'keep_speed_conflict': None, 139 | 'deceleration_conflict': None 140 | } 141 | 142 | # Extract tools from toolModels 143 | acceleration_tool = next((tool for tool in toolModels if isinstance(tool, isAccelerationConflictWithCar)), None) 144 | keep_speed_tool = next((tool for tool in toolModels if isinstance(tool, isKeepSpeedConflictWithCar)), None) 145 | deceleration_tool = next((tool for tool in toolModels if isinstance(tool, isDecelerationSafe)), None) 146 | 147 | current_lane_car_id = lane_and_car_ids['current_lane']['car_id'] 148 | 149 | if current_lane_car_id: 150 | # Check for conflicts if there is a car in the current lane 151 | if acceleration_tool: 152 | safety_analysis['acceleration_conflict'] = acceleration_tool.inference(current_lane_car_id) 153 | if keep_speed_tool: 154 | safety_analysis['keep_speed_conflict'] = keep_speed_tool.inference(current_lane_car_id) 155 | if deceleration_tool: 156 | safety_analysis['deceleration_conflict'] = deceleration_tool.inference(current_lane_car_id) 157 | 158 | return safety_analysis 159 | 160 | def get_current_speed(toolModels): 161 | lane_cars_info = {} 162 | lane_involved_car_tool = next((tool for tool in toolModels if isinstance(tool, getLaneInvolvedCar)), None) 163 | ego_speed = round(lane_involved_car_tool.sce.vehicles['ego'].speed, 1) 164 | speed_info = f"Your current speed is {ego_speed} m/s.\n\n" 165 | return speed_info 166 | 167 | def format_training_info(available_actions_msg, lanes_info_msg, speed_info, all_lane_info_msg, lanes_adjacent_info, cars_near_lane, lane_change_safety, current_lane_safety): 168 | formatted_message = "" 169 | 170 | # Add available actions information 171 | formatted_message += "Available Actions:\n" 172 | for tool, action_info in available_actions_msg.items(): 173 | formatted_message += f"- {action_info}\n" 174 | 175 | # Add information about lanes 176 | formatted_message += "\nLane Information:\n" 177 | formatted_message += f"- Current Lane: {lanes_adjacent_info['current']}\n" 178 | formatted_message += f"- Left Adjacent Lane: {lanes_adjacent_info['left'] or 'None'}\n" 179 | formatted_message += f"- Right Adjacent Lane: {lanes_adjacent_info['right'] or 'None'}\n" 180 | 181 | # Add details about vehicles in each lane 182 | formatted_message += f"{speed_info}\n" 183 | formatted_message += "\nOther Vehicles in all the Lanes:\n" 184 | for lane_id, car_info in all_lane_info_msg.items(): 185 | formatted_message += f"- {lane_id}: {car_info}\n" 186 | 187 | # Safety assessment for lane changes 188 | formatted_message += "\nSafety Assessment for Lane Changes:\n" 189 | formatted_message += f"- Left Lane Change: {'Safe' if lane_change_safety['left_lane_change_safe'] else 'Not Safe'}\n" 190 | formatted_message += f"- Right Lane Change: {'Safe' if lane_change_safety['right_lane_change_safe'] else 'Not Safe'}\n" 191 | 192 | # Safety assessment in the current lane 193 | formatted_message += "\nSafety Assessment in Current Lane:\n" 194 | for action, safety in current_lane_safety.items(): 195 | formatted_message += f"- {action.capitalize().replace('_', ' ')}: {safety}\n" 196 | 197 | return formatted_message -------------------------------------------------------------------------------- /Auto_Driving_Highway/ask_llm.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | import pre_prompt 3 | from logging import lastResort 4 | 5 | with open('MY_KEY.txt', 'r') as f: 6 | api_key = f.read() 7 | API_KEY = api_key 8 | 9 | 10 | ACTIONS_ALL = { 11 | 0: 'LANE_LEFT', 12 | 1: 'IDLE', 13 | 2: 'LANE_RIGHT', 14 | 3: 'FASTER', 15 | 4: 'SLOWER' 16 | } 17 | 18 | ACTIONS_DESCRIPTION = { 19 | 0: 'change lane to the left of the current lane,', 20 | 1: 'remain in the current lane with current speed', 21 | 2: 'change lane to the right of the current lane', 22 | 3: 'accelerate the vehicle', 23 | 4: 'decelerate the vehicle' 24 | } 25 | 26 | def send_to_chatgpt(last_action, current_scenario, sce): 27 | client = OpenAI(api_key=API_KEY) 28 | print("=========================",type(last_action),"=========================") 29 | action_id = int(last_action) # Convert to integer 30 | message_prefix = pre_prompt.SYSTEM_MESSAGE_PREFIX 31 | traffic_rules = pre_prompt.get_traffic_rules() 32 | decision_cautions = pre_prompt.get_decision_cautions() 33 | action_name = ACTIONS_ALL.get(action_id, "Unknown Action") 34 | action_description = ACTIONS_DESCRIPTION.get(action_id, "No description available") 35 | 36 | prompt = (f"{message_prefix}" 37 | f"You, the 'ego' car, are now driving on a highway. You have already driven for {sce.frame} seconds.\n" 38 | "There are several rules you need to follow when you drive on a highway:\n" 39 | f"{traffic_rules}\n\n" 40 | "Here are your attention points:\n" 41 | f"{decision_cautions}\n\n" 42 | "Once you make a final decision, output it in the following format:\n" 43 | "```\n" 44 | "Final Answer: \n" 45 | " \"decision\": {\"\"},\n" 46 | "```\n") 47 | user_prompt = (f"The decision made by the agent LAST time step was `{action_name}` ({action_description}).\n\n" 48 | "Here is the current scenario:\n" 49 | f"{current_scenario}\n\n") 50 | completion = client.chat.completions.create( 51 | model="gpt-3.5-turbo", 52 | messages=[ 53 | {"role": "system", "content": prompt}, 54 | {"role": "user", "content": user_prompt} 55 | ] 56 | ) 57 | 58 | return completion.choices[0].message -------------------------------------------------------------------------------- /Auto_Driving_Highway/baseClass.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Tuple, Union 2 | from dataclasses import dataclass, field 3 | from math import sqrt 4 | 5 | 6 | @dataclass 7 | class Lane: 8 | id: str 9 | laneIdx: int 10 | left_lanes: List[str] = field(default_factory=list) 11 | right_lanes: List[str] = field(default_factory=list) 12 | 13 | def export2json(self): 14 | return { 15 | 'id': self.id, 16 | 'lane index': self.laneIdx, 17 | 'left_lanes': self.left_lanes, 18 | 'right_lanes': self.right_lanes, 19 | } 20 | 21 | 22 | @dataclass 23 | class Vehicle: 24 | id: str 25 | lane_id: str = '' 26 | x: float = 0.0 27 | y: float = 0.0 28 | speedx: float = 0.0 29 | speedy: float = 0.0 30 | presence: bool = False 31 | 32 | def clear(self) -> None: 33 | self.lane_id = '' 34 | self.x = 0.0 35 | self.y = 0.0 36 | self.speedx = 0.0 37 | self.speedy = 0.0 38 | self.presence = False 39 | 40 | def updateProperty( 41 | self, x: float, y: float, vx: float, vy: float 42 | ) -> None: 43 | self.x = x 44 | self.y = y 45 | self.speedx = vx 46 | self.speedy = vy 47 | laneIdx = round(y/4.0) 48 | self.lane_id = 'lane_' + str(laneIdx) 49 | 50 | @property 51 | def speed(self) -> float: 52 | return sqrt(pow(self.speedx, 2) + pow(self.speedy, 2)) 53 | 54 | @property 55 | def lanePosition(self) -> float: 56 | return self.x 57 | 58 | def export2json(self) -> Dict: 59 | return { 60 | 'id': self.id, 61 | 'current lane': self.lane_id, 62 | # float() is used to transfer np.float32 to float, since np.float32 63 | # can not be serialized by JSON 64 | 'lane position': round(float(self.x), 2), 65 | 'speed': round(float(self.speed), 2), 66 | } 67 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/customTools.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from scenario import Scenario 3 | 4 | def prompts(name, description): 5 | def decorator(func): 6 | func.name = name 7 | func.description = description 8 | return func 9 | 10 | return decorator 11 | 12 | 13 | ACTIONS_ALL = { 14 | 0: 'LANE_LEFT', 15 | 1: 'IDLE', 16 | 2: 'LANE_RIGHT', 17 | 3: 'FASTER', 18 | 4: 'SLOWER' 19 | } 20 | 21 | ACTIONS_DESCRIPTION = { 22 | 0: 'change lane to the left of the current lane,', 23 | 1: 'remain in the current lane with current speed', 24 | 2: 'change lane to the right of the current lane', 25 | 3: 'accelerate the vehicle', 26 | 4: 'decelerate the vehicle' 27 | } 28 | 29 | 30 | class getAvailableActions: 31 | def __init__(self, env: Any) -> None: 32 | self.env = env 33 | 34 | @prompts(name='Get Available Actions', 35 | description="""Useful before you make decisions, this tool let you know what are your available actions in this situation. The input to this tool should be 'ego'.""") 36 | def inference(self, input: str) -> str: 37 | outputPrefix = 'You can ONLY use one of the following actions: \n ' 38 | availableActions = self.env.get_available_actions() 39 | for action in availableActions: 40 | outputPrefix += ACTIONS_ALL[action] + \ 41 | '--' + ACTIONS_DESCRIPTION[action] + '; \n' 42 | if 1 in availableActions: 43 | outputPrefix += 'You should check idle and faster action as FIRST priority. ' 44 | 45 | if 0 in availableActions or 2 in availableActions: 46 | outputPrefix += 'For change lane action, CAREFULLY CHECK the safety of vehicles on target lane. ' 47 | if 3 in availableActions: 48 | outputPrefix += 'Consider acceleration action carefully. ' 49 | if 4 in availableActions: 50 | outputPrefix += 'The deceleration action is LAST priority. ' 51 | outputPrefix += """\nTo check decision safety you should follow steps: 52 | Step 1: Get the vehicles in this lane that you may affect. Acceleration, deceleration and idle action affect the Current lane, while left and right lane changes affect the Adjacent lane. 53 | Step 2: If there are vehicles, check safety between ego and all vehicles in the action lane ONE by ONE. 54 | Step 3: If you find There is no car driving on your "current lane" you can drive faster ! but not too fast to follow the traffic rules. 55 | Step 4: If you want to make lane change consider :"Safety Assessment for Lane Changes:" Safe means it is safe to change ,If you want to do IDLE, FASTER, SLOWER, you should consider "Safety Assessment in Current Lane:" 56 | """ 57 | return outputPrefix 58 | 59 | 60 | class isActionSafe: 61 | def __init__(self) -> None: 62 | pass 63 | 64 | # @prompts(name='Check Action Safety', 65 | # description="""Use this tool when you want to check the proposed action's safety. The input to this tool should be a string, which is ONLY the action name.""") 66 | @prompts(name='Decision-making Instructions', 67 | description="""This tool gives you a brief intruduction about how to ensure that the action you make is safe. The input to this tool should be a string, which is ONLY the action name.""") 68 | def inference(self, action: str) -> str: 69 | return f"""To check action safety you should follow three steps: 70 | Step 1: Identify the lanes affected by this action. Acceleration, deceleration and idle affect the current lane, while left and right lane changes affect the corresponding lane. 71 | Step 2:(Optional) Get the vehicles in this lane that you may affect, ONLY when you don't know. 72 | Step 3: If there are vehicles, check safety between ego and all vehicles in the action lane ONE by ONE. 73 | Follow the instructions and remember to use the proper tools mentioned in the tool list once a time. 74 | """ 75 | 76 | 77 | class getAvailableLanes: 78 | def __init__(self, sce: Scenario) -> None: 79 | self.sce = sce 80 | 81 | @prompts(name='Get Available Lanes', 82 | description="""useful when you want to know the available lanes of the vehicles. like: I want to know the available lanes of the vehicle `ego`. The input to this tool should be a string, representing the id of the vehicle.""") 83 | def inference(self, vid: str) -> str: 84 | veh = self.sce.vehicles[vid] 85 | currentLaneID = veh.lane_id 86 | laneIdx = self.sce.lanes[currentLaneID].laneIdx 87 | if laneIdx == 2: 88 | leftLane = 'lane_1' 89 | return f"""The availabel lane of `{vid}` is `{leftLane}` and `{currentLaneID}`. `{leftLane}` is to the left of the current lane. `{currentLaneID}` is the current lane.""" 90 | elif laneIdx == 0: 91 | rightLane = 'lane_1' 92 | return f"""The availabel lane of `{vid}` is `{currentLaneID}` and `{rightLane}`. `{currentLaneID}` is the current lane. `{rightLane}` is to the right of the current lane.""" 93 | else: 94 | leftLane = 'lane_' + str(laneIdx-1) 95 | rightLane = 'lane_' + str(laneIdx+1) 96 | return f"""The availabel lane of `{vid}` is `{currentLaneID}`, `{rightLane}` and {leftLane}. `{currentLaneID}` is the current lane. `{rightLane}` is to the right of the current lane. `{leftLane}` is to the left of the current lane.""" 97 | 98 | 99 | class getLaneInvolvedCar: 100 | def __init__(self, sce: Scenario) -> None: 101 | self.sce = sce 102 | 103 | @prompts(name='Get Lane Involved Car', 104 | description="""useful whent want to know the cars may affect your action in the certain lane. Make sure you have use tool `Get Available Lanes` first. The input is a string, representing the id of the specific lane you want to drive on, DONNOT input multiple lane_id once.""") 105 | def inference(self, laneID: str) -> str: 106 | if laneID not in {'lane_0', 'lane_1', 'lane_2', 'lane_3'}: 107 | return "Not a valid lane id! Make sure you have use tool `Get Available Lanes` first." 108 | ego = self.sce.vehicles['ego'] 109 | laneVehicles = [] 110 | for vk, vv in self.sce.vehicles.items(): 111 | if vk != 'ego': 112 | if vv.lane_id == laneID: 113 | laneVehicles.append((vv.id, vv.lanePosition)) 114 | laneVehicles.sort(key=lambda x: x[1]) 115 | leadingCarIdx = -1 116 | for i in range(len(laneVehicles)): 117 | vp = laneVehicles[i] 118 | if vp[1] >= ego.lanePosition: 119 | leadingCarIdx = i 120 | break 121 | if leadingCarIdx == -1: 122 | try: 123 | rearingCar = laneVehicles[-1][0] 124 | except IndexError: 125 | return f'There is no car driving on {laneID}, This lane is safe, you donot need to check for any vehicle for safety! you can drive on this lane as fast as you can.' 126 | return f"{rearingCar} is driving on {laneID}, and it's driving behind ego car. You need to make sure that your actions do not conflict with each of the vehicles mentioned." 127 | elif leadingCarIdx == 0: 128 | leadingCar = laneVehicles[0][0] 129 | distance = round(laneVehicles[0][1] - ego.lanePosition, 2) 130 | leading_car_vel = round(self.sce.vehicles[leadingCar].speed,1) 131 | return f"{leadingCar} is driving at {leading_car_vel}m/s on {laneID}, and it's driving in front of ego car for {distance} meters. You need to make sure that your actions do not conflict with each of the vehicles mentioned." 132 | else: 133 | leadingCar = laneVehicles[leadingCarIdx][0] 134 | rearingCar = laneVehicles[leadingCarIdx-1][0] 135 | distance = round(laneVehicles[leadingCarIdx][1] - ego.lanePosition, 2) 136 | leading_car_vel = round(self.sce.vehicles[leadingCar].speed,1) 137 | return f"{leadingCar} and {rearingCar} is driving on {laneID}, and {leadingCar} is driving at {leading_car_vel}m/s in front of ego car for {distance} meters, while {rearingCar} is driving behind ego car. You need to make sure that your actions do not conflict with each of the vehicles mentioned." 138 | 139 | 140 | class isChangeLaneConflictWithCar: 141 | def __init__(self, sce: Scenario) -> None: 142 | self.sce = sce 143 | self.TIME_HEAD_WAY = 3.0 144 | self.VEHICLE_LENGTH = 5.0 145 | 146 | @prompts(name='Is Change Lane Confict With Car', 147 | description="""useful when you want to know whether change lane to a specific lane is confict with a specific car, ONLY when your decision is change_lane_left or change_lane_right. The input to this tool should be a string of a comma separated string of two, representing the id of the lane you want to change to and the id of the car you want to check.""") 148 | def inference(self, inputs: str) -> str: 149 | laneID, vid = inputs.replace(' ', '').split(',') 150 | if vid not in self.sce.vehicles: 151 | return "Your input is not a valid vehicle id, make sure you use `Get Lane Involved Car` tool first!" 152 | veh = self.sce.vehicles[vid] 153 | ego = self.sce.vehicles['ego'] 154 | if veh.lanePosition >= ego.lanePosition: 155 | relativeSpeed = ego.speed - veh.speed 156 | if veh.lanePosition - ego.lanePosition - self.VEHICLE_LENGTH > self.TIME_HEAD_WAY * relativeSpeed: 157 | return f"change lane to `{laneID}` is safe with `{vid}`." 158 | else: 159 | return f"change lane to `{laneID}` may be conflict with `{vid}`, which is unacceptable." 160 | else: 161 | relativeSpeed = veh.speed - ego.speed 162 | if ego.lanePosition - veh.lanePosition - self.VEHICLE_LENGTH > self.TIME_HEAD_WAY * relativeSpeed: 163 | return f"change lane to `{laneID}` is safe with `{vid}`." 164 | else: 165 | return f"change lane to `{laneID}` may be conflict with `{vid}`, which is unacceptable." 166 | 167 | 168 | class isAccelerationConflictWithCar: 169 | def __init__(self, sce: Scenario) -> None: 170 | self.sce = sce 171 | self.TIME_HEAD_WAY = 5.0 172 | self.VEHICLE_LENGTH = 5.0 173 | self.acceleration = 4.0 174 | 175 | @prompts(name='Is Acceleration Conflict With Car', 176 | description="""useful when you want to know whether acceleration is safe with a specific car, ONLY when your decision is accelerate. The input to this tool should be a string, representing the id of the car you want to check.""") 177 | def inference(self, vid: str) -> str: 178 | if vid not in self.sce.vehicles: 179 | return "Your input is not a valid vehicle id, make sure you use `Get Lane Involved Car` tool first!" 180 | if vid == 'ego': 181 | return "You are checking the acceleration of ego car, which is meaningless, input a valid vehicle id please!" 182 | veh = self.sce.vehicles[vid] 183 | ego = self.sce.vehicles['ego'] 184 | if veh.lane_id != ego.lane_id: 185 | return f'{vid} is not in the same lane with ego, please call `Get Lane Involved Car` and rethink your input.' 186 | if veh.lane_id == ego.lane_id: 187 | if veh.lanePosition >= ego.lanePosition: 188 | relativeSpeed = ego.speed + self.acceleration - veh.speed 189 | distance = veh.lanePosition - ego.lanePosition - self.VEHICLE_LENGTH * 2 190 | if distance > self.TIME_HEAD_WAY * relativeSpeed: 191 | return f"acceleration is safe with `{vid}`." 192 | else: 193 | return f"acceleration may be conflict with `{vid}`, which is unacceptable." 194 | else: 195 | return f"acceleration is safe with {vid}" 196 | else: 197 | return f"acceleration is safe with {vid}" 198 | 199 | 200 | class isKeepSpeedConflictWithCar: 201 | def __init__(self, sce: Scenario) -> None: 202 | self.sce = sce 203 | self.TIME_HEAD_WAY = 5.0 204 | self.VEHICLE_LENGTH = 5.0 205 | 206 | @prompts(name='Is Keep Speed Conflict With Car', 207 | description="""useful when you want to know whether keep speed is safe with a specific car, ONLY when your decision is keep_speed. The input to this tool should be a string, representing the id of the car you want to check.""") 208 | def inference(self, vid: str) -> str: 209 | if vid not in self.sce.vehicles: 210 | return "Your input is not a valid vehicle id, make sure you use `Get Lane Involved Car` tool first!" 211 | if vid == 'ego': 212 | return "You are checking the acceleration of ego car, which is meaningless, input a valid vehicle id please!" 213 | veh = self.sce.vehicles[vid] 214 | ego = self.sce.vehicles['ego'] 215 | if veh.lane_id != ego.lane_id: 216 | return f'{vid} is not in the same lane with ego, please call `Get Lane Involved Car` and rethink your input.' 217 | if veh.lane_id == ego.lane_id: 218 | if veh.lanePosition >= ego.lanePosition: 219 | relativeSpeed = ego.speed - veh.speed 220 | distance = veh.lanePosition - ego.lanePosition - self.VEHICLE_LENGTH * 2 221 | if distance > self.TIME_HEAD_WAY * relativeSpeed: 222 | return f"keep lane with current speed is safe with {vid}" 223 | else: 224 | return f"keep lane with current speed may be conflict with {vid}, you need consider decelerate" 225 | else: 226 | return f"keep lane with current speed is safe with {vid}" 227 | else: 228 | return f"keep lane with current speed is safe with {vid}" 229 | 230 | 231 | class isDecelerationSafe: 232 | def __init__(self, sce: Scenario) -> None: 233 | self.sce = sce 234 | self.TIME_HEAD_WAY = 3.0 235 | self.VEHICLE_LENGTH = 5.0 236 | self.deceleration = 3.0 237 | 238 | @prompts(name='Is Deceleration Safe', 239 | description="""useful when you want to know whether deceleration is safe, ONLY when your decision is decelerate.The input to this tool should be a string, representing the id of the car you want to check.""") 240 | def inference(self, vid: str) -> str: 241 | if vid not in self.sce.vehicles: 242 | return "Your input is not a valid vehicle id, make sure you use `Get Lane Involved Car` tool first!" 243 | if vid == 'ego': 244 | return "You are checking the acceleration of ego car, which is meaningless, input a valid vehicle id please!" 245 | veh = self.sce.vehicles[vid] 246 | ego = self.sce.vehicles['ego'] 247 | if veh.lane_id != ego.lane_id: 248 | return f'{vid} is not in the same lane with ego, please call `Get Lane Involved Car` and rethink your input.' 249 | if veh.lane_id == ego.lane_id: 250 | if veh.lanePosition >= ego.lanePosition: 251 | relativeSpeed = ego.speed - veh.speed - self.deceleration 252 | distance = veh.lanePosition - ego.lanePosition - self.VEHICLE_LENGTH 253 | if distance > self.TIME_HEAD_WAY * relativeSpeed: 254 | return f"deceleration with current speed is safe with {vid}" 255 | else: 256 | return f"deceleration with current speed may be conflict with {vid}, if you have no other choice, slow down as much as possible" 257 | else: 258 | return f"deceleration with current speed is safe with {vid}" 259 | else: 260 | return f"deceleration with current speed is safe with {vid}" -------------------------------------------------------------------------------- /Auto_Driving_Highway/pre_prompt.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE_PREFIX = """ 2 | You are ChatGPT, a large language model trained by OpenAI. 3 | You are now act as a mature driving assistant, who can give accurate and correct advice for human driver in complex urban driving scenarios. 4 | The information in 'current scenario' : 5 | 6 | """ 7 | 8 | TRAFFIC_RULES = """ 9 | 1. Try to keep a safe distance to the car in front of you. 10 | 2. If there is no safe decision, just slowing down. 11 | 3. DONOT change lane frequently. If you want to change lane, double-check the safety of vehicles on target lane. 12 | """ 13 | 14 | 15 | DECISION_CAUTIONS = """ 16 | 1. You must output a decision when you finish this task. Your final output decision must be unique and not ambiguous. For example you cannot say "I can either keep lane or accelerate at current time". 17 | 2. You need to always remember your current lane ID, your available actions and available lanes before you make any decision. 18 | 3. Once you have a decision, you should check the safety with all the vehicles affected by your decision. 19 | 4. If you verify a decision is unsafe, you should start a new one and verify its safety again from scratch. 20 | """ 21 | 22 | def get_traffic_rules(): 23 | return TRAFFIC_RULES 24 | 25 | def get_decision_cautions(): 26 | return DECISION_CAUTIONS 27 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.21.0 2 | gymnasium==0.29.1 3 | highway_env==1.8.2 4 | numpy==1.26.2 5 | openai==1.3.6 6 | stable_baselines3==2.2.1 7 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/scenario.py: -------------------------------------------------------------------------------- 1 | from baseClass import Lane, Vehicle 2 | from typing import List, Dict 3 | from datetime import datetime 4 | import sqlite3 5 | import json 6 | import os 7 | 8 | 9 | class Scenario: 10 | def __init__(self, vehicleCount: int, database: str = None) -> None: 11 | self.lanes: Dict[str, Lane] = {} 12 | self.getRoadgraph() 13 | self.vehicles: Dict[str, Vehicle] = {} 14 | self.vehicleCount = vehicleCount 15 | self.initVehicles() 16 | 17 | if database: 18 | self.database = database 19 | else: 20 | self.database = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.db' 21 | 22 | if os.path.exists(self.database): 23 | os.remove(self.database) 24 | 25 | conn = sqlite3.connect(self.database) 26 | cur = conn.cursor() 27 | cur.execute( 28 | """CREATE TABLE IF NOT EXISTS vehINFO( 29 | frame INT, 30 | id TEXT, 31 | x REAL, 32 | y REAL, 33 | lane_id TEXT, 34 | speedx REAL, 35 | speedy REAL, 36 | PRIMARY KEY (frame, id));""" 37 | ) 38 | cur.execute( 39 | """CREATE TABLE IF NOT EXISTS decisionINFO( 40 | frame INT PRIMARY KEY, 41 | scenario TEXT, 42 | thoughtsAndActions TEXT, 43 | finalAnswer TEXT, 44 | outputParser TEXT);""" 45 | ) 46 | conn.commit() 47 | conn.close() 48 | 49 | self.frame = 0 50 | 51 | def getRoadgraph(self): 52 | for i in range(3): 53 | lid = 'lane_' + str(i) 54 | leftLanes = [] 55 | rightLanes = [] 56 | for j in range(i+1, 3): 57 | rightLanes.append('lane_' + str(j)) 58 | for k in range(0, i): 59 | leftLanes.append('lane_' + str(k)) 60 | self.lanes[lid] = Lane( 61 | id=lid, laneIdx=i, 62 | left_lanes=leftLanes, 63 | right_lanes=rightLanes 64 | ) 65 | 66 | def initVehicles(self): 67 | for i in range(self.vehicleCount): 68 | if i == 0: 69 | vid = 'ego' 70 | else: 71 | vid = 'veh' + str(i) 72 | self.vehicles[vid] = Vehicle(id=vid) 73 | 74 | def updateVehicles(self, observation, frame): 75 | self.frame = frame 76 | conn = sqlite3.connect(self.database) 77 | cur = conn.cursor() 78 | for i, vehicle_obs in enumerate(observation): 79 | presence = vehicle_obs[0] # Assuming the first element of each vehicle observation is 'presence' 80 | x, y, vx, vy = vehicle_obs[1:] 81 | 82 | if presence == 1: # Check if the vehicle is present 83 | vid = f'veh{i}' if i != 0 else 'ego' 84 | veh = self.vehicles[vid] 85 | veh.updateProperty(x, y, vx, vy) 86 | cur.execute( 87 | '''INSERT INTO vehINFO VALUES (?,?,?,?,?,?,?);''', 88 | (self.frame, vid, x, y, veh.lane_id, vx, vy) 89 | ) 90 | else: 91 | if i != 0: # Skip 'ego' when it's not present 92 | vid = f'veh{i}' 93 | self.vehicles[vid].clear() 94 | 95 | conn.commit() 96 | conn.close() 97 | 98 | def export2json(self): 99 | scenario = {} 100 | scenario['lanes'] = [] 101 | scenario['vehicles'] = [] 102 | for lv in self.lanes.values(): 103 | scenario['lanes'].append(lv.export2json()) 104 | scenario['ego_info'] = self.vehicles['ego'].export2json() 105 | 106 | for vv in self.vehicles.values(): 107 | if vv.presence: 108 | scenario['vehicles'].append(vv.export2json()) 109 | 110 | return json.dumps(scenario) 111 | 112 | def clear_vehicle_info(self): 113 | """Clears the vehINFO table in the database for a new episode.""" 114 | conn = sqlite3.connect(self.database) 115 | cur = conn.cursor() 116 | cur.execute("DELETE FROM vehINFO") 117 | conn.commit() 118 | conn.close() -------------------------------------------------------------------------------- /Auto_Driving_Highway/test_DQN.py: -------------------------------------------------------------------------------- 1 | import highway_env 2 | import gymnasium as gym 3 | from stable_baselines3 import DQN 4 | from stable_baselines3.common.vec_env import DummyVecEnv # Import DummyVecEnv 5 | import random 6 | from agent_train import MyHighwayEnv 7 | from scenario import Scenario 8 | from customTools import getAvailableActions, getAvailableLanes, getLaneInvolvedCar,isChangeLaneConflictWithCar,isAccelerationConflictWithCar,isKeepSpeedConflictWithCar,isDecelerationSafe 9 | from analysis_obs import available_action, get_available_lanes, get_involved_cars, extract_lanes_info, extract_lane_and_car_ids, assess_lane_change_safety, check_safety_in_current_lane, format_training_info 10 | from customTools import ACTIONS_ALL 11 | import ask_llm 12 | 13 | def main(): 14 | env = MyHighwayEnv(vehicleCount=5) 15 | observation = env.reset() 16 | print("Initial Observation:", observation) 17 | print("Observation space:", env.observation_space) 18 | # print("Action space:", env.action_space) 19 | 20 | # Wrap the environment in a DummyVecEnv for SB3 21 | env = DummyVecEnv([lambda: env]) # Add this line 22 | available_actions = env.envs[0].get_available_actions() 23 | model = DQN( 24 | "MlpPolicy", 25 | env, 26 | verbose=0, 27 | train_freq=2, 28 | learning_starts=20, 29 | exploration_fraction=0.5, 30 | learning_rate=0.0001, 31 | ) 32 | # Initialize scenario and tools 33 | sce = Scenario(vehicleCount=5) 34 | toolModels = [ 35 | getAvailableActions(env.envs[0]), 36 | getAvailableLanes(sce), 37 | getLaneInvolvedCar(sce), 38 | isChangeLaneConflictWithCar(sce), 39 | isAccelerationConflictWithCar(sce), 40 | isKeepSpeedConflictWithCar(sce), 41 | isDecelerationSafe(sce), 42 | # isActionSafe() 43 | ] 44 | frame = 0 45 | for _ in range(10): 46 | obs = env.reset() 47 | done = False 48 | while not done: 49 | sce.updateVehicles(obs, frame) 50 | # Observation translation 51 | msg0 = available_action(toolModels) 52 | msg1 = get_available_lanes(toolModels) 53 | msg2 = get_involved_cars((toolModels)) 54 | msg1_info = next(iter(msg1.values())) 55 | lanes_info = extract_lanes_info(msg1_info) 56 | 57 | lane_car_ids = extract_lane_and_car_ids(lanes_info, msg2) 58 | safety_assessment = assess_lane_change_safety(toolModels, lane_car_ids) 59 | safety_msg = check_safety_in_current_lane(toolModels, lane_car_ids) 60 | formatted_info = format_training_info(msg0, msg1, msg2, lanes_info, lane_car_ids, safety_assessment, 61 | safety_msg) 62 | 63 | action, _ = model.predict(obs) 64 | action_id = int(action[0]) 65 | action_name = ACTIONS_ALL.get(action_id, "Unknown Action") 66 | print(f"DQN action: {action_name}") 67 | 68 | llm_response = ask_llm.send_to_chatgpt(action, formatted_info, sce) 69 | decision_content = llm_response.content 70 | print(llm_response) 71 | llm_suggested_action = extract_decision(decision_content) 72 | print(f"llm action: {llm_suggested_action}") 73 | 74 | env.env_method('set_llm_suggested_action', llm_suggested_action) 75 | # print(f"Action: {action}") 76 | # print(f"Observation: {next_obs}") 77 | 78 | obs, custom_reward, done, info = env.step(action) 79 | print(f"Reward: {custom_reward}\n") 80 | frame += 1 81 | model.learn(total_timesteps=10, reset_num_timesteps=False) 82 | 83 | obs = env.reset() 84 | for step in range(1000): 85 | action, _states = model.predict(obs, deterministic=True) 86 | obs, rewards, dones, info = env.step(action) 87 | 88 | print(f"Reward: {rewards}\n") 89 | 90 | env.close() 91 | 92 | # utils.py 93 | def extract_decision(response_content): 94 | try: 95 | start = response_content.find('"decision": {') + len('"decision": {') 96 | end = response_content.find('}', start) 97 | decision = response_content[start:end].strip('"') 98 | return decision 99 | except Exception as e: 100 | print(f"Error in extracting decision: {e}") 101 | return None 102 | 103 | 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/test_chat.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | 4 | def send_to_chatgpt(api_key, structured_message): 5 | client = OpenAI(api_key=api_key) 6 | 7 | # Replace the messages below with your structured_message 8 | completion = client.chat.completions.create( 9 | model="gpt-3.5-turbo-1106", 10 | messages=[ 11 | {"role": "system", "content": "You are an assistant, skilled in analyzing driving scenarios."}, 12 | {"role": "user", "content": structured_message} 13 | ] 14 | ) 15 | 16 | return completion.choices[0].message 17 | 18 | def extract_decision(response): 19 | try: 20 | start = response.find('"decision": {') + len('"decision": {') 21 | end = response.find('}', start) 22 | decision = response[start:end].strip('"') 23 | return decision 24 | except Exception as e: 25 | print(f"Error in extracting decision: {e}") 26 | return None 27 | 28 | 29 | def main(): 30 | # Replace with your actual API Key 31 | with open('MY_KEY.txt', 'r') as f: 32 | api_key = f.read().strip() 33 | 34 | # Example structured message (replace with actual message) 35 | structured_message = "Your analysis of the driving situation..." 36 | 37 | # Send message and get response 38 | chatgpt_response = send_to_chatgpt(api_key, structured_message) 39 | print("ChatGPT Response:", chatgpt_response) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /Auto_Driving_Highway/train_logger.py: -------------------------------------------------------------------------------- 1 | from stable_baselines3.common.callbacks import BaseCallback 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | class TrainingLogger(BaseCallback): 6 | def __init__(self, check_freq, verbose=0): 7 | super(TrainingLogger, self).__init__(verbose) 8 | self.check_freq = check_freq 9 | # Metrics for collision and speed 10 | self.collision_count = 0 11 | self.total_speed = 0 12 | self.steps = 0 13 | self.collision_probabilities = [] 14 | self.average_speeds = [] 15 | # Metrics for matching percentage 16 | self.match_count = 0 17 | self.total_decisions = 0 # Adjusted to total_decisions for clarity 18 | self.match_percentages = [] 19 | 20 | def _on_step(self) -> bool: 21 | self.steps += 1 22 | self.total_decisions += 1 # Increment total decisions 23 | 24 | # Extract collision, speed, and LLM match from 'info' 25 | collision = self.locals['infos'][0].get('crashed', False) 26 | average_speed = self.locals['infos'][0].get('speed', 0) 27 | llm_reward = self.locals['infos'][0].get('llm_reward', 0) 28 | 29 | # Update metrics 30 | self.collision_count += int(collision) 31 | self.total_speed += average_speed 32 | self.match_count += int(llm_reward == 1) 33 | 34 | if self.steps % self.check_freq == 0: 35 | # Calculate and store metrics 36 | collision_probability = self.collision_count / self.check_freq 37 | avg_speed = self.total_speed / self.check_freq 38 | match_percentage = (self.match_count / self.total_decisions) * 100 39 | self.collision_probabilities.append(collision_probability) 40 | self.average_speeds.append(avg_speed) 41 | self.match_percentages.append(match_percentage) 42 | print(f"Step: {self.steps}, Match Percentage: {match_percentage}%, Collision Probability: {collision_probability}, Average Speed: {avg_speed}") 43 | # Reset counters 44 | self.collision_count = 0 45 | self.total_speed = 0 46 | self.match_count = 0 47 | self.total_decisions = 0 48 | 49 | return True 50 | 51 | def _on_training_end(self): 52 | # Plot all metrics 53 | plt.figure(figsize=(18, 6)) 54 | 55 | # Subplot for collision probabilities 56 | plt.subplot(1, 3, 1) 57 | plt.plot(self.collision_probabilities, label='Collision Probability', color='red') 58 | plt.xlabel('Checkpoints') 59 | plt.ylabel('Collision Probability') 60 | plt.title('Collision Probability Over Time') 61 | 62 | # Subplot for average speeds 63 | plt.subplot(1, 3, 2) 64 | plt.plot(self.average_speeds, label='Average Speed', color='orange') 65 | plt.xlabel('Checkpoints') 66 | plt.ylabel('Average Speed') 67 | plt.title('Average Speed Over Time') 68 | 69 | # Subplot for matching percentages 70 | plt.subplot(1, 3, 3) 71 | plt.plot(self.match_percentages, label='Matching Percentage', color='blue') 72 | plt.xlabel('Checkpoints') 73 | plt.ylabel('Matching Percentage (%)') 74 | plt.title('Matching Percentage Over Time') 75 | 76 | plt.tight_layout() 77 | plt.legend() 78 | plt.show() 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 JingYue2000 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In-context_Learning_for_Automated_Driving 🤔 2 | 3 | The code in this repository is based on the paper Reward Design with Language Models. 4 | 5 | This repository contains the prompts that we used for each domain as well as code to train an RL agent with an LLM in the loop using those prompts. Each domain (Ultimatum Game, Matrix Games, DealOrNoDeal) has a separate directory and will need a separate conda/virtual environment. 6 | 7 | Please check out the READMEs in each directory for more information on how to run things. 8 | 9 | # Description 10 | 11 | We implemented a framework that utilizes LLMs to assist Reinforcement Learning (RL) agents in better simulating human driving behavior. First, we define the desired driving style and output format and input them into the agent. Then, we describe the current scenario the ego car faces to the LLM. Except for the first round, in each subsequent round, we provide the LLM with the decision output from the previous round, enabling better context understanding. By comparing the LLM's output decision with the decision made by the RL agent, we generate a reward function to train the LLM-RL agent. 12 | 13 | # Model Setup 14 | 15 | - We use GPT3 for our experiments. You will need an API key from them saved in a file named MY_KEY.txt. 16 | - Use requirement.txt for model environment setup: 17 | ```pip install -r requirements.txt``` 18 | 19 | # Model Work Flow 20 | 21 | ![image](https://github.com/JingYue2000/In-context_Learning_for_Automated_Driving/blob/main/Results/Framework.png) 22 | Our framework integrates a Language Model (LM) to process textual prompts as input and generate a reward signal. The input prompt consists of three primary components: the Task Description, the User Objective, and the Last Outcome. Specifically, the LM functions by accepting a concatenated version of these components and returns a textual string. This output is then analyzed by a parser, which translates the text into a binary reward signal. The binary reward signal is subsequently utilized to train a Reinforcement Learning (RL) agent. By employing the LM as a proxy for the traditional reward function, our framework is compatible with various RL training algorithms, enhancing adaptability and potential applications. 23 | 24 | # Prompt Structure 25 | 26 | ![image](https://github.com/JingYue2000/In-context_Learning_for_Automated_Driving/blob/main/Results/prompt_Structure.png) 27 | 28 | A prompt for one interaction consists of the last action, the predefined prompt, and the scene description. This description includes available actions, a general overview, and a safety evaluation. 29 | 30 | # Conservative Model and Aggressive Model 31 | 32 | ![image](https://github.com/JingYue2000/In-context_Learning_for_Automated_Driving/blob/main/Results/Prompt_CaseStudy.png) 33 | 34 | An example of conservative model: 35 | 36 | ![Conservative Model](Results/conservative.gif) 37 | 38 | 39 | An example of aggressive model: 40 | 41 | ![Aggressive Model](Results/aggressive.gif) 42 | -------------------------------------------------------------------------------- /Results/Casestudy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/Casestudy.png -------------------------------------------------------------------------------- /Results/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/Framework.png -------------------------------------------------------------------------------- /Results/Prompt_CaseStudy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/Prompt_CaseStudy.jpg -------------------------------------------------------------------------------- /Results/aggressive.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/aggressive.gif -------------------------------------------------------------------------------- /Results/conservative.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/conservative.gif -------------------------------------------------------------------------------- /Results/prompt_Structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingYue2000/In-context_Learning_for_Automated_Driving/73b0ad467ded93ef38c2be0c06714d33c5184b00/Results/prompt_Structure.png --------------------------------------------------------------------------------