├── GRE ├── __init__.py ├── agent.py └── recorder.py ├── HUN ├── KM.py ├── __init__.py ├── agent.py └── recorder.py ├── LTA ├── KM.py ├── __init__.py ├── agent.py ├── entity.py ├── global_var.py ├── grid.py ├── grid_id ├── grids_info ├── kdtree ├── matcher.py ├── recorder.py ├── scheduler.py └── utils.py ├── NNP ├── __init__.py ├── agent.py └── recorder.py └── README.md /GRE/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/GRE/__init__.py -------------------------------------------------------------------------------- /GRE/agent.py: -------------------------------------------------------------------------------- 1 | from recorder import Recorder 2 | from typing import List, Dict, Any 3 | 4 | 5 | class Agent(Recorder): 6 | """ Agent for dispatching and reposition """ 7 | 8 | def __init__(self, **kwargs): 9 | """ Load your trained model and initialize the parameters """ 10 | super().__init__() 11 | 12 | def dispatch(self, dispatch_observ: List[Dict[str, Any]], index2hash=None) -> List[Dict[str, int]]: 13 | """ Compute the assignment between drivers and passengers at each time step 14 | :param dispatch_observ: a list of dict, the key in the dict includes: 15 | order_id, int 16 | driver_id, int 17 | order_driver_distance, float 18 | order_start_location, a list as [lng, lat], float 19 | order_finish_location, a list as [lng, lat], float 20 | driver_location, a list as [lng, lat], float 21 | timestamp, int 22 | order_finish_timestamp, int 23 | day_of_week, int 24 | reward_units, float 25 | pick_up_eta, float 26 | :param index2hash: driver_id to driver_hash 27 | :return: a list of dict, the key in the dict includes: 28 | order_id and driver_id, the pair indicating the assignment 29 | """ 30 | dispatch_observ.sort(key=lambda od_info: -od_info['reward_units']) 31 | assigned_order = set() 32 | assigned_driver = set() 33 | dispatch_action = [] 34 | for od in dispatch_observ: 35 | # make sure each order is assigned to one driver, and each driver is assigned with one order 36 | if (od["order_id"] in assigned_order) or (od["driver_id"] in assigned_driver): 37 | continue 38 | assigned_order.add(od["order_id"]) 39 | assigned_driver.add(od["driver_id"]) 40 | dispatch_action.append(dict(order_id=od["order_id"], driver_id=od["driver_id"])) 41 | return dispatch_action 42 | 43 | def reposition(self, repo_observ): 44 | """ Compute the reposition action for the given drivers 45 | :param repo_observ: a dict, the key in the dict includes: 46 | timestamp: int 47 | driver_info: a list of dict, the key in the dict includes: 48 | driver_id: driver_id of the idle driver in the treatment group, int 49 | grid_id: id of the grid the driver is located at, str 50 | day_of_week: int 51 | :return: a list of dict, the key in the dict includes: 52 | driver_id: corresponding to the driver_id in the od_list 53 | destination: id of the grid the driver is repositioned to, str 54 | """ 55 | # repo_action = [] 56 | # for driver in repo_observ['driver_info']: 57 | # # the default reposition is to let drivers stay where they are 58 | # repo_action.append({'driver_id': driver['driver_id'], 'destination': driver['grid_id']}) 59 | # return repo_action 60 | return [] 61 | -------------------------------------------------------------------------------- /GRE/recorder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | import math 4 | from typing import Dict, List, Any, Set 5 | import time 6 | 7 | 8 | def acc_dist(lng1: float, lat1: float, lng2: float, lat2: float) -> float: 9 | delta_lat = (lat1 - lat2) / 2 10 | delta_lng = (lng1 - lng2) / 2 11 | arc_pi = 3.14159265359 / 180 12 | R = 6378137 13 | return 2 * R * math.asin(math.sqrt( 14 | math.sin(arc_pi * delta_lat) ** 2 + math.cos(arc_pi * lat1) * math.cos(arc_pi * lat2) * ( 15 | math.sin(arc_pi * delta_lng) ** 2))) 16 | 17 | 18 | class Recorder: 19 | 20 | def __init__(self): 21 | self.drivers_total_income = defaultdict(float) 22 | self.drivers_online_time = defaultdict(float) 23 | self.drivers_log_on_off = defaultdict(list) 24 | self.drivers_serving_order_info = defaultdict(list) 25 | self.drivers_income_per_hour = defaultdict(lambda: [0 for i in range(25)]) 26 | self.active_drivers = set() 27 | self.median_ratio = 0. 28 | 29 | def __update_online_time(self, drivers_online_time: Dict[str, int]): 30 | """ 31 | update the driver's online time 32 | :param drivers_online_time: a dict, with key is driver's id, value is online time (in seconds) 33 | :return: None 34 | """ 35 | for driver_id in drivers_online_time: 36 | self.drivers_online_time[driver_id] = drivers_online_time[driver_id] 37 | 38 | def update_log_on(self, online_drivers_hash: Set[str], timestamp): 39 | """ 40 | update the driver log_on time 41 | :param online_drivers_hash: online drivers' hashcode at this timestamp 42 | :param timestamp: current timestamp 43 | """ 44 | self.active_drivers = self.active_drivers.union(online_drivers_hash) 45 | for driver_hash in online_drivers_hash: 46 | if len(self.drivers_log_on_off[driver_hash]) != 0: 47 | print("MULTIPLE LOG ON!!") 48 | self.drivers_log_on_off[driver_hash].append(timestamp) 49 | 50 | def update_log_off(self, offline_drivers_hash: Set[str], timestamp): 51 | """ 52 | update the driver log_off time 53 | :param offline_drivers_hash: offline drivers' hashcode at this timestamp 54 | :param timestamp: current timestamp 55 | """ 56 | self.active_drivers = self.active_drivers.difference(offline_drivers_hash) 57 | for driver_hash in offline_drivers_hash: 58 | if len(self.drivers_log_on_off[driver_hash]) != 1: 59 | print("LOG OFF BEFORE LOG ON!!") 60 | self.drivers_log_on_off[driver_hash].append(timestamp) 61 | ratios = [self.drivers_total_income[driver] / (0.1 + timestamp - self.drivers_log_on_off[driver][0]) 62 | for driver in self.active_drivers] 63 | ratios.sort() 64 | if len(ratios) > 0: 65 | self.median_ratio = ratios[len(ratios) // 2] 66 | 67 | def update_driver_income_after_rejection(self, assignment: List[Dict[str, Any]], 68 | dispatch_observ: List[Dict[str, Any]], index2hash: Dict[int, str]): 69 | """ 70 | this function update the driver's income. 71 | Should be called after the rejection process. 72 | :param assignment: a list of dicts, one dict is <'order_id': xxx, 'driver_id':xxx> 73 | :param dispatch_observ: the same as agent.matching parameter. 74 | :param index2hash: driver_id to driver_hash 75 | :return: None 76 | """ 77 | if len(dispatch_observ) == 0: 78 | return 79 | cur_hour = time.localtime(int(dispatch_observ[0]['timestamp'])).tm_hour 80 | order_price = {od['order_id']: od['reward_units'] for od in dispatch_observ} 81 | # for all recorders 82 | order_info = {od['order_id']: [od['reward_units'], 83 | od['order_driver_distance'], 84 | od['order_start_location'], 85 | od['order_finish_location'], 86 | od['timestamp'], 87 | od['order_finish_timestamp'], 88 | od['pick_up_eta']] for od in dispatch_observ} 89 | for pair in assignment: 90 | self.drivers_total_income[index2hash[pair['driver_id']]] += order_price[pair['order_id']] 91 | self.drivers_income_per_hour[index2hash[pair['driver_id']]][cur_hour] += order_price[pair['order_id']] 92 | self.drivers_serving_order_info[index2hash[pair['driver_id']]].append(order_info[pair['order_id']]) 93 | return 94 | 95 | def save_logs(self, solpath: str, city: str, date: str, notes=""): 96 | """ 97 | After one day simulation, output the driver's income and his/her online time into file 98 | :param solpath: str, the solution path 99 | :param date: str, the simulation date, eg. 20201129 100 | :param city: str, the city name, eg. chengdu 101 | :param notes: str, the parameter setting information 102 | :return: None 103 | """ 104 | bad1 = 0 105 | bad2 = 0 106 | bad3 = 0 107 | for driver_hash in self.drivers_log_on_off: 108 | if len(self.drivers_log_on_off[driver_hash]) == 1: 109 | bad1 += 1 110 | continue 111 | if len(self.drivers_log_on_off[driver_hash]) == 2: 112 | bad2 += 1 113 | log_on, log_off = self.drivers_log_on_off[driver_hash] 114 | self.drivers_online_time[driver_hash] = log_off - log_on 115 | continue 116 | if len(self.drivers_log_on_off[driver_hash]) > 2: 117 | bad3 += 1 118 | continue 119 | print("collision:", bad1, bad2, bad3, len(self.drivers_log_on_off)) 120 | driver_perhour_income = dict() 121 | for driver in self.drivers_income_per_hour: 122 | driver_perhour_income[driver] = self.drivers_income_per_hour[driver] 123 | solname = solpath.split('/')[-1] 124 | pickle.dump(self.drivers_log_on_off, open(solpath + "/" + solname + "_" + city + "_" + date + "logonoff" + notes, "wb")) 125 | pickle.dump(self.drivers_online_time, open(solpath + "/" + solname + "_" + city + "_" + date + "online_time" + notes, "wb")) 126 | pickle.dump(self.drivers_total_income, open(solpath + "/" + solname + "_" + city + "_" + date + "total_income" + notes, "wb")) 127 | pickle.dump(driver_perhour_income, open(solpath + "/" + solname + "_" + city + "_" + date + "perhourincome" + notes, "wb")) 128 | pickle.dump(self.drivers_serving_order_info, open(solpath + "/" + solname + "_" + city + "_" + date + "order_info" + notes, "wb")) 129 | 130 | -------------------------------------------------------------------------------- /HUN/KM.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import numpy as np 4 | import random 5 | import time 6 | from collections import deque, defaultdict 7 | 8 | random.seed(0) 9 | 10 | zero_threshold = 0.0000001 11 | INF = 100000000 12 | 13 | 14 | def bfs_split(values): 15 | left_name_idx = dict() 16 | left_idx_name = [] 17 | right_name_idx = dict() 18 | right_idx_name = [] 19 | left_cnt = 0 20 | right_cnt = 0 21 | left_right = defaultdict(list) 22 | right_left = defaultdict(list) 23 | for x, y, w in values: 24 | if x not in left_name_idx: 25 | left_name_idx[x] = left_cnt 26 | left_cnt += 1 27 | left_idx_name.append(x) 28 | if y not in right_name_idx: 29 | right_name_idx[y] = right_cnt 30 | right_cnt += 1 31 | right_idx_name.append(y) 32 | left_right[left_name_idx[x]].append((right_name_idx[y], w)) 33 | right_left[right_name_idx[y]].append((left_name_idx[x], w)) 34 | left_visit = [False] * left_cnt 35 | right_visit = [False] * right_cnt 36 | 37 | blocks = [] 38 | for x in left_right: 39 | if left_visit[x]: 40 | continue 41 | block_x = [x] 42 | left_visit[x] = True 43 | q = deque([(x, 'l')]) 44 | while q: 45 | src, side = q.popleft() 46 | if side == 'l': 47 | for dst, w in left_right[src]: 48 | if right_visit[dst]: 49 | continue 50 | right_visit[dst] = True 51 | q.append((dst, 'r')) 52 | else: 53 | for dst, w in right_left[src]: 54 | if left_visit[dst]: 55 | continue 56 | block_x.append(dst) 57 | left_visit[dst] = True 58 | q.append((dst, 'l')) 59 | # convert blocks back to values 60 | values_block = [] 61 | for x_1 in block_x: 62 | for y, w in left_right[x_1]: 63 | values_block.append((left_idx_name[x_1], right_idx_name[y], w)) 64 | blocks.append(values_block) 65 | return blocks 66 | 67 | 68 | class KMNode(object): 69 | def __init__(self, idx, no, exception=0, match=None, visit=False): 70 | self.id = idx 71 | self.no = no 72 | self.exception = exception 73 | self.match = match 74 | self.visit = visit 75 | self.slack = INF 76 | 77 | def __repr__(self): 78 | return "idx:" + str(self.id) + " tag: " + str(self.exception) + " match: " + \ 79 | str(self.match) + " vis: " + str(self.visit) + " slack: " + str(self.slack) 80 | 81 | 82 | class KuhnMunkres(object): 83 | 84 | def __init__(self, quick): 85 | self.matrix = None 86 | self.x_nodes = [] 87 | self.y_nodes = [] 88 | self.x_length = 0 89 | self.y_length = 0 90 | self.index_x = 0 91 | self.index_y = 1 92 | self.quick_sol = None 93 | self.is_quick = quick 94 | 95 | def set_matrix(self, x_y_values): 96 | xs = set() 97 | ys = set() 98 | for x, y, w in x_y_values: 99 | xs.add(x) 100 | ys.add(y) 101 | 102 | # 选取较小的作为x 103 | if len(xs) <= len(ys): 104 | self.index_x = 0 105 | self.index_y = 1 106 | else: 107 | self.index_x = 1 108 | self.index_y = 0 109 | xs, ys = ys, xs 110 | 111 | x_dic = {x: i for i, x in enumerate(xs)} 112 | y_dic = {y: j for j, y in enumerate(ys)} 113 | self.x_nodes = [KMNode(x, x_dic[x]) for x in xs] 114 | self.y_nodes = [KMNode(y, y_dic[y]) for y in ys] 115 | self.x_length = len(xs) 116 | self.y_length = len(ys) 117 | 118 | self.matrix = np.zeros((self.x_length, self.y_length)) 119 | for row in x_y_values: 120 | x = row[self.index_x] 121 | y = row[self.index_y] 122 | w = row[2] 123 | x_index = x_dic[x] 124 | y_index = y_dic[y] 125 | self.matrix[x_index, y_index] = w 126 | if self.x_length == 1 and self.is_quick: 127 | best_choice = int(np.argmax(self.matrix[0])) 128 | max_val = self.matrix[0][best_choice] 129 | left_id = self.x_nodes[0].id 130 | right_id = self.y_nodes[best_choice].id 131 | if self.index_x == 1: 132 | left_id, right_id = right_id, left_id 133 | match = [(left_id, right_id, max_val)] 134 | self.quick_sol = (max_val, match) 135 | return 136 | for i in range(self.x_length): 137 | self.x_nodes[i].exception = max(self.matrix[i, :]) 138 | 139 | def km(self): 140 | if self.quick_sol is not None: 141 | return 142 | for i in range(self.x_length): 143 | for node in self.y_nodes: 144 | node.slack = INF 145 | while True: 146 | for node in self.x_nodes: 147 | node.visit = False 148 | for node in self.y_nodes: 149 | node.visit = False 150 | if self.dfs(i): 151 | break 152 | d = INF 153 | for node in self.y_nodes: 154 | if (not node.visit) and d > node.slack: 155 | d = node.slack 156 | if d == INF or d < zero_threshold: 157 | break 158 | for node in self.x_nodes: 159 | if node.visit: 160 | node.exception -= d 161 | for node in self.y_nodes: 162 | if node.visit: 163 | node.exception += d 164 | else: 165 | node.slack -= d 166 | # remain order is not matching 167 | if self.index_x == 1: 168 | remain_orders = [(self.order_price_dur[x.no][0] / self.order_price_dur[x.no][1], x.no) for x in self.x_nodes if x.match is None] 169 | if len(remain_orders) == 0: 170 | return 171 | remain_drivers = [(self.income[y.no] / (self.online_time[y.no] + 0.1), y.no) for y in self.y_nodes if y.match is None] 172 | remain_drivers.sort() 173 | remain_orders.sort(reverse=True) 174 | idx = 0 175 | for _, order_no in remain_orders: 176 | driver_ratio, driver_no = remain_drivers[idx] 177 | idx += 1 178 | self.x_nodes[order_no].match = driver_no 179 | self.y_nodes[driver_no].match = order_no 180 | return 181 | 182 | def dfs(self, x): 183 | x_node = self.x_nodes[x] 184 | x_node.visit = True 185 | for y in range(self.y_length): 186 | y_node = self.y_nodes[y] 187 | if y_node.visit: 188 | continue 189 | t = x_node.exception + y_node.exception - self.matrix[x][y] 190 | if abs(t) < zero_threshold: 191 | y_node.visit = True 192 | if y_node.match is None or self.dfs(y_node.match): 193 | y_node.match = x 194 | x_node.match = y 195 | return True 196 | elif y_node.slack > t: 197 | y_node.slack = t 198 | return False 199 | 200 | def get_connect_result(self): 201 | if self.quick_sol is not None: 202 | return self.quick_sol[1] 203 | ret = [] 204 | for i in range(self.x_length): 205 | x_node = self.x_nodes[i] 206 | j = x_node.match 207 | if j is None: 208 | continue 209 | # TODO: handle those unmatched orders 210 | y_node = self.y_nodes[j] 211 | x_id = x_node.id 212 | y_id = y_node.id 213 | w = self.matrix[i][j] 214 | if self.index_x == 1 and self.index_y == 0: 215 | x_id, y_id = y_id, x_id 216 | ret.append((x_id, y_id, w)) 217 | return ret 218 | 219 | def get_max_value_result(self): 220 | if self.quick_sol is not None: 221 | return self.quick_sol[0] 222 | ret = 0 223 | for i in range(self.x_length): 224 | j = self.x_nodes[i].match 225 | if j is None: 226 | continue 227 | ret += self.matrix[i][j] 228 | return ret 229 | 230 | 231 | def find_part_block(part_block_value, quick=False): 232 | solvers = [KuhnMunkres(quick) for _ in range(len(part_block_value))] 233 | for i, solver in enumerate(solvers): 234 | solver.set_matrix(part_block_value[i]) 235 | solver.km() 236 | val = 0 237 | for solver in solvers: 238 | val += solver.get_max_value_result() 239 | matches = [solver.get_connect_result() for solver in solvers] 240 | match_all = [match[i] for match in matches for i in range(len(match))] 241 | return val, match_all 242 | 243 | 244 | def find_max_match(x_y_values, split=True, quick=True): 245 | if (not split) or (len(x_y_values) < 1): 246 | return find_part_block([x_y_values], quick) 247 | block_values = bfs_split(x_y_values) 248 | return find_part_block(block_values, quick) 249 | 250 | 251 | if __name__ == '__main__': 252 | values = [] 253 | random.seed(0) 254 | for i in range(50): 255 | for j in range(60): 256 | if i // 100 == j // 1000: 257 | value = random.random() 258 | values.append((i, j, value)) 259 | print("begin") 260 | s_time = time.time() 261 | print(find_max_match(values, split=False)) 262 | print("time usage: %s " % str(time.time() - s_time)) 263 | -------------------------------------------------------------------------------- /HUN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/HUN/__init__.py -------------------------------------------------------------------------------- /HUN/agent.py: -------------------------------------------------------------------------------- 1 | from recorder import Recorder 2 | from typing import Dict, List, Set, Tuple, Any 3 | import sys 4 | if sys.platform == 'darwin': 5 | from model.KM import find_max_match 6 | else: 7 | from KM import find_max_match 8 | 9 | 10 | class Agent(Recorder): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def dispatch(self, dispatch_observ: List[Dict[str, Any]], index2hash) -> List[Dict[str, str]]: 15 | if len(dispatch_observ) == 0: 16 | return [] 17 | values = [(each['driver_id'], each['order_id'], each['reward_units']) for each in dispatch_observ] 18 | val, dispatch_tuple = find_max_match(x_y_values=values, split=False, quick=False) 19 | return [dict(driver_id=each[0], order_id=each[1]) for each in dispatch_tuple] 20 | 21 | def reposition(self, repo_observ: Dict[str, Any]) -> List[Dict[str, str]]: 22 | return [] 23 | -------------------------------------------------------------------------------- /HUN/recorder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | import math 4 | from typing import Dict, List, Any, Set 5 | import time 6 | 7 | 8 | def acc_dist(lng1: float, lat1: float, lng2: float, lat2: float) -> float: 9 | delta_lat = (lat1 - lat2) / 2 10 | delta_lng = (lng1 - lng2) / 2 11 | arc_pi = 3.14159265359 / 180 12 | R = 6378137 13 | return 2 * R * math.asin(math.sqrt( 14 | math.sin(arc_pi * delta_lat) ** 2 + math.cos(arc_pi * lat1) * math.cos(arc_pi * lat2) * ( 15 | math.sin(arc_pi * delta_lng) ** 2))) 16 | 17 | 18 | class Recorder: 19 | 20 | def __init__(self): 21 | self.drivers_total_income = defaultdict(float) 22 | self.drivers_online_time = defaultdict(float) 23 | self.drivers_log_on_off = defaultdict(list) 24 | self.drivers_serving_order_info = defaultdict(list) 25 | self.drivers_income_per_hour = defaultdict(lambda: [0 for i in range(25)]) 26 | self.active_drivers = set() 27 | self.median_ratio = 0. 28 | 29 | def __update_online_time(self, drivers_online_time: Dict[str, int]): 30 | """ 31 | update the driver's online time 32 | :param drivers_online_time: a dict, with key is driver's id, value is online time (in seconds) 33 | :return: None 34 | """ 35 | for driver_id in drivers_online_time: 36 | self.drivers_online_time[driver_id] = drivers_online_time[driver_id] 37 | 38 | def update_log_on(self, online_drivers_hash: Set[str], timestamp): 39 | """ 40 | update the driver log_on time 41 | :param online_drivers_hash: online drivers' hashcode at this timestamp 42 | :param timestamp: current timestamp 43 | """ 44 | self.active_drivers = self.active_drivers.union(online_drivers_hash) 45 | for driver_hash in online_drivers_hash: 46 | if len(self.drivers_log_on_off[driver_hash]) != 0: 47 | print("MULTIPLE LOG ON!!") 48 | self.drivers_log_on_off[driver_hash].append(timestamp) 49 | 50 | def update_log_off(self, offline_drivers_hash: Set[str], timestamp): 51 | """ 52 | update the driver log_off time 53 | :param offline_drivers_hash: offline drivers' hashcode at this timestamp 54 | :param timestamp: current timestamp 55 | """ 56 | self.active_drivers = self.active_drivers.difference(offline_drivers_hash) 57 | for driver_hash in offline_drivers_hash: 58 | if len(self.drivers_log_on_off[driver_hash]) != 1: 59 | print("LOG OFF BEFORE LOG ON!!") 60 | self.drivers_log_on_off[driver_hash].append(timestamp) 61 | ratios = [self.drivers_total_income[driver] / (0.1 + timestamp - self.drivers_log_on_off[driver][0]) 62 | for driver in self.active_drivers] 63 | ratios.sort() 64 | if len(ratios) > 0: 65 | self.median_ratio = ratios[len(ratios) // 2] 66 | 67 | def update_driver_income_after_rejection(self, assignment: List[Dict[str, Any]], 68 | dispatch_observ: List[Dict[str, Any]], index2hash: Dict[int, str]): 69 | """ 70 | this function update the driver's income. 71 | Should be called after the rejection process. 72 | :param assignment: a list of dicts, one dict is <'order_id': xxx, 'driver_id':xxx> 73 | :param dispatch_observ: the same as agent.matching parameter. 74 | :param index2hash: driver_id to driver_hash 75 | :return: None 76 | """ 77 | if len(dispatch_observ) == 0: 78 | return 79 | cur_hour = time.localtime(int(dispatch_observ[0]['timestamp'])).tm_hour 80 | order_price = {od['order_id']: od['reward_units'] for od in dispatch_observ} 81 | # for all recorders 82 | order_info = {od['order_id']: [od['reward_units'], 83 | od['order_driver_distance'], 84 | od['order_start_location'], 85 | od['order_finish_location'], 86 | od['timestamp'], 87 | od['order_finish_timestamp'], 88 | od['pick_up_eta']] for od in dispatch_observ} 89 | for pair in assignment: 90 | self.drivers_total_income[index2hash[pair['driver_id']]] += order_price[pair['order_id']] 91 | self.drivers_income_per_hour[index2hash[pair['driver_id']]][cur_hour] += order_price[pair['order_id']] 92 | self.drivers_serving_order_info[index2hash[pair['driver_id']]].append(order_info[pair['order_id']]) 93 | return 94 | 95 | def save_logs(self, solpath: str, city: str, date: str, notes=""): 96 | """ 97 | After one day simulation, output the driver's income and his/her online time into file 98 | :param solpath: str, the solution path 99 | :param date: str, the simulation date, eg. 20201129 100 | :param city: str, the city name, eg. chengdu 101 | :param notes: str, the parameter setting information 102 | :return: None 103 | """ 104 | bad1 = 0 105 | bad2 = 0 106 | bad3 = 0 107 | for driver_hash in self.drivers_log_on_off: 108 | if len(self.drivers_log_on_off[driver_hash]) == 1: 109 | bad1 += 1 110 | continue 111 | if len(self.drivers_log_on_off[driver_hash]) == 2: 112 | bad2 += 1 113 | log_on, log_off = self.drivers_log_on_off[driver_hash] 114 | self.drivers_online_time[driver_hash] = log_off - log_on 115 | continue 116 | if len(self.drivers_log_on_off[driver_hash]) > 2: 117 | bad3 += 1 118 | continue 119 | print("collision:", bad1, bad2, bad3, len(self.drivers_log_on_off)) 120 | driver_perhour_income = dict() 121 | for driver in self.drivers_income_per_hour: 122 | driver_perhour_income[driver] = self.drivers_income_per_hour[driver] 123 | solname = solpath.split('/')[-1] 124 | pickle.dump(self.drivers_log_on_off, open(solpath + "/" + solname + "_" + city + "_" + date + "logonoff" + notes, "wb")) 125 | pickle.dump(self.drivers_online_time, open(solpath + "/" + solname + "_" + city + "_" + date + "online_time" + notes, "wb")) 126 | pickle.dump(self.drivers_total_income, open(solpath + "/" + solname + "_" + city + "_" + date + "total_income" + notes, "wb")) 127 | pickle.dump(driver_perhour_income, open(solpath + "/" + solname + "_" + city + "_" + date + "perhourincome" + notes, "wb")) 128 | pickle.dump(self.drivers_serving_order_info, open(solpath + "/" + solname + "_" + city + "_" + date + "order_info" + notes, "wb")) 129 | 130 | -------------------------------------------------------------------------------- /LTA/KM.py: -------------------------------------------------------------------------------- 1 | # encoding=utf-8 2 | 3 | import numpy as np 4 | import random 5 | import time 6 | from collections import deque, defaultdict 7 | 8 | random.seed(0) 9 | 10 | zero_threshold = 0.0000001 11 | INF = 100000000 12 | 13 | 14 | def bfs_split(values): 15 | left_name_idx = dict() 16 | left_idx_name = [] 17 | right_name_idx = dict() 18 | right_idx_name = [] 19 | left_cnt = 0 20 | right_cnt = 0 21 | left_right = defaultdict(list) 22 | right_left = defaultdict(list) 23 | for x, y, w in values: 24 | if x not in left_name_idx: 25 | left_name_idx[x] = left_cnt 26 | left_cnt += 1 27 | left_idx_name.append(x) 28 | if y not in right_name_idx: 29 | right_name_idx[y] = right_cnt 30 | right_cnt += 1 31 | right_idx_name.append(y) 32 | left_right[left_name_idx[x]].append((right_name_idx[y], w)) 33 | right_left[right_name_idx[y]].append((left_name_idx[x], w)) 34 | left_visit = [False] * left_cnt 35 | right_visit = [False] * right_cnt 36 | 37 | blocks = [] 38 | for x in left_right: 39 | if left_visit[x]: 40 | continue 41 | block_x = [x] 42 | left_visit[x] = True 43 | q = deque([(x, 'l')]) 44 | while q: 45 | src, side = q.popleft() 46 | if side == 'l': 47 | for dst, w in left_right[src]: 48 | if right_visit[dst]: 49 | continue 50 | right_visit[dst] = True 51 | q.append((dst, 'r')) 52 | else: 53 | for dst, w in right_left[src]: 54 | if left_visit[dst]: 55 | continue 56 | block_x.append(dst) 57 | left_visit[dst] = True 58 | q.append((dst, 'l')) 59 | # convert blocks back to values 60 | values_block = [] 61 | for x_1 in block_x: 62 | for y, w in left_right[x_1]: 63 | values_block.append((left_idx_name[x_1], right_idx_name[y], w)) 64 | blocks.append(values_block) 65 | return blocks 66 | 67 | 68 | class KMNode(object): 69 | def __init__(self, idx, no, exception=0, match=None, visit=False): 70 | self.id = idx 71 | self.no = no 72 | self.exception = exception 73 | self.match = match 74 | self.visit = visit 75 | self.slack = INF 76 | 77 | def __repr__(self): 78 | return "idx:" + str(self.id) + " tag: " + str(self.exception) + " match: " + \ 79 | str(self.match) + " vis: " + str(self.visit) + " slack: " + str(self.slack) 80 | 81 | 82 | class KuhnMunkres(object): 83 | 84 | def __init__(self, quick): 85 | self.matrix = None 86 | self.x_nodes = [] 87 | self.y_nodes = [] 88 | self.x_length = 0 89 | self.y_length = 0 90 | self.index_x = 0 91 | self.index_y = 1 92 | self.quick_sol = None 93 | self.is_quick = quick 94 | 95 | def set_matrix(self, x_y_values): 96 | xs = set() 97 | ys = set() 98 | for x, y, w in x_y_values: 99 | xs.add(x) 100 | ys.add(y) 101 | 102 | # 选取较小的作为x 103 | if len(xs) <= len(ys): 104 | self.index_x = 0 105 | self.index_y = 1 106 | else: 107 | self.index_x = 1 108 | self.index_y = 0 109 | xs, ys = ys, xs 110 | 111 | x_dic = {x: i for i, x in enumerate(xs)} 112 | y_dic = {y: j for j, y in enumerate(ys)} 113 | self.x_nodes = [KMNode(x, x_dic[x]) for x in xs] 114 | self.y_nodes = [KMNode(y, y_dic[y]) for y in ys] 115 | self.x_length = len(xs) 116 | self.y_length = len(ys) 117 | 118 | self.matrix = np.zeros((self.x_length, self.y_length)) 119 | for row in x_y_values: 120 | x = row[self.index_x] 121 | y = row[self.index_y] 122 | w = row[2] 123 | x_index = x_dic[x] 124 | y_index = y_dic[y] 125 | self.matrix[x_index, y_index] = w 126 | if self.x_length == 1 and self.is_quick: 127 | best_choice = int(np.argmax(self.matrix[0])) 128 | max_val = self.matrix[0][best_choice] 129 | left_id = self.x_nodes[0].id 130 | right_id = self.y_nodes[best_choice].id 131 | if self.index_x == 1: 132 | left_id, right_id = right_id, left_id 133 | match = [(left_id, right_id, max_val)] 134 | self.quick_sol = (max_val, match) 135 | return 136 | for i in range(self.x_length): 137 | self.x_nodes[i].exception = max(self.matrix[i, :]) 138 | 139 | def km(self): 140 | if self.quick_sol is not None: 141 | return 142 | for i in range(self.x_length): 143 | for node in self.y_nodes: 144 | node.slack = INF 145 | while True: 146 | for node in self.x_nodes: 147 | node.visit = False 148 | for node in self.y_nodes: 149 | node.visit = False 150 | if self.dfs(i): 151 | break 152 | d = INF 153 | for node in self.y_nodes: 154 | if (not node.visit) and d > node.slack: 155 | d = node.slack 156 | if d == INF or d < zero_threshold: 157 | break 158 | for node in self.x_nodes: 159 | if node.visit: 160 | node.exception -= d 161 | for node in self.y_nodes: 162 | if node.visit: 163 | node.exception += d 164 | else: 165 | node.slack -= d 166 | # remain order is not matching 167 | if self.index_x == 1: 168 | remain_orders = [(self.order_price_dur[x.no][0] / self.order_price_dur[x.no][1], x.no) for x in self.x_nodes if x.match is None] 169 | if len(remain_orders) == 0: 170 | return 171 | remain_drivers = [(self.income[y.no] / (self.online_time[y.no] + 0.1), y.no) for y in self.y_nodes if y.match is None] 172 | remain_drivers.sort() 173 | remain_orders.sort(reverse=True) 174 | idx = 0 175 | for _, order_no in remain_orders: 176 | driver_ratio, driver_no = remain_drivers[idx] 177 | idx += 1 178 | self.x_nodes[order_no].match = driver_no 179 | self.y_nodes[driver_no].match = order_no 180 | return 181 | 182 | def dfs(self, x): 183 | x_node = self.x_nodes[x] 184 | x_node.visit = True 185 | for y in range(self.y_length): 186 | y_node = self.y_nodes[y] 187 | if y_node.visit: 188 | continue 189 | t = x_node.exception + y_node.exception - self.matrix[x][y] 190 | if abs(t) < zero_threshold: 191 | y_node.visit = True 192 | if y_node.match is None or self.dfs(y_node.match): 193 | y_node.match = x 194 | x_node.match = y 195 | return True 196 | elif y_node.slack > t: 197 | y_node.slack = t 198 | return False 199 | 200 | def get_connect_result(self): 201 | if self.quick_sol is not None: 202 | return self.quick_sol[1] 203 | ret = [] 204 | for i in range(self.x_length): 205 | x_node = self.x_nodes[i] 206 | j = x_node.match 207 | if j is None: 208 | continue 209 | # TODO: handle those unmatched orders 210 | y_node = self.y_nodes[j] 211 | x_id = x_node.id 212 | y_id = y_node.id 213 | w = self.matrix[i][j] 214 | if self.index_x == 1 and self.index_y == 0: 215 | x_id, y_id = y_id, x_id 216 | ret.append((x_id, y_id, w)) 217 | return ret 218 | 219 | def get_max_value_result(self): 220 | if self.quick_sol is not None: 221 | return self.quick_sol[0] 222 | ret = 0 223 | for i in range(self.x_length): 224 | j = self.x_nodes[i].match 225 | if j is None: 226 | continue 227 | ret += self.matrix[i][j] 228 | return ret 229 | 230 | 231 | def find_part_block(part_block_value, quick=False): 232 | solvers = [KuhnMunkres(quick) for _ in range(len(part_block_value))] 233 | for i, solver in enumerate(solvers): 234 | solver.set_matrix(part_block_value[i]) 235 | solver.km() 236 | val = 0 237 | for solver in solvers: 238 | val += solver.get_max_value_result() 239 | matches = [solver.get_connect_result() for solver in solvers] 240 | match_all = [match[i] for match in matches for i in range(len(match))] 241 | return val, match_all 242 | 243 | 244 | def find_max_match(x_y_values, split=True, quick=True): 245 | if (not split) or (len(x_y_values) < 1): 246 | return find_part_block([x_y_values], quick) 247 | block_values = bfs_split(x_y_values) 248 | return find_part_block(block_values, quick) 249 | 250 | 251 | if __name__ == '__main__': 252 | values = [] 253 | random.seed(0) 254 | for i in range(50): 255 | for j in range(60): 256 | if i // 100 == j // 1000: 257 | value = random.random() 258 | values.append((i, j, value)) 259 | print("begin") 260 | s_time = time.time() 261 | print(find_max_match(values, split=False)) 262 | print("time usage: %s " % str(time.time() - s_time)) 263 | -------------------------------------------------------------------------------- /LTA/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/LTA/__init__.py -------------------------------------------------------------------------------- /LTA/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Dict 2 | from matcher import Matcher 3 | from scheduler import Scheduler 4 | from global_var import alpha, gamma1, gamma2 5 | from recorder import Recorder 6 | 7 | 8 | class Agent(Recorder): 9 | def __init__(self): 10 | super().__init__() 11 | self.matcher = Matcher(alpha, gamma1) 12 | self.scheduler = Scheduler(gamma2) 13 | 14 | def dispatch(self, dispatch_observ: List[Dict[str, Any]], index2hash=None) -> List[Dict[str, str]]: 15 | return self.matcher.dispatch(dispatch_observ, index2hash) 16 | 17 | def reposition(self, repo_observ: Dict[str, Any]) -> List[Dict[str, str]]: 18 | return self.scheduler.reposition(self.matcher, repo_observ) 19 | # return [] 20 | -------------------------------------------------------------------------------- /LTA/entity.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | from grid import Grid 3 | 4 | 5 | class Driver: 6 | def __init__(self, od: Dict[str, Any]): 7 | self.driver_id = od['driver_id'] # type: str 8 | self.loc = od['driver_location'] # type: Tuple[float, float] 9 | self.grid, self.grid_no = Grid.find_grid(od['driver_location'][0], od['driver_location'][1]) 10 | 11 | 12 | class Order: 13 | def __init__(self, od: Dict[str, Any]): 14 | self.order_id = od['order_id'] # type: str 15 | self.start_time = od['timestamp'] # type: int 16 | self.start_loc = od['order_start_location'] # type: Tuple[float, float] 17 | self.start_grid, self.start_grid_no = Grid.find_grid(od['order_start_location'][0], od['order_start_location'][1]) 18 | 19 | self.finish_time = od['order_finish_timestamp'] # type: int 20 | self.finish_loc = od['order_finish_location'] # type: Tuple[float, float] 21 | self.finish_grid, self.finish_grid_no = Grid.find_grid(od['order_finish_location'][0], od['order_finish_location'][1]) 22 | 23 | self.day_of_week = od['day_of_week'] # type: int 24 | self.reward = od['reward_units'] # type: float 25 | 26 | 27 | class Pair: 28 | def __init__(self, od: Dict[str, Any]): 29 | self.driver_id = od['driver_id'] # type: str 30 | self.order_id = od['order_id'] # type: str 31 | self.od_distance = od['order_driver_distance'] # type: float 32 | self.pick_up_eta = od['pick_up_eta'] # type: float 33 | self.weight = od['reward_units'] 34 | self.duration = od['order_finish_timestamp'] - od['timestamp'] 35 | 36 | def redefine_weight(self, score): 37 | self.weight = score 38 | -------------------------------------------------------------------------------- /LTA/global_var.py: -------------------------------------------------------------------------------- 1 | MIN_PER_HOUR = 60 2 | alpha = 0.025 3 | gamma1 = 0.9 4 | residual = 0.09 5 | gamma2 = gamma1 + residual 6 | time_step = 2 7 | lng_step = 0.003 8 | lat_step = 0.003 9 | # DIRS = [[0, 0], [0, 1], [1, 0], [-1, 0], [0, -1], [1, 1], [1, -1], [-1, 1], [-1, -1]] 10 | DIRS = [[3, 1], [1, -3], [-3, -1], [-1, 3], [0, 0]] 11 | # DIRS = [[0, 0], [0, 1], [1, 0], [0, -1], [-1, 0]] 12 | DIR_NUM = len(DIRS) 13 | SPEED = 3 14 | INF = 1e12 15 | QUICK = True # True: use probability quick version 16 | INIT_VALUE = False 17 | 18 | INIT_GRID = False # initial values for grid values 19 | TIME_INTERVAL = 10 * 60 20 | 21 | IS_CLEAR = True # clear up the values when day_of_week changes 22 | 23 | -------------------------------------------------------------------------------- /LTA/grid.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | import sys 4 | from typing import Tuple 5 | if sys.platform == 'darwin': 6 | from model.utils import get_path, acc_dist 7 | from model.global_var import INF 8 | else: 9 | from utils import get_path, acc_dist 10 | from global_var import INF 11 | 12 | div_quantile = 2000 13 | 14 | 15 | class Grid: 16 | grids = pickle.load(open(get_path(__file__, "grids_info"), "rb")) 17 | grid_ids = pickle.load(open(get_path(__file__, "grid_id"), "rb")) 18 | kdtree = pickle.load(open(get_path(__file__, "kdtree"), "rb")) 19 | min_lng, min_lat, max_lng, max_lat = 200, 200, 0, 0 20 | for grid_id in grids: 21 | lng, lat = grids[grid_id] 22 | min_lng = min(min_lng, lng) 23 | max_lng = max(max_lng, lng) 24 | min_lat = min(min_lat, lat) 25 | max_lat = max(max_lat, lat) 26 | step_lng = (max_lng - min_lng) / div_quantile 27 | step_lat = (max_lat - min_lat) / div_quantile 28 | mesh = [[[] for i in range(div_quantile + 20)] for j in range(div_quantile + 20)] 29 | for idx, grid_id in enumerate(grid_ids): 30 | lng, lat = grids[grid_id] 31 | mesh[int((lng - min_lng) / step_lng) + 1][int((lat - min_lat) / step_lat) + 1].append((idx, grids[grid_id])) 32 | # for i in range(div_quantile + 20): 33 | # for j in range(div_quantile + 20): 34 | # if len(mesh[i][j]) > 1: 35 | # print(len(mesh[i][j])) 36 | # print(min_lng, max_lng, min_lat, max_lat, step_lng, step_lat) 37 | dx = [0, 0, 0, 1, 1, 1, -1, -1, -1] 38 | dy = [0, 1, -1, 1, 0, -1, 0, 1, -1] 39 | 40 | @staticmethod 41 | def get_grid_ids(): 42 | return Grid.grid_ids 43 | 44 | @staticmethod 45 | def _find_grid(lng: float, lat: float) -> Tuple[str, int]: 46 | _, i = Grid.kdtree.query([lng, lat]) 47 | return Grid.grid_ids[i], i 48 | 49 | @staticmethod 50 | def find_grid(lng: float, lat: float) -> Tuple[str, int]: 51 | i = int((lng - Grid.min_lng) / Grid.step_lng) + 1 52 | j = int((lat - Grid.min_lat) / Grid.step_lat) + 1 53 | min_dis = 10000 54 | idx = -1 55 | try: 56 | for di in range(9): 57 | for id, lng_lat in Grid.mesh[i + Grid.dx[di]][j + Grid.dy[di]]: 58 | dis = (lng - lng_lat[0]) * (lng - lng_lat[0]) + (lat - lng_lat[1]) * (lat - lng_lat[1]) 59 | # dis = acc_dist(lng, lat, lng_lat[0], lng_lat[1]) 60 | if min_dis > dis: 61 | idx = id 62 | min_dis = dis 63 | except: 64 | idx = 0 65 | return Grid.grid_ids[idx], idx 66 | 67 | @staticmethod 68 | def mahattan_distance(grid_hash0: str, grid_hash1: str) -> float: 69 | if grid_hash0 in Grid.grids and grid_hash1 in Grid.grids: 70 | lng0, lat0 = Grid.grids[grid_hash0] 71 | lng1, lat1 = Grid.grids[grid_hash1] 72 | delta_lng = 0.685 * abs(lng0 - lng1) 73 | delta_lat = abs(lat0 - lat1) 74 | return 111320 * math.sqrt(delta_lat * delta_lat + delta_lng * delta_lng) 75 | return INF 76 | -------------------------------------------------------------------------------- /LTA/grid_id: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/LTA/grid_id -------------------------------------------------------------------------------- /LTA/grids_info: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/LTA/grids_info -------------------------------------------------------------------------------- /LTA/kdtree: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/LTA/kdtree -------------------------------------------------------------------------------- /LTA/matcher.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Dict, List, Set, Tuple, Any 3 | import sys 4 | import math 5 | if sys.platform == 'darwin': 6 | from model.utils import finish_prob, get_layer_id, discrete_time 7 | from model.entity import Pair, Driver, Order 8 | from model.global_var import lng_step, lat_step, DIRS, DIR_NUM, TIME_INTERVAL, IS_CLEAR 9 | from model.KM import find_max_match 10 | else: 11 | from utils import finish_prob, get_layer_id, discrete_time 12 | from entity import Pair, Driver, Order 13 | from global_var import lng_step, lat_step, DIRS, DIR_NUM, TIME_INTERVAL, IS_CLEAR 14 | from KM import find_max_match 15 | 16 | 17 | class Matcher: 18 | def __init__(self, alpha, gamma): 19 | self.alpha = alpha 20 | self.gamma = gamma 21 | self.dow = -1 22 | self.cur_discrete_time = 0 23 | self.grid_values = collections.defaultdict(float) 24 | self.layer_values = collections.defaultdict(float) 25 | 26 | def dispatch(self, dispatch_observ: List[Dict[str, Any]], index2hash) -> List[Dict[str, str]]: 27 | if len(dispatch_observ) == 0: 28 | return [] 29 | self.cur_discrete_time = discrete_time(dispatch_observ[0]['timestamp']) 30 | if self.dow != dispatch_observ[0]['day_of_week']: 31 | self.dow = dispatch_observ[0]['day_of_week'] 32 | if IS_CLEAR: 33 | self.grid_values = collections.defaultdict(float) 34 | self.layer_values = collections.defaultdict(float) 35 | drivers, orders, pairs = Matcher.parse_dispatch(dispatch_observ) 36 | edges = [] # type: List[Pair] 37 | for key in pairs: # type: str 38 | pair = pairs[key] # type: Pair 39 | order = orders[pair.order_id] 40 | driver = drivers[pair.driver_id] 41 | duration = int(1. * pair.duration / TIME_INTERVAL) 42 | duration = max(1, duration) 43 | v0 = self.get_smoothed_value(driver.loc, driver.grid) 44 | v1 = self.get_smoothed_value(order.finish_loc, order.finish_grid) 45 | done_prob = finish_prob(pair.od_distance, order.start_loc, order.finish_loc, order.start_time) 46 | 47 | if done_prob > 0: 48 | gamma = math.pow(self.gamma, duration) 49 | complete_update = order.reward + gamma * v1 - v0 50 | expected_update = done_prob * complete_update 51 | pair.redefine_weight(expected_update) 52 | edges.append(pair) 53 | 54 | # Assign drivers 55 | assigned_driver_ids = set() # type: Set[str] 56 | # begin hungary 57 | driver_order_to_score = {str(each.driver_id) + '#' + str(each.order_id): each.weight for each in edges} 58 | values = [(each.driver_id, each.order_id, each.weight) for each in edges] 59 | 60 | val, dispatch_tuple = find_max_match(x_y_values=values, split=True, quick=True) 61 | dispatch = [dict(driver_id=each[0], order_id=each[1]) for each in dispatch_tuple] 62 | for each in dispatch: 63 | assigned_driver_ids.add(each['driver_id']) 64 | driver = drivers[each['driver_id']] 65 | key = str(each['driver_id']) + '#' + str(each['order_id']) 66 | if key in driver_order_to_score: 67 | score = driver_order_to_score[key] 68 | self.update_value(driver.loc, driver.grid, self.alpha * score) 69 | 70 | for driver in drivers.values(): 71 | if driver.driver_id in assigned_driver_ids: 72 | continue 73 | v0 = self.get_smoothed_value(driver.loc, driver.grid) 74 | v1 = v0 75 | update = self.gamma * v1 - v0 # no reward 76 | self.update_value(driver.loc, driver.grid, self.alpha * update) 77 | 78 | return dispatch 79 | 80 | def get_grid_ids(self) -> Set[str]: 81 | return set(self.grid_values.keys()) 82 | 83 | def get_grid_value(self, grid_id: str) -> float: 84 | return self.grid_values[grid_id] 85 | 86 | def get_smoothed_value(self, loc: Tuple[float, float], grid_id: str) -> float: 87 | value = self.grid_values[grid_id] 88 | for i, one_dir in enumerate(DIRS): 89 | layer_id = get_layer_id(loc[0] + one_dir[0] * lng_step, loc[1] + one_dir[1] * lat_step, direction=0) 90 | value += self.layer_values[layer_id] 91 | return value / (1 + DIR_NUM) 92 | 93 | def update_value(self, loc: Tuple[float, float], grid_id: str, delta: float) -> None: 94 | self.grid_values[grid_id] += delta 95 | for i, one_dir in enumerate(DIRS): 96 | layer_id = get_layer_id(loc[0] + one_dir[0] * lng_step, loc[1] + one_dir[1] * lat_step, direction=0) 97 | self.layer_values[layer_id] += delta 98 | 99 | @staticmethod 100 | def parse_dispatch(dispatch_observ: List[Dict[str, Any]]) -> (Dict[str, Driver], Dict[str, Order], Dict[str, Set[Pair]]): 101 | drivers = collections.OrderedDict() # type: collections.OrderedDict[str, Driver] 102 | orders = collections.OrderedDict() # type: collections.OrderedDict[str, Order] 103 | pairs = collections.OrderedDict() # type: collections.OrderedDict[str, Pair] 104 | for pair_raw in dispatch_observ: 105 | driver = Driver(pair_raw) 106 | drivers[driver.driver_id] = driver 107 | order = Order(pair_raw) 108 | orders[order.order_id] = order 109 | key = str(order.order_id) + "#" + str(driver.driver_id) 110 | pairs[key] = Pair(pair_raw) 111 | return drivers, orders, pairs 112 | -------------------------------------------------------------------------------- /LTA/recorder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | import math 4 | from typing import Dict, List, Any, Set 5 | import time 6 | 7 | 8 | def acc_dist(lng1: float, lat1: float, lng2: float, lat2: float) -> float: 9 | delta_lat = (lat1 - lat2) / 2 10 | delta_lng = (lng1 - lng2) / 2 11 | arc_pi = 3.14159265359 / 180 12 | R = 6378137 13 | return 2 * R * math.asin(math.sqrt( 14 | math.sin(arc_pi * delta_lat) ** 2 + math.cos(arc_pi * lat1) * math.cos(arc_pi * lat2) * ( 15 | math.sin(arc_pi * delta_lng) ** 2))) 16 | 17 | 18 | def sec2time(sec): 19 | sec = int(sec) 20 | second = sec % 60 21 | hour = sec // 3600 22 | minute = (sec - hour * 3600) // 60 23 | if second < 10: 24 | str_second = '0' + str(second) 25 | else: 26 | str_second = str(second) 27 | if minute < 10: 28 | str_minute = '0' + str(minute) 29 | else: 30 | str_minute = str(minute) 31 | if hour < 10: 32 | str_hour = '0' + str(hour) 33 | else: 34 | str_hour = str(hour) 35 | return str_hour + ":" + str_minute + ":" + str_second 36 | 37 | 38 | class Recorder: 39 | def __init__(self): 40 | self.drivers_total_income = defaultdict(float) 41 | self.drivers_online_time = defaultdict(float) 42 | self.drivers_log_on_off = defaultdict(list) 43 | self.drivers_serving_order_info = defaultdict(list) 44 | self.drivers_income_per_hour = defaultdict(lambda: [0 for i in range(25)]) 45 | self.active_drivers = set() 46 | self.median_ratio = 0. 47 | 48 | def __update_online_time(self, drivers_online_time: Dict[str, int]): 49 | """ 50 | update the driver's online time 51 | :param drivers_online_time: a dict, with key is driver's id, value is online time (in seconds) 52 | :return: None 53 | """ 54 | for driver_id in drivers_online_time: 55 | self.drivers_online_time[driver_id] = drivers_online_time[driver_id] 56 | 57 | def update_log_on(self, online_drivers_hash: Set[str], online_drivers_loc, timestamp): 58 | """ 59 | update the driver log_on time 60 | :param online_drivers_hash: online drivers' hashcode at this timestamp 61 | :param timestamp: current timestamp 62 | """ 63 | self.active_drivers = self.active_drivers.union(online_drivers_hash) 64 | for driver_hash in online_drivers_hash: 65 | if len(self.drivers_log_on_off[driver_hash]) != 0: 66 | print("MULTIPLE LOG ON!!") 67 | self.drivers_log_on_off[driver_hash].append((timestamp, online_drivers_loc[driver_hash])) 68 | 69 | def update_log_off(self, offline_drivers_hash: Set[str], offline_drivers_loc, timestamp: int): 70 | """ 71 | update the driver log_off time 72 | :param offline_drivers_hash: offline drivers' hashcode at this timestamp 73 | :param timestamp: current timestamp 74 | """ 75 | self.active_drivers = self.active_drivers.difference(offline_drivers_hash) 76 | for driver_hash in offline_drivers_hash: 77 | if len(self.drivers_log_on_off[driver_hash]) != 1: 78 | print("LOG OFF BEFORE LOG ON!!") 79 | 80 | self.drivers_log_on_off[driver_hash].append((timestamp, offline_drivers_loc[driver_hash])) 81 | ratios = [self.drivers_total_income[driver] / (0.1 + timestamp - self.drivers_log_on_off[driver][0][0]) 82 | for driver in self.active_drivers] 83 | ratios.sort() 84 | if len(ratios) > 0: 85 | self.median_ratio = ratios[len(ratios) // 2] 86 | 87 | def update_driver_income_after_rejection(self, assignment: List[Dict[str, Any]], 88 | dispatch_observ: List[Dict[str, Any]], index2hash: Dict[int, str]): 89 | """ 90 | this function update the driver's income. 91 | Should be called after the rejection process. 92 | :param assignment: a list of dicts, one dict is <'order_id': xxx, 'driver_id':xxx> 93 | :param dispatch_observ: the same as agent.matching parameter. 94 | :param index2hash: driver_id to driver_hash 95 | :return: None 96 | """ 97 | if len(dispatch_observ) == 0: 98 | return 99 | cur_hour = time.localtime(int(dispatch_observ[0]['timestamp'])).tm_hour 100 | order_price = {od['order_id']: od['reward_units'] for od in dispatch_observ} 101 | # for all recorders 102 | order_info = {od['order_id']: [od['reward_units'], 103 | od['order_driver_distance'], 104 | od['order_start_location'], 105 | od['order_finish_location'], 106 | od['order_start_timestamp'], 107 | od['order_finish_timestamp'], 108 | od['pick_up_eta'], 109 | od['driver_location'], 110 | od['real_order_id']] for od in dispatch_observ} 111 | for pair in assignment: 112 | self.drivers_total_income[index2hash[pair['driver_id']]] += order_price[pair['order_id']] 113 | self.drivers_income_per_hour[index2hash[pair['driver_id']]][cur_hour] += order_price[pair['order_id']] 114 | self.drivers_serving_order_info[index2hash[pair['driver_id']]].append(order_info[pair['order_id']]) 115 | return 116 | 117 | def save_logs(self, solpath: str, city: str, date: str, notes="", dealine_drivers_loc=None): 118 | """ 119 | After one day simulation, output the driver's income and his/her online time into file 120 | :param solpath: str, the solution path 121 | :param date: str, the simulation date, eg. 20201129 122 | :param city: str, the city name, eg. chengdu 123 | :param notes: str, the parameter setting information 124 | :return: None 125 | """ 126 | bad1 = 0 127 | bad2 = 0 128 | bad3 = 0 129 | for driver_hash in self.drivers_log_on_off: 130 | if len(self.drivers_log_on_off[driver_hash]) == 1: 131 | bad1 += 1 132 | continue 133 | if len(self.drivers_log_on_off[driver_hash]) == 2: 134 | bad2 += 1 135 | online_ts = self.drivers_log_on_off[driver_hash][0][0] 136 | offline_ts = self.drivers_log_on_off[driver_hash][1][0] 137 | self.drivers_online_time[driver_hash] = offline_ts - online_ts 138 | continue 139 | if len(self.drivers_log_on_off[driver_hash]) > 2: 140 | bad3 += 1 141 | continue 142 | print("collision:", bad1, bad2, bad3, len(self.drivers_log_on_off)) 143 | # 上下线时间加上4小时,补齐logonoff数据 144 | drivers_log_on_off_fixed = defaultdict(list) 145 | for driver_hash in self.drivers_log_on_off: 146 | if len(self.drivers_log_on_off[driver_hash]) == 1: 147 | offline_loc = dealine_drivers_loc[driver_hash] 148 | online_ts = self.drivers_log_on_off[driver_hash][0][0] 149 | online_loc = self.drivers_log_on_off[driver_hash][0][1] 150 | if len(self.drivers_serving_order_info[driver_hash]) == 0: 151 | offline_timestamp = int(time.mktime(time.strptime(date + " " + sec2time(24 * 3600 - 1), "%Y%m%d %H:%M:%S"))) + 1 152 | else: 153 | offline_timestamp = self.drivers_serving_order_info[driver_hash][-1][-4] 154 | online_ts += 4 * 3600 155 | online_timestamp = int(time.mktime(time.strptime(date + " " + sec2time(online_ts), "%Y%m%d %H:%M:%S"))) 156 | drivers_log_on_off_fixed[driver_hash] = [(online_timestamp, online_loc), (offline_timestamp, offline_loc)] 157 | elif len(self.drivers_log_on_off[driver_hash]) == 2: 158 | online_ts = self.drivers_log_on_off[driver_hash][0][0] 159 | online_loc = self.drivers_log_on_off[driver_hash][0][1] 160 | offline_ts = self.drivers_log_on_off[driver_hash][1][0] 161 | offline_loc = self.drivers_log_on_off[driver_hash][1][1] 162 | online_ts += 4 * 3600 163 | offline_ts += 4 * 3600 164 | online_timestamp = int(time.mktime(time.strptime(date + " " + sec2time(online_ts), "%Y%m%d %H:%M:%S"))) 165 | offline_timestamp = int(time.mktime(time.strptime(date + " " + sec2time(offline_ts), "%Y%m%d %H:%M:%S"))) 166 | drivers_log_on_off_fixed[driver_hash] = [(online_timestamp, online_loc), (offline_timestamp, offline_loc)] 167 | 168 | driver_perhour_income = dict() 169 | for driver in self.drivers_income_per_hour: 170 | driver_perhour_income[driver] = self.drivers_income_per_hour[driver] 171 | solname = solpath.split('/')[-1] 172 | pickle.dump(drivers_log_on_off_fixed, open(solpath + "/" + solname + "_" + city + "_" + date + "logonoff" + notes, "wb")) 173 | # pickle.dump(self.drivers_online_time, open(solpath + "/" + solname + "_" + city + "_" + date + "online_time" + notes, "wb")) 174 | pickle.dump(self.drivers_total_income, open(solpath + "/" + solname + "_" + city + "_" + date + "total_income" + notes, "wb")) 175 | # pickle.dump(driver_perhour_income, open(solpath + "/" + solname + "_" + city + "_" + date + "perhourincome" + notes, "wb")) 176 | pickle.dump(self.drivers_serving_order_info, open(solpath + "/" + solname + "_" + city + "_" + date + "order_info" + notes, "wb")) 177 | -------------------------------------------------------------------------------- /LTA/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, List, Any, Tuple 3 | from grid import Grid 4 | from matcher import Matcher 5 | from global_var import SPEED 6 | import time 7 | 8 | 9 | class Scheduler: 10 | def __init__(self, gamma: float): 11 | self.gamma = gamma 12 | self.grid_ids = Grid.get_grid_ids() 13 | 14 | def reposition(self, matcher, repo_observ) -> List[Dict[str, str]]: 15 | if len(repo_observ['driver_info']) == 0: 16 | return [] 17 | timestamp, day_of_week, drivers = Scheduler.parse_repo(repo_observ) 18 | grid_ids = self.grid_ids 19 | reposition = [] # type: List[Dict[str, str]] 20 | for driver_id, current_grid_id in drivers: 21 | best_grid_id, best_value = current_grid_id, -100 22 | current_value = matcher.get_grid_value(current_grid_id) 23 | for grid_id in grid_ids: 24 | duration = Grid.mahattan_distance(current_grid_id, grid_id) / SPEED 25 | discount = math.pow(0.999, duration) 26 | proposed_value = matcher.get_grid_value(grid_id) 27 | incremental_value = discount * proposed_value - current_value 28 | if incremental_value > best_value: 29 | best_grid_id, best_value = grid_id, incremental_value 30 | reposition.append(dict(driver_id=driver_id, destination=best_grid_id)) 31 | return reposition 32 | 33 | @staticmethod 34 | def parse_repo(repo_observ): 35 | timestamp = repo_observ['timestamp'] # type: int 36 | cur_local = time.localtime(timestamp) 37 | cur_time = cur_local.tm_hour * 3600 + cur_local.tm_min * 60 + cur_local.tm_sec - 4 * 3600 38 | day_of_week = repo_observ['day_of_week'] # type: int 39 | drivers = [(driver['driver_id'], driver['grid_id']) 40 | for driver in repo_observ['driver_info']] 41 | return cur_time, day_of_week, drivers 42 | -------------------------------------------------------------------------------- /LTA/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import random 5 | import math 6 | from global_var import QUICK 7 | 8 | 9 | def acc_dist(lng1, lat1, lng2, lat2): 10 | delta_lat = (lat1 - lat2) / 2 11 | delta_lng = (lng1 - lng2) / 2 12 | arc_pi = 3.14159265359 / 180 13 | R = 6378137 14 | return 2 * R * math.asin(math.sqrt(math.sin(arc_pi * delta_lat) ** 2 + math.cos(arc_pi * lat1) * math.cos(arc_pi * lat2) * (math.sin(arc_pi * delta_lng) ** 2))) 15 | 16 | 17 | def get_cancel_prob(od_distance, start_loc, finish_loc, timestamp): 18 | if QUICK: 19 | return 0.01 * math.exp(math.log(20)/2000. * od_distance) 20 | dest_distance = acc_dist(start_loc[0], start_loc[1], finish_loc[0], finish_loc[1]) 21 | type1 = [0.01170420419107327, 0.011207019692459854, 0.011546449093435978, 0.01884863953849936, 0.02491680458581119, 22 | 0.03516195815907564, 0.04086130456994513, 0.048299280067015636, 0.05361063371897413, 0.06108357592353861] 23 | type2 = [0.08593582973402296, 0.09519422058873243, 0.1017559777282097, 0.11247952637804468, 0.1257643815269181, 24 | 0.15757382064280756, 0.18561752617526786, 0.21870050620163034, 0.25365262800365274, 0.28741045770387896] 25 | type3 = [0.20698865741088152, 0.25653585750468955, 0.28873054746716675, 0.3025450351782363, 0.3452418484052541, 26 | 0.39487368752345364, 0.4442851417448408, 0.4919876854596625, 0.5838032553470912, 0.6622435844277653] 27 | prob = 0. 28 | while timestamp < 1477944000: 29 | timestamp += 86400 30 | time_id = int(((timestamp - 1477944000) % 86400) // 1800) 31 | slot = int(od_distance / 200.) 32 | slot = slot if slot <= 9 else 9 33 | if dest_distance > 20000.: 34 | prob = type1[slot] 35 | elif 10000. < dest_distance <= 20000.: 36 | if 0 <= time_id < 8: 37 | r = random.random() * 19.5 38 | if 0. <= r < 18.: 39 | prob = type1[slot] 40 | else: 41 | prob = type2[slot] 42 | else: 43 | r = random.random() * 19. 44 | if 0. <= r < 16.: 45 | prob = type1[slot] 46 | else: 47 | prob = type2[slot] 48 | else: 49 | r = random.random() * 20. 50 | if 0 <= time_id < 8: 51 | if 0. <= r < 18.: 52 | prob = type1[slot] 53 | elif 18. <= r < 19.5: 54 | prob = type2[slot] 55 | else: 56 | prob = type3[slot] 57 | else: 58 | if 0. <= r < 16.: 59 | prob = type1[slot] 60 | elif 16. <= r < 19.: 61 | prob = type2[slot] 62 | else: 63 | prob = type3[slot] 64 | if dest_distance <= 1000.: 65 | prob = 0.5 66 | return prob 67 | 68 | 69 | def finish_prob(od_distance, start_loc, finish_loc, timestamp): 70 | return 1 - max(min(get_cancel_prob(od_distance, start_loc, finish_loc, timestamp), 1), 0) 71 | 72 | 73 | def pnpoly(testx, testy, boundary): 74 | nvert = boundary.shape[0] 75 | c = 0 76 | i = 0 77 | j = nvert - 1 78 | vertx = boundary[:, 0] 79 | verty = boundary[:, 1] 80 | while i < nvert: 81 | if (((verty[i] > testy) != (verty[j] > testy)) and 82 | (testx < (vertx[j] - vertx[i]) * (testy - verty[i]) / (verty[j] - verty[i]) + vertx[i])): 83 | c = 1 ^ c 84 | j = i 85 | i = i + 1 86 | return c 87 | 88 | 89 | def judge_area(lng, lat, boundary, fuzzy=False): 90 | boundary = np.array(boundary) 91 | [lng_max, lat_max] = np.amax(boundary, axis=0) 92 | [lng_min, lat_min] = np.amin(boundary, axis=0) 93 | if lng < lng_min or lng > lng_max or lat < lat_min or lat > lat_max: 94 | return False 95 | if fuzzy: 96 | return True 97 | else: 98 | c = pnpoly(lng, lat, boundary) 99 | if c == 1: 100 | return True 101 | else: 102 | return False 103 | 104 | 105 | def discrete_location(lng, lat, kdtree, grids): 106 | _, ids = kdtree.query([lng, lat], k=8) 107 | for one_id in ids: 108 | if judge_area(lng, lat, grids[one_id][1]): 109 | return one_id 110 | return -1 111 | 112 | 113 | def discrete_time(timestamp): 114 | tm = pd.Timestamp(timestamp, unit='s', tz='Asia/Shanghai') 115 | return int(tm.hour * 60 + tm.minute) // 10 116 | 117 | 118 | def get_path(path, file_name): 119 | return os.path.join(os.path.dirname(os.path.abspath(path)), file_name) 120 | 121 | 122 | def rehash(dispatch_observ): 123 | driver_id_orig2new = dict() 124 | order_id_orig2new = dict() 125 | driver_id_new2orig = list() 126 | order_id_new2orig = list() 127 | driver_cnt = 0 128 | order_cnt = 0 129 | for each in dispatch_observ: 130 | driver_id = each["driver_id"] 131 | if driver_id not in driver_id_orig2new: 132 | driver_id_orig2new[driver_id] = driver_cnt 133 | driver_id_new2orig.append(driver_id) 134 | driver_cnt += 1 135 | order_id = each["order_id"] 136 | if order_id not in order_id_orig2new: 137 | order_id_orig2new[order_id] = order_cnt 138 | order_id_new2orig.append(order_id) 139 | order_cnt += 1 140 | return driver_id_orig2new, order_id_orig2new, driver_id_new2orig, order_id_new2orig 141 | 142 | 143 | def get_pairs(lmate, row_is_driver, driver_id_new2orig, order_id_new2orig): 144 | dispatch_action = [] 145 | for i in range(len(lmate)): 146 | if lmate[i] != -1: 147 | if row_is_driver: 148 | dispatch_action.append(dict(order_id=order_id_new2orig[lmate[i]], driver_id=driver_id_new2orig[i])) 149 | else: 150 | dispatch_action.append(dict(order_id=order_id_new2orig[i], driver_id=driver_id_new2orig[lmate[i]])) 151 | return dispatch_action 152 | 153 | 154 | def get_topK(dispatch_observ, k=10): 155 | order_to_dis_idx = dict() 156 | dispatch_observ_after_sift = [] 157 | for i in range(len(dispatch_observ)): 158 | elem = dispatch_observ[i] 159 | if elem['order_id'] not in order_to_dis_idx: 160 | order_to_dis_idx[elem['order_id']] = [] 161 | order_to_dis_idx[elem['order_id']].append((elem['order_driver_distance'], i)) 162 | for key in order_to_dis_idx: 163 | order_to_dis_idx[key].sort() 164 | for j in range(min(k, len(order_to_dis_idx[key]))): 165 | dispatch_observ_after_sift.append(dispatch_observ[order_to_dis_idx[key][j][1]]) 166 | return dispatch_observ_after_sift 167 | 168 | 169 | def build_graph(dispatch_observ, driver_id_orig2new, order_id_orig2new): 170 | # assure row < colum 171 | driver_num = len(driver_id_orig2new) 172 | order_num = len(order_id_orig2new) 173 | row_is_driver = driver_num <= order_num 174 | if row_is_driver: 175 | costs = np.zeros([driver_num, order_num]) 176 | else: 177 | costs = np.zeros([order_num, driver_num]) 178 | for each in dispatch_observ: 179 | driver_new_id = driver_id_orig2new[each["driver_id"]] 180 | order_new_id = order_id_orig2new[each["order_id"]] 181 | if row_is_driver: 182 | costs[driver_new_id][order_new_id] = each['score'] 183 | else: 184 | costs[order_new_id][driver_new_id] = each['score'] 185 | return costs, row_is_driver 186 | 187 | 188 | def rebuild_by_score(edges): 189 | dispatch_observ = [] 190 | for pair in edges: 191 | elem = dict() 192 | elem['order_id'] = pair.order_id 193 | elem['driver_id'] = pair.driver_id 194 | elem['score'] = pair.weight 195 | elem['pick_up_eta'] = pair.pick_up_eta 196 | elem['order_driver_distance'] = pair.od_distance 197 | dispatch_observ.append(elem) 198 | return dispatch_observ 199 | 200 | 201 | def get_layer_id(lng, lat, direction = 0): 202 | return f'{direction:02d}#{lng:.2f}#{lat:.2f}' 203 | 204 | -------------------------------------------------------------------------------- /NNP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dingyuan-shi/Learning-To-Dispatch/e5360630286b8831b4a6e7bdd8d522e5446ac354/NNP/__init__.py -------------------------------------------------------------------------------- /NNP/agent.py: -------------------------------------------------------------------------------- 1 | from recorder import Recorder 2 | from typing import List, Dict, Any 3 | from ortools.graph import pywrapgraph 4 | 5 | 6 | class Agent(Recorder): 7 | """ Agent for dispatching and reposition """ 8 | 9 | def __init__(self, **kwargs): 10 | """ Load your trained model and initialize the parameters """ 11 | super().__init__() 12 | 13 | def dispatch(self, dispatch_observ: List[Dict[str, Any]], index2hash=None) -> List[Dict[str, int]]: 14 | """ Compute the assignment between drivers and passengers at each time step 15 | :param dispatch_observ: a list of dict, the key in the dict includes: 16 | order_id, int 17 | driver_id, int 18 | order_driver_distance, float 19 | order_start_location, a list as [lng, lat], float 20 | order_finish_location, a list as [lng, lat], float 21 | driver_location, a list as [lng, lat], float 22 | timestamp, int 23 | order_finish_timestamp, int 24 | day_of_week, int 25 | reward_units, float 26 | pick_up_eta, float 27 | :param index2hash: driver_id to driver_hash 28 | :return: a list of dict, the key in the dict includes: 29 | order_id and driver_id, the pair indicating the assignment 30 | """ 31 | dispatch = [] 32 | global_num = 2 33 | order2idx = dict() 34 | idx2order = dict() 35 | driver2idx = dict() 36 | idx2driver = dict() 37 | for od in dispatch_observ: 38 | order_id = od['order_id'] 39 | driver_id = od['driver_id'] 40 | if order_id not in order2idx: 41 | order2idx[order_id] = global_num 42 | idx2order[global_num] = order_id 43 | global_num += 1 44 | if driver_id not in driver2idx: 45 | driver2idx[driver_id] = global_num 46 | idx2driver[global_num] = driver_id 47 | global_num += 1 48 | start_nodes = [] 49 | end_nodes = [] 50 | capacities = [] 51 | unit_costs = [] 52 | for od in dispatch_observ: 53 | order_idx = order2idx[od['order_id']] 54 | driver_idx = driver2idx[od['driver_id']] 55 | cost = int(od['order_driver_distance']) 56 | start_nodes.append(order_idx) 57 | end_nodes.append(driver_idx) 58 | unit_costs.append(cost) 59 | capacities.append(1) 60 | 61 | src = 0 62 | dst = 1 63 | for order_idx in idx2order: 64 | start_nodes.append(src) 65 | end_nodes.append(order_idx) 66 | unit_costs.append(0) 67 | capacities.append(1) 68 | 69 | for driver_idx in idx2driver: 70 | start_nodes.append(driver_idx) 71 | end_nodes.append(dst) 72 | unit_costs.append(0) 73 | capacities.append(1) 74 | min_cost_flow = pywrapgraph.SimpleMinCostFlow() 75 | for i in range(0, len(start_nodes)): 76 | min_cost_flow.AddArcWithCapacityAndUnitCost(start_nodes[i], end_nodes[i], capacities[i], unit_costs[i]) 77 | min_cost_flow.SetNodeSupply(0, min(len(order2idx), len(driver2idx))) 78 | min_cost_flow.SetNodeSupply(1, -min(len(order2idx), len(driver2idx))) 79 | min_cost_flow.SolveMaxFlowWithMinCost() 80 | for i in range(min_cost_flow.NumArcs()): 81 | if min_cost_flow.Flow(i) > 0.1: 82 | if min_cost_flow.Tail(i) != 0 and min_cost_flow.Head(i) != 1: 83 | dispatch.append(dict(order_id=idx2order[min_cost_flow.Tail(i)], driver_id=idx2driver[min_cost_flow.Head(i)])) 84 | return dispatch 85 | 86 | def reposition(self, repo_observ): 87 | """ Compute the reposition action for the given drivers 88 | :param repo_observ: a dict, the key in the dict includes: 89 | timestamp: int 90 | driver_info: a list of dict, the key in the dict includes: 91 | driver_id: driver_id of the idle driver in the treatment group, int 92 | grid_id: id of the grid the driver is located at, str 93 | day_of_week: int 94 | :return: a list of dict, the key in the dict includes: 95 | driver_id: corresponding to the driver_id in the od_list 96 | destination: id of the grid the driver is repositioned to, str 97 | """ 98 | # repo_action = [] 99 | # for driver in repo_observ['driver_info']: 100 | # # the default reposition is to let drivers stay where they are 101 | # repo_action.append({'driver_id': driver['driver_id'], 'destination': driver['grid_id']}) 102 | # return repo_action 103 | return [] 104 | -------------------------------------------------------------------------------- /NNP/recorder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from collections import defaultdict 3 | import math 4 | from typing import Dict, List, Any, Set 5 | import time 6 | 7 | 8 | def acc_dist(lng1: float, lat1: float, lng2: float, lat2: float) -> float: 9 | delta_lat = (lat1 - lat2) / 2 10 | delta_lng = (lng1 - lng2) / 2 11 | arc_pi = 3.14159265359 / 180 12 | R = 6378137 13 | return 2 * R * math.asin(math.sqrt( 14 | math.sin(arc_pi * delta_lat) ** 2 + math.cos(arc_pi * lat1) * math.cos(arc_pi * lat2) * ( 15 | math.sin(arc_pi * delta_lng) ** 2))) 16 | 17 | 18 | class Recorder: 19 | 20 | def __init__(self): 21 | self.drivers_total_income = defaultdict(float) 22 | self.drivers_online_time = defaultdict(float) 23 | self.drivers_log_on_off = defaultdict(list) 24 | self.drivers_serving_order_info = defaultdict(list) 25 | self.drivers_income_per_hour = defaultdict(lambda: [0 for i in range(25)]) 26 | self.active_drivers = set() 27 | self.median_ratio = 0. 28 | 29 | def __update_online_time(self, drivers_online_time: Dict[str, int]): 30 | """ 31 | update the driver's online time 32 | :param drivers_online_time: a dict, with key is driver's id, value is online time (in seconds) 33 | :return: None 34 | """ 35 | for driver_id in drivers_online_time: 36 | self.drivers_online_time[driver_id] = drivers_online_time[driver_id] 37 | 38 | def update_log_on(self, online_drivers_hash: Set[str], timestamp): 39 | """ 40 | update the driver log_on time 41 | :param online_drivers_hash: online drivers' hashcode at this timestamp 42 | :param timestamp: current timestamp 43 | """ 44 | self.active_drivers = self.active_drivers.union(online_drivers_hash) 45 | for driver_hash in online_drivers_hash: 46 | if len(self.drivers_log_on_off[driver_hash]) != 0: 47 | print("MULTIPLE LOG ON!!") 48 | self.drivers_log_on_off[driver_hash].append(timestamp) 49 | 50 | def update_log_off(self, offline_drivers_hash: Set[str], timestamp): 51 | """ 52 | update the driver log_off time 53 | :param offline_drivers_hash: offline drivers' hashcode at this timestamp 54 | :param timestamp: current timestamp 55 | """ 56 | self.active_drivers = self.active_drivers.difference(offline_drivers_hash) 57 | for driver_hash in offline_drivers_hash: 58 | if len(self.drivers_log_on_off[driver_hash]) != 1: 59 | print("LOG OFF BEFORE LOG ON!!") 60 | self.drivers_log_on_off[driver_hash].append(timestamp) 61 | ratios = [self.drivers_total_income[driver] / (0.1 + timestamp - self.drivers_log_on_off[driver][0]) 62 | for driver in self.active_drivers] 63 | ratios.sort() 64 | if len(ratios) > 0: 65 | self.median_ratio = ratios[len(ratios) // 2] 66 | 67 | def update_driver_income_after_rejection(self, assignment: List[Dict[str, Any]], 68 | dispatch_observ: List[Dict[str, Any]], index2hash: Dict[int, str]): 69 | """ 70 | this function update the driver's income. 71 | Should be called after the rejection process. 72 | :param assignment: a list of dicts, one dict is <'order_id': xxx, 'driver_id':xxx> 73 | :param dispatch_observ: the same as agent.matching parameter. 74 | :param index2hash: driver_id to driver_hash 75 | :return: None 76 | """ 77 | if len(dispatch_observ) == 0: 78 | return 79 | cur_hour = time.localtime(int(dispatch_observ[0]['timestamp'])).tm_hour 80 | order_price = {od['order_id']: od['reward_units'] for od in dispatch_observ} 81 | # for all recorders 82 | order_info = {od['order_id']: [od['reward_units'], 83 | od['order_driver_distance'], 84 | od['order_start_location'], 85 | od['order_finish_location'], 86 | od['timestamp'], 87 | od['order_finish_timestamp'], 88 | od['pick_up_eta']] for od in dispatch_observ} 89 | for pair in assignment: 90 | self.drivers_total_income[index2hash[pair['driver_id']]] += order_price[pair['order_id']] 91 | self.drivers_income_per_hour[index2hash[pair['driver_id']]][cur_hour] += order_price[pair['order_id']] 92 | self.drivers_serving_order_info[index2hash[pair['driver_id']]].append(order_info[pair['order_id']]) 93 | return 94 | 95 | def save_logs(self, solpath: str, city: str, date: str, notes=""): 96 | """ 97 | After one day simulation, output the driver's income and his/her online time into file 98 | :param solpath: str, the solution path 99 | :param date: str, the simulation date, eg. 20201129 100 | :param city: str, the city name, eg. chengdu 101 | :param notes: str, the parameter setting information 102 | :return: None 103 | """ 104 | bad1 = 0 105 | bad2 = 0 106 | bad3 = 0 107 | for driver_hash in self.drivers_log_on_off: 108 | if len(self.drivers_log_on_off[driver_hash]) == 1: 109 | bad1 += 1 110 | continue 111 | if len(self.drivers_log_on_off[driver_hash]) == 2: 112 | bad2 += 1 113 | log_on, log_off = self.drivers_log_on_off[driver_hash] 114 | self.drivers_online_time[driver_hash] = log_off - log_on 115 | continue 116 | if len(self.drivers_log_on_off[driver_hash]) > 2: 117 | bad3 += 1 118 | continue 119 | print("collision:", bad1, bad2, bad3, len(self.drivers_log_on_off)) 120 | driver_perhour_income = dict() 121 | for driver in self.drivers_income_per_hour: 122 | driver_perhour_income[driver] = self.drivers_income_per_hour[driver] 123 | solname = solpath.split('/')[-1] 124 | pickle.dump(self.drivers_log_on_off, open(solpath + "/" + solname + "_" + city + "_" + date + "logonoff" + notes, "wb")) 125 | pickle.dump(self.drivers_online_time, open(solpath + "/" + solname + "_" + city + "_" + date + "online_time" + notes, "wb")) 126 | pickle.dump(self.drivers_total_income, open(solpath + "/" + solname + "_" + city + "_" + date + "total_income" + notes, "wb")) 127 | pickle.dump(driver_perhour_income, open(solpath + "/" + solname + "_" + city + "_" + date + "perhourincome" + notes, "wb")) 128 | pickle.dump(self.drivers_serving_order_info, open(solpath + "/" + solname + "_" + city + "_" + date + "order_info" + notes, "wb")) 129 | 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Source code of paper Combinatorial Optimization Meets Reinforcement Learning: Effective Taxi Order Dispatching at Large-Scale (TKDE 2022). 2 | 3 | Proposed method is in ./LTA/ 4 | 5 | Baselines are in other directories. --------------------------------------------------------------------------------