├── LICENSE ├── README.md ├── common ├── __init__.py ├── grid.py ├── mbr.py ├── path.py ├── road_network.py ├── spatial_func.py └── trajectory.py ├── data ├── map │ ├── extra_info │ │ ├── new2raw_rid.json │ │ ├── poi20.csv │ │ ├── raw2new_rid.json │ │ ├── raw_rn_dict.json │ │ ├── rn_dict.json │ │ └── weather_dict.pkl │ └── road_network │ │ ├── edges.dbf │ │ ├── edges.shp │ │ ├── edges.shx │ │ ├── nodes.dbf │ │ ├── nodes.shp │ │ └── nodes.shx └── raw_trajectory │ └── 09_01.txt ├── doc └── KDD21_MTrajRec_Huimin.pdf ├── img ├── intro.png └── model_framework.png ├── map_matching ├── __init__.py ├── candidate_point.py ├── hmm │ ├── __init__.py │ ├── hmm_map_matcher.py │ ├── hmm_probabilities.py │ └── ti_viterbi.py ├── map_matcher.py ├── route_constructor.py └── utils.py ├── models ├── __init__.py ├── datasets.py ├── loss_fn.py ├── model_utils.py ├── models_attn_tandem.py └── multi_train.py ├── multi_main.py └── utils ├── __init__.py ├── coord_transform.py ├── imputation.py ├── noise_filter.py ├── parse_traj.py ├── save_traj.py ├── segmentation.py ├── stats.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 paperanonymous945 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MTrajRec 2 |

3 | 4 | 5 |

6 | 7 | ## About 8 | Source code of the KDD'21: [MTrajRec: Map-Constrained Trajectory Recovery via Seq2Seq Multi-task Learning](doc/KDD21_MTrajRec_Huimin.pdf) 9 | ## Requirements 10 | * Python==3.6 11 | * `pytorch==1.7.1` 12 | * `rtree==0.9.4` 13 | * `GDAL==2.3.3` 14 | * `networkx==2.3` 15 | 16 | ## Usage 17 | ### Installation 18 | #### Clone this repo: 19 | ```bash 20 | git clone git@github.com:huiminren/MTrajRec.git 21 | cd MTrajRec 22 | ``` 23 | #### Running 24 | `python multi_main.py` 25 | 26 | #### Dataset 27 | We provide sample data under data/. 28 | 29 | Please note that the sample data is generated with the structure as the original data. For the data preprocessing, please refer to [tptk](https://github.com/sjruan/tptk). 30 | 31 | ## Acknowledge 32 | Thanks to [Sijie](https://github.com/sjruan/) to support data preprocessing. 33 | 34 | ## Citation 35 | If you find this repo useful and would like to cite it, citing our paper as the following will be really appropriate:
36 | ``` 37 | @inproceedings{ren2021mtrajrec, 38 | title={MTrajRec: Map-Constrained Trajectory Recovery via Seq2Seq Multi-task Learning}, 39 | author={Ren, Huimin and Ruan, Sijie and Li, Yanhua and Bao, Jie and Meng, Chuishi and Li, Ruiyuan and Zheng, Yu}, 40 | booktitle={Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining}, 41 | pages={1410--1419}, 42 | year={2021} 43 | } 44 | ``` 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/3 13:34 4 | -------------------------------------------------------------------------------- /common/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .mbr import MBR 3 | from .spatial_func import LAT_PER_METER, LNG_PER_METER 4 | 5 | 6 | class Grid: 7 | """ 8 | index order 9 | 30 31 32 33 34... 10 | 20 21 22 23 24... 11 | 10 11 12 13 14... 12 | 00 01 02 03 04... 13 | """ 14 | 15 | def __init__(self, mbr, row_num, col_num): 16 | self.mbr = mbr 17 | self.row_num = row_num 18 | self.col_num = col_num 19 | self.lat_interval = (mbr.max_lat - mbr.min_lat) / float(row_num) 20 | self.lng_interval = (mbr.max_lng - mbr.min_lng) / float(col_num) 21 | 22 | def get_row_idx(self, lat): 23 | row_idx = int((lat - self.mbr.min_lat) // self.lat_interval) 24 | if row_idx >= self.row_num or row_idx < 0: 25 | raise IndexError("lat is out of mbr") 26 | return row_idx 27 | 28 | def get_col_idx(self, lng): 29 | col_idx = int((lng - self.mbr.min_lng) // self.lng_interval) 30 | if col_idx >= self.col_num or col_idx < 0: 31 | raise IndexError("lng is out of mbr") 32 | return col_idx 33 | 34 | def safe_matrix_to_idx(self, lat, lng): 35 | try: 36 | return self.get_matrix_idx(lat, lng) 37 | except IndexError: 38 | return np.nan, np.nan 39 | 40 | def get_idx(self, lat, lng): 41 | return self.get_row_idx(lat), self.get_col_idx(lng) 42 | 43 | def get_matrix_idx(self, lat, lng): 44 | return self.row_num - 1 - self.get_row_idx(lat), self.get_col_idx(lng) 45 | 46 | def get_min_lng(self, col_idx): 47 | return self.mbr.min_lng + col_idx * self.lng_interval 48 | 49 | def get_max_lng(self, col_idx): 50 | return self.mbr.min_lng + (col_idx + 1) * self.lng_interval 51 | 52 | def get_min_lat(self, row_idx): 53 | return self.mbr.min_lat + row_idx * self.lat_interval 54 | 55 | def get_max_lat(self, row_idx): 56 | return self.mbr.min_lat + (row_idx + 1) * self.lat_interval 57 | 58 | def get_mbr_by_idx(self, row_idx, col_idx): 59 | min_lat = self.get_min_lat(row_idx) 60 | max_lat = self.get_max_lat(row_idx) 61 | min_lng = self.get_min_lng(col_idx) 62 | max_lng = self.get_max_lng(col_idx) 63 | return MBR(min_lat, min_lng, max_lat, max_lng) 64 | 65 | def get_mbr_by_matrix_idx(self, mat_row_idx, mat_col_idx): 66 | row_idx = self.row_num - 1 - mat_row_idx 67 | min_lat = self.get_min_lat(row_idx) 68 | max_lat = self.get_max_lat(row_idx) 69 | min_lng = self.get_min_lng(mat_col_idx) 70 | max_lng = self.get_max_lng(mat_col_idx) 71 | return MBR(min_lat, min_lng, max_lat, max_lng) 72 | 73 | def range_query(self, query_mbr, type): 74 | target_idx = [] 75 | # squeeze the mbr a little, since the top and right boundary are belong to the other grid 76 | delta = 1e-7 77 | min_lat = max(query_mbr.min_lat, self.mbr.min_lat) 78 | min_lng = max(query_mbr.min_lng, self.mbr.min_lng) 79 | max_lat = min(query_mbr.max_lat, self.mbr.max_lat) - delta 80 | max_lng = min(query_mbr.max_lng, self.mbr.max_lng) - delta 81 | if type == 'matrix': 82 | max_row_idx, min_col_idx = self.get_matrix_idx(min_lat, min_lng) 83 | min_row_idx, max_col_idx = self.get_matrix_idx(max_lat, max_lng) 84 | elif type == 'cartesian': 85 | min_row_idx, min_col_idx = self.get_idx(min_lat, min_lng) 86 | max_row_idx, max_col_idx = self.get_idx(max_lat, max_lng) 87 | else: 88 | raise Exception('unrecognized index type') 89 | for r_idx in range(min_row_idx, max_row_idx + 1): 90 | for c_idx in range(min_col_idx, max_col_idx + 1): 91 | target_idx.append((r_idx, c_idx)) 92 | return target_idx 93 | 94 | 95 | # def create_grid(min_lat, min_lng, km_per_cell_lat, km_per_cell_lng, km_lat, km_lng): 96 | # nb_rows = int(km_lat / km_per_cell_lat) 97 | # nb_cols = int(km_lng / km_per_cell_lng) 98 | # max_lat = min_lat + LAT_PER_METER * km_lat * 1000.0 99 | # max_lng = min_lng + LNG_PER_METER * km_lng * 1000.0 100 | # mbr = MBR(min_lat, min_lng, max_lat, max_lng) 101 | # return Grid(mbr, nb_rows, nb_cols) 102 | 103 | def create_grid(min_lat, min_lng, max_lat, max_lng, km_per_cell_lat, km_per_cell_lng): 104 | """ 105 | Given region and unit of each cell, return a Grid class. 106 | Update original function since it's difficult to know the length of lat and lng. 107 | """ 108 | mbr = MBR(min_lat, min_lng, max_lat, max_lng) 109 | km_lat = mbr.get_h 110 | km_lng = mbr.get_w 111 | nb_rows = int(km_lat / km_per_cell_lat) 112 | nb_cols = int(km_lng / km_per_cell_lng) 113 | return Grid(mbr, nb_rows, nb_cols) 114 | -------------------------------------------------------------------------------- /common/mbr.py: -------------------------------------------------------------------------------- 1 | from .spatial_func import distance, SPoint 2 | 3 | class MBR: 4 | """ 5 | MBR creates the minimal bounding regions for users. 6 | """ 7 | def __init__(self, min_lat, min_lng, max_lat, max_lng): 8 | self.min_lat = min_lat 9 | self.min_lng = min_lng 10 | self.max_lat = max_lat 11 | self.max_lng = max_lng 12 | 13 | def contains(self, lat, lng): 14 | # return self.min_lat <= lat <= self.max_lat and self.min_lng <= lng <= self.max_lng 15 | # remove = max.lat/max.lng, to be consist with grid index 16 | return self.min_lat <= lat < self.max_lat and self.min_lng <= lng < self.max_lng 17 | 18 | def center(self): 19 | return (self.min_lat + self.max_lat) / 2.0, (self.min_lng + self.max_lng) / 2.0 20 | 21 | def get_h(self): 22 | return distance(SPoint(self.min_lat, self.min_lng), SPoint(self.max_lat, self.min_lng)) 23 | 24 | def get_w(self): 25 | return distance(SPoint(self.min_lat, self.min_lng), SPoint(self.min_lat, self.max_lng)) 26 | 27 | def __str__(self): 28 | h = self.get_h() 29 | w = self.get_w() 30 | return '{}x{}m2'.format(h, w) 31 | 32 | def __eq__(self, other): 33 | return self.min_lat == other.min_lat and self.min_lng == other.min_lng \ 34 | and self.max_lat == other.max_lat and self.max_lng == other.max_lng 35 | 36 | def to_wkt(self): 37 | # Here providing five points is for GIS visualization 38 | # sometimes wkt cannot draw a rectangle without the last point. 39 | # (the last point should be the same as the first one) 40 | return 'POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))'.format(self.min_lng, self.min_lat, 41 | self.min_lng, self.max_lat, 42 | self.max_lng, self.max_lat, 43 | self.max_lng, self.min_lat, 44 | self.min_lng, self.min_lat) 45 | 46 | @staticmethod 47 | # staticmethod means this function will not use self attribute. 48 | def cal_mbr(coords): 49 | """ 50 | Find MBR from coordinates 51 | Args: 52 | ----- 53 | coords: 54 | list of Point() 55 | Returns: 56 | ------- 57 | MBR() 58 | """ 59 | min_lat = float('inf') 60 | min_lng = float('inf') 61 | max_lat = float('-inf') 62 | max_lng = float('-inf') 63 | for coord in coords: 64 | if coord.lat > max_lat: 65 | max_lat = coord.lat 66 | if coord.lat < min_lat: 67 | min_lat = coord.lat 68 | if coord.lng > max_lng: 69 | max_lng = coord.lng 70 | if coord.lng < min_lng: 71 | min_lng = coord.lng 72 | return MBR(min_lat, min_lng, max_lat, max_lng) 73 | 74 | @staticmethod 75 | def load_mbr(file_path): 76 | with open(file_path, 'r') as f: 77 | f.readline() 78 | attrs = f.readline()[:-1].split(';') 79 | mbr = MBR(float(attrs[1]), float(attrs[2]), float(attrs[3]), float(attrs[4])) 80 | return mbr 81 | 82 | @staticmethod 83 | def store_mbr(mbr, file_path): 84 | with open(file_path, 'w') as f: 85 | f.write('name;min_lat;min_lng;max_lat;max_lng;wkt\n') 86 | f.write('{};{};{};{};{};{}\n'.format(0, mbr.min_lat, mbr.min_lng, mbr.max_lat, mbr.max_lng, mbr.to_wkt())) 87 | -------------------------------------------------------------------------------- /common/path.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | class PathEntity: 4 | def __init__(self, enter_time, leave_time, eid): 5 | self.enter_time = enter_time 6 | self.leave_time = leave_time 7 | self.eid = eid 8 | 9 | 10 | class Path: 11 | def __init__(self, oid, pid, path_entities): 12 | self.oid = oid 13 | self.pid = pid 14 | self.path_entities = path_entities 15 | 16 | 17 | def parse_path_file(input_path): 18 | time_format = '%Y-%m-%d %H:%M:%S.%f' 19 | paths = [] 20 | with open(input_path, 'r') as f: 21 | path_entities = [] 22 | pid = None 23 | for line in f.readlines(): 24 | attrs = line.rstrip().split(',') 25 | if attrs[0] == '#': 26 | if len(path_entities) != 0: 27 | paths.append(Path(oid, pid, path_entities)) 28 | oid = attrs[2] 29 | pid = attrs[1] 30 | path_entities = [] 31 | else: 32 | enter_time = datetime.strptime(attrs[0], time_format) 33 | leave_time = datetime.strptime(attrs[1], time_format) 34 | eid = int(attrs[2]) 35 | path_entities.append(PathEntity(enter_time, leave_time, eid)) 36 | if len(path_entities) != 0: 37 | paths.append(Path(oid, pid, path_entities)) 38 | return paths 39 | 40 | 41 | def store_path_file(paths, target_path): 42 | with open(target_path, 'w') as f: 43 | for path in paths: 44 | path_entities = path.path_entities 45 | f.write('#,{},{},{},{}\n'.format(path.pid, path.oid, 46 | path_entities[0].enter_time.isoformat(sep=' ', timespec='milliseconds'), 47 | path_entities[-1].leave_time.isoformat(sep=' ', timespec='milliseconds'))) 48 | for path_entity in path_entities: 49 | f.write('{},{},{}\n'.format(path_entity.enter_time.isoformat(sep=' ', timespec='milliseconds'), 50 | path_entity.leave_time.isoformat(sep=' ', timespec='milliseconds'), 51 | path_entity.eid)) 52 | -------------------------------------------------------------------------------- /common/road_network.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from rtree import Rtree 3 | from osgeo import ogr 4 | from .spatial_func import SPoint, distance 5 | from .mbr import MBR 6 | import copy 7 | 8 | 9 | class UndirRoadNetwork(nx.Graph): 10 | def __init__(self, g, edge_spatial_idx, edge_idx): 11 | super(UndirRoadNetwork, self).__init__(g) 12 | # entry: eid 13 | self.edge_spatial_idx = edge_spatial_idx 14 | # eid -> edge key (start_coord, end_coord) 15 | self.edge_idx = edge_idx 16 | 17 | def to_directed(self, as_view=False): 18 | """ 19 | Convert undirected road network to directed road network 20 | new edge will have new eid, and each original edge will have two edge with reversed coords 21 | :return: 22 | """ 23 | assert as_view is False, "as_view is not supported" 24 | avail_eid = max([eid for u, v, eid in self.edges.data(data='eid')]) + 1 25 | g = nx.DiGraph() 26 | edge_spatial_idx = Rtree() 27 | edge_idx = {} 28 | # add nodes 29 | for n, data in self.nodes(data=True): 30 | # when data=True, it means will data=node's attributes 31 | new_data = copy.deepcopy(data) 32 | g.add_node(n, **new_data) 33 | # add edges 34 | for u, v, data in self.edges(data=True): 35 | mbr = MBR.cal_mbr(data['coords']) 36 | # add forward edge 37 | forward_data = copy.deepcopy(data) 38 | g.add_edge(u, v, **forward_data) 39 | edge_spatial_idx.insert(forward_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 40 | edge_idx[forward_data['eid']] = (u, v) 41 | # add backward edge 42 | backward_data = copy.deepcopy(data) 43 | backward_data['eid'] = avail_eid 44 | avail_eid += 1 45 | backward_data['coords'].reverse() 46 | g.add_edge(v, u, **backward_data) 47 | edge_spatial_idx.insert(backward_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 48 | edge_idx[backward_data['eid']] = (v, u) 49 | print('# of nodes:{}'.format(g.number_of_nodes())) 50 | print('# of edges:{}'.format(g.number_of_edges())) 51 | return RoadNetwork(g, edge_spatial_idx, edge_idx) 52 | 53 | def range_query(self, mbr): 54 | """ 55 | spatial range query. Given a mbr, return a range of edges. 56 | :param mbr: query mbr 57 | :return: qualified edge keys 58 | """ 59 | eids = self.edge_spatial_idx.intersection((mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 60 | return [self.edge_idx[eid] for eid in eids] 61 | 62 | def remove_edge(self, u, v): 63 | edge_data = self[u][v] 64 | coords = edge_data['coords'] 65 | mbr = MBR.cal_mbr(coords) 66 | # delete self.edge_idx[eid] from edge index 67 | del self.edge_idx[edge_data['eid']] 68 | # delete from spatial index 69 | self.edge_spatial_idx.delete(edge_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 70 | # delete from graph 71 | super(UndirRoadNetwork, self).remove_edge(u, v) 72 | 73 | def add_edge(self, u_of_edge, v_of_edge, **attr): 74 | coords = attr['coords'] 75 | mbr = MBR.cal_mbr(coords) 76 | attr['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)]) 77 | # add edge to edge index 78 | self.edge_idx[attr['eid']] = (u_of_edge, v_of_edge) 79 | # add edge to spatial index 80 | self.edge_spatial_idx.insert(attr['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 81 | # add edge to graph 82 | super(UndirRoadNetwork, self).add_edge(u_of_edge, v_of_edge, **attr) 83 | 84 | 85 | class RoadNetwork(nx.DiGraph): 86 | def __init__(self, g, edge_spatial_idx, edge_idx): 87 | super(RoadNetwork, self).__init__(g) 88 | # entry: eid 89 | self.edge_spatial_idx = edge_spatial_idx 90 | # eid -> edge key (start_coord, end_coord) 91 | self.edge_idx = edge_idx 92 | 93 | def range_query(self, mbr): 94 | """ 95 | spatial range query 96 | :param mbr: query mbr 97 | :return: qualified edge keys 98 | """ 99 | eids = self.edge_spatial_idx.intersection((mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 100 | return [self.edge_idx[eid] for eid in eids] 101 | 102 | def remove_edge(self, u, v): 103 | edge_data = self[u][v] 104 | coords = edge_data['coords'] 105 | mbr = MBR.cal_mbr(coords) 106 | # delete self.edge_idx[eifrom edge index 107 | del self.edge_idx[edge_data['eid']] 108 | # delete from spatial index 109 | self.edge_spatial_idx.delete(edge_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 110 | # delete from graph 111 | super(RoadNetwork, self).remove_edge(u, v) 112 | 113 | def add_edge(self, u_of_edge, v_of_edge, **attr): 114 | coords = attr['coords'] 115 | mbr = MBR.cal_mbr(coords) 116 | attr['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)]) 117 | # add edge to edge index 118 | self.edge_idx[attr['eid']] = (u_of_edge, v_of_edge) 119 | # add edge to spatial index 120 | self.edge_spatial_idx.insert(attr['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat)) 121 | # add edge to graph 122 | super(RoadNetwork, self).add_edge(u_of_edge, v_of_edge, **attr) 123 | 124 | 125 | def load_rn_shp(path, is_directed=True): 126 | edge_spatial_idx = Rtree() 127 | edge_idx = {} 128 | # node uses coordinate as key 129 | # edge uses coordinate tuple as key 130 | g = nx.read_shp(path, simplify=True, strict=False) 131 | if not is_directed: 132 | g = g.to_undirected() 133 | # node attrs: nid, pt, ... 134 | for n, data in g.nodes(data=True): 135 | data['pt'] = SPoint(n[1], n[0]) 136 | if 'ShpName' in data: 137 | del data['ShpName'] 138 | # edge attrs: eid, length, coords, ... 139 | for u, v, data in g.edges(data=True): 140 | geom_line = ogr.CreateGeometryFromWkb(data['Wkb']) 141 | coords = [] 142 | for i in range(geom_line.GetPointCount()): 143 | geom_pt = geom_line.GetPoint(i) 144 | coords.append(SPoint(geom_pt[1], geom_pt[0])) 145 | data['coords'] = coords 146 | data['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)]) 147 | env = geom_line.GetEnvelope() 148 | edge_spatial_idx.insert(data['eid'], (env[0], env[2], env[1], env[3])) 149 | edge_idx[data['eid']] = (u, v) 150 | del data['ShpName'] 151 | del data['Json'] 152 | del data['Wkt'] 153 | del data['Wkb'] 154 | print('# of nodes:{}'.format(g.number_of_nodes())) 155 | print('# of edges:{}'.format(g.number_of_edges())) 156 | if not is_directed: 157 | return UndirRoadNetwork(g, edge_spatial_idx, edge_idx) 158 | else: 159 | return RoadNetwork(g, edge_spatial_idx, edge_idx) 160 | 161 | 162 | def store_rn_shp(rn, target_path): 163 | print('# of nodes:{}'.format(rn.number_of_nodes())) 164 | print('# of edges:{}'.format(rn.number_of_edges())) 165 | for _, data in rn.nodes(data=True): 166 | if 'pt' in data: 167 | del data['pt'] 168 | for _, _, data in rn.edges(data=True): 169 | geo_line = ogr.Geometry(ogr.wkbLineString) 170 | for coord in data['coords']: 171 | geo_line.AddPoint(coord.lng, coord.lat) 172 | data['Wkb'] = geo_line.ExportToWkb() 173 | del data['coords'] 174 | if 'length' in data: 175 | del data['length'] 176 | if not rn.is_directed(): 177 | rn = rn.to_directed() 178 | nx.write_shp(rn, target_path) 179 | -------------------------------------------------------------------------------- /common/spatial_func.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | DEGREES_TO_RADIANS = math.pi / 180 4 | RADIANS_TO_DEGREES = 1 / DEGREES_TO_RADIANS 5 | EARTH_MEAN_RADIUS_METER = 6371008.7714 6 | DEG_TO_KM = DEGREES_TO_RADIANS * EARTH_MEAN_RADIUS_METER 7 | LAT_PER_METER = 8.993203677616966e-06 8 | LNG_PER_METER = 1.1700193970443768e-05 9 | 10 | 11 | class SPoint: 12 | def __init__(self, lat, lng): 13 | self.lat = lat 14 | self.lng = lng 15 | 16 | def __str__(self): 17 | return '({},{})'.format(self.lat, self.lng) 18 | 19 | def __repr__(self): 20 | return self.__str__() 21 | 22 | def __eq__(self, other): 23 | # equal. Orginally is compared with reference. Here we change to value 24 | return self.lat == other.lat and self.lng == other.lng 25 | 26 | def __ne__(self, other): 27 | # not equal 28 | return not self == other 29 | 30 | 31 | def same_coords(a, b): 32 | # we can directly use == since SPoint has updated __eq__() 33 | if a == b: 34 | return True 35 | else: 36 | return False 37 | 38 | 39 | def distance(a, b): 40 | """ 41 | Calculate haversine distance between two GPS points in meters 42 | Args: 43 | ----- 44 | a,b: SPoint class 45 | Returns: 46 | -------- 47 | d: float. haversine distance in meter 48 | """ 49 | if same_coords(a, b): 50 | return 0.0 51 | delta_lat = math.radians(b.lat - a.lat) 52 | delta_lng = math.radians(b.lng - a.lng) 53 | h = math.sin(delta_lat / 2.0) * math.sin(delta_lat / 2.0) + math.cos(math.radians(a.lat)) * math.cos( 54 | math.radians(b.lat)) * math.sin(delta_lng / 2.0) * math.sin(delta_lng / 2.0) 55 | c = 2.0 * math.atan2(math.sqrt(h), math.sqrt(1 - h)) 56 | d = EARTH_MEAN_RADIUS_METER * c 57 | return d 58 | 59 | 60 | # http://www.movable-type.co.uk/scripts/latlong.html 61 | def bearing(a, b): 62 | """ 63 | Calculate the bearing of ab 64 | """ 65 | pt_a_lat_rad = math.radians(a.lat) 66 | pt_a_lng_rad = math.radians(a.lng) 67 | pt_b_lat_rad = math.radians(b.lat) 68 | pt_b_lng_rad = math.radians(b.lng) 69 | y = math.sin(pt_b_lng_rad - pt_a_lng_rad) * math.cos(pt_b_lat_rad) 70 | x = math.cos(pt_a_lat_rad) * math.sin(pt_b_lat_rad) - math.sin(pt_a_lat_rad) * math.cos(pt_b_lat_rad) * math.cos(pt_b_lng_rad - pt_a_lng_rad) 71 | bearing_rad = math.atan2(y, x) 72 | return math.fmod(math.degrees(bearing_rad) + 360.0, 360.0) 73 | 74 | 75 | def cal_loc_along_line(a, b, rate): 76 | """ 77 | convert rate to gps location 78 | """ 79 | lat = a.lat + rate * (b.lat - a.lat) 80 | lng = a.lng + rate * (b.lng - a.lng) 81 | return SPoint(lat, lng) 82 | 83 | 84 | def project_pt_to_segment(a, b, t): 85 | """ 86 | Args: 87 | ----- 88 | a,b: start/end GPS location of a road segment 89 | t: raw point 90 | Returns: 91 | ------- 92 | project: projected GPS point on road segment 93 | rate: rate of projected point location to road segment 94 | dist: haversine_distance of raw and projected point 95 | """ 96 | ab_angle = bearing(a, b) 97 | at_angle = bearing(a, t) 98 | ab_length = distance(a, b) 99 | at_length = distance(a, t) 100 | delta_angle = at_angle - ab_angle 101 | meters_along = at_length * math.cos(math.radians(delta_angle)) 102 | if ab_length == 0.0: 103 | rate = 0.0 104 | else: 105 | rate = meters_along / ab_length 106 | if rate >= 1: 107 | projection = SPoint(b.lat, b.lng) 108 | rate = 1.0 109 | elif rate <= 0: 110 | projection = SPoint(a.lat, a.lng) 111 | rate = 0.0 112 | else: 113 | projection = cal_loc_along_line(a, b, rate) 114 | dist = distance(t, projection) 115 | return projection, rate, dist 116 | -------------------------------------------------------------------------------- /common/trajectory.py: -------------------------------------------------------------------------------- 1 | from .spatial_func import distance, SPoint, cal_loc_along_line 2 | from .mbr import MBR 3 | from datetime import timedelta 4 | 5 | 6 | class STPoint(SPoint): 7 | """ 8 | STPoint creates a data type for spatio-temporal point, i.e. STPoint(). 9 | 10 | """ 11 | 12 | def __init__(self, lat, lng, time, data=None): 13 | super(STPoint, self).__init__(lat, lng) 14 | self.time = time 15 | self.data = data # contains edge's attributes 16 | 17 | def __str__(self): 18 | """ 19 | For easily reading the output 20 | """ 21 | # __repr__() to change the print review 22 | # st = STPoint() 23 | # print(st) will not be the reference but the following format 24 | # if __repr__ is changed to str format, __str__ will be automatically change. 25 | 26 | return str(self.__dict__) # key and value of self attributes 27 | # return '({}, {}, {})'.format(self.time.strftime('%Y/%m/%d %H:%M:%S'), self.lat, self.lng) 28 | 29 | 30 | class Trajectory: 31 | """ 32 | Trajectory creates a data type for trajectory, i.e. Trajectory() 33 | """ 34 | 35 | def __init__(self, oid, tid, pt_list): 36 | """ 37 | Args: 38 | ----- 39 | oid: 40 | str. human id 41 | tid: 42 | str. trajectory id, sometimes based on start and end time. see get_tid() 43 | pt_list: 44 | list of STPoint(lat, lng, time), containing the attributes of class STPoint 45 | """ 46 | self.oid = oid 47 | self.tid = tid 48 | self.pt_list = pt_list 49 | 50 | def get_duration(self): 51 | """ 52 | Get duration of a trajectory (pt_list) 53 | last_point.time - first_point.time 54 | seconds 55 | """ 56 | return (self.pt_list[-1].time - self.pt_list[0].time).total_seconds() 57 | 58 | def get_distance(self): 59 | """ 60 | Get geographical distance of a trajectory (pt_list) 61 | sum of two adjacent points 62 | meters 63 | """ 64 | dist = 0.0 65 | pre_pt = self.pt_list[0] 66 | for pt in self.pt_list[1:]: 67 | tmp_dist = distance(pre_pt, pt) 68 | dist += tmp_dist 69 | pre_pt = pt 70 | return dist 71 | 72 | def get_avg_time_interval(self): 73 | """ 74 | Calculate average time interval between two GPS points in one trajectory (pt_list) 75 | """ 76 | point_time_interval = [] 77 | # How clever method! zip to get time interval 78 | 79 | for pre, cur in zip(self.pt_list[:-1], self.pt_list[1:]): 80 | point_time_interval.append((cur.time - pre.time).total_seconds()) 81 | return sum(point_time_interval) / len(point_time_interval) 82 | 83 | def get_avg_distance_interval(self): 84 | """ 85 | Calculate average distance interval between two GPS points in one trajectory (pt_list) 86 | """ 87 | point_dist_interval = [] 88 | for pre, cur in zip(self.pt_list[:-1], self.pt_list[1:]): 89 | point_dist_interval.append(distance(pre, cur)) 90 | return sum(point_dist_interval) / len(point_dist_interval) 91 | 92 | def get_mbr(self): 93 | return MBR.cal_mbr(self.pt_list) 94 | 95 | def get_start_time(self): 96 | return self.pt_list[0].time 97 | 98 | def get_end_time(self): 99 | return self.pt_list[-1].time 100 | 101 | def get_mid_time(self): 102 | return self.pt_list[0].time + (self.pt_list[-1].time - self.pt_list[0].time) / 2.0 103 | 104 | def get_centroid(self): 105 | """ 106 | Get centroid SPoint 107 | """ 108 | mean_lat = 0.0 109 | mean_lng = 0.0 110 | for pt in self.pt_list: 111 | mean_lat += pt.lat 112 | mean_lng += pt.lng 113 | mean_lat /= len(self.pt_list) 114 | mean_lng /= len(self.pt_list) 115 | return SPoint(mean_lat, mean_lng) 116 | 117 | def query_trajectory_by_temporal_range(self, start_time, end_time): 118 | # start_time <= pt.time < end_time 119 | traj_start_time = self.get_start_time() 120 | traj_end_time = self.get_end_time() 121 | if start_time > traj_end_time: 122 | return None 123 | if end_time <= traj_start_time: 124 | return None 125 | st = max(traj_start_time, start_time) 126 | et = min(traj_end_time + timedelta(seconds=1), end_time) 127 | start_idx = self.binary_search_idx(st) # pt_list[start_idx].time <= st < pt_list[start_idx+1].time 128 | if self.pt_list[start_idx].time < st: 129 | # then the start_idx is out of the range, we need to increase it 130 | start_idx += 1 131 | end_idx = self.binary_search_idx(et) # pt_list[end_idx].time <= et < pt_list[end_idx+1].time 132 | if self.pt_list[end_idx].time < et: 133 | # then the end_idx is acceptable 134 | end_idx += 1 135 | sub_pt_list = self.pt_list[start_idx:end_idx] 136 | return Trajectory(self.oid, get_tid(self.oid, sub_pt_list), sub_pt_list) 137 | 138 | def binary_search_idx(self, time): 139 | # self.pt_list[idx].time <= time < self.pt_list[idx+1].time 140 | # if time < self.pt_list[0].time, return -1 141 | # if time >= self.pt_list[len(self.pt_list)-1].time, return len(self.pt_list)-1 142 | nb_pts = len(self.pt_list) 143 | if time < self.pt_list[0].time: 144 | return -1 145 | if time >= self.pt_list[-1].time: 146 | return nb_pts - 1 147 | # the time is in the middle 148 | left_idx = 0 149 | right_idx = nb_pts - 1 150 | while left_idx <= right_idx: 151 | mid_idx = int((left_idx + right_idx) / 2) 152 | if mid_idx < nb_pts - 1 and self.pt_list[mid_idx].time <= time < self.pt_list[mid_idx + 1].time: 153 | return mid_idx 154 | elif self.pt_list[mid_idx].time < time: 155 | left_idx = mid_idx + 1 156 | else: 157 | right_idx = mid_idx - 1 158 | 159 | def query_location_by_timestamp(self, time): 160 | idx = self.binary_search_idx(time) 161 | if idx == -1 or idx == len(self.pt_list) - 1: 162 | return None 163 | if self.pt_list[idx].time == time or (self.pt_list[idx + 1].time - self.pt_list[idx].time).total_seconds() == 0: 164 | return SPoint(self.pt_list[idx].lat, self.pt_list[idx].lng) 165 | else: 166 | # interpolate location 167 | dist_ab = distance(self.pt_list[idx], self.pt_list[idx + 1]) 168 | if dist_ab == 0: 169 | return SPoint(self.pt_list[idx].lat, self.pt_list[idx].lng) 170 | dist_traveled = dist_ab * (time - self.pt_list[idx].time).total_seconds() / \ 171 | (self.pt_list[idx + 1].time - self.pt_list[idx].time).total_seconds() 172 | return cal_loc_along_line(self.pt_list[idx], self.pt_list[idx + 1], dist_traveled / dist_ab) 173 | 174 | def to_wkt(self): 175 | wkt = 'LINESTRING (' 176 | for pt in self.pt_list: 177 | wkt += '{} {}, '.format(pt.lng, pt.lat) 178 | wkt = wkt[:-2] + ')' 179 | return wkt 180 | 181 | def __hash__(self): 182 | return hash(self.oid + '_' + self.pt_list[0].time.strftime('%Y%m%d%H%M%S') + '_' + 183 | self.pt_list[-1].time.strftime('%Y%m%d%H%M%S')) 184 | 185 | def __eq__(self, other): 186 | return hash(self) == hash(other) 187 | 188 | def __repr__(self): 189 | return f'Trajectory(oid={self.oid},tid={self.tid})' 190 | 191 | 192 | def get_tid(oid, pt_list): 193 | return oid + '_' + pt_list[0].time.strftime('%Y%m%d%H%M%S') + '_' + pt_list[-1].time.strftime('%Y%m%d%H%M%S') 194 | 195 | # remove parse and store trajectory functions. The related functions are in utils.parse_traj, utils.store_traj 196 | -------------------------------------------------------------------------------- /data/map/extra_info/poi20.csv: -------------------------------------------------------------------------------- 1 | ,,company,shopping,food,house,viewpoint 2 | 173,73,1,1,0,0,0 3 | 215,157,1,0,1,0,0 4 | 87,154,1,1,2,0,0 5 | 181,199,1,2,0,1,0 6 | 132,141,0,2,4,0,2 7 | 68,228,0,1,0,2,0 8 | 125,82,0,0,0,1,0 9 | 127,159,0,2,0,0,0 10 | 220,283,0,1,0,0,0 11 | 42,141,3,0,0,1,0 12 | 158,205,0,26,0,0,0 13 | 220,182,5,0,0,2,0 14 | 49,239,0,2,1,0,0 15 | 111,45,1,3,1,0,0 16 | 119,334,0,0,0,1,0 17 | 25,248,0,0,0,1,0 18 | 155,199,0,0,0,1,0 19 | 172,289,0,0,1,0,0 20 | 143,24,0,1,1,0,0 21 | -------------------------------------------------------------------------------- /data/map/extra_info/weather_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/extra_info/weather_dict.pkl -------------------------------------------------------------------------------- /data/map/road_network/edges.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.dbf -------------------------------------------------------------------------------- /data/map/road_network/edges.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.shp -------------------------------------------------------------------------------- /data/map/road_network/edges.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.shx -------------------------------------------------------------------------------- /data/map/road_network/nodes.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.dbf -------------------------------------------------------------------------------- /data/map/road_network/nodes.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.shp -------------------------------------------------------------------------------- /data/map/road_network/nodes.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.shx -------------------------------------------------------------------------------- /doc/KDD21_MTrajRec_Huimin.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/doc/KDD21_MTrajRec_Huimin.pdf -------------------------------------------------------------------------------- /img/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/img/intro.png -------------------------------------------------------------------------------- /img/model_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/img/model_framework.png -------------------------------------------------------------------------------- /map_matching/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/3 13:34 -------------------------------------------------------------------------------- /map_matching/candidate_point.py: -------------------------------------------------------------------------------- 1 | from common.spatial_func import SPoint, LAT_PER_METER, LNG_PER_METER, project_pt_to_segment, distance 2 | from common.mbr import MBR 3 | 4 | 5 | class CandidatePoint(SPoint): 6 | def __init__(self, lat, lng, eid, error, offset, rate): 7 | super(CandidatePoint, self).__init__(lat, lng) 8 | self.eid = eid 9 | self.error = error 10 | self.offset = offset 11 | self.rate = rate 12 | 13 | def __str__(self): 14 | return '{},{},{},{},{},{}'.format(self.eid, self.lat, self.lng, self.error, self.offset, self.rate) 15 | 16 | def __repr__(self): 17 | return '{},{},{},{},{},{}'.format(self.eid, self.lat, self.lng, self.error, self.offset, self.rate) 18 | 19 | def __hash__(self): 20 | return hash(self.__str__()) 21 | 22 | 23 | def get_candidates(pt, rn, search_dist): 24 | """ 25 | Args: 26 | ----- 27 | pt: point STPoint() 28 | rn: road network 29 | search_dist: in meter. a parameter for HMM_mm. range of pt's potential road 30 | Returns: 31 | -------- 32 | candidates: list of potential projected points. 33 | """ 34 | candidates = None 35 | mbr = MBR(pt.lat - search_dist * LAT_PER_METER, 36 | pt.lng - search_dist * LNG_PER_METER, 37 | pt.lat + search_dist * LAT_PER_METER, 38 | pt.lng + search_dist * LNG_PER_METER) 39 | candidate_edges = rn.range_query(mbr) # list of edges (two nodes/points) 40 | if len(candidate_edges) > 0: 41 | candi_pt_list = [cal_candidate_point(pt, rn, candidate_edge) for candidate_edge in candidate_edges] 42 | # refinement 43 | candi_pt_list = [candi_pt for candi_pt in candi_pt_list if candi_pt.error <= search_dist] 44 | if len(candi_pt_list) > 0: 45 | candidates = candi_pt_list 46 | return candidates 47 | 48 | 49 | def cal_candidate_point(raw_pt, rn, edge): 50 | """ 51 | Get attributes of candidate point 52 | """ 53 | u, v = edge 54 | coords = rn[u][v]['coords'] # GPS points in road segment, may be larger than 2 55 | candidates = [project_pt_to_segment(coords[i], coords[i + 1], raw_pt) for i in range(len(coords) - 1)] 56 | idx, (projection, coor_rate, dist) = min(enumerate(candidates), key=lambda x: x[1][2]) 57 | # enumerate return idx and (), x[1] --> () x[1][2] --> dist. get smallest error project edge 58 | offset = 0.0 59 | for i in range(idx): 60 | offset += distance(coords[i], coords[i + 1]) # make the road distance more accurately 61 | offset += distance(coords[idx], projection) # distance of road start position and projected point 62 | if rn[u][v]['length'] == 0: 63 | rate = 0 64 | # print(u, v) 65 | else: 66 | rate = offset/rn[u][v]['length'] # rate of whole road, coor_rate is the rate of coords. 67 | return CandidatePoint(projection.lat, projection.lng, rn[u][v]['eid'], dist, offset, rate) 68 | 69 | -------------------------------------------------------------------------------- /map_matching/hmm/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/3 13:34 -------------------------------------------------------------------------------- /map_matching/hmm/hmm_map_matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on Newson, Paul, and John Krumm. "Hidden Markov map matching through 3 | noise and sparseness." Proceedings of the 17th ACM SIGSPATIAL International 4 | Conference on Advances in Geographic Information Systems. ACM, 2009. 5 | This is a Python translation from https://github.com/graphhopper/map-matching/tree/master/hmm-lib 6 | """ 7 | 8 | from ..hmm.hmm_probabilities import HMMProbabilities 9 | from ..hmm.ti_viterbi import ViterbiAlgorithm, SequenceState 10 | from ..map_matcher import MapMatcher 11 | from ..candidate_point import get_candidates 12 | from common.spatial_func import distance 13 | from common.trajectory import STPoint, Trajectory 14 | from ..utils import find_shortest_path 15 | from ..route_constructor import construct_path 16 | 17 | 18 | class TimeStep: 19 | """ 20 | Contains everything the hmm-lib needs to process a new time step including emission and observation probabilities. 21 | emission probability: p(z|r), the likelihood that the measurement z would be observed if the vehicle were actually 22 | on road segment r 23 | """ 24 | def __init__(self, observation, candidates): 25 | if observation is None or candidates is None: 26 | raise Exception('observation and candidates must not be null.') 27 | self.observation = observation 28 | self.candidates = candidates 29 | self.emission_log_probabilities = {} 30 | self.transition_log_probabilities = {} 31 | # transition -> dist 32 | self.road_paths = {} 33 | 34 | def add_emission_log_probability(self, candidate, emission_log_probability): 35 | if candidate in self.emission_log_probabilities: 36 | raise Exception('Candidate has already been added.') 37 | self.emission_log_probabilities[candidate] = emission_log_probability 38 | 39 | def add_transition_log_probability(self, from_position, to_position, transition_log_probability): 40 | transition = (from_position, to_position) 41 | if transition in self.transition_log_probabilities: 42 | raise Exception('Transition has already been added.') 43 | self.transition_log_probabilities[transition] = transition_log_probability 44 | 45 | def add_road_path(self, from_position, to_position, road_path): 46 | transition = (from_position, to_position) 47 | if transition in self.road_paths: 48 | raise Exception('Transition has already been added.') 49 | self.road_paths[transition] = road_path 50 | 51 | 52 | class TIHMMMapMatcher(MapMatcher): 53 | def __init__(self, rn, search_dis=50, sigma=5.0, beta=2.0, routing_weight='length', debug=False): 54 | self.measurement_error_sigma = search_dis # search_dist, original paper = 200m. 4/50m ? 55 | self.transition_probability_beta = beta # beta is a parameter in equation(2). 56 | self.guassian_sigma = sigma 57 | # A larger beta measures the difference between great circle distances and route distances, 58 | # represents more tolerance of non-direct routes. 59 | self.debug = debug 60 | super(TIHMMMapMatcher, self).__init__(rn, routing_weight) 61 | 62 | # our implementation, no candidates or no transition will be set to None, and start a new matching 63 | def match(self, traj): 64 | """ Given original traj, return map-matched trajectory""" 65 | seq = self.compute_viterbi_sequence(traj.pt_list) 66 | assert len(traj.pt_list) == len(seq), 'pt_list and seq must have the same size' 67 | mm_pt_list = [] 68 | for ss in seq: 69 | candi_pt = None 70 | if ss.state is not None: 71 | candi_pt = ss.state 72 | data = {'candi_pt': candi_pt} 73 | mm_pt_list.append(STPoint(ss.observation.lat, ss.observation.lng, ss.observation.time, data)) 74 | mm_traj = Trajectory(traj.oid, traj.tid, mm_pt_list) 75 | return mm_traj 76 | 77 | def match_to_path(self, traj): 78 | mm_traj = self.match(traj) 79 | path = construct_path(self.rn, mm_traj, self.routing_weight) 80 | return path 81 | 82 | def create_time_step(self, pt): 83 | time_step = None 84 | candidates = get_candidates(pt, self.rn, self.measurement_error_sigma) 85 | if candidates is not None: 86 | time_step = TimeStep(pt, candidates) 87 | return time_step 88 | 89 | def compute_viterbi_sequence(self, pt_list): 90 | """ 91 | Args: 92 | ----- 93 | pt_list: observation pt_list 94 | Returns: 95 | ------- 96 | seq: ? 97 | """ 98 | seq = [] 99 | probabilities = HMMProbabilities(self.guassian_sigma, self.transition_probability_beta) 100 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug) 101 | prev_time_step = None 102 | idx = 0 103 | nb_points = len(pt_list) 104 | while idx < nb_points: 105 | time_step = self.create_time_step(pt_list[idx]) 106 | # construct the sequence ended at t-1, and skip current point (no candidate error) 107 | if time_step is None: 108 | seq.extend(viterbi.compute_most_likely_sequence()) 109 | seq.append(SequenceState(None, pt_list[idx], None)) 110 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug) 111 | prev_time_step = None 112 | else: 113 | self.compute_emission_probabilities(time_step, probabilities) 114 | if prev_time_step is None: 115 | viterbi.start_with_initial_observation(time_step.observation, time_step.candidates, 116 | time_step.emission_log_probabilities) 117 | else: 118 | self.compute_transition_probabilities(prev_time_step, time_step, probabilities) 119 | viterbi.next_step(time_step.observation, time_step.candidates, time_step.emission_log_probabilities, 120 | time_step.transition_log_probabilities, time_step.road_paths) 121 | if viterbi.is_broken: 122 | # construct the sequence ended at t-1, and start a new matching at t (no transition error) 123 | seq.extend(viterbi.compute_most_likely_sequence()) 124 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug) 125 | viterbi.start_with_initial_observation(time_step.observation, time_step.candidates, 126 | time_step.emission_log_probabilities) 127 | prev_time_step = time_step 128 | idx += 1 129 | if len(seq) < nb_points: 130 | seq.extend(viterbi.compute_most_likely_sequence()) 131 | return seq 132 | 133 | def compute_emission_probabilities(self, time_step, probabilities): 134 | for candi_pt in time_step.candidates: 135 | dist = candi_pt.error 136 | time_step.add_emission_log_probability(candi_pt, probabilities.emission_log_probability(dist)) 137 | 138 | def compute_transition_probabilities(self, prev_time_step, time_step, probabilities): 139 | linear_dist = distance(prev_time_step.observation, time_step.observation) 140 | for prev_candi_pt in prev_time_step.candidates: 141 | for cur_candi_pt in time_step.candidates: 142 | path_dist, path = find_shortest_path(self.rn, prev_candi_pt, cur_candi_pt, self.routing_weight) 143 | # invalid transition has no transition probability 144 | if path is not None: 145 | time_step.add_road_path(prev_candi_pt, cur_candi_pt, path) 146 | time_step.add_transition_log_probability(prev_candi_pt, cur_candi_pt, 147 | probabilities.transition_log_probability(path_dist, 148 | linear_dist)) 149 | -------------------------------------------------------------------------------- /map_matching/hmm/hmm_probabilities.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class HMMProbabilities: 5 | def __init__(self, sigma, beta): 6 | self.sigma = sigma 7 | self.beta = beta 8 | 9 | def emission_log_probability(self, distance): 10 | """ 11 | Returns the logarithmic emission probability density. 12 | :param distance: Absolute distance [m] between GPS measurement and map matching candidate. 13 | :return: 14 | """ 15 | return log_normal_distribution(self.sigma, distance) 16 | 17 | def transition_log_probability(self, route_length, linear_distance): 18 | """ 19 | Returns the logarithmic transition probability density for the given transition parameters. 20 | :param route_length: Length of the shortest route [m] between two consecutive map matching candidates. 21 | :param linear_distance: Linear distance [m] between two consecutive GPS measurements. 22 | :return: 23 | """ 24 | transition_metric = math.fabs(linear_distance - route_length) 25 | return log_exponential_distribution(self.beta, transition_metric) 26 | 27 | 28 | def log_normal_distribution(sigma, x): 29 | return math.log(1.0 / (math.sqrt(2.0 * math.pi) * sigma)) + (-0.5 * pow(x / sigma, 2)) 30 | 31 | 32 | def log_exponential_distribution(beta, x): 33 | return math.log(1.0 / beta) - (x / beta) 34 | -------------------------------------------------------------------------------- /map_matching/hmm/ti_viterbi.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the Viterbi algorithm for time-inhomogeneous Markov processes, 3 | meaning that the set of states and state transition probabilities are not necessarily fixed for all time steps. 4 | For long observation sequences, back pointers usually converge to a single path after a 5 | certain number of time steps. For instance, when matching GPS coordinates to roads, the last 6 | GPS positions in the trace usually do not affect the first road matches anymore. 7 | This implementation exploits this fact by letting the Java garbage collector 8 | take care of unreachable back pointers. If back pointers converge to a single path after a 9 | constant number of time steps, only O(t) back pointers and transition descriptors need to be stored in memory. 10 | """ 11 | 12 | 13 | class ExtendedState: 14 | """ 15 | Back pointer to previous state candidate in the most likely sequence. 16 | """ 17 | def __init__(self, state, back_pointer, observation, transition_descriptor): 18 | self.state = state 19 | self.back_pointer = back_pointer 20 | self.observation = observation 21 | self.transition_descriptor = transition_descriptor 22 | 23 | 24 | class SequenceState: 25 | def __init__(self, state, observation, transition_descriptor): 26 | self.state = state 27 | self.observation = observation 28 | self.transition_descriptor = transition_descriptor 29 | 30 | 31 | class ForwardStepResult: 32 | def __init__(self): 33 | """ 34 | Includes back pointers to previous state candidates for retrieving the most likely sequence after the forward pass. 35 | :param nb_states: 36 | """ 37 | self.new_message = {} 38 | self.new_extended_states = {} 39 | 40 | 41 | class ViterbiAlgorithm: 42 | def __init__(self, keep_message_history=False): 43 | # Allows to retrieve the most likely sequence using back pointers. 44 | self.last_extended_states = None 45 | self.prev_candidates = [] 46 | # For each state s_t of the current time step t, message.get(s_t) contains the log 47 | # probability of the most likely sequence ending in state s_t with given observations o_1, ..., o_t. 48 | # 49 | # Formally, this is max log p(s_1, ..., s_t, o_1, ..., o_t) w.r.t. s_1, ..., s_{t-1}. 50 | # Note that to compute the most likely state sequence, it is sufficient and more 51 | # efficient to compute in each time step the joint probability of states and observations 52 | # instead of computing the conditional probability of states given the observations. 53 | # state -> float 54 | self.message = None 55 | self.is_broken = False 56 | # list of message 57 | self.message_history = None 58 | if keep_message_history: 59 | self.message_history = [] 60 | 61 | def initialize_state_probabilities(self, observation, candidates, initial_log_probabilities): 62 | """ 63 | Use only if HMM only starts with first observation. 64 | :param observation: 65 | :param candidates: 66 | :param initial_log_probabilities: 67 | :return: 68 | """ 69 | if self.message is not None: 70 | raise Exception('Initial probabilities have already been set.') 71 | # Set initial log probability for each start state candidate based on first observation. 72 | # Do not assign initial_log_probabilities directly to message to not rely on its iteration order. 73 | initial_message = {} 74 | for candidate in candidates: 75 | if candidate not in initial_log_probabilities: 76 | raise Exception('No initial probability for {}'.format(candidate)) 77 | log_probability = initial_log_probabilities[candidate] 78 | initial_message[candidate] = log_probability 79 | self.is_broken = self.hmm_break(initial_message) 80 | if self.is_broken: 81 | return 82 | self.message = initial_message 83 | if self.message_history is not None: 84 | self.message_history.append(self.message) 85 | self.last_extended_states = {} 86 | for candidate in candidates: 87 | self.last_extended_states[candidate] = ExtendedState(candidate, None, observation, None) 88 | self.prev_candidates = [candidate for candidate in candidates] 89 | 90 | def hmm_break(self, message): 91 | """ 92 | Returns whether the specified message is either empty or only contains state candidates with zero probability and thus causes the HMM to break. 93 | :return: 94 | """ 95 | for log_probability in message.values(): 96 | if log_probability != float('-inf'): 97 | return False 98 | return True 99 | 100 | def forward_step(self, observation, prev_candidates, cur_candidates, message, emission_log_probabilities, 101 | transition_log_probabilities, transition_descriptors=None): 102 | result = ForwardStepResult() 103 | assert len(prev_candidates) != 0 104 | 105 | for cur_state in cur_candidates: 106 | max_log_probability = float('-inf') 107 | max_prev_state = None 108 | for prev_state in prev_candidates: 109 | log_probability = message[prev_state] + self.transition_log_probability(prev_state, cur_state, 110 | transition_log_probabilities) 111 | if log_probability > max_log_probability: 112 | max_log_probability = log_probability 113 | max_prev_state = prev_state 114 | # throws KeyError if cur_state not in the emission_log_probabilities 115 | result.new_message[cur_state] = max_log_probability + emission_log_probabilities[cur_state] 116 | # Note that max_prev_state == None if there is no transition with non-zero probability. 117 | # In this case cur_state has zero probability and will not be part of the most likely 118 | # sequence, so we don't need an ExtendedState. 119 | if max_prev_state is not None: 120 | transition = (max_prev_state, cur_state) 121 | if transition_descriptors is not None: 122 | transition_descriptor = transition_descriptors[transition] 123 | else: 124 | transition_descriptor = None 125 | extended_state = ExtendedState(cur_state, self.last_extended_states[max_prev_state], observation, 126 | transition_descriptor) 127 | result.new_extended_states[cur_state] = extended_state 128 | return result 129 | 130 | def transition_log_probability(self, prev_state, cur_state, transition_log_probabilities): 131 | transition = (prev_state, cur_state) 132 | if transition not in transition_log_probabilities: 133 | return float('-inf') 134 | else: 135 | return transition_log_probabilities[transition] 136 | 137 | def most_likely_state(self): 138 | # Retrieves the first state of the current forward message with maximum probability. 139 | assert len(self.message) != 0 140 | 141 | result = None 142 | max_log_probability = float('-inf') 143 | for state in self.message: 144 | if self.message[state] > max_log_probability: 145 | result = state 146 | max_log_probability = self.message[state] 147 | # Otherwise an HMM break would have occurred. 148 | assert result is not None 149 | return result 150 | 151 | def retrieve_most_likely_sequence(self): 152 | # Otherwise an HMM break would have occurred and message would be null. 153 | assert len(self.message) != 0 154 | 155 | last_state = self.most_likely_state() 156 | # Retrieve most likely state sequence in reverse order 157 | result = [] 158 | es = self.last_extended_states[last_state] 159 | while es is not None: 160 | ss = SequenceState(es.state, es.observation, es.transition_descriptor) 161 | result.append(ss) 162 | es = es.back_pointer 163 | result.reverse() 164 | return result 165 | 166 | def start_with_initial_observation(self, observation, candidates, emission_log_probabilities): 167 | """ 168 | Lets the HMM computation start at the given first observation and uses the given emission 169 | probabilities as the initial state probability for each starting state s. 170 | :param observation: 171 | :param candidates: 172 | :param emission_log_probabilities: 173 | :return: 174 | """ 175 | self.initialize_state_probabilities(observation, candidates, emission_log_probabilities) 176 | 177 | def next_step(self, observation, candidates, emission_log_probabilities, transition_log_probabilities, transition_descriptors=None): 178 | if self.message is None: 179 | raise Exception('start_with_initial_observation() must be called first.') 180 | if self.is_broken: 181 | raise Exception('Method must not be called after an HMM break.') 182 | forward_step_result = self.forward_step(observation, self.prev_candidates, candidates, self.message, 183 | emission_log_probabilities, transition_log_probabilities, transition_descriptors) 184 | self.is_broken = self.hmm_break(forward_step_result.new_message) 185 | if self.is_broken: 186 | return 187 | if self.message_history is not None: 188 | self.message_history.append(forward_step_result.new_message) 189 | self.message = forward_step_result.new_message 190 | self.last_extended_states = forward_step_result.new_extended_states 191 | self.prev_candidates = [candidate for candidate in candidates] 192 | 193 | def compute_most_likely_sequence(self): 194 | """ 195 | Returns the most likely sequence of states for all time steps. This includes the initial 196 | states / initial observation time step. If an HMM break occurred in the last time step t, 197 | then the most likely sequence up to t-1 is returned. See also {@link #isBroken()}. 198 | Formally, the most likely sequence is argmax p([s_0,] s_1, ..., s_T | o_1, ..., o_T) 199 | with respect to s_1, ..., s_T, where s_t is a state candidate at time step t, 200 | o_t is the observation at time step t and T is the number of time steps. 201 | :return: 202 | """ 203 | if self.message is None: 204 | # Return empty most likely sequence if there are no time steps or if initial observations caused an HMM break. 205 | return [] 206 | else: 207 | return self.retrieve_most_likely_sequence() 208 | -------------------------------------------------------------------------------- /map_matching/map_matcher.py: -------------------------------------------------------------------------------- 1 | class MapMatcher: 2 | def __init__(self, rn, routing_weight='length'): 3 | self.rn = rn 4 | self.routing_weight = routing_weight 5 | 6 | def match(self, traj): 7 | pass 8 | 9 | def match_to_path(self, traj): 10 | pass 11 | -------------------------------------------------------------------------------- /map_matching/route_constructor.py: -------------------------------------------------------------------------------- 1 | from .utils import find_shortest_path 2 | from datetime import timedelta 3 | import networkx as nx 4 | from common.path import PathEntity, Path 5 | 6 | 7 | def construct_path(rn, mm_traj, routing_weight): 8 | """ 9 | construct the path of the map matched trajectory 10 | Note: the enter time of the first path entity & the leave time of the last path entity is not accurate 11 | :param rn: the road network 12 | :param mm_traj: the map matched trajectory 13 | :param routing_weight: the attribute name to find the routing weight 14 | :return: a list of paths (Note: in case that the route is broken) 15 | """ 16 | paths = [] 17 | path = [] 18 | mm_pt_list = mm_traj.pt_list 19 | start_idx = len(mm_pt_list) - 1 20 | # find the first matched point 21 | for i in range(len(mm_pt_list)): 22 | if mm_pt_list[i].data['candi_pt'] is not None: 23 | start_idx = i 24 | break 25 | pre_edge_enter_time = mm_pt_list[start_idx].time 26 | for i in range(start_idx + 1, len(mm_pt_list)): 27 | pre_mm_pt = mm_pt_list[i-1] 28 | cur_mm_pt = mm_pt_list[i] 29 | # unmatched -> matched 30 | if pre_mm_pt.data['candi_pt'] is None: 31 | pre_edge_enter_time = cur_mm_pt.time 32 | continue 33 | # matched -> unmatched 34 | pre_candi_pt = pre_mm_pt.data['candi_pt'] 35 | if cur_mm_pt.data['candi_pt'] is None: 36 | path.append(PathEntity(pre_edge_enter_time, pre_mm_pt.time, pre_candi_pt.eid)) 37 | if len(path) > 2: 38 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path)) 39 | path = [] 40 | continue 41 | # matched -> matched 42 | cur_candi_pt = cur_mm_pt.data['candi_pt'] 43 | # if consecutive points are on the same road, cur_mm_pt doesn't bring new information 44 | if pre_candi_pt.eid != cur_candi_pt.eid: 45 | weight_p, p = find_shortest_path(rn, pre_candi_pt, cur_candi_pt, routing_weight) 46 | # cannot connect 47 | if p is None: 48 | path.append(PathEntity(pre_edge_enter_time, pre_mm_pt.time, pre_candi_pt.eid)) 49 | if len(path) > 2: 50 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path)) 51 | path = [] 52 | pre_edge_enter_time = cur_mm_pt.time 53 | continue 54 | # can connect 55 | if nx.is_directed(rn): 56 | dist_to_p_entrance = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['length'] - pre_candi_pt.offset 57 | dist_to_p_exit = cur_candi_pt.offset 58 | else: 59 | entrance_vertex = p[0] 60 | pre_edge_coords = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['coords'] 61 | if (pre_edge_coords[0].lng, pre_edge_coords[0].lat) == entrance_vertex: 62 | dist_to_p_entrance = pre_candi_pt.offset 63 | else: 64 | dist_to_p_entrance = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['length'] - pre_candi_pt.offset 65 | exit_vertex = p[-1] 66 | cur_edge_coords = rn.edges[rn.edge_idx[cur_candi_pt.eid]]['coords'] 67 | if (cur_edge_coords[0].lng, cur_edge_coords[0].lat) == exit_vertex: 68 | dist_to_p_exit = cur_candi_pt.offset 69 | else: 70 | dist_to_p_exit = rn.edges[rn.edge_idx[cur_candi_pt.eid]]['length'] - pre_candi_pt.offset 71 | if routing_weight == 'length': 72 | total_dist = weight_p 73 | else: 74 | dist_inner = 0.0 75 | for i in range(len(p) - 1): 76 | start, end = p[i], p[i + 1] 77 | dist_inner += rn[start][end]['length'] 78 | total_dist = dist_inner + dist_to_p_entrance + dist_to_p_exit 79 | delta_time = (cur_mm_pt.time - pre_mm_pt.time).total_seconds() 80 | # two consecutive points matched to the same vertex 81 | if total_dist == 0: 82 | pre_edge_leave_time = cur_mm_pt.time 83 | path.append(PathEntity(pre_edge_enter_time, pre_edge_leave_time, pre_candi_pt.eid)) 84 | cur_edge_enter_time = cur_mm_pt.time 85 | else: 86 | pre_edge_leave_time = pre_mm_pt.time + timedelta(seconds=delta_time*(dist_to_p_entrance/total_dist)) 87 | path.append(PathEntity(pre_edge_enter_time, pre_edge_leave_time, pre_candi_pt.eid)) 88 | cur_edge_enter_time = cur_mm_pt.time - timedelta(seconds=delta_time * (dist_to_p_exit / total_dist)) 89 | sub_path = linear_interpolate_path(p, total_dist - dist_to_p_entrance - dist_to_p_exit, 90 | rn, pre_edge_leave_time, cur_edge_enter_time) 91 | path.extend(sub_path) 92 | pre_edge_enter_time = cur_edge_enter_time 93 | # handle last matched similar to (matched -> unmatched) 94 | if mm_pt_list[-1].data['candi_pt'] is not None: 95 | path.append(PathEntity(pre_edge_enter_time, mm_pt_list[-1].time, mm_pt_list[-1].data['candi_pt'].eid)) 96 | if len(path) > 2: 97 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path)) 98 | return paths 99 | 100 | 101 | def linear_interpolate_path(p, dist_inner, rn, enter_time, leave_time): 102 | path = [] 103 | edges = [] 104 | for i in range(len(p)-1): 105 | edges.append((p[i], p[i+1])) 106 | delta_time = (leave_time - enter_time).total_seconds() 107 | edge_enter_time = enter_time 108 | for i in range(len(edges)): 109 | edge_data = rn.edges[edges[i]] 110 | if i == len(edges) - 1: 111 | # to make sure the last connect edge leave time 112 | # meet the path leave time due to double calculation accuracy 113 | edge_leave_time = leave_time 114 | else: 115 | edge_leave_time = edge_enter_time + timedelta(seconds=delta_time*(edge_data['length']/dist_inner)) 116 | path.append(PathEntity(edge_enter_time, edge_leave_time, edge_data['eid'])) 117 | edge_enter_time = edge_leave_time 118 | return path 119 | 120 | 121 | def get_pid(oid, path): 122 | return oid + '_' + path[0].enter_time.strftime('%Y%m%d%H%M') + '_' + \ 123 | path[-1].leave_time.strftime('%Y%m%d%H%M') 124 | -------------------------------------------------------------------------------- /map_matching/utils.py: -------------------------------------------------------------------------------- 1 | from common.spatial_func import SPoint, distance 2 | import networkx as nx 3 | import math 4 | 5 | 6 | def find_shortest_path(rn, prev_candi_pt, cur_candi_pt, weight='length'): 7 | if nx.is_directed(rn): 8 | return find_shortest_path_directed(rn, prev_candi_pt, cur_candi_pt, weight) 9 | else: 10 | return find_shortest_path_undirected(rn, prev_candi_pt, cur_candi_pt, weight) 11 | 12 | 13 | def find_shortest_path_directed(rn, prev_candi_pt, cur_candi_pt, weight): 14 | # case 1, on the same road 15 | if prev_candi_pt.eid == cur_candi_pt.eid: 16 | if prev_candi_pt.offset < cur_candi_pt.offset: 17 | return (cur_candi_pt.offset - prev_candi_pt.offset), [] 18 | else: 19 | return float('inf'), None 20 | # case 2, on different roads (including opposite roads) 21 | else: 22 | pre_u, pre_v = rn.edge_idx[prev_candi_pt.eid] 23 | cur_u, cur_v = rn.edge_idx[cur_candi_pt.eid] 24 | try: 25 | path = get_cheapest_path_with_weight(rn, pre_v, cur_u, rn[pre_u][pre_v]['length'] - prev_candi_pt.offset, 26 | cur_candi_pt.offset, heuristic, weight) 27 | return path 28 | except nx.NetworkXNoPath: 29 | return float('inf'), None 30 | 31 | 32 | def find_shortest_path_undirected(rn, prev_candi_pt, cur_candi_pt, weight): 33 | # case 1, on the same road 34 | if prev_candi_pt.eid == cur_candi_pt.eid: 35 | return math.fabs(cur_candi_pt.offset - prev_candi_pt.offset), [] 36 | # case 2, on different roads (including opposite roads) 37 | else: 38 | pre_u, pre_v = rn.edge_idx[prev_candi_pt.eid] 39 | cur_u, cur_v = rn.edge_idx[cur_candi_pt.eid] 40 | min_dist = float('inf') 41 | shortest_path = None 42 | paths = [] 43 | # prev_u -> cur_u 44 | try: 45 | paths.append(get_cheapest_path_with_weight(rn, pre_u, cur_u, prev_candi_pt.offset, 46 | cur_candi_pt.offset, heuristic, weight)) 47 | except nx.NetworkXNoPath: 48 | pass 49 | # prev_u -> cur_v 50 | try: 51 | paths.append(get_cheapest_path_with_weight(rn, pre_u, cur_v, prev_candi_pt.offset, 52 | rn[cur_u][cur_v]['length'] - cur_candi_pt.offset, 53 | heuristic, weight)) 54 | except nx.NetworkXNoPath: 55 | pass 56 | # pre_v -> cur_u 57 | try: 58 | paths.append(get_cheapest_path_with_weight(rn, pre_v, cur_u, 59 | rn[pre_u][pre_v]['length'] - prev_candi_pt.offset, 60 | cur_candi_pt.offset, heuristic, weight)) 61 | except nx.NetworkXNoPath: 62 | pass 63 | # prev_v -> cur_v: 64 | try: 65 | paths.append(get_cheapest_path_with_weight(rn, pre_v, cur_v, 66 | rn[pre_u][pre_v]['length'] - prev_candi_pt.offset, 67 | rn[cur_u][cur_v]['length'] - cur_candi_pt.offset, 68 | heuristic, weight)) 69 | except nx.NetworkXNoPath: 70 | pass 71 | if len(paths) > 0: 72 | min_dist, shortest_path = min(paths, key=lambda t: t[0]) 73 | return min_dist, shortest_path 74 | 75 | 76 | def heuristic(node1, node2): 77 | return distance(SPoint(node1[1], node1[0]), SPoint(node2[1], node2[0])) 78 | 79 | 80 | def get_cheapest_path_with_weight(rn, src, dest, dist_to_src, dist_to_dest, heuristic, weight): 81 | tot_weight = 0.0 82 | path = nx.astar_path(rn, src, dest, heuristic, weight=weight) 83 | tot_weight += dist_to_src 84 | for i in range(len(path) - 1): 85 | start = path[i] 86 | end = path[i + 1] 87 | tot_weight += rn[start][end][weight] 88 | tot_weight += dist_to_dest 89 | return tot_weight, path 90 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/10/20 18:07 4 | -------------------------------------------------------------------------------- /models/datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | import os 4 | from chinese_calendar import is_holiday 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from common.spatial_func import distance 10 | from common.trajectory import get_tid, Trajectory 11 | from utils.parse_traj import ParseMMTraj 12 | from utils.save_traj import SaveTraj2MM 13 | from utils.utils import create_dir 14 | from .model_utils import load_rid_freqs 15 | 16 | 17 | def split_data(traj_input_dir, output_dir): 18 | """ 19 | split original data to train, valid and test datasets 20 | """ 21 | create_dir(output_dir) 22 | train_data_dir = output_dir + 'train_data/' 23 | create_dir(train_data_dir) 24 | val_data_dir = output_dir + 'valid_data/' 25 | create_dir(val_data_dir) 26 | test_data_dir = output_dir + 'test_data/' 27 | create_dir(test_data_dir) 28 | 29 | trg_parser = ParseMMTraj() 30 | trg_saver = SaveTraj2MM() 31 | 32 | for file_name in tqdm(os.listdir(traj_input_dir)): 33 | traj_input_path = os.path.join(traj_input_dir, file_name) 34 | trg_trajs = np.array(trg_parser.parse(traj_input_path)) 35 | ttl_lens = len(trg_trajs) 36 | test_inds = random.sample(range(ttl_lens), int(ttl_lens * 0.1)) # 10% as test data 37 | tmp_inds = [ind for ind in range(ttl_lens) if ind not in test_inds] 38 | val_inds = random.sample(tmp_inds, int(ttl_lens * 0.2)) # 20% as validation data 39 | train_inds = [ind for ind in tmp_inds if ind not in val_inds] # 70% as training data 40 | 41 | trg_saver.store(trg_trajs[train_inds], os.path.join(train_data_dir, 'train_' + file_name)) 42 | print("target traj train len: ", len(trg_trajs[train_inds])) 43 | trg_saver.store(trg_trajs[val_inds], os.path.join(val_data_dir, 'val_' + file_name)) 44 | print("target traj val len: ", len(trg_trajs[val_inds])) 45 | trg_saver.store(trg_trajs[test_inds], os.path.join(test_data_dir, 'test_' + file_name)) 46 | print("target traj test len: ", len(trg_trajs[test_inds])) 47 | 48 | 49 | class Dataset(torch.utils.data.Dataset): 50 | """ 51 | customize a dataset for PyTorch 52 | """ 53 | 54 | def __init__(self, trajs_dir, mbr, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, parameters, debug=True): 55 | self.mbr = mbr # MBR of all trajectories 56 | self.grid_size = parameters.grid_size 57 | self.time_span = parameters.time_span # time interval between two consecutive points. 58 | self.online_features_flag = parameters.online_features_flag 59 | self.src_grid_seqs, self.src_gps_seqs, self.src_pro_feas = [], [], [] 60 | self.trg_gps_seqs, self.trg_rids, self.trg_rates = [], [], [] 61 | self.new_tids = [] 62 | # above should be [num_seq, len_seq(unpadded)] 63 | self.get_data(trajs_dir, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, parameters.win_size, 64 | parameters.ds_type, parameters.keep_ratio, debug) 65 | def __len__(self): 66 | """Denotes the total number of samples""" 67 | return len(self.src_grid_seqs) 68 | 69 | def __getitem__(self, index): 70 | """Generate one sample of data""" 71 | src_grid_seq = self.src_grid_seqs[index] 72 | src_gps_seq = self.src_gps_seqs[index] 73 | trg_gps_seq = self.trg_gps_seqs[index] 74 | trg_rid = self.trg_rids[index] 75 | trg_rate = self.trg_rates[index] 76 | 77 | src_grid_seq = self.add_token(src_grid_seq) 78 | src_gps_seq = self.add_token(src_gps_seq) 79 | trg_gps_seq = self.add_token(trg_gps_seq) 80 | trg_rid = self.add_token(trg_rid) 81 | trg_rate = self.add_token(trg_rate) 82 | src_pro_fea = torch.tensor(self.src_pro_feas[index]) 83 | 84 | return src_grid_seq, src_gps_seq, src_pro_fea, trg_gps_seq, trg_rid, trg_rate 85 | 86 | def add_token(self, sequence): 87 | """ 88 | Append start element(sos in NLP) for each sequence. And convert each list to tensor. 89 | """ 90 | new_sequence = [] 91 | dimension = len(sequence[0]) 92 | start = [0] * dimension # pad 0 as start of rate sequence 93 | new_sequence.append(start) 94 | new_sequence.extend(sequence) 95 | new_sequence = torch.tensor(new_sequence) 96 | return new_sequence 97 | 98 | def get_data(self, trajs_dir, norm_grid_poi_dict, norm_grid_rnfea_dict, 99 | weather_dict, win_size, ds_type, keep_ratio, debug): 100 | parser = ParseMMTraj() 101 | 102 | if debug: 103 | trg_paths = os.listdir(trajs_dir)[:3] 104 | num = -1 105 | else: 106 | trg_paths = os.listdir(trajs_dir) 107 | num = -1 108 | 109 | for file_name in tqdm(trg_paths): 110 | trajs = parser.parse(os.path.join(trajs_dir, file_name)) 111 | for traj in trajs[:num]: 112 | new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, \ 113 | ls_grid_seq_ls, ls_gps_seq_ls, features_ls = self.parse_traj(traj, norm_grid_poi_dict, norm_grid_rnfea_dict, 114 | weather_dict, win_size, ds_type, keep_ratio) 115 | if new_tid_ls is not None: 116 | self.new_tids.extend(new_tid_ls) 117 | self.trg_gps_seqs.extend(mm_gps_seq_ls) 118 | self.trg_rids.extend(mm_eids_ls) 119 | self.trg_rates.extend(mm_rates_ls) 120 | self.src_grid_seqs.extend(ls_grid_seq_ls) 121 | self.src_gps_seqs.extend(ls_gps_seq_ls) 122 | self.src_pro_feas.extend(features_ls) 123 | assert len(new_tid_ls) == len(mm_gps_seq_ls) == len(mm_eids_ls) == len(mm_rates_ls) 124 | 125 | assert len(self.new_tids) == len(self.trg_gps_seqs) == len(self.trg_rids) == len(self.trg_rates) == \ 126 | len(self.src_gps_seqs) == len(self.src_grid_seqs) == len(self.src_pro_feas), \ 127 | 'The number of source and target sequence must be equal.' 128 | 129 | 130 | def parse_traj(self, traj, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, win_size, ds_type, keep_ratio): 131 | """ 132 | Split traj based on length. 133 | Preprocess ground truth (map-matched) Trajectory(), get gps sequence, rid list and rate list. 134 | Down sample original Trajectory(), get ls_gps, ls_grid sequence and profile features 135 | Args: 136 | ----- 137 | traj: 138 | Trajectory() 139 | win_size: 140 | window size of length for a single high sampling trajectory 141 | ds_type: 142 | ['uniform', 'random'] 143 | uniform: sample GPS point every down_steps element. 144 | the down_step is calculated by 1/remove_ratio 145 | random: randomly sample (1-down_ratio)*len(old_traj) points by ascending. 146 | keep_ratio: 147 | float. range in (0,1). The ratio that keep GPS points to total points. 148 | Returns: 149 | -------- 150 | new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, ls_grid_seq_ls, ls_gps_seq_ls, features_ls 151 | """ 152 | new_trajs = self.get_win_trajs(traj, win_size) 153 | 154 | new_tid_ls = [] 155 | mm_gps_seq_ls, mm_eids_ls, mm_rates_ls = [], [], [] 156 | ls_grid_seq_ls, ls_gps_seq_ls, features_ls = [], [], [] 157 | 158 | for tr in new_trajs: 159 | tmp_pt_list = tr.pt_list 160 | new_tid_ls.append(tr.tid) 161 | 162 | # get target sequence 163 | mm_gps_seq, mm_eids, mm_rates = self.get_trg_seq(tmp_pt_list) 164 | if mm_eids is None: 165 | return None, None, None, None, None, None, None 166 | 167 | # get source sequence 168 | ds_pt_list = self.downsample_traj(tmp_pt_list, ds_type, keep_ratio) 169 | ls_grid_seq, ls_gps_seq, hours, ttl_t = self.get_src_seq(ds_pt_list, norm_grid_poi_dict, norm_grid_rnfea_dict) 170 | features = self.get_pro_features(ds_pt_list, hours, weather_dict) 171 | 172 | # check if src and trg len equal, if not return none 173 | if len(mm_gps_seq) != ttl_t: 174 | return None, None, None, None, None, None, None 175 | 176 | mm_gps_seq_ls.append(mm_gps_seq) 177 | mm_eids_ls.append(mm_eids) 178 | mm_rates_ls.append(mm_rates) 179 | ls_grid_seq_ls.append(ls_grid_seq) 180 | ls_gps_seq_ls.append(ls_gps_seq) 181 | features_ls.append(features) 182 | 183 | return new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, ls_grid_seq_ls, ls_gps_seq_ls, features_ls 184 | 185 | 186 | def get_win_trajs(self, traj, win_size): 187 | pt_list = traj.pt_list 188 | len_pt_list = len(pt_list) 189 | if len_pt_list < win_size: 190 | return [traj] 191 | 192 | num_win = len_pt_list // win_size 193 | last_traj_len = len_pt_list % win_size + 1 194 | new_trajs = [] 195 | for w in range(num_win): 196 | # if last window is large enough then split to a single trajectory 197 | if w == num_win and last_traj_len > 15: 198 | tmp_pt_list = pt_list[win_size * w - 1:] 199 | # elif last window is not large enough then merge to the last trajectory 200 | elif w == num_win - 1 and last_traj_len <= 15: 201 | # fix bug, when num_win = 1 202 | ind = 0 203 | if win_size * w - 1 > 0: 204 | ind = win_size * w - 1 205 | tmp_pt_list = pt_list[ind:] 206 | # else split trajectories based on the window size 207 | else: 208 | tmp_pt_list = pt_list[max(0, (win_size * w - 1)):win_size * (w + 1)] 209 | # -1 to make sure the overlap between two trajs 210 | 211 | new_traj = Trajectory(traj.oid, get_tid(traj.oid, tmp_pt_list), tmp_pt_list) 212 | new_trajs.append(new_traj) 213 | return new_trajs 214 | 215 | def get_trg_seq(self, tmp_pt_list): 216 | mm_gps_seq = [] 217 | mm_eids = [] 218 | mm_rates = [] 219 | for pt in tmp_pt_list: 220 | candi_pt = pt.data['candi_pt'] 221 | if candi_pt is None: 222 | return None, None, None 223 | else: 224 | mm_gps_seq.append([candi_pt.lat, candi_pt.lng]) 225 | mm_eids.append([candi_pt.eid]) # keep the same format as seq 226 | mm_rates.append([candi_pt.rate]) 227 | return mm_gps_seq, mm_eids, mm_rates 228 | 229 | 230 | def get_src_seq(self, ds_pt_list, norm_grid_poi_dict, norm_grid_rnfea_dict): 231 | hours = [] 232 | ls_grid_seq = [] 233 | ls_gps_seq = [] 234 | first_pt = ds_pt_list[0] 235 | last_pt = ds_pt_list[-1] 236 | time_interval = self.time_span 237 | ttl_t = self.get_noramlized_t(first_pt, last_pt, time_interval) 238 | for ds_pt in ds_pt_list: 239 | hours.append(ds_pt.time.hour) 240 | t = self.get_noramlized_t(first_pt, ds_pt, time_interval) 241 | ls_gps_seq.append([ds_pt.lat, ds_pt.lng]) 242 | locgrid_xid, locgrid_yid = self.gps2grid(ds_pt, self.mbr, self.grid_size) 243 | if self.online_features_flag: 244 | poi_features = norm_grid_poi_dict[(locgrid_xid, locgrid_yid)] 245 | rn_features = norm_grid_rnfea_dict[(locgrid_xid, locgrid_yid)] 246 | ls_grid_seq.append([locgrid_xid, locgrid_yid, t]+poi_features+rn_features) 247 | else: 248 | ls_grid_seq.append([locgrid_xid, locgrid_yid, t]) 249 | 250 | return ls_grid_seq, ls_gps_seq, hours, ttl_t 251 | 252 | 253 | def get_pro_features(self, ds_pt_list, hours, weather_dict): 254 | holiday = is_holiday(ds_pt_list[0].time)*1 255 | day = ds_pt_list[0].time.day 256 | hour = {'hour': np.bincount(hours).max()} # find most frequent hours as hour of the trajectory 257 | weather = {'weather': weather_dict[(day, hour['hour'])]} 258 | features = self.one_hot(hour) + self.one_hot(weather) + [holiday] 259 | return features 260 | 261 | 262 | def gps2grid(self, pt, mbr, grid_size): 263 | """ 264 | mbr: 265 | MBR class. 266 | grid size: 267 | int. in meter 268 | """ 269 | LAT_PER_METER = 8.993203677616966e-06 270 | LNG_PER_METER = 1.1700193970443768e-05 271 | lat_unit = LAT_PER_METER * grid_size 272 | lng_unit = LNG_PER_METER * grid_size 273 | 274 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1 275 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1 276 | 277 | lat = pt.lat 278 | lng = pt.lng 279 | locgrid_x = int((lat - mbr.min_lat) / lat_unit) + 1 280 | locgrid_y = int((lng - mbr.min_lng) / lng_unit) + 1 281 | 282 | return locgrid_x, locgrid_y 283 | 284 | 285 | def get_noramlized_t(self, first_pt, current_pt, time_interval): 286 | """ 287 | calculate normalized t from first and current pt 288 | return time index (normalized time) 289 | """ 290 | t = int(1+((current_pt.time - first_pt.time).seconds/time_interval)) 291 | return t 292 | 293 | @staticmethod 294 | def get_distance(pt_list): 295 | dist = 0.0 296 | pre_pt = pt_list[0] 297 | for pt in pt_list[1:]: 298 | tmp_dist = distance(pre_pt, pt) 299 | dist += tmp_dist 300 | pre_pt = pt 301 | return dist 302 | 303 | 304 | @staticmethod 305 | def downsample_traj(pt_list, ds_type, keep_ratio): 306 | """ 307 | Down sample trajectory 308 | Args: 309 | ----- 310 | pt_list: 311 | list of Point() 312 | ds_type: 313 | ['uniform', 'random'] 314 | uniform: sample GPS point every down_stepth element. 315 | the down_step is calculated by 1/remove_ratio 316 | random: randomly sample (1-down_ratio)*len(old_traj) points by ascending. 317 | keep_ratio: 318 | float. range in (0,1). The ratio that keep GPS points to total points. 319 | Returns: 320 | ------- 321 | traj: 322 | new Trajectory() 323 | """ 324 | assert ds_type in ['uniform', 'random'], 'only `uniform` or `random` is supported' 325 | 326 | old_pt_list = pt_list.copy() 327 | start_pt = old_pt_list[0] 328 | end_pt = old_pt_list[-1] 329 | 330 | if ds_type == 'uniform': 331 | if (len(old_pt_list) - 1) % int(1 / keep_ratio) == 0: 332 | new_pt_list = old_pt_list[::int(1 / keep_ratio)] 333 | else: 334 | new_pt_list = old_pt_list[::int(1 / keep_ratio)] + [end_pt] 335 | elif ds_type == 'random': 336 | sampled_inds = sorted( 337 | random.sample(range(1, len(old_pt_list) - 1), int((len(old_pt_list) - 2) * keep_ratio))) 338 | new_pt_list = [start_pt] + list(np.array(old_pt_list)[sampled_inds]) + [end_pt] 339 | 340 | return new_pt_list 341 | 342 | 343 | 344 | 345 | @staticmethod 346 | def one_hot(data): 347 | one_hot_dict = {'hour': 24, 'weekday': 7, 'weather':5} 348 | for k, v in data.items(): 349 | encoded_data = [0] * one_hot_dict[k] 350 | encoded_data[v - 1] = 1 351 | return encoded_data 352 | 353 | 354 | # Use for DataLoader 355 | def collate_fn(data): 356 | """ 357 | Reference: https://github.com/yunjey/seq2seq-dataloader/blob/master/data_loader.py 358 | Creates mini-batch tensors from the list of tuples (src_seq, src_pro_fea, trg_seq, trg_rid, trg_rate). 359 | We should build a custom collate_fn rather than using default collate_fn, 360 | because merging sequences (including padding) is not supported in default. 361 | Sequences are padded to the maximum length of mini-batch sequences (dynamic padding). 362 | Args: 363 | ----- 364 | data: list of tuple (src_seq, src_pro_fea, trg_seq, trg_rid, trg_rate), from dataset.__getitem__(). 365 | - src_seq: torch tensor of shape (?,2); variable length. 366 | - src_pro_fea: torch tensor of shape (1,64) # concatenate all profile features 367 | - trg_seq: torch tensor of shape (??,2); variable length. 368 | - trg_rid: torch tensor of shape (??); variable length. 369 | - trg_rate: torch tensor of shape (??); variable length. 370 | Returns: 371 | -------- 372 | src_grid_seqs: 373 | torch tensor of shape (batch_size, padded_length, 3) 374 | src_gps_seqs: 375 | torch tensor of shape (batch_size, padded_length, 3). 376 | src_pro_feas: 377 | torch tensor of shape (batch_size, feature_dim) unnecessary to pad 378 | src_lengths: 379 | list of length (batch_size); valid length for each padded source sequence. 380 | trg_seqs: 381 | torch tensor of shape (batch_size, padded_length, 2). 382 | trg_rids: 383 | torch tensor of shape (batch_size, padded_length, 1). 384 | trg_rates: 385 | torch tensor of shape (batch_size, padded_length, 1). 386 | trg_lengths: 387 | list of length (batch_size); valid length for each padded target sequence. 388 | """ 389 | 390 | def merge(sequences): 391 | lengths = [len(seq) for seq in sequences] 392 | dim = sequences[0].size(1) # get dim for each sequence 393 | padded_seqs = torch.zeros(len(sequences), max(lengths), dim) 394 | for i, seq in enumerate(sequences): 395 | end = lengths[i] 396 | padded_seqs[i, :end] = seq[:end] 397 | return padded_seqs, lengths 398 | 399 | # sort a list by source sequence length (descending order) to use pack_padded_sequence 400 | data.sort(key=lambda x: len(x[0]), reverse=True) 401 | 402 | # seperate source and target sequences 403 | src_grid_seqs, src_gps_seqs, src_pro_feas, trg_gps_seqs, trg_rids, trg_rates = zip(*data) # unzip data 404 | 405 | # merge sequences (from tuple of 1D tensor to 2D tensor) 406 | src_grid_seqs, src_lengths = merge(src_grid_seqs) 407 | src_gps_seqs, src_lengths = merge(src_gps_seqs) 408 | src_pro_feas = torch.tensor([list(src_pro_fea) for src_pro_fea in src_pro_feas]) 409 | trg_gps_seqs, trg_lengths = merge(trg_gps_seqs) 410 | trg_rids, _ = merge(trg_rids) 411 | trg_rates, _ = merge(trg_rates) 412 | 413 | return src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths 414 | 415 | -------------------------------------------------------------------------------- /models/loss_fn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/11/25 18:06 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from map_matching.candidate_point import CandidatePoint 12 | from map_matching.utils import find_shortest_path 13 | from common.spatial_func import SPoint, distance 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | def check_dis_loss(predict, target, trg_len): 18 | """ 19 | Calculate MAE and RMSE between predicted and targeted GPS sequence. 20 | Args: 21 | ----- 22 | predict = [seq len, batch size, 2] 23 | target = [seq len, batch size, 2] 24 | trg_len = [batch size] if not considering target length, the loss will smaller than the real one. 25 | predict and target have been removed sos 26 | Returns: 27 | ------- 28 | MAE of a batch in meter. 29 | RMSE of a batch in meter. 30 | """ 31 | predict = predict.permute(1, 0, 2) # [batch size, seq len, 2] 32 | target = target.permute(1, 0, 2) # [batch size, seq len, 2] 33 | bs = predict.size(0) 34 | 35 | ls_dis = [] 36 | for bs_i in range(bs): 37 | for len_i in range(trg_len[bs_i]-1): 38 | pre = SPoint(predict[bs_i, len_i][0], predict[bs_i, len_i][1]) 39 | trg = SPoint(target[bs_i, len_i][0], target[bs_i, len_i][1]) 40 | dis = distance(pre, trg) 41 | ls_dis.append(dis) 42 | 43 | ls_dis = np.array(ls_dis) 44 | mae = ls_dis.mean() 45 | rmse = np.sqrt((ls_dis**2).mean()) 46 | return mae, rmse 47 | 48 | 49 | def check_rn_dis_loss(predict_gps, predict_id, predict_rate, target_gps, target_id, target_rate, trg_len, 50 | rn, raw_rn_dict, new2raw_rid_dict): 51 | """ 52 | Calculate road network based MAE and RMSE between predicted and targeted GPS sequence. 53 | Args: 54 | ----- 55 | predict_gps = [seq len, batch size, 2] 56 | predict_id = [seq len, batch size, id one hot output dim] 57 | predict_rates = [seq len, batch size] 58 | target_gps = [seq len, batch size, 2] 59 | target_id = [seq len, batch size] 60 | target_rates = [seq len, batch size] 61 | trg_len = [batch size] if not considering target length, the loss will smaller than the real one. 62 | 63 | predict and target have been removed sos 64 | Returns: 65 | ------- 66 | MAE of a batch in meter. 67 | RMSE of a batch in meter. 68 | """ 69 | seq_len = target_id.size(0) 70 | batch_size = target_id.size(1) 71 | predict_gps = predict_gps.permute(1, 0, 2) 72 | predict_id = predict_id.permute(1, 0, 2) 73 | predict_rate = predict_rate.permute(1, 0) 74 | target_gps = target_gps.permute(1, 0, 2) 75 | target_id = target_id.permute(1, 0) 76 | target_rate = target_rate.permute(1, 0) 77 | 78 | ls_dis, rn_ls_dis = [], [] 79 | for bs in range(batch_size): 80 | for len_i in range(trg_len[bs]-1): # don't calculate padding points 81 | pre_rid = predict_id[bs, len_i].argmax() 82 | convert_pre_rid = new2raw_rid_dict[pre_rid.tolist()] 83 | pre_rate = predict_rate[bs, len_i] 84 | pre_offset = raw_rn_dict[convert_pre_rid]['length'] * pre_rate 85 | pre_candi_pt = CandidatePoint(predict_gps[bs,len_i][0], predict_gps[bs,len_i][1], 86 | convert_pre_rid, 0, pre_offset, pre_rate) 87 | 88 | trg_rid = target_id[bs, len_i] 89 | convert_trg_rid = new2raw_rid_dict[trg_rid.tolist()] 90 | trg_rate = target_rate[bs, len_i] 91 | trg_offset = raw_rn_dict[convert_trg_rid]['length'] * trg_rate 92 | trg_candi_pt = CandidatePoint(target_gps[bs,len_i][0], target_gps[bs,len_i][1], 93 | convert_trg_rid, 0, trg_offset, trg_rate) 94 | 95 | if pre_candi_pt.lat == trg_candi_pt.lat and pre_candi_pt.lng == trg_candi_pt.lng: 96 | rn_dis = 0 97 | dis = 0 98 | else: 99 | rn_dis, _ = min(find_shortest_path(rn, pre_candi_pt, trg_candi_pt), 100 | find_shortest_path(rn, trg_candi_pt, pre_candi_pt)) 101 | if type(rn_dis) is not float: 102 | rn_dis = rn_dis.tolist() 103 | dis = distance(pre_candi_pt, trg_candi_pt) 104 | 105 | if rn_dis == np.inf: 106 | rn_dis = 1000 107 | rn_ls_dis.append(rn_dis) 108 | ls_dis.append(dis) 109 | 110 | ls_dis = np.array(ls_dis) 111 | rn_ls_dis = np.array(rn_ls_dis) 112 | 113 | mae = ls_dis.mean() 114 | rmse = np.sqrt((ls_dis**2).mean()) 115 | rn_mae = rn_ls_dis.mean() 116 | rn_rmse = np.sqrt((rn_ls_dis**2).mean()) 117 | return mae, rmse, rn_mae, rn_rmse 118 | 119 | 120 | def shrink_seq(seq): 121 | """remove repeated ids""" 122 | s0 = seq[0] 123 | new_seq = [s0] 124 | for s in seq[1:]: 125 | if s == s0: 126 | continue 127 | else: 128 | new_seq.append(s) 129 | s0 = s 130 | 131 | return new_seq 132 | 133 | def memoize(fn): 134 | '''Return a memoized version of the input function. 135 | 136 | The returned function caches the results of previous calls. 137 | Useful if a function call is expensive, and the function 138 | is called repeatedly with the same arguments. 139 | ''' 140 | cache = dict() 141 | def wrapped(*v): 142 | key = tuple(v) # tuples are hashable, and can be used as dict keys 143 | if key not in cache: 144 | cache[key] = fn(*v) 145 | return cache[key] 146 | return wrapped 147 | 148 | def lcs(xs, ys): 149 | '''Return the longest subsequence common to xs and ys. 150 | 151 | Example 152 | >>> lcs("HUMAN", "CHIMPANZEE") 153 | ['H', 'M', 'A', 'N'] 154 | ''' 155 | @memoize 156 | def lcs_(i, j): 157 | if i and j: 158 | xe, ye = xs[i-1], ys[j-1] 159 | if xe == ye: 160 | return lcs_(i-1, j-1) + [xe] 161 | else: 162 | return max(lcs_(i, j-1), lcs_(i-1, j), key=len) 163 | else: 164 | return [] 165 | return lcs_(len(xs), len(ys)) 166 | 167 | 168 | def cal_id_acc(predict, target, trg_len): 169 | """ 170 | Calculate RID accuracy between predicted and targeted RID sequence. 171 | 1. no repeated rid for two consecutive road segments 172 | 2. longest common subsequence 173 | http://wordaligned.org/articles/longest-common-subsequence 174 | Args: 175 | ----- 176 | predict = [seq len, batch size, id one hot output dim] 177 | target = [seq len, batch size, 1] 178 | predict and target have been removed sos 179 | Returns: 180 | ------- 181 | mean matched RID accuracy. 182 | """ 183 | predict = predict.permute(1, 0, 2) # [batch size, seq len, id dim] 184 | target = target.permute(1, 0) # [batch size, seq len, 1] 185 | bs = predict.size(0) 186 | 187 | correct_id_num = 0 188 | ttl_trg_id_num = 0 189 | ttl_pre_id_num = 0 190 | ttl = 0 191 | cnt = 0 192 | for bs_i in range(bs): 193 | pre_ids = [] 194 | trg_ids = [] 195 | # -1 because predict and target are removed sos. 196 | for len_i in range(trg_len[bs_i] - 1): 197 | pre_id = predict[bs_i][len_i].argmax() 198 | trg_id = target[bs_i][len_i] 199 | pre_ids.append(pre_id) 200 | trg_ids.append(trg_id) 201 | if pre_id == trg_id: 202 | cnt += 1 203 | ttl += 1 204 | 205 | # compute average rid accuracy 206 | shr_trg_ids = shrink_seq(trg_ids) 207 | shr_pre_ids = shrink_seq(pre_ids) 208 | correct_id_num += len(lcs(shr_trg_ids, shr_pre_ids)) 209 | ttl_trg_id_num += len(shr_trg_ids) 210 | ttl_pre_id_num += len(shr_pre_ids) 211 | 212 | rid_acc = cnt / ttl 213 | rid_recall = correct_id_num / ttl_trg_id_num 214 | rid_precision = correct_id_num / ttl_pre_id_num 215 | return rid_acc, rid_recall, rid_precision -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/10/20 18:24 4 | 5 | 6 | import torch 7 | import math 8 | import numpy as np 9 | import pandas as pd 10 | import networkx as nx 11 | 12 | from common.spatial_func import distance, cal_loc_along_line, SPoint 13 | from map_matching.candidate_point import get_candidates, CandidatePoint 14 | 15 | from utils.utils import load_json_data 16 | 17 | 18 | ##################################################################################################### 19 | # 20 | # Load Files 21 | # 22 | ##################################################################################################### 23 | 24 | def load_rid_freqs(dir, file_name): 25 | """ 26 | load rid freqs and convert key from str to int 27 | """ 28 | rid_freqs = load_json_data(dir, file_name) 29 | rid_freqs = {int(k): int(v) for k, v in rid_freqs.items()} # convert key from str to int 30 | 31 | return rid_freqs 32 | 33 | 34 | def load_rn_dict(dir, file_name): 35 | """ 36 | This function will be use in rate2gps. 37 | """ 38 | rn_dict = load_json_data(dir, file_name) 39 | new_rn_dict = {} 40 | for k, v in rn_dict.items(): 41 | new_rn_dict[int(k)] = {} 42 | new_rn_dict[int(k)]['coords'] = [SPoint(coord[0], coord[1]) for coord in v['coords']] 43 | # convert str to SPoint() to calculate distance 44 | new_rn_dict[int(k)]['length'] = v['length'] 45 | new_rn_dict[int(k)]['level'] = v['level'] 46 | del rn_dict 47 | return new_rn_dict 48 | 49 | def load_online_features(dir, file_name): 50 | """ 51 | load POI or road network and covert key from str to tuple 52 | """ 53 | data = load_json_data(dir, file_name) 54 | data = {} 55 | 56 | return data 57 | 58 | ##################################################################################################### 59 | # 60 | # RID + Rate 2 GPS 61 | # 62 | ##################################################################################################### 63 | def rate2gps(rn_dict, eid, rate, parameters): 64 | """ 65 | Convert road rate to GPS on the road segment. 66 | Since one road contains several coordinates, iteratively computing length can be more accurate. 67 | Args: 68 | ----- 69 | rn_dict: 70 | dictionary of road network 71 | eid,rate: 72 | single value from model prediction 73 | Returns: 74 | -------- 75 | project_pt: 76 | projected GPS point on the road segment. 77 | """ 78 | eid = eid.tolist() # convert tensor to normal value 79 | rate = rate.tolist() 80 | if eid <= 0 or rate < 0 or eid > (parameters.id_size-1) or rate > 1: 81 | # force eid and rate in the right range 82 | return SPoint(0, 0) 83 | 84 | coords = rn_dict[eid]['coords'] 85 | offset = rn_dict[eid]['length'] * rate 86 | dist = 0 # temp distance for coords 87 | pre_dist = 0 # coords distance is smaller than offset 88 | 89 | if rate == 1.0: 90 | return coords[-1] 91 | if rate == 0.0: 92 | return coords[0] 93 | 94 | for i in range(len(coords) - 1): 95 | if i > 0: 96 | pre_dist += distance(coords[i - 1], coords[i]) 97 | dist += distance(coords[i], coords[i + 1]) 98 | if dist >= offset: 99 | coor_rate = (offset - pre_dist) / distance(coords[i], coords[i + 1]) 100 | project_pt = cal_loc_along_line(coords[i], coords[i + 1], coor_rate) 101 | break 102 | 103 | return project_pt 104 | 105 | 106 | def toseq(rn_dict, rids, rates, paramters): 107 | """ 108 | Convert batched rids and rates to gps sequence. 109 | Args: 110 | ----- 111 | rn_dict: 112 | use for rate2gps() 113 | rids: 114 | [trg len, batch size, id one hot dim] 115 | rates: 116 | [trg len, batch size] 117 | Returns: 118 | -------- 119 | seqs: 120 | [trg len, batch size, 2] 121 | """ 122 | batch_size = rids.size(1) 123 | trg_len = rids.size(0) 124 | seqs = torch.zeros(trg_len, batch_size, 2).to(paramters.device) 125 | 126 | for i in range(1, trg_len): 127 | for bs in range(batch_size): 128 | rid = rids[i][bs].argmax() 129 | rate = rates[i][bs] 130 | pt = rate2gps(rn_dict, rid, rate, paramters) 131 | seqs[i][bs][0] = pt.lat 132 | seqs[i][bs][1] = pt.lng 133 | return seqs 134 | 135 | 136 | ##################################################################################################### 137 | # 138 | # Constraint mask 139 | # 140 | ##################################################################################################### 141 | def get_rid_grid(mbr, grid_size, rn_dict): 142 | """ 143 | Create a dict {key: grid id, value: rid} 144 | """ 145 | LAT_PER_METER = 8.993203677616966e-06 146 | LNG_PER_METER = 1.1700193970443768e-05 147 | lat_unit = LAT_PER_METER * grid_size 148 | lng_unit = LNG_PER_METER * grid_size 149 | 150 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1 151 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1 152 | 153 | grid_rn_dict = {} 154 | for k, v in rn_dict.items(): 155 | pre_lat = v['coords'][0].lat 156 | pre_lng = v['coords'][0].lng 157 | pre_locgrid_x = max(1, int((pre_lat - mbr.min_lat) / lat_unit) + 1) 158 | pre_locgrid_y = max(1, int((pre_lng - mbr.min_lng) / lng_unit) + 1) 159 | 160 | 161 | if (pre_locgrid_x, pre_locgrid_y) not in grid_rn_dict.keys(): 162 | grid_rn_dict[(pre_locgrid_x, pre_locgrid_y)] = [k] 163 | else: 164 | grid_rn_dict[(pre_locgrid_x, pre_locgrid_y)].append(k) 165 | 166 | for coord in v['coords'][1:]: 167 | lat = coord.lat 168 | lng = coord.lng 169 | locgrid_x = max(1, int((lat - mbr.min_lat) / lat_unit) + 1) 170 | locgrid_y = max(1, int((lng - mbr.min_lng) / lng_unit) + 1) 171 | 172 | if (locgrid_x, locgrid_y) not in grid_rn_dict.keys(): 173 | grid_rn_dict[(locgrid_x, locgrid_y)] = [k] 174 | else: 175 | grid_rn_dict[(locgrid_x, locgrid_y)].append(k) 176 | 177 | mid_x_num = abs(locgrid_x - pre_locgrid_x) 178 | mid_y_num = abs(locgrid_y - pre_locgrid_y) 179 | 180 | if mid_x_num > 1 and mid_y_num <= 1: 181 | for mid_x in range(1, mid_x_num): 182 | if (min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y) not in grid_rn_dict.keys(): 183 | grid_rn_dict[(min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y)] = [k] 184 | else: 185 | grid_rn_dict[(min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y)].append(k) 186 | 187 | elif mid_x_num <= 1 and mid_y_num > 1: 188 | for mid_y in range(1, mid_y_num): 189 | if (locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y) not in grid_rn_dict.keys(): 190 | grid_rn_dict[(locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y)] = [k] 191 | else: 192 | grid_rn_dict[(locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y)].append(k) 193 | 194 | elif mid_x_num > 1 and mid_y_num > 1: 195 | ttl_num = mid_x_num + mid_y_num + 1 196 | for mid in range(1, ttl_num): 197 | mid_xid = min(lat, pre_lat) + mid*abs(lat - pre_lat)/ttl_num 198 | mid_yid = min(lng, pre_lng) + mid*abs(lng - pre_lng)/ttl_num 199 | 200 | pre_lat = lat 201 | pre_lng = lng 202 | pre_locgrid_x = locgrid_x 203 | pre_locgrid_y = locgrid_y 204 | 205 | for k, v in grid_rn_dict.items(): 206 | grid_rn_dict[k] = list(set(v)) 207 | 208 | return grid_rn_dict, max_xid, max_yid 209 | 210 | def exp_prob(beta, x): 211 | """ 212 | error distance weight. 213 | """ 214 | return math.exp(-pow(x,2)/pow(beta,2)) 215 | 216 | def get_reachable_inds(pre_grid, cur_grid, grid_rn_dict,time_diff, parameters): 217 | reachable_inds = list(range(parameters.id_size)) 218 | 219 | return reachable_inds 220 | 221 | def get_dis_prob_vec(gps, rn, raw2new_rid_dict, parameters): 222 | """ 223 | Args: 224 | ----- 225 | gps: [SPoint, tid] 226 | """ 227 | cons_vec = torch.zeros(parameters.id_size) + 1e-10 228 | candis = get_candidates(gps[0], rn, parameters.search_dist) 229 | if candis is not None: 230 | for candi_pt in candis: 231 | if candi_pt.eid in raw2new_rid_dict.keys(): 232 | new_rid = raw2new_rid_dict[candi_pt.eid] 233 | prob = exp_prob(parameters.beta, candi_pt.error) 234 | cons_vec[new_rid] = prob 235 | else: 236 | cons_vec = torch.ones(parameters.id_size) 237 | return cons_vec 238 | 239 | def get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths, trg_lengths, grid_rn_dict, rn, raw2new_rid_dict, parameters): 240 | max_trg_len = max(trg_lengths) 241 | batch_size = src_grid_seqs.size(0) 242 | 243 | constraint_mat = torch.zeros(batch_size, max_trg_len, parameters.id_size) + 1e-10 244 | pre_grids = torch.zeros(batch_size, max_trg_len, 3) 245 | cur_grids = torch.zeros(batch_size, max_trg_len, 3) 246 | 247 | for bs in range(batch_size): 248 | # first src gps 249 | pre_t = 1 250 | pre_grid = [int(src_grid_seqs[bs][pre_t][0].tolist()), 251 | int(src_grid_seqs[bs][pre_t][1].tolist()), 252 | pre_t] 253 | pre_gps = [SPoint(src_gps_seqs[bs][pre_t][0].tolist(), 254 | src_gps_seqs[bs][pre_t][1].tolist()), 255 | pre_t] 256 | pre_grids[bs, pre_t] = torch.tensor(pre_grid) 257 | cur_grids[bs, pre_t] = torch.tensor(pre_grid) 258 | 259 | if parameters.dis_prob_mask_flag: 260 | cons_vec = get_dis_prob_vec(pre_gps, rn, raw2new_rid_dict, parameters) 261 | constraint_mat[bs][pre_t] = cons_vec 262 | else: 263 | reachable_inds = get_reachable_inds(pre_grid, pre_grid, grid_rn_dict, 0, parameters) 264 | constraint_mat[bs][pre_t][reachable_inds] = 1 265 | 266 | # missed gps 267 | for i in range(2, src_lengths[bs]): 268 | cur_t = int(src_grid_seqs[bs,i,2].tolist()) 269 | cur_grid = [int(src_grid_seqs[bs][i][0].tolist()), 270 | int(src_grid_seqs[bs][i][1].tolist()), 271 | cur_t] 272 | cur_gps = [SPoint(src_gps_seqs[bs][i][0].tolist(), 273 | src_gps_seqs[bs][i][1].tolist()), 274 | cur_t] 275 | pre_grids[bs, cur_t] = torch.tensor(cur_grid) 276 | cur_grids[bs, cur_t] = torch.tensor(cur_grid) 277 | 278 | time_diff = cur_t - pre_t 279 | reachable_inds = get_reachable_inds(pre_grid, cur_grid, grid_rn_dict, time_diff, parameters) 280 | 281 | for t in range(pre_t+1, cur_t): 282 | constraint_mat[bs][t][reachable_inds] = 1 283 | pre_grids[bs, t] = torch.tensor(pre_grid) 284 | cur_grids[bs, t] = torch.tensor(cur_grid) 285 | 286 | # middle src gps 287 | if parameters.dis_prob_mask_flag: 288 | cons_vec = get_dis_prob_vec(cur_gps, rn, raw2new_rid_dict, parameters) 289 | constraint_mat[bs][cur_t] = cons_vec 290 | else: 291 | reachable_inds = get_reachable_inds(cur_grid, cur_grid, grid_rn_dict, 0, parameters) 292 | constraint_mat[bs][cur_t][reachable_inds] = 1 293 | 294 | pre_t = cur_t 295 | pre_grid = cur_grid 296 | pre_gps = cur_gps 297 | 298 | return constraint_mat, pre_grids, cur_grids 299 | 300 | ##################################################################################################### 301 | # 302 | # Use for extracting POI features 303 | # 304 | ##################################################################################################### 305 | def get_poi_info(grid_poi_df, parameters): 306 | """ 307 | ['company','food', 'gym', 'education','shopping','gov', 'viewpoint','entrance','house','life', 308 | 'traffic','car','hotel','beauty','hospital','media','finance','entertainment','road','nature','landmark','address'] 309 | """ 310 | types = parameters.poi_type.split(',') 311 | norm_grid_poi_df=(grid_poi_df[types]-grid_poi_df[types].min())/(grid_poi_df[types].max()-grid_poi_df[types].min()) 312 | norm_grid_poi_df = norm_grid_poi_df.fillna(0) 313 | 314 | norm_grid_poi_dict = {} 315 | for i in range(len(norm_grid_poi_df)): 316 | k = norm_grid_poi_df.index[i] 317 | v = norm_grid_poi_df.iloc[i].values 318 | norm_grid_poi_dict[k] = list(v) 319 | 320 | for xid in range(1, parameters.max_xid+1): 321 | for yid in range(1, parameters.max_yid+1): 322 | if (xid,yid) not in norm_grid_poi_dict.keys(): 323 | norm_grid_poi_dict[(xid,yid)] = [0.] * len(types) 324 | return norm_grid_poi_dict 325 | 326 | ##################################################################################################### 327 | # 328 | # Use for extracting RN features 329 | # 330 | ##################################################################################################### 331 | def get_edge_results(eids, rn_dict): 332 | edge_results = [] 333 | for eid in eids: 334 | u = rn_dict[eid]['coords'][0] 335 | v = rn_dict[eid]['coords'][-1] 336 | edge_results.append(((u.lng,u.lat),(v.lng,v.lat))) 337 | return edge_results 338 | 339 | def extract_single_rn_features(edge_results, rn): 340 | part_g = nx.Graph() 341 | for u, v in edge_results: 342 | part_g.add_edge(u, v, **rn[u][v]) 343 | 344 | tot_length = 0.0 345 | level_2_cnt = 0 346 | level_3_cnt = 0 347 | level_4_cnt = 0 348 | for u, v, data in part_g.edges(data=True): 349 | tot_length += data['length'] 350 | if data['highway'] == 'trunk': 351 | level_2_cnt += 1 352 | elif data['highway'] == 'primary': 353 | level_3_cnt += 1 354 | elif data['highway'] == 'secondary': 355 | level_4_cnt += 1 356 | nb_intersections = 0 357 | for node, degree in part_g.degree(): 358 | if degree > 2: 359 | nb_intersections += 1 360 | 361 | rn_features = np.array([tot_length, nb_intersections, level_2_cnt, level_3_cnt, level_4_cnt]) 362 | 363 | return rn_features 364 | 365 | def get_rn_info(rn, mbr, grid_size, grid_rn_dict, rn_dict): 366 | """ 367 | rn_dict contains rn information 368 | """ 369 | LAT_PER_METER = 8.993203677616966e-06 370 | LNG_PER_METER = 1.1700193970443768e-05 371 | lat_unit = LAT_PER_METER * grid_size 372 | lng_unit = LNG_PER_METER * grid_size 373 | 374 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1 375 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1 376 | 377 | grid_rnfea_dict = {} 378 | for k,v in grid_rn_dict.items(): 379 | eids = grid_rn_dict[k] 380 | edge_results = get_edge_results(eids, rn_dict) 381 | grid_rnfea_dict[k] = extract_single_rn_features(edge_results, rn) 382 | 383 | grid_rnfea_df = pd.DataFrame(grid_rnfea_dict).T 384 | norm_grid_rnfea_df=(grid_rnfea_df-grid_rnfea_df.min())/(grid_rnfea_df.max()-grid_rnfea_df.min()) # col norm 385 | 386 | norm_grid_rnfea_dict = {} 387 | for i in range(len(norm_grid_rnfea_df)): 388 | k = norm_grid_rnfea_df.index[i] 389 | v = norm_grid_rnfea_df.iloc[i].values 390 | norm_grid_rnfea_dict[k] = list(v) 391 | 392 | for xid in range(1, max_xid+1): 393 | for yid in range(1, max_yid+1): 394 | if (xid,yid) not in norm_grid_rnfea_dict.keys(): 395 | norm_grid_rnfea_dict[(xid,yid)] = [0.] * len(v) 396 | 397 | return norm_grid_rnfea_dict 398 | 399 | def get_rid_rnfea_dict(rn_dict, parameters): 400 | df = pd.DataFrame(rn_dict).T 401 | 402 | # standardization length 403 | df['norm_len'] = [np.log10(l) /np.log10(df['length'].max()) for l in df['length']] 404 | # df['norm_len'] = (df['length'] - df['length'].mean())/df['length'].std() 405 | 406 | # one hot road level 407 | one_hot_df = pd.get_dummies(df.level, prefix='level') 408 | df = df.join(one_hot_df) 409 | 410 | # get number of neighbours (standardization) 411 | g = nx.Graph() 412 | edges = [] 413 | for coords in df['coords'].values: 414 | start_node = (coords[0].lat, coords[0].lng) 415 | end_node = (coords[-1].lat, coords[-1].lng) 416 | edges.append((start_node, end_node)) 417 | g.add_edges_from(edges) 418 | 419 | num_start_neighbors = [] 420 | num_end_neighbors = [] 421 | for coords in df['coords'].values: 422 | start_node = (coords[0].lat, coords[0].lng) 423 | end_node = (coords[-1].lat, coords[-1].lng) 424 | num_start_neighbors.append(len(list(g.edges(start_node)))) 425 | num_end_neighbors.append(len(list(g.edges(end_node)))) 426 | df['num_start_neighbors'] = num_start_neighbors 427 | df['num_end_neighbors'] = num_end_neighbors 428 | start = df['num_start_neighbors'] 429 | end = df['num_end_neighbors'] 430 | # distribution is like gaussian --> use min max normalization 431 | df['norm_num_start_neighbors'] = (start - start.min())/(start.max() - start.min()) 432 | df['norm_num_end_neighbors'] = (end - end.min())/(end.max() - end.min()) 433 | 434 | # convert to dict 435 | norm_rid_rnfea_dict = {} 436 | for i in range(len(df)): 437 | k = df.index[i] 438 | v = df.iloc[i][['norm_len','level_2','level_3','level_4','level_5','level_6',\ 439 | 'norm_num_start_neighbors','norm_num_end_neighbors']] 440 | norm_rid_rnfea_dict[k] = list(v) 441 | 442 | norm_rid_rnfea_dict[0] = [0.]*len(list(v)) # add soss 443 | return norm_rid_rnfea_dict 444 | 445 | ##################################################################################################### 446 | # 447 | # Use for online features 448 | # 449 | ##################################################################################################### 450 | def get_rid_grid_dict(grid_rn_dict): 451 | rid_grid_dict = {} 452 | for k, v in grid_rn_dict.items(): 453 | for rid in v: 454 | if rid not in rid_grid_dict: 455 | rid_grid_dict[rid] = [k] 456 | else: 457 | rid_grid_dict[rid].append(k) 458 | 459 | for k,v in rid_grid_dict.items(): 460 | rid_grid_dict[k] = list(set(v)) 461 | 462 | return rid_grid_dict 463 | 464 | def get_online_info_dict(grid_rn_dict, norm_grid_poi_dict, norm_grid_rnfea_dict, parameters): 465 | rid_grid_dict = get_rid_grid_dict(grid_rn_dict) 466 | online_features_dict = {} 467 | for rid in rid_grid_dict.keys(): 468 | online_feas = [] 469 | for grid in rid_grid_dict[rid]: 470 | try: 471 | poi = norm_grid_poi_dict[grid] 472 | except: 473 | poi = [0.]*5 474 | try: 475 | rnfea = norm_grid_rnfea_dict[grid] 476 | except: 477 | rnfea = [0.]*5 478 | online_feas.append(poi + rnfea) 479 | 480 | online_feas = np.array(online_feas) 481 | online_features_dict[rid] = list(online_feas.mean(axis=0)) 482 | 483 | online_features_dict[0] = [0.]*online_feas.shape[1] # add soss 484 | 485 | return online_features_dict 486 | 487 | def get_dict_info_batch(input_id, features_dict): 488 | """ 489 | batched dict info 490 | """ 491 | # input_id = [1, batch size] 492 | features = [] 493 | for rid in input_id.squeeze(1): 494 | features.append(features_dict[rid.cpu().tolist()]) 495 | 496 | features = torch.tensor(features).float() 497 | # features = [1, batch size, features dim] 498 | return features 499 | 500 | ##################################################################################################### 501 | # 502 | # Use for visualization 503 | # 504 | ##################################################################################################### 505 | def get_plot_seq(raw_input, predict, target, src_len, trg_len): 506 | """ 507 | Get input, prediction and ground truth GPS sequence. 508 | raw_input, predict, target = [seq len, batch size, 2] and the sos is not removed. 509 | """ 510 | raw_input = raw_input[1:].permute(1, 0, 2) 511 | predict = predict[1:].permute(1, 0, 2) # [batch size, seq len, 2] 512 | target = target[1:].permute(1, 0, 2) # [batch size, seq len, 2] 513 | 514 | bs = predict.size(0) 515 | 516 | ls_pre_seq, ls_trg_seq, ls_input_seq =[], [], [] 517 | for bs_i in range(bs): 518 | pre_seq = [] 519 | trg_seq = [] 520 | for len_i in range(trg_len[bs_i]-1): 521 | pre_seq.append([predict[bs_i, len_i][0].cpu().data.tolist(), predict[bs_i, len_i][1].cpu().data.tolist()]) 522 | trg_seq.append([target[bs_i, len_i][0].cpu().data.tolist(), target[bs_i, len_i][1].cpu().data.tolist()]) 523 | input_seq = [] 524 | for len_i in range(src_len[bs_i]-1): 525 | input_seq.append([raw_input[bs_i, len_i][0].cpu().data.tolist(), raw_input[bs_i, len_i][1].cpu().data.tolist()]) 526 | ls_pre_seq.append(pre_seq) 527 | ls_trg_seq.append(trg_seq) 528 | ls_input_seq.append(input_seq) 529 | return ls_input_seq, ls_pre_seq, ls_trg_seq 530 | 531 | 532 | ##################################################################################################### 533 | # 534 | # POIs 535 | # 536 | ##################################################################################################### 537 | def filterPOI(df, mbr): 538 | labels = ['公司企业', '美食', '运动健身', '教育培训', '购物', '政府机构', '旅游景点', '出入口', '房地产', '生活服务', 539 | '交通设施', '汽车服务', '酒店', '丽人', '医疗', '文化传媒', '金融', '休闲娱乐', '道路','自然地物', '行政地标', '门址'] 540 | eng_labels = ['company','food', 'gym', 'education','shopping','gov', 'viewpoint','entrance','house','life', 541 | 'traffic','car','hotel','beauty','hospital','media','finance','entertainment','road','nature','landmark','address'] 542 | eng_labels_dict = {} 543 | for i in range(len(labels)): 544 | eng_labels_dict[labels[i]] = eng_labels[i] 545 | 546 | new_df = {'lat':[],'lng':[],'type':[]} 547 | for i in range(len(df)): 548 | gps = df.iloc[i]['经纬度wgs编码'].split(',') 549 | lat = float(gps[0]) 550 | lng = float(gps[1]) 551 | label = df.iloc[i]['一级行业分类'] 552 | if mbr.contains(lat,lng) and (label is not np.nan): 553 | new_df['lat'].append(lat) 554 | new_df['lng'].append(lng) 555 | new_df['type'].append(eng_labels_dict[label]) 556 | new_df = pd.DataFrame(new_df) 557 | return new_df 558 | 559 | 560 | def get_poi_grid(mbr, grid_size, df): 561 | labels = ['company','food','shopping','viewpoint','house'] 562 | new_df = filterPOI(df, mbr) 563 | LAT_PER_METER = 8.993203677616966e-06 564 | LNG_PER_METER = 1.1700193970443768e-05 565 | lat_unit = LAT_PER_METER * grid_size 566 | lng_unit = LNG_PER_METER * grid_size 567 | 568 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1 569 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1 570 | 571 | grid_poi_dict = {} 572 | for i in range(len(new_df)): 573 | lat = new_df.iloc[i]['lat'] 574 | lng = new_df.iloc[i]['lng'] 575 | label = new_df.iloc[i]['type'] 576 | # only consider partial labels 577 | if label in labels: 578 | locgrid_x = int((lat - mbr.min_lat) / lat_unit) + 1 579 | locgrid_y = int((lng - mbr.min_lng) / lng_unit) + 1 580 | 581 | if (locgrid_x, locgrid_y) not in grid_poi_dict.keys(): 582 | grid_poi_dict[(locgrid_x, locgrid_y)] = {label:1} 583 | else: 584 | if label not in grid_poi_dict[(locgrid_x, locgrid_y)].keys(): 585 | grid_poi_dict[(locgrid_x, locgrid_y)][label] = 1 586 | else: 587 | grid_poi_dict[(locgrid_x, locgrid_y)][label]+=1 588 | 589 | # key: grid, value: [0.1,0.5,0.5] normalized POI by column 590 | grid_poi_df = pd.DataFrame(grid_poi_dict).T.fillna(0) 591 | 592 | norm_grid_poi_df=(grid_poi_df-grid_poi_df.min())/(grid_poi_df.max()-grid_poi_df.min()) 593 | # norm_grid_poi_df = grid_poi_df.div(grid_poi_df.sum(axis=1), axis=0) # row normalization 594 | 595 | norm_grid_poi_dict = {} 596 | for i in range(len(norm_grid_poi_df)): 597 | k = norm_grid_poi_df.index[i] 598 | v = norm_grid_poi_df.iloc[i].values 599 | norm_grid_poi_dict[k] = v 600 | 601 | return norm_grid_poi_dict, grid_poi_df 602 | 603 | 604 | # extra_info_dir = "../data/map/extra_info/" 605 | # poi_df = pd.read_csv(extra_info_dir+'jnPoiInfo.txt',sep='\t') 606 | # norm_grid_poi_dict, grid_poi_df = get_poi_grid(mbr, args.grid_size, poi_df) 607 | 608 | # save_pkl_data(norm_grid_poi_dict, extra_info_dir, 'poi_col_norm.pkl') 609 | # grid_poi_df = pd.to_csv(extra_info_dir+'poi.csv') 610 | 611 | ##################################################################################################### 612 | # 613 | # others 614 | # 615 | ##################################################################################################### 616 | def epoch_time(start_time, end_time): 617 | elapsed_time = end_time - start_time 618 | elapsed_mins = int(elapsed_time / 60) 619 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60)) 620 | return elapsed_mins, elapsed_secs 621 | 622 | 623 | class AttrDict(dict): 624 | def __init__(self, *args, **kwargs): 625 | super(AttrDict, self).__init__(*args, **kwargs) 626 | self.__dict__ = self -------------------------------------------------------------------------------- /models/models_attn_tandem.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/11/5 10:27 4 | 5 | import random 6 | import operator 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from models.model_utils import get_dict_info_batch 12 | 13 | 14 | def mask_log_softmax(x, mask, log_flag=True): 15 | maxes = torch.max(x, 1, keepdim=True)[0] 16 | x_exp = torch.exp(x - maxes) * mask 17 | x_exp_sum = torch.sum(x_exp, 1, keepdim=True) 18 | if log_flag: 19 | output_custom = torch.log(x_exp / x_exp_sum) 20 | else: 21 | output_custom = x_exp / x_exp_sum 22 | return output_custom 23 | 24 | class Extra_MLP(nn.Module): 25 | def __init__(self, parameters): 26 | super().__init__() 27 | self.pro_input_dim = parameters.pro_input_dim 28 | self.pro_output_dim = parameters.pro_output_dim 29 | self.fc_out = nn.Linear(self.pro_input_dim, self.pro_output_dim) 30 | 31 | def forward(self, x): 32 | out = torch.tanh(self.fc_out(x)) 33 | return out 34 | 35 | 36 | class Encoder(nn.Module): 37 | def __init__(self, parameters): 38 | super().__init__() 39 | self.hid_dim = parameters.hid_dim 40 | self.pro_output_dim = parameters.pro_output_dim 41 | self.online_features_flag = parameters.online_features_flag 42 | self.pro_features_flag = parameters.pro_features_flag 43 | 44 | input_dim = 3 45 | if self.online_features_flag: 46 | input_dim = input_dim + parameters.online_dim 47 | 48 | self.rnn = nn.GRU(input_dim, self.hid_dim) 49 | self.dropout = nn.Dropout(parameters.dropout) 50 | 51 | if self.pro_features_flag: 52 | self.extra = Extra_MLP(parameters) 53 | self.fc_hid = nn.Linear(self.hid_dim + self.pro_output_dim, self.hid_dim) 54 | 55 | def forward(self, src, src_len, pro_features): 56 | # src = [src len, batch size, 3] 57 | # if only input trajectory, input dim = 2; elif input trajectory + behavior feature, input dim = 2 + n 58 | # src_len = [batch size] 59 | 60 | packed_embedded = nn.utils.rnn.pack_padded_sequence(src, src_len) 61 | packed_outputs, hidden = self.rnn(packed_embedded) 62 | 63 | # packed_outputs is a packed sequence containing all hidden states 64 | # hidden is now from the final non-padded element in the batch 65 | 66 | outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs) 67 | # outputs is now a non-packed sequence, all hidden states obtained 68 | # when the input is a pad token are all zeros 69 | 70 | # outputs = [src len, batch size, hid dim * num directions] 71 | # hidden = [n layers * num directions, batch size, hid dim] 72 | 73 | # initial decoder hidden is final hidden state of the forwards and backwards 74 | # encoder RNNs fed through a linear layer 75 | 76 | # hidden = [1, batch size, hidden_dim] 77 | # outputs = [src len, batch size, hidden_dim * num directions] 78 | 79 | if self.pro_features_flag: 80 | extra_emb = self.extra(pro_features) 81 | extra_emb = extra_emb.unsqueeze(0) 82 | # extra_emb = [1, batch size, extra output dim] 83 | hidden = torch.tanh(self.fc_hid(torch.cat((extra_emb, hidden), dim=2))) 84 | # hidden = [1, batch size, hid dim] 85 | 86 | return outputs, hidden 87 | 88 | 89 | class Attention(nn.Module): 90 | # TODO update to more advanced attention layer. 91 | def __init__(self, parameters): 92 | super().__init__() 93 | self.hid_dim = parameters.hid_dim 94 | 95 | self.attn = nn.Linear(self.hid_dim * 2, self.hid_dim) 96 | self.v = nn.Linear(self.hid_dim, 1, bias=False) 97 | 98 | def forward(self, hidden, encoder_outputs, attn_mask): 99 | # hidden = [1, bath size, hid dim] 100 | # encoder_outputs = [src len, batch size, hid dim * num directions] 101 | src_len = encoder_outputs.shape[0] 102 | # repeat decoder hidden sate src_len times 103 | hidden = hidden.repeat(src_len, 1, 1) 104 | hidden = hidden.permute(1, 0, 2) 105 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 106 | # hidden = [batch size, src len, hid dim] 107 | # encoder_outputs = [batch size, src len, hid dim * num directions] 108 | 109 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2))) 110 | # energy = [batch size, src len, hid dim] 111 | 112 | attention = self.v(energy).squeeze(2) 113 | # attention = [batch size, src len] 114 | attention = attention.masked_fill(attn_mask == 0, -1e10) 115 | # using mask to force the attention to only be over non-padding elements. 116 | 117 | return F.softmax(attention, dim=1) 118 | 119 | 120 | class DecoderMulti(nn.Module): 121 | def __init__(self, parameters): 122 | super().__init__() 123 | 124 | self.id_size = parameters.id_size 125 | self.id_emb_dim = parameters.id_emb_dim 126 | self.hid_dim = parameters.hid_dim 127 | self.pro_output_dim = parameters.pro_output_dim 128 | self.online_dim = parameters.online_dim 129 | self.rid_fea_dim = parameters.rid_fea_dim 130 | 131 | self.attn_flag = parameters.attn_flag 132 | self.dis_prob_mask_flag = parameters.dis_prob_mask_flag # final softmax 133 | self.online_features_flag = parameters.online_features_flag 134 | self.tandem_fea_flag = parameters.tandem_fea_flag 135 | 136 | self.emb_id = nn.Embedding(self.id_size, self.id_emb_dim) 137 | 138 | rnn_input_dim = self.id_emb_dim + 1 139 | fc_id_out_input_dim = self.hid_dim 140 | fc_rate_out_input_dim = self.hid_dim 141 | 142 | type_input_dim = self.id_emb_dim + self.hid_dim 143 | self.tandem_fc = nn.Sequential( 144 | nn.Linear(type_input_dim, self.hid_dim), 145 | nn.ReLU() 146 | ) 147 | 148 | if self.attn_flag: 149 | self.attn = Attention(parameters) 150 | rnn_input_dim = rnn_input_dim + self.hid_dim 151 | 152 | if self.online_features_flag: 153 | rnn_input_dim = rnn_input_dim + self.online_dim # 5 poi and 5 road network 154 | 155 | if self.tandem_fea_flag: 156 | fc_rate_out_input_dim = self.hid_dim + self.rid_fea_dim 157 | 158 | self.rnn = nn.GRU(rnn_input_dim, self.hid_dim) 159 | self.fc_id_out = nn.Linear(fc_id_out_input_dim, self.id_size) 160 | self.fc_rate_out = nn.Linear(fc_rate_out_input_dim, 1) 161 | self.dropout = nn.Dropout(parameters.dropout) 162 | 163 | 164 | def forward(self, input_id, input_rate, hidden, encoder_outputs, attn_mask, 165 | pre_grid, next_grid, constraint_vec, pro_features, online_features, rid_features): 166 | 167 | # input_id = [batch size, 1] rid long 168 | # input_rate = [batch size, 1] rate float. 169 | # hidden = [1, batch size, hid dim] 170 | # encoder_outputs = [src len, batch size, hid dim * num directions] 171 | # attn_mask = [batch size, src len] 172 | # pre_grid = [batch size, 3] 173 | # next_grid = [batch size, 3] 174 | # constraint_vec = [batch size, id_size], [id_size] is the vector of reachable rid 175 | # pro_features = [batch size, profile features input dim] 176 | # online_features = [batch size, online features dim] 177 | # rid_features = [batch size, rid features dim] 178 | 179 | input_id = input_id.squeeze(1).unsqueeze(0) # cannot use squeeze() bug for batch size = 1 180 | # input_id = [1, batch size] 181 | input_rate = input_rate.unsqueeze(0) 182 | # input_rate = [1, batch size, 1] 183 | embedded = self.dropout(self.emb_id(input_id)) 184 | # embedded = [1, batch size, emb dim] 185 | 186 | if self.attn_flag: 187 | a = self.attn(hidden, encoder_outputs, attn_mask) 188 | # a = [batch size, src len] 189 | a = a.unsqueeze(1) 190 | # a = [batch size, 1, src len] 191 | encoder_outputs = encoder_outputs.permute(1, 0, 2) 192 | # encoder_outputs = [batch size, src len, hid dim * num directions] 193 | weighted = torch.bmm(a, encoder_outputs) 194 | # weighted = [batch size, 1, hid dim * num directions] 195 | weighted = weighted.permute(1, 0, 2) 196 | # weighted = [1, batch size, hid dim * num directions] 197 | 198 | if self.online_features_flag: 199 | rnn_input = torch.cat((weighted, embedded, input_rate, 200 | online_features.unsqueeze(0)), dim=2) 201 | else: 202 | rnn_input = torch.cat((weighted, embedded, input_rate), dim=2) 203 | else: 204 | if self.online_features_flag: 205 | rnn_input = torch.cat((embedded, input_rate, online_features.unsqueeze(0)), dim=2) 206 | else: 207 | rnn_input = torch.cat((embedded, input_rate), dim=2) 208 | 209 | output, hidden = self.rnn(rnn_input, hidden) 210 | 211 | # output = [seq len, batch size, hid dim * n directions] 212 | # hidden = [n layers * n directions, batch size, hid dim] 213 | # seq len and n directions will always be 1 in the decoder, therefore: 214 | # output = [1, batch size, dec hid dim] 215 | # hidden = [1, batch size, dec hid dim] 216 | assert (output == hidden).all() 217 | 218 | # pre_rid 219 | if self.dis_prob_mask_flag: 220 | prediction_id = mask_log_softmax(self.fc_id_out(output.squeeze(0)), 221 | constraint_vec, log_flag=True) 222 | else: 223 | prediction_id = F.log_softmax(self.fc_id_out(output.squeeze(0)), dim=1) 224 | # then the loss function should change to nll_loss() 225 | 226 | # pre_rate 227 | max_id = prediction_id.argmax(dim=1).long() 228 | id_emb = self.dropout(self.emb_id(max_id)) 229 | rate_input = torch.cat((id_emb, hidden.squeeze(0)),dim=1) 230 | rate_input = self.tandem_fc(rate_input) # [batch size, hid dim] 231 | if self.tandem_fea_flag: 232 | prediction_rate = torch.sigmoid(self.fc_rate_out(torch.cat((rate_input, rid_features), dim=1))) 233 | else: 234 | prediction_rate = torch.sigmoid(self.fc_rate_out(rate_input)) 235 | 236 | # prediction_id = [batch size, id_size] 237 | # prediction_rate = [batch size, 1] 238 | 239 | return prediction_id, prediction_rate, hidden 240 | 241 | class Seq2SeqMulti(nn.Module): 242 | def __init__(self, encoder, decoder, device): 243 | super().__init__() 244 | 245 | self.encoder = encoder # Encoder 246 | self.decoder = decoder # DecoderMulti 247 | self.device = device 248 | 249 | def forward(self, src, src_len, trg_id, trg_rate, trg_len, 250 | pre_grids, next_grids, constraint_mat, pro_features, 251 | online_features_dict, rid_features_dict, 252 | teacher_forcing_ratio=0.5): 253 | """ 254 | src = [src len, batch size, 3], x,y,t 255 | src_len = [batch size] 256 | trg_id = [trg len, batch size, 1] 257 | trg_rate = [trg len, batch size, 1] 258 | trg_len = [batch size] 259 | pre_grids = [trg len, batch size, 3] 260 | nex_grids = [trg len, batch size, 3] 261 | constraint_mat = [trg len, batch size, id_size] 262 | pro_features = [batch size, profile features input dim] 263 | online_features_dict = {rid: online_features} # rid --> grid --> online features 264 | rid_features_dict = {rid: rn_features} 265 | teacher_forcing_ratio is probability to use teacher forcing 266 | e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time 267 | Return: 268 | ------ 269 | outputs_id: [seq len, batch size, id_size(1)] based on beam search 270 | outputs_rate: [seq len, batch size, 1] 271 | """ 272 | max_trg_len = trg_id.size(0) 273 | batch_size = trg_id.size(1) 274 | 275 | # encoder_outputs is all hidden states of the input sequence, back and forwards 276 | # hidden is the final forward and backward hidden states, passed through a linear layer 277 | encoder_outputs, hiddens = self.encoder(src, src_len, pro_features) 278 | 279 | if self.decoder.attn_flag: 280 | attn_mask = torch.zeros(batch_size, max(src_len)) # only attend on unpadded sequence 281 | for i in range(len(src_len)): 282 | attn_mask[i][:src_len[i]] = 1. 283 | attn_mask = attn_mask.to(self.device) 284 | else: 285 | attn_mask = None 286 | 287 | outputs_id, outputs_rate = self.normal_step(max_trg_len, batch_size, trg_id, trg_rate, trg_len, 288 | encoder_outputs, hiddens, attn_mask, 289 | online_features_dict, 290 | rid_features_dict, 291 | pre_grids, next_grids, constraint_mat, pro_features, 292 | teacher_forcing_ratio) 293 | 294 | return outputs_id, outputs_rate 295 | 296 | def normal_step(self, max_trg_len, batch_size, trg_id, trg_rate, trg_len, encoder_outputs, hidden, 297 | attn_mask, online_features_dict, rid_features_dict, 298 | pre_grids, next_grids, constraint_mat, pro_features, teacher_forcing_ratio): 299 | """ 300 | Returns: 301 | ------- 302 | outputs_id: [seq len, batch size, id size] 303 | outputs_rate: [seq len, batch size, 1] 304 | """ 305 | # tensor to store decoder outputs 306 | outputs_id = torch.zeros(max_trg_len, batch_size, self.decoder.id_size).to(self.device) 307 | outputs_rate = torch.zeros(trg_rate.size()).to(self.device) 308 | 309 | # first input to the decoder is the tokens 310 | input_id = trg_id[0, :] 311 | input_rate = trg_rate[0, :] 312 | for t in range(1, max_trg_len): 313 | # insert input token embedding, previous hidden state, all encoder hidden states 314 | # and attn_mask 315 | # receive output tensor (predictions) and new hidden state 316 | if self.decoder.online_features_flag: 317 | online_features = get_dict_info_batch(input_id, online_features_dict).to(self.device) 318 | else: 319 | online_features = torch.zeros((1, batch_size, self.decoder.online_dim)) 320 | if self.decoder.tandem_fea_flag: 321 | rid_features = get_dict_info_batch(input_id, rid_features_dict).to(self.device) 322 | else: 323 | rid_features = None 324 | prediction_id, prediction_rate, hidden = self.decoder(input_id, input_rate, hidden, encoder_outputs, 325 | attn_mask, pre_grids[t], next_grids[t], 326 | constraint_mat[t], pro_features, online_features, 327 | rid_features) 328 | 329 | # place predictions in a tensor holding predictions for each token 330 | outputs_id[t] = prediction_id 331 | outputs_rate[t] = prediction_rate 332 | 333 | # decide if we are going to use teacher forcing or not 334 | teacher_force = random.random() < teacher_forcing_ratio 335 | 336 | # get the highest predicted token from our predictions 337 | top1_id = prediction_id.argmax(1) 338 | top1_id = top1_id.unsqueeze(-1) # make sure the output has the same dimension as input 339 | 340 | # if teacher forcing, use actual next token as next input 341 | # if not, use predicted token 342 | input_id = trg_id[t] if teacher_force else top1_id 343 | input_rate = trg_rate[t] if teacher_force else prediction_rate 344 | 345 | # max_trg_len, batch_size, trg_rid_size 346 | outputs_id = outputs_id.permute(1, 0, 2) # batch size, seq len, rid size 347 | outputs_rate = outputs_rate.permute(1, 0, 2) # batch size, seq len, 1 348 | for i in range(batch_size): 349 | outputs_id[i][trg_len[i]:] = 0 350 | outputs_id[i][trg_len[i]:, 0] = 1 # make sure argmax will return eid0 351 | outputs_rate[i][trg_len[i]:] = 0 352 | outputs_id = outputs_id.permute(1, 0, 2) 353 | outputs_rate = outputs_rate.permute(1, 0, 2) 354 | 355 | return outputs_id, outputs_rate 356 | 357 | -------------------------------------------------------------------------------- /models/multi_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/10/29 17:37 4 | 5 | import numpy as np 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from models.model_utils import toseq, get_constraint_mask 12 | from models.loss_fn import cal_id_acc, check_rn_dis_loss 13 | 14 | # set random seed 15 | SEED = 20202020 16 | 17 | random.seed(SEED) 18 | np.random.seed(SEED) 19 | torch.manual_seed(SEED) 20 | torch.cuda.manual_seed(SEED) 21 | torch.backends.cudnn.deterministic = True 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | print('multi_task device', device) 24 | 25 | 26 | def init_weights(self): 27 | """ 28 | Here we reproduce Keras default initialization weights for consistency with Keras version 29 | Reference: https://github.com/vonfeng/DeepMove/blob/master/codes/model.py 30 | """ 31 | ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name) 32 | hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name) 33 | b = (param.data for name, param in self.named_parameters() if 'bias' in name) 34 | 35 | for t in ih: 36 | nn.init.xavier_uniform_(t) 37 | for t in hh: 38 | nn.init.orthogonal_(t) 39 | for t in b: 40 | nn.init.constant_(t, 0) 41 | 42 | 43 | def train(model, iterator, optimizer, log_vars, rn_dict, grid_rn_dict, rn, 44 | raw2new_rid_dict, online_features_dict, rid_features_dict, parameters): 45 | model.train() # not necessary to have this line but it's safe to use model.train() to train model 46 | 47 | criterion_reg = nn.MSELoss() 48 | criterion_ce = nn.NLLLoss() 49 | 50 | epoch_ttl_loss = 0 51 | epoch_id1_loss = 0 52 | epoch_recall_loss = 0 53 | epoch_precision_loss = 0 54 | epoch_train_id_loss = 0 55 | epoch_rate_loss = 0 56 | for i, batch in enumerate(iterator): 57 | src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths = batch 58 | if parameters.dis_prob_mask_flag: 59 | constraint_mat, pre_grids, next_grids = get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths, 60 | trg_lengths, grid_rn_dict, rn, raw2new_rid_dict, 61 | parameters) 62 | constraint_mat = constraint_mat.permute(1, 0, 2).to(device) 63 | pre_grids = pre_grids.permute(1, 0, 2).to(device) 64 | next_grids = next_grids.permute(1, 0, 2).to(device) 65 | else: 66 | max_trg_len = max(trg_lengths) 67 | batch_size = src_grid_seqs.size(0) 68 | constraint_mat = torch.zeros(max_trg_len, batch_size, parameters.id_size, device=device) 69 | pre_grids = torch.zeros(max_trg_len, batch_size, 3).to(device) 70 | next_grids = torch.zeros(max_trg_len, batch_size, 3).to(device) 71 | 72 | src_pro_feas = src_pro_feas.float().to(device) 73 | 74 | src_grid_seqs = src_grid_seqs.permute(1, 0, 2).to(device) 75 | trg_gps_seqs = trg_gps_seqs.permute(1, 0, 2).to(device) 76 | trg_rids = trg_rids.permute(1, 0, 2).long().to(device) 77 | trg_rates = trg_rates.permute(1, 0, 2).to(device) 78 | 79 | # constraint_mat = [trg len, batch size, id size] 80 | # src_grid_seqs = [src len, batch size, 2] 81 | # src_lengths = [batch size] 82 | # trg_gps_seqs = [trg len, batch size, 2] 83 | # trg_rids = [trg len, batch size, 1] 84 | # trg_rates = [trg len, batch size, 1] 85 | # trg_lengths = [batch size] 86 | 87 | optimizer.zero_grad() 88 | output_ids, output_rates = model(src_grid_seqs, src_lengths, trg_rids, trg_rates, trg_lengths, 89 | pre_grids, next_grids, constraint_mat, src_pro_feas, 90 | online_features_dict, rid_features_dict, parameters.tf_ratio) 91 | output_rates = output_rates.squeeze(2) 92 | trg_rids = trg_rids.squeeze(2) 93 | trg_rates = trg_rates.squeeze(2) 94 | 95 | # output_ids = [trg len, batch size, id one hot output dim] 96 | # output_rates = [trg len, batch size] 97 | # trg_rids = [trg len, batch size] 98 | # trg_rates = [trg len, batch size] 99 | 100 | # rid loss, only show and not bbp 101 | loss_ids1, recall, precision = cal_id_acc(output_ids[1:], trg_rids[1:], trg_lengths) 102 | 103 | # for bbp 104 | output_ids_dim = output_ids.shape[-1] 105 | output_ids = output_ids[1:].reshape(-1, output_ids_dim) # [(trg len - 1)* batch size, output id one hot dim] 106 | trg_rids = trg_rids[1:].reshape(-1) # [(trg len - 1) * batch size], 107 | # view size is not compatible with input tensor's size and stride ==> use reshape() instead 108 | 109 | loss_train_ids = criterion_ce(output_ids, trg_rids) 110 | loss_rates = criterion_reg(output_rates[1:], trg_rates[1:]) * parameters.lambda1 111 | ttl_loss = loss_train_ids + loss_rates 112 | 113 | ttl_loss.backward() 114 | torch.nn.utils.clip_grad_norm_(model.parameters(), parameters.clip) # log_vars are not necessary to clip 115 | optimizer.step() 116 | 117 | epoch_ttl_loss += ttl_loss.item() 118 | epoch_id1_loss += loss_ids1 119 | epoch_recall_loss += recall 120 | epoch_precision_loss += precision 121 | epoch_train_id_loss += loss_train_ids.item() 122 | epoch_rate_loss += loss_rates.item() 123 | 124 | return log_vars, epoch_ttl_loss / len(iterator), epoch_id1_loss / len(iterator), epoch_recall_loss / len(iterator), \ 125 | epoch_precision_loss / len(iterator), epoch_rate_loss / len(iterator), epoch_train_id_loss / len(iterator) 126 | 127 | 128 | def evaluate(model, iterator, rn_dict, grid_rn_dict, rn, raw2new_rid_dict, 129 | online_features_dict, rid_features_dict, raw_rn_dict, new2raw_rid_dict, parameters): 130 | model.eval() # must have this line since it will affect dropout and batch normalization 131 | 132 | epoch_dis_mae_loss = 0 133 | epoch_dis_rmse_loss = 0 134 | epoch_dis_rn_mae_loss = 0 135 | epoch_dis_rn_rmse_loss = 0 136 | epoch_id1_loss = 0 137 | epoch_recall_loss = 0 138 | epoch_precision_loss = 0 139 | epoch_rate_loss = 0 140 | epoch_id_loss = 0 # loss from dl model 141 | criterion_ce = nn.NLLLoss() 142 | criterion_reg = nn.MSELoss() 143 | 144 | with torch.no_grad(): # this line can help speed up evaluation 145 | for i, batch in enumerate(iterator): 146 | src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths = batch 147 | 148 | if parameters.dis_prob_mask_flag: 149 | constraint_mat, pre_grids, next_grids = get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths, 150 | trg_lengths, grid_rn_dict, rn, 151 | raw2new_rid_dict, parameters) 152 | constraint_mat = constraint_mat.permute(1, 0, 2).to(device) 153 | pre_grids = pre_grids.permute(1, 0, 2).to(device) 154 | next_grids = next_grids.permute(1, 0, 2).to(device) 155 | else: 156 | max_trg_len = max(trg_lengths) 157 | batch_size = src_grid_seqs.size(0) 158 | constraint_mat = torch.zeros(max_trg_len, batch_size, parameters.id_size).to(device) 159 | pre_grids = torch.zeros(max_trg_len, batch_size, 3).to(device) 160 | next_grids = torch.zeros(max_trg_len, batch_size, 3).to(device) 161 | 162 | src_pro_feas = src_pro_feas.float().to(device) 163 | 164 | src_grid_seqs = src_grid_seqs.permute(1, 0, 2).to(device) 165 | trg_gps_seqs = trg_gps_seqs.permute(1, 0, 2).to(device) 166 | trg_rids = trg_rids.permute(1, 0, 2).long().to(device) 167 | trg_rates = trg_rates.permute(1, 0, 2).to(device) 168 | 169 | # constraint_mat = [trg len, batch size, id size] 170 | # src_grid_seqs = [src len, batch size, 2] 171 | # src_pro_feas = [batch size, feature dim] 172 | # src_lengths = [batch size] 173 | # trg_gps_seqs = [trg len, batch size, 2] 174 | # trg_rids = [trg len, batch size, 1] 175 | # trg_rates = [trg len, batch size, 1] 176 | # trg_lengths = [batch size] 177 | 178 | output_ids, output_rates = model(src_grid_seqs, src_lengths, trg_rids, trg_rates, trg_lengths, 179 | pre_grids, next_grids, constraint_mat, 180 | src_pro_feas, online_features_dict, rid_features_dict, 181 | teacher_forcing_ratio=0) 182 | 183 | output_rates = output_rates.squeeze(2) 184 | output_seqs = toseq(rn_dict, output_ids, output_rates, parameters) 185 | trg_rids = trg_rids.squeeze(2) 186 | trg_rates = trg_rates.squeeze(2) 187 | # output_ids = [trg len, batch size, id one hot output dim] 188 | # output_rates = [trg len, batch size] 189 | # trg_rids = [trg len, batch size] 190 | # trg_rates = [trg len, batch size] 191 | 192 | # rid loss, only show and not bbp 193 | loss_ids1, recall, precision = cal_id_acc(output_ids[1:], trg_rids[1:], trg_lengths) 194 | # distance loss 195 | dis_mae_loss, dis_rmse_loss, dis_rn_mae_loss, dis_rn_rmse_loss = check_rn_dis_loss(output_seqs[1:], 196 | output_ids[1:], 197 | output_rates[1:], 198 | trg_gps_seqs[1:], 199 | trg_rids[1:], 200 | trg_rates[1:], 201 | trg_lengths, 202 | rn, raw_rn_dict, 203 | new2raw_rid_dict) 204 | 205 | # for bbp 206 | output_ids_dim = output_ids.shape[-1] 207 | output_ids = output_ids[1:].reshape(-1, 208 | output_ids_dim) # [(trg len - 1)* batch size, output id one hot dim] 209 | trg_rids = trg_rids[1:].reshape(-1) # [(trg len - 1) * batch size], 210 | loss_ids = criterion_ce(output_ids, trg_rids) 211 | # rate loss 212 | loss_rates = criterion_reg(output_rates[1:], trg_rates[1:]) * parameters.lambda1 213 | # loss_rates.size = [(trg len - 1), batch size], --> [(trg len - 1)* batch size,1] 214 | 215 | epoch_dis_mae_loss += dis_mae_loss 216 | epoch_dis_rmse_loss += dis_rmse_loss 217 | epoch_dis_rn_mae_loss += dis_rn_mae_loss 218 | epoch_dis_rn_rmse_loss += dis_rn_rmse_loss 219 | epoch_id1_loss += loss_ids1 220 | epoch_recall_loss += recall 221 | epoch_precision_loss += precision 222 | epoch_rate_loss += loss_rates.item() 223 | epoch_id_loss += loss_ids.item() 224 | 225 | return epoch_id1_loss / len(iterator), epoch_recall_loss / len(iterator), \ 226 | epoch_precision_loss / len(iterator), \ 227 | epoch_dis_mae_loss / len(iterator), epoch_dis_rmse_loss / len(iterator), \ 228 | epoch_dis_rn_mae_loss / len(iterator), epoch_dis_rn_rmse_loss / len(iterator), \ 229 | epoch_rate_loss / len(iterator), epoch_id_loss / len(iterator) 230 | -------------------------------------------------------------------------------- /multi_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/11/10 11:07 4 | 5 | import time 6 | from tqdm import tqdm 7 | import logging 8 | import sys 9 | import argparse 10 | import pandas as pd 11 | 12 | import torch 13 | import torch.optim as optim 14 | 15 | from utils.utils import save_json_data, create_dir, load_pkl_data 16 | from common.mbr import MBR 17 | from common.spatial_func import SPoint, distance 18 | from common.road_network import load_rn_shp 19 | 20 | from models.datasets import Dataset, collate_fn, split_data 21 | from models.model_utils import load_rn_dict, load_rid_freqs, get_rid_grid, get_poi_info, get_rn_info 22 | from models.model_utils import get_online_info_dict, epoch_time, AttrDict, get_rid_rnfea_dict 23 | from models.multi_train import evaluate, init_weights, train 24 | from models.models_attn_tandem import Encoder, DecoderMulti, Seq2SeqMulti 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description='Multi-task Traj Interp') 29 | parser.add_argument('--module_type', type=str, default='simple', help='module type') 30 | parser.add_argument('--keep_ratio', type=float, default=0.125, help='keep ratio in float') 31 | parser.add_argument('--lambda1', type=int, default=10, help='weight for multi task rate') 32 | parser.add_argument('--hid_dim', type=int, default=512, help='hidden dimension') 33 | parser.add_argument('--epochs', type=int, default=10, help='epochs') 34 | parser.add_argument('--grid_size', type=int, default=50, help='grid size in int') 35 | parser.add_argument('--dis_prob_mask_flag', action='store_true', help='flag of using prob mask') 36 | parser.add_argument('--pro_features_flag', action='store_true', help='flag of using profile features') 37 | parser.add_argument('--online_features_flag', action='store_true', help='flag of using online features') 38 | parser.add_argument('--tandem_fea_flag', action='store_true', help='flag of using tandem rid features') 39 | parser.add_argument('--no_attn_flag', action='store_false', help='flag of using attention') 40 | parser.add_argument('--load_pretrained_flag', action='store_true', help='flag of load pretrained model') 41 | parser.add_argument('--model_old_path', type=str, default='', help='old model path') 42 | parser.add_argument('--no_debug', action='store_false', help='flag of debug') 43 | parser.add_argument('--no_train_flag', action='store_false', help='flag of training') 44 | parser.add_argument('--test_flag', action='store_true', help='flag of testing') 45 | 46 | 47 | opts = parser.parse_args() 48 | 49 | debug = opts.no_debug 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | args = AttrDict() 53 | args_dict = { 54 | 'module_type':opts.module_type, 55 | 'debug':debug, 56 | 'device':device, 57 | 58 | # pre train 59 | 'load_pretrained_flag':opts.load_pretrained_flag, 60 | 'model_old_path':opts.model_old_path, 61 | 'train_flag':opts.no_train_flag, 62 | 'test_flag':opts.test_flag, 63 | 64 | # attention 65 | 'attn_flag':opts.no_attn_flag, 66 | 67 | # constranit 68 | 'dis_prob_mask_flag':opts.dis_prob_mask_flag, 69 | 'search_dist':50, 70 | 'beta':15, 71 | 72 | # features 73 | 'tandem_fea_flag':opts.tandem_fea_flag, 74 | 'pro_features_flag':opts.pro_features_flag, 75 | 'online_features_flag':opts.online_features_flag, 76 | 77 | # extra info module 78 | 'rid_fea_dim':8, 79 | 'pro_input_dim':30, # 24[hour] + 5[waether] + 1[holiday] 80 | 'pro_output_dim':8, 81 | 'poi_num':5, 82 | 'online_dim':5+5, # poi/roadnetwork features dim 83 | 'poi_type':'company,food,shopping,viewpoint,house', 84 | 85 | # MBR 86 | 'min_lat':36.6456, 87 | 'min_lng':116.9854, 88 | 'max_lat':36.6858, 89 | 'max_lng':117.0692, 90 | 91 | # input data params 92 | 'keep_ratio':opts.keep_ratio, 93 | 'grid_size':opts.grid_size, 94 | 'time_span':15, 95 | 'win_size':25, 96 | 'ds_type':'random', 97 | 'split_flag':True, 98 | 'shuffle':True, 99 | 100 | # model params 101 | 'hid_dim':opts.hid_dim, 102 | 'id_emb_dim':128, 103 | 'dropout':0.5, 104 | 'id_size':2571+1, 105 | 106 | 'lambda1':opts.lambda1, 107 | 'n_epochs':opts.epochs, 108 | 'batch_size':128, 109 | 'learning_rate':1e-3, 110 | 'tf_ratio':0.5, 111 | 'clip':1, 112 | 'log_step':1 113 | } 114 | args.update(args_dict) 115 | 116 | print('Preparing data...') 117 | if args.split_flag: 118 | traj_input_dir = "./data/raw_trajectory/" 119 | output_dir = "./data/model_data/" 120 | split_data(traj_input_dir, output_dir) 121 | 122 | extra_info_dir = "./data/map/extra_info/" 123 | rn_dir = "./data/map/road_network/" 124 | train_trajs_dir = "./data/model_data/train_data/" 125 | valid_trajs_dir = "./data/model_data/valid_data/" 126 | test_trajs_dir = "./data/model_data/test_data/" 127 | if args.tandem_fea_flag: 128 | fea_flag = True 129 | else: 130 | fea_flag = False 131 | 132 | if args.load_pretrained_flag: 133 | model_save_path = args.model_old_path 134 | else: 135 | model_save_path = './results/'+args.module_type+'_kr_'+str(args.keep_ratio)+'_debug_'+str(args.debug)+\ 136 | '_gs_'+str(args.grid_size)+'_lam_'+str(args.lambda1)+\ 137 | '_attn_'+str(args.attn_flag)+'_prob_'+str(args.dis_prob_mask_flag)+\ 138 | '_fea_'+str(fea_flag)+'_'+time.strftime("%Y%m%d_%H%M%S") + '/' 139 | create_dir(model_save_path) 140 | 141 | logging.basicConfig(level=logging.DEBUG, 142 | format='%(asctime)s %(levelname)s %(message)s', 143 | filename=model_save_path + 'log.txt', 144 | filemode='a') 145 | 146 | rn = load_rn_shp(rn_dir, is_directed=True) 147 | raw_rn_dict = load_rn_dict(extra_info_dir, file_name='raw_rn_dict.json') 148 | new2raw_rid_dict = load_rid_freqs(extra_info_dir, file_name='new2raw_rid.json') 149 | raw2new_rid_dict = load_rid_freqs(extra_info_dir, file_name='raw2new_rid.json') 150 | rn_dict = load_rn_dict(extra_info_dir, file_name='rn_dict.json') 151 | 152 | mbr = MBR(args.min_lat, args.min_lng, args.max_lat, args.max_lng) 153 | grid_rn_dict, max_xid, max_yid = get_rid_grid(mbr, args.grid_size, rn_dict) 154 | args_dict['max_xid'] = max_xid 155 | args_dict['max_yid'] = max_yid 156 | args.update(args_dict) 157 | print(args) 158 | logging.info(args_dict) 159 | 160 | # load features 161 | weather_dict = load_pkl_data(extra_info_dir, 'weather_dict.pkl') 162 | if args.online_features_flag: 163 | grid_poi_df = pd.read_csv(extra_info_dir+'poi'+str(args.grid_size)+'.csv',index_col=[0,1]) 164 | norm_grid_poi_dict = get_poi_info(grid_poi_df, args) 165 | norm_grid_rnfea_dict = get_rn_info(rn, mbr, args.grid_size, grid_rn_dict, rn_dict) 166 | online_features_dict = get_online_info_dict(grid_rn_dict, norm_grid_poi_dict, norm_grid_rnfea_dict, args) 167 | else: 168 | norm_grid_poi_dict, norm_grid_rnfea_dict, online_features_dict = None, None, None 169 | if args: 170 | rid_features_dict = get_rid_rnfea_dict(rn_dict, args) 171 | else: 172 | rid_features_dict = None 173 | 174 | # load dataset 175 | train_dataset = Dataset(train_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict, 176 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict, 177 | parameters=args, debug=debug) 178 | valid_dataset = Dataset(valid_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict, 179 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict, 180 | parameters=args, debug=debug) 181 | test_dataset = Dataset(test_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict, 182 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict, 183 | parameters=args, debug=debug) 184 | print('training dataset shape: ' + str(len(train_dataset))) 185 | print('validation dataset shape: ' + str(len(valid_dataset))) 186 | print('test dataset shape: ' + str(len(test_dataset))) 187 | 188 | train_iterator = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, 189 | shuffle=args.shuffle, collate_fn=collate_fn, 190 | num_workers=4, pin_memory=True) 191 | valid_iterator = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, 192 | shuffle=args.shuffle, collate_fn=collate_fn, 193 | num_workers=4, pin_memory=True) 194 | test_iterator = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, 195 | shuffle=args.shuffle, collate_fn=collate_fn, 196 | num_workers=4, pin_memory=True) 197 | 198 | logging.info('Finish data preparing.') 199 | logging.info('training dataset shape: ' + str(len(train_dataset))) 200 | logging.info('validation dataset shape: ' + str(len(valid_dataset))) 201 | logging.info('test dataset shape: ' + str(len(test_dataset))) 202 | 203 | enc = Encoder(args) 204 | dec = DecoderMulti(args) 205 | model = Seq2SeqMulti(enc, dec, device).to(device) 206 | model.apply(init_weights) # learn how to init weights 207 | if args.load_pretrained_flag: 208 | model.load_state_dict(torch.load(args.model_old_path + 'val-best-model.pt')) 209 | 210 | print('model', str(model)) 211 | logging.info('model' + str(model)) 212 | 213 | if args.train_flag: 214 | ls_train_loss, ls_train_id_acc1, ls_train_id_recall, ls_train_id_precision, \ 215 | ls_train_rate_loss, ls_train_id_loss = [], [], [], [], [], [] 216 | ls_valid_loss, ls_valid_id_acc1, ls_valid_id_recall, ls_valid_id_precision, \ 217 | ls_valid_dis_mae_loss, ls_valid_dis_rmse_loss = [], [], [], [], [], [] 218 | ls_valid_dis_rn_mae_loss, ls_valid_dis_rn_rmse_loss, ls_valid_rate_loss, ls_valid_id_loss = [], [], [], [] 219 | 220 | dict_train_loss = {} 221 | dict_valid_loss = {} 222 | best_valid_loss = float('inf') # compare id loss 223 | 224 | # get all parameters (model parameters + task dependent log variances) 225 | log_vars = [torch.zeros((1,), requires_grad=True, device=device)] * 2 # use for auto-tune multi-task param 226 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) 227 | for epoch in tqdm(range(args.n_epochs)): 228 | start_time = time.time() 229 | 230 | new_log_vars, train_loss, train_id_acc1, train_id_recall, train_id_precision, \ 231 | train_rate_loss, train_id_loss = train(model, train_iterator, optimizer, log_vars, 232 | rn_dict, grid_rn_dict, rn, raw2new_rid_dict, 233 | online_features_dict, rid_features_dict, args) 234 | 235 | valid_id_acc1, valid_id_recall, valid_id_precision, valid_dis_mae_loss, valid_dis_rmse_loss, \ 236 | valid_dis_rn_mae_loss, valid_dis_rn_rmse_loss, \ 237 | valid_rate_loss, valid_id_loss = evaluate(model, valid_iterator, 238 | rn_dict, grid_rn_dict, rn, raw2new_rid_dict, 239 | online_features_dict, rid_features_dict, raw_rn_dict, 240 | new2raw_rid_dict, args) 241 | ls_train_loss.append(train_loss) 242 | ls_train_id_acc1.append(train_id_acc1) 243 | ls_train_id_recall.append(train_id_recall) 244 | ls_train_id_precision.append(train_id_precision) 245 | ls_train_rate_loss.append(train_rate_loss) 246 | ls_train_id_loss.append(train_id_loss) 247 | 248 | ls_valid_id_acc1.append(valid_id_acc1) 249 | ls_valid_id_recall.append(valid_id_recall) 250 | ls_valid_id_precision.append(valid_id_precision) 251 | ls_valid_dis_mae_loss.append(valid_dis_mae_loss) 252 | ls_valid_dis_rmse_loss.append(valid_dis_rmse_loss) 253 | ls_valid_dis_rn_mae_loss.append(valid_dis_rn_mae_loss) 254 | ls_valid_dis_rn_rmse_loss.append(valid_dis_rn_rmse_loss) 255 | ls_valid_rate_loss.append(valid_rate_loss) 256 | ls_valid_id_loss.append(valid_id_loss) 257 | valid_loss = valid_rate_loss + valid_id_loss 258 | ls_valid_loss.append(valid_loss) 259 | 260 | dict_train_loss['train_ttl_loss'] = ls_train_loss 261 | dict_train_loss['train_id_acc1'] = ls_train_id_acc1 262 | dict_train_loss['train_id_recall'] = ls_train_id_recall 263 | dict_train_loss['train_id_precision'] = ls_train_id_precision 264 | dict_train_loss['train_rate_loss'] = ls_train_rate_loss 265 | dict_train_loss['train_id_loss'] = ls_train_id_loss 266 | 267 | dict_valid_loss['valid_ttl_loss'] = ls_valid_loss 268 | dict_valid_loss['valid_id_acc1'] = ls_valid_id_acc1 269 | dict_valid_loss['valid_id_recall'] = ls_valid_id_recall 270 | dict_valid_loss['valid_id_precision'] = ls_valid_id_precision 271 | dict_valid_loss['valid_rate_loss'] = ls_valid_rate_loss 272 | dict_valid_loss['valid_dis_mae_loss'] = ls_valid_dis_mae_loss 273 | dict_valid_loss['valid_dis_rmse_loss'] = ls_valid_dis_rmse_loss 274 | dict_valid_loss['valid_dis_rn_mae_loss'] = ls_valid_dis_rn_mae_loss 275 | dict_valid_loss['valid_dis_rn_rmse_loss'] = ls_valid_dis_rn_rmse_loss 276 | dict_valid_loss['valid_id_loss'] = ls_valid_id_loss 277 | 278 | end_time = time.time() 279 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 280 | 281 | if valid_loss < best_valid_loss: 282 | best_valid_loss = valid_loss 283 | torch.save(model.state_dict(), model_save_path + 'val-best-model.pt') 284 | 285 | if (epoch % args.log_step == 0) or (epoch == args.n_epochs - 1): 286 | logging.info('Epoch: ' + str(epoch + 1) + ' Time: ' + str(epoch_mins) + 'm' + str(epoch_secs) + 's') 287 | weights = [torch.exp(weight) ** 0.5 for weight in new_log_vars] 288 | logging.info('log_vars:' + str(weights)) 289 | logging.info('\tTrain Loss:' + str(train_loss) + 290 | '\tTrain RID Acc1:' + str(train_id_acc1) + 291 | '\tTrain RID Recall:' + str(train_id_recall) + 292 | '\tTrain RID Precision:' + str(train_id_precision) + 293 | '\tTrain Rate Loss:' + str(train_rate_loss) + 294 | '\tTrain RID Loss:' + str(train_id_loss)) 295 | logging.info('\tValid Loss:' + str(valid_loss) + 296 | '\tValid RID Acc1:' + str(valid_id_acc1) + 297 | '\tValid RID Recall:' + str(valid_id_recall) + 298 | '\tValid RID Precision:' + str(valid_id_precision) + 299 | '\tValid Distance MAE Loss:' + str(valid_dis_mae_loss) + 300 | '\tValid Distance RMSE Loss:' + str(valid_dis_rmse_loss) + 301 | '\tValid Distance RN MAE Loss:' + str(valid_dis_rn_mae_loss) + 302 | '\tValid Distance RN RMSE Loss:' + str(valid_dis_rn_rmse_loss) + 303 | '\tValid Rate Loss:' + str(valid_rate_loss) + 304 | '\tValid RID Loss:' + str(valid_id_loss)) 305 | 306 | torch.save(model.state_dict(), model_save_path + 'train-mid-model.pt') 307 | save_json_data(dict_train_loss, model_save_path, "train_loss.json") 308 | save_json_data(dict_valid_loss, model_save_path, "valid_loss.json") 309 | 310 | if args.test_flag: 311 | model.load_state_dict(torch.load(model_save_path + 'val-best-model.pt')) 312 | start_time = time.time() 313 | test_id_acc1, test_id_recall, test_id_precision, test_dis_mae_loss, test_dis_rmse_loss, \ 314 | test_dis_rn_mae_loss, test_dis_rn_rmse_loss, test_rate_loss, test_id_loss = evaluate(model, test_iterator, 315 | rn_dict, grid_rn_dict, rn, 316 | raw2new_rid_dict, 317 | online_features_dict, 318 | rid_features_dict, 319 | raw_rn_dict, new2raw_rid_dict, 320 | args) 321 | end_time = time.time() 322 | epoch_mins, epoch_secs = epoch_time(start_time, end_time) 323 | logging.info('Test Time: ' + str(epoch_mins) + 'm' + str(epoch_secs) + 's') 324 | logging.info('\tTest RID Acc1:' + str(test_id_acc1) + 325 | '\tTest RID Recall:' + str(test_id_recall) + 326 | '\tTest RID Precision:' + str(test_id_precision) + 327 | '\tTest Distance MAE Loss:' + str(test_dis_mae_loss) + 328 | '\tTest Distance RMSE Loss:' + str(test_dis_rmse_loss) + 329 | '\tTest Distance RN MAE Loss:' + str(test_dis_rn_mae_loss) + 330 | '\tTest Distance RN RMSE Loss:' + str(test_dis_rn_rmse_loss) + 331 | '\tTest Rate Loss:' + str(test_rate_loss) + 332 | '\tTest RID Loss:' + str(test_id_loss)) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/3 13:34 -------------------------------------------------------------------------------- /utils/coord_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import urllib 4 | import math 5 | 6 | x_pi = 3.14159265358979324 * 3000.0 / 180.0 7 | pi = 3.1415926535897932384626 # π 8 | a = 6378245.0 # 长半轴 9 | ee = 0.00669342162296594323 # 偏心率平方 10 | 11 | 12 | # class Geocoding: 13 | # def __init__(self, api_key): 14 | # self.api_key = api_key 15 | # 16 | # def geocode(self, address): 17 | # """ 18 | # 利用高德geocoding服务解析地址获取位置坐标 19 | # :param address:需要解析的地址 20 | # :return: 21 | # """ 22 | # geocoding = {'s': 'rsv3', 23 | # 'key': self.api_key, 24 | # 'city': '全国', 25 | # 'address': address} 26 | # geocoding = urllib.urlencode(geocoding) 27 | # ret = urllib.urlopen("%s?%s" % ("http://restapi.amap.com/v3/geocode/geo", geocoding)) 28 | # 29 | # if ret.getcode() == 200: 30 | # res = ret.read() 31 | # json_obj = json.loads(res) 32 | # if json_obj['status'] == '1' and int(json_obj['count']) >= 1: 33 | # geocodes = json_obj['geocodes'][0] 34 | # lng = float(geocodes.get('location').split(',')[0]) 35 | # lat = float(geocodes.get('location').split(',')[1]) 36 | # return [lng, lat] 37 | # else: 38 | # return None 39 | # else: 40 | # return None 41 | 42 | class Convert(): 43 | def __init__(self): 44 | pass 45 | 46 | def convert(self, lng, lat): 47 | return lng, lat 48 | 49 | 50 | class GCJ02ToWGS84(Convert): 51 | def __init__(self): 52 | super().__init__() 53 | 54 | def convert(self, lng, lat): 55 | """ 56 | GCJ02(火星坐标系)转GPS84 57 | :param lng:火星坐标系的经度 58 | :param lat:火星坐标系纬度 59 | :return: 60 | """ 61 | if out_of_china(lng, lat): 62 | return [lng, lat] 63 | dlat = _transformlat(lng - 105.0, lat - 35.0) 64 | dlng = _transformlng(lng - 105.0, lat - 35.0) 65 | radlat = lat / 180.0 * pi 66 | magic = math.sin(radlat) 67 | magic = 1 - ee * magic * magic 68 | sqrtmagic = math.sqrt(magic) 69 | dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi) 70 | dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi) 71 | mglat = lat + dlat 72 | mglng = lng + dlng 73 | return lng * 2 - mglng, lat * 2 - mglat 74 | 75 | 76 | class WGS84ToGCJ02(Convert): 77 | def __init__(self): 78 | super().__init__() 79 | 80 | def convert(self, lng, lat): 81 | """ 82 | WGS84转GCJ02(火星坐标系) 83 | :param lng:WGS84坐标系的经度 84 | :param lat:WGS84坐标系的纬度 85 | :return: 86 | """ 87 | if out_of_china(lng, lat): # 判断是否在国内 88 | return [lng, lat] 89 | dlat = _transformlat(lng - 105.0, lat - 35.0) 90 | dlng = _transformlng(lng - 105.0, lat - 35.0) 91 | radlat = lat / 180.0 * pi 92 | magic = math.sin(radlat) 93 | magic = 1 - ee * magic * magic 94 | sqrtmagic = math.sqrt(magic) 95 | dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi) 96 | dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi) 97 | mglat = lat + dlat 98 | mglng = lng + dlng 99 | return mglng, mglat 100 | 101 | 102 | # def gcj02_to_bd09(lng, lat): 103 | # """ 104 | # 火星坐标系(GCJ-02)转百度坐标系(BD-09) 105 | # 谷歌、高德——>百度 106 | # :param lng:火星坐标经度 107 | # :param lat:火星坐标纬度 108 | # :return: 109 | # """ 110 | # z = math.sqrt(lng * lng + lat * lat) + 0.00002 * math.sin(lat * x_pi) 111 | # theta = math.atan2(lat, lng) + 0.000003 * math.cos(lng * x_pi) 112 | # bd_lng = z * math.cos(theta) + 0.0065 113 | # bd_lat = z * math.sin(theta) + 0.006 114 | # return [bd_lng, bd_lat] 115 | # 116 | # 117 | # def bd09_to_gcj02(bd_lon, bd_lat): 118 | # """ 119 | # 百度坐标系(BD-09)转火星坐标系(GCJ-02) 120 | # 百度——>谷歌、高德 121 | # :param bd_lat:百度坐标纬度 122 | # :param bd_lon:百度坐标经度 123 | # :return:转换后的坐标列表形式 124 | # """ 125 | # x = bd_lon - 0.0065 126 | # y = bd_lat - 0.006 127 | # z = math.sqrt(x * x + y * y) - 0.00002 * math.sin(y * x_pi) 128 | # theta = math.atan2(y, x) - 0.000003 * math.cos(x * x_pi) 129 | # gg_lng = z * math.cos(theta) 130 | # gg_lat = z * math.sin(theta) 131 | # return [gg_lng, gg_lat] 132 | # 133 | # 134 | # def wgs84_to_gcj02(lng, lat): 135 | # """ 136 | # WGS84转GCJ02(火星坐标系) 137 | # :param lng:WGS84坐标系的经度 138 | # :param lat:WGS84坐标系的纬度 139 | # :return: 140 | # """ 141 | # if out_of_china(lng, lat): # 判断是否在国内 142 | # return [lng, lat] 143 | # dlat = _transformlat(lng - 105.0, lat - 35.0) 144 | # dlng = _transformlng(lng - 105.0, lat - 35.0) 145 | # radlat = lat / 180.0 * pi 146 | # magic = math.sin(radlat) 147 | # magic = 1 - ee * magic * magic 148 | # sqrtmagic = math.sqrt(magic) 149 | # dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi) 150 | # dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi) 151 | # mglat = lat + dlat 152 | # mglng = lng + dlng 153 | # return [mglng, mglat] 154 | # 155 | # 156 | # def gcj02_to_wgs84(lng, lat): 157 | # """ 158 | # GCJ02(火星坐标系)转GPS84 159 | # :param lng:火星坐标系的经度 160 | # :param lat:火星坐标系纬度 161 | # :return: 162 | # """ 163 | # if out_of_china(lng, lat): 164 | # return [lng, lat] 165 | # dlat = _transformlat(lng - 105.0, lat - 35.0) 166 | # dlng = _transformlng(lng - 105.0, lat - 35.0) 167 | # radlat = lat / 180.0 * pi 168 | # magic = math.sin(radlat) 169 | # magic = 1 - ee * magic * magic 170 | # sqrtmagic = math.sqrt(magic) 171 | # dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi) 172 | # dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi) 173 | # mglat = lat + dlat 174 | # mglng = lng + dlng 175 | # return [lng * 2 - mglng, lat * 2 - mglat] 176 | # 177 | # 178 | # def bd09_to_wgs84(bd_lon, bd_lat): 179 | # lon, lat = bd09_to_gcj02(bd_lon, bd_lat) 180 | # return gcj02_to_wgs84(lon, lat) 181 | # 182 | # 183 | # def wgs84_to_bd09(lon, lat): 184 | # lon, lat = wgs84_to_gcj02(lon, lat) 185 | # return gcj02_to_bd09(lon, lat) 186 | # 187 | # 188 | def _transformlat(lng, lat): 189 | ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \ 190 | 0.1 * lng * lat + 0.2 * math.sqrt(math.fabs(lng)) 191 | ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 * 192 | math.sin(2.0 * lng * pi)) * 2.0 / 3.0 193 | ret += (20.0 * math.sin(lat * pi) + 40.0 * 194 | math.sin(lat / 3.0 * pi)) * 2.0 / 3.0 195 | ret += (160.0 * math.sin(lat / 12.0 * pi) + 320 * 196 | math.sin(lat * pi / 30.0)) * 2.0 / 3.0 197 | return ret 198 | 199 | 200 | def _transformlng(lng, lat): 201 | ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \ 202 | 0.1 * lng * lat + 0.1 * math.sqrt(math.fabs(lng)) 203 | ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 * 204 | math.sin(2.0 * lng * pi)) * 2.0 / 3.0 205 | ret += (20.0 * math.sin(lng * pi) + 40.0 * 206 | math.sin(lng / 3.0 * pi)) * 2.0 / 3.0 207 | ret += (150.0 * math.sin(lng / 12.0 * pi) + 300.0 * 208 | math.sin(lng / 30.0 * pi)) * 2.0 / 3.0 209 | return ret 210 | 211 | 212 | def out_of_china(lng, lat): 213 | """ 214 | 判断是否在国内,不在国内不做偏移 215 | :param lng: 216 | :param lat: 217 | :return: 218 | """ 219 | return not (lng > 73.66 and lng < 135.05 and lat > 3.86 and lat < 53.55) 220 | 221 | 222 | # if __name__ == '__main__': 223 | # lng, lat = 116.527559, 39.807378 224 | # min_lat = 39.727 225 | # min_lng = 116.490 226 | # max_lat = 39.83 227 | # max_lng = 116.588 228 | # # lng = 109.642194 229 | # # lat = 20.123355 230 | # result1 = gcj02_to_bd09(lng, lat) 231 | # result2 = bd09_to_gcj02(lng, lat) 232 | # result3 = wgs84_to_gcj02(lng, lat) 233 | # result4 = gcj02_to_wgs84(lng, lat) 234 | # result5 = bd09_to_wgs84(lng, lat) 235 | # result6 = wgs84_to_bd09(lng, lat) 236 | # print(gcj02_to_wgs84(min_lng, min_lat)) 237 | # print(gcj02_to_wgs84(max_lng, max_lat)) 238 | -------------------------------------------------------------------------------- /utils/imputation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/22 16:29 4 | # @authoer: Huimin Ren 5 | from common.spatial_func import distance, cal_loc_along_line 6 | from common.trajectory import Trajectory, STPoint 7 | import numpy as np 8 | 9 | from map_matching.candidate_point import CandidatePoint 10 | from map_matching.hmm.hmm_map_matcher import TIHMMMapMatcher 11 | 12 | class Imputation: 13 | def __init__(self): 14 | pass 15 | 16 | def impute(self, traj): 17 | pass 18 | 19 | 20 | class LinearImputation(Imputation): 21 | """ 22 | Uniformly interpolate GPS points into a trajectory. 23 | Args: 24 | ----- 25 | time_interval: 26 | int. unit in seconds. the time interval between two points. 27 | """ 28 | def __init__(self, time_interval): 29 | super(LinearImputation, self).__init__() 30 | self.time_interval = time_interval 31 | 32 | def impute(self, traj): 33 | pt_list = traj.pt_list 34 | if len(pt_list) <= 1: 35 | return [] 36 | 37 | pre_pt = pt_list[0] 38 | new_pt_list = [pre_pt] 39 | for cur_pt in pt_list[1:]: 40 | time_span = (cur_pt.time - pre_pt.time).total_seconds() 41 | num_pts = 0 # do not have to interpolate points 42 | if time_span % self.time_interval > self.time_interval / 2: 43 | # if the reminder is larger than half of the time interval 44 | num_pts = time_span // self.time_interval # quotient 45 | elif time_span > self.time_interval: 46 | # if if the reminder is smaller than half of the time interval and not equal to time interval 47 | num_pts = time_span // self.time_interval - 1 48 | 49 | interp_list = [] 50 | unit_lat = abs(cur_pt.lat - pre_pt.lat) / (num_pts + 1) 51 | unit_lng = abs(cur_pt.lng - pre_pt.lng) / (num_pts + 1) 52 | unit_ts = (cur_pt.time - pre_pt.time) / (num_pts + 1) 53 | sign_lat = np.sign(cur_pt.lat - pre_pt.lat) 54 | sign_lng = np.sign(cur_pt.lng - pre_pt.lng) 55 | for i in range(int(num_pts)): 56 | new_lat = pre_pt.lat + sign_lat * (i+1) * unit_lat 57 | new_lng = pre_pt.lng + sign_lng * (i+1) * unit_lng 58 | new_time = pre_pt.time + (i+1) * unit_ts 59 | interp_list.append(STPoint(new_lat, new_lng, new_time)) 60 | 61 | new_pt_list.extend(interp_list) 62 | new_pt_list.append(cur_pt) 63 | pre_pt = cur_pt 64 | 65 | new_traj = Trajectory(traj.oid, traj.tid, new_pt_list) 66 | 67 | return new_traj 68 | 69 | 70 | class MMLinearImputation(Imputation): 71 | 72 | def __init__(self, time_interval): 73 | super(MMLinearImputation, self).__init__() 74 | self.time_interval = time_interval 75 | 76 | def impute(self, traj, rn, rn_dict): 77 | try: 78 | map_matcher = TIHMMMapMatcher(rn) 79 | mm_ls_path = map_matcher.match_to_path(traj)[0] # find shortest path 80 | except: 81 | # cannot find shortest path 82 | return None 83 | 84 | path_eids = [p.eid for p in mm_ls_path.path_entities] 85 | 86 | pre_mm_pt = traj.pt_list[0] 87 | new_pt_list = [pre_mm_pt] 88 | for cur_mm_pt in traj.pt_list[1:]: 89 | time_span = (cur_mm_pt.time - pre_mm_pt.time).total_seconds() 90 | num_pts = 0 # do not have to interpolate points 91 | if time_span % self.time_interval > self.time_interval / 2: 92 | # if the reminder is larger than half of the time interval 93 | num_pts = time_span // self.time_interval # quotient 94 | elif time_span > self.time_interval: 95 | # if if the reminder is smaller than half of the time interval and not equal to time interval 96 | num_pts = time_span // self.time_interval - 1 97 | 98 | if pre_mm_pt.data['candi_pt'] is None or cur_mm_pt.data['candi_pt'] is None: 99 | return None 100 | 101 | pre_eid = pre_mm_pt.data['candi_pt'].eid 102 | cur_eid = cur_mm_pt.data['candi_pt'].eid 103 | two_points_coords, two_points_eids, ttl_dis = self.get_two_points_coords(path_eids, pre_eid, cur_eid, 104 | pre_mm_pt, cur_mm_pt, rn_dict) 105 | interp_list = self.get_interp_list(num_pts, cur_mm_pt, pre_mm_pt, ttl_dis, 106 | two_points_eids, two_points_coords, rn_dict) 107 | 108 | new_pt_list.extend(interp_list) 109 | new_pt_list.append(cur_mm_pt) 110 | pre_mm_pt = cur_mm_pt 111 | new_traj = Trajectory(traj.oid, traj.tid, new_pt_list) 112 | 113 | # get all coords of shortest path 114 | # path_coords = [] 115 | # for eid in path_eids: 116 | # path_coords.extend(rn_dict[eid]['coords']) 117 | # path_pt_list = [] 118 | # for pt in path_coords: 119 | # path_pt_list.append([pt.lat, pt.lng]) 120 | 121 | return new_traj 122 | 123 | def get_interp_list(self, num_pts, cur_mm_pt, pre_mm_pt, ttl_dis, two_points_eids, two_points_coords, rn_dict): 124 | interp_list = [] 125 | unit_ts = (cur_mm_pt.time - pre_mm_pt.time) / (num_pts + 1) 126 | for n in range(int(num_pts)): 127 | new_time = pre_mm_pt.time + (n + 1) * unit_ts 128 | move_dis = (ttl_dis / num_pts) * n + pre_mm_pt.data['candi_pt'].offset 129 | 130 | # get eid and offset 131 | pre_road_dist, road_dist = 0, 0 132 | for i in range(len(two_points_eids)): 133 | if i > 0: 134 | pre_road_dist += rn_dict[two_points_eids[i - 1]]['length'] 135 | road_dist += rn_dict[two_points_eids[i]]['length'] 136 | if move_dis <= road_dist: 137 | insert_eid = two_points_eids[i] 138 | insert_offset = move_dis - pre_road_dist 139 | break 140 | 141 | # get lat and lng 142 | dist, pre_dist = 0, 0 143 | for i in range(len(two_points_coords) - 1): 144 | if i > 0: 145 | pre_dist += distance(two_points_coords[i - 1][0], two_points_coords[i][0]) 146 | dist += distance(two_points_coords[i][0], two_points_coords[i + 1][0]) 147 | if dist >= move_dis: 148 | coor_rate = (move_dis - pre_dist) / distance(two_points_coords[i][0], 149 | two_points_coords[i + 1][0]) 150 | project_pt = cal_loc_along_line(two_points_coords[i][0], two_points_coords[i + 1][0], coor_rate) 151 | break 152 | data = {'candi_pt': CandidatePoint(project_pt.lat, project_pt.lng, insert_eid, 0, insert_offset, 0)} 153 | interp_list.append(STPoint(project_pt.lat, project_pt.lng, new_time, data)) 154 | 155 | return interp_list 156 | 157 | def get_two_points_coords(self, path_eids, pre_eid, cur_eid, pre_mm_pt, cur_mm_pt, rn_dict): 158 | if pre_eid == cur_eid: 159 | # if in the same road 160 | two_points_eids = [path_eids[path_eids.index(pre_eid)]] 161 | two_points_coords = [[item, two_points_eids[0]] for item in rn_dict[two_points_eids[0]]['coords']] 162 | ttl_dis = cur_mm_pt.data['candi_pt'].offset - cur_mm_pt.data['candi_pt'].offset 163 | else: 164 | # if in different road 165 | start = path_eids.index(pre_eid) 166 | end = path_eids.index(cur_eid) + 1 167 | if start >= end: 168 | end = path_eids.index(cur_eid, start) # cur_eid shows at least twice 169 | two_points_eids = path_eids[start: end] 170 | 171 | two_points_coords = [] # 2D, [gps, eid] 172 | ttl_eids_dis = 0 173 | for eid in two_points_eids[:-1]: 174 | tmp_coords = [[item, eid] for item in rn_dict[eid]['coords']] 175 | two_points_coords.extend(tmp_coords) 176 | ttl_eids_dis += rn_dict[eid]['length'] 177 | ttl_dis = ttl_eids_dis - pre_mm_pt.data['candi_pt'].offset + cur_mm_pt.data['candi_pt'].offset 178 | tmp_coords = [[item, two_points_eids[-1]] for item in rn_dict[two_points_eids[-1]]['coords']] 179 | two_points_coords.extend(tmp_coords) # add coords after computing ttl_dis 180 | 181 | return two_points_coords, two_points_eids, ttl_dis 182 | -------------------------------------------------------------------------------- /utils/noise_filter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/7 17:07 4 | # Add new method 5 | from statistics import median 6 | 7 | from common.spatial_func import distance 8 | from common.trajectory import Trajectory, STPoint 9 | 10 | """ 11 | All the methods are refered from 12 | Zheng, Y. and Zhou, X. eds., 2011. 13 | Computing with spatial trajectories. Springer Science & Business Media. 14 | Chapter 1 Trajectory Preprocessing 15 | 16 | """ 17 | 18 | 19 | class NoiseFilter: 20 | def filter(self, traj): 21 | pass 22 | 23 | def get_tid(self, oid, clean_pt_list): 24 | return oid + '_' + clean_pt_list[0].time.strftime('%Y%m%d%H%M') + '_' + \ 25 | clean_pt_list[-1].time.strftime('%Y%m%d%H%M') 26 | 27 | 28 | # modify point 29 | class MedianFilter(NoiseFilter): 30 | """ 31 | Smooth each point with median value, since mean filter is sensitive to outliers. 32 | x_i = median(z_{i-n+1}, z_{i-n+2}, ..., z_{i-1}, z_{i}) eq.1.4 33 | i: ith point 34 | z: original points 35 | n: window size 36 | """ 37 | 38 | def __init__(self, win_size=3): 39 | super(MedianFilter, self).__init__() 40 | self.win_size = win_size // 2 # put the target value in middle 41 | 42 | def filter(self, traj): 43 | pt_list = traj.pt_list.copy() 44 | if len(pt_list) <= 1: 45 | return None 46 | 47 | for i in range(1, len(pt_list)-1): 48 | # post-preprocessing. make sure index_i is in the middle. 49 | wind_size = self.win_size 50 | lats, lngs = [], [] 51 | if self.win_size > i: 52 | wind_size = i 53 | elif self.win_size > len(pt_list) - 1 - i: 54 | wind_size = len(pt_list) - 1 - i 55 | # get smoothed location 56 | for pt in pt_list[i-wind_size:i+wind_size+1]: 57 | lats.append(pt.lat) 58 | lngs.append(pt.lng) 59 | 60 | pt_list[i] = STPoint(median(lats), median(lngs), pt_list[i].time) 61 | 62 | if len(pt_list) > 1: 63 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list) 64 | return Trajectory(traj.oid, traj.tid, pt_list) 65 | else: 66 | return None 67 | 68 | 69 | # modify point 70 | class HeuristicMeanFilter(NoiseFilter): 71 | """ 72 | Find outlier by speed (if the current speed is out of max speed) 73 | Replace outlier with mean 74 | Mean filter usually handles individual noise points with a dense representation. 75 | """ 76 | 77 | def __init__(self, max_speed, win_size=1): 78 | """ 79 | Args: 80 | ---- 81 | max_speed: 82 | int. m/s. threshold of noise speed 83 | win_size: 84 | int. prefer odd number. window size of calculating mean value to replace noise. 85 | """ 86 | super(NoiseFilter, self).__init__() 87 | 88 | self.max_speed = max_speed 89 | self.win_size = win_size 90 | 91 | def filter(self, traj): 92 | """ 93 | When previous speed and next speed both are larger than max speed, then considering it as outlier. 94 | Replace outlier with mean value. The range is defined by window size. 95 | consider about the boundary. 96 | make sure noise value is in the middle. 97 | 98 | Args: 99 | ----- 100 | traj: 101 | Trajectory(). a single trajectory 102 | Returns: 103 | -------- 104 | new_traj: 105 | Trajectory(). replace noise with mean or median 106 | """ 107 | pt_list = traj.pt_list.copy() 108 | if len(pt_list) <= 1: 109 | return None 110 | for i in range(1, len(pt_list) - 1): 111 | time_span_pre = (pt_list[i].time - pt_list[i - 1].time).total_seconds() 112 | dist_pre = distance(pt_list[i - 1], pt_list[i]) 113 | time_span_next = (pt_list[i + 1].time - pt_list[i].time).total_seconds() 114 | dist_next = distance(pt_list[i], pt_list[i + 1]) 115 | # compute current speed 116 | speed_pre = dist_pre / time_span_pre 117 | speed_next = dist_next / time_span_next 118 | # if the first point is noise 119 | if i == 1 and speed_pre > self.max_speed > speed_next: 120 | lat = pt_list[i].lat * 2 - pt_list[i + 1].lat 121 | lng = pt_list[i].lng * 2 - pt_list[i + 1].lng 122 | pt_list[0] = STPoint(lat, lng, pt_list[0].time) 123 | # if the last point is noise 124 | elif i == len(pt_list) - 2 and speed_next > self.max_speed >= speed_pre: 125 | lat = pt_list[i - 1].lat * 2 - pt_list[i - 2].lat 126 | lng = pt_list[i - 1].lng * 2 - pt_list[i - 2].lng 127 | pt_list[i + 1] = STPoint(lat, lng, pt_list[i].time) 128 | # if the middle point is noise 129 | elif speed_pre > self.max_speed and speed_next > self.max_speed: 130 | pt_list[i] = STPoint(0, 0, pt_list[i].time) 131 | lats, lngs = [], [] 132 | # fix index bug. make sure index_i is in the middle. 133 | wind_size = self.win_size 134 | if self.win_size > i: 135 | wind_size = i 136 | elif self.win_size > len(pt_list) - 1 - i: 137 | wind_size = len(pt_list) - 1 - i 138 | for pt in pt_list[i-wind_size:i+wind_size+1]: 139 | lats.append(pt.lat) 140 | lngs.append(pt.lng) 141 | 142 | lat = sum(lats) / (len(lats) - 1) 143 | lng = sum(lngs) / (len(lngs) - 1) 144 | pt_list[i] = STPoint(lat, lng, pt_list[i].time) 145 | 146 | if len(pt_list) > 1: 147 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list) 148 | return Trajectory(traj.oid, traj.tid, pt_list) 149 | else: 150 | return None 151 | 152 | 153 | # remove point 154 | class HeuristicFilter(NoiseFilter): 155 | """ 156 | Remove outlier if it is out of the max speed 157 | """ 158 | 159 | def __init__(self, max_speed): 160 | super(NoiseFilter, self).__init__() 161 | self.max_speed = max_speed 162 | 163 | def filter(self, traj): 164 | pt_list = traj.pt_list 165 | if len(pt_list) <= 1: 166 | return None 167 | 168 | remove_inds = [] 169 | for i in range(1, len(pt_list) - 1): 170 | time_span_pre = (pt_list[i].time - pt_list[i - 1].time).total_seconds() 171 | dist_pre = distance(pt_list[i - 1], pt_list[i]) 172 | time_span_next = (pt_list[i + 1].time - pt_list[i].time).total_seconds() 173 | dist_next = distance(pt_list[i], pt_list[i + 1]) 174 | speed_pre = dist_pre / time_span_pre 175 | speed_next = dist_next / time_span_next 176 | # the first point is outlier 177 | if i == 1 and speed_pre > self.max_speed > speed_next: 178 | remove_inds.append(0) 179 | # the last point is outlier 180 | elif i == len(pt_list) - 2 and speed_next > self.max_speed >= speed_pre: 181 | remove_inds.append(len(pt_list) - 1) 182 | # middle point is outlier 183 | elif speed_pre > self.max_speed and speed_next > self.max_speed: 184 | remove_inds.append(i) 185 | 186 | clean_pt_list = [] 187 | for j in range(len(pt_list)): 188 | if j in remove_inds: 189 | continue 190 | clean_pt_list.append(pt_list[j]) 191 | 192 | if len(clean_pt_list) > 1: 193 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list) 194 | return Trajectory(traj.oid, traj.tid, pt_list) 195 | else: 196 | return None 197 | 198 | 199 | # remove point 200 | class STFilter(NoiseFilter): 201 | """ 202 | remove point if it is out of mbr 203 | """ 204 | def __init__(self, mbr, start_time, end_time): 205 | super(STFilter, self).__init__() 206 | self.mbr = mbr 207 | self.start_time = start_time 208 | self.end_time = end_time 209 | 210 | def filter(self, traj): 211 | pt_list = traj.pt_list 212 | if len(pt_list) <= 1: 213 | return None 214 | clean_pt_list = [] 215 | for pt in pt_list: 216 | if self.start_time <= pt.time < self.end_time and self.mbr.contains(pt.lat, pt.lng): 217 | clean_pt_list.append(pt) 218 | if len(clean_pt_list) > 1: 219 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list) 220 | return Trajectory(traj.oid, traj.tid, pt_list) 221 | else: 222 | return None 223 | -------------------------------------------------------------------------------- /utils/parse_traj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/23 15:40 4 | # Reference: https://github.com/huiminren/tptk/blob/master/common/trajectory.py 5 | 6 | import re 7 | from datetime import datetime 8 | import pandas as pd 9 | 10 | from common.trajectory import Trajectory, STPoint 11 | from map_matching.candidate_point import CandidatePoint 12 | 13 | 14 | class ParseTraj: 15 | """ 16 | ParseTraj is an abstract class for parsing trajectory. 17 | It defines parse() function for parsing trajectory. 18 | """ 19 | def __init__(self): 20 | pass 21 | 22 | def parse(self, input_path): 23 | """ 24 | The parse() function is to load data to a list of Trajectory() 25 | """ 26 | pass 27 | 28 | 29 | class ParseRawTraj(ParseTraj): 30 | """ 31 | Parse original GPS points to trajectories list. No extra data preprocessing 32 | """ 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def parse(self, input_path): 37 | """ 38 | Args: 39 | ----- 40 | input_path: 41 | str. input directory with file name 42 | Returns: 43 | -------- 44 | trajs: 45 | list. list of trajectories. trajs contain input_path file's all gps points 46 | """ 47 | time_format = '%Y/%m/%d %H:%M:%S' 48 | tid_to_remove = '[:/ ]' 49 | with open(input_path, 'r') as f: 50 | trajs = [] 51 | pt_list = [] 52 | for line in f.readlines(): 53 | attrs = line.rstrip().split(',') 54 | if attrs[0] == '#': 55 | if len(pt_list) > 1: 56 | traj = Trajectory(oid, tid, pt_list) 57 | trajs.append(traj) 58 | oid = attrs[2] 59 | tid = attrs[1] 60 | pt_list = [] 61 | else: 62 | lat = float(attrs[1]) 63 | lng = float(attrs[2]) 64 | pt = STPoint(lat, lng, datetime.strptime(attrs[0], time_format)) 65 | # pt contains all the attributes of class STPoint 66 | pt_list.append(pt) 67 | if len(pt_list) > 1: 68 | traj = Trajectory(oid, tid, pt_list) 69 | trajs.append(traj) 70 | return trajs 71 | 72 | 73 | class ParseMMTraj(ParseTraj): 74 | """ 75 | Parse map matched GPS points to trajectories list. No extra data preprocessing 76 | """ 77 | def __init__(self): 78 | super().__init__() 79 | 80 | def parse(self, input_path): 81 | """ 82 | Args: 83 | ----- 84 | input_path: 85 | str. input directory with file name 86 | Returns: 87 | -------- 88 | trajs: 89 | list. list of trajectories. trajs contain input_path file's all gps points 90 | """ 91 | time_format = '%Y/%m/%d %H:%M:%S' 92 | tid_to_remove = '[:/ ]' 93 | with open(input_path, 'r') as f: 94 | trajs = [] 95 | pt_list = [] 96 | for line in f.readlines(): 97 | attrs = line.rstrip().split(',') 98 | if attrs[0] == '#': 99 | if len(pt_list) > 1: 100 | traj = Trajectory(oid, tid, pt_list) 101 | trajs.append(traj) 102 | oid = attrs[2] 103 | tid = attrs[1] 104 | pt_list = [] 105 | else: 106 | lat = float(attrs[1]) 107 | lng = float(attrs[2]) 108 | if attrs[3] == 'None': 109 | candi_pt = None 110 | else: 111 | eid = int(attrs[3]) 112 | proj_lat = float(attrs[4]) 113 | proj_lng = float(attrs[5]) 114 | error = float(attrs[6]) 115 | offset = float(attrs[7]) 116 | rate = float(attrs[8]) 117 | candi_pt = CandidatePoint(proj_lat, proj_lng, eid, error, offset, rate) 118 | pt = STPoint(lat, lng, datetime.strptime(attrs[0], time_format), {'candi_pt': candi_pt}) 119 | # pt contains all the attributes of class STPoint 120 | pt_list.append(pt) 121 | if len(pt_list) > 1: 122 | traj = Trajectory(oid, tid, pt_list) 123 | trajs.append(traj) 124 | return trajs 125 | 126 | 127 | class ParseJUSTInputTraj(ParseTraj): 128 | """ 129 | Parse JUST input format to list of Trajectory() 130 | """ 131 | def __init__(self): 132 | super().__init__() 133 | 134 | def parse(self, input_path): 135 | time_format = '%Y-%m-%d %H:%M:%S' 136 | with open(input_path, 'r') as f: 137 | trajs = [] 138 | pt_list = [] 139 | pre_tid = '' 140 | for line in f.readlines(): 141 | attrs = line.rstrip().split(',') 142 | tid = attrs[0] 143 | oid = attrs[1] 144 | time = datetime.strptime(attrs[2][:19], time_format) 145 | lat = float(attrs[3]) 146 | lng = float(attrs[4]) 147 | pt = STPoint(lat, lng, time) 148 | if pre_tid != tid: 149 | if len(pt_list) > 1: 150 | traj = Trajectory(oid, pre_tid, pt_list) 151 | trajs.append(traj) 152 | pt_list = [] 153 | pt_list.append(pt) 154 | pre_tid = tid 155 | if len(pt_list) > 1: 156 | traj = Trajectory(oid, tid, pt_list) 157 | trajs.append(traj) 158 | 159 | return trajs 160 | 161 | 162 | class ParseJUSTOutputTraj(ParseTraj): 163 | """ 164 | Parse JUST output to trajectories list. The output format will be the same as Trajectory() 165 | """ 166 | def __init__(self): 167 | super().__init__() 168 | 169 | def parse(self, input_path, feature_flag=False): 170 | """ 171 | Args: 172 | ----- 173 | input_path: 174 | str. input directory with file name 175 | Returns: 176 | -------- 177 | trajs: 178 | list of Trajectory() 179 | 180 | 'oid': object_id 181 | 'geom': line string of raw trajectory 182 | 'time': trajectory start time 183 | 'tid': trajectory id 184 | 'time_series': line string of map matched trajectory and containing other features 185 | raw start time, raw lng, raw lat, road segment ID, index of road segment, 186 | distance between raw point to map matched point, distanced between projected point and start of road segment. 187 | 'start_position': raw start position 188 | 'end_position': raw end position 189 | 'point_number': number of points in the trajectory 190 | 'length': distance of the trajectory in km 191 | 'speed': average speed of the trajectory in km/h 192 | 'signature': signature for GIS 193 | 'id': primary key 194 | """ 195 | col_names = ['oid', 'geom', 'tid', 'time_series'] 196 | df = pd.read_csv(input_path, sep='|', usecols=col_names) 197 | 198 | str_to_remove = '[LINESTRING()]' 199 | time_format = '%Y-%m-%d %H:%M:%S' 200 | trajs = [] 201 | for i in range(len(df)): 202 | oid = df['oid'][i] 203 | tid = str(df['tid'][i]) 204 | # prepare to get map matched lat, lng and original datetime for each GPS point 205 | time_series = df['time_series'][i] 206 | geom = re.sub(str_to_remove, "", df['geom'][i]) # load geom and remove "LINESTRING()" 207 | ts_list = time_series.split(';') # contain datetime of each gps point and original gps location 208 | geom_list = geom.split(',') # contain map matched gps points 209 | assert len(ts_list) == len(geom_list) 210 | 211 | pt_list = [] 212 | for j in range(len(ts_list)): 213 | tmp_location = geom_list[j].split(" ") 214 | tmp_features = ts_list[j].split(",") 215 | lat = tmp_location[2] 216 | lng = tmp_location[1] 217 | time = tmp_features[0][:19] # ts_list[j][:19] 218 | 219 | if feature_flag: 220 | rid = int(tmp_features[3]) 221 | rdis = float(tmp_features[-1][:-1]) 222 | pt = STPoint(lat, lng, datetime.strptime(time, time_format), rid=rid, rdis=rdis) 223 | else: 224 | pt = STPoint(lat, lng, datetime.strptime(time, time_format)) 225 | 226 | pt_list.append(pt) 227 | traj = Trajectory(oid, tid, pt_list) 228 | trajs.append(traj) 229 | 230 | return trajs 231 | -------------------------------------------------------------------------------- /utils/save_traj.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/23 15:40 4 | # Reference: https://github.com/huiminren/tptk/blob/master/common/trajectory.py 5 | from common.trajectory import get_tid 6 | from utils.coord_transform import GCJ02ToWGS84, WGS84ToGCJ02, Convert 7 | 8 | 9 | class SaveTraj: 10 | """ 11 | SaveTraj is an abstract class for storing trajectory. 12 | It defines store() function for storing trajectory to different format. 13 | """ 14 | def __init__(self, convert_method): 15 | # GCJ: Auto, Didi 16 | # WGS: OSM, Tiandi 17 | if convert_method == 'GCJ02ToWGS84': 18 | self.convert = GCJ02ToWGS84() 19 | elif convert_method == 'WGS84ToGCJ02': 20 | self.convert = WGS84ToGCJ02() 21 | elif convert_method is None: 22 | self.convert = Convert() 23 | 24 | def store(self, trajs, target_path): 25 | pass 26 | 27 | 28 | class SaveTraj2Raw(SaveTraj): 29 | def __init__(self, convert_method=None): 30 | super().__init__(convert_method) 31 | 32 | def store(self, trajs, target_path): 33 | time_format = '%Y/%m/%d %H:%M:%S' 34 | with open(target_path, 'w') as f: 35 | for traj in trajs: 36 | pt_list = traj.pt_list 37 | tid = get_tid(traj.oid, pt_list) 38 | f.write('#,{},{},{},{},{} km\n'.format(tid, traj.oid, pt_list[0].time.strftime(time_format), 39 | pt_list[-1].time.strftime(time_format), 40 | traj.get_distance() / 1000)) 41 | for pt in pt_list: 42 | lng, lat = self.convert.convert(pt.lng, pt.lat) 43 | f.write('{},{},{}\n'.format( 44 | pt.time.strftime(time_format), lat, lng)) 45 | 46 | 47 | class SaveTraj2MM(SaveTraj): 48 | """ 49 | """ 50 | def __init__(self, convert_method=None): 51 | super().__init__(convert_method) 52 | 53 | def store(self, trajs, target_path): 54 | time_format = '%Y/%m/%d %H:%M:%S' 55 | with open(target_path, 'w') as f: 56 | for traj in trajs: 57 | pt_list = traj.pt_list 58 | tid = get_tid(traj.oid, pt_list) 59 | f.write('#,{},{},{},{},{} km\n'.format(tid, traj.oid, pt_list[0].time.strftime(time_format), 60 | pt_list[-1].time.strftime(time_format), 61 | traj.get_distance() / 1000)) 62 | for pt in pt_list: 63 | candi_pt = pt.data['candi_pt'] 64 | if candi_pt is not None: 65 | f.write('{},{},{},{},{},{},{},{},{}\n'.format(pt.time.strftime(time_format), pt.lat, pt.lng, 66 | candi_pt.eid, candi_pt.lat, candi_pt.lng, 67 | candi_pt.error, candi_pt.offset, candi_pt.rate)) 68 | else: 69 | f.write('{},{},{},None,None,None,None,None,None\n'.format( 70 | pt.time.strftime(time_format), pt.lat, pt.lng)) 71 | 72 | 73 | class SaveTraj2JUST(SaveTraj): 74 | """ 75 | Convert trajs to JUST format. 76 | cvs file. trajectory_id, oid, time, lat, lng 77 | """ 78 | def __init__(self, convert_method=None): 79 | super().__init__(convert_method) 80 | 81 | def store(self, trajs, target_path): 82 | """ 83 | Convert trajs to JUST format. 84 | cvs file. trajectory_id (primary key), oid, time, lat, lng 85 | Args: 86 | ---- 87 | trajs: 88 | list. list of Trajectory() 89 | target_path: 90 | str. target path (directory + file_name) 91 | """ 92 | with open(target_path, 'w') as f: 93 | for traj in trajs: 94 | for pt in traj.pt_list: 95 | lng, lat = self.convert.convert(pt.lng, pt.lat) 96 | f.write('{},{},{},{},{}\n'.format(traj.tid, traj.oid, pt.time, lat, lng)) 97 | 98 | -------------------------------------------------------------------------------- /utils/segmentation.py: -------------------------------------------------------------------------------- 1 | from common.trajectory import Trajectory, get_tid 2 | 3 | 4 | class Segmentation: 5 | def __init__(self): 6 | pass 7 | 8 | def segment(self, traj): 9 | pass 10 | 11 | 12 | class TimeIntervalSegmentation(Segmentation): 13 | """ 14 | Split trajectory if the time interval between two GPS points is larger than max_time_interval_min 15 | Store sub-trajectory when its length is larger than min_len 16 | """ 17 | def __init__(self, max_time_interval_min, min_len=1): 18 | super(Segmentation, self).__init__() 19 | self.max_time_interval = max_time_interval_min * 60 20 | self.min_len = min_len 21 | 22 | def segment(self, traj): 23 | segmented_traj_list = [] 24 | pt_list = traj.pt_list 25 | if len(pt_list) <= 1: 26 | return [] 27 | oid = traj.oid 28 | pre_pt = pt_list[0] 29 | partial_pt_list = [pre_pt] 30 | for cur_pt in pt_list[1:]: 31 | time_span = (cur_pt.time - pre_pt.time).total_seconds() 32 | if time_span <= self.max_time_interval: 33 | partial_pt_list.append(cur_pt) 34 | else: 35 | if len(partial_pt_list) > self.min_len: 36 | segmented_traj = Trajectory(oid, get_tid(oid, partial_pt_list), partial_pt_list) 37 | segmented_traj_list.append(segmented_traj) 38 | partial_pt_list = [cur_pt] # re-initialize partial_pt_list 39 | pre_pt = cur_pt 40 | if len(partial_pt_list) > self.min_len: 41 | segmented_traj = Trajectory(oid, get_tid(oid, partial_pt_list), partial_pt_list) 42 | segmented_traj_list.append(segmented_traj) 43 | return segmented_traj_list 44 | 45 | 46 | -------------------------------------------------------------------------------- /utils/stats.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/3 17:08 4 | # Refernce: https://github.com/sjruan/tptk/blob/master/statistics.py 5 | 6 | from utils.utils import create_dir 7 | 8 | import matplotlib.pyplot as plt 9 | from matplotlib.ticker import PercentFormatter 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import os 14 | import json 15 | 16 | 17 | def plot_hist(data, x_axis, save_stats_dir, pic_name): 18 | plt.hist(data, weights=np.ones(len(data)) / len(data)) 19 | plt.gca().yaxis.set_major_formatter(PercentFormatter(1)) 20 | plt.xlabel(x_axis) 21 | plt.ylabel('Percentage') 22 | plt.savefig(os.path.join(save_stats_dir, pic_name)) 23 | plt.clf() 24 | 25 | 26 | def statistics(trajs, save_stats_dir, stats, save_stats_name, save_plot=False): 27 | """ 28 | Plot basic statistical analysis , such as 29 | Args: 30 | ----- 31 | traj_dir: 32 | str. directory of raw GPS points 33 | save_stats_dir: 34 | str. directory of saving stats results 35 | stats: 36 | dict. dictionary of stats 37 | save_stats_name: 38 | str. name of saving stats. 39 | plot_flat: 40 | boolean. if plot the histogram 41 | """ 42 | 43 | create_dir(save_stats_dir) 44 | 45 | oids = set() 46 | tot_pts = 0 47 | 48 | distance_data = [] # geographical distance 49 | duration_data = [] # time difference between end and start time of a trajectory 50 | seq_len_data = [] # length of each trajectory 51 | traj_avg_time_interval_data = [] 52 | traj_avg_dist_interval_data = [] 53 | 54 | if len(stats) == 0: 55 | # if new, initialize stats with keys 56 | stats['#object'], stats['#points'], stats['#trajectories'] = 0, 0, 0 57 | stats['seq_len_data'], stats['distance_data'], stats['duration_data'], \ 58 | stats['traj_avg_time_interval_data'], stats['traj_avg_dist_interval_data'] = [], [], [], [], [] 59 | 60 | for traj in trajs: 61 | oids.add(traj.oid) 62 | nb_pts = len(traj.pt_list) 63 | tot_pts += nb_pts 64 | 65 | seq_len_data.append(nb_pts) 66 | distance_data.append(traj.get_distance() / 1000.0) 67 | duration_data.append(traj.get_duration() / 60.0) 68 | traj_avg_time_interval_data.append(traj.get_avg_time_interval()) 69 | traj_avg_dist_interval_data.append(traj.get_avg_distance_interval()) 70 | 71 | print('#objects_single:{}'.format(len(oids))) 72 | print('#points_single:{}'.format(tot_pts)) 73 | print('#trajectories_single:{}'.format(len(trajs))) 74 | 75 | stats['#object'] += len(oids) 76 | stats['#points'] += tot_pts 77 | stats['#trajectories'] += len(trajs) 78 | stats['seq_len_data'] += seq_len_data 79 | stats['distance_data'] += distance_data 80 | stats['duration_data'] += duration_data 81 | stats['traj_avg_time_interval_data'] += traj_avg_time_interval_data 82 | stats['traj_avg_dist_interval_data'] += traj_avg_dist_interval_data 83 | 84 | print('#objects_total:{}'.format(stats['#object'])) 85 | print('#points_total:{}'.format(stats['#points'])) 86 | print('#trajectories_total:{}'.format(stats['#trajectories'])) 87 | 88 | with open(os.path.join(save_stats_dir, save_stats_name + '.json'), 'w') as f: 89 | json.dump(stats, f) 90 | 91 | if save_plot: 92 | plot_hist(stats['seq_len_data'], '#Points', save_stats_dir, save_stats_name + '_nb_points_dist.png') 93 | plot_hist(stats['distance_data'], 'Distance (KM)', save_stats_dir, save_stats_name + '_distance_dist.png') 94 | plot_hist(stats['duration_data'], 'Duration (Min)', save_stats_dir, save_stats_name + '_duration_dist.png') 95 | plot_hist(stats['traj_avg_time_interval_data'], 'Time Interval (Sec)', save_stats_dir, 96 | save_stats_name + '_time_interval_dist.png') 97 | plot_hist(stats['traj_avg_dist_interval_data'], 'Distance Interval (Meter)', save_stats_dir, 98 | save_stats_name + '_distance_interval_dist.png') 99 | 100 | return stats 101 | 102 | 103 | def stats_threshold(trajs): 104 | """ 105 | Find threshold for trajectory preprocessing, 106 | including time interval for splitting; minimal length and maximal abnormal ratio. 107 | Args: 108 | ----- 109 | trajs: 110 | list of Trajectory(). Sampled trajectories. 111 | Returns: 112 | -------- 113 | thr_min_len, thr_max_abn_ratio, thr_normal_ratio, thr_split_traj_ts 114 | """ 115 | stats = {'oid': [], 'tid': [], 'traj_len': [], 'num_ab_pts': [], 'abn_ts': [], 'avg_ts': [], 'avg_abn_ts': []} 116 | 117 | for i in range(len(trajs)): 118 | stats['oid'].append(trajs[i].oid) 119 | stats['tid'].append(trajs[i].tid) 120 | stats['traj_len'].append(len(trajs[i].pt_list)) 121 | 122 | pt_list = trajs[i].pt_list 123 | pre_pt = pt_list[0] 124 | 125 | time_spans = [] 126 | abn_ts_list = [] 127 | for cur_pt in pt_list[1:]: 128 | time_span = (cur_pt.time - pre_pt.time).total_seconds() 129 | if time_span > 4: 130 | abn_ts_list.append(time_span) 131 | time_spans.append(time_span) 132 | pre_pt = cur_pt 133 | 134 | stats['num_ab_pts'].append(len(abn_ts_list)) 135 | stats['abn_ts'].append(abn_ts_list) 136 | stats['avg_ts'].append(sum(time_spans) / len(time_spans)) 137 | try: 138 | # in case len(abn_ts_list) = 0 139 | stats['avg_abn_ts'].append(sum(abn_ts_list) / len(abn_ts_list)) 140 | except: 141 | stats['avg_abn_ts'].append(0) 142 | 143 | df_stats = pd.DataFrame(stats) 144 | df_stats['abn_ratio'] = df_stats.num_ab_pts / df_stats.traj_len 145 | 146 | thr_min_len = df_stats.traj_len.quantile(0.1) # if less than min length(number of points), remove 147 | thr_max_abn_ratio = df_stats.abn_ratio.quantile(0.8) # if larger than max abnormal ratio, remove 148 | # thr_normal_ratio = 0 # if abn_ratio == 0, use the trajectory directly 149 | # thr_split_traj_ts = 60 # if time_interval is larger than 60 seconds, split trajectories. 150 | # print("min len: {}, max abnormal ratio: {}, normal ratio: {}, split threshold: {} seconds". 151 | # format(thr_min_len, thr_max_abn_ratio, thr_normal_ratio, thr_split_traj_ts)) 152 | print("min len: {}, max abnormal ratio: {}".format(thr_min_len, thr_max_abn_ratio)) 153 | 154 | return df_stats, thr_min_len, thr_max_abn_ratio #, thr_normal_ratio, thr_split_traj_ts 155 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | # coding: utf-8 3 | # @Time : 2020/9/2 19:58 4 | 5 | 6 | import pickle 7 | import json 8 | import random 9 | import os 10 | import numpy as np 11 | 12 | 13 | def create_dir(directory): 14 | """ 15 | Creates a directory if it does not already exist. 16 | """ 17 | if not os.path.exists(directory): 18 | os.makedirs(directory) 19 | 20 | 21 | def save_pkl_data(data, dir, file_name): 22 | create_dir(dir) 23 | pickle.dump(data, open(dir + file_name, 'wb')) 24 | 25 | 26 | def load_pkl_data(dir, file_name): 27 | ''' 28 | Args: 29 | ----- 30 | path: path 31 | filename: file name 32 | Returns: 33 | -------- 34 | data: loaded data 35 | ''' 36 | file = open(dir+file_name, 'rb') 37 | data = pickle.load(file) 38 | file.close() 39 | return data 40 | 41 | 42 | def save_json_data(data, dir, file_name): 43 | create_dir(dir) 44 | with open(dir+file_name, 'w') as fp: 45 | json.dump(data, fp) 46 | 47 | 48 | def load_json_data(dir, file_name): 49 | with open(dir+file_name, 'r') as fp: 50 | data = json.load(fp) 51 | return data --------------------------------------------------------------------------------