├── .gitignore ├── A_star ├── A_star_traj.py └── __init__.py ├── DTA-SIM ├── data │ ├── nanshan_grid.cpg │ ├── nanshan_grid.dbf │ ├── nanshan_grid.prj │ ├── nanshan_grid.sbn │ ├── nanshan_grid.sbx │ ├── nanshan_grid.shp │ ├── nanshan_grid.shx │ ├── nanshan_network.gexf │ ├── nanshan_road.cpg │ ├── nanshan_road.dbf │ ├── nanshan_road.prj │ ├── nanshan_road.sbn │ ├── nanshan_road.sbx │ ├── nanshan_road.shp │ ├── nanshan_road.shx │ └── reward.xlsx ├── dta_grid.ipynb ├── dta_simluation.ipynb └── img │ ├── grid_dta.png │ ├── grid_dta_with_nodes.png │ └── road_dta.png ├── DTW └── dtw.py ├── LICENSE ├── README.md ├── causal_data.py ├── causal_learn_pc_detailed.py ├── causal_plot_detailed.ipynb ├── data ├── evaluate │ ├── traj_evaluate_be.npy │ ├── traj_evaluate_pemirl.npy │ ├── traj_evaluate_psl.npy │ └── traj_evaluate_rl.npy ├── nanshan_grid.cpg ├── nanshan_grid.dbf ├── nanshan_grid.prj ├── nanshan_grid.sbn ├── nanshan_grid.sbx ├── nanshan_grid.shp ├── nanshan_grid.shx ├── nanshan_tfidf_be.csv ├── ns_sf.csv └── routes_states │ ├── 0_0_states_tuple.npy │ ├── 0_1_states_tuple.npy │ ├── 0_2_states_tuple.npy │ ├── 1_0_states_tuple.npy │ ├── 1_1_states_tuple.npy │ └── 1_2_states_tuple.npy ├── deep_irl_be.py ├── deep_irl_realworld.py ├── demo_deepirl_be.py ├── demo_deepirl_realworld.py ├── demo_recursive_logit.py ├── dnn_psl.py ├── essay └── MEDIRL-IC.pdf ├── evaluate_traj_IRL.py ├── evaluate_traj_psl.py ├── img ├── MEDIRL_IC.png ├── box_comparison.png ├── causal_directed_graph.png ├── causal_strength.png ├── dta.png ├── reward_and_likelihood_plot.png ├── reward_map_based_on_IC.png └── sz_reward_map.png ├── img_utils.py ├── model ├── be │ ├── checkpoint │ ├── policy_realworld.npy │ ├── realworld.data-00000-of-00001 │ ├── realworld.index │ └── realworld.meta ├── checkpoint ├── dnn psl │ └── model_weights.pth ├── policy_realworld.npy ├── realworld.data-00000-of-00001 ├── realworld.index ├── realworld.meta └── rlogit_realworld.npy ├── nshortest_path.ipynb ├── plot ├── box plot.ipynb └── training process.ipynb ├── realGrid ├── real_grid.py └── value_iteration.py ├── recursive_logit.py ├── test_dirl_be.py ├── test_dirl_realworld.py ├── test_dnn_psl.py ├── test_recursive_logit.py ├── tf_utils.py ├── traj_contrast.py ├── traj_contrast_psl.py ├── traj_policy_logll.py ├── trajectory.ipynb ├── trajectory.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | ./data 2 | ./img 3 | ./test.py 4 | **/__pycache__ 5 | **/.vscode 6 | **/.DS_Store -------------------------------------------------------------------------------- /A_star/A_star_traj.py: -------------------------------------------------------------------------------- 1 | import math 2 | class Map(object): 3 | def __init__(self,states_list,fnid_idx,idx_fnid,cost): 4 | self.states = states_list 5 | self.fnid_idx=fnid_idx 6 | self.idx_fnid=idx_fnid 7 | self.cost=cost 8 | 9 | # Manhattan Distance 10 | class Node(object): 11 | def __init__(self,state,g,h,cost,father): 12 | self.s = state 13 | self.g = g 14 | self.h = h 15 | self.c=cost 16 | self.len=1 17 | self.father = father 18 | 19 | def calDist(self,s_state,e_state): 20 | s_x,s_y=s_state%357,s_state//357 21 | e_x,e_y=e_state%357,e_state//357 22 | dx = abs(s_x - e_x) 23 | dy = abs(s_y - e_y) 24 | return math.sqrt(dx * dx + dy * dy) 25 | 26 | def getNeighbor(self,mapState,end_state): 27 | s =self.s 28 | result = [] 29 | self.len+=1 30 | def calCost(s_state): 31 | nonlocal mapState 32 | idx=mapState.fnid_idx[s_state] 33 | return mapState.cost[idx] 34 | #up 35 | if (s+357) in mapState.states: 36 | start_state=s+357 37 | upNode = Node(start_state,self.g+1,self.calDist(start_state,end_state),calCost(start_state),self) 38 | result.append(upNode) 39 | #down 40 | if (s-357) in mapState.states: 41 | start_state=s-357 42 | upNode = Node(start_state,self.g+1,self.calDist(start_state,end_state),calCost(start_state),self) 43 | result.append(upNode) 44 | #left 45 | if (s-1) in mapState.states: 46 | start_state=s-1 47 | upNode = Node(start_state,self.g+1,self.calDist(start_state,end_state),calCost(start_state),self) 48 | result.append(upNode) 49 | #right 50 | if (s+1) in mapState.states: 51 | start_state=s+1 52 | upNode = Node(start_state,self.g+1,self.calDist(start_state,end_state),calCost(start_state),self) 53 | result.append(upNode) 54 | #up-left 55 | if (s+356) in mapState.states: 56 | start_state=s+356 57 | upNode = Node(start_state,self.g+1.4,self.calDist(start_state,end_state),calCost(start_state),self) 58 | result.append(upNode) 59 | #up-right 60 | if (s+358) in mapState.states: 61 | start_state=s+358 62 | upNode = Node(start_state,self.g+1.4,self.calDist(start_state,end_state),calCost(start_state),self) 63 | result.append(upNode) 64 | #down-left 65 | if (s-358) in mapState.states: 66 | start_state=s-358 67 | upNode = Node(start_state,self.g+1.4,self.calDist(start_state,end_state),calCost(start_state),self) 68 | result.append(upNode) 69 | #down-right 70 | if (s-356) in mapState.states: 71 | start_state=s-356 72 | upNode = Node(start_state,self.g+1.4,self.calDist(start_state,end_state),calCost(start_state),self) 73 | result.append(upNode) 74 | 75 | return result 76 | 77 | def hasNode(self,worklist): 78 | for i in worklist: 79 | if(i.s==self.s): 80 | return True 81 | return False 82 | 83 | # if hasNode=True 84 | def changeG(self,worklist): 85 | for i in worklist: 86 | if(i.s==self.s): 87 | if(i.g>self.g): 88 | i.g = self.g 89 | 90 | def getKeyforSort(element:Node): 91 | return element.g+element.c*element.len+element.h 92 | 93 | def astar(workMap,start_state,end_state): 94 | startNode = Node(start_state, 0, 0, 0, None) 95 | openList = [] 96 | lockList = [] 97 | lockList.append(startNode) 98 | currNode = startNode 99 | 100 | while end_state != currNode.s: 101 | workList = currNode.getNeighbor(workMap,end_state) 102 | for i in workList: 103 | if (i not in lockList): 104 | if(i.hasNode(openList)): 105 | i.changeG(openList) 106 | else: 107 | openList.append(i) 108 | openList.sort(key=getKeyforSort) 109 | currNode = openList.pop(0) 110 | lockList.append(currNode) 111 | 112 | result = [] 113 | while(currNode.father!=None): 114 | result.append((currNode.s)) 115 | currNode = currNode.father 116 | result.append((currNode.s)) 117 | return result -------------------------------------------------------------------------------- /A_star/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/A_star/__init__.py -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_grid.dbf -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.sbn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_grid.sbn -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.sbx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_grid.sbx -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_grid.shp -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_grid.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_grid.shx -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_network.gexf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_network.gexf -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_road.dbf -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.sbn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_road.sbn -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.sbx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_road.sbx -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_road.shp -------------------------------------------------------------------------------- /DTA-SIM/data/nanshan_road.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/nanshan_road.shx -------------------------------------------------------------------------------- /DTA-SIM/data/reward.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/data/reward.xlsx -------------------------------------------------------------------------------- /DTA-SIM/img/grid_dta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/img/grid_dta.png -------------------------------------------------------------------------------- /DTA-SIM/img/grid_dta_with_nodes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/img/grid_dta_with_nodes.png -------------------------------------------------------------------------------- /DTA-SIM/img/road_dta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/DTA-SIM/img/road_dta.png -------------------------------------------------------------------------------- /DTW/dtw.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | class DTW(object): 4 | def __init__(self,traj,traj_generate): 5 | self.traj=traj 6 | self.traj_generate=traj_generate 7 | 8 | def calDistance(self,x, y): 9 | return math.sqrt(pow(x[0]-y[0],2)+pow(x[1]-y[1],2)) 10 | 11 | def dtw(self): 12 | X=self.fnidToXY(self.traj) 13 | Y=self.fnidToXY(self.traj_generate) 14 | l1 = len(X) 15 | l2 = len(Y) 16 | D = [[0 for i in range(l1 + 1)] for i in range(l2 + 1)] 17 | # D[0][0] = 0 18 | for i in range(1, l1 + 1): 19 | D[0][i] = sys.maxsize 20 | for j in range(1, l2 + 1): 21 | D[j][0] = sys.maxsize 22 | for j in range(1, l2 + 1): 23 | for i in range(1, l1 + 1): 24 | D[j][i] = self.calDistance(X[i - 1], Y[j-1]) + \ 25 | min(D[j - 1][i], D[j][i - 1], D[j - 1][i - 1]) 26 | return D 27 | 28 | def fnidToXY(self,traj): 29 | traj_list=[] 30 | for t in traj: 31 | x=t%357 32 | y=(t)//357 33 | traj_list.append((x,y)) 34 | return traj_list 35 | 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Boyang Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MEDIRL-IC 2 | 3 | ## Article 4 | 5 | **Enhancing Pedestrian Route Choice Models through Maximum-Entropy Deep Inverse Reinforcement Learning with Individual Covariates (MEDIRL-IC):** This article has been accepted for publication in *IEEE Transactions on Intelligent Transportation Systems*. You can access it via [IEEE Xplore](https://ieeexplore.ieee.org/document/10689250). 6 | 7 | ## Introduction 8 | 9 | This project is a collection of algorithms and models dedicated to "Deep Inverse Reinforcement Learning with Individual Covariates" in the context of pedestrian route choice. Developed by Boyang Li at SUPD, Peking University. 10 | 11 | ## Directory Structure 12 | 13 | - **A_star**: Implementation of the A* algorithm. 14 | - **DTW**: Initial setup and configurations for DTW (Dynamic Time Warping). 15 | - **data**: Contains datasets, data preprocessing scripts, and other data-related utilities. 16 | - **model**: Stores model parameters and configuration files. 17 | - **plot**: Scripts for generating visualizations and related plots. 18 | - **realGrid**: Urban grid classes and methods representing real-world scenarios. 19 | - **img**: Images utilized primarily for the project's README documentation. 20 | 21 | 22 | ### Configuration and Utility Files 23 | 24 | - `.gitignore`: Configuration file for Git to determine which files and directories to ignore before committing. 25 | - `README.md`: Provides an overview and documentation for the project. 26 | - `img_utils.py`: Utility functions related to image processing. 27 | - `tf_utils.py`: Utility functions related to TensorFlow operations. 28 | - `utils.py`: General utility functions for the project. 29 | 30 | ### Causal Analysis 31 | 32 | - `causal_data.py`: Script for handling causal data. 33 | - `causal_learn_pc_detailed.py`: Detailed learning scripts for the PC algorithm in causal inference. 34 | - `causal_plot_detailed.ipynb`: Jupyter notebook for detailed plotting and visualization of causal data. 35 | 36 | ### Deep Inverse Reinforcement Learning (IRL) 37 | 38 | - `deep_irl_be.py`: Deep IRL only considered built environment. 39 | - `deep_irl_realworld.py`: Implementation of deep IRL with IC for real-world scenarios. 40 | - `demo_deepirl_be.py`: Demonstration script for deep IRL backend. 41 | - `demo_deepirl_realworld.py`: Demonstration script for deep IRL with IC in real-world scenarios. 42 | 43 | ### Recursive Logit Model 44 | 45 | - `demo_recursive_logit.py`: Demonstration script for the recursive logit model. 46 | - `recursive_logit.py`: Implementation of the recursive logit model. 47 | - `test_recursive_logit.py`: Testing script for the recursive logit model. 48 | 49 | ### DNN-PSL Model 50 | 51 | - `dnn_psl.py`: Script related to the deep neural network model. 52 | - `test_dnn_psl.py`: Testing script for the DNN model. 53 | 54 | ### Trajectory Evaluation and Analysis 55 | 56 | - `evaluate_traj_IRL.py`: Evaluation scripts for IRL trajectories. 57 | - `evaluate_traj_psl.py`: Evaluation scripts for PSL trajectories. 58 | - `traj_contrast.py`: Scripts for trajectory contrast analysis. 59 | - `traj_contrast_psl.py`: Scripts for PSL trajectory contrast analysis. 60 | - `traj_policy_logll.py`: Scripts related to policy log likelihood for trajectories. 61 | - `trajectory.ipynb`: Jupyter notebook for trajectory functions. 62 | - `trajectory.py`: Scripts related to trajectory functions. 63 | 64 | ### Miscellaneous 65 | 66 | - `nshortest_path.ipynb`: Jupyter notebook related to the choice set of path-size logit model. 67 | 68 | ## Getting Started 69 | 70 | 1. Clone the repository. 71 | 2. Install necessary dependencies. 72 | 3. Run the desired scripts or models. 73 | 74 | ## Dependencies 75 | 76 | - **Fiona**: 1.8.13 77 | - **GDAL**: 3.0.4 78 | - **geopandas**: 0.8.2 79 | - **matplotlib**: 3.0.3 80 | - **networkx**: 2.4 81 | - **numpy**: 1.14.5+mkl 82 | - **pandas**: 0.25.3 83 | - **pyproj**: 2.5.0 84 | - **Rtree**: 0.9.3 85 | - **scipy**: 1.4.1 86 | - **seaborn**: 0.9.1 87 | - **Shapely**: 1.6.4.post2 88 | - **tensorflow**: 0.12.1 89 | 90 | ## Data Availability 91 | 92 | Due to individual privacy concerns, we only provide geographical data for the training region along with a limited set of encrypted individual trajectory data. 93 | 94 | 95 | ## Results and Visualizations 96 | 97 | ### DeepIRL 98 | ![MEDIRL-IC Framework](img/MEDIRL_IC.png) 99 | 100 | ### Box Comparison 101 | ![Model Comparison](img/box_comparison.png) 102 | 103 | ### Reward and Likelihood Plot 104 | ![Reward and Likelihood Plot](img/reward_and_likelihood_plot.png) 105 | 106 | ### Dynamic Traffic Equilibrium 107 | ![Dynamic Traffic Equilibrium](img/dta.png) 108 | 109 | ### Reward Map based on IC 110 | ![Reward Map based on IC](img/reward_map_based_on_IC.png) 111 | 112 | ### SZ Reward Map 113 | ![SZ Reward Map](img/sz_reward_map.png) 114 | 115 | ### Causal Directed Graph 116 | ![Causal Directed Graph](img/causal_directed_graph.png) 117 | 118 | ### Causal Strength 119 | ![Causal Strength Graph](img/causal_strength.png) 120 | 121 | 122 | ## License and Credits 123 | 124 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 125 | ## Citation 126 | 127 | Li, B., & Zhang, W. (2024). Enhancing Pedestrian Route Choice Models Through Maximum-Entropy Deep Inverse Reinforcement Learning With Individual Covariates (MEDIRL-IC). *IEEE Transactions on Intelligent Transportation Systems*, 1-18. [https://doi.org/10.1109/TITS.2024.3457680](https://doi.org/10.1109/TITS.2024.3457680) 128 | 129 | ```bibtex 130 | @ARTICLE{10689250, 131 | author={Li, Boyang and Zhang, Wenjia}, 132 | journal={IEEE Transactions on Intelligent Transportation Systems}, 133 | title={Enhancing Pedestrian Route Choice Models Through Maximum-Entropy Deep Inverse Reinforcement Learning With Individual Covariates (MEDIRL-IC)}, 134 | year={2024}, 135 | pages={1-18}, 136 | keywords={Pedestrians; Analytical models; Decision making; Predictive models; Reinforcement learning; Biological system modeling; Trajectory; Deep inverse reinforcement learning; pedestrian route choice; causal discovery; cell phone signaling}, 137 | doi={10.1109/TITS.2024.3457680} 138 | } 139 | ``` 140 | -------------------------------------------------------------------------------- /causal_data.py: -------------------------------------------------------------------------------- 1 | from deep_irl_realworld import * 2 | import pandas as pd 3 | 4 | 5 | def getReward(feature_map_file, genderAge, model_file): 6 | """ 7 | get reward map from ckpt file 8 | """ 9 | lr = 0.02 10 | 11 | # get feature map of state 12 | feature_map_excel = pd.read_csv(feature_map_file) 13 | feat_map_df = feature_map_excel.iloc[:, 1:] 14 | feat_map = np.array(feat_map_df) 15 | 16 | nn_r = DeepIRLFC(feat_map.shape[1], lr, genderAge, 40, 30) 17 | model_name = 'realworld' 18 | if os.path.exists(os.path.join(model_file, model_name+'.meta')): 19 | print('restore graph from ckpt file') 20 | nn_r.restoreGraph(model_file, model_name) 21 | else: 22 | print("there isn't ckpt file") 23 | return 24 | 25 | oh = [genderAge for _ in range(feat_map.shape[0])] 26 | rewards = normalize(nn_r.get_rewards(feat_map, oh)) 27 | 28 | return rewards, feat_map_df 29 | 30 | 31 | if __name__ == "__main__": 32 | feature_map_file = './data/ns_sf.csv' 33 | gpd_file = './data/nanshan_grid.shp' 34 | model_file = './model' 35 | route_file_path = './data/routes_states' 36 | 37 | final_df = pd.DataFrame() 38 | for f in os.listdir(route_file_path): 39 | genderAge = [0]*5 40 | gender, age = int(f[0]), int(f[2]) 41 | genderAge[gender], genderAge[age+2] = 1, 1 42 | tf.reset_default_graph() 43 | reward, feat_df = getReward(feature_map_file, genderAge, model_file) 44 | reward = [r[0] for r in reward] 45 | 46 | reward_df = pd.DataFrame({'reward': reward}) 47 | df_length = len(reward_df) 48 | gender_age_df = pd.DataFrame({'male': [genderAge[0]]*df_length, 'female':[genderAge[1]]*df_length, 49 | 'young': [genderAge[2]]*df_length, 'middle': [genderAge[3]]*df_length, 'old': [genderAge[4]]*df_length}) 50 | feat_reward_df = pd.concat([feat_df, gender_age_df, reward_df], axis=1) 51 | final_df = final_df.append(feat_reward_df, ignore_index=True) 52 | final_df.to_csv('./data/nanshan_reward.csv',index=False,header=None) -------------------------------------------------------------------------------- /causal_learn_pc_detailed.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | from itertools import chain, combinations 4 | from pyexpat import model 5 | 6 | import matplotlib.pyplot as plt 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | from scipy.stats import norm, pearsonr 11 | 12 | 13 | def dfs(graph, k: int, path: list, vis: list, labels: list): 14 | """ 15 | Depth First Search for causal relation search 16 | 17 | Args: 18 | graph : cdpag 19 | k : start index 20 | path : depth first search path 21 | vis : if visited 22 | labels: labels list 23 | """ 24 | flag = True 25 | 26 | for i in range(len(graph)): 27 | if (graph[i][k]) and (vis[i] != True): 28 | flag = True 29 | vis[i] = True 30 | path.append(labels[i]) 31 | dfs(graph, i, path, vis, labels) 32 | path.pop() 33 | vis[i] = False 34 | 35 | if flag: 36 | print(path) 37 | 38 | 39 | def causalGraphPlot(graph, labels: list, pic_path: str, model_path: str): 40 | """ 41 | visualize beysian network 42 | 43 | Args: 44 | graph : networkx graph 45 | labels (list): label list 46 | path (str): picture save path 47 | model_path (str): model save path 48 | """ 49 | G = nx.DiGraph() 50 | 51 | for i in range(len(graph)): 52 | if i <= 29: 53 | G.add_node(labels[i], partition=1) 54 | elif 29 < i < 35: 55 | G.add_node(labels[i], partition=2) 56 | else: 57 | G.add_node(labels[i], partition=3) 58 | for j in range(len(graph[i])): 59 | if graph[i][j]: 60 | # G.add_edges_from([(labels[i], labels[j])]) 61 | G.add_weighted_edges_from( 62 | [(labels[i], labels[j], graph[i][j])], weight='weight') 63 | 64 | nx.write_gexf(G, model_path) 65 | nx.draw(G, with_labels=True) 66 | # plt.savefig(pic_path) 67 | plt.show() 68 | 69 | 70 | def subset(iterable): 71 | """ 72 | calculate sunbet 73 | subset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3) 74 | 75 | Args: 76 | iterable : Iterable variables 77 | 78 | Returns: 79 | iterable : Iterable subset variables 80 | """ 81 | xs = list(iterable) 82 | return chain.from_iterable(combinations(xs, n) for n in range(len(xs) + 1)) 83 | 84 | 85 | def skeleton(suffStat, indepTest, alpha, labels, m_max): 86 | """ 87 | construct skeleton for beysian network 88 | 89 | Args: 90 | suffStat : sufficient states {"C":correlation coefficient,"n":number of states} 91 | indepTest : gaussian independence test 92 | alpha : minimum for conditional independence test 93 | labels : label list 94 | m_max : max subset length 95 | 96 | Returns: 97 | sk+sepset: {"sk":np.array(G),"sepset set":d-seperation set for each point pair} 98 | """ 99 | # Completely undirected graphs 100 | sepset = [[[] for i in range(len(labels))] for i in range(len(labels))] 101 | G = [[1 for i in range(len(labels))] for i in range(len(labels))] 102 | 103 | for i in range(len(labels)): 104 | G[i][i] = 0 105 | 106 | done = False # flag 107 | 108 | ord = 0 # subset length 109 | while done != True and any(G) and ord <= m_max: 110 | done = True 111 | 112 | # neighboring point set 113 | ind = [] 114 | for i in range(len(G)): 115 | for j in range(len(G[i])): 116 | if G[i][j]: 117 | ind.append((i, j)) 118 | 119 | G1 = G.copy() 120 | 121 | for x, y in ind: 122 | if G[x][y]: 123 | neighborsBool = [row[x] for row in G1] 124 | neighborsBool[y] = 0 125 | 126 | # adj(C,x) \ {y} 127 | neighbors = [i for i in range( 128 | len(neighborsBool)) if neighborsBool[i]] 129 | 130 | if len(neighbors) >= ord: 131 | 132 | # |adj(C, x) \ {y}| > ord 133 | if len(neighbors) > ord: 134 | done = False 135 | 136 | # |adj(C, x) \ {y}| = ord 137 | for neighbors_S in set(itertools.combinations(neighbors, ord)): 138 | # if x and y are d-seperated by neighbors_S 139 | # conditional independence,return p-value 140 | pval, correlation_r = indepTest( 141 | suffStat, x, y, list(neighbors_S)) 142 | 143 | if pval >= alpha: 144 | # if pval>alpha, x is independent from y 145 | G[x][y] = G[y][x] = 0 146 | 147 | # add neighbors_S into seperation set 148 | sepset[x][y] = list(neighbors_S) 149 | break 150 | else: 151 | G[x][y] = G[y][x] = correlation_r 152 | 153 | ord += 1 154 | return {'sk': np.array(G), 'sepset': sepset} 155 | 156 | 157 | def extendCpdag(graph): 158 | """ 159 | turn a partially directed acyclic graph to completed partially directed acyclic graph 160 | 161 | Args: 162 | graph : {"skeleton":np.array(G),"seperation set":d-seperation set for each point pair} 163 | """ 164 | 165 | def rule1(pdag): 166 | """ 167 | If there is a chain a -> b - c, and a, c is not adjacent, change b - c to b - > c 168 | 169 | Args: 170 | pdag : partially directed acyclic graph 171 | 172 | Returns: 173 | pdag: partially directed acyclic graph 174 | """ 175 | search_pdag = pdag.copy() 176 | ind = [] 177 | for i in range(len(pdag)): 178 | for j in range(len(pdag)): 179 | if pdag[i][j] and pdag[j][i] == 0: 180 | ind.append((i, j)) 181 | 182 | # 183 | for a, b in sorted(ind, key=lambda x: (x[1], x[0])): 184 | isC = [] 185 | 186 | for i in range(len(search_pdag)): 187 | if (search_pdag[b][i] and search_pdag[i][b]) and (search_pdag[a][i] == 0 and search_pdag[i][a] == 0): 188 | isC.append(i) 189 | 190 | if len(isC) > 0: 191 | for c in isC: 192 | if 'unfTriples' in graph.keys() and ((a, b, c) in graph['unfTriples'] or (c, b, a) in graph['unfTriples']): 193 | # if unfaithful, skip 194 | continue 195 | if pdag[b][c] and pdag[c][b]: 196 | pdag[b][c] = graph['sk'][b][c] 197 | pdag[c][b] = 0 198 | elif pdag[b][c] == 0 and pdag[c][b]: 199 | pdag[b][c] = pdag[c][b] = 2 200 | 201 | return pdag 202 | 203 | def rule2(pdag): 204 | """ 205 | If there is a chain a -> c -> b, change a - b to a -> b 206 | 207 | Args: 208 | pdag : partially directed acyclic graph 209 | 210 | Returns: 211 | pdag: partially directed acyclic graph 212 | """ 213 | search_pdag = pdag.copy() 214 | ind = [] 215 | 216 | for i in range(len(pdag)): 217 | for j in range(len(pdag)): 218 | if pdag[i][j] and pdag[j][i]: 219 | ind.append((i, j)) 220 | 221 | # 222 | for a, b in sorted(ind, key=lambda x: (x[1], x[0])): 223 | isC = [] 224 | for i in range(len(search_pdag)): 225 | if (search_pdag[a][i] and search_pdag[i][a] == 0) and (search_pdag[i][b] and search_pdag[b][i] == 0): 226 | isC.append(i) 227 | if len(isC) > 0: 228 | if pdag[a][b] and pdag[b][a]: 229 | pdag[a][b] = graph['sk'][a][b] 230 | pdag[b][a] = 0 231 | elif pdag[a][b] == 0 and pdag[b][a]: 232 | pdag[a][b] = pdag[b][a] = 2 233 | 234 | return pdag 235 | 236 | def rule3(pdag): 237 | """ 238 | If a - c1 - > b and a - c2 - > b, and c1, c2 are not adjacent, change a - b to a -> b 239 | 240 | Args: 241 | pdag (_type_): _description_ 242 | 243 | Returns: 244 | pdag: partially directed acyclic graph 245 | """ 246 | search_pdag = pdag.copy() 247 | ind = [] 248 | for i in range(len(pdag)): 249 | for j in range(len(pdag)): 250 | if pdag[i][j] and pdag[j][i]: 251 | ind.append((i, j)) 252 | 253 | # 254 | for a, b in sorted(ind, key=lambda x: (x[1], x[0])): 255 | isC = [] 256 | 257 | for i in range(len(search_pdag)): 258 | if (search_pdag[a][i] and search_pdag[i][a]) and (search_pdag[i][b] and search_pdag[b][i] == 0): 259 | isC.append(i) 260 | 261 | if len(isC) >= 2: 262 | for c1, c2 in combinations(isC, 2): 263 | if search_pdag[c1][c2] == 0 and search_pdag[c2][c1] == 0: 264 | # unfaithful 265 | if 'unfTriples' in graph.keys() and ((c1, a, c2) in graph['unfTriples'] or (c2, a, c1) in graph['unfTriples']): 266 | continue 267 | if search_pdag[a][b] and search_pdag[b][a]: 268 | pdag[a][b] = graph['sk'][a][b] 269 | pdag[b][a] = 0 270 | break 271 | elif search_pdag[a][b] == 0 and search_pdag[b][a]: 272 | pdag[a][b] = pdag[b][a] = 2 273 | break 274 | 275 | return pdag 276 | 277 | pdag = [[0 if graph['sk'][i][j] == 0 else graph['sk'][i][j] for i in range( 278 | len(graph['sk']))] for j in range(len(graph['sk']))] 279 | 280 | ind = [] 281 | for i in range(len(pdag)): 282 | for j in range(len(pdag[i])): 283 | if pdag[i][j]: 284 | ind.append((i, j)) 285 | 286 | # Change x - y - z to x -> y <- z 287 | for x, y in sorted(ind, key=lambda x: (x[1], x[0])): 288 | allZ = [] 289 | for z in range(len(pdag)): 290 | if graph['sk'][y][z] and z != x: 291 | allZ.append(z) 292 | 293 | for z in allZ: 294 | if graph['sk'][x][z] == 0 and graph['sepset'][x][z] != None and graph['sepset'][z][x] != None and not (y in graph['sepset'][x][z] or y in graph['sepset'][z][x]): 295 | pdag[y][x] = pdag[y][z] = 0 296 | pdag[x][y] = graph['sk'][x][y] 297 | pdag[z][y] = graph['sk'][z][y] 298 | 299 | # # apply rule1 - rule3 300 | pdag = rule1(pdag) 301 | pdag = rule2(pdag) 302 | pdag = rule3(pdag) 303 | 304 | return np.array(pdag) 305 | 306 | 307 | def pc(suffStat, alpha, labels, indepTest, skeleton_path, m_max=float("inf"), verbose=False): 308 | """ 309 | PC algorithm 310 | 311 | Args: 312 | suffStat : sufficient states {"C":correlation coefficient,"n":number of states} 313 | alpha : minimun conditional independence score 314 | labels : labels of each state 315 | indepTest : gaussian independence test 316 | m_max : max sunbset length 317 | verbose : log display 318 | 319 | Returns: 320 | cpdag: Completed partially directed acyclic graph 321 | """ 322 | # skeleton 323 | graphDict = skeleton(suffStat, indepTest, alpha, labels, m_max) 324 | df_graph = pd.DataFrame(graphDict['sk']) 325 | df_graph.to_csv(skeleton_path, index=False) 326 | # exteng to CPDAG 327 | cpdag = extendCpdag(graphDict) 328 | # print beysian network matrix 329 | if verbose: 330 | print(cpdag) 331 | return cpdag 332 | 333 | 334 | def gaussCiTest(suffstat, x, y, S): 335 | """ 336 | test for conditional independence 337 | 338 | Args: 339 | suffStat : sufficient states {"C":correlation coefficient,"n":number of states} 340 | x : causal parameter 341 | y : causal parameter 342 | S : d-seperation point list 343 | 344 | Returns: 345 | p-value: conditional independence parameter 346 | """ 347 | C = suffstat["C"] 348 | n = suffstat["n"] 349 | 350 | cut_at = 0.9999999 351 | 352 | # Zero-order partial correlation coefficient 353 | if len(S) == 0: 354 | r = C[x, y] 355 | 356 | # First-order partial correlation coefficient 357 | elif len(S) == 1: 358 | r = (C[x, y] - C[x, S] * C[y, S]) / \ 359 | math.sqrt((1 - math.pow(C[y, S], 2)) * (1 - math.pow(C[x, S], 2))) 360 | 361 | # High-order partial correlation coefficient 362 | else: 363 | m = C[np.ix_([x]+[y] + S, [x] + [y] + S)] 364 | PM = np.linalg.pinv(m) 365 | r = -1 * PM[0, 1] / math.sqrt(abs(PM[0, 0] * PM[1, 1])) 366 | 367 | r = min(cut_at, max(-1 * cut_at, r)) 368 | 369 | # Fisher’s z-transform 370 | res = math.sqrt(n - len(S) - 3) * .5 * math.log1p((2 * r) / (1 - r)) 371 | 372 | # Φ^{-1}(1-α/2) 373 | return 2 * (1 - norm.cdf(abs(res))), r 374 | 375 | 376 | if __name__ == '__main__': 377 | file_path = './data/nanshan_reward.csv' 378 | image_path = './img/causalGraph.png' 379 | model_path = './data/causalGraph.gexf' 380 | skeleton_path = './data/skeleton.csv' 381 | 382 | data = pd.read_csv(file_path) 383 | labels = [ 384 | 'Population density', 385 | 'Land Use Mix', 386 | 'Open Space Ratio', 387 | 'Intersections', 388 | 'Center', 389 | 'Airport', 390 | 'Railway', 391 | 'Dock', 392 | 'Coach', 393 | 'Expressway', 394 | 'Cycleway', 395 | 'Suburban Road', 396 | 'City Branch', 397 | 'Inner Road', 398 | 'Main Road', 399 | 'Not Built Road', 400 | 'SideWalk', 401 | 'Urban Secondary', 402 | 'Attractions', 403 | 'Food&Beverages', 404 | 'Tranportation', 405 | 'Sport', 406 | 'Public', 407 | 'Enterprises', 408 | 'Medical', 409 | 'Government', 410 | 'Finance', 411 | 'Education&Science', 412 | 'Shopping', 413 | 'Life', 414 | 'Male', 415 | 'Female', 416 | 'Young', 417 | 'Middle', 418 | 'Old', 419 | 'Reward' 420 | ] 421 | 422 | row_count = len(labels) 423 | graph = pc( 424 | suffStat={"C": data.corr().values, "n": data.values.shape[0]}, 425 | alpha=0.05, 426 | skeleton_path=skeleton_path, 427 | labels=[str(i) for i in range(row_count)], 428 | indepTest=gaussCiTest, 429 | verbose=True 430 | ) 431 | 432 | df_graph = pd.DataFrame(graph) 433 | df_graph.to_csv('./data/edge_weight.csv', index=False) 434 | 435 | start = -1 # index for 'reward' label 436 | vis = [0 for i in range(row_count)] 437 | vis[start] = True 438 | path = [] 439 | path.append(labels[start]) 440 | dfs(graph, start, path, vis, labels) 441 | 442 | causalGraphPlot(graph, labels, image_path, model_path) 443 | -------------------------------------------------------------------------------- /data/evaluate/traj_evaluate_be.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/evaluate/traj_evaluate_be.npy -------------------------------------------------------------------------------- /data/evaluate/traj_evaluate_pemirl.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/evaluate/traj_evaluate_pemirl.npy -------------------------------------------------------------------------------- /data/evaluate/traj_evaluate_psl.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/evaluate/traj_evaluate_psl.npy -------------------------------------------------------------------------------- /data/evaluate/traj_evaluate_rl.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/evaluate/traj_evaluate_rl.npy -------------------------------------------------------------------------------- /data/nanshan_grid.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 -------------------------------------------------------------------------------- /data/nanshan_grid.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/nanshan_grid.dbf -------------------------------------------------------------------------------- /data/nanshan_grid.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]] -------------------------------------------------------------------------------- /data/nanshan_grid.sbn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/nanshan_grid.sbn -------------------------------------------------------------------------------- /data/nanshan_grid.sbx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/nanshan_grid.sbx -------------------------------------------------------------------------------- /data/nanshan_grid.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/nanshan_grid.shp -------------------------------------------------------------------------------- /data/nanshan_grid.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/nanshan_grid.shx -------------------------------------------------------------------------------- /data/routes_states/0_0_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/0_0_states_tuple.npy -------------------------------------------------------------------------------- /data/routes_states/0_1_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/0_1_states_tuple.npy -------------------------------------------------------------------------------- /data/routes_states/0_2_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/0_2_states_tuple.npy -------------------------------------------------------------------------------- /data/routes_states/1_0_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/1_0_states_tuple.npy -------------------------------------------------------------------------------- /data/routes_states/1_1_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/1_1_states_tuple.npy -------------------------------------------------------------------------------- /data/routes_states/1_2_states_tuple.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/data/routes_states/1_2_states_tuple.npy -------------------------------------------------------------------------------- /deep_irl_be.py: -------------------------------------------------------------------------------- 1 | from turtle import st 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import tensorflow as tf 6 | import datetime 7 | import realGrid.value_iteration as value_iteration 8 | import img_utils 9 | import tf_utils 10 | from utils import * 11 | import time 12 | 13 | 14 | class DeepIRLFC: 15 | def __init__(self, n_input, lr, n_h1=400, n_h2=300, l2=0.5, name='deep_irl_fc'): 16 | """initialize DeepIRl, construct function between feature and reward 17 | 18 | Args: 19 | n_input (_type_): number of features 20 | lr : learning rate 21 | n_h1 (int, optional): output size of fc1. Defaults to 400. 22 | n_h2 (int, optional): output size of fc2. Defaults to 300. 23 | l2 (int, optional): l2 loss gradient. Defaults to 0.1. 24 | name (str, optional): variable scope. Defaults to 'deep_irl_fc'. 25 | """ 26 | self.n_input = n_input 27 | self.lr = lr 28 | self.n_h1 = n_h1 29 | self.n_h2 = n_h2 30 | self.name = name 31 | 32 | self.sess = tf.Session() 33 | # input_s:States*Feature 34 | self.input_s, self.reward, self.theta = self._build_network(self.name) 35 | self.optimizer = tf.train.GradientDescentOptimizer(lr) 36 | # apply l2 loss gradient 37 | self.l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.theta]) 38 | self.grad_l2 = tf.gradients(ys=self.l2_loss, xs=self.theta) 39 | self.grad_r = tf.placeholder(tf.float32, [None, 1]) 40 | self.grad_theta = tf.gradients(self.reward, self.theta, -self.grad_r) 41 | self.grad_theta = [tf.add(l2*self.grad_l2[i], self.grad_theta[i]) 42 | for i in range(len(self.grad_l2))] 43 | # Gradient Clipping 44 | self.grad_theta, _ = tf.clip_by_global_norm(self.grad_theta, 100.0) 45 | self.grad_norms = tf.global_norm(self.grad_theta) 46 | self.optimize = self.optimizer.apply_gradients( 47 | zip(self.grad_theta, self.theta)) 48 | self.sess.run(tf.global_variables_initializer()) 49 | 50 | def _build_network(self, name): 51 | """build forward netword with 3 fully connected layer 52 | 53 | Args: 54 | name (string): variable scope 55 | 56 | Returns: 57 | input_s,reward,theta: features of states,reward of states,trainable parameters 58 | """ 59 | input_s = tf.placeholder(tf.float32, [None, self.n_input]) 60 | with tf.variable_scope(name): 61 | fc1 = tf_utils.fc(input_s, self.n_h1, scope="fc1", activation_fn=tf.nn.elu, 62 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 63 | fc2 = tf_utils.fc(fc1, self.n_h2, scope="fc2", activation_fn=tf.nn.elu, 64 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 65 | reward = tf_utils.fc(fc2, 1, scope="reward") 66 | theta = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) 67 | return input_s, reward, theta 68 | 69 | def get_theta(self): 70 | return self.sess.run(self.theta) 71 | 72 | def get_rewards(self, states): 73 | rewards = self.sess.run(self.reward, feed_dict={self.input_s: states}) 74 | return rewards 75 | 76 | def apply_grads(self, feat_map, grad_r): 77 | grad_r = np.reshape(grad_r, [-1, 1]) 78 | feat_map = np.reshape(feat_map, [-1, self.n_input]) 79 | _, grad_theta, l2_loss, grad_norms = self.sess.run([self.optimize, self.grad_theta, self.l2_loss, self.grad_norms], 80 | feed_dict={self.grad_r: grad_r, self.input_s: feat_map}) 81 | return grad_theta, l2_loss, grad_norms 82 | 83 | def restoreGraph(self, model_file, model_name): 84 | saver = tf.train.Saver() 85 | save_path = model_file+".\\"+model_name 86 | saver.restore(self.sess, save_path) 87 | 88 | 89 | def expectStateVisitFreq(P_a, gamma, trajs, fnid_idx, policy, deterministic=True): 90 | """ 91 | compute the expected states visition frequency p(s| theta, T) 92 | using dynamic programming 93 | 94 | inputs: 95 | P_a NxNxN_ACTIONS matrix - transition dynamics 96 | gamma float - discount factor 97 | trajs list of Steps - collected from expert 98 | fnid_idx {fnid:index} 99 | policy Nx1 vector (or NxN_ACTIONS if deterministic=False) - policy 100 | 101 | returns: 102 | p Nx1 vector - state visitation frequencies 103 | """ 104 | N_STATES, _, N_ACTIONS = np.shape(P_a) 105 | T = [] 106 | for traj in trajs: 107 | T.append(len(traj)) 108 | 109 | avg_T = int(np.mean(T)) 110 | # mu[s, t] is the prob of visiting state s at step t,get the 111 | mu = np.zeros([N_STATES, avg_T]) 112 | 113 | for traj in trajs: 114 | index = fnid_idx[traj[0].state] 115 | mu[index, 0] += 1 116 | mu[:, 0] = mu[:, 0]/len(trajs) 117 | 118 | for s in range(N_STATES): 119 | for t in range(avg_T-1): 120 | if deterministic: 121 | mu[s, t+1] = sum([mu[pre_s, t]*P_a[pre_s, s, int(policy[pre_s])] 122 | for pre_s in range(N_STATES)]) 123 | else: 124 | mu[s, t+1] = sum([sum([mu[pre_s, t]*P_a[pre_s, s, a1]*policy[pre_s, a1] for a1 in range(N_ACTIONS)]) 125 | for pre_s in range(N_STATES)]) 126 | p = np.sum(mu, 1) 127 | return p 128 | 129 | 130 | def stateVisitFreq(trajs, fnid_idx, n_states): 131 | """ 132 | compute state visitation frequences from demonstrations 133 | 134 | input: 135 | trajs list of list of Steps - collected from expert 136 | fnid_idx {fnid:index} 137 | n_states number of states 138 | returns: 139 | p Nx1 vector - state visitation frequences 140 | """ 141 | 142 | p = np.zeros(n_states) 143 | for traj in trajs: 144 | for step in traj: 145 | if step.state not in fnid_idx: 146 | continue 147 | idx = fnid_idx[step.state] 148 | p[idx] += 1 149 | p = p/len(trajs) 150 | return p 151 | 152 | 153 | def deepMaxEntIRL(feat_map, P_a, gamma, trajs, lr, n_iters, fnid_idx, idx_fnid, gpd_file, restore=True): 154 | """ 155 | Maximum Entropy Inverse Reinforcement Learning (Maxent IRL) 156 | 157 | inputs: 158 | feat_map NxD matrix - the features for each state 159 | P_a NxNxN_ACTIONS matrix - P_a[s0, s1, a] is the transition prob of 160 | landing at state s1 when taking action 161 | a at state s0 162 | gamma float - RL discount factor 163 | trajs a list of demonstrations 164 | lr float - learning rate 165 | n_iters int - number of optimization steps 166 | fnid_idx {fnid:index} 167 | idx_fnid {index:fnid} 168 | returns 169 | rewards Nx1 vector - recoverred state rewards 170 | """ 171 | 172 | # tf.set_random_seed(1) 173 | 174 | N_STATES, _, N_ACTIONS = np.shape(P_a) 175 | 176 | # init nn model 177 | nn_r = DeepIRLFC(feat_map.shape[1], lr, 40, 30) 178 | 179 | # restor graph 180 | model_file = './model/be' 181 | model_name = 'realworld' 182 | if restore and os.path.exists(os.path.join(model_file, model_name+'.meta')): 183 | print('restore graph from ckpt file') 184 | nn_r.restoreGraph(model_file, model_name) 185 | 186 | # find state visitation frequencies using demonstrations 187 | mu_D = stateVisitFreq(trajs, fnid_idx, N_STATES) 188 | 189 | # set pre-reward 190 | pre_reward=np.zeros(feat_map.shape[0]) 191 | 192 | T0=time.time() 193 | now_time=datetime.datetime.now() 194 | print('this loop start at {}'.format(now_time)) 195 | # training 196 | for iteration in range(n_iters): 197 | T1=time.time() 198 | if iteration % (n_iters/10) == 0: 199 | print('iteration: {}'.format(iteration)) 200 | tf.train.Saver().save(nn_r.sess, './model/be/realworld') 201 | # compute the reward matrix 202 | rewards = nn_r.get_rewards(feat_map) 203 | reward_difference=np.mean(normalize(rewards)-pre_reward) 204 | print("the current reward difference is {}".format(reward_difference)) 205 | if abs(reward_difference)<=0.001: 206 | print('the difference of reward is less than 0.001, then break the loop') 207 | break 208 | 209 | # # save picture 210 | # img_utils.rewardVisual(normalize(rewards), idx_fnid, gpd_file, "{} iteration".format(iteration)) 211 | # plt.savefig('./img/reward_{}.png'.format(iteration)) 212 | # plt.close() 213 | 214 | # compute policy 215 | values, policy = value_iteration.value_iteration( 216 | P_a, rewards, gamma, error=0.1, deterministic=True) 217 | np.save("./model/be/policy_realworld.npy",policy) 218 | print("The calculation of value and policy is finished!") 219 | # compute expected svf 220 | mu_exp = expectStateVisitFreq( 221 | P_a, gamma, trajs, fnid_idx, policy, deterministic=True) 222 | # compute gradients on rewards: 223 | grad_r = mu_D - mu_exp 224 | print("visit frequency difference is {}".format(np.mean(grad_r))) 225 | # apply gradients to the neural network 226 | grad_theta, l2_loss, grad_norm = nn_r.apply_grads(feat_map, grad_r) 227 | # calculate time pass 228 | T2=time.time() 229 | print("this iteration lasts {:.2f},the loop lasts {:.2f}".format(T2-T1,T2-T0)) 230 | # set pre reward 231 | pre_reward=normalize(rewards) 232 | tf.train.Saver().save(nn_r.sess, './model/be/realworld') 233 | np.save('./model/be/policy_realworld.npy',policy) 234 | rewards = nn_r.get_rewards(feat_map) 235 | return normalize(rewards) 236 | -------------------------------------------------------------------------------- /deep_irl_realworld.py: -------------------------------------------------------------------------------- 1 | from turtle import st 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import tensorflow as tf 6 | import datetime 7 | import realGrid.value_iteration as value_iteration 8 | import img_utils 9 | import tf_utils 10 | from utils import * 11 | import time 12 | from traj_policy_logll import * 13 | 14 | Step=namedtuple('Step',['state','action']) 15 | class DeepIRLFC: 16 | def __init__(self, n_input, lr, genderAge, n_h1=400, n_h2=300, l2=0.5, name='deep_irl_fc'): 17 | """initialize DeepIRl, construct function between feature and reward 18 | 19 | Args: 20 | n_input (_type_): number of features 21 | lr : learning rate 22 | n_h1 (int, optional): output size of fc1. Defaults to 400. 23 | n_h2 (int, optional): output size of fc2. Defaults to 300. 24 | l2 (int, optional): l2 loss gradient. Defaults to 0.1. 25 | name (str, optional): variable scope. Defaults to 'deep_irl_fc'. 26 | """ 27 | self.n_input = n_input 28 | self.lr = lr 29 | self.n_h1 = n_h1 30 | self.n_h2 = n_h2 31 | self.name = name 32 | self.embedding_dim = 32 33 | self.genderAge = genderAge 34 | 35 | self.sess = tf.Session() 36 | self.input_s, self.reward, self.theta, self.input_onehot = self._build_network( 37 | self.name) 38 | self.optimizer = tf.train.GradientDescentOptimizer(lr) 39 | # apply l2 loss gradient 40 | self.l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.theta]) 41 | self.grad_l2 = tf.gradients(ys=self.l2_loss, xs=self.theta) 42 | self.grad_reward = tf.placeholder(tf.float32, [None, 1]) 43 | self.grad_theta = tf.gradients(self.reward, self.theta, -self.grad_reward) 44 | self.grad_theta = [tf.add(l2*self.grad_l2[i], self.grad_theta[i]) 45 | for i in range(len(self.grad_l2))] 46 | # Gradient Clipping 47 | self.grad_theta, _ = tf.clip_by_global_norm(self.grad_theta, 100.0) 48 | self.grad_norms = tf.global_norm(self.grad_theta) 49 | self.optimize = self.optimizer.apply_gradients( 50 | zip(self.grad_theta, self.theta)) 51 | 52 | self.sess.run(tf.global_variables_initializer()) 53 | 54 | def _build_network(self, name): 55 | """build forward netword with 3 fully connected layer 56 | 57 | Args: 58 | name (string): variable scope 59 | 60 | Returns: 61 | input_s: features of states 62 | reward: reward of states 63 | theta: trainable parameters 64 | input_onehot= gender and age parameters 65 | """ 66 | input_s = tf.placeholder(tf.float32, [None, self.n_input]) 67 | input_onehot = tf.placeholder(tf.int32, [None, len(self.genderAge)]) 68 | 69 | embedding_matrix = tf.get_variable("embedding_matrix", [self.embedding_dim]) 70 | 71 | with tf.variable_scope(name): 72 | fc1 = tf_utils.fc(input_s, self.n_h1, scope="fc1", activation_fn=tf.nn.elu, 73 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 74 | fc2 = tf_utils.fc(fc1, self.n_h2, scope="fc2", activation_fn=tf.nn.elu, 75 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 76 | 77 | embedded_input_onehot = tf.nn.embedding_lookup(embedding_matrix, input_onehot) 78 | fc3 = tf_utils.fc(embedded_input_onehot, self.n_h1, scope="fc3", activation_fn=tf.nn.elu, 79 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 80 | fc4 = tf_utils.fc(fc3, self.n_h2, scope="fc4", activation_fn=tf.nn.elu, 81 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 82 | 83 | fc_contact = tf.concat(1, [fc2, fc4]) 84 | 85 | fc_final = tf_utils.fc(fc_contact, self.n_h2, scope="fc_final", activation_fn=tf.nn.elu, 86 | initializer=tf.contrib.layers.variance_scaling_initializer(mode="FAN_IN")) 87 | 88 | reward = tf_utils.fc(fc_final, 1, scope="reward") 89 | 90 | theta = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=name) 91 | return input_s, reward, theta, input_onehot 92 | 93 | def get_theta(self): 94 | return self.sess.run(self.theta) 95 | 96 | def get_rewards(self, states, oh): 97 | rewards = self.sess.run(self.reward, feed_dict={ 98 | self.input_s: states, self.input_onehot: oh}) 99 | return rewards 100 | 101 | def apply_grads(self, feat_map, grad_r,oh): 102 | grad_r = np.reshape(grad_r, [-1, 1]) 103 | feat_map = np.reshape(feat_map, [-1, self.n_input]) 104 | _, grad_theta, l2_loss, grad_norms = self.sess.run([self.optimize, self.grad_theta, self.l2_loss, self.grad_norms], 105 | feed_dict={self.grad_reward: grad_r, self.input_s: feat_map, self.input_onehot: oh}) 106 | return grad_theta, l2_loss, grad_norms 107 | 108 | def restoreGraph(self, model_file, model_name): 109 | saver = tf.train.Saver() 110 | save_path = model_file+".\\"+model_name 111 | saver.restore(self.sess, save_path) 112 | 113 | 114 | def expectStateVisitFreq(P_a, gamma, trajs, fnid_idx, policy, deterministic=True): 115 | """ 116 | compute the expected states visition frequency p(s| theta, T) 117 | using dynamic programming 118 | 119 | inputs: 120 | P_a NxNxN_ACTIONS matrix - transition dynamics 121 | gamma float - discount factor 122 | trajs list of Steps - collected from expert 123 | fnid_idx {fnid:index} 124 | policy Nx1 vector (or NxN_ACTIONS if deterministic=False) - policy 125 | 126 | returns: 127 | p Nx1 vector - state visitation frequencies 128 | """ 129 | N_STATES, _, N_ACTIONS = np.shape(P_a) 130 | T = [] 131 | for traj in trajs: 132 | T.append(len(traj)) 133 | 134 | avg_T = int(np.mean(T)) 135 | # mu[s, t] is the prob of visiting state s at step t,get the 136 | mu = np.zeros([N_STATES, avg_T]) 137 | 138 | for traj in trajs: 139 | index = fnid_idx[traj[0].state] 140 | mu[index, 0] += 1 141 | mu[:, 0] = mu[:, 0]/len(trajs) 142 | 143 | for s in range(N_STATES): 144 | for t in range(avg_T-1): 145 | if deterministic: 146 | mu[s, t+1] = sum([mu[pre_s, t]*P_a[pre_s, s, int(policy[pre_s])] 147 | for pre_s in range(N_STATES)]) 148 | else: 149 | mu[s, t+1] = sum([sum([mu[pre_s, t]*P_a[pre_s, s, a1]*policy[pre_s, a1] for a1 in range(N_ACTIONS)]) 150 | for pre_s in range(N_STATES)]) 151 | p = np.sum(mu, 1) 152 | return p 153 | 154 | 155 | def stateVisitFreq(trajs, fnid_idx, n_states): 156 | """ 157 | compute state visitation frequences from demonstrations 158 | 159 | input: 160 | trajs list of list of Steps - collected from expert 161 | fnid_idx {fnid:index} 162 | n_states number of states 163 | returns: 164 | p Nx1 vector - state visitation frequences 165 | """ 166 | 167 | p = np.zeros(n_states) 168 | for traj in trajs: 169 | for step in traj: 170 | if step.state not in fnid_idx: 171 | continue 172 | idx = fnid_idx[step.state] 173 | p[idx] += 1 174 | p = p/len(trajs) 175 | return p 176 | 177 | def deepMaxEntIRL2(nn_r,traj_file,feature_map_file, feat_map, P_a, gamma, trajs, lr, fnid_idx, idx_fnid, gpd_file, genderAge, restore): 178 | """ 179 | Maximum Entropy Inverse Reinforcement Learning (Maxent IRL), with personalized features 180 | """ 181 | N_STATES, _, N_ACTIONS = np.shape(P_a) 182 | nn_r.genderAge = genderAge 183 | 184 | # restore graph 185 | model_file = './model' 186 | model_name = 'realworld' 187 | if restore and os.path.exists(os.path.join(model_file, model_name+'.meta')): 188 | print('restore graph from ckpt file') 189 | nn_r.restoreGraph(model_file, model_name) 190 | 191 | # find state visitation frequencies using demonstrations 192 | mu_D = stateVisitFreq(trajs, fnid_idx, N_STATES) 193 | # optimize the neural network by the difference between expected svf(state visit frequency) and real svf 194 | oh = [genderAge for _ in range(feat_map.shape[0])] 195 | rewards = normalize(nn_r.get_rewards(feat_map, oh)) 196 | print("begin value iteration") 197 | values, policy = value_iteration.value_iteration( 198 | P_a, rewards, gamma, error=5, deterministic=True) 199 | np.save("./model/policy_realworld.npy", policy) 200 | print("The calculation of value and policy is finished!") 201 | # add traj log likelihood 202 | real_traj = trajFromExpert(traj_file) 203 | llrs = [trajLogLikelihood(feature_map_file, traj, trajFromPolicyFile(policy, fnid_idx, int(traj[0]), len(traj))) for traj in real_traj] 204 | print("The log-likelihood ratio of the generated trajectory to the true trajectory is {}".format(np.nanmean(llrs))) 205 | # compute expected svf 206 | mu_exp = expectStateVisitFreq( 207 | P_a, gamma, trajs, fnid_idx, policy, deterministic=True) 208 | # compute gradients on rewards: 209 | grad_r = mu_D - mu_exp 210 | print("visit frequency difference is {}".format(np.mean(grad_r))) 211 | # apply gradients to the neural network 212 | _,_,_ = nn_r.apply_grads(feat_map, grad_r, oh) 213 | 214 | # Store model weights 215 | tf.train.Saver().save(nn_r.sess, './model/realworld') # ckpt 216 | np.save('./model/policy_realworld.npy', policy) # npy 217 | 218 | return normalize(rewards), nn_r -------------------------------------------------------------------------------- /demo_deepirl_be.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import pandas as pd 4 | from collections import namedtuple 5 | 6 | from realGrid import real_grid 7 | from deep_irl_be import * 8 | 9 | PARSER = argparse.ArgumentParser( 10 | description="argument of deep max entropy inverse learning algorithm") 11 | PARSER.add_argument('-g', '--gamma', default=0.9, 12 | type=float, help='discount factor') 13 | PARSER.add_argument('-a', '--act_random', default=0.2, 14 | type=float, help='probability of acting randomly') 15 | PARSER.add_argument('-lr', '--learning_rate', default=0.02, 16 | type=float, help='learning rate') 17 | PARSER.add_argument('-ni', '--n_iters', default=10, 18 | type=int, help='number of iterations') 19 | PARSER.add_argument('--restore', dest='restore', action='store_true', 20 | help='restore graph from existed checkpoint file') 21 | ARGS = PARSER.parse_args() 22 | print(ARGS) 23 | 24 | GAMMA = ARGS.gamma 25 | ACT_RAND = ARGS.act_random 26 | LEARNING_RATE = ARGS.learning_rate 27 | N_ITERS = ARGS.n_iters 28 | RESTORE = ARGS.restore 29 | Step=namedtuple('Step',['state','action']) 30 | 31 | # get the transition probability between states 32 | feature_map_file = './data/nanshan_tfidf_be.csv' 33 | feature_map_excel = pd.read_csv(feature_map_file) 34 | states_list = list(feature_map_excel['fnid']) 35 | fnid_idx = {} 36 | idx_fnid = {} 37 | for i in range(len(states_list)): 38 | fnid_idx.update({states_list[i]: i}) 39 | idx_fnid.update({i: states_list[i]}) 40 | grid = real_grid.RealGridWorld(fnid_idx, idx_fnid, 1-ACT_RAND) 41 | p_a = grid.get_transition_mat() 42 | # get feature map of state 43 | feat_map = feature_map_excel.iloc[:, 1:] 44 | index_fnid = feature_map_excel['fnid'] 45 | feat_map = np.array(feat_map) 46 | 47 | # train the deep-irl without built environment 48 | gpd_file = './data/nanshan_grid.shp' 49 | route_file_path = './data/routes_states' 50 | 51 | trajs = [] 52 | for f in os.listdir(route_file_path): 53 | traj = np.load(route_file_path+'/'+f,allow_pickle=True) 54 | traj = traj.tolist() 55 | trajs.extend(traj) 56 | 57 | rewards = deepMaxEntIRL(feat_map, p_a, GAMMA, trajs, 58 | LEARNING_RATE, N_ITERS, fnid_idx, idx_fnid, gpd_file, RESTORE) -------------------------------------------------------------------------------- /demo_deepirl_realworld.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import pandas as pd 5 | from collections import namedtuple 6 | 7 | from realGrid import real_grid 8 | from realGrid import value_iteration 9 | from trajectory import * 10 | from deep_irl_realworld import * 11 | 12 | PARSER = argparse.ArgumentParser( 13 | description="argument of deep max entropy inverse learning algorithm") 14 | PARSER.add_argument('-g', '--gamma', default=0.9, 15 | type=float, help='discount factor') 16 | PARSER.add_argument('-a', '--act_random', default=0.2, 17 | type=float, help='probability of acting randomly') 18 | PARSER.add_argument('-lr', '--learning_rate', default=0.02, 19 | type=float, help='learning rate') 20 | PARSER.add_argument('-ni', '--n_epochs', default=100, 21 | type=int, help='number of epochs') 22 | PARSER.add_argument('--restore', dest='restore', action='store_true', 23 | help='restore graph from existed checkpoint file') 24 | ARGS = PARSER.parse_args() 25 | print(ARGS) 26 | 27 | GAMMA = ARGS.gamma 28 | ACT_RAND = ARGS.act_random 29 | LEARNING_RATE = ARGS.learning_rate 30 | N_EPOCHS = ARGS.n_epochs 31 | RESTORE = ARGS.restore 32 | 33 | # get the transition probability between states 34 | feature_map_file = './data/ns_SF.csv' 35 | feature_map_excel = pd.read_csv(feature_map_file) 36 | states_list = list(feature_map_excel['fnid']) 37 | fnid_idx = {} 38 | idx_fnid = {} 39 | for i in range(len(states_list)): 40 | fnid_idx.update({states_list[i]: i}) 41 | idx_fnid.update({i: states_list[i]}) 42 | grid = real_grid.RealGridWorld(fnid_idx, idx_fnid, 1-ACT_RAND) 43 | p_a = grid.get_transition_mat() 44 | # get feature map of state 45 | feat_map = feature_map_excel.iloc[:, 1:] 46 | index_fnid = feature_map_excel['fnid'] 47 | feat_map = np.array(feat_map) 48 | 49 | # train the deep-irl 50 | gpd_file = './data/nanshan_grid.shp' 51 | route_file_path = './data/routes_states' 52 | 53 | nn_r = DeepIRLFC(feat_map.shape[1], LEARNING_RATE, [ 54 | 0, 0, 0, 0, 0], 40, 30) # initialize the deep-network 55 | 56 | pre_reward= np.zeros((2,3,feat_map.shape[0]),dtype=float) 57 | 58 | T0 = time.time() 59 | now_time = datetime.datetime.now() 60 | Step=namedtuple('Step',['state','action']) 61 | print('this training loop start at {}'.format(now_time)) 62 | for epoch in range(N_EPOCHS): 63 | T1 = time.time() 64 | this_reward=np.zeros((2,3,feat_map.shape[0])) 65 | for f in os.listdir(route_file_path): 66 | genderAge = [0]*5 67 | gender,age=int(f[0]),int(f[2]) 68 | genderAge[gender], genderAge[age+2] = 1,1 69 | trajs = np.load(route_file_path+'/'+f,allow_pickle=True) 70 | trajs = trajs.tolist() 71 | print("load trajectory done!") 72 | rewards, nn_r = deepMaxEntIRL2(nn_r,route_file_path+'/'+f,feature_map_file, feat_map, p_a, GAMMA, trajs, 73 | LEARNING_RATE, fnid_idx, idx_fnid, gpd_file, genderAge, RESTORE) 74 | this_reward[gender,age,:]=np.array(rewards).reshape(feat_map.shape[0]) 75 | 76 | T2 = time.time() 77 | print("this epoch lasts {:.2f}s, the loop lasts {:.2f}s, the {}th epoch end at {}".format( 78 | T2-T1, T2-T0, epoch, datetime.datetime.now())) 79 | 80 | reward_difference=np.mean(this_reward-pre_reward) 81 | print("the current reward difference is {}".format(reward_difference)) 82 | if abs(reward_difference) <= 0.001: 83 | print('the difference of reward is less than 0.001, then break the loop') 84 | break 85 | pre_reward=this_reward -------------------------------------------------------------------------------- /demo_recursive_logit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | import pandas as pd 5 | import os 6 | from collections import namedtuple 7 | 8 | import img_utils 9 | from realGrid import real_grid 10 | from realGrid import value_iteration 11 | from trajectory import * 12 | from recursive_logit import * 13 | 14 | 15 | 16 | PARSER = argparse.ArgumentParser( 17 | description="argument of deep max entropy inverse learning algorithm") 18 | PARSER.add_argument('-g', '--gamma', default=0.9, 19 | type=float, help='discount factor') 20 | PARSER.add_argument('-a', '--act_random', default=0.2, 21 | type=float, help='probability of acting randomly') 22 | PARSER.add_argument('-lr', '--learning_rate', default=0.02, 23 | type=float, help='learning rate') 24 | PARSER.add_argument('-ni', '--n_iters', default=20, 25 | type=int, help='number of iterations') 26 | PARSER.add_argument('--restore', dest='restore', action='store_true', 27 | help='restore graph from existed checkpoint file') 28 | ARGS = PARSER.parse_args() 29 | 30 | 31 | GAMMA = ARGS.gamma 32 | ACT_RAND = ARGS.act_random 33 | LEARNING_RATE = ARGS.learning_rate 34 | N_ITERS = ARGS.n_iters 35 | RESTORE = ARGS.restore 36 | 37 | # get the transition probability between states 38 | feature_map_file = './data/ns_SF.csv' 39 | feature_map_excel = pd.read_csv(feature_map_file) 40 | states_list = list(feature_map_excel['fnid']) 41 | fnid_idx = {} 42 | idx_fnid = {} 43 | for i in range(len(states_list)): 44 | fnid_idx.update({states_list[i]: i}) 45 | idx_fnid.update({i: states_list[i]}) 46 | grid = real_grid.RealGridWorld(fnid_idx, idx_fnid, 1-ACT_RAND) 47 | p_a = grid.get_transition_mat() 48 | # get feature map of state 49 | feat_map = feature_map_excel.iloc[:, 1:] 50 | feat_map=np.array(feat_map) 51 | 52 | route_file_path = './data/routes_states' 53 | Step=namedtuple('Step',['state','action']) 54 | # run trajectory.py,and get the state-action pair of real world trajectory 55 | for f in os.listdir(route_file_path): 56 | genderAge = [0]*5 57 | gender,age=int(f[0]),int(f[2]) 58 | genderAge[gender], genderAge[age+2] = 1,1 59 | 60 | trajs = np.load(route_file_path+'/'+f,allow_pickle=True) 61 | trajs = trajs.tolist() 62 | rewards_maxent = recursiveLogit(feat_map, genderAge, p_a, GAMMA, trajs, fnid_idx,LEARNING_RATE, N_ITERS) -------------------------------------------------------------------------------- /dnn_psl.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from collections import namedtuple 4 | import pandas as pd 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from sklearn.model_selection import train_test_split 10 | from sklearn.preprocessing import StandardScaler 11 | 12 | Step=namedtuple('Step',['state','action']) 13 | 14 | def nShortestPaths(G, origin, destination, num_paths_to_find=10): 15 | origin_node = [(node, data) for node, data in G.nodes(data=True) if data.get('fnid') == origin] 16 | destination_node = [(node, data) for node, data in G.nodes(data=True) if data.get('fnid') == destination] 17 | 18 | if not origin_node or not destination_node: 19 | print("Origin or destination nodes not found.") 20 | return 21 | 22 | origin = origin_node[0][0] 23 | destination = destination_node[0][0] 24 | 25 | shortest_paths = [] 26 | for _ in range(num_paths_to_find): 27 | shortest_path = nx.shortest_path(G, source=origin, target=destination) 28 | shortest_paths.append(shortest_path) 29 | 30 | for u, v in zip(shortest_path[:-1], shortest_path[1:]): 31 | G.remove_edge(u, v) 32 | 33 | paths = [] 34 | for i, path in enumerate(shortest_paths): 35 | fnid_values = [G.nodes[node]['fnid'] for node in path] 36 | paths.append(fnid_values) 37 | 38 | return paths 39 | 40 | class DNNPSLModel(nn.Module): 41 | def __init__(self, input_size): 42 | super(DNNPSLModel, self).__init__() 43 | self.fc1 = nn.Linear(input_size, 128) 44 | self.relu1 = nn.ReLU() 45 | self.fc2 = nn.Linear(128, 64) 46 | self.relu2 = nn.ReLU() 47 | self.fc3 = nn.Linear(64, 1) # Output layer with one neuron for logits 48 | 49 | def forward(self, x): 50 | x = self.fc1(x) 51 | x = self.relu1(x) 52 | x = self.fc2(x) 53 | x = self.relu2(x) 54 | x = self.fc3(x) 55 | return torch.sigmoid(x) 56 | 57 | 58 | if __name__ == "__main__": 59 | route_file_path = './data/routes_states' 60 | graphml_file = "./data/nanshan_network.graphml" 61 | 62 | # Instantiate your model and define the loss function 63 | input_size = 35 # Match the input size to your feature dimension 64 | model = DNNPSLModel(input_size) 65 | criterion = nn.BCELoss() # Binary Cross-Entropy Loss 66 | optimizer = optim.SGD(model.parameters(), lr=0.01) # Define an optimizer (e.g., SGD or Adam) 67 | 68 | while True: 69 | for f in os.listdir(route_file_path): 70 | genderAge = [0]*5 71 | gender,age=int(f[0]),int(f[2]) 72 | genderAge[gender], genderAge[age+2] = 1,1 73 | state_action = np.load(route_file_path+'/'+f,allow_pickle=True) 74 | selected_traj = np.random.choice(state_action, size=1, replace=False)[0] 75 | origin, destination = int(selected_traj[0].state), int(selected_traj[-1].state) 76 | 77 | G=nx.read_graphml(graphml_file) 78 | try: 79 | paths=nShortestPaths(G, origin, destination, num_paths_to_find=10) 80 | except: 81 | print("the shortest path don't exit") 82 | continue 83 | # add true trajectory 84 | selected_traj = [t.state for t in selected_traj] 85 | paths.append(selected_traj) 86 | 87 | # path features 88 | path_features=[] 89 | feature_map_file="./data/ns_sf.csv" 90 | df = pd.read_csv(feature_map_file) 91 | for path in paths: 92 | filtered_rows = df[df['fnid'].isin(path)] 93 | feature = filtered_rows[1:] 94 | path_features.append(filtered_rows.iloc[:, 1:].values.tolist()) 95 | column_means = [np.mean(np.array(inner_list).T, axis=1) for inner_list in path_features] 96 | 97 | # Convert the lists to NumPy arrays 98 | column_means = np.array(column_means) 99 | genderAge = np.array(genderAge) 100 | # Stack genderAge horizontally to each row in column_means 101 | column_means = np.hstack((column_means, np.tile(genderAge, (len(column_means), 1)))) 102 | # Convert path features to tensors 103 | X = torch.tensor(column_means, dtype=torch.float32) 104 | # Create labels (y) 105 | y = torch.zeros(len(paths)) 106 | y[-1] = 1 # True path has a label of 1, others are 0 107 | 108 | # Training loop 109 | num_epochs = 1000 110 | for epoch in range(num_epochs): 111 | outputs = model(X) 112 | loss = criterion(outputs, y.unsqueeze(1)) # Ensure the shape matches 113 | 114 | # Backpropagation and optimization 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | # Print the loss for monitoring 120 | print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}') 121 | 122 | if loss.item() < 0.001: 123 | print("Training completed, end the training loop!") 124 | torch.save(model.state_dict(), './model/dnn psl/model_weights.pth') 125 | break -------------------------------------------------------------------------------- /essay/MEDIRL-IC.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/essay/MEDIRL-IC.pdf -------------------------------------------------------------------------------- /evaluate_traj_IRL.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import geopandas as gpd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from deep_irl_realworld import * 6 | import deep_irl_be 7 | from test_dirl_realworld import readFeatureMap 8 | from A_star.A_star_traj import * 9 | from DTW.dtw import DTW 10 | 11 | def getRewardFileGA(feat_map, genderAge): 12 | """ 13 | get reward of feature map 14 | 15 | Args: 16 | feat_map (N*M:2d_matrix): feature map of zone 17 | 18 | Returns: 19 | rewards (N*1): reward of each grid 20 | """ 21 | nn_r = DeepIRLFC(feat_map.shape[1], 0.02,genderAge, 40, 30) 22 | model_file = './model' 23 | model_name = 'realworld' 24 | if os.path.exists(os.path.join(model_file, model_name+'.meta')): 25 | print('restore graph from ckpt file') 26 | nn_r.restoreGraph(model_file, model_name) 27 | else: 28 | print("there isn't ckpt file") 29 | return 30 | 31 | oh = [genderAge for _ in range(feat_map.shape[0])] 32 | rewards = normalize(nn_r.get_rewards(feat_map, oh)) 33 | return rewards 34 | 35 | def getRewardFileBE(feat_map): 36 | nn_r = deep_irl_be.DeepIRLFC(feat_map.shape[1], 0.02,40, 30) 37 | model_file = './model/be' 38 | model_name = 'realworld' 39 | if os.path.exists(os.path.join(model_file, model_name+'.meta')): 40 | print('restore graph from ckpt file') 41 | nn_r.restoreGraph(model_file, model_name) 42 | else: 43 | print("there isn't ckpt file") 44 | return 45 | 46 | rewards = normalize(nn_r.get_rewards(feat_map)) 47 | return rewards 48 | 49 | def getRewardFileRLogit(feat_map,genderAge): 50 | genderAge=np.array(genderAge) 51 | genderAge = np.tile(genderAge, (feat_map.shape[0], 1)) 52 | feat_map = np.concatenate((genderAge, feat_map), axis=1) 53 | 54 | theta = np.load('./model/rlogit_realworld.npy') 55 | rewards = normalize(np.dot(feat_map, theta)) 56 | return rewards 57 | 58 | def trajContrast(route, real_map, gpd_file,if_show): 59 | """ 60 | Comparison of real and generated trajectories 61 | 62 | Args: 63 | routes_file 64 | real_map : real map class 65 | gpd_file 66 | """ 67 | start_fnid = route[0] 68 | end_fnid = route[-1] 69 | 70 | route_generate = astar(real_map, start_fnid, end_fnid) 71 | 72 | gdf = gpd.read_file(gpd_file) 73 | ColNames = gdf.columns 74 | route_df = pd.DataFrame(columns=ColNames) 75 | route_generate_df = pd.DataFrame(columns=ColNames) 76 | 77 | for fnid in route: 78 | idx = gdf[(gdf['fnid'] == fnid)].index 79 | route_df = route_df.append(gdf.iloc[idx, :], ignore_index=True) 80 | for fnid in route_generate: 81 | idx = gdf[(gdf['fnid'] == fnid)].index 82 | route_generate_df = route_generate_df.append( 83 | gdf.iloc[idx, :], ignore_index=True) 84 | route_generate_gdf = gpd.GeoDataFrame( 85 | route_generate_df, geometry="geometry") 86 | route_gdf = gpd.GeoDataFrame( 87 | route_df, geometry="geometry") 88 | 89 | plt.rcParams["font.family"] = "Times New Roman" 90 | _, ax = plt.subplots() 91 | ax=gdf.plot(ax=ax, color='darkgray', label='map') 92 | ax=route_gdf.plot(ax=ax, color='lightsalmon',label = "real trajecory") 93 | ax=route_generate_gdf.plot(ax=ax, color='lightskyblue',label = "generated trajecory") 94 | ax.axis('off') 95 | plt.title('trajectory contrast') 96 | if if_show: 97 | plt.show() 98 | 99 | def trajEvaluate(routes_file,real_map,gpd_file,save_picture=False): 100 | routes = np.load(routes_file) 101 | all_length=0 102 | list_length=[] 103 | for i,route in enumerate(routes): 104 | route = [r.state for r in route] 105 | # generate trajectory 106 | start_fnid=route[0] 107 | end_fnid = route[-1] 108 | route_generate = astar(real_map, start_fnid, end_fnid) 109 | # calculate dtw distance between trajectory 110 | traj_dtw=DTW(route,route_generate) 111 | D=traj_dtw.dtw() 112 | length=(len(route_generate)+len(route))/2 113 | dtw_length=D[-1][-1]/length 114 | # save figure 115 | if save_picture: 116 | trajContrast(route, real_map, gpd_file, False) 117 | plt.title('DTW Length is {}'.format(dtw_length)) 118 | plt.savefig('./img/traj_distance/{}.png'.format(i)) 119 | print(dtw_length) 120 | all_length+=dtw_length 121 | list_length.append(dtw_length) 122 | 123 | avg_length=all_length/len(routes) 124 | return avg_length,list_length 125 | 126 | if __name__ == "__main__": 127 | feat_map_file = './data/ns_sf.csv' 128 | gpd_file = './data/nanshan_grid.shp' 129 | routes_file = './data/routes_states/0_0_states_tuple.npy' 130 | genderAge=[1,0,1,0,0] 131 | 132 | feat_map, fnid_idx, idx_fnid, states_list = readFeatureMap(feat_map_file) 133 | 134 | # NOTE: get different reward 135 | # rewards = getRewardFileGA(feat_map, genderAge).tolist() 136 | # rewards = getRewardFileBE(feat_map).tolist() 137 | # cost = [1-r[0] for r in rewards] 138 | 139 | rewards = getRewardFileRLogit(feat_map,genderAge).tolist() 140 | cost = [1-r for r in rewards] 141 | 142 | real_map = Map(states_list, fnid_idx, idx_fnid, cost*10) 143 | 144 | # # Trjectory comparison 145 | # routes = np.load(routes_file) 146 | # route_idx = np.random.randint(0, len(routes)-1) 147 | # route = routes[route_idx] 148 | # trajContrast(route, real_map, gpd_file, True) 149 | 150 | # trajectory evaluation 151 | avg_length,list_length=trajEvaluate(routes_file,real_map,gpd_file) 152 | 153 | np.save('./data/evaluate/traj_evaluate_rl.npy',list_length) -------------------------------------------------------------------------------- /evaluate_traj_psl.py: -------------------------------------------------------------------------------- 1 | from dnn_psl import * 2 | from A_star.A_star_traj import * 3 | from DTW.dtw import DTW 4 | 5 | def readFeatureMap(feature_map_file): 6 | """ 7 | read feature map from csv feature file 8 | 9 | Args: 10 | feature_map_file : path of feature file 11 | 12 | Returns: 13 | feat_map: numpy feature map 14 | fnid_idx:{fnid:idx} 15 | idx_fnid:{idx:fnid} 16 | """ 17 | feature_map_excel = pd.read_csv(feature_map_file) 18 | states_list = list(feature_map_excel['fnid']) 19 | fnid_idx = {} 20 | idx_fnid = {} 21 | for i in range(len(states_list)): 22 | fnid_idx.update({states_list[i]: i}) 23 | idx_fnid.update({i: states_list[i]}) 24 | states_list = list(feature_map_excel['fnid']) 25 | # get feature map of state 26 | feat_map = feature_map_excel.iloc[:, 1:] 27 | feat_map = np.array(feat_map) 28 | return feat_map, fnid_idx, idx_fnid, states_list 29 | 30 | # load model 31 | loaded_model = DNNPSLModel(35) 32 | loaded_model.load_state_dict(torch.load('./model/dnn psl/model_weights.pth')) 33 | loaded_model.eval() 34 | # load grid network 35 | graphml_file = "./data/nanshan_network.graphml" 36 | 37 | # load state-action of trajectory 38 | state_action = np.load('./data/routes_states/0_0_states_tuple.npy',allow_pickle=True) 39 | feature_map_file="./data/ns_sf.csv" 40 | df = pd.read_csv(feature_map_file) 41 | # load feature map 42 | feat_map_file = './data/ns_sf.csv' 43 | feat_map, fnid_idx, idx_fnid, states_list = readFeatureMap(feat_map_file) 44 | # result 45 | all_length = 0 46 | list_length = [] 47 | 48 | for selected_traj in state_action: 49 | origin, destination = int(selected_traj[0].state), int(selected_traj[-1].state) 50 | G=nx.read_graphml(graphml_file) 51 | try: 52 | paths=nShortestPaths(G, origin, destination, num_paths_to_find=10) 53 | except: 54 | continue 55 | # path features 56 | path_features=[] 57 | for path in paths: 58 | filtered_rows = df[df['fnid'].isin(path)] 59 | feature = filtered_rows[1:] 60 | path_features.append(filtered_rows.iloc[:, 1:].values.tolist()) 61 | column_means = [np.mean(np.array(inner_list).T, axis=1) for inner_list in path_features] 62 | column_means = np.array(column_means) 63 | genderAge = [1,0,1,0,0] 64 | genderAge = np.array(genderAge) 65 | column_means = np.hstack((column_means, np.tile(genderAge, (len(column_means), 1)))) 66 | X = torch.tensor(column_means, dtype=torch.float32) 67 | 68 | outputs = loaded_model(X) 69 | 70 | true_traj = [t.state for t in selected_traj] 71 | generated_traj = paths[np.argmax(outputs.detach().numpy())] 72 | # trajectory evaluate 73 | traj_dtw=DTW(true_traj,generated_traj) 74 | D=traj_dtw.dtw() 75 | length=(len(true_traj)+len(generated_traj))/2 76 | dtw_length=D[-1][-1]/length 77 | print(dtw_length) 78 | all_length+=dtw_length 79 | list_length.append(dtw_length) 80 | 81 | avg_length=all_length/len(state_action) 82 | np.save('./data/evaluate/traj_evaluate_psl.npy',list_length) -------------------------------------------------------------------------------- /img/MEDIRL_IC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/MEDIRL_IC.png -------------------------------------------------------------------------------- /img/box_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/box_comparison.png -------------------------------------------------------------------------------- /img/causal_directed_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/causal_directed_graph.png -------------------------------------------------------------------------------- /img/causal_strength.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/causal_strength.png -------------------------------------------------------------------------------- /img/dta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/dta.png -------------------------------------------------------------------------------- /img/reward_and_likelihood_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/reward_and_likelihood_plot.png -------------------------------------------------------------------------------- /img/reward_map_based_on_IC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/reward_map_based_on_IC.png -------------------------------------------------------------------------------- /img/sz_reward_map.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/img/sz_reward_map.png -------------------------------------------------------------------------------- /img_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import geopandas as gpd 4 | import seaborn as sns 5 | 6 | def show_img(img): 7 | print(img.shape, img.dtype) 8 | plt.imshow(img[:, :, 0]) 9 | plt.ion() 10 | plt.show() 11 | 12 | 13 | def heatmap2d(hm_mat, title='', block=False, fig_num=1, text=True): 14 | """ 15 | Display heatmap 16 | input: 17 | hm_mat: mxn 2d np array 18 | """ 19 | print('map shape: {}, data type: {}'.format(hm_mat.shape, hm_mat.dtype)) 20 | 21 | if block: 22 | plt.figure(fig_num) 23 | plt.clf() 24 | 25 | plt.imshow(hm_mat, interpolation='nearest') 26 | plt.title(title) 27 | plt.colorbar() 28 | 29 | if text: 30 | for y in range(hm_mat.shape[0]): 31 | for x in range(hm_mat.shape[1]): 32 | plt.text(x, y, '%.1f' % hm_mat[y, x], 33 | horizontalalignment='center', 34 | verticalalignment='center', 35 | ) 36 | 37 | if block: 38 | plt.ion() 39 | print('press enter to continue') 40 | plt.show() 41 | plt.waitforbuttonpress() 42 | 43 | 44 | def heatmap3d(hm_mat, title=''): 45 | from mpl_toolkits.mplot3d import Axes3D 46 | import matplotlib.pyplot as plt 47 | import numpy as np 48 | 49 | data_2d = hm_mat 50 | 51 | data_array = np.array(data_2d) 52 | 53 | fig = plt.figure() 54 | ax = fig.add_subplot(111, projection='3d') 55 | plt.title(title) 56 | 57 | x_data, y_data = np.meshgrid(np.arange(data_array.shape[1]), 58 | np.arange(data_array.shape[0])) 59 | x_data = x_data.flatten() 60 | y_data = y_data.flatten() 61 | z_data = data_array.flatten() 62 | ax.bar3d(x_data, 63 | y_data, 64 | np.zeros(len(z_data)), 65 | 1, 1, z_data) 66 | plt.show() 67 | plt.waitforbuttonpress() 68 | 69 | 70 | def rewardVisual(rewards, idx_fnid, gpd_file,title="", text=True): 71 | gdf = gpd.read_file(gpd_file) 72 | gdf['reward'] = 0 73 | for i in range(len(rewards)): 74 | fnid = idx_fnid[i] 75 | idx = gdf[(gdf['fnid'] == fnid)].index 76 | gdf.iloc[idx, -1] = rewards[i] 77 | gdf.plot(column='reward', cmap='viridis') 78 | plt.title(title) 79 | 80 | 81 | def histKernel(x): 82 | plt.figure(dpi=120) 83 | rc = {'font.sans-serif': 'Times New Roman', 84 | 'axes.unicode_minus': False} 85 | sns.set_style(style='dark', rc=rc) 86 | sns.set_style({"axes.facecolor": "#e9f3ea"}) 87 | g = sns.distplot(x, 88 | hist=True, 89 | kde=True, # 开启核密度曲线kernel density estimate (KDE) 90 | kde_kws={'linestyle': '--', 'linewidth': '1', 'color': '#c72e29', # 设置外框线属性 91 | }, 92 | color='#098154', 93 | axlabel='Xlabel', # 设置x轴标题 94 | ) 95 | plt.savefig('./kernel.png',dpi=400) 96 | plt.show() 97 | 98 | 99 | if __name__ == "__main__": 100 | import pandas as pd 101 | feature_map_excel = pd.read_excel('./data/nanshan_tfidf.xlsx') 102 | states_list = list(feature_map_excel['fnid']) 103 | fnid_idx = {} 104 | idx_fnid = {} 105 | for i in range(len(states_list)): 106 | fnid_idx.update({states_list[i]: i}) 107 | idx_fnid.update({i: states_list[i]}) 108 | rewards = np.random.randint(100, size=(len(idx_fnid))) 109 | rewardVisual(rewards, idx_fnid, '') 110 | -------------------------------------------------------------------------------- /model/be/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "realworld" 2 | all_model_checkpoint_paths: "realworld" 3 | -------------------------------------------------------------------------------- /model/be/policy_realworld.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/be/policy_realworld.npy -------------------------------------------------------------------------------- /model/be/realworld.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/be/realworld.data-00000-of-00001 -------------------------------------------------------------------------------- /model/be/realworld.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/be/realworld.index -------------------------------------------------------------------------------- /model/be/realworld.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/be/realworld.meta -------------------------------------------------------------------------------- /model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "realworld" 2 | all_model_checkpoint_paths: "realworld" 3 | -------------------------------------------------------------------------------- /model/dnn psl/model_weights.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/dnn psl/model_weights.pth -------------------------------------------------------------------------------- /model/policy_realworld.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/policy_realworld.npy -------------------------------------------------------------------------------- /model/realworld.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/realworld.data-00000-of-00001 -------------------------------------------------------------------------------- /model/realworld.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/realworld.index -------------------------------------------------------------------------------- /model/realworld.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/realworld.meta -------------------------------------------------------------------------------- /model/rlogit_realworld.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BoyangL1/Advanced_DeepIRL/c04a0fe4704e6b24b67ba805c99a75629f87e8b3/model/rlogit_realworld.npy -------------------------------------------------------------------------------- /nshortest_path.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import geopandas as gpd\n", 10 | "import networkx as nx\n", 11 | "from shapely.geometry import Polygon\n", 12 | "import numpy as np\n", 13 | "from collections import namedtuple\n", 14 | "Step=namedtuple('Step',['state','action'])" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "# nanshan graph" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "shapefile_path = \"./data/nanshan_grid.shp\"\n", 31 | "gdf = gpd.read_file(shapefile_path)\n", 32 | "\n", 33 | "G = nx.Graph()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "for index, row in gdf.iterrows():\n", 43 | " current_node = tuple(row['geometry'].exterior.coords)\n", 44 | " fnid = row['fnid'] \n", 45 | " G.add_node(current_node, fnid=fnid)\n", 46 | "\n", 47 | " for neighbor_index, neighbor_row in gdf.iterrows():\n", 48 | " neighbor_node = tuple(neighbor_row['geometry'].exterior.coords)\n", 49 | " if current_node != neighbor_node and Polygon(current_node).touches(Polygon(neighbor_node)):\n", 50 | " G.add_edge(current_node, neighbor_node)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "graphml_file = \"./data/nanshan_network.graphml\"\n", 60 | "nx.write_graphml(G, graphml_file)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "G=nx.read_graphml(graphml_file)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "# origin & destination coord" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "states_tuple_path = './data/routes_states/0_0_states_tuple.npy'\n", 86 | "state_action = np.load(states_tuple_path, allow_pickle=True)\n", 87 | "selected_traj = np.random.choice(state_action, size=1, replace=False)[0]\n", 88 | "origin, destination = int(selected_traj[0].state), int(selected_traj[-1].state)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "# n shortest path" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "origin_node = [(node,data) for node, data in G.nodes(data=True) if data.get('fnid') == origin]\n", 105 | "destination_node = [(node, data) for node, data in G.nodes(data=True) if data.get('fnid') == destination]" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "num_paths_to_find = 10\n", 115 | "origin = origin_node[0][0]\n", 116 | "destination = destination_node[0][0]\n", 117 | "\n", 118 | "shortest_paths = []\n", 119 | "for _ in range(num_paths_to_find):\n", 120 | " shortest_path = nx.shortest_path(G, source=origin, target=destination)\n", 121 | " shortest_paths.append(shortest_path)\n", 122 | "\n", 123 | " for u, v in zip(shortest_path[:-1], shortest_path[1:]):\n", 124 | " G.remove_edge(u, v)" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "for i, path in enumerate(shortest_paths):\n", 134 | " fnid_values = [G.nodes[node]['fnid'] for node in path]\n", 135 | " print(f\"shortest path {i+1}: {fnid_values}\")" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "django", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.7.11" 156 | }, 157 | "orig_nbformat": 4 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /realGrid/real_grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RealGridWorld(object): 5 | def __init__(self, fnid_idx, idx_fnid, trans_prob): 6 | """ 7 | modelling the real grid world 8 | Args: 9 | fnid_idx : {fnid:index} index start from 0 10 | trans_prob : probability that state tranform to another one when take action 11 | """ 12 | self.fnid_idx = fnid_idx 13 | self.idx_fnid = idx_fnid 14 | self.trans_prob = trans_prob 15 | self.n_actions = 5 16 | self.actions = [0, 1, 2, 3, 4] 17 | self.neighbors = [1, -1, -357, 357, 0] 18 | self.dirs = {0: 'r', 1: 'l', 2: 'd', 3: 'u', 4: 's'} 19 | # right, left, down, up , stay 20 | 21 | def get_transition_states_and_probs(self, state_fnid, action): 22 | """ 23 | get all the possible transition states and their probabilities with [action] on [state] 24 | Args: 25 | state_fnid : fnid of current state 26 | action : int 27 | Return: 28 | a list of (state,probability) pair 29 | """ 30 | if self.trans_prob == 1: 31 | inc = self.neighbors[action] 32 | nei_s = state_fnid+inc 33 | if nei_s not in self.fnid_idx: 34 | return [(state_fnid, 1)] 35 | else: 36 | return [(nei_s, 1)] 37 | else: 38 | mov_probs = np.zeros([self.n_actions]) 39 | mov_probs[action] = self.trans_prob 40 | mov_probs += (1-self.trans_prob)/(self.n_actions-1) 41 | mov_probs[action] -= (1-self.trans_prob)/(self.n_actions-1) 42 | 43 | for a in range(len(self.actions)): 44 | inc = self.neighbors[a] 45 | nei_s = state_fnid+inc 46 | if nei_s not in self.fnid_idx: 47 | mov_probs[-1] += mov_probs[a] 48 | mov_probs[a] = 0 49 | 50 | res = [] 51 | for a in range(len(self.actions)): 52 | if mov_probs[a] != 0: 53 | inc = self.neighbors[a] 54 | nei_s = state_fnid+inc 55 | res.append((nei_s, mov_probs[a])) 56 | return res 57 | 58 | def get_transition_mat(self): 59 | """ 60 | get transition dynamics of the gridworld 61 | 62 | return: 63 | P_a N_STATESDxN_STATESxN_ACTIONS transition probabilities matrix - 64 | P_a[s0, s1, a] is the transition prob of 65 | landing at state s1 when taking action 66 | a at state s0 67 | """ 68 | N_STATES = len(self.fnid_idx) 69 | N_ACTIONS = self.n_actions 70 | P_a = np.zeros((N_STATES, N_STATES, N_ACTIONS)) 71 | for si in range(N_STATES): 72 | posi = self.idx_fnid[si] 73 | for a in range(N_ACTIONS): 74 | probs = self.get_transition_states_and_probs(posi, a) 75 | for posj, prob in probs: 76 | sj = self.fnid_idx[posj] 77 | # Probility from si to sj given action a 78 | P_a[si, sj, a] = prob 79 | return P_a 80 | -------------------------------------------------------------------------------- /realGrid/value_iteration.py: -------------------------------------------------------------------------------- 1 | from email import policy 2 | import math 3 | import numpy as np 4 | from collections import deque 5 | 6 | 7 | def value_iteration(P_a, rewards, gamma, error=0.01, deterministic=True): 8 | """ 9 | static value iteration function. Perhaps the most useful function in this repo 10 | 11 | inputs: 12 | P_a NxNxN_ACTIONS transition probabilities matrix - 13 | P_a[s0, s1, a] is the transition prob of 14 | landing at state s1 when taking action 15 | a at state s0 16 | rewards Nx1 matrix - rewards for all the states 17 | gamma float - RL discount 18 | error float - threshold for a stop 19 | deterministic bool - to return deterministic policy or stochastic policy 20 | 21 | returns: 22 | values Nx1 matrix - estimated values 23 | policy Nx1 (NxN_ACTIONS if non-det) matrix - policy 24 | """ 25 | N_STATES, _, N_ACTIONS = np.shape(P_a) 26 | 27 | values = np.zeros([N_STATES]) 28 | 29 | # estimate values 30 | while True: 31 | values_tmp = values.copy() 32 | 33 | for s in range(N_STATES): 34 | v_s = [] 35 | values[s] = max([sum([P_a[s, s1, a]*(rewards[s] + gamma*values_tmp[s1]) 36 | for s1 in range(N_STATES)]) for a in range(N_ACTIONS)]) 37 | 38 | max_diff = np.max(np.abs(values - values_tmp)) 39 | print(max_diff) 40 | if max_diff < error: 41 | break 42 | 43 | 44 | if deterministic: 45 | # generate deterministic policy 46 | policy = np.zeros([N_STATES]) 47 | for s in range(N_STATES): 48 | policy[s] = np.argmax([sum([P_a[s, s1, a]*(rewards[s]+gamma*values[s1]) 49 | for s1 in range(N_STATES)]) 50 | for a in range(N_ACTIONS)]) 51 | return values, policy 52 | else: 53 | # generate stochastic policy 54 | policy = np.zeros([N_STATES, N_ACTIONS]) 55 | for s in range(N_STATES): 56 | v_s = np.array([sum([P_a[s, s1, a]*(rewards[s] + gamma*values[s1]) 57 | for s1 in range(N_STATES)]) 58 | for a in range(N_ACTIONS)]) 59 | policy[s, :] = np.transpose(v_s/np.sum(v_s)) 60 | return values, policy 61 | 62 | 63 | def determinValIteration(end_fnid, actions, neighbors, gamma, fnid_idx, reward_map): 64 | """ 65 | determinstic value iteration 66 | 67 | Args: 68 | end_fnid : destination fnid 69 | actions : action list 70 | neighbors : neighboring fnid distance 71 | gamma : Attenuation coefficient .Default set to 0.9 72 | fnid_idx : fnid:idx 73 | reward_map : reward map 74 | 75 | Returns: 76 | values,policy: value map policy map 77 | """ 78 | values = [0]*len(fnid_idx) 79 | policy = [-1]*len(fnid_idx) 80 | # values[fnid_idx[end_fnid]] = float(reward_map[fnid_idx[end_fnid]]) 81 | values[fnid_idx[end_fnid]] = reward_map[fnid_idx[end_fnid]] 82 | policy[fnid_idx[end_fnid]] = end_fnid 83 | queue = deque() 84 | queue.append(end_fnid) 85 | while queue: 86 | cur_fnid = queue.popleft() 87 | for a in actions: 88 | nei_fnid = cur_fnid+neighbors[a] 89 | if nei_fnid in fnid_idx.keys() and values[fnid_idx[nei_fnid]] == 0: 90 | queue.append(nei_fnid) 91 | nei_fnid_idx = fnid_idx[nei_fnid] 92 | 93 | max_value=float('-inf') 94 | max_fnid=0 95 | for a in actions: 96 | n = nei_fnid+neighbors[a] 97 | if n in fnid_idx.keys() and max_value= 4) & (routes['age'] <= 7) & (routes['gender'] == 1)]\n", 67 | "young_women = routes[(routes['age'] >= 4) & (routes['age'] <= 7) & (routes['gender'] == 2)]\n", 68 | "middle_men = routes[(routes['age'] >= 8) & (routes['age'] <= 13) & (routes['gender'] == 1)]\n", 69 | "middle_women = routes[(routes['age'] >= 8) & (routes['age'] <= 13) & (routes['gender'] == 2)]\n", 70 | "old_men = routes[(routes['age'] >= 14) & (routes['gender'] == 1)]\n", 71 | "old_women = routes[(routes['age'] >= 14) & (routes['gender'] == 2)]" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "print(len(young_men),len(young_women),len(middle_men),len(middle_women),len(old_men),len(old_women))" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "# Function" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "def getFnidByPoint(first_point, second_point):\n", 97 | " fnid_list = []\n", 98 | "\n", 99 | " # Create LineString from first_point to second_point\n", 100 | " line = LineString([first_point, second_point])\n", 101 | " line_series = gpd.GeoSeries([line], crs=4326)\n", 102 | " line_utm = line_series.to_crs(32650)\n", 103 | "\n", 104 | " dist = line_utm.length.iloc[0]\n", 105 | " delta = int(dist // 50)\n", 106 | "\n", 107 | " x_diff = second_point[0] - first_point[0]\n", 108 | " y_diff = second_point[1] - first_point[1]\n", 109 | " x_values = np.linspace(first_point[0], second_point[0], delta + 1)\n", 110 | " y_values = np.linspace(first_point[1], second_point[1], delta + 1)\n", 111 | " interpolation_points = [Point(x, y) for x, y in zip(x_values, y_values)]\n", 112 | "\n", 113 | " # Perform overlay analysis for each point\n", 114 | " intersect = gpd.overlay(district, gpd.GeoDataFrame(geometry=interpolation_points, crs=4326), how='intersection', keep_geom_type=False)\n", 115 | " fnid_list.extend(intersect['fnid'])\n", 116 | "\n", 117 | " return fnid_list" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "def routeToFnid(routes):\n", 127 | " def process_route(route):\n", 128 | " geometry_ = route\n", 129 | " q = deque(geometry_.coords)\n", 130 | " first_point = q.popleft()\n", 131 | " states = set() # 使用集合来存储状态,以去重\n", 132 | " while q:\n", 133 | " second_point = q.popleft()\n", 134 | " fnid_list = getFnidByPoint(first_point, second_point)\n", 135 | " if fnid_list:\n", 136 | " states.update(fnid_list)\n", 137 | " first_point = second_point\n", 138 | " return list(states)\n", 139 | "\n", 140 | " print('the length of routes is {}'.format(len(routes)))\n", 141 | " routes_states = []\n", 142 | " with concurrent.futures.ThreadPoolExecutor() as executor:\n", 143 | " results = list(tqdm(executor.map(process_route, routes.geometry), total=len(routes)))\n", 144 | " for i, states_unique in enumerate(results):\n", 145 | " routes_states.append(states_unique)\n", 146 | "\n", 147 | " routes_states = np.array(routes_states)\n", 148 | " return routes_states" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "actions = [0, 1, 2, 3, 4]\n", 158 | "dirs = {0: 'r', 1: 'l', 2: 'd', 3: 'u', 4: 's'}\n", 159 | "Step=namedtuple('Step',['state','action'])\n", 160 | "\n", 161 | "def getActionOfStates(route_state):\n", 162 | " state_action = []\n", 163 | " length = len(route_state)\n", 164 | "\n", 165 | " if length == 1:\n", 166 | " step = Step(state=route_state[0], action=4)\n", 167 | " state_action.append(step)\n", 168 | " return state_action\n", 169 | "\n", 170 | " diff = np.diff(route_state) # 计算相邻状态之间的差值\n", 171 | "\n", 172 | " def getAction(diff_value):\n", 173 | " if diff_value == 1:\n", 174 | " return 0\n", 175 | " elif diff_value == -1:\n", 176 | " return 1\n", 177 | " elif diff_value == 357:\n", 178 | " return 3\n", 179 | " elif diff_value == -357:\n", 180 | " return 2\n", 181 | " else:\n", 182 | " return 4 # 默认停留\n", 183 | "\n", 184 | " actions = np.vectorize(getAction)(diff) # 使用矢量化操作获取所有动作\n", 185 | "\n", 186 | " for i, action in enumerate(actions):\n", 187 | " step = Step(state=route_state[i], action=action)\n", 188 | " state_action.append(step)\n", 189 | "\n", 190 | " step = Step(state=route_state[-1], action=4)\n", 191 | " state_action.append(step)\n", 192 | "\n", 193 | " return state_action\n" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "import random\n", 203 | "\n", 204 | "# If the amount of data is too large, downsampling can be performed\n", 205 | "def randomSelectLines(geodataframe, percent_to_select=0.7):\n", 206 | " total_lines = len(geodataframe)\n", 207 | " if total_lines<1000:\n", 208 | " num_lines_to_select = int(total_lines * 1)\n", 209 | " random_line_indices = random.sample(range(total_lines), num_lines_to_select)\n", 210 | " randomly_selected_lines = geodataframe.iloc[random_line_indices]\n", 211 | " return randomly_selected_lines\n", 212 | "\n", 213 | " num_lines_to_select = int(total_lines * (percent_to_select))\n", 214 | " random_line_indices = random.sample(range(total_lines), num_lines_to_select)\n", 215 | " randomly_selected_lines = geodataframe.iloc[random_line_indices]\n", 216 | " return randomly_selected_lines" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [ 225 | "def routeToTuple(routes,save_file):\n", 226 | " routes=randomSelectLines(routes)\n", 227 | " routes_states=routeToFnid(routes)\n", 228 | " state_action_tuple=[]\n", 229 | " for route_state in tqdm(routes_states):\n", 230 | " sta_act=getActionOfStates(route_state)\n", 231 | " state_action_tuple.append(sta_act)\n", 232 | " print(state_action_tuple[0])\n", 233 | " state_action_tuple=np.array(state_action_tuple)\n", 234 | " np.save(save_file,state_action_tuple) " 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "# main function" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "# 第一个0男,1女;第二个0青,1中,2老\n", 251 | "routeToTuple(young_men,'./data/routes_states/0_0_states_tuple.npy')" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "routeToTuple(middle_men,'./data/routes_states/0_1_states_tuple.npy')" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [ 269 | "routeToTuple(old_men,'./data/routes_states/0_2_states_tuple.npy')" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "metadata": {}, 276 | "outputs": [], 277 | "source": [ 278 | "routeToTuple(young_women,'./data/routes_states/1_0_states_tuple.npy')" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "routeToTuple(middle_women,'./data/routes_states/1_1_states_tuple.npy')" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "routeToTuple(old_women,'./data/routes_states/1_2_states_tuple.npy')" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "import numpy as np \n", 306 | "np.load('./data/routes_states/1_2_states_tuple.npy',allow_pickle=True)" 307 | ] 308 | } 309 | ], 310 | "metadata": { 311 | "kernelspec": { 312 | "display_name": "Python 3.7.11 ('django')", 313 | "language": "python", 314 | "name": "python3" 315 | }, 316 | "language_info": { 317 | "codemirror_mode": { 318 | "name": "ipython", 319 | "version": 3 320 | }, 321 | "file_extension": ".py", 322 | "mimetype": "text/x-python", 323 | "name": "python", 324 | "nbconvert_exporter": "python", 325 | "pygments_lexer": "ipython3", 326 | "version": "3.7.11" 327 | }, 328 | "orig_nbformat": 4, 329 | "vscode": { 330 | "interpreter": { 331 | "hash": "764d3ea85697cfb78fbfbf4297caf293f5408afbb175a5f0ffc0949cef450b37" 332 | } 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 2 337 | } 338 | -------------------------------------------------------------------------------- /trajectory.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import geopandas as gpd 3 | import matplotlib.pyplot as plot 4 | from shapely import geometry 5 | import pandas as pd 6 | import itertools 7 | from collections import deque 8 | import numpy as np 9 | 10 | actions = [0, 1, 2, 3, 4] 11 | dirs = {0: 'r', 1: 'l', 2: 'd', 3: 'u', 4: 's'} 12 | # right, left, down, up , stay 13 | Step=namedtuple('Step',['state','action']) 14 | 15 | def getActionOfStates(route_state): 16 | state_action=[] 17 | length=len(route_state) 18 | first_state=route_state[0] 19 | if length==1: 20 | step=Step(state=first_state,action=4) 21 | state_action.append(step) 22 | return state_action 23 | for i in range(1,length): 24 | second_state=route_state[i] 25 | 26 | def getAction(first_s,second_s): 27 | if second_s-first_s==1: 28 | return 0 29 | elif second_s- first_s==-1: 30 | return 1 31 | elif second_s-first_s==357: 32 | return 3 33 | elif second_s-first_s==-357: 34 | return 2 35 | 36 | idx_minux=second_state-first_state 37 | if idx_minux==358: 38 | second_state=second_state-1 39 | act=getAction(first_state,second_state) 40 | step=Step(state=first_state,action=act) 41 | state_action.append(step) 42 | first_state=second_state 43 | second_state+=1 44 | elif idx_minux==356: 45 | second_state=second_state+1 46 | act=getAction(first_state,second_state) 47 | step=Step(state=first_state,action=act) 48 | state_action.append(step) 49 | first_state=second_state 50 | second_state-=1 51 | elif idx_minux==-358: 52 | second_state=second_state+1 53 | act=getAction(first_state,second_state) 54 | step=Step(state=first_state,action=act) 55 | state_action.append(step) 56 | first_state=second_state 57 | second_state-=1 58 | elif idx_minux==-356: 59 | second_state=second_state-1 60 | act=getAction(first_state,second_state) 61 | step=Step(state=first_state,action=act) 62 | state_action.append(step) 63 | first_state=second_state 64 | second_state+=1 65 | act=getAction(first_state,second_state) 66 | step=Step(state=first_state,action=act) 67 | state_action.append(step) 68 | first_state=second_state 69 | step=Step(state=second_state,action=4) 70 | state_action.append(step) 71 | return state_action 72 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from collections import namedtuple 4 | 5 | Step = namedtuple('Step', 'cur_state action next_state reward done') 6 | 7 | 8 | def normalize(vals): 9 | """ 10 | normalize to (0, max_val) 11 | input: 12 | vals: 1d array 13 | """ 14 | min_val = np.min(vals) 15 | max_val = np.max(vals) 16 | return (vals - min_val) / (max_val - min_val) 17 | 18 | 19 | def sigmoid(xs): 20 | """ 21 | sigmoid function 22 | inputs: 23 | xs 1d array 24 | """ 25 | return [1 / (1 + math.exp(-x)) for x in xs] 26 | --------------------------------------------------------------------------------