├── LICENSE
├── README.md
├── common
├── __init__.py
├── grid.py
├── mbr.py
├── path.py
├── road_network.py
├── spatial_func.py
└── trajectory.py
├── data
├── map
│ ├── extra_info
│ │ ├── new2raw_rid.json
│ │ ├── poi20.csv
│ │ ├── raw2new_rid.json
│ │ ├── raw_rn_dict.json
│ │ ├── rn_dict.json
│ │ └── weather_dict.pkl
│ └── road_network
│ │ ├── edges.dbf
│ │ ├── edges.shp
│ │ ├── edges.shx
│ │ ├── nodes.dbf
│ │ ├── nodes.shp
│ │ └── nodes.shx
└── raw_trajectory
│ └── 09_01.txt
├── doc
└── KDD21_MTrajRec_Huimin.pdf
├── img
├── intro.png
└── model_framework.png
├── map_matching
├── __init__.py
├── candidate_point.py
├── hmm
│ ├── __init__.py
│ ├── hmm_map_matcher.py
│ ├── hmm_probabilities.py
│ └── ti_viterbi.py
├── map_matcher.py
├── route_constructor.py
└── utils.py
├── models
├── __init__.py
├── datasets.py
├── loss_fn.py
├── model_utils.py
├── models_attn_tandem.py
└── multi_train.py
├── multi_main.py
└── utils
├── __init__.py
├── coord_transform.py
├── imputation.py
├── noise_filter.py
├── parse_traj.py
├── save_traj.py
├── segmentation.py
├── stats.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 paperanonymous945
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MTrajRec
2 |
3 |
4 |
5 |
6 |
7 | ## About
8 | Source code of the KDD'21: [MTrajRec: Map-Constrained Trajectory Recovery via Seq2Seq Multi-task Learning](doc/KDD21_MTrajRec_Huimin.pdf)
9 | ## Requirements
10 | * Python==3.6
11 | * `pytorch==1.7.1`
12 | * `rtree==0.9.4`
13 | * `GDAL==2.3.3`
14 | * `networkx==2.3`
15 |
16 | ## Usage
17 | ### Installation
18 | #### Clone this repo:
19 | ```bash
20 | git clone git@github.com:huiminren/MTrajRec.git
21 | cd MTrajRec
22 | ```
23 | #### Running
24 | `python multi_main.py`
25 |
26 | #### Dataset
27 | We provide sample data under data/.
28 |
29 | Please note that the sample data is generated with the structure as the original data. For the data preprocessing, please refer to [tptk](https://github.com/sjruan/tptk).
30 |
31 | ## Acknowledge
32 | Thanks to [Sijie](https://github.com/sjruan/) to support data preprocessing.
33 |
34 | ## Citation
35 | If you find this repo useful and would like to cite it, citing our paper as the following will be really appropriate:
36 | ```
37 | @inproceedings{ren2021mtrajrec,
38 | title={MTrajRec: Map-Constrained Trajectory Recovery via Seq2Seq Multi-task Learning},
39 | author={Ren, Huimin and Ruan, Sijie and Li, Yanhua and Bao, Jie and Meng, Chuishi and Li, Ruiyuan and Zheng, Yu},
40 | booktitle={Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining},
41 | pages={1410--1419},
42 | year={2021}
43 | }
44 | ```
45 |
46 |
47 |
48 |
--------------------------------------------------------------------------------
/common/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/3 13:34
4 |
--------------------------------------------------------------------------------
/common/grid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .mbr import MBR
3 | from .spatial_func import LAT_PER_METER, LNG_PER_METER
4 |
5 |
6 | class Grid:
7 | """
8 | index order
9 | 30 31 32 33 34...
10 | 20 21 22 23 24...
11 | 10 11 12 13 14...
12 | 00 01 02 03 04...
13 | """
14 |
15 | def __init__(self, mbr, row_num, col_num):
16 | self.mbr = mbr
17 | self.row_num = row_num
18 | self.col_num = col_num
19 | self.lat_interval = (mbr.max_lat - mbr.min_lat) / float(row_num)
20 | self.lng_interval = (mbr.max_lng - mbr.min_lng) / float(col_num)
21 |
22 | def get_row_idx(self, lat):
23 | row_idx = int((lat - self.mbr.min_lat) // self.lat_interval)
24 | if row_idx >= self.row_num or row_idx < 0:
25 | raise IndexError("lat is out of mbr")
26 | return row_idx
27 |
28 | def get_col_idx(self, lng):
29 | col_idx = int((lng - self.mbr.min_lng) // self.lng_interval)
30 | if col_idx >= self.col_num or col_idx < 0:
31 | raise IndexError("lng is out of mbr")
32 | return col_idx
33 |
34 | def safe_matrix_to_idx(self, lat, lng):
35 | try:
36 | return self.get_matrix_idx(lat, lng)
37 | except IndexError:
38 | return np.nan, np.nan
39 |
40 | def get_idx(self, lat, lng):
41 | return self.get_row_idx(lat), self.get_col_idx(lng)
42 |
43 | def get_matrix_idx(self, lat, lng):
44 | return self.row_num - 1 - self.get_row_idx(lat), self.get_col_idx(lng)
45 |
46 | def get_min_lng(self, col_idx):
47 | return self.mbr.min_lng + col_idx * self.lng_interval
48 |
49 | def get_max_lng(self, col_idx):
50 | return self.mbr.min_lng + (col_idx + 1) * self.lng_interval
51 |
52 | def get_min_lat(self, row_idx):
53 | return self.mbr.min_lat + row_idx * self.lat_interval
54 |
55 | def get_max_lat(self, row_idx):
56 | return self.mbr.min_lat + (row_idx + 1) * self.lat_interval
57 |
58 | def get_mbr_by_idx(self, row_idx, col_idx):
59 | min_lat = self.get_min_lat(row_idx)
60 | max_lat = self.get_max_lat(row_idx)
61 | min_lng = self.get_min_lng(col_idx)
62 | max_lng = self.get_max_lng(col_idx)
63 | return MBR(min_lat, min_lng, max_lat, max_lng)
64 |
65 | def get_mbr_by_matrix_idx(self, mat_row_idx, mat_col_idx):
66 | row_idx = self.row_num - 1 - mat_row_idx
67 | min_lat = self.get_min_lat(row_idx)
68 | max_lat = self.get_max_lat(row_idx)
69 | min_lng = self.get_min_lng(mat_col_idx)
70 | max_lng = self.get_max_lng(mat_col_idx)
71 | return MBR(min_lat, min_lng, max_lat, max_lng)
72 |
73 | def range_query(self, query_mbr, type):
74 | target_idx = []
75 | # squeeze the mbr a little, since the top and right boundary are belong to the other grid
76 | delta = 1e-7
77 | min_lat = max(query_mbr.min_lat, self.mbr.min_lat)
78 | min_lng = max(query_mbr.min_lng, self.mbr.min_lng)
79 | max_lat = min(query_mbr.max_lat, self.mbr.max_lat) - delta
80 | max_lng = min(query_mbr.max_lng, self.mbr.max_lng) - delta
81 | if type == 'matrix':
82 | max_row_idx, min_col_idx = self.get_matrix_idx(min_lat, min_lng)
83 | min_row_idx, max_col_idx = self.get_matrix_idx(max_lat, max_lng)
84 | elif type == 'cartesian':
85 | min_row_idx, min_col_idx = self.get_idx(min_lat, min_lng)
86 | max_row_idx, max_col_idx = self.get_idx(max_lat, max_lng)
87 | else:
88 | raise Exception('unrecognized index type')
89 | for r_idx in range(min_row_idx, max_row_idx + 1):
90 | for c_idx in range(min_col_idx, max_col_idx + 1):
91 | target_idx.append((r_idx, c_idx))
92 | return target_idx
93 |
94 |
95 | # def create_grid(min_lat, min_lng, km_per_cell_lat, km_per_cell_lng, km_lat, km_lng):
96 | # nb_rows = int(km_lat / km_per_cell_lat)
97 | # nb_cols = int(km_lng / km_per_cell_lng)
98 | # max_lat = min_lat + LAT_PER_METER * km_lat * 1000.0
99 | # max_lng = min_lng + LNG_PER_METER * km_lng * 1000.0
100 | # mbr = MBR(min_lat, min_lng, max_lat, max_lng)
101 | # return Grid(mbr, nb_rows, nb_cols)
102 |
103 | def create_grid(min_lat, min_lng, max_lat, max_lng, km_per_cell_lat, km_per_cell_lng):
104 | """
105 | Given region and unit of each cell, return a Grid class.
106 | Update original function since it's difficult to know the length of lat and lng.
107 | """
108 | mbr = MBR(min_lat, min_lng, max_lat, max_lng)
109 | km_lat = mbr.get_h
110 | km_lng = mbr.get_w
111 | nb_rows = int(km_lat / km_per_cell_lat)
112 | nb_cols = int(km_lng / km_per_cell_lng)
113 | return Grid(mbr, nb_rows, nb_cols)
114 |
--------------------------------------------------------------------------------
/common/mbr.py:
--------------------------------------------------------------------------------
1 | from .spatial_func import distance, SPoint
2 |
3 | class MBR:
4 | """
5 | MBR creates the minimal bounding regions for users.
6 | """
7 | def __init__(self, min_lat, min_lng, max_lat, max_lng):
8 | self.min_lat = min_lat
9 | self.min_lng = min_lng
10 | self.max_lat = max_lat
11 | self.max_lng = max_lng
12 |
13 | def contains(self, lat, lng):
14 | # return self.min_lat <= lat <= self.max_lat and self.min_lng <= lng <= self.max_lng
15 | # remove = max.lat/max.lng, to be consist with grid index
16 | return self.min_lat <= lat < self.max_lat and self.min_lng <= lng < self.max_lng
17 |
18 | def center(self):
19 | return (self.min_lat + self.max_lat) / 2.0, (self.min_lng + self.max_lng) / 2.0
20 |
21 | def get_h(self):
22 | return distance(SPoint(self.min_lat, self.min_lng), SPoint(self.max_lat, self.min_lng))
23 |
24 | def get_w(self):
25 | return distance(SPoint(self.min_lat, self.min_lng), SPoint(self.min_lat, self.max_lng))
26 |
27 | def __str__(self):
28 | h = self.get_h()
29 | w = self.get_w()
30 | return '{}x{}m2'.format(h, w)
31 |
32 | def __eq__(self, other):
33 | return self.min_lat == other.min_lat and self.min_lng == other.min_lng \
34 | and self.max_lat == other.max_lat and self.max_lng == other.max_lng
35 |
36 | def to_wkt(self):
37 | # Here providing five points is for GIS visualization
38 | # sometimes wkt cannot draw a rectangle without the last point.
39 | # (the last point should be the same as the first one)
40 | return 'POLYGON (({} {}, {} {}, {} {}, {} {}, {} {}))'.format(self.min_lng, self.min_lat,
41 | self.min_lng, self.max_lat,
42 | self.max_lng, self.max_lat,
43 | self.max_lng, self.min_lat,
44 | self.min_lng, self.min_lat)
45 |
46 | @staticmethod
47 | # staticmethod means this function will not use self attribute.
48 | def cal_mbr(coords):
49 | """
50 | Find MBR from coordinates
51 | Args:
52 | -----
53 | coords:
54 | list of Point()
55 | Returns:
56 | -------
57 | MBR()
58 | """
59 | min_lat = float('inf')
60 | min_lng = float('inf')
61 | max_lat = float('-inf')
62 | max_lng = float('-inf')
63 | for coord in coords:
64 | if coord.lat > max_lat:
65 | max_lat = coord.lat
66 | if coord.lat < min_lat:
67 | min_lat = coord.lat
68 | if coord.lng > max_lng:
69 | max_lng = coord.lng
70 | if coord.lng < min_lng:
71 | min_lng = coord.lng
72 | return MBR(min_lat, min_lng, max_lat, max_lng)
73 |
74 | @staticmethod
75 | def load_mbr(file_path):
76 | with open(file_path, 'r') as f:
77 | f.readline()
78 | attrs = f.readline()[:-1].split(';')
79 | mbr = MBR(float(attrs[1]), float(attrs[2]), float(attrs[3]), float(attrs[4]))
80 | return mbr
81 |
82 | @staticmethod
83 | def store_mbr(mbr, file_path):
84 | with open(file_path, 'w') as f:
85 | f.write('name;min_lat;min_lng;max_lat;max_lng;wkt\n')
86 | f.write('{};{};{};{};{};{}\n'.format(0, mbr.min_lat, mbr.min_lng, mbr.max_lat, mbr.max_lng, mbr.to_wkt()))
87 |
--------------------------------------------------------------------------------
/common/path.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 |
3 | class PathEntity:
4 | def __init__(self, enter_time, leave_time, eid):
5 | self.enter_time = enter_time
6 | self.leave_time = leave_time
7 | self.eid = eid
8 |
9 |
10 | class Path:
11 | def __init__(self, oid, pid, path_entities):
12 | self.oid = oid
13 | self.pid = pid
14 | self.path_entities = path_entities
15 |
16 |
17 | def parse_path_file(input_path):
18 | time_format = '%Y-%m-%d %H:%M:%S.%f'
19 | paths = []
20 | with open(input_path, 'r') as f:
21 | path_entities = []
22 | pid = None
23 | for line in f.readlines():
24 | attrs = line.rstrip().split(',')
25 | if attrs[0] == '#':
26 | if len(path_entities) != 0:
27 | paths.append(Path(oid, pid, path_entities))
28 | oid = attrs[2]
29 | pid = attrs[1]
30 | path_entities = []
31 | else:
32 | enter_time = datetime.strptime(attrs[0], time_format)
33 | leave_time = datetime.strptime(attrs[1], time_format)
34 | eid = int(attrs[2])
35 | path_entities.append(PathEntity(enter_time, leave_time, eid))
36 | if len(path_entities) != 0:
37 | paths.append(Path(oid, pid, path_entities))
38 | return paths
39 |
40 |
41 | def store_path_file(paths, target_path):
42 | with open(target_path, 'w') as f:
43 | for path in paths:
44 | path_entities = path.path_entities
45 | f.write('#,{},{},{},{}\n'.format(path.pid, path.oid,
46 | path_entities[0].enter_time.isoformat(sep=' ', timespec='milliseconds'),
47 | path_entities[-1].leave_time.isoformat(sep=' ', timespec='milliseconds')))
48 | for path_entity in path_entities:
49 | f.write('{},{},{}\n'.format(path_entity.enter_time.isoformat(sep=' ', timespec='milliseconds'),
50 | path_entity.leave_time.isoformat(sep=' ', timespec='milliseconds'),
51 | path_entity.eid))
52 |
--------------------------------------------------------------------------------
/common/road_network.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | from rtree import Rtree
3 | from osgeo import ogr
4 | from .spatial_func import SPoint, distance
5 | from .mbr import MBR
6 | import copy
7 |
8 |
9 | class UndirRoadNetwork(nx.Graph):
10 | def __init__(self, g, edge_spatial_idx, edge_idx):
11 | super(UndirRoadNetwork, self).__init__(g)
12 | # entry: eid
13 | self.edge_spatial_idx = edge_spatial_idx
14 | # eid -> edge key (start_coord, end_coord)
15 | self.edge_idx = edge_idx
16 |
17 | def to_directed(self, as_view=False):
18 | """
19 | Convert undirected road network to directed road network
20 | new edge will have new eid, and each original edge will have two edge with reversed coords
21 | :return:
22 | """
23 | assert as_view is False, "as_view is not supported"
24 | avail_eid = max([eid for u, v, eid in self.edges.data(data='eid')]) + 1
25 | g = nx.DiGraph()
26 | edge_spatial_idx = Rtree()
27 | edge_idx = {}
28 | # add nodes
29 | for n, data in self.nodes(data=True):
30 | # when data=True, it means will data=node's attributes
31 | new_data = copy.deepcopy(data)
32 | g.add_node(n, **new_data)
33 | # add edges
34 | for u, v, data in self.edges(data=True):
35 | mbr = MBR.cal_mbr(data['coords'])
36 | # add forward edge
37 | forward_data = copy.deepcopy(data)
38 | g.add_edge(u, v, **forward_data)
39 | edge_spatial_idx.insert(forward_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
40 | edge_idx[forward_data['eid']] = (u, v)
41 | # add backward edge
42 | backward_data = copy.deepcopy(data)
43 | backward_data['eid'] = avail_eid
44 | avail_eid += 1
45 | backward_data['coords'].reverse()
46 | g.add_edge(v, u, **backward_data)
47 | edge_spatial_idx.insert(backward_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
48 | edge_idx[backward_data['eid']] = (v, u)
49 | print('# of nodes:{}'.format(g.number_of_nodes()))
50 | print('# of edges:{}'.format(g.number_of_edges()))
51 | return RoadNetwork(g, edge_spatial_idx, edge_idx)
52 |
53 | def range_query(self, mbr):
54 | """
55 | spatial range query. Given a mbr, return a range of edges.
56 | :param mbr: query mbr
57 | :return: qualified edge keys
58 | """
59 | eids = self.edge_spatial_idx.intersection((mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
60 | return [self.edge_idx[eid] for eid in eids]
61 |
62 | def remove_edge(self, u, v):
63 | edge_data = self[u][v]
64 | coords = edge_data['coords']
65 | mbr = MBR.cal_mbr(coords)
66 | # delete self.edge_idx[eid] from edge index
67 | del self.edge_idx[edge_data['eid']]
68 | # delete from spatial index
69 | self.edge_spatial_idx.delete(edge_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
70 | # delete from graph
71 | super(UndirRoadNetwork, self).remove_edge(u, v)
72 |
73 | def add_edge(self, u_of_edge, v_of_edge, **attr):
74 | coords = attr['coords']
75 | mbr = MBR.cal_mbr(coords)
76 | attr['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)])
77 | # add edge to edge index
78 | self.edge_idx[attr['eid']] = (u_of_edge, v_of_edge)
79 | # add edge to spatial index
80 | self.edge_spatial_idx.insert(attr['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
81 | # add edge to graph
82 | super(UndirRoadNetwork, self).add_edge(u_of_edge, v_of_edge, **attr)
83 |
84 |
85 | class RoadNetwork(nx.DiGraph):
86 | def __init__(self, g, edge_spatial_idx, edge_idx):
87 | super(RoadNetwork, self).__init__(g)
88 | # entry: eid
89 | self.edge_spatial_idx = edge_spatial_idx
90 | # eid -> edge key (start_coord, end_coord)
91 | self.edge_idx = edge_idx
92 |
93 | def range_query(self, mbr):
94 | """
95 | spatial range query
96 | :param mbr: query mbr
97 | :return: qualified edge keys
98 | """
99 | eids = self.edge_spatial_idx.intersection((mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
100 | return [self.edge_idx[eid] for eid in eids]
101 |
102 | def remove_edge(self, u, v):
103 | edge_data = self[u][v]
104 | coords = edge_data['coords']
105 | mbr = MBR.cal_mbr(coords)
106 | # delete self.edge_idx[eifrom edge index
107 | del self.edge_idx[edge_data['eid']]
108 | # delete from spatial index
109 | self.edge_spatial_idx.delete(edge_data['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
110 | # delete from graph
111 | super(RoadNetwork, self).remove_edge(u, v)
112 |
113 | def add_edge(self, u_of_edge, v_of_edge, **attr):
114 | coords = attr['coords']
115 | mbr = MBR.cal_mbr(coords)
116 | attr['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)])
117 | # add edge to edge index
118 | self.edge_idx[attr['eid']] = (u_of_edge, v_of_edge)
119 | # add edge to spatial index
120 | self.edge_spatial_idx.insert(attr['eid'], (mbr.min_lng, mbr.min_lat, mbr.max_lng, mbr.max_lat))
121 | # add edge to graph
122 | super(RoadNetwork, self).add_edge(u_of_edge, v_of_edge, **attr)
123 |
124 |
125 | def load_rn_shp(path, is_directed=True):
126 | edge_spatial_idx = Rtree()
127 | edge_idx = {}
128 | # node uses coordinate as key
129 | # edge uses coordinate tuple as key
130 | g = nx.read_shp(path, simplify=True, strict=False)
131 | if not is_directed:
132 | g = g.to_undirected()
133 | # node attrs: nid, pt, ...
134 | for n, data in g.nodes(data=True):
135 | data['pt'] = SPoint(n[1], n[0])
136 | if 'ShpName' in data:
137 | del data['ShpName']
138 | # edge attrs: eid, length, coords, ...
139 | for u, v, data in g.edges(data=True):
140 | geom_line = ogr.CreateGeometryFromWkb(data['Wkb'])
141 | coords = []
142 | for i in range(geom_line.GetPointCount()):
143 | geom_pt = geom_line.GetPoint(i)
144 | coords.append(SPoint(geom_pt[1], geom_pt[0]))
145 | data['coords'] = coords
146 | data['length'] = sum([distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)])
147 | env = geom_line.GetEnvelope()
148 | edge_spatial_idx.insert(data['eid'], (env[0], env[2], env[1], env[3]))
149 | edge_idx[data['eid']] = (u, v)
150 | del data['ShpName']
151 | del data['Json']
152 | del data['Wkt']
153 | del data['Wkb']
154 | print('# of nodes:{}'.format(g.number_of_nodes()))
155 | print('# of edges:{}'.format(g.number_of_edges()))
156 | if not is_directed:
157 | return UndirRoadNetwork(g, edge_spatial_idx, edge_idx)
158 | else:
159 | return RoadNetwork(g, edge_spatial_idx, edge_idx)
160 |
161 |
162 | def store_rn_shp(rn, target_path):
163 | print('# of nodes:{}'.format(rn.number_of_nodes()))
164 | print('# of edges:{}'.format(rn.number_of_edges()))
165 | for _, data in rn.nodes(data=True):
166 | if 'pt' in data:
167 | del data['pt']
168 | for _, _, data in rn.edges(data=True):
169 | geo_line = ogr.Geometry(ogr.wkbLineString)
170 | for coord in data['coords']:
171 | geo_line.AddPoint(coord.lng, coord.lat)
172 | data['Wkb'] = geo_line.ExportToWkb()
173 | del data['coords']
174 | if 'length' in data:
175 | del data['length']
176 | if not rn.is_directed():
177 | rn = rn.to_directed()
178 | nx.write_shp(rn, target_path)
179 |
--------------------------------------------------------------------------------
/common/spatial_func.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | DEGREES_TO_RADIANS = math.pi / 180
4 | RADIANS_TO_DEGREES = 1 / DEGREES_TO_RADIANS
5 | EARTH_MEAN_RADIUS_METER = 6371008.7714
6 | DEG_TO_KM = DEGREES_TO_RADIANS * EARTH_MEAN_RADIUS_METER
7 | LAT_PER_METER = 8.993203677616966e-06
8 | LNG_PER_METER = 1.1700193970443768e-05
9 |
10 |
11 | class SPoint:
12 | def __init__(self, lat, lng):
13 | self.lat = lat
14 | self.lng = lng
15 |
16 | def __str__(self):
17 | return '({},{})'.format(self.lat, self.lng)
18 |
19 | def __repr__(self):
20 | return self.__str__()
21 |
22 | def __eq__(self, other):
23 | # equal. Orginally is compared with reference. Here we change to value
24 | return self.lat == other.lat and self.lng == other.lng
25 |
26 | def __ne__(self, other):
27 | # not equal
28 | return not self == other
29 |
30 |
31 | def same_coords(a, b):
32 | # we can directly use == since SPoint has updated __eq__()
33 | if a == b:
34 | return True
35 | else:
36 | return False
37 |
38 |
39 | def distance(a, b):
40 | """
41 | Calculate haversine distance between two GPS points in meters
42 | Args:
43 | -----
44 | a,b: SPoint class
45 | Returns:
46 | --------
47 | d: float. haversine distance in meter
48 | """
49 | if same_coords(a, b):
50 | return 0.0
51 | delta_lat = math.radians(b.lat - a.lat)
52 | delta_lng = math.radians(b.lng - a.lng)
53 | h = math.sin(delta_lat / 2.0) * math.sin(delta_lat / 2.0) + math.cos(math.radians(a.lat)) * math.cos(
54 | math.radians(b.lat)) * math.sin(delta_lng / 2.0) * math.sin(delta_lng / 2.0)
55 | c = 2.0 * math.atan2(math.sqrt(h), math.sqrt(1 - h))
56 | d = EARTH_MEAN_RADIUS_METER * c
57 | return d
58 |
59 |
60 | # http://www.movable-type.co.uk/scripts/latlong.html
61 | def bearing(a, b):
62 | """
63 | Calculate the bearing of ab
64 | """
65 | pt_a_lat_rad = math.radians(a.lat)
66 | pt_a_lng_rad = math.radians(a.lng)
67 | pt_b_lat_rad = math.radians(b.lat)
68 | pt_b_lng_rad = math.radians(b.lng)
69 | y = math.sin(pt_b_lng_rad - pt_a_lng_rad) * math.cos(pt_b_lat_rad)
70 | x = math.cos(pt_a_lat_rad) * math.sin(pt_b_lat_rad) - math.sin(pt_a_lat_rad) * math.cos(pt_b_lat_rad) * math.cos(pt_b_lng_rad - pt_a_lng_rad)
71 | bearing_rad = math.atan2(y, x)
72 | return math.fmod(math.degrees(bearing_rad) + 360.0, 360.0)
73 |
74 |
75 | def cal_loc_along_line(a, b, rate):
76 | """
77 | convert rate to gps location
78 | """
79 | lat = a.lat + rate * (b.lat - a.lat)
80 | lng = a.lng + rate * (b.lng - a.lng)
81 | return SPoint(lat, lng)
82 |
83 |
84 | def project_pt_to_segment(a, b, t):
85 | """
86 | Args:
87 | -----
88 | a,b: start/end GPS location of a road segment
89 | t: raw point
90 | Returns:
91 | -------
92 | project: projected GPS point on road segment
93 | rate: rate of projected point location to road segment
94 | dist: haversine_distance of raw and projected point
95 | """
96 | ab_angle = bearing(a, b)
97 | at_angle = bearing(a, t)
98 | ab_length = distance(a, b)
99 | at_length = distance(a, t)
100 | delta_angle = at_angle - ab_angle
101 | meters_along = at_length * math.cos(math.radians(delta_angle))
102 | if ab_length == 0.0:
103 | rate = 0.0
104 | else:
105 | rate = meters_along / ab_length
106 | if rate >= 1:
107 | projection = SPoint(b.lat, b.lng)
108 | rate = 1.0
109 | elif rate <= 0:
110 | projection = SPoint(a.lat, a.lng)
111 | rate = 0.0
112 | else:
113 | projection = cal_loc_along_line(a, b, rate)
114 | dist = distance(t, projection)
115 | return projection, rate, dist
116 |
--------------------------------------------------------------------------------
/common/trajectory.py:
--------------------------------------------------------------------------------
1 | from .spatial_func import distance, SPoint, cal_loc_along_line
2 | from .mbr import MBR
3 | from datetime import timedelta
4 |
5 |
6 | class STPoint(SPoint):
7 | """
8 | STPoint creates a data type for spatio-temporal point, i.e. STPoint().
9 |
10 | """
11 |
12 | def __init__(self, lat, lng, time, data=None):
13 | super(STPoint, self).__init__(lat, lng)
14 | self.time = time
15 | self.data = data # contains edge's attributes
16 |
17 | def __str__(self):
18 | """
19 | For easily reading the output
20 | """
21 | # __repr__() to change the print review
22 | # st = STPoint()
23 | # print(st) will not be the reference but the following format
24 | # if __repr__ is changed to str format, __str__ will be automatically change.
25 |
26 | return str(self.__dict__) # key and value of self attributes
27 | # return '({}, {}, {})'.format(self.time.strftime('%Y/%m/%d %H:%M:%S'), self.lat, self.lng)
28 |
29 |
30 | class Trajectory:
31 | """
32 | Trajectory creates a data type for trajectory, i.e. Trajectory()
33 | """
34 |
35 | def __init__(self, oid, tid, pt_list):
36 | """
37 | Args:
38 | -----
39 | oid:
40 | str. human id
41 | tid:
42 | str. trajectory id, sometimes based on start and end time. see get_tid()
43 | pt_list:
44 | list of STPoint(lat, lng, time), containing the attributes of class STPoint
45 | """
46 | self.oid = oid
47 | self.tid = tid
48 | self.pt_list = pt_list
49 |
50 | def get_duration(self):
51 | """
52 | Get duration of a trajectory (pt_list)
53 | last_point.time - first_point.time
54 | seconds
55 | """
56 | return (self.pt_list[-1].time - self.pt_list[0].time).total_seconds()
57 |
58 | def get_distance(self):
59 | """
60 | Get geographical distance of a trajectory (pt_list)
61 | sum of two adjacent points
62 | meters
63 | """
64 | dist = 0.0
65 | pre_pt = self.pt_list[0]
66 | for pt in self.pt_list[1:]:
67 | tmp_dist = distance(pre_pt, pt)
68 | dist += tmp_dist
69 | pre_pt = pt
70 | return dist
71 |
72 | def get_avg_time_interval(self):
73 | """
74 | Calculate average time interval between two GPS points in one trajectory (pt_list)
75 | """
76 | point_time_interval = []
77 | # How clever method! zip to get time interval
78 |
79 | for pre, cur in zip(self.pt_list[:-1], self.pt_list[1:]):
80 | point_time_interval.append((cur.time - pre.time).total_seconds())
81 | return sum(point_time_interval) / len(point_time_interval)
82 |
83 | def get_avg_distance_interval(self):
84 | """
85 | Calculate average distance interval between two GPS points in one trajectory (pt_list)
86 | """
87 | point_dist_interval = []
88 | for pre, cur in zip(self.pt_list[:-1], self.pt_list[1:]):
89 | point_dist_interval.append(distance(pre, cur))
90 | return sum(point_dist_interval) / len(point_dist_interval)
91 |
92 | def get_mbr(self):
93 | return MBR.cal_mbr(self.pt_list)
94 |
95 | def get_start_time(self):
96 | return self.pt_list[0].time
97 |
98 | def get_end_time(self):
99 | return self.pt_list[-1].time
100 |
101 | def get_mid_time(self):
102 | return self.pt_list[0].time + (self.pt_list[-1].time - self.pt_list[0].time) / 2.0
103 |
104 | def get_centroid(self):
105 | """
106 | Get centroid SPoint
107 | """
108 | mean_lat = 0.0
109 | mean_lng = 0.0
110 | for pt in self.pt_list:
111 | mean_lat += pt.lat
112 | mean_lng += pt.lng
113 | mean_lat /= len(self.pt_list)
114 | mean_lng /= len(self.pt_list)
115 | return SPoint(mean_lat, mean_lng)
116 |
117 | def query_trajectory_by_temporal_range(self, start_time, end_time):
118 | # start_time <= pt.time < end_time
119 | traj_start_time = self.get_start_time()
120 | traj_end_time = self.get_end_time()
121 | if start_time > traj_end_time:
122 | return None
123 | if end_time <= traj_start_time:
124 | return None
125 | st = max(traj_start_time, start_time)
126 | et = min(traj_end_time + timedelta(seconds=1), end_time)
127 | start_idx = self.binary_search_idx(st) # pt_list[start_idx].time <= st < pt_list[start_idx+1].time
128 | if self.pt_list[start_idx].time < st:
129 | # then the start_idx is out of the range, we need to increase it
130 | start_idx += 1
131 | end_idx = self.binary_search_idx(et) # pt_list[end_idx].time <= et < pt_list[end_idx+1].time
132 | if self.pt_list[end_idx].time < et:
133 | # then the end_idx is acceptable
134 | end_idx += 1
135 | sub_pt_list = self.pt_list[start_idx:end_idx]
136 | return Trajectory(self.oid, get_tid(self.oid, sub_pt_list), sub_pt_list)
137 |
138 | def binary_search_idx(self, time):
139 | # self.pt_list[idx].time <= time < self.pt_list[idx+1].time
140 | # if time < self.pt_list[0].time, return -1
141 | # if time >= self.pt_list[len(self.pt_list)-1].time, return len(self.pt_list)-1
142 | nb_pts = len(self.pt_list)
143 | if time < self.pt_list[0].time:
144 | return -1
145 | if time >= self.pt_list[-1].time:
146 | return nb_pts - 1
147 | # the time is in the middle
148 | left_idx = 0
149 | right_idx = nb_pts - 1
150 | while left_idx <= right_idx:
151 | mid_idx = int((left_idx + right_idx) / 2)
152 | if mid_idx < nb_pts - 1 and self.pt_list[mid_idx].time <= time < self.pt_list[mid_idx + 1].time:
153 | return mid_idx
154 | elif self.pt_list[mid_idx].time < time:
155 | left_idx = mid_idx + 1
156 | else:
157 | right_idx = mid_idx - 1
158 |
159 | def query_location_by_timestamp(self, time):
160 | idx = self.binary_search_idx(time)
161 | if idx == -1 or idx == len(self.pt_list) - 1:
162 | return None
163 | if self.pt_list[idx].time == time or (self.pt_list[idx + 1].time - self.pt_list[idx].time).total_seconds() == 0:
164 | return SPoint(self.pt_list[idx].lat, self.pt_list[idx].lng)
165 | else:
166 | # interpolate location
167 | dist_ab = distance(self.pt_list[idx], self.pt_list[idx + 1])
168 | if dist_ab == 0:
169 | return SPoint(self.pt_list[idx].lat, self.pt_list[idx].lng)
170 | dist_traveled = dist_ab * (time - self.pt_list[idx].time).total_seconds() / \
171 | (self.pt_list[idx + 1].time - self.pt_list[idx].time).total_seconds()
172 | return cal_loc_along_line(self.pt_list[idx], self.pt_list[idx + 1], dist_traveled / dist_ab)
173 |
174 | def to_wkt(self):
175 | wkt = 'LINESTRING ('
176 | for pt in self.pt_list:
177 | wkt += '{} {}, '.format(pt.lng, pt.lat)
178 | wkt = wkt[:-2] + ')'
179 | return wkt
180 |
181 | def __hash__(self):
182 | return hash(self.oid + '_' + self.pt_list[0].time.strftime('%Y%m%d%H%M%S') + '_' +
183 | self.pt_list[-1].time.strftime('%Y%m%d%H%M%S'))
184 |
185 | def __eq__(self, other):
186 | return hash(self) == hash(other)
187 |
188 | def __repr__(self):
189 | return f'Trajectory(oid={self.oid},tid={self.tid})'
190 |
191 |
192 | def get_tid(oid, pt_list):
193 | return oid + '_' + pt_list[0].time.strftime('%Y%m%d%H%M%S') + '_' + pt_list[-1].time.strftime('%Y%m%d%H%M%S')
194 |
195 | # remove parse and store trajectory functions. The related functions are in utils.parse_traj, utils.store_traj
196 |
--------------------------------------------------------------------------------
/data/map/extra_info/poi20.csv:
--------------------------------------------------------------------------------
1 | ,,company,shopping,food,house,viewpoint
2 | 173,73,1,1,0,0,0
3 | 215,157,1,0,1,0,0
4 | 87,154,1,1,2,0,0
5 | 181,199,1,2,0,1,0
6 | 132,141,0,2,4,0,2
7 | 68,228,0,1,0,2,0
8 | 125,82,0,0,0,1,0
9 | 127,159,0,2,0,0,0
10 | 220,283,0,1,0,0,0
11 | 42,141,3,0,0,1,0
12 | 158,205,0,26,0,0,0
13 | 220,182,5,0,0,2,0
14 | 49,239,0,2,1,0,0
15 | 111,45,1,3,1,0,0
16 | 119,334,0,0,0,1,0
17 | 25,248,0,0,0,1,0
18 | 155,199,0,0,0,1,0
19 | 172,289,0,0,1,0,0
20 | 143,24,0,1,1,0,0
21 |
--------------------------------------------------------------------------------
/data/map/extra_info/weather_dict.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/extra_info/weather_dict.pkl
--------------------------------------------------------------------------------
/data/map/road_network/edges.dbf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.dbf
--------------------------------------------------------------------------------
/data/map/road_network/edges.shp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.shp
--------------------------------------------------------------------------------
/data/map/road_network/edges.shx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/edges.shx
--------------------------------------------------------------------------------
/data/map/road_network/nodes.dbf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.dbf
--------------------------------------------------------------------------------
/data/map/road_network/nodes.shp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.shp
--------------------------------------------------------------------------------
/data/map/road_network/nodes.shx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/data/map/road_network/nodes.shx
--------------------------------------------------------------------------------
/doc/KDD21_MTrajRec_Huimin.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/doc/KDD21_MTrajRec_Huimin.pdf
--------------------------------------------------------------------------------
/img/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/img/intro.png
--------------------------------------------------------------------------------
/img/model_framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huiminren/MTrajRec/f9c1f24311d77c68b1000bc6bb62e26d09841390/img/model_framework.png
--------------------------------------------------------------------------------
/map_matching/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/3 13:34
--------------------------------------------------------------------------------
/map_matching/candidate_point.py:
--------------------------------------------------------------------------------
1 | from common.spatial_func import SPoint, LAT_PER_METER, LNG_PER_METER, project_pt_to_segment, distance
2 | from common.mbr import MBR
3 |
4 |
5 | class CandidatePoint(SPoint):
6 | def __init__(self, lat, lng, eid, error, offset, rate):
7 | super(CandidatePoint, self).__init__(lat, lng)
8 | self.eid = eid
9 | self.error = error
10 | self.offset = offset
11 | self.rate = rate
12 |
13 | def __str__(self):
14 | return '{},{},{},{},{},{}'.format(self.eid, self.lat, self.lng, self.error, self.offset, self.rate)
15 |
16 | def __repr__(self):
17 | return '{},{},{},{},{},{}'.format(self.eid, self.lat, self.lng, self.error, self.offset, self.rate)
18 |
19 | def __hash__(self):
20 | return hash(self.__str__())
21 |
22 |
23 | def get_candidates(pt, rn, search_dist):
24 | """
25 | Args:
26 | -----
27 | pt: point STPoint()
28 | rn: road network
29 | search_dist: in meter. a parameter for HMM_mm. range of pt's potential road
30 | Returns:
31 | --------
32 | candidates: list of potential projected points.
33 | """
34 | candidates = None
35 | mbr = MBR(pt.lat - search_dist * LAT_PER_METER,
36 | pt.lng - search_dist * LNG_PER_METER,
37 | pt.lat + search_dist * LAT_PER_METER,
38 | pt.lng + search_dist * LNG_PER_METER)
39 | candidate_edges = rn.range_query(mbr) # list of edges (two nodes/points)
40 | if len(candidate_edges) > 0:
41 | candi_pt_list = [cal_candidate_point(pt, rn, candidate_edge) for candidate_edge in candidate_edges]
42 | # refinement
43 | candi_pt_list = [candi_pt for candi_pt in candi_pt_list if candi_pt.error <= search_dist]
44 | if len(candi_pt_list) > 0:
45 | candidates = candi_pt_list
46 | return candidates
47 |
48 |
49 | def cal_candidate_point(raw_pt, rn, edge):
50 | """
51 | Get attributes of candidate point
52 | """
53 | u, v = edge
54 | coords = rn[u][v]['coords'] # GPS points in road segment, may be larger than 2
55 | candidates = [project_pt_to_segment(coords[i], coords[i + 1], raw_pt) for i in range(len(coords) - 1)]
56 | idx, (projection, coor_rate, dist) = min(enumerate(candidates), key=lambda x: x[1][2])
57 | # enumerate return idx and (), x[1] --> () x[1][2] --> dist. get smallest error project edge
58 | offset = 0.0
59 | for i in range(idx):
60 | offset += distance(coords[i], coords[i + 1]) # make the road distance more accurately
61 | offset += distance(coords[idx], projection) # distance of road start position and projected point
62 | if rn[u][v]['length'] == 0:
63 | rate = 0
64 | # print(u, v)
65 | else:
66 | rate = offset/rn[u][v]['length'] # rate of whole road, coor_rate is the rate of coords.
67 | return CandidatePoint(projection.lat, projection.lng, rn[u][v]['eid'], dist, offset, rate)
68 |
69 |
--------------------------------------------------------------------------------
/map_matching/hmm/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/3 13:34
--------------------------------------------------------------------------------
/map_matching/hmm/hmm_map_matcher.py:
--------------------------------------------------------------------------------
1 | """
2 | Based on Newson, Paul, and John Krumm. "Hidden Markov map matching through
3 | noise and sparseness." Proceedings of the 17th ACM SIGSPATIAL International
4 | Conference on Advances in Geographic Information Systems. ACM, 2009.
5 | This is a Python translation from https://github.com/graphhopper/map-matching/tree/master/hmm-lib
6 | """
7 |
8 | from ..hmm.hmm_probabilities import HMMProbabilities
9 | from ..hmm.ti_viterbi import ViterbiAlgorithm, SequenceState
10 | from ..map_matcher import MapMatcher
11 | from ..candidate_point import get_candidates
12 | from common.spatial_func import distance
13 | from common.trajectory import STPoint, Trajectory
14 | from ..utils import find_shortest_path
15 | from ..route_constructor import construct_path
16 |
17 |
18 | class TimeStep:
19 | """
20 | Contains everything the hmm-lib needs to process a new time step including emission and observation probabilities.
21 | emission probability: p(z|r), the likelihood that the measurement z would be observed if the vehicle were actually
22 | on road segment r
23 | """
24 | def __init__(self, observation, candidates):
25 | if observation is None or candidates is None:
26 | raise Exception('observation and candidates must not be null.')
27 | self.observation = observation
28 | self.candidates = candidates
29 | self.emission_log_probabilities = {}
30 | self.transition_log_probabilities = {}
31 | # transition -> dist
32 | self.road_paths = {}
33 |
34 | def add_emission_log_probability(self, candidate, emission_log_probability):
35 | if candidate in self.emission_log_probabilities:
36 | raise Exception('Candidate has already been added.')
37 | self.emission_log_probabilities[candidate] = emission_log_probability
38 |
39 | def add_transition_log_probability(self, from_position, to_position, transition_log_probability):
40 | transition = (from_position, to_position)
41 | if transition in self.transition_log_probabilities:
42 | raise Exception('Transition has already been added.')
43 | self.transition_log_probabilities[transition] = transition_log_probability
44 |
45 | def add_road_path(self, from_position, to_position, road_path):
46 | transition = (from_position, to_position)
47 | if transition in self.road_paths:
48 | raise Exception('Transition has already been added.')
49 | self.road_paths[transition] = road_path
50 |
51 |
52 | class TIHMMMapMatcher(MapMatcher):
53 | def __init__(self, rn, search_dis=50, sigma=5.0, beta=2.0, routing_weight='length', debug=False):
54 | self.measurement_error_sigma = search_dis # search_dist, original paper = 200m. 4/50m ?
55 | self.transition_probability_beta = beta # beta is a parameter in equation(2).
56 | self.guassian_sigma = sigma
57 | # A larger beta measures the difference between great circle distances and route distances,
58 | # represents more tolerance of non-direct routes.
59 | self.debug = debug
60 | super(TIHMMMapMatcher, self).__init__(rn, routing_weight)
61 |
62 | # our implementation, no candidates or no transition will be set to None, and start a new matching
63 | def match(self, traj):
64 | """ Given original traj, return map-matched trajectory"""
65 | seq = self.compute_viterbi_sequence(traj.pt_list)
66 | assert len(traj.pt_list) == len(seq), 'pt_list and seq must have the same size'
67 | mm_pt_list = []
68 | for ss in seq:
69 | candi_pt = None
70 | if ss.state is not None:
71 | candi_pt = ss.state
72 | data = {'candi_pt': candi_pt}
73 | mm_pt_list.append(STPoint(ss.observation.lat, ss.observation.lng, ss.observation.time, data))
74 | mm_traj = Trajectory(traj.oid, traj.tid, mm_pt_list)
75 | return mm_traj
76 |
77 | def match_to_path(self, traj):
78 | mm_traj = self.match(traj)
79 | path = construct_path(self.rn, mm_traj, self.routing_weight)
80 | return path
81 |
82 | def create_time_step(self, pt):
83 | time_step = None
84 | candidates = get_candidates(pt, self.rn, self.measurement_error_sigma)
85 | if candidates is not None:
86 | time_step = TimeStep(pt, candidates)
87 | return time_step
88 |
89 | def compute_viterbi_sequence(self, pt_list):
90 | """
91 | Args:
92 | -----
93 | pt_list: observation pt_list
94 | Returns:
95 | -------
96 | seq: ?
97 | """
98 | seq = []
99 | probabilities = HMMProbabilities(self.guassian_sigma, self.transition_probability_beta)
100 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug)
101 | prev_time_step = None
102 | idx = 0
103 | nb_points = len(pt_list)
104 | while idx < nb_points:
105 | time_step = self.create_time_step(pt_list[idx])
106 | # construct the sequence ended at t-1, and skip current point (no candidate error)
107 | if time_step is None:
108 | seq.extend(viterbi.compute_most_likely_sequence())
109 | seq.append(SequenceState(None, pt_list[idx], None))
110 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug)
111 | prev_time_step = None
112 | else:
113 | self.compute_emission_probabilities(time_step, probabilities)
114 | if prev_time_step is None:
115 | viterbi.start_with_initial_observation(time_step.observation, time_step.candidates,
116 | time_step.emission_log_probabilities)
117 | else:
118 | self.compute_transition_probabilities(prev_time_step, time_step, probabilities)
119 | viterbi.next_step(time_step.observation, time_step.candidates, time_step.emission_log_probabilities,
120 | time_step.transition_log_probabilities, time_step.road_paths)
121 | if viterbi.is_broken:
122 | # construct the sequence ended at t-1, and start a new matching at t (no transition error)
123 | seq.extend(viterbi.compute_most_likely_sequence())
124 | viterbi = ViterbiAlgorithm(keep_message_history=self.debug)
125 | viterbi.start_with_initial_observation(time_step.observation, time_step.candidates,
126 | time_step.emission_log_probabilities)
127 | prev_time_step = time_step
128 | idx += 1
129 | if len(seq) < nb_points:
130 | seq.extend(viterbi.compute_most_likely_sequence())
131 | return seq
132 |
133 | def compute_emission_probabilities(self, time_step, probabilities):
134 | for candi_pt in time_step.candidates:
135 | dist = candi_pt.error
136 | time_step.add_emission_log_probability(candi_pt, probabilities.emission_log_probability(dist))
137 |
138 | def compute_transition_probabilities(self, prev_time_step, time_step, probabilities):
139 | linear_dist = distance(prev_time_step.observation, time_step.observation)
140 | for prev_candi_pt in prev_time_step.candidates:
141 | for cur_candi_pt in time_step.candidates:
142 | path_dist, path = find_shortest_path(self.rn, prev_candi_pt, cur_candi_pt, self.routing_weight)
143 | # invalid transition has no transition probability
144 | if path is not None:
145 | time_step.add_road_path(prev_candi_pt, cur_candi_pt, path)
146 | time_step.add_transition_log_probability(prev_candi_pt, cur_candi_pt,
147 | probabilities.transition_log_probability(path_dist,
148 | linear_dist))
149 |
--------------------------------------------------------------------------------
/map_matching/hmm/hmm_probabilities.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | class HMMProbabilities:
5 | def __init__(self, sigma, beta):
6 | self.sigma = sigma
7 | self.beta = beta
8 |
9 | def emission_log_probability(self, distance):
10 | """
11 | Returns the logarithmic emission probability density.
12 | :param distance: Absolute distance [m] between GPS measurement and map matching candidate.
13 | :return:
14 | """
15 | return log_normal_distribution(self.sigma, distance)
16 |
17 | def transition_log_probability(self, route_length, linear_distance):
18 | """
19 | Returns the logarithmic transition probability density for the given transition parameters.
20 | :param route_length: Length of the shortest route [m] between two consecutive map matching candidates.
21 | :param linear_distance: Linear distance [m] between two consecutive GPS measurements.
22 | :return:
23 | """
24 | transition_metric = math.fabs(linear_distance - route_length)
25 | return log_exponential_distribution(self.beta, transition_metric)
26 |
27 |
28 | def log_normal_distribution(sigma, x):
29 | return math.log(1.0 / (math.sqrt(2.0 * math.pi) * sigma)) + (-0.5 * pow(x / sigma, 2))
30 |
31 |
32 | def log_exponential_distribution(beta, x):
33 | return math.log(1.0 / beta) - (x / beta)
34 |
--------------------------------------------------------------------------------
/map_matching/hmm/ti_viterbi.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of the Viterbi algorithm for time-inhomogeneous Markov processes,
3 | meaning that the set of states and state transition probabilities are not necessarily fixed for all time steps.
4 | For long observation sequences, back pointers usually converge to a single path after a
5 | certain number of time steps. For instance, when matching GPS coordinates to roads, the last
6 | GPS positions in the trace usually do not affect the first road matches anymore.
7 | This implementation exploits this fact by letting the Java garbage collector
8 | take care of unreachable back pointers. If back pointers converge to a single path after a
9 | constant number of time steps, only O(t) back pointers and transition descriptors need to be stored in memory.
10 | """
11 |
12 |
13 | class ExtendedState:
14 | """
15 | Back pointer to previous state candidate in the most likely sequence.
16 | """
17 | def __init__(self, state, back_pointer, observation, transition_descriptor):
18 | self.state = state
19 | self.back_pointer = back_pointer
20 | self.observation = observation
21 | self.transition_descriptor = transition_descriptor
22 |
23 |
24 | class SequenceState:
25 | def __init__(self, state, observation, transition_descriptor):
26 | self.state = state
27 | self.observation = observation
28 | self.transition_descriptor = transition_descriptor
29 |
30 |
31 | class ForwardStepResult:
32 | def __init__(self):
33 | """
34 | Includes back pointers to previous state candidates for retrieving the most likely sequence after the forward pass.
35 | :param nb_states:
36 | """
37 | self.new_message = {}
38 | self.new_extended_states = {}
39 |
40 |
41 | class ViterbiAlgorithm:
42 | def __init__(self, keep_message_history=False):
43 | # Allows to retrieve the most likely sequence using back pointers.
44 | self.last_extended_states = None
45 | self.prev_candidates = []
46 | # For each state s_t of the current time step t, message.get(s_t) contains the log
47 | # probability of the most likely sequence ending in state s_t with given observations o_1, ..., o_t.
48 | #
49 | # Formally, this is max log p(s_1, ..., s_t, o_1, ..., o_t) w.r.t. s_1, ..., s_{t-1}.
50 | # Note that to compute the most likely state sequence, it is sufficient and more
51 | # efficient to compute in each time step the joint probability of states and observations
52 | # instead of computing the conditional probability of states given the observations.
53 | # state -> float
54 | self.message = None
55 | self.is_broken = False
56 | # list of message
57 | self.message_history = None
58 | if keep_message_history:
59 | self.message_history = []
60 |
61 | def initialize_state_probabilities(self, observation, candidates, initial_log_probabilities):
62 | """
63 | Use only if HMM only starts with first observation.
64 | :param observation:
65 | :param candidates:
66 | :param initial_log_probabilities:
67 | :return:
68 | """
69 | if self.message is not None:
70 | raise Exception('Initial probabilities have already been set.')
71 | # Set initial log probability for each start state candidate based on first observation.
72 | # Do not assign initial_log_probabilities directly to message to not rely on its iteration order.
73 | initial_message = {}
74 | for candidate in candidates:
75 | if candidate not in initial_log_probabilities:
76 | raise Exception('No initial probability for {}'.format(candidate))
77 | log_probability = initial_log_probabilities[candidate]
78 | initial_message[candidate] = log_probability
79 | self.is_broken = self.hmm_break(initial_message)
80 | if self.is_broken:
81 | return
82 | self.message = initial_message
83 | if self.message_history is not None:
84 | self.message_history.append(self.message)
85 | self.last_extended_states = {}
86 | for candidate in candidates:
87 | self.last_extended_states[candidate] = ExtendedState(candidate, None, observation, None)
88 | self.prev_candidates = [candidate for candidate in candidates]
89 |
90 | def hmm_break(self, message):
91 | """
92 | Returns whether the specified message is either empty or only contains state candidates with zero probability and thus causes the HMM to break.
93 | :return:
94 | """
95 | for log_probability in message.values():
96 | if log_probability != float('-inf'):
97 | return False
98 | return True
99 |
100 | def forward_step(self, observation, prev_candidates, cur_candidates, message, emission_log_probabilities,
101 | transition_log_probabilities, transition_descriptors=None):
102 | result = ForwardStepResult()
103 | assert len(prev_candidates) != 0
104 |
105 | for cur_state in cur_candidates:
106 | max_log_probability = float('-inf')
107 | max_prev_state = None
108 | for prev_state in prev_candidates:
109 | log_probability = message[prev_state] + self.transition_log_probability(prev_state, cur_state,
110 | transition_log_probabilities)
111 | if log_probability > max_log_probability:
112 | max_log_probability = log_probability
113 | max_prev_state = prev_state
114 | # throws KeyError if cur_state not in the emission_log_probabilities
115 | result.new_message[cur_state] = max_log_probability + emission_log_probabilities[cur_state]
116 | # Note that max_prev_state == None if there is no transition with non-zero probability.
117 | # In this case cur_state has zero probability and will not be part of the most likely
118 | # sequence, so we don't need an ExtendedState.
119 | if max_prev_state is not None:
120 | transition = (max_prev_state, cur_state)
121 | if transition_descriptors is not None:
122 | transition_descriptor = transition_descriptors[transition]
123 | else:
124 | transition_descriptor = None
125 | extended_state = ExtendedState(cur_state, self.last_extended_states[max_prev_state], observation,
126 | transition_descriptor)
127 | result.new_extended_states[cur_state] = extended_state
128 | return result
129 |
130 | def transition_log_probability(self, prev_state, cur_state, transition_log_probabilities):
131 | transition = (prev_state, cur_state)
132 | if transition not in transition_log_probabilities:
133 | return float('-inf')
134 | else:
135 | return transition_log_probabilities[transition]
136 |
137 | def most_likely_state(self):
138 | # Retrieves the first state of the current forward message with maximum probability.
139 | assert len(self.message) != 0
140 |
141 | result = None
142 | max_log_probability = float('-inf')
143 | for state in self.message:
144 | if self.message[state] > max_log_probability:
145 | result = state
146 | max_log_probability = self.message[state]
147 | # Otherwise an HMM break would have occurred.
148 | assert result is not None
149 | return result
150 |
151 | def retrieve_most_likely_sequence(self):
152 | # Otherwise an HMM break would have occurred and message would be null.
153 | assert len(self.message) != 0
154 |
155 | last_state = self.most_likely_state()
156 | # Retrieve most likely state sequence in reverse order
157 | result = []
158 | es = self.last_extended_states[last_state]
159 | while es is not None:
160 | ss = SequenceState(es.state, es.observation, es.transition_descriptor)
161 | result.append(ss)
162 | es = es.back_pointer
163 | result.reverse()
164 | return result
165 |
166 | def start_with_initial_observation(self, observation, candidates, emission_log_probabilities):
167 | """
168 | Lets the HMM computation start at the given first observation and uses the given emission
169 | probabilities as the initial state probability for each starting state s.
170 | :param observation:
171 | :param candidates:
172 | :param emission_log_probabilities:
173 | :return:
174 | """
175 | self.initialize_state_probabilities(observation, candidates, emission_log_probabilities)
176 |
177 | def next_step(self, observation, candidates, emission_log_probabilities, transition_log_probabilities, transition_descriptors=None):
178 | if self.message is None:
179 | raise Exception('start_with_initial_observation() must be called first.')
180 | if self.is_broken:
181 | raise Exception('Method must not be called after an HMM break.')
182 | forward_step_result = self.forward_step(observation, self.prev_candidates, candidates, self.message,
183 | emission_log_probabilities, transition_log_probabilities, transition_descriptors)
184 | self.is_broken = self.hmm_break(forward_step_result.new_message)
185 | if self.is_broken:
186 | return
187 | if self.message_history is not None:
188 | self.message_history.append(forward_step_result.new_message)
189 | self.message = forward_step_result.new_message
190 | self.last_extended_states = forward_step_result.new_extended_states
191 | self.prev_candidates = [candidate for candidate in candidates]
192 |
193 | def compute_most_likely_sequence(self):
194 | """
195 | Returns the most likely sequence of states for all time steps. This includes the initial
196 | states / initial observation time step. If an HMM break occurred in the last time step t,
197 | then the most likely sequence up to t-1 is returned. See also {@link #isBroken()}.
198 | Formally, the most likely sequence is argmax p([s_0,] s_1, ..., s_T | o_1, ..., o_T)
199 | with respect to s_1, ..., s_T, where s_t is a state candidate at time step t,
200 | o_t is the observation at time step t and T is the number of time steps.
201 | :return:
202 | """
203 | if self.message is None:
204 | # Return empty most likely sequence if there are no time steps or if initial observations caused an HMM break.
205 | return []
206 | else:
207 | return self.retrieve_most_likely_sequence()
208 |
--------------------------------------------------------------------------------
/map_matching/map_matcher.py:
--------------------------------------------------------------------------------
1 | class MapMatcher:
2 | def __init__(self, rn, routing_weight='length'):
3 | self.rn = rn
4 | self.routing_weight = routing_weight
5 |
6 | def match(self, traj):
7 | pass
8 |
9 | def match_to_path(self, traj):
10 | pass
11 |
--------------------------------------------------------------------------------
/map_matching/route_constructor.py:
--------------------------------------------------------------------------------
1 | from .utils import find_shortest_path
2 | from datetime import timedelta
3 | import networkx as nx
4 | from common.path import PathEntity, Path
5 |
6 |
7 | def construct_path(rn, mm_traj, routing_weight):
8 | """
9 | construct the path of the map matched trajectory
10 | Note: the enter time of the first path entity & the leave time of the last path entity is not accurate
11 | :param rn: the road network
12 | :param mm_traj: the map matched trajectory
13 | :param routing_weight: the attribute name to find the routing weight
14 | :return: a list of paths (Note: in case that the route is broken)
15 | """
16 | paths = []
17 | path = []
18 | mm_pt_list = mm_traj.pt_list
19 | start_idx = len(mm_pt_list) - 1
20 | # find the first matched point
21 | for i in range(len(mm_pt_list)):
22 | if mm_pt_list[i].data['candi_pt'] is not None:
23 | start_idx = i
24 | break
25 | pre_edge_enter_time = mm_pt_list[start_idx].time
26 | for i in range(start_idx + 1, len(mm_pt_list)):
27 | pre_mm_pt = mm_pt_list[i-1]
28 | cur_mm_pt = mm_pt_list[i]
29 | # unmatched -> matched
30 | if pre_mm_pt.data['candi_pt'] is None:
31 | pre_edge_enter_time = cur_mm_pt.time
32 | continue
33 | # matched -> unmatched
34 | pre_candi_pt = pre_mm_pt.data['candi_pt']
35 | if cur_mm_pt.data['candi_pt'] is None:
36 | path.append(PathEntity(pre_edge_enter_time, pre_mm_pt.time, pre_candi_pt.eid))
37 | if len(path) > 2:
38 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path))
39 | path = []
40 | continue
41 | # matched -> matched
42 | cur_candi_pt = cur_mm_pt.data['candi_pt']
43 | # if consecutive points are on the same road, cur_mm_pt doesn't bring new information
44 | if pre_candi_pt.eid != cur_candi_pt.eid:
45 | weight_p, p = find_shortest_path(rn, pre_candi_pt, cur_candi_pt, routing_weight)
46 | # cannot connect
47 | if p is None:
48 | path.append(PathEntity(pre_edge_enter_time, pre_mm_pt.time, pre_candi_pt.eid))
49 | if len(path) > 2:
50 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path))
51 | path = []
52 | pre_edge_enter_time = cur_mm_pt.time
53 | continue
54 | # can connect
55 | if nx.is_directed(rn):
56 | dist_to_p_entrance = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['length'] - pre_candi_pt.offset
57 | dist_to_p_exit = cur_candi_pt.offset
58 | else:
59 | entrance_vertex = p[0]
60 | pre_edge_coords = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['coords']
61 | if (pre_edge_coords[0].lng, pre_edge_coords[0].lat) == entrance_vertex:
62 | dist_to_p_entrance = pre_candi_pt.offset
63 | else:
64 | dist_to_p_entrance = rn.edges[rn.edge_idx[pre_candi_pt.eid]]['length'] - pre_candi_pt.offset
65 | exit_vertex = p[-1]
66 | cur_edge_coords = rn.edges[rn.edge_idx[cur_candi_pt.eid]]['coords']
67 | if (cur_edge_coords[0].lng, cur_edge_coords[0].lat) == exit_vertex:
68 | dist_to_p_exit = cur_candi_pt.offset
69 | else:
70 | dist_to_p_exit = rn.edges[rn.edge_idx[cur_candi_pt.eid]]['length'] - pre_candi_pt.offset
71 | if routing_weight == 'length':
72 | total_dist = weight_p
73 | else:
74 | dist_inner = 0.0
75 | for i in range(len(p) - 1):
76 | start, end = p[i], p[i + 1]
77 | dist_inner += rn[start][end]['length']
78 | total_dist = dist_inner + dist_to_p_entrance + dist_to_p_exit
79 | delta_time = (cur_mm_pt.time - pre_mm_pt.time).total_seconds()
80 | # two consecutive points matched to the same vertex
81 | if total_dist == 0:
82 | pre_edge_leave_time = cur_mm_pt.time
83 | path.append(PathEntity(pre_edge_enter_time, pre_edge_leave_time, pre_candi_pt.eid))
84 | cur_edge_enter_time = cur_mm_pt.time
85 | else:
86 | pre_edge_leave_time = pre_mm_pt.time + timedelta(seconds=delta_time*(dist_to_p_entrance/total_dist))
87 | path.append(PathEntity(pre_edge_enter_time, pre_edge_leave_time, pre_candi_pt.eid))
88 | cur_edge_enter_time = cur_mm_pt.time - timedelta(seconds=delta_time * (dist_to_p_exit / total_dist))
89 | sub_path = linear_interpolate_path(p, total_dist - dist_to_p_entrance - dist_to_p_exit,
90 | rn, pre_edge_leave_time, cur_edge_enter_time)
91 | path.extend(sub_path)
92 | pre_edge_enter_time = cur_edge_enter_time
93 | # handle last matched similar to (matched -> unmatched)
94 | if mm_pt_list[-1].data['candi_pt'] is not None:
95 | path.append(PathEntity(pre_edge_enter_time, mm_pt_list[-1].time, mm_pt_list[-1].data['candi_pt'].eid))
96 | if len(path) > 2:
97 | paths.append(Path(mm_traj.oid, get_pid(mm_traj.oid, path), path))
98 | return paths
99 |
100 |
101 | def linear_interpolate_path(p, dist_inner, rn, enter_time, leave_time):
102 | path = []
103 | edges = []
104 | for i in range(len(p)-1):
105 | edges.append((p[i], p[i+1]))
106 | delta_time = (leave_time - enter_time).total_seconds()
107 | edge_enter_time = enter_time
108 | for i in range(len(edges)):
109 | edge_data = rn.edges[edges[i]]
110 | if i == len(edges) - 1:
111 | # to make sure the last connect edge leave time
112 | # meet the path leave time due to double calculation accuracy
113 | edge_leave_time = leave_time
114 | else:
115 | edge_leave_time = edge_enter_time + timedelta(seconds=delta_time*(edge_data['length']/dist_inner))
116 | path.append(PathEntity(edge_enter_time, edge_leave_time, edge_data['eid']))
117 | edge_enter_time = edge_leave_time
118 | return path
119 |
120 |
121 | def get_pid(oid, path):
122 | return oid + '_' + path[0].enter_time.strftime('%Y%m%d%H%M') + '_' + \
123 | path[-1].leave_time.strftime('%Y%m%d%H%M')
124 |
--------------------------------------------------------------------------------
/map_matching/utils.py:
--------------------------------------------------------------------------------
1 | from common.spatial_func import SPoint, distance
2 | import networkx as nx
3 | import math
4 |
5 |
6 | def find_shortest_path(rn, prev_candi_pt, cur_candi_pt, weight='length'):
7 | if nx.is_directed(rn):
8 | return find_shortest_path_directed(rn, prev_candi_pt, cur_candi_pt, weight)
9 | else:
10 | return find_shortest_path_undirected(rn, prev_candi_pt, cur_candi_pt, weight)
11 |
12 |
13 | def find_shortest_path_directed(rn, prev_candi_pt, cur_candi_pt, weight):
14 | # case 1, on the same road
15 | if prev_candi_pt.eid == cur_candi_pt.eid:
16 | if prev_candi_pt.offset < cur_candi_pt.offset:
17 | return (cur_candi_pt.offset - prev_candi_pt.offset), []
18 | else:
19 | return float('inf'), None
20 | # case 2, on different roads (including opposite roads)
21 | else:
22 | pre_u, pre_v = rn.edge_idx[prev_candi_pt.eid]
23 | cur_u, cur_v = rn.edge_idx[cur_candi_pt.eid]
24 | try:
25 | path = get_cheapest_path_with_weight(rn, pre_v, cur_u, rn[pre_u][pre_v]['length'] - prev_candi_pt.offset,
26 | cur_candi_pt.offset, heuristic, weight)
27 | return path
28 | except nx.NetworkXNoPath:
29 | return float('inf'), None
30 |
31 |
32 | def find_shortest_path_undirected(rn, prev_candi_pt, cur_candi_pt, weight):
33 | # case 1, on the same road
34 | if prev_candi_pt.eid == cur_candi_pt.eid:
35 | return math.fabs(cur_candi_pt.offset - prev_candi_pt.offset), []
36 | # case 2, on different roads (including opposite roads)
37 | else:
38 | pre_u, pre_v = rn.edge_idx[prev_candi_pt.eid]
39 | cur_u, cur_v = rn.edge_idx[cur_candi_pt.eid]
40 | min_dist = float('inf')
41 | shortest_path = None
42 | paths = []
43 | # prev_u -> cur_u
44 | try:
45 | paths.append(get_cheapest_path_with_weight(rn, pre_u, cur_u, prev_candi_pt.offset,
46 | cur_candi_pt.offset, heuristic, weight))
47 | except nx.NetworkXNoPath:
48 | pass
49 | # prev_u -> cur_v
50 | try:
51 | paths.append(get_cheapest_path_with_weight(rn, pre_u, cur_v, prev_candi_pt.offset,
52 | rn[cur_u][cur_v]['length'] - cur_candi_pt.offset,
53 | heuristic, weight))
54 | except nx.NetworkXNoPath:
55 | pass
56 | # pre_v -> cur_u
57 | try:
58 | paths.append(get_cheapest_path_with_weight(rn, pre_v, cur_u,
59 | rn[pre_u][pre_v]['length'] - prev_candi_pt.offset,
60 | cur_candi_pt.offset, heuristic, weight))
61 | except nx.NetworkXNoPath:
62 | pass
63 | # prev_v -> cur_v:
64 | try:
65 | paths.append(get_cheapest_path_with_weight(rn, pre_v, cur_v,
66 | rn[pre_u][pre_v]['length'] - prev_candi_pt.offset,
67 | rn[cur_u][cur_v]['length'] - cur_candi_pt.offset,
68 | heuristic, weight))
69 | except nx.NetworkXNoPath:
70 | pass
71 | if len(paths) > 0:
72 | min_dist, shortest_path = min(paths, key=lambda t: t[0])
73 | return min_dist, shortest_path
74 |
75 |
76 | def heuristic(node1, node2):
77 | return distance(SPoint(node1[1], node1[0]), SPoint(node2[1], node2[0]))
78 |
79 |
80 | def get_cheapest_path_with_weight(rn, src, dest, dist_to_src, dist_to_dest, heuristic, weight):
81 | tot_weight = 0.0
82 | path = nx.astar_path(rn, src, dest, heuristic, weight=weight)
83 | tot_weight += dist_to_src
84 | for i in range(len(path) - 1):
85 | start = path[i]
86 | end = path[i + 1]
87 | tot_weight += rn[start][end][weight]
88 | tot_weight += dist_to_dest
89 | return tot_weight, path
90 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/10/20 18:07
4 |
--------------------------------------------------------------------------------
/models/datasets.py:
--------------------------------------------------------------------------------
1 | import random
2 | from tqdm import tqdm
3 | import os
4 | from chinese_calendar import is_holiday
5 |
6 | import numpy as np
7 | import torch
8 |
9 | from common.spatial_func import distance
10 | from common.trajectory import get_tid, Trajectory
11 | from utils.parse_traj import ParseMMTraj
12 | from utils.save_traj import SaveTraj2MM
13 | from utils.utils import create_dir
14 | from .model_utils import load_rid_freqs
15 |
16 |
17 | def split_data(traj_input_dir, output_dir):
18 | """
19 | split original data to train, valid and test datasets
20 | """
21 | create_dir(output_dir)
22 | train_data_dir = output_dir + 'train_data/'
23 | create_dir(train_data_dir)
24 | val_data_dir = output_dir + 'valid_data/'
25 | create_dir(val_data_dir)
26 | test_data_dir = output_dir + 'test_data/'
27 | create_dir(test_data_dir)
28 |
29 | trg_parser = ParseMMTraj()
30 | trg_saver = SaveTraj2MM()
31 |
32 | for file_name in tqdm(os.listdir(traj_input_dir)):
33 | traj_input_path = os.path.join(traj_input_dir, file_name)
34 | trg_trajs = np.array(trg_parser.parse(traj_input_path))
35 | ttl_lens = len(trg_trajs)
36 | test_inds = random.sample(range(ttl_lens), int(ttl_lens * 0.1)) # 10% as test data
37 | tmp_inds = [ind for ind in range(ttl_lens) if ind not in test_inds]
38 | val_inds = random.sample(tmp_inds, int(ttl_lens * 0.2)) # 20% as validation data
39 | train_inds = [ind for ind in tmp_inds if ind not in val_inds] # 70% as training data
40 |
41 | trg_saver.store(trg_trajs[train_inds], os.path.join(train_data_dir, 'train_' + file_name))
42 | print("target traj train len: ", len(trg_trajs[train_inds]))
43 | trg_saver.store(trg_trajs[val_inds], os.path.join(val_data_dir, 'val_' + file_name))
44 | print("target traj val len: ", len(trg_trajs[val_inds]))
45 | trg_saver.store(trg_trajs[test_inds], os.path.join(test_data_dir, 'test_' + file_name))
46 | print("target traj test len: ", len(trg_trajs[test_inds]))
47 |
48 |
49 | class Dataset(torch.utils.data.Dataset):
50 | """
51 | customize a dataset for PyTorch
52 | """
53 |
54 | def __init__(self, trajs_dir, mbr, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, parameters, debug=True):
55 | self.mbr = mbr # MBR of all trajectories
56 | self.grid_size = parameters.grid_size
57 | self.time_span = parameters.time_span # time interval between two consecutive points.
58 | self.online_features_flag = parameters.online_features_flag
59 | self.src_grid_seqs, self.src_gps_seqs, self.src_pro_feas = [], [], []
60 | self.trg_gps_seqs, self.trg_rids, self.trg_rates = [], [], []
61 | self.new_tids = []
62 | # above should be [num_seq, len_seq(unpadded)]
63 | self.get_data(trajs_dir, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, parameters.win_size,
64 | parameters.ds_type, parameters.keep_ratio, debug)
65 | def __len__(self):
66 | """Denotes the total number of samples"""
67 | return len(self.src_grid_seqs)
68 |
69 | def __getitem__(self, index):
70 | """Generate one sample of data"""
71 | src_grid_seq = self.src_grid_seqs[index]
72 | src_gps_seq = self.src_gps_seqs[index]
73 | trg_gps_seq = self.trg_gps_seqs[index]
74 | trg_rid = self.trg_rids[index]
75 | trg_rate = self.trg_rates[index]
76 |
77 | src_grid_seq = self.add_token(src_grid_seq)
78 | src_gps_seq = self.add_token(src_gps_seq)
79 | trg_gps_seq = self.add_token(trg_gps_seq)
80 | trg_rid = self.add_token(trg_rid)
81 | trg_rate = self.add_token(trg_rate)
82 | src_pro_fea = torch.tensor(self.src_pro_feas[index])
83 |
84 | return src_grid_seq, src_gps_seq, src_pro_fea, trg_gps_seq, trg_rid, trg_rate
85 |
86 | def add_token(self, sequence):
87 | """
88 | Append start element(sos in NLP) for each sequence. And convert each list to tensor.
89 | """
90 | new_sequence = []
91 | dimension = len(sequence[0])
92 | start = [0] * dimension # pad 0 as start of rate sequence
93 | new_sequence.append(start)
94 | new_sequence.extend(sequence)
95 | new_sequence = torch.tensor(new_sequence)
96 | return new_sequence
97 |
98 | def get_data(self, trajs_dir, norm_grid_poi_dict, norm_grid_rnfea_dict,
99 | weather_dict, win_size, ds_type, keep_ratio, debug):
100 | parser = ParseMMTraj()
101 |
102 | if debug:
103 | trg_paths = os.listdir(trajs_dir)[:3]
104 | num = -1
105 | else:
106 | trg_paths = os.listdir(trajs_dir)
107 | num = -1
108 |
109 | for file_name in tqdm(trg_paths):
110 | trajs = parser.parse(os.path.join(trajs_dir, file_name))
111 | for traj in trajs[:num]:
112 | new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, \
113 | ls_grid_seq_ls, ls_gps_seq_ls, features_ls = self.parse_traj(traj, norm_grid_poi_dict, norm_grid_rnfea_dict,
114 | weather_dict, win_size, ds_type, keep_ratio)
115 | if new_tid_ls is not None:
116 | self.new_tids.extend(new_tid_ls)
117 | self.trg_gps_seqs.extend(mm_gps_seq_ls)
118 | self.trg_rids.extend(mm_eids_ls)
119 | self.trg_rates.extend(mm_rates_ls)
120 | self.src_grid_seqs.extend(ls_grid_seq_ls)
121 | self.src_gps_seqs.extend(ls_gps_seq_ls)
122 | self.src_pro_feas.extend(features_ls)
123 | assert len(new_tid_ls) == len(mm_gps_seq_ls) == len(mm_eids_ls) == len(mm_rates_ls)
124 |
125 | assert len(self.new_tids) == len(self.trg_gps_seqs) == len(self.trg_rids) == len(self.trg_rates) == \
126 | len(self.src_gps_seqs) == len(self.src_grid_seqs) == len(self.src_pro_feas), \
127 | 'The number of source and target sequence must be equal.'
128 |
129 |
130 | def parse_traj(self, traj, norm_grid_poi_dict, norm_grid_rnfea_dict, weather_dict, win_size, ds_type, keep_ratio):
131 | """
132 | Split traj based on length.
133 | Preprocess ground truth (map-matched) Trajectory(), get gps sequence, rid list and rate list.
134 | Down sample original Trajectory(), get ls_gps, ls_grid sequence and profile features
135 | Args:
136 | -----
137 | traj:
138 | Trajectory()
139 | win_size:
140 | window size of length for a single high sampling trajectory
141 | ds_type:
142 | ['uniform', 'random']
143 | uniform: sample GPS point every down_steps element.
144 | the down_step is calculated by 1/remove_ratio
145 | random: randomly sample (1-down_ratio)*len(old_traj) points by ascending.
146 | keep_ratio:
147 | float. range in (0,1). The ratio that keep GPS points to total points.
148 | Returns:
149 | --------
150 | new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, ls_grid_seq_ls, ls_gps_seq_ls, features_ls
151 | """
152 | new_trajs = self.get_win_trajs(traj, win_size)
153 |
154 | new_tid_ls = []
155 | mm_gps_seq_ls, mm_eids_ls, mm_rates_ls = [], [], []
156 | ls_grid_seq_ls, ls_gps_seq_ls, features_ls = [], [], []
157 |
158 | for tr in new_trajs:
159 | tmp_pt_list = tr.pt_list
160 | new_tid_ls.append(tr.tid)
161 |
162 | # get target sequence
163 | mm_gps_seq, mm_eids, mm_rates = self.get_trg_seq(tmp_pt_list)
164 | if mm_eids is None:
165 | return None, None, None, None, None, None, None
166 |
167 | # get source sequence
168 | ds_pt_list = self.downsample_traj(tmp_pt_list, ds_type, keep_ratio)
169 | ls_grid_seq, ls_gps_seq, hours, ttl_t = self.get_src_seq(ds_pt_list, norm_grid_poi_dict, norm_grid_rnfea_dict)
170 | features = self.get_pro_features(ds_pt_list, hours, weather_dict)
171 |
172 | # check if src and trg len equal, if not return none
173 | if len(mm_gps_seq) != ttl_t:
174 | return None, None, None, None, None, None, None
175 |
176 | mm_gps_seq_ls.append(mm_gps_seq)
177 | mm_eids_ls.append(mm_eids)
178 | mm_rates_ls.append(mm_rates)
179 | ls_grid_seq_ls.append(ls_grid_seq)
180 | ls_gps_seq_ls.append(ls_gps_seq)
181 | features_ls.append(features)
182 |
183 | return new_tid_ls, mm_gps_seq_ls, mm_eids_ls, mm_rates_ls, ls_grid_seq_ls, ls_gps_seq_ls, features_ls
184 |
185 |
186 | def get_win_trajs(self, traj, win_size):
187 | pt_list = traj.pt_list
188 | len_pt_list = len(pt_list)
189 | if len_pt_list < win_size:
190 | return [traj]
191 |
192 | num_win = len_pt_list // win_size
193 | last_traj_len = len_pt_list % win_size + 1
194 | new_trajs = []
195 | for w in range(num_win):
196 | # if last window is large enough then split to a single trajectory
197 | if w == num_win and last_traj_len > 15:
198 | tmp_pt_list = pt_list[win_size * w - 1:]
199 | # elif last window is not large enough then merge to the last trajectory
200 | elif w == num_win - 1 and last_traj_len <= 15:
201 | # fix bug, when num_win = 1
202 | ind = 0
203 | if win_size * w - 1 > 0:
204 | ind = win_size * w - 1
205 | tmp_pt_list = pt_list[ind:]
206 | # else split trajectories based on the window size
207 | else:
208 | tmp_pt_list = pt_list[max(0, (win_size * w - 1)):win_size * (w + 1)]
209 | # -1 to make sure the overlap between two trajs
210 |
211 | new_traj = Trajectory(traj.oid, get_tid(traj.oid, tmp_pt_list), tmp_pt_list)
212 | new_trajs.append(new_traj)
213 | return new_trajs
214 |
215 | def get_trg_seq(self, tmp_pt_list):
216 | mm_gps_seq = []
217 | mm_eids = []
218 | mm_rates = []
219 | for pt in tmp_pt_list:
220 | candi_pt = pt.data['candi_pt']
221 | if candi_pt is None:
222 | return None, None, None
223 | else:
224 | mm_gps_seq.append([candi_pt.lat, candi_pt.lng])
225 | mm_eids.append([candi_pt.eid]) # keep the same format as seq
226 | mm_rates.append([candi_pt.rate])
227 | return mm_gps_seq, mm_eids, mm_rates
228 |
229 |
230 | def get_src_seq(self, ds_pt_list, norm_grid_poi_dict, norm_grid_rnfea_dict):
231 | hours = []
232 | ls_grid_seq = []
233 | ls_gps_seq = []
234 | first_pt = ds_pt_list[0]
235 | last_pt = ds_pt_list[-1]
236 | time_interval = self.time_span
237 | ttl_t = self.get_noramlized_t(first_pt, last_pt, time_interval)
238 | for ds_pt in ds_pt_list:
239 | hours.append(ds_pt.time.hour)
240 | t = self.get_noramlized_t(first_pt, ds_pt, time_interval)
241 | ls_gps_seq.append([ds_pt.lat, ds_pt.lng])
242 | locgrid_xid, locgrid_yid = self.gps2grid(ds_pt, self.mbr, self.grid_size)
243 | if self.online_features_flag:
244 | poi_features = norm_grid_poi_dict[(locgrid_xid, locgrid_yid)]
245 | rn_features = norm_grid_rnfea_dict[(locgrid_xid, locgrid_yid)]
246 | ls_grid_seq.append([locgrid_xid, locgrid_yid, t]+poi_features+rn_features)
247 | else:
248 | ls_grid_seq.append([locgrid_xid, locgrid_yid, t])
249 |
250 | return ls_grid_seq, ls_gps_seq, hours, ttl_t
251 |
252 |
253 | def get_pro_features(self, ds_pt_list, hours, weather_dict):
254 | holiday = is_holiday(ds_pt_list[0].time)*1
255 | day = ds_pt_list[0].time.day
256 | hour = {'hour': np.bincount(hours).max()} # find most frequent hours as hour of the trajectory
257 | weather = {'weather': weather_dict[(day, hour['hour'])]}
258 | features = self.one_hot(hour) + self.one_hot(weather) + [holiday]
259 | return features
260 |
261 |
262 | def gps2grid(self, pt, mbr, grid_size):
263 | """
264 | mbr:
265 | MBR class.
266 | grid size:
267 | int. in meter
268 | """
269 | LAT_PER_METER = 8.993203677616966e-06
270 | LNG_PER_METER = 1.1700193970443768e-05
271 | lat_unit = LAT_PER_METER * grid_size
272 | lng_unit = LNG_PER_METER * grid_size
273 |
274 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1
275 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1
276 |
277 | lat = pt.lat
278 | lng = pt.lng
279 | locgrid_x = int((lat - mbr.min_lat) / lat_unit) + 1
280 | locgrid_y = int((lng - mbr.min_lng) / lng_unit) + 1
281 |
282 | return locgrid_x, locgrid_y
283 |
284 |
285 | def get_noramlized_t(self, first_pt, current_pt, time_interval):
286 | """
287 | calculate normalized t from first and current pt
288 | return time index (normalized time)
289 | """
290 | t = int(1+((current_pt.time - first_pt.time).seconds/time_interval))
291 | return t
292 |
293 | @staticmethod
294 | def get_distance(pt_list):
295 | dist = 0.0
296 | pre_pt = pt_list[0]
297 | for pt in pt_list[1:]:
298 | tmp_dist = distance(pre_pt, pt)
299 | dist += tmp_dist
300 | pre_pt = pt
301 | return dist
302 |
303 |
304 | @staticmethod
305 | def downsample_traj(pt_list, ds_type, keep_ratio):
306 | """
307 | Down sample trajectory
308 | Args:
309 | -----
310 | pt_list:
311 | list of Point()
312 | ds_type:
313 | ['uniform', 'random']
314 | uniform: sample GPS point every down_stepth element.
315 | the down_step is calculated by 1/remove_ratio
316 | random: randomly sample (1-down_ratio)*len(old_traj) points by ascending.
317 | keep_ratio:
318 | float. range in (0,1). The ratio that keep GPS points to total points.
319 | Returns:
320 | -------
321 | traj:
322 | new Trajectory()
323 | """
324 | assert ds_type in ['uniform', 'random'], 'only `uniform` or `random` is supported'
325 |
326 | old_pt_list = pt_list.copy()
327 | start_pt = old_pt_list[0]
328 | end_pt = old_pt_list[-1]
329 |
330 | if ds_type == 'uniform':
331 | if (len(old_pt_list) - 1) % int(1 / keep_ratio) == 0:
332 | new_pt_list = old_pt_list[::int(1 / keep_ratio)]
333 | else:
334 | new_pt_list = old_pt_list[::int(1 / keep_ratio)] + [end_pt]
335 | elif ds_type == 'random':
336 | sampled_inds = sorted(
337 | random.sample(range(1, len(old_pt_list) - 1), int((len(old_pt_list) - 2) * keep_ratio)))
338 | new_pt_list = [start_pt] + list(np.array(old_pt_list)[sampled_inds]) + [end_pt]
339 |
340 | return new_pt_list
341 |
342 |
343 |
344 |
345 | @staticmethod
346 | def one_hot(data):
347 | one_hot_dict = {'hour': 24, 'weekday': 7, 'weather':5}
348 | for k, v in data.items():
349 | encoded_data = [0] * one_hot_dict[k]
350 | encoded_data[v - 1] = 1
351 | return encoded_data
352 |
353 |
354 | # Use for DataLoader
355 | def collate_fn(data):
356 | """
357 | Reference: https://github.com/yunjey/seq2seq-dataloader/blob/master/data_loader.py
358 | Creates mini-batch tensors from the list of tuples (src_seq, src_pro_fea, trg_seq, trg_rid, trg_rate).
359 | We should build a custom collate_fn rather than using default collate_fn,
360 | because merging sequences (including padding) is not supported in default.
361 | Sequences are padded to the maximum length of mini-batch sequences (dynamic padding).
362 | Args:
363 | -----
364 | data: list of tuple (src_seq, src_pro_fea, trg_seq, trg_rid, trg_rate), from dataset.__getitem__().
365 | - src_seq: torch tensor of shape (?,2); variable length.
366 | - src_pro_fea: torch tensor of shape (1,64) # concatenate all profile features
367 | - trg_seq: torch tensor of shape (??,2); variable length.
368 | - trg_rid: torch tensor of shape (??); variable length.
369 | - trg_rate: torch tensor of shape (??); variable length.
370 | Returns:
371 | --------
372 | src_grid_seqs:
373 | torch tensor of shape (batch_size, padded_length, 3)
374 | src_gps_seqs:
375 | torch tensor of shape (batch_size, padded_length, 3).
376 | src_pro_feas:
377 | torch tensor of shape (batch_size, feature_dim) unnecessary to pad
378 | src_lengths:
379 | list of length (batch_size); valid length for each padded source sequence.
380 | trg_seqs:
381 | torch tensor of shape (batch_size, padded_length, 2).
382 | trg_rids:
383 | torch tensor of shape (batch_size, padded_length, 1).
384 | trg_rates:
385 | torch tensor of shape (batch_size, padded_length, 1).
386 | trg_lengths:
387 | list of length (batch_size); valid length for each padded target sequence.
388 | """
389 |
390 | def merge(sequences):
391 | lengths = [len(seq) for seq in sequences]
392 | dim = sequences[0].size(1) # get dim for each sequence
393 | padded_seqs = torch.zeros(len(sequences), max(lengths), dim)
394 | for i, seq in enumerate(sequences):
395 | end = lengths[i]
396 | padded_seqs[i, :end] = seq[:end]
397 | return padded_seqs, lengths
398 |
399 | # sort a list by source sequence length (descending order) to use pack_padded_sequence
400 | data.sort(key=lambda x: len(x[0]), reverse=True)
401 |
402 | # seperate source and target sequences
403 | src_grid_seqs, src_gps_seqs, src_pro_feas, trg_gps_seqs, trg_rids, trg_rates = zip(*data) # unzip data
404 |
405 | # merge sequences (from tuple of 1D tensor to 2D tensor)
406 | src_grid_seqs, src_lengths = merge(src_grid_seqs)
407 | src_gps_seqs, src_lengths = merge(src_gps_seqs)
408 | src_pro_feas = torch.tensor([list(src_pro_fea) for src_pro_fea in src_pro_feas])
409 | trg_gps_seqs, trg_lengths = merge(trg_gps_seqs)
410 | trg_rids, _ = merge(trg_rids)
411 | trg_rates, _ = merge(trg_rates)
412 |
413 | return src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths
414 |
415 |
--------------------------------------------------------------------------------
/models/loss_fn.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/11/25 18:06
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from map_matching.candidate_point import CandidatePoint
12 | from map_matching.utils import find_shortest_path
13 | from common.spatial_func import SPoint, distance
14 |
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 |
17 | def check_dis_loss(predict, target, trg_len):
18 | """
19 | Calculate MAE and RMSE between predicted and targeted GPS sequence.
20 | Args:
21 | -----
22 | predict = [seq len, batch size, 2]
23 | target = [seq len, batch size, 2]
24 | trg_len = [batch size] if not considering target length, the loss will smaller than the real one.
25 | predict and target have been removed sos
26 | Returns:
27 | -------
28 | MAE of a batch in meter.
29 | RMSE of a batch in meter.
30 | """
31 | predict = predict.permute(1, 0, 2) # [batch size, seq len, 2]
32 | target = target.permute(1, 0, 2) # [batch size, seq len, 2]
33 | bs = predict.size(0)
34 |
35 | ls_dis = []
36 | for bs_i in range(bs):
37 | for len_i in range(trg_len[bs_i]-1):
38 | pre = SPoint(predict[bs_i, len_i][0], predict[bs_i, len_i][1])
39 | trg = SPoint(target[bs_i, len_i][0], target[bs_i, len_i][1])
40 | dis = distance(pre, trg)
41 | ls_dis.append(dis)
42 |
43 | ls_dis = np.array(ls_dis)
44 | mae = ls_dis.mean()
45 | rmse = np.sqrt((ls_dis**2).mean())
46 | return mae, rmse
47 |
48 |
49 | def check_rn_dis_loss(predict_gps, predict_id, predict_rate, target_gps, target_id, target_rate, trg_len,
50 | rn, raw_rn_dict, new2raw_rid_dict):
51 | """
52 | Calculate road network based MAE and RMSE between predicted and targeted GPS sequence.
53 | Args:
54 | -----
55 | predict_gps = [seq len, batch size, 2]
56 | predict_id = [seq len, batch size, id one hot output dim]
57 | predict_rates = [seq len, batch size]
58 | target_gps = [seq len, batch size, 2]
59 | target_id = [seq len, batch size]
60 | target_rates = [seq len, batch size]
61 | trg_len = [batch size] if not considering target length, the loss will smaller than the real one.
62 |
63 | predict and target have been removed sos
64 | Returns:
65 | -------
66 | MAE of a batch in meter.
67 | RMSE of a batch in meter.
68 | """
69 | seq_len = target_id.size(0)
70 | batch_size = target_id.size(1)
71 | predict_gps = predict_gps.permute(1, 0, 2)
72 | predict_id = predict_id.permute(1, 0, 2)
73 | predict_rate = predict_rate.permute(1, 0)
74 | target_gps = target_gps.permute(1, 0, 2)
75 | target_id = target_id.permute(1, 0)
76 | target_rate = target_rate.permute(1, 0)
77 |
78 | ls_dis, rn_ls_dis = [], []
79 | for bs in range(batch_size):
80 | for len_i in range(trg_len[bs]-1): # don't calculate padding points
81 | pre_rid = predict_id[bs, len_i].argmax()
82 | convert_pre_rid = new2raw_rid_dict[pre_rid.tolist()]
83 | pre_rate = predict_rate[bs, len_i]
84 | pre_offset = raw_rn_dict[convert_pre_rid]['length'] * pre_rate
85 | pre_candi_pt = CandidatePoint(predict_gps[bs,len_i][0], predict_gps[bs,len_i][1],
86 | convert_pre_rid, 0, pre_offset, pre_rate)
87 |
88 | trg_rid = target_id[bs, len_i]
89 | convert_trg_rid = new2raw_rid_dict[trg_rid.tolist()]
90 | trg_rate = target_rate[bs, len_i]
91 | trg_offset = raw_rn_dict[convert_trg_rid]['length'] * trg_rate
92 | trg_candi_pt = CandidatePoint(target_gps[bs,len_i][0], target_gps[bs,len_i][1],
93 | convert_trg_rid, 0, trg_offset, trg_rate)
94 |
95 | if pre_candi_pt.lat == trg_candi_pt.lat and pre_candi_pt.lng == trg_candi_pt.lng:
96 | rn_dis = 0
97 | dis = 0
98 | else:
99 | rn_dis, _ = min(find_shortest_path(rn, pre_candi_pt, trg_candi_pt),
100 | find_shortest_path(rn, trg_candi_pt, pre_candi_pt))
101 | if type(rn_dis) is not float:
102 | rn_dis = rn_dis.tolist()
103 | dis = distance(pre_candi_pt, trg_candi_pt)
104 |
105 | if rn_dis == np.inf:
106 | rn_dis = 1000
107 | rn_ls_dis.append(rn_dis)
108 | ls_dis.append(dis)
109 |
110 | ls_dis = np.array(ls_dis)
111 | rn_ls_dis = np.array(rn_ls_dis)
112 |
113 | mae = ls_dis.mean()
114 | rmse = np.sqrt((ls_dis**2).mean())
115 | rn_mae = rn_ls_dis.mean()
116 | rn_rmse = np.sqrt((rn_ls_dis**2).mean())
117 | return mae, rmse, rn_mae, rn_rmse
118 |
119 |
120 | def shrink_seq(seq):
121 | """remove repeated ids"""
122 | s0 = seq[0]
123 | new_seq = [s0]
124 | for s in seq[1:]:
125 | if s == s0:
126 | continue
127 | else:
128 | new_seq.append(s)
129 | s0 = s
130 |
131 | return new_seq
132 |
133 | def memoize(fn):
134 | '''Return a memoized version of the input function.
135 |
136 | The returned function caches the results of previous calls.
137 | Useful if a function call is expensive, and the function
138 | is called repeatedly with the same arguments.
139 | '''
140 | cache = dict()
141 | def wrapped(*v):
142 | key = tuple(v) # tuples are hashable, and can be used as dict keys
143 | if key not in cache:
144 | cache[key] = fn(*v)
145 | return cache[key]
146 | return wrapped
147 |
148 | def lcs(xs, ys):
149 | '''Return the longest subsequence common to xs and ys.
150 |
151 | Example
152 | >>> lcs("HUMAN", "CHIMPANZEE")
153 | ['H', 'M', 'A', 'N']
154 | '''
155 | @memoize
156 | def lcs_(i, j):
157 | if i and j:
158 | xe, ye = xs[i-1], ys[j-1]
159 | if xe == ye:
160 | return lcs_(i-1, j-1) + [xe]
161 | else:
162 | return max(lcs_(i, j-1), lcs_(i-1, j), key=len)
163 | else:
164 | return []
165 | return lcs_(len(xs), len(ys))
166 |
167 |
168 | def cal_id_acc(predict, target, trg_len):
169 | """
170 | Calculate RID accuracy between predicted and targeted RID sequence.
171 | 1. no repeated rid for two consecutive road segments
172 | 2. longest common subsequence
173 | http://wordaligned.org/articles/longest-common-subsequence
174 | Args:
175 | -----
176 | predict = [seq len, batch size, id one hot output dim]
177 | target = [seq len, batch size, 1]
178 | predict and target have been removed sos
179 | Returns:
180 | -------
181 | mean matched RID accuracy.
182 | """
183 | predict = predict.permute(1, 0, 2) # [batch size, seq len, id dim]
184 | target = target.permute(1, 0) # [batch size, seq len, 1]
185 | bs = predict.size(0)
186 |
187 | correct_id_num = 0
188 | ttl_trg_id_num = 0
189 | ttl_pre_id_num = 0
190 | ttl = 0
191 | cnt = 0
192 | for bs_i in range(bs):
193 | pre_ids = []
194 | trg_ids = []
195 | # -1 because predict and target are removed sos.
196 | for len_i in range(trg_len[bs_i] - 1):
197 | pre_id = predict[bs_i][len_i].argmax()
198 | trg_id = target[bs_i][len_i]
199 | pre_ids.append(pre_id)
200 | trg_ids.append(trg_id)
201 | if pre_id == trg_id:
202 | cnt += 1
203 | ttl += 1
204 |
205 | # compute average rid accuracy
206 | shr_trg_ids = shrink_seq(trg_ids)
207 | shr_pre_ids = shrink_seq(pre_ids)
208 | correct_id_num += len(lcs(shr_trg_ids, shr_pre_ids))
209 | ttl_trg_id_num += len(shr_trg_ids)
210 | ttl_pre_id_num += len(shr_pre_ids)
211 |
212 | rid_acc = cnt / ttl
213 | rid_recall = correct_id_num / ttl_trg_id_num
214 | rid_precision = correct_id_num / ttl_pre_id_num
215 | return rid_acc, rid_recall, rid_precision
--------------------------------------------------------------------------------
/models/model_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/10/20 18:24
4 |
5 |
6 | import torch
7 | import math
8 | import numpy as np
9 | import pandas as pd
10 | import networkx as nx
11 |
12 | from common.spatial_func import distance, cal_loc_along_line, SPoint
13 | from map_matching.candidate_point import get_candidates, CandidatePoint
14 |
15 | from utils.utils import load_json_data
16 |
17 |
18 | #####################################################################################################
19 | #
20 | # Load Files
21 | #
22 | #####################################################################################################
23 |
24 | def load_rid_freqs(dir, file_name):
25 | """
26 | load rid freqs and convert key from str to int
27 | """
28 | rid_freqs = load_json_data(dir, file_name)
29 | rid_freqs = {int(k): int(v) for k, v in rid_freqs.items()} # convert key from str to int
30 |
31 | return rid_freqs
32 |
33 |
34 | def load_rn_dict(dir, file_name):
35 | """
36 | This function will be use in rate2gps.
37 | """
38 | rn_dict = load_json_data(dir, file_name)
39 | new_rn_dict = {}
40 | for k, v in rn_dict.items():
41 | new_rn_dict[int(k)] = {}
42 | new_rn_dict[int(k)]['coords'] = [SPoint(coord[0], coord[1]) for coord in v['coords']]
43 | # convert str to SPoint() to calculate distance
44 | new_rn_dict[int(k)]['length'] = v['length']
45 | new_rn_dict[int(k)]['level'] = v['level']
46 | del rn_dict
47 | return new_rn_dict
48 |
49 | def load_online_features(dir, file_name):
50 | """
51 | load POI or road network and covert key from str to tuple
52 | """
53 | data = load_json_data(dir, file_name)
54 | data = {}
55 |
56 | return data
57 |
58 | #####################################################################################################
59 | #
60 | # RID + Rate 2 GPS
61 | #
62 | #####################################################################################################
63 | def rate2gps(rn_dict, eid, rate, parameters):
64 | """
65 | Convert road rate to GPS on the road segment.
66 | Since one road contains several coordinates, iteratively computing length can be more accurate.
67 | Args:
68 | -----
69 | rn_dict:
70 | dictionary of road network
71 | eid,rate:
72 | single value from model prediction
73 | Returns:
74 | --------
75 | project_pt:
76 | projected GPS point on the road segment.
77 | """
78 | eid = eid.tolist() # convert tensor to normal value
79 | rate = rate.tolist()
80 | if eid <= 0 or rate < 0 or eid > (parameters.id_size-1) or rate > 1:
81 | # force eid and rate in the right range
82 | return SPoint(0, 0)
83 |
84 | coords = rn_dict[eid]['coords']
85 | offset = rn_dict[eid]['length'] * rate
86 | dist = 0 # temp distance for coords
87 | pre_dist = 0 # coords distance is smaller than offset
88 |
89 | if rate == 1.0:
90 | return coords[-1]
91 | if rate == 0.0:
92 | return coords[0]
93 |
94 | for i in range(len(coords) - 1):
95 | if i > 0:
96 | pre_dist += distance(coords[i - 1], coords[i])
97 | dist += distance(coords[i], coords[i + 1])
98 | if dist >= offset:
99 | coor_rate = (offset - pre_dist) / distance(coords[i], coords[i + 1])
100 | project_pt = cal_loc_along_line(coords[i], coords[i + 1], coor_rate)
101 | break
102 |
103 | return project_pt
104 |
105 |
106 | def toseq(rn_dict, rids, rates, paramters):
107 | """
108 | Convert batched rids and rates to gps sequence.
109 | Args:
110 | -----
111 | rn_dict:
112 | use for rate2gps()
113 | rids:
114 | [trg len, batch size, id one hot dim]
115 | rates:
116 | [trg len, batch size]
117 | Returns:
118 | --------
119 | seqs:
120 | [trg len, batch size, 2]
121 | """
122 | batch_size = rids.size(1)
123 | trg_len = rids.size(0)
124 | seqs = torch.zeros(trg_len, batch_size, 2).to(paramters.device)
125 |
126 | for i in range(1, trg_len):
127 | for bs in range(batch_size):
128 | rid = rids[i][bs].argmax()
129 | rate = rates[i][bs]
130 | pt = rate2gps(rn_dict, rid, rate, paramters)
131 | seqs[i][bs][0] = pt.lat
132 | seqs[i][bs][1] = pt.lng
133 | return seqs
134 |
135 |
136 | #####################################################################################################
137 | #
138 | # Constraint mask
139 | #
140 | #####################################################################################################
141 | def get_rid_grid(mbr, grid_size, rn_dict):
142 | """
143 | Create a dict {key: grid id, value: rid}
144 | """
145 | LAT_PER_METER = 8.993203677616966e-06
146 | LNG_PER_METER = 1.1700193970443768e-05
147 | lat_unit = LAT_PER_METER * grid_size
148 | lng_unit = LNG_PER_METER * grid_size
149 |
150 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1
151 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1
152 |
153 | grid_rn_dict = {}
154 | for k, v in rn_dict.items():
155 | pre_lat = v['coords'][0].lat
156 | pre_lng = v['coords'][0].lng
157 | pre_locgrid_x = max(1, int((pre_lat - mbr.min_lat) / lat_unit) + 1)
158 | pre_locgrid_y = max(1, int((pre_lng - mbr.min_lng) / lng_unit) + 1)
159 |
160 |
161 | if (pre_locgrid_x, pre_locgrid_y) not in grid_rn_dict.keys():
162 | grid_rn_dict[(pre_locgrid_x, pre_locgrid_y)] = [k]
163 | else:
164 | grid_rn_dict[(pre_locgrid_x, pre_locgrid_y)].append(k)
165 |
166 | for coord in v['coords'][1:]:
167 | lat = coord.lat
168 | lng = coord.lng
169 | locgrid_x = max(1, int((lat - mbr.min_lat) / lat_unit) + 1)
170 | locgrid_y = max(1, int((lng - mbr.min_lng) / lng_unit) + 1)
171 |
172 | if (locgrid_x, locgrid_y) not in grid_rn_dict.keys():
173 | grid_rn_dict[(locgrid_x, locgrid_y)] = [k]
174 | else:
175 | grid_rn_dict[(locgrid_x, locgrid_y)].append(k)
176 |
177 | mid_x_num = abs(locgrid_x - pre_locgrid_x)
178 | mid_y_num = abs(locgrid_y - pre_locgrid_y)
179 |
180 | if mid_x_num > 1 and mid_y_num <= 1:
181 | for mid_x in range(1, mid_x_num):
182 | if (min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y) not in grid_rn_dict.keys():
183 | grid_rn_dict[(min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y)] = [k]
184 | else:
185 | grid_rn_dict[(min(pre_locgrid_x,locgrid_x)+mid_x, locgrid_y)].append(k)
186 |
187 | elif mid_x_num <= 1 and mid_y_num > 1:
188 | for mid_y in range(1, mid_y_num):
189 | if (locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y) not in grid_rn_dict.keys():
190 | grid_rn_dict[(locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y)] = [k]
191 | else:
192 | grid_rn_dict[(locgrid_x, min(pre_locgrid_y,locgrid_y)+mid_y)].append(k)
193 |
194 | elif mid_x_num > 1 and mid_y_num > 1:
195 | ttl_num = mid_x_num + mid_y_num + 1
196 | for mid in range(1, ttl_num):
197 | mid_xid = min(lat, pre_lat) + mid*abs(lat - pre_lat)/ttl_num
198 | mid_yid = min(lng, pre_lng) + mid*abs(lng - pre_lng)/ttl_num
199 |
200 | pre_lat = lat
201 | pre_lng = lng
202 | pre_locgrid_x = locgrid_x
203 | pre_locgrid_y = locgrid_y
204 |
205 | for k, v in grid_rn_dict.items():
206 | grid_rn_dict[k] = list(set(v))
207 |
208 | return grid_rn_dict, max_xid, max_yid
209 |
210 | def exp_prob(beta, x):
211 | """
212 | error distance weight.
213 | """
214 | return math.exp(-pow(x,2)/pow(beta,2))
215 |
216 | def get_reachable_inds(pre_grid, cur_grid, grid_rn_dict,time_diff, parameters):
217 | reachable_inds = list(range(parameters.id_size))
218 |
219 | return reachable_inds
220 |
221 | def get_dis_prob_vec(gps, rn, raw2new_rid_dict, parameters):
222 | """
223 | Args:
224 | -----
225 | gps: [SPoint, tid]
226 | """
227 | cons_vec = torch.zeros(parameters.id_size) + 1e-10
228 | candis = get_candidates(gps[0], rn, parameters.search_dist)
229 | if candis is not None:
230 | for candi_pt in candis:
231 | if candi_pt.eid in raw2new_rid_dict.keys():
232 | new_rid = raw2new_rid_dict[candi_pt.eid]
233 | prob = exp_prob(parameters.beta, candi_pt.error)
234 | cons_vec[new_rid] = prob
235 | else:
236 | cons_vec = torch.ones(parameters.id_size)
237 | return cons_vec
238 |
239 | def get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths, trg_lengths, grid_rn_dict, rn, raw2new_rid_dict, parameters):
240 | max_trg_len = max(trg_lengths)
241 | batch_size = src_grid_seqs.size(0)
242 |
243 | constraint_mat = torch.zeros(batch_size, max_trg_len, parameters.id_size) + 1e-10
244 | pre_grids = torch.zeros(batch_size, max_trg_len, 3)
245 | cur_grids = torch.zeros(batch_size, max_trg_len, 3)
246 |
247 | for bs in range(batch_size):
248 | # first src gps
249 | pre_t = 1
250 | pre_grid = [int(src_grid_seqs[bs][pre_t][0].tolist()),
251 | int(src_grid_seqs[bs][pre_t][1].tolist()),
252 | pre_t]
253 | pre_gps = [SPoint(src_gps_seqs[bs][pre_t][0].tolist(),
254 | src_gps_seqs[bs][pre_t][1].tolist()),
255 | pre_t]
256 | pre_grids[bs, pre_t] = torch.tensor(pre_grid)
257 | cur_grids[bs, pre_t] = torch.tensor(pre_grid)
258 |
259 | if parameters.dis_prob_mask_flag:
260 | cons_vec = get_dis_prob_vec(pre_gps, rn, raw2new_rid_dict, parameters)
261 | constraint_mat[bs][pre_t] = cons_vec
262 | else:
263 | reachable_inds = get_reachable_inds(pre_grid, pre_grid, grid_rn_dict, 0, parameters)
264 | constraint_mat[bs][pre_t][reachable_inds] = 1
265 |
266 | # missed gps
267 | for i in range(2, src_lengths[bs]):
268 | cur_t = int(src_grid_seqs[bs,i,2].tolist())
269 | cur_grid = [int(src_grid_seqs[bs][i][0].tolist()),
270 | int(src_grid_seqs[bs][i][1].tolist()),
271 | cur_t]
272 | cur_gps = [SPoint(src_gps_seqs[bs][i][0].tolist(),
273 | src_gps_seqs[bs][i][1].tolist()),
274 | cur_t]
275 | pre_grids[bs, cur_t] = torch.tensor(cur_grid)
276 | cur_grids[bs, cur_t] = torch.tensor(cur_grid)
277 |
278 | time_diff = cur_t - pre_t
279 | reachable_inds = get_reachable_inds(pre_grid, cur_grid, grid_rn_dict, time_diff, parameters)
280 |
281 | for t in range(pre_t+1, cur_t):
282 | constraint_mat[bs][t][reachable_inds] = 1
283 | pre_grids[bs, t] = torch.tensor(pre_grid)
284 | cur_grids[bs, t] = torch.tensor(cur_grid)
285 |
286 | # middle src gps
287 | if parameters.dis_prob_mask_flag:
288 | cons_vec = get_dis_prob_vec(cur_gps, rn, raw2new_rid_dict, parameters)
289 | constraint_mat[bs][cur_t] = cons_vec
290 | else:
291 | reachable_inds = get_reachable_inds(cur_grid, cur_grid, grid_rn_dict, 0, parameters)
292 | constraint_mat[bs][cur_t][reachable_inds] = 1
293 |
294 | pre_t = cur_t
295 | pre_grid = cur_grid
296 | pre_gps = cur_gps
297 |
298 | return constraint_mat, pre_grids, cur_grids
299 |
300 | #####################################################################################################
301 | #
302 | # Use for extracting POI features
303 | #
304 | #####################################################################################################
305 | def get_poi_info(grid_poi_df, parameters):
306 | """
307 | ['company','food', 'gym', 'education','shopping','gov', 'viewpoint','entrance','house','life',
308 | 'traffic','car','hotel','beauty','hospital','media','finance','entertainment','road','nature','landmark','address']
309 | """
310 | types = parameters.poi_type.split(',')
311 | norm_grid_poi_df=(grid_poi_df[types]-grid_poi_df[types].min())/(grid_poi_df[types].max()-grid_poi_df[types].min())
312 | norm_grid_poi_df = norm_grid_poi_df.fillna(0)
313 |
314 | norm_grid_poi_dict = {}
315 | for i in range(len(norm_grid_poi_df)):
316 | k = norm_grid_poi_df.index[i]
317 | v = norm_grid_poi_df.iloc[i].values
318 | norm_grid_poi_dict[k] = list(v)
319 |
320 | for xid in range(1, parameters.max_xid+1):
321 | for yid in range(1, parameters.max_yid+1):
322 | if (xid,yid) not in norm_grid_poi_dict.keys():
323 | norm_grid_poi_dict[(xid,yid)] = [0.] * len(types)
324 | return norm_grid_poi_dict
325 |
326 | #####################################################################################################
327 | #
328 | # Use for extracting RN features
329 | #
330 | #####################################################################################################
331 | def get_edge_results(eids, rn_dict):
332 | edge_results = []
333 | for eid in eids:
334 | u = rn_dict[eid]['coords'][0]
335 | v = rn_dict[eid]['coords'][-1]
336 | edge_results.append(((u.lng,u.lat),(v.lng,v.lat)))
337 | return edge_results
338 |
339 | def extract_single_rn_features(edge_results, rn):
340 | part_g = nx.Graph()
341 | for u, v in edge_results:
342 | part_g.add_edge(u, v, **rn[u][v])
343 |
344 | tot_length = 0.0
345 | level_2_cnt = 0
346 | level_3_cnt = 0
347 | level_4_cnt = 0
348 | for u, v, data in part_g.edges(data=True):
349 | tot_length += data['length']
350 | if data['highway'] == 'trunk':
351 | level_2_cnt += 1
352 | elif data['highway'] == 'primary':
353 | level_3_cnt += 1
354 | elif data['highway'] == 'secondary':
355 | level_4_cnt += 1
356 | nb_intersections = 0
357 | for node, degree in part_g.degree():
358 | if degree > 2:
359 | nb_intersections += 1
360 |
361 | rn_features = np.array([tot_length, nb_intersections, level_2_cnt, level_3_cnt, level_4_cnt])
362 |
363 | return rn_features
364 |
365 | def get_rn_info(rn, mbr, grid_size, grid_rn_dict, rn_dict):
366 | """
367 | rn_dict contains rn information
368 | """
369 | LAT_PER_METER = 8.993203677616966e-06
370 | LNG_PER_METER = 1.1700193970443768e-05
371 | lat_unit = LAT_PER_METER * grid_size
372 | lng_unit = LNG_PER_METER * grid_size
373 |
374 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1
375 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1
376 |
377 | grid_rnfea_dict = {}
378 | for k,v in grid_rn_dict.items():
379 | eids = grid_rn_dict[k]
380 | edge_results = get_edge_results(eids, rn_dict)
381 | grid_rnfea_dict[k] = extract_single_rn_features(edge_results, rn)
382 |
383 | grid_rnfea_df = pd.DataFrame(grid_rnfea_dict).T
384 | norm_grid_rnfea_df=(grid_rnfea_df-grid_rnfea_df.min())/(grid_rnfea_df.max()-grid_rnfea_df.min()) # col norm
385 |
386 | norm_grid_rnfea_dict = {}
387 | for i in range(len(norm_grid_rnfea_df)):
388 | k = norm_grid_rnfea_df.index[i]
389 | v = norm_grid_rnfea_df.iloc[i].values
390 | norm_grid_rnfea_dict[k] = list(v)
391 |
392 | for xid in range(1, max_xid+1):
393 | for yid in range(1, max_yid+1):
394 | if (xid,yid) not in norm_grid_rnfea_dict.keys():
395 | norm_grid_rnfea_dict[(xid,yid)] = [0.] * len(v)
396 |
397 | return norm_grid_rnfea_dict
398 |
399 | def get_rid_rnfea_dict(rn_dict, parameters):
400 | df = pd.DataFrame(rn_dict).T
401 |
402 | # standardization length
403 | df['norm_len'] = [np.log10(l) /np.log10(df['length'].max()) for l in df['length']]
404 | # df['norm_len'] = (df['length'] - df['length'].mean())/df['length'].std()
405 |
406 | # one hot road level
407 | one_hot_df = pd.get_dummies(df.level, prefix='level')
408 | df = df.join(one_hot_df)
409 |
410 | # get number of neighbours (standardization)
411 | g = nx.Graph()
412 | edges = []
413 | for coords in df['coords'].values:
414 | start_node = (coords[0].lat, coords[0].lng)
415 | end_node = (coords[-1].lat, coords[-1].lng)
416 | edges.append((start_node, end_node))
417 | g.add_edges_from(edges)
418 |
419 | num_start_neighbors = []
420 | num_end_neighbors = []
421 | for coords in df['coords'].values:
422 | start_node = (coords[0].lat, coords[0].lng)
423 | end_node = (coords[-1].lat, coords[-1].lng)
424 | num_start_neighbors.append(len(list(g.edges(start_node))))
425 | num_end_neighbors.append(len(list(g.edges(end_node))))
426 | df['num_start_neighbors'] = num_start_neighbors
427 | df['num_end_neighbors'] = num_end_neighbors
428 | start = df['num_start_neighbors']
429 | end = df['num_end_neighbors']
430 | # distribution is like gaussian --> use min max normalization
431 | df['norm_num_start_neighbors'] = (start - start.min())/(start.max() - start.min())
432 | df['norm_num_end_neighbors'] = (end - end.min())/(end.max() - end.min())
433 |
434 | # convert to dict
435 | norm_rid_rnfea_dict = {}
436 | for i in range(len(df)):
437 | k = df.index[i]
438 | v = df.iloc[i][['norm_len','level_2','level_3','level_4','level_5','level_6',\
439 | 'norm_num_start_neighbors','norm_num_end_neighbors']]
440 | norm_rid_rnfea_dict[k] = list(v)
441 |
442 | norm_rid_rnfea_dict[0] = [0.]*len(list(v)) # add soss
443 | return norm_rid_rnfea_dict
444 |
445 | #####################################################################################################
446 | #
447 | # Use for online features
448 | #
449 | #####################################################################################################
450 | def get_rid_grid_dict(grid_rn_dict):
451 | rid_grid_dict = {}
452 | for k, v in grid_rn_dict.items():
453 | for rid in v:
454 | if rid not in rid_grid_dict:
455 | rid_grid_dict[rid] = [k]
456 | else:
457 | rid_grid_dict[rid].append(k)
458 |
459 | for k,v in rid_grid_dict.items():
460 | rid_grid_dict[k] = list(set(v))
461 |
462 | return rid_grid_dict
463 |
464 | def get_online_info_dict(grid_rn_dict, norm_grid_poi_dict, norm_grid_rnfea_dict, parameters):
465 | rid_grid_dict = get_rid_grid_dict(grid_rn_dict)
466 | online_features_dict = {}
467 | for rid in rid_grid_dict.keys():
468 | online_feas = []
469 | for grid in rid_grid_dict[rid]:
470 | try:
471 | poi = norm_grid_poi_dict[grid]
472 | except:
473 | poi = [0.]*5
474 | try:
475 | rnfea = norm_grid_rnfea_dict[grid]
476 | except:
477 | rnfea = [0.]*5
478 | online_feas.append(poi + rnfea)
479 |
480 | online_feas = np.array(online_feas)
481 | online_features_dict[rid] = list(online_feas.mean(axis=0))
482 |
483 | online_features_dict[0] = [0.]*online_feas.shape[1] # add soss
484 |
485 | return online_features_dict
486 |
487 | def get_dict_info_batch(input_id, features_dict):
488 | """
489 | batched dict info
490 | """
491 | # input_id = [1, batch size]
492 | features = []
493 | for rid in input_id.squeeze(1):
494 | features.append(features_dict[rid.cpu().tolist()])
495 |
496 | features = torch.tensor(features).float()
497 | # features = [1, batch size, features dim]
498 | return features
499 |
500 | #####################################################################################################
501 | #
502 | # Use for visualization
503 | #
504 | #####################################################################################################
505 | def get_plot_seq(raw_input, predict, target, src_len, trg_len):
506 | """
507 | Get input, prediction and ground truth GPS sequence.
508 | raw_input, predict, target = [seq len, batch size, 2] and the sos is not removed.
509 | """
510 | raw_input = raw_input[1:].permute(1, 0, 2)
511 | predict = predict[1:].permute(1, 0, 2) # [batch size, seq len, 2]
512 | target = target[1:].permute(1, 0, 2) # [batch size, seq len, 2]
513 |
514 | bs = predict.size(0)
515 |
516 | ls_pre_seq, ls_trg_seq, ls_input_seq =[], [], []
517 | for bs_i in range(bs):
518 | pre_seq = []
519 | trg_seq = []
520 | for len_i in range(trg_len[bs_i]-1):
521 | pre_seq.append([predict[bs_i, len_i][0].cpu().data.tolist(), predict[bs_i, len_i][1].cpu().data.tolist()])
522 | trg_seq.append([target[bs_i, len_i][0].cpu().data.tolist(), target[bs_i, len_i][1].cpu().data.tolist()])
523 | input_seq = []
524 | for len_i in range(src_len[bs_i]-1):
525 | input_seq.append([raw_input[bs_i, len_i][0].cpu().data.tolist(), raw_input[bs_i, len_i][1].cpu().data.tolist()])
526 | ls_pre_seq.append(pre_seq)
527 | ls_trg_seq.append(trg_seq)
528 | ls_input_seq.append(input_seq)
529 | return ls_input_seq, ls_pre_seq, ls_trg_seq
530 |
531 |
532 | #####################################################################################################
533 | #
534 | # POIs
535 | #
536 | #####################################################################################################
537 | def filterPOI(df, mbr):
538 | labels = ['公司企业', '美食', '运动健身', '教育培训', '购物', '政府机构', '旅游景点', '出入口', '房地产', '生活服务',
539 | '交通设施', '汽车服务', '酒店', '丽人', '医疗', '文化传媒', '金融', '休闲娱乐', '道路','自然地物', '行政地标', '门址']
540 | eng_labels = ['company','food', 'gym', 'education','shopping','gov', 'viewpoint','entrance','house','life',
541 | 'traffic','car','hotel','beauty','hospital','media','finance','entertainment','road','nature','landmark','address']
542 | eng_labels_dict = {}
543 | for i in range(len(labels)):
544 | eng_labels_dict[labels[i]] = eng_labels[i]
545 |
546 | new_df = {'lat':[],'lng':[],'type':[]}
547 | for i in range(len(df)):
548 | gps = df.iloc[i]['经纬度wgs编码'].split(',')
549 | lat = float(gps[0])
550 | lng = float(gps[1])
551 | label = df.iloc[i]['一级行业分类']
552 | if mbr.contains(lat,lng) and (label is not np.nan):
553 | new_df['lat'].append(lat)
554 | new_df['lng'].append(lng)
555 | new_df['type'].append(eng_labels_dict[label])
556 | new_df = pd.DataFrame(new_df)
557 | return new_df
558 |
559 |
560 | def get_poi_grid(mbr, grid_size, df):
561 | labels = ['company','food','shopping','viewpoint','house']
562 | new_df = filterPOI(df, mbr)
563 | LAT_PER_METER = 8.993203677616966e-06
564 | LNG_PER_METER = 1.1700193970443768e-05
565 | lat_unit = LAT_PER_METER * grid_size
566 | lng_unit = LNG_PER_METER * grid_size
567 |
568 | max_xid = int((mbr.max_lat - mbr.min_lat) / lat_unit) + 1
569 | max_yid = int((mbr.max_lng - mbr.min_lng) / lng_unit) + 1
570 |
571 | grid_poi_dict = {}
572 | for i in range(len(new_df)):
573 | lat = new_df.iloc[i]['lat']
574 | lng = new_df.iloc[i]['lng']
575 | label = new_df.iloc[i]['type']
576 | # only consider partial labels
577 | if label in labels:
578 | locgrid_x = int((lat - mbr.min_lat) / lat_unit) + 1
579 | locgrid_y = int((lng - mbr.min_lng) / lng_unit) + 1
580 |
581 | if (locgrid_x, locgrid_y) not in grid_poi_dict.keys():
582 | grid_poi_dict[(locgrid_x, locgrid_y)] = {label:1}
583 | else:
584 | if label not in grid_poi_dict[(locgrid_x, locgrid_y)].keys():
585 | grid_poi_dict[(locgrid_x, locgrid_y)][label] = 1
586 | else:
587 | grid_poi_dict[(locgrid_x, locgrid_y)][label]+=1
588 |
589 | # key: grid, value: [0.1,0.5,0.5] normalized POI by column
590 | grid_poi_df = pd.DataFrame(grid_poi_dict).T.fillna(0)
591 |
592 | norm_grid_poi_df=(grid_poi_df-grid_poi_df.min())/(grid_poi_df.max()-grid_poi_df.min())
593 | # norm_grid_poi_df = grid_poi_df.div(grid_poi_df.sum(axis=1), axis=0) # row normalization
594 |
595 | norm_grid_poi_dict = {}
596 | for i in range(len(norm_grid_poi_df)):
597 | k = norm_grid_poi_df.index[i]
598 | v = norm_grid_poi_df.iloc[i].values
599 | norm_grid_poi_dict[k] = v
600 |
601 | return norm_grid_poi_dict, grid_poi_df
602 |
603 |
604 | # extra_info_dir = "../data/map/extra_info/"
605 | # poi_df = pd.read_csv(extra_info_dir+'jnPoiInfo.txt',sep='\t')
606 | # norm_grid_poi_dict, grid_poi_df = get_poi_grid(mbr, args.grid_size, poi_df)
607 |
608 | # save_pkl_data(norm_grid_poi_dict, extra_info_dir, 'poi_col_norm.pkl')
609 | # grid_poi_df = pd.to_csv(extra_info_dir+'poi.csv')
610 |
611 | #####################################################################################################
612 | #
613 | # others
614 | #
615 | #####################################################################################################
616 | def epoch_time(start_time, end_time):
617 | elapsed_time = end_time - start_time
618 | elapsed_mins = int(elapsed_time / 60)
619 | elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
620 | return elapsed_mins, elapsed_secs
621 |
622 |
623 | class AttrDict(dict):
624 | def __init__(self, *args, **kwargs):
625 | super(AttrDict, self).__init__(*args, **kwargs)
626 | self.__dict__ = self
--------------------------------------------------------------------------------
/models/models_attn_tandem.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/11/5 10:27
4 |
5 | import random
6 | import operator
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from models.model_utils import get_dict_info_batch
12 |
13 |
14 | def mask_log_softmax(x, mask, log_flag=True):
15 | maxes = torch.max(x, 1, keepdim=True)[0]
16 | x_exp = torch.exp(x - maxes) * mask
17 | x_exp_sum = torch.sum(x_exp, 1, keepdim=True)
18 | if log_flag:
19 | output_custom = torch.log(x_exp / x_exp_sum)
20 | else:
21 | output_custom = x_exp / x_exp_sum
22 | return output_custom
23 |
24 | class Extra_MLP(nn.Module):
25 | def __init__(self, parameters):
26 | super().__init__()
27 | self.pro_input_dim = parameters.pro_input_dim
28 | self.pro_output_dim = parameters.pro_output_dim
29 | self.fc_out = nn.Linear(self.pro_input_dim, self.pro_output_dim)
30 |
31 | def forward(self, x):
32 | out = torch.tanh(self.fc_out(x))
33 | return out
34 |
35 |
36 | class Encoder(nn.Module):
37 | def __init__(self, parameters):
38 | super().__init__()
39 | self.hid_dim = parameters.hid_dim
40 | self.pro_output_dim = parameters.pro_output_dim
41 | self.online_features_flag = parameters.online_features_flag
42 | self.pro_features_flag = parameters.pro_features_flag
43 |
44 | input_dim = 3
45 | if self.online_features_flag:
46 | input_dim = input_dim + parameters.online_dim
47 |
48 | self.rnn = nn.GRU(input_dim, self.hid_dim)
49 | self.dropout = nn.Dropout(parameters.dropout)
50 |
51 | if self.pro_features_flag:
52 | self.extra = Extra_MLP(parameters)
53 | self.fc_hid = nn.Linear(self.hid_dim + self.pro_output_dim, self.hid_dim)
54 |
55 | def forward(self, src, src_len, pro_features):
56 | # src = [src len, batch size, 3]
57 | # if only input trajectory, input dim = 2; elif input trajectory + behavior feature, input dim = 2 + n
58 | # src_len = [batch size]
59 |
60 | packed_embedded = nn.utils.rnn.pack_padded_sequence(src, src_len)
61 | packed_outputs, hidden = self.rnn(packed_embedded)
62 |
63 | # packed_outputs is a packed sequence containing all hidden states
64 | # hidden is now from the final non-padded element in the batch
65 |
66 | outputs, _ = nn.utils.rnn.pad_packed_sequence(packed_outputs)
67 | # outputs is now a non-packed sequence, all hidden states obtained
68 | # when the input is a pad token are all zeros
69 |
70 | # outputs = [src len, batch size, hid dim * num directions]
71 | # hidden = [n layers * num directions, batch size, hid dim]
72 |
73 | # initial decoder hidden is final hidden state of the forwards and backwards
74 | # encoder RNNs fed through a linear layer
75 |
76 | # hidden = [1, batch size, hidden_dim]
77 | # outputs = [src len, batch size, hidden_dim * num directions]
78 |
79 | if self.pro_features_flag:
80 | extra_emb = self.extra(pro_features)
81 | extra_emb = extra_emb.unsqueeze(0)
82 | # extra_emb = [1, batch size, extra output dim]
83 | hidden = torch.tanh(self.fc_hid(torch.cat((extra_emb, hidden), dim=2)))
84 | # hidden = [1, batch size, hid dim]
85 |
86 | return outputs, hidden
87 |
88 |
89 | class Attention(nn.Module):
90 | # TODO update to more advanced attention layer.
91 | def __init__(self, parameters):
92 | super().__init__()
93 | self.hid_dim = parameters.hid_dim
94 |
95 | self.attn = nn.Linear(self.hid_dim * 2, self.hid_dim)
96 | self.v = nn.Linear(self.hid_dim, 1, bias=False)
97 |
98 | def forward(self, hidden, encoder_outputs, attn_mask):
99 | # hidden = [1, bath size, hid dim]
100 | # encoder_outputs = [src len, batch size, hid dim * num directions]
101 | src_len = encoder_outputs.shape[0]
102 | # repeat decoder hidden sate src_len times
103 | hidden = hidden.repeat(src_len, 1, 1)
104 | hidden = hidden.permute(1, 0, 2)
105 | encoder_outputs = encoder_outputs.permute(1, 0, 2)
106 | # hidden = [batch size, src len, hid dim]
107 | # encoder_outputs = [batch size, src len, hid dim * num directions]
108 |
109 | energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
110 | # energy = [batch size, src len, hid dim]
111 |
112 | attention = self.v(energy).squeeze(2)
113 | # attention = [batch size, src len]
114 | attention = attention.masked_fill(attn_mask == 0, -1e10)
115 | # using mask to force the attention to only be over non-padding elements.
116 |
117 | return F.softmax(attention, dim=1)
118 |
119 |
120 | class DecoderMulti(nn.Module):
121 | def __init__(self, parameters):
122 | super().__init__()
123 |
124 | self.id_size = parameters.id_size
125 | self.id_emb_dim = parameters.id_emb_dim
126 | self.hid_dim = parameters.hid_dim
127 | self.pro_output_dim = parameters.pro_output_dim
128 | self.online_dim = parameters.online_dim
129 | self.rid_fea_dim = parameters.rid_fea_dim
130 |
131 | self.attn_flag = parameters.attn_flag
132 | self.dis_prob_mask_flag = parameters.dis_prob_mask_flag # final softmax
133 | self.online_features_flag = parameters.online_features_flag
134 | self.tandem_fea_flag = parameters.tandem_fea_flag
135 |
136 | self.emb_id = nn.Embedding(self.id_size, self.id_emb_dim)
137 |
138 | rnn_input_dim = self.id_emb_dim + 1
139 | fc_id_out_input_dim = self.hid_dim
140 | fc_rate_out_input_dim = self.hid_dim
141 |
142 | type_input_dim = self.id_emb_dim + self.hid_dim
143 | self.tandem_fc = nn.Sequential(
144 | nn.Linear(type_input_dim, self.hid_dim),
145 | nn.ReLU()
146 | )
147 |
148 | if self.attn_flag:
149 | self.attn = Attention(parameters)
150 | rnn_input_dim = rnn_input_dim + self.hid_dim
151 |
152 | if self.online_features_flag:
153 | rnn_input_dim = rnn_input_dim + self.online_dim # 5 poi and 5 road network
154 |
155 | if self.tandem_fea_flag:
156 | fc_rate_out_input_dim = self.hid_dim + self.rid_fea_dim
157 |
158 | self.rnn = nn.GRU(rnn_input_dim, self.hid_dim)
159 | self.fc_id_out = nn.Linear(fc_id_out_input_dim, self.id_size)
160 | self.fc_rate_out = nn.Linear(fc_rate_out_input_dim, 1)
161 | self.dropout = nn.Dropout(parameters.dropout)
162 |
163 |
164 | def forward(self, input_id, input_rate, hidden, encoder_outputs, attn_mask,
165 | pre_grid, next_grid, constraint_vec, pro_features, online_features, rid_features):
166 |
167 | # input_id = [batch size, 1] rid long
168 | # input_rate = [batch size, 1] rate float.
169 | # hidden = [1, batch size, hid dim]
170 | # encoder_outputs = [src len, batch size, hid dim * num directions]
171 | # attn_mask = [batch size, src len]
172 | # pre_grid = [batch size, 3]
173 | # next_grid = [batch size, 3]
174 | # constraint_vec = [batch size, id_size], [id_size] is the vector of reachable rid
175 | # pro_features = [batch size, profile features input dim]
176 | # online_features = [batch size, online features dim]
177 | # rid_features = [batch size, rid features dim]
178 |
179 | input_id = input_id.squeeze(1).unsqueeze(0) # cannot use squeeze() bug for batch size = 1
180 | # input_id = [1, batch size]
181 | input_rate = input_rate.unsqueeze(0)
182 | # input_rate = [1, batch size, 1]
183 | embedded = self.dropout(self.emb_id(input_id))
184 | # embedded = [1, batch size, emb dim]
185 |
186 | if self.attn_flag:
187 | a = self.attn(hidden, encoder_outputs, attn_mask)
188 | # a = [batch size, src len]
189 | a = a.unsqueeze(1)
190 | # a = [batch size, 1, src len]
191 | encoder_outputs = encoder_outputs.permute(1, 0, 2)
192 | # encoder_outputs = [batch size, src len, hid dim * num directions]
193 | weighted = torch.bmm(a, encoder_outputs)
194 | # weighted = [batch size, 1, hid dim * num directions]
195 | weighted = weighted.permute(1, 0, 2)
196 | # weighted = [1, batch size, hid dim * num directions]
197 |
198 | if self.online_features_flag:
199 | rnn_input = torch.cat((weighted, embedded, input_rate,
200 | online_features.unsqueeze(0)), dim=2)
201 | else:
202 | rnn_input = torch.cat((weighted, embedded, input_rate), dim=2)
203 | else:
204 | if self.online_features_flag:
205 | rnn_input = torch.cat((embedded, input_rate, online_features.unsqueeze(0)), dim=2)
206 | else:
207 | rnn_input = torch.cat((embedded, input_rate), dim=2)
208 |
209 | output, hidden = self.rnn(rnn_input, hidden)
210 |
211 | # output = [seq len, batch size, hid dim * n directions]
212 | # hidden = [n layers * n directions, batch size, hid dim]
213 | # seq len and n directions will always be 1 in the decoder, therefore:
214 | # output = [1, batch size, dec hid dim]
215 | # hidden = [1, batch size, dec hid dim]
216 | assert (output == hidden).all()
217 |
218 | # pre_rid
219 | if self.dis_prob_mask_flag:
220 | prediction_id = mask_log_softmax(self.fc_id_out(output.squeeze(0)),
221 | constraint_vec, log_flag=True)
222 | else:
223 | prediction_id = F.log_softmax(self.fc_id_out(output.squeeze(0)), dim=1)
224 | # then the loss function should change to nll_loss()
225 |
226 | # pre_rate
227 | max_id = prediction_id.argmax(dim=1).long()
228 | id_emb = self.dropout(self.emb_id(max_id))
229 | rate_input = torch.cat((id_emb, hidden.squeeze(0)),dim=1)
230 | rate_input = self.tandem_fc(rate_input) # [batch size, hid dim]
231 | if self.tandem_fea_flag:
232 | prediction_rate = torch.sigmoid(self.fc_rate_out(torch.cat((rate_input, rid_features), dim=1)))
233 | else:
234 | prediction_rate = torch.sigmoid(self.fc_rate_out(rate_input))
235 |
236 | # prediction_id = [batch size, id_size]
237 | # prediction_rate = [batch size, 1]
238 |
239 | return prediction_id, prediction_rate, hidden
240 |
241 | class Seq2SeqMulti(nn.Module):
242 | def __init__(self, encoder, decoder, device):
243 | super().__init__()
244 |
245 | self.encoder = encoder # Encoder
246 | self.decoder = decoder # DecoderMulti
247 | self.device = device
248 |
249 | def forward(self, src, src_len, trg_id, trg_rate, trg_len,
250 | pre_grids, next_grids, constraint_mat, pro_features,
251 | online_features_dict, rid_features_dict,
252 | teacher_forcing_ratio=0.5):
253 | """
254 | src = [src len, batch size, 3], x,y,t
255 | src_len = [batch size]
256 | trg_id = [trg len, batch size, 1]
257 | trg_rate = [trg len, batch size, 1]
258 | trg_len = [batch size]
259 | pre_grids = [trg len, batch size, 3]
260 | nex_grids = [trg len, batch size, 3]
261 | constraint_mat = [trg len, batch size, id_size]
262 | pro_features = [batch size, profile features input dim]
263 | online_features_dict = {rid: online_features} # rid --> grid --> online features
264 | rid_features_dict = {rid: rn_features}
265 | teacher_forcing_ratio is probability to use teacher forcing
266 | e.g. if teacher_forcing_ratio is 0.75 we use ground-truth inputs 75% of the time
267 | Return:
268 | ------
269 | outputs_id: [seq len, batch size, id_size(1)] based on beam search
270 | outputs_rate: [seq len, batch size, 1]
271 | """
272 | max_trg_len = trg_id.size(0)
273 | batch_size = trg_id.size(1)
274 |
275 | # encoder_outputs is all hidden states of the input sequence, back and forwards
276 | # hidden is the final forward and backward hidden states, passed through a linear layer
277 | encoder_outputs, hiddens = self.encoder(src, src_len, pro_features)
278 |
279 | if self.decoder.attn_flag:
280 | attn_mask = torch.zeros(batch_size, max(src_len)) # only attend on unpadded sequence
281 | for i in range(len(src_len)):
282 | attn_mask[i][:src_len[i]] = 1.
283 | attn_mask = attn_mask.to(self.device)
284 | else:
285 | attn_mask = None
286 |
287 | outputs_id, outputs_rate = self.normal_step(max_trg_len, batch_size, trg_id, trg_rate, trg_len,
288 | encoder_outputs, hiddens, attn_mask,
289 | online_features_dict,
290 | rid_features_dict,
291 | pre_grids, next_grids, constraint_mat, pro_features,
292 | teacher_forcing_ratio)
293 |
294 | return outputs_id, outputs_rate
295 |
296 | def normal_step(self, max_trg_len, batch_size, trg_id, trg_rate, trg_len, encoder_outputs, hidden,
297 | attn_mask, online_features_dict, rid_features_dict,
298 | pre_grids, next_grids, constraint_mat, pro_features, teacher_forcing_ratio):
299 | """
300 | Returns:
301 | -------
302 | outputs_id: [seq len, batch size, id size]
303 | outputs_rate: [seq len, batch size, 1]
304 | """
305 | # tensor to store decoder outputs
306 | outputs_id = torch.zeros(max_trg_len, batch_size, self.decoder.id_size).to(self.device)
307 | outputs_rate = torch.zeros(trg_rate.size()).to(self.device)
308 |
309 | # first input to the decoder is the tokens
310 | input_id = trg_id[0, :]
311 | input_rate = trg_rate[0, :]
312 | for t in range(1, max_trg_len):
313 | # insert input token embedding, previous hidden state, all encoder hidden states
314 | # and attn_mask
315 | # receive output tensor (predictions) and new hidden state
316 | if self.decoder.online_features_flag:
317 | online_features = get_dict_info_batch(input_id, online_features_dict).to(self.device)
318 | else:
319 | online_features = torch.zeros((1, batch_size, self.decoder.online_dim))
320 | if self.decoder.tandem_fea_flag:
321 | rid_features = get_dict_info_batch(input_id, rid_features_dict).to(self.device)
322 | else:
323 | rid_features = None
324 | prediction_id, prediction_rate, hidden = self.decoder(input_id, input_rate, hidden, encoder_outputs,
325 | attn_mask, pre_grids[t], next_grids[t],
326 | constraint_mat[t], pro_features, online_features,
327 | rid_features)
328 |
329 | # place predictions in a tensor holding predictions for each token
330 | outputs_id[t] = prediction_id
331 | outputs_rate[t] = prediction_rate
332 |
333 | # decide if we are going to use teacher forcing or not
334 | teacher_force = random.random() < teacher_forcing_ratio
335 |
336 | # get the highest predicted token from our predictions
337 | top1_id = prediction_id.argmax(1)
338 | top1_id = top1_id.unsqueeze(-1) # make sure the output has the same dimension as input
339 |
340 | # if teacher forcing, use actual next token as next input
341 | # if not, use predicted token
342 | input_id = trg_id[t] if teacher_force else top1_id
343 | input_rate = trg_rate[t] if teacher_force else prediction_rate
344 |
345 | # max_trg_len, batch_size, trg_rid_size
346 | outputs_id = outputs_id.permute(1, 0, 2) # batch size, seq len, rid size
347 | outputs_rate = outputs_rate.permute(1, 0, 2) # batch size, seq len, 1
348 | for i in range(batch_size):
349 | outputs_id[i][trg_len[i]:] = 0
350 | outputs_id[i][trg_len[i]:, 0] = 1 # make sure argmax will return eid0
351 | outputs_rate[i][trg_len[i]:] = 0
352 | outputs_id = outputs_id.permute(1, 0, 2)
353 | outputs_rate = outputs_rate.permute(1, 0, 2)
354 |
355 | return outputs_id, outputs_rate
356 |
357 |
--------------------------------------------------------------------------------
/models/multi_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/10/29 17:37
4 |
5 | import numpy as np
6 | import random
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from models.model_utils import toseq, get_constraint_mask
12 | from models.loss_fn import cal_id_acc, check_rn_dis_loss
13 |
14 | # set random seed
15 | SEED = 20202020
16 |
17 | random.seed(SEED)
18 | np.random.seed(SEED)
19 | torch.manual_seed(SEED)
20 | torch.cuda.manual_seed(SEED)
21 | torch.backends.cudnn.deterministic = True
22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 | print('multi_task device', device)
24 |
25 |
26 | def init_weights(self):
27 | """
28 | Here we reproduce Keras default initialization weights for consistency with Keras version
29 | Reference: https://github.com/vonfeng/DeepMove/blob/master/codes/model.py
30 | """
31 | ih = (param.data for name, param in self.named_parameters() if 'weight_ih' in name)
32 | hh = (param.data for name, param in self.named_parameters() if 'weight_hh' in name)
33 | b = (param.data for name, param in self.named_parameters() if 'bias' in name)
34 |
35 | for t in ih:
36 | nn.init.xavier_uniform_(t)
37 | for t in hh:
38 | nn.init.orthogonal_(t)
39 | for t in b:
40 | nn.init.constant_(t, 0)
41 |
42 |
43 | def train(model, iterator, optimizer, log_vars, rn_dict, grid_rn_dict, rn,
44 | raw2new_rid_dict, online_features_dict, rid_features_dict, parameters):
45 | model.train() # not necessary to have this line but it's safe to use model.train() to train model
46 |
47 | criterion_reg = nn.MSELoss()
48 | criterion_ce = nn.NLLLoss()
49 |
50 | epoch_ttl_loss = 0
51 | epoch_id1_loss = 0
52 | epoch_recall_loss = 0
53 | epoch_precision_loss = 0
54 | epoch_train_id_loss = 0
55 | epoch_rate_loss = 0
56 | for i, batch in enumerate(iterator):
57 | src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths = batch
58 | if parameters.dis_prob_mask_flag:
59 | constraint_mat, pre_grids, next_grids = get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths,
60 | trg_lengths, grid_rn_dict, rn, raw2new_rid_dict,
61 | parameters)
62 | constraint_mat = constraint_mat.permute(1, 0, 2).to(device)
63 | pre_grids = pre_grids.permute(1, 0, 2).to(device)
64 | next_grids = next_grids.permute(1, 0, 2).to(device)
65 | else:
66 | max_trg_len = max(trg_lengths)
67 | batch_size = src_grid_seqs.size(0)
68 | constraint_mat = torch.zeros(max_trg_len, batch_size, parameters.id_size, device=device)
69 | pre_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
70 | next_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
71 |
72 | src_pro_feas = src_pro_feas.float().to(device)
73 |
74 | src_grid_seqs = src_grid_seqs.permute(1, 0, 2).to(device)
75 | trg_gps_seqs = trg_gps_seqs.permute(1, 0, 2).to(device)
76 | trg_rids = trg_rids.permute(1, 0, 2).long().to(device)
77 | trg_rates = trg_rates.permute(1, 0, 2).to(device)
78 |
79 | # constraint_mat = [trg len, batch size, id size]
80 | # src_grid_seqs = [src len, batch size, 2]
81 | # src_lengths = [batch size]
82 | # trg_gps_seqs = [trg len, batch size, 2]
83 | # trg_rids = [trg len, batch size, 1]
84 | # trg_rates = [trg len, batch size, 1]
85 | # trg_lengths = [batch size]
86 |
87 | optimizer.zero_grad()
88 | output_ids, output_rates = model(src_grid_seqs, src_lengths, trg_rids, trg_rates, trg_lengths,
89 | pre_grids, next_grids, constraint_mat, src_pro_feas,
90 | online_features_dict, rid_features_dict, parameters.tf_ratio)
91 | output_rates = output_rates.squeeze(2)
92 | trg_rids = trg_rids.squeeze(2)
93 | trg_rates = trg_rates.squeeze(2)
94 |
95 | # output_ids = [trg len, batch size, id one hot output dim]
96 | # output_rates = [trg len, batch size]
97 | # trg_rids = [trg len, batch size]
98 | # trg_rates = [trg len, batch size]
99 |
100 | # rid loss, only show and not bbp
101 | loss_ids1, recall, precision = cal_id_acc(output_ids[1:], trg_rids[1:], trg_lengths)
102 |
103 | # for bbp
104 | output_ids_dim = output_ids.shape[-1]
105 | output_ids = output_ids[1:].reshape(-1, output_ids_dim) # [(trg len - 1)* batch size, output id one hot dim]
106 | trg_rids = trg_rids[1:].reshape(-1) # [(trg len - 1) * batch size],
107 | # view size is not compatible with input tensor's size and stride ==> use reshape() instead
108 |
109 | loss_train_ids = criterion_ce(output_ids, trg_rids)
110 | loss_rates = criterion_reg(output_rates[1:], trg_rates[1:]) * parameters.lambda1
111 | ttl_loss = loss_train_ids + loss_rates
112 |
113 | ttl_loss.backward()
114 | torch.nn.utils.clip_grad_norm_(model.parameters(), parameters.clip) # log_vars are not necessary to clip
115 | optimizer.step()
116 |
117 | epoch_ttl_loss += ttl_loss.item()
118 | epoch_id1_loss += loss_ids1
119 | epoch_recall_loss += recall
120 | epoch_precision_loss += precision
121 | epoch_train_id_loss += loss_train_ids.item()
122 | epoch_rate_loss += loss_rates.item()
123 |
124 | return log_vars, epoch_ttl_loss / len(iterator), epoch_id1_loss / len(iterator), epoch_recall_loss / len(iterator), \
125 | epoch_precision_loss / len(iterator), epoch_rate_loss / len(iterator), epoch_train_id_loss / len(iterator)
126 |
127 |
128 | def evaluate(model, iterator, rn_dict, grid_rn_dict, rn, raw2new_rid_dict,
129 | online_features_dict, rid_features_dict, raw_rn_dict, new2raw_rid_dict, parameters):
130 | model.eval() # must have this line since it will affect dropout and batch normalization
131 |
132 | epoch_dis_mae_loss = 0
133 | epoch_dis_rmse_loss = 0
134 | epoch_dis_rn_mae_loss = 0
135 | epoch_dis_rn_rmse_loss = 0
136 | epoch_id1_loss = 0
137 | epoch_recall_loss = 0
138 | epoch_precision_loss = 0
139 | epoch_rate_loss = 0
140 | epoch_id_loss = 0 # loss from dl model
141 | criterion_ce = nn.NLLLoss()
142 | criterion_reg = nn.MSELoss()
143 |
144 | with torch.no_grad(): # this line can help speed up evaluation
145 | for i, batch in enumerate(iterator):
146 | src_grid_seqs, src_gps_seqs, src_pro_feas, src_lengths, trg_gps_seqs, trg_rids, trg_rates, trg_lengths = batch
147 |
148 | if parameters.dis_prob_mask_flag:
149 | constraint_mat, pre_grids, next_grids = get_constraint_mask(src_grid_seqs, src_gps_seqs, src_lengths,
150 | trg_lengths, grid_rn_dict, rn,
151 | raw2new_rid_dict, parameters)
152 | constraint_mat = constraint_mat.permute(1, 0, 2).to(device)
153 | pre_grids = pre_grids.permute(1, 0, 2).to(device)
154 | next_grids = next_grids.permute(1, 0, 2).to(device)
155 | else:
156 | max_trg_len = max(trg_lengths)
157 | batch_size = src_grid_seqs.size(0)
158 | constraint_mat = torch.zeros(max_trg_len, batch_size, parameters.id_size).to(device)
159 | pre_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
160 | next_grids = torch.zeros(max_trg_len, batch_size, 3).to(device)
161 |
162 | src_pro_feas = src_pro_feas.float().to(device)
163 |
164 | src_grid_seqs = src_grid_seqs.permute(1, 0, 2).to(device)
165 | trg_gps_seqs = trg_gps_seqs.permute(1, 0, 2).to(device)
166 | trg_rids = trg_rids.permute(1, 0, 2).long().to(device)
167 | trg_rates = trg_rates.permute(1, 0, 2).to(device)
168 |
169 | # constraint_mat = [trg len, batch size, id size]
170 | # src_grid_seqs = [src len, batch size, 2]
171 | # src_pro_feas = [batch size, feature dim]
172 | # src_lengths = [batch size]
173 | # trg_gps_seqs = [trg len, batch size, 2]
174 | # trg_rids = [trg len, batch size, 1]
175 | # trg_rates = [trg len, batch size, 1]
176 | # trg_lengths = [batch size]
177 |
178 | output_ids, output_rates = model(src_grid_seqs, src_lengths, trg_rids, trg_rates, trg_lengths,
179 | pre_grids, next_grids, constraint_mat,
180 | src_pro_feas, online_features_dict, rid_features_dict,
181 | teacher_forcing_ratio=0)
182 |
183 | output_rates = output_rates.squeeze(2)
184 | output_seqs = toseq(rn_dict, output_ids, output_rates, parameters)
185 | trg_rids = trg_rids.squeeze(2)
186 | trg_rates = trg_rates.squeeze(2)
187 | # output_ids = [trg len, batch size, id one hot output dim]
188 | # output_rates = [trg len, batch size]
189 | # trg_rids = [trg len, batch size]
190 | # trg_rates = [trg len, batch size]
191 |
192 | # rid loss, only show and not bbp
193 | loss_ids1, recall, precision = cal_id_acc(output_ids[1:], trg_rids[1:], trg_lengths)
194 | # distance loss
195 | dis_mae_loss, dis_rmse_loss, dis_rn_mae_loss, dis_rn_rmse_loss = check_rn_dis_loss(output_seqs[1:],
196 | output_ids[1:],
197 | output_rates[1:],
198 | trg_gps_seqs[1:],
199 | trg_rids[1:],
200 | trg_rates[1:],
201 | trg_lengths,
202 | rn, raw_rn_dict,
203 | new2raw_rid_dict)
204 |
205 | # for bbp
206 | output_ids_dim = output_ids.shape[-1]
207 | output_ids = output_ids[1:].reshape(-1,
208 | output_ids_dim) # [(trg len - 1)* batch size, output id one hot dim]
209 | trg_rids = trg_rids[1:].reshape(-1) # [(trg len - 1) * batch size],
210 | loss_ids = criterion_ce(output_ids, trg_rids)
211 | # rate loss
212 | loss_rates = criterion_reg(output_rates[1:], trg_rates[1:]) * parameters.lambda1
213 | # loss_rates.size = [(trg len - 1), batch size], --> [(trg len - 1)* batch size,1]
214 |
215 | epoch_dis_mae_loss += dis_mae_loss
216 | epoch_dis_rmse_loss += dis_rmse_loss
217 | epoch_dis_rn_mae_loss += dis_rn_mae_loss
218 | epoch_dis_rn_rmse_loss += dis_rn_rmse_loss
219 | epoch_id1_loss += loss_ids1
220 | epoch_recall_loss += recall
221 | epoch_precision_loss += precision
222 | epoch_rate_loss += loss_rates.item()
223 | epoch_id_loss += loss_ids.item()
224 |
225 | return epoch_id1_loss / len(iterator), epoch_recall_loss / len(iterator), \
226 | epoch_precision_loss / len(iterator), \
227 | epoch_dis_mae_loss / len(iterator), epoch_dis_rmse_loss / len(iterator), \
228 | epoch_dis_rn_mae_loss / len(iterator), epoch_dis_rn_rmse_loss / len(iterator), \
229 | epoch_rate_loss / len(iterator), epoch_id_loss / len(iterator)
230 |
--------------------------------------------------------------------------------
/multi_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/11/10 11:07
4 |
5 | import time
6 | from tqdm import tqdm
7 | import logging
8 | import sys
9 | import argparse
10 | import pandas as pd
11 |
12 | import torch
13 | import torch.optim as optim
14 |
15 | from utils.utils import save_json_data, create_dir, load_pkl_data
16 | from common.mbr import MBR
17 | from common.spatial_func import SPoint, distance
18 | from common.road_network import load_rn_shp
19 |
20 | from models.datasets import Dataset, collate_fn, split_data
21 | from models.model_utils import load_rn_dict, load_rid_freqs, get_rid_grid, get_poi_info, get_rn_info
22 | from models.model_utils import get_online_info_dict, epoch_time, AttrDict, get_rid_rnfea_dict
23 | from models.multi_train import evaluate, init_weights, train
24 | from models.models_attn_tandem import Encoder, DecoderMulti, Seq2SeqMulti
25 |
26 |
27 | if __name__ == '__main__':
28 | parser = argparse.ArgumentParser(description='Multi-task Traj Interp')
29 | parser.add_argument('--module_type', type=str, default='simple', help='module type')
30 | parser.add_argument('--keep_ratio', type=float, default=0.125, help='keep ratio in float')
31 | parser.add_argument('--lambda1', type=int, default=10, help='weight for multi task rate')
32 | parser.add_argument('--hid_dim', type=int, default=512, help='hidden dimension')
33 | parser.add_argument('--epochs', type=int, default=10, help='epochs')
34 | parser.add_argument('--grid_size', type=int, default=50, help='grid size in int')
35 | parser.add_argument('--dis_prob_mask_flag', action='store_true', help='flag of using prob mask')
36 | parser.add_argument('--pro_features_flag', action='store_true', help='flag of using profile features')
37 | parser.add_argument('--online_features_flag', action='store_true', help='flag of using online features')
38 | parser.add_argument('--tandem_fea_flag', action='store_true', help='flag of using tandem rid features')
39 | parser.add_argument('--no_attn_flag', action='store_false', help='flag of using attention')
40 | parser.add_argument('--load_pretrained_flag', action='store_true', help='flag of load pretrained model')
41 | parser.add_argument('--model_old_path', type=str, default='', help='old model path')
42 | parser.add_argument('--no_debug', action='store_false', help='flag of debug')
43 | parser.add_argument('--no_train_flag', action='store_false', help='flag of training')
44 | parser.add_argument('--test_flag', action='store_true', help='flag of testing')
45 |
46 |
47 | opts = parser.parse_args()
48 |
49 | debug = opts.no_debug
50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51 |
52 | args = AttrDict()
53 | args_dict = {
54 | 'module_type':opts.module_type,
55 | 'debug':debug,
56 | 'device':device,
57 |
58 | # pre train
59 | 'load_pretrained_flag':opts.load_pretrained_flag,
60 | 'model_old_path':opts.model_old_path,
61 | 'train_flag':opts.no_train_flag,
62 | 'test_flag':opts.test_flag,
63 |
64 | # attention
65 | 'attn_flag':opts.no_attn_flag,
66 |
67 | # constranit
68 | 'dis_prob_mask_flag':opts.dis_prob_mask_flag,
69 | 'search_dist':50,
70 | 'beta':15,
71 |
72 | # features
73 | 'tandem_fea_flag':opts.tandem_fea_flag,
74 | 'pro_features_flag':opts.pro_features_flag,
75 | 'online_features_flag':opts.online_features_flag,
76 |
77 | # extra info module
78 | 'rid_fea_dim':8,
79 | 'pro_input_dim':30, # 24[hour] + 5[waether] + 1[holiday]
80 | 'pro_output_dim':8,
81 | 'poi_num':5,
82 | 'online_dim':5+5, # poi/roadnetwork features dim
83 | 'poi_type':'company,food,shopping,viewpoint,house',
84 |
85 | # MBR
86 | 'min_lat':36.6456,
87 | 'min_lng':116.9854,
88 | 'max_lat':36.6858,
89 | 'max_lng':117.0692,
90 |
91 | # input data params
92 | 'keep_ratio':opts.keep_ratio,
93 | 'grid_size':opts.grid_size,
94 | 'time_span':15,
95 | 'win_size':25,
96 | 'ds_type':'random',
97 | 'split_flag':True,
98 | 'shuffle':True,
99 |
100 | # model params
101 | 'hid_dim':opts.hid_dim,
102 | 'id_emb_dim':128,
103 | 'dropout':0.5,
104 | 'id_size':2571+1,
105 |
106 | 'lambda1':opts.lambda1,
107 | 'n_epochs':opts.epochs,
108 | 'batch_size':128,
109 | 'learning_rate':1e-3,
110 | 'tf_ratio':0.5,
111 | 'clip':1,
112 | 'log_step':1
113 | }
114 | args.update(args_dict)
115 |
116 | print('Preparing data...')
117 | if args.split_flag:
118 | traj_input_dir = "./data/raw_trajectory/"
119 | output_dir = "./data/model_data/"
120 | split_data(traj_input_dir, output_dir)
121 |
122 | extra_info_dir = "./data/map/extra_info/"
123 | rn_dir = "./data/map/road_network/"
124 | train_trajs_dir = "./data/model_data/train_data/"
125 | valid_trajs_dir = "./data/model_data/valid_data/"
126 | test_trajs_dir = "./data/model_data/test_data/"
127 | if args.tandem_fea_flag:
128 | fea_flag = True
129 | else:
130 | fea_flag = False
131 |
132 | if args.load_pretrained_flag:
133 | model_save_path = args.model_old_path
134 | else:
135 | model_save_path = './results/'+args.module_type+'_kr_'+str(args.keep_ratio)+'_debug_'+str(args.debug)+\
136 | '_gs_'+str(args.grid_size)+'_lam_'+str(args.lambda1)+\
137 | '_attn_'+str(args.attn_flag)+'_prob_'+str(args.dis_prob_mask_flag)+\
138 | '_fea_'+str(fea_flag)+'_'+time.strftime("%Y%m%d_%H%M%S") + '/'
139 | create_dir(model_save_path)
140 |
141 | logging.basicConfig(level=logging.DEBUG,
142 | format='%(asctime)s %(levelname)s %(message)s',
143 | filename=model_save_path + 'log.txt',
144 | filemode='a')
145 |
146 | rn = load_rn_shp(rn_dir, is_directed=True)
147 | raw_rn_dict = load_rn_dict(extra_info_dir, file_name='raw_rn_dict.json')
148 | new2raw_rid_dict = load_rid_freqs(extra_info_dir, file_name='new2raw_rid.json')
149 | raw2new_rid_dict = load_rid_freqs(extra_info_dir, file_name='raw2new_rid.json')
150 | rn_dict = load_rn_dict(extra_info_dir, file_name='rn_dict.json')
151 |
152 | mbr = MBR(args.min_lat, args.min_lng, args.max_lat, args.max_lng)
153 | grid_rn_dict, max_xid, max_yid = get_rid_grid(mbr, args.grid_size, rn_dict)
154 | args_dict['max_xid'] = max_xid
155 | args_dict['max_yid'] = max_yid
156 | args.update(args_dict)
157 | print(args)
158 | logging.info(args_dict)
159 |
160 | # load features
161 | weather_dict = load_pkl_data(extra_info_dir, 'weather_dict.pkl')
162 | if args.online_features_flag:
163 | grid_poi_df = pd.read_csv(extra_info_dir+'poi'+str(args.grid_size)+'.csv',index_col=[0,1])
164 | norm_grid_poi_dict = get_poi_info(grid_poi_df, args)
165 | norm_grid_rnfea_dict = get_rn_info(rn, mbr, args.grid_size, grid_rn_dict, rn_dict)
166 | online_features_dict = get_online_info_dict(grid_rn_dict, norm_grid_poi_dict, norm_grid_rnfea_dict, args)
167 | else:
168 | norm_grid_poi_dict, norm_grid_rnfea_dict, online_features_dict = None, None, None
169 | if args:
170 | rid_features_dict = get_rid_rnfea_dict(rn_dict, args)
171 | else:
172 | rid_features_dict = None
173 |
174 | # load dataset
175 | train_dataset = Dataset(train_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict,
176 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict,
177 | parameters=args, debug=debug)
178 | valid_dataset = Dataset(valid_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict,
179 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict,
180 | parameters=args, debug=debug)
181 | test_dataset = Dataset(test_trajs_dir, mbr=mbr, norm_grid_poi_dict=norm_grid_poi_dict,
182 | norm_grid_rnfea_dict=norm_grid_rnfea_dict, weather_dict=weather_dict,
183 | parameters=args, debug=debug)
184 | print('training dataset shape: ' + str(len(train_dataset)))
185 | print('validation dataset shape: ' + str(len(valid_dataset)))
186 | print('test dataset shape: ' + str(len(test_dataset)))
187 |
188 | train_iterator = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,
189 | shuffle=args.shuffle, collate_fn=collate_fn,
190 | num_workers=4, pin_memory=True)
191 | valid_iterator = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size,
192 | shuffle=args.shuffle, collate_fn=collate_fn,
193 | num_workers=4, pin_memory=True)
194 | test_iterator = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
195 | shuffle=args.shuffle, collate_fn=collate_fn,
196 | num_workers=4, pin_memory=True)
197 |
198 | logging.info('Finish data preparing.')
199 | logging.info('training dataset shape: ' + str(len(train_dataset)))
200 | logging.info('validation dataset shape: ' + str(len(valid_dataset)))
201 | logging.info('test dataset shape: ' + str(len(test_dataset)))
202 |
203 | enc = Encoder(args)
204 | dec = DecoderMulti(args)
205 | model = Seq2SeqMulti(enc, dec, device).to(device)
206 | model.apply(init_weights) # learn how to init weights
207 | if args.load_pretrained_flag:
208 | model.load_state_dict(torch.load(args.model_old_path + 'val-best-model.pt'))
209 |
210 | print('model', str(model))
211 | logging.info('model' + str(model))
212 |
213 | if args.train_flag:
214 | ls_train_loss, ls_train_id_acc1, ls_train_id_recall, ls_train_id_precision, \
215 | ls_train_rate_loss, ls_train_id_loss = [], [], [], [], [], []
216 | ls_valid_loss, ls_valid_id_acc1, ls_valid_id_recall, ls_valid_id_precision, \
217 | ls_valid_dis_mae_loss, ls_valid_dis_rmse_loss = [], [], [], [], [], []
218 | ls_valid_dis_rn_mae_loss, ls_valid_dis_rn_rmse_loss, ls_valid_rate_loss, ls_valid_id_loss = [], [], [], []
219 |
220 | dict_train_loss = {}
221 | dict_valid_loss = {}
222 | best_valid_loss = float('inf') # compare id loss
223 |
224 | # get all parameters (model parameters + task dependent log variances)
225 | log_vars = [torch.zeros((1,), requires_grad=True, device=device)] * 2 # use for auto-tune multi-task param
226 | optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
227 | for epoch in tqdm(range(args.n_epochs)):
228 | start_time = time.time()
229 |
230 | new_log_vars, train_loss, train_id_acc1, train_id_recall, train_id_precision, \
231 | train_rate_loss, train_id_loss = train(model, train_iterator, optimizer, log_vars,
232 | rn_dict, grid_rn_dict, rn, raw2new_rid_dict,
233 | online_features_dict, rid_features_dict, args)
234 |
235 | valid_id_acc1, valid_id_recall, valid_id_precision, valid_dis_mae_loss, valid_dis_rmse_loss, \
236 | valid_dis_rn_mae_loss, valid_dis_rn_rmse_loss, \
237 | valid_rate_loss, valid_id_loss = evaluate(model, valid_iterator,
238 | rn_dict, grid_rn_dict, rn, raw2new_rid_dict,
239 | online_features_dict, rid_features_dict, raw_rn_dict,
240 | new2raw_rid_dict, args)
241 | ls_train_loss.append(train_loss)
242 | ls_train_id_acc1.append(train_id_acc1)
243 | ls_train_id_recall.append(train_id_recall)
244 | ls_train_id_precision.append(train_id_precision)
245 | ls_train_rate_loss.append(train_rate_loss)
246 | ls_train_id_loss.append(train_id_loss)
247 |
248 | ls_valid_id_acc1.append(valid_id_acc1)
249 | ls_valid_id_recall.append(valid_id_recall)
250 | ls_valid_id_precision.append(valid_id_precision)
251 | ls_valid_dis_mae_loss.append(valid_dis_mae_loss)
252 | ls_valid_dis_rmse_loss.append(valid_dis_rmse_loss)
253 | ls_valid_dis_rn_mae_loss.append(valid_dis_rn_mae_loss)
254 | ls_valid_dis_rn_rmse_loss.append(valid_dis_rn_rmse_loss)
255 | ls_valid_rate_loss.append(valid_rate_loss)
256 | ls_valid_id_loss.append(valid_id_loss)
257 | valid_loss = valid_rate_loss + valid_id_loss
258 | ls_valid_loss.append(valid_loss)
259 |
260 | dict_train_loss['train_ttl_loss'] = ls_train_loss
261 | dict_train_loss['train_id_acc1'] = ls_train_id_acc1
262 | dict_train_loss['train_id_recall'] = ls_train_id_recall
263 | dict_train_loss['train_id_precision'] = ls_train_id_precision
264 | dict_train_loss['train_rate_loss'] = ls_train_rate_loss
265 | dict_train_loss['train_id_loss'] = ls_train_id_loss
266 |
267 | dict_valid_loss['valid_ttl_loss'] = ls_valid_loss
268 | dict_valid_loss['valid_id_acc1'] = ls_valid_id_acc1
269 | dict_valid_loss['valid_id_recall'] = ls_valid_id_recall
270 | dict_valid_loss['valid_id_precision'] = ls_valid_id_precision
271 | dict_valid_loss['valid_rate_loss'] = ls_valid_rate_loss
272 | dict_valid_loss['valid_dis_mae_loss'] = ls_valid_dis_mae_loss
273 | dict_valid_loss['valid_dis_rmse_loss'] = ls_valid_dis_rmse_loss
274 | dict_valid_loss['valid_dis_rn_mae_loss'] = ls_valid_dis_rn_mae_loss
275 | dict_valid_loss['valid_dis_rn_rmse_loss'] = ls_valid_dis_rn_rmse_loss
276 | dict_valid_loss['valid_id_loss'] = ls_valid_id_loss
277 |
278 | end_time = time.time()
279 | epoch_mins, epoch_secs = epoch_time(start_time, end_time)
280 |
281 | if valid_loss < best_valid_loss:
282 | best_valid_loss = valid_loss
283 | torch.save(model.state_dict(), model_save_path + 'val-best-model.pt')
284 |
285 | if (epoch % args.log_step == 0) or (epoch == args.n_epochs - 1):
286 | logging.info('Epoch: ' + str(epoch + 1) + ' Time: ' + str(epoch_mins) + 'm' + str(epoch_secs) + 's')
287 | weights = [torch.exp(weight) ** 0.5 for weight in new_log_vars]
288 | logging.info('log_vars:' + str(weights))
289 | logging.info('\tTrain Loss:' + str(train_loss) +
290 | '\tTrain RID Acc1:' + str(train_id_acc1) +
291 | '\tTrain RID Recall:' + str(train_id_recall) +
292 | '\tTrain RID Precision:' + str(train_id_precision) +
293 | '\tTrain Rate Loss:' + str(train_rate_loss) +
294 | '\tTrain RID Loss:' + str(train_id_loss))
295 | logging.info('\tValid Loss:' + str(valid_loss) +
296 | '\tValid RID Acc1:' + str(valid_id_acc1) +
297 | '\tValid RID Recall:' + str(valid_id_recall) +
298 | '\tValid RID Precision:' + str(valid_id_precision) +
299 | '\tValid Distance MAE Loss:' + str(valid_dis_mae_loss) +
300 | '\tValid Distance RMSE Loss:' + str(valid_dis_rmse_loss) +
301 | '\tValid Distance RN MAE Loss:' + str(valid_dis_rn_mae_loss) +
302 | '\tValid Distance RN RMSE Loss:' + str(valid_dis_rn_rmse_loss) +
303 | '\tValid Rate Loss:' + str(valid_rate_loss) +
304 | '\tValid RID Loss:' + str(valid_id_loss))
305 |
306 | torch.save(model.state_dict(), model_save_path + 'train-mid-model.pt')
307 | save_json_data(dict_train_loss, model_save_path, "train_loss.json")
308 | save_json_data(dict_valid_loss, model_save_path, "valid_loss.json")
309 |
310 | if args.test_flag:
311 | model.load_state_dict(torch.load(model_save_path + 'val-best-model.pt'))
312 | start_time = time.time()
313 | test_id_acc1, test_id_recall, test_id_precision, test_dis_mae_loss, test_dis_rmse_loss, \
314 | test_dis_rn_mae_loss, test_dis_rn_rmse_loss, test_rate_loss, test_id_loss = evaluate(model, test_iterator,
315 | rn_dict, grid_rn_dict, rn,
316 | raw2new_rid_dict,
317 | online_features_dict,
318 | rid_features_dict,
319 | raw_rn_dict, new2raw_rid_dict,
320 | args)
321 | end_time = time.time()
322 | epoch_mins, epoch_secs = epoch_time(start_time, end_time)
323 | logging.info('Test Time: ' + str(epoch_mins) + 'm' + str(epoch_secs) + 's')
324 | logging.info('\tTest RID Acc1:' + str(test_id_acc1) +
325 | '\tTest RID Recall:' + str(test_id_recall) +
326 | '\tTest RID Precision:' + str(test_id_precision) +
327 | '\tTest Distance MAE Loss:' + str(test_dis_mae_loss) +
328 | '\tTest Distance RMSE Loss:' + str(test_dis_rmse_loss) +
329 | '\tTest Distance RN MAE Loss:' + str(test_dis_rn_mae_loss) +
330 | '\tTest Distance RN RMSE Loss:' + str(test_dis_rn_rmse_loss) +
331 | '\tTest Rate Loss:' + str(test_rate_loss) +
332 | '\tTest RID Loss:' + str(test_id_loss))
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/3 13:34
--------------------------------------------------------------------------------
/utils/coord_transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import json
3 | import urllib
4 | import math
5 |
6 | x_pi = 3.14159265358979324 * 3000.0 / 180.0
7 | pi = 3.1415926535897932384626 # π
8 | a = 6378245.0 # 长半轴
9 | ee = 0.00669342162296594323 # 偏心率平方
10 |
11 |
12 | # class Geocoding:
13 | # def __init__(self, api_key):
14 | # self.api_key = api_key
15 | #
16 | # def geocode(self, address):
17 | # """
18 | # 利用高德geocoding服务解析地址获取位置坐标
19 | # :param address:需要解析的地址
20 | # :return:
21 | # """
22 | # geocoding = {'s': 'rsv3',
23 | # 'key': self.api_key,
24 | # 'city': '全国',
25 | # 'address': address}
26 | # geocoding = urllib.urlencode(geocoding)
27 | # ret = urllib.urlopen("%s?%s" % ("http://restapi.amap.com/v3/geocode/geo", geocoding))
28 | #
29 | # if ret.getcode() == 200:
30 | # res = ret.read()
31 | # json_obj = json.loads(res)
32 | # if json_obj['status'] == '1' and int(json_obj['count']) >= 1:
33 | # geocodes = json_obj['geocodes'][0]
34 | # lng = float(geocodes.get('location').split(',')[0])
35 | # lat = float(geocodes.get('location').split(',')[1])
36 | # return [lng, lat]
37 | # else:
38 | # return None
39 | # else:
40 | # return None
41 |
42 | class Convert():
43 | def __init__(self):
44 | pass
45 |
46 | def convert(self, lng, lat):
47 | return lng, lat
48 |
49 |
50 | class GCJ02ToWGS84(Convert):
51 | def __init__(self):
52 | super().__init__()
53 |
54 | def convert(self, lng, lat):
55 | """
56 | GCJ02(火星坐标系)转GPS84
57 | :param lng:火星坐标系的经度
58 | :param lat:火星坐标系纬度
59 | :return:
60 | """
61 | if out_of_china(lng, lat):
62 | return [lng, lat]
63 | dlat = _transformlat(lng - 105.0, lat - 35.0)
64 | dlng = _transformlng(lng - 105.0, lat - 35.0)
65 | radlat = lat / 180.0 * pi
66 | magic = math.sin(radlat)
67 | magic = 1 - ee * magic * magic
68 | sqrtmagic = math.sqrt(magic)
69 | dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
70 | dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
71 | mglat = lat + dlat
72 | mglng = lng + dlng
73 | return lng * 2 - mglng, lat * 2 - mglat
74 |
75 |
76 | class WGS84ToGCJ02(Convert):
77 | def __init__(self):
78 | super().__init__()
79 |
80 | def convert(self, lng, lat):
81 | """
82 | WGS84转GCJ02(火星坐标系)
83 | :param lng:WGS84坐标系的经度
84 | :param lat:WGS84坐标系的纬度
85 | :return:
86 | """
87 | if out_of_china(lng, lat): # 判断是否在国内
88 | return [lng, lat]
89 | dlat = _transformlat(lng - 105.0, lat - 35.0)
90 | dlng = _transformlng(lng - 105.0, lat - 35.0)
91 | radlat = lat / 180.0 * pi
92 | magic = math.sin(radlat)
93 | magic = 1 - ee * magic * magic
94 | sqrtmagic = math.sqrt(magic)
95 | dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
96 | dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
97 | mglat = lat + dlat
98 | mglng = lng + dlng
99 | return mglng, mglat
100 |
101 |
102 | # def gcj02_to_bd09(lng, lat):
103 | # """
104 | # 火星坐标系(GCJ-02)转百度坐标系(BD-09)
105 | # 谷歌、高德——>百度
106 | # :param lng:火星坐标经度
107 | # :param lat:火星坐标纬度
108 | # :return:
109 | # """
110 | # z = math.sqrt(lng * lng + lat * lat) + 0.00002 * math.sin(lat * x_pi)
111 | # theta = math.atan2(lat, lng) + 0.000003 * math.cos(lng * x_pi)
112 | # bd_lng = z * math.cos(theta) + 0.0065
113 | # bd_lat = z * math.sin(theta) + 0.006
114 | # return [bd_lng, bd_lat]
115 | #
116 | #
117 | # def bd09_to_gcj02(bd_lon, bd_lat):
118 | # """
119 | # 百度坐标系(BD-09)转火星坐标系(GCJ-02)
120 | # 百度——>谷歌、高德
121 | # :param bd_lat:百度坐标纬度
122 | # :param bd_lon:百度坐标经度
123 | # :return:转换后的坐标列表形式
124 | # """
125 | # x = bd_lon - 0.0065
126 | # y = bd_lat - 0.006
127 | # z = math.sqrt(x * x + y * y) - 0.00002 * math.sin(y * x_pi)
128 | # theta = math.atan2(y, x) - 0.000003 * math.cos(x * x_pi)
129 | # gg_lng = z * math.cos(theta)
130 | # gg_lat = z * math.sin(theta)
131 | # return [gg_lng, gg_lat]
132 | #
133 | #
134 | # def wgs84_to_gcj02(lng, lat):
135 | # """
136 | # WGS84转GCJ02(火星坐标系)
137 | # :param lng:WGS84坐标系的经度
138 | # :param lat:WGS84坐标系的纬度
139 | # :return:
140 | # """
141 | # if out_of_china(lng, lat): # 判断是否在国内
142 | # return [lng, lat]
143 | # dlat = _transformlat(lng - 105.0, lat - 35.0)
144 | # dlng = _transformlng(lng - 105.0, lat - 35.0)
145 | # radlat = lat / 180.0 * pi
146 | # magic = math.sin(radlat)
147 | # magic = 1 - ee * magic * magic
148 | # sqrtmagic = math.sqrt(magic)
149 | # dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
150 | # dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
151 | # mglat = lat + dlat
152 | # mglng = lng + dlng
153 | # return [mglng, mglat]
154 | #
155 | #
156 | # def gcj02_to_wgs84(lng, lat):
157 | # """
158 | # GCJ02(火星坐标系)转GPS84
159 | # :param lng:火星坐标系的经度
160 | # :param lat:火星坐标系纬度
161 | # :return:
162 | # """
163 | # if out_of_china(lng, lat):
164 | # return [lng, lat]
165 | # dlat = _transformlat(lng - 105.0, lat - 35.0)
166 | # dlng = _transformlng(lng - 105.0, lat - 35.0)
167 | # radlat = lat / 180.0 * pi
168 | # magic = math.sin(radlat)
169 | # magic = 1 - ee * magic * magic
170 | # sqrtmagic = math.sqrt(magic)
171 | # dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
172 | # dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
173 | # mglat = lat + dlat
174 | # mglng = lng + dlng
175 | # return [lng * 2 - mglng, lat * 2 - mglat]
176 | #
177 | #
178 | # def bd09_to_wgs84(bd_lon, bd_lat):
179 | # lon, lat = bd09_to_gcj02(bd_lon, bd_lat)
180 | # return gcj02_to_wgs84(lon, lat)
181 | #
182 | #
183 | # def wgs84_to_bd09(lon, lat):
184 | # lon, lat = wgs84_to_gcj02(lon, lat)
185 | # return gcj02_to_bd09(lon, lat)
186 | #
187 | #
188 | def _transformlat(lng, lat):
189 | ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
190 | 0.1 * lng * lat + 0.2 * math.sqrt(math.fabs(lng))
191 | ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
192 | math.sin(2.0 * lng * pi)) * 2.0 / 3.0
193 | ret += (20.0 * math.sin(lat * pi) + 40.0 *
194 | math.sin(lat / 3.0 * pi)) * 2.0 / 3.0
195 | ret += (160.0 * math.sin(lat / 12.0 * pi) + 320 *
196 | math.sin(lat * pi / 30.0)) * 2.0 / 3.0
197 | return ret
198 |
199 |
200 | def _transformlng(lng, lat):
201 | ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
202 | 0.1 * lng * lat + 0.1 * math.sqrt(math.fabs(lng))
203 | ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
204 | math.sin(2.0 * lng * pi)) * 2.0 / 3.0
205 | ret += (20.0 * math.sin(lng * pi) + 40.0 *
206 | math.sin(lng / 3.0 * pi)) * 2.0 / 3.0
207 | ret += (150.0 * math.sin(lng / 12.0 * pi) + 300.0 *
208 | math.sin(lng / 30.0 * pi)) * 2.0 / 3.0
209 | return ret
210 |
211 |
212 | def out_of_china(lng, lat):
213 | """
214 | 判断是否在国内,不在国内不做偏移
215 | :param lng:
216 | :param lat:
217 | :return:
218 | """
219 | return not (lng > 73.66 and lng < 135.05 and lat > 3.86 and lat < 53.55)
220 |
221 |
222 | # if __name__ == '__main__':
223 | # lng, lat = 116.527559, 39.807378
224 | # min_lat = 39.727
225 | # min_lng = 116.490
226 | # max_lat = 39.83
227 | # max_lng = 116.588
228 | # # lng = 109.642194
229 | # # lat = 20.123355
230 | # result1 = gcj02_to_bd09(lng, lat)
231 | # result2 = bd09_to_gcj02(lng, lat)
232 | # result3 = wgs84_to_gcj02(lng, lat)
233 | # result4 = gcj02_to_wgs84(lng, lat)
234 | # result5 = bd09_to_wgs84(lng, lat)
235 | # result6 = wgs84_to_bd09(lng, lat)
236 | # print(gcj02_to_wgs84(min_lng, min_lat))
237 | # print(gcj02_to_wgs84(max_lng, max_lat))
238 |
--------------------------------------------------------------------------------
/utils/imputation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/22 16:29
4 | # @authoer: Huimin Ren
5 | from common.spatial_func import distance, cal_loc_along_line
6 | from common.trajectory import Trajectory, STPoint
7 | import numpy as np
8 |
9 | from map_matching.candidate_point import CandidatePoint
10 | from map_matching.hmm.hmm_map_matcher import TIHMMMapMatcher
11 |
12 | class Imputation:
13 | def __init__(self):
14 | pass
15 |
16 | def impute(self, traj):
17 | pass
18 |
19 |
20 | class LinearImputation(Imputation):
21 | """
22 | Uniformly interpolate GPS points into a trajectory.
23 | Args:
24 | -----
25 | time_interval:
26 | int. unit in seconds. the time interval between two points.
27 | """
28 | def __init__(self, time_interval):
29 | super(LinearImputation, self).__init__()
30 | self.time_interval = time_interval
31 |
32 | def impute(self, traj):
33 | pt_list = traj.pt_list
34 | if len(pt_list) <= 1:
35 | return []
36 |
37 | pre_pt = pt_list[0]
38 | new_pt_list = [pre_pt]
39 | for cur_pt in pt_list[1:]:
40 | time_span = (cur_pt.time - pre_pt.time).total_seconds()
41 | num_pts = 0 # do not have to interpolate points
42 | if time_span % self.time_interval > self.time_interval / 2:
43 | # if the reminder is larger than half of the time interval
44 | num_pts = time_span // self.time_interval # quotient
45 | elif time_span > self.time_interval:
46 | # if if the reminder is smaller than half of the time interval and not equal to time interval
47 | num_pts = time_span // self.time_interval - 1
48 |
49 | interp_list = []
50 | unit_lat = abs(cur_pt.lat - pre_pt.lat) / (num_pts + 1)
51 | unit_lng = abs(cur_pt.lng - pre_pt.lng) / (num_pts + 1)
52 | unit_ts = (cur_pt.time - pre_pt.time) / (num_pts + 1)
53 | sign_lat = np.sign(cur_pt.lat - pre_pt.lat)
54 | sign_lng = np.sign(cur_pt.lng - pre_pt.lng)
55 | for i in range(int(num_pts)):
56 | new_lat = pre_pt.lat + sign_lat * (i+1) * unit_lat
57 | new_lng = pre_pt.lng + sign_lng * (i+1) * unit_lng
58 | new_time = pre_pt.time + (i+1) * unit_ts
59 | interp_list.append(STPoint(new_lat, new_lng, new_time))
60 |
61 | new_pt_list.extend(interp_list)
62 | new_pt_list.append(cur_pt)
63 | pre_pt = cur_pt
64 |
65 | new_traj = Trajectory(traj.oid, traj.tid, new_pt_list)
66 |
67 | return new_traj
68 |
69 |
70 | class MMLinearImputation(Imputation):
71 |
72 | def __init__(self, time_interval):
73 | super(MMLinearImputation, self).__init__()
74 | self.time_interval = time_interval
75 |
76 | def impute(self, traj, rn, rn_dict):
77 | try:
78 | map_matcher = TIHMMMapMatcher(rn)
79 | mm_ls_path = map_matcher.match_to_path(traj)[0] # find shortest path
80 | except:
81 | # cannot find shortest path
82 | return None
83 |
84 | path_eids = [p.eid for p in mm_ls_path.path_entities]
85 |
86 | pre_mm_pt = traj.pt_list[0]
87 | new_pt_list = [pre_mm_pt]
88 | for cur_mm_pt in traj.pt_list[1:]:
89 | time_span = (cur_mm_pt.time - pre_mm_pt.time).total_seconds()
90 | num_pts = 0 # do not have to interpolate points
91 | if time_span % self.time_interval > self.time_interval / 2:
92 | # if the reminder is larger than half of the time interval
93 | num_pts = time_span // self.time_interval # quotient
94 | elif time_span > self.time_interval:
95 | # if if the reminder is smaller than half of the time interval and not equal to time interval
96 | num_pts = time_span // self.time_interval - 1
97 |
98 | if pre_mm_pt.data['candi_pt'] is None or cur_mm_pt.data['candi_pt'] is None:
99 | return None
100 |
101 | pre_eid = pre_mm_pt.data['candi_pt'].eid
102 | cur_eid = cur_mm_pt.data['candi_pt'].eid
103 | two_points_coords, two_points_eids, ttl_dis = self.get_two_points_coords(path_eids, pre_eid, cur_eid,
104 | pre_mm_pt, cur_mm_pt, rn_dict)
105 | interp_list = self.get_interp_list(num_pts, cur_mm_pt, pre_mm_pt, ttl_dis,
106 | two_points_eids, two_points_coords, rn_dict)
107 |
108 | new_pt_list.extend(interp_list)
109 | new_pt_list.append(cur_mm_pt)
110 | pre_mm_pt = cur_mm_pt
111 | new_traj = Trajectory(traj.oid, traj.tid, new_pt_list)
112 |
113 | # get all coords of shortest path
114 | # path_coords = []
115 | # for eid in path_eids:
116 | # path_coords.extend(rn_dict[eid]['coords'])
117 | # path_pt_list = []
118 | # for pt in path_coords:
119 | # path_pt_list.append([pt.lat, pt.lng])
120 |
121 | return new_traj
122 |
123 | def get_interp_list(self, num_pts, cur_mm_pt, pre_mm_pt, ttl_dis, two_points_eids, two_points_coords, rn_dict):
124 | interp_list = []
125 | unit_ts = (cur_mm_pt.time - pre_mm_pt.time) / (num_pts + 1)
126 | for n in range(int(num_pts)):
127 | new_time = pre_mm_pt.time + (n + 1) * unit_ts
128 | move_dis = (ttl_dis / num_pts) * n + pre_mm_pt.data['candi_pt'].offset
129 |
130 | # get eid and offset
131 | pre_road_dist, road_dist = 0, 0
132 | for i in range(len(two_points_eids)):
133 | if i > 0:
134 | pre_road_dist += rn_dict[two_points_eids[i - 1]]['length']
135 | road_dist += rn_dict[two_points_eids[i]]['length']
136 | if move_dis <= road_dist:
137 | insert_eid = two_points_eids[i]
138 | insert_offset = move_dis - pre_road_dist
139 | break
140 |
141 | # get lat and lng
142 | dist, pre_dist = 0, 0
143 | for i in range(len(two_points_coords) - 1):
144 | if i > 0:
145 | pre_dist += distance(two_points_coords[i - 1][0], two_points_coords[i][0])
146 | dist += distance(two_points_coords[i][0], two_points_coords[i + 1][0])
147 | if dist >= move_dis:
148 | coor_rate = (move_dis - pre_dist) / distance(two_points_coords[i][0],
149 | two_points_coords[i + 1][0])
150 | project_pt = cal_loc_along_line(two_points_coords[i][0], two_points_coords[i + 1][0], coor_rate)
151 | break
152 | data = {'candi_pt': CandidatePoint(project_pt.lat, project_pt.lng, insert_eid, 0, insert_offset, 0)}
153 | interp_list.append(STPoint(project_pt.lat, project_pt.lng, new_time, data))
154 |
155 | return interp_list
156 |
157 | def get_two_points_coords(self, path_eids, pre_eid, cur_eid, pre_mm_pt, cur_mm_pt, rn_dict):
158 | if pre_eid == cur_eid:
159 | # if in the same road
160 | two_points_eids = [path_eids[path_eids.index(pre_eid)]]
161 | two_points_coords = [[item, two_points_eids[0]] for item in rn_dict[two_points_eids[0]]['coords']]
162 | ttl_dis = cur_mm_pt.data['candi_pt'].offset - cur_mm_pt.data['candi_pt'].offset
163 | else:
164 | # if in different road
165 | start = path_eids.index(pre_eid)
166 | end = path_eids.index(cur_eid) + 1
167 | if start >= end:
168 | end = path_eids.index(cur_eid, start) # cur_eid shows at least twice
169 | two_points_eids = path_eids[start: end]
170 |
171 | two_points_coords = [] # 2D, [gps, eid]
172 | ttl_eids_dis = 0
173 | for eid in two_points_eids[:-1]:
174 | tmp_coords = [[item, eid] for item in rn_dict[eid]['coords']]
175 | two_points_coords.extend(tmp_coords)
176 | ttl_eids_dis += rn_dict[eid]['length']
177 | ttl_dis = ttl_eids_dis - pre_mm_pt.data['candi_pt'].offset + cur_mm_pt.data['candi_pt'].offset
178 | tmp_coords = [[item, two_points_eids[-1]] for item in rn_dict[two_points_eids[-1]]['coords']]
179 | two_points_coords.extend(tmp_coords) # add coords after computing ttl_dis
180 |
181 | return two_points_coords, two_points_eids, ttl_dis
182 |
--------------------------------------------------------------------------------
/utils/noise_filter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/7 17:07
4 | # Add new method
5 | from statistics import median
6 |
7 | from common.spatial_func import distance
8 | from common.trajectory import Trajectory, STPoint
9 |
10 | """
11 | All the methods are refered from
12 | Zheng, Y. and Zhou, X. eds., 2011.
13 | Computing with spatial trajectories. Springer Science & Business Media.
14 | Chapter 1 Trajectory Preprocessing
15 |
16 | """
17 |
18 |
19 | class NoiseFilter:
20 | def filter(self, traj):
21 | pass
22 |
23 | def get_tid(self, oid, clean_pt_list):
24 | return oid + '_' + clean_pt_list[0].time.strftime('%Y%m%d%H%M') + '_' + \
25 | clean_pt_list[-1].time.strftime('%Y%m%d%H%M')
26 |
27 |
28 | # modify point
29 | class MedianFilter(NoiseFilter):
30 | """
31 | Smooth each point with median value, since mean filter is sensitive to outliers.
32 | x_i = median(z_{i-n+1}, z_{i-n+2}, ..., z_{i-1}, z_{i}) eq.1.4
33 | i: ith point
34 | z: original points
35 | n: window size
36 | """
37 |
38 | def __init__(self, win_size=3):
39 | super(MedianFilter, self).__init__()
40 | self.win_size = win_size // 2 # put the target value in middle
41 |
42 | def filter(self, traj):
43 | pt_list = traj.pt_list.copy()
44 | if len(pt_list) <= 1:
45 | return None
46 |
47 | for i in range(1, len(pt_list)-1):
48 | # post-preprocessing. make sure index_i is in the middle.
49 | wind_size = self.win_size
50 | lats, lngs = [], []
51 | if self.win_size > i:
52 | wind_size = i
53 | elif self.win_size > len(pt_list) - 1 - i:
54 | wind_size = len(pt_list) - 1 - i
55 | # get smoothed location
56 | for pt in pt_list[i-wind_size:i+wind_size+1]:
57 | lats.append(pt.lat)
58 | lngs.append(pt.lng)
59 |
60 | pt_list[i] = STPoint(median(lats), median(lngs), pt_list[i].time)
61 |
62 | if len(pt_list) > 1:
63 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list)
64 | return Trajectory(traj.oid, traj.tid, pt_list)
65 | else:
66 | return None
67 |
68 |
69 | # modify point
70 | class HeuristicMeanFilter(NoiseFilter):
71 | """
72 | Find outlier by speed (if the current speed is out of max speed)
73 | Replace outlier with mean
74 | Mean filter usually handles individual noise points with a dense representation.
75 | """
76 |
77 | def __init__(self, max_speed, win_size=1):
78 | """
79 | Args:
80 | ----
81 | max_speed:
82 | int. m/s. threshold of noise speed
83 | win_size:
84 | int. prefer odd number. window size of calculating mean value to replace noise.
85 | """
86 | super(NoiseFilter, self).__init__()
87 |
88 | self.max_speed = max_speed
89 | self.win_size = win_size
90 |
91 | def filter(self, traj):
92 | """
93 | When previous speed and next speed both are larger than max speed, then considering it as outlier.
94 | Replace outlier with mean value. The range is defined by window size.
95 | consider about the boundary.
96 | make sure noise value is in the middle.
97 |
98 | Args:
99 | -----
100 | traj:
101 | Trajectory(). a single trajectory
102 | Returns:
103 | --------
104 | new_traj:
105 | Trajectory(). replace noise with mean or median
106 | """
107 | pt_list = traj.pt_list.copy()
108 | if len(pt_list) <= 1:
109 | return None
110 | for i in range(1, len(pt_list) - 1):
111 | time_span_pre = (pt_list[i].time - pt_list[i - 1].time).total_seconds()
112 | dist_pre = distance(pt_list[i - 1], pt_list[i])
113 | time_span_next = (pt_list[i + 1].time - pt_list[i].time).total_seconds()
114 | dist_next = distance(pt_list[i], pt_list[i + 1])
115 | # compute current speed
116 | speed_pre = dist_pre / time_span_pre
117 | speed_next = dist_next / time_span_next
118 | # if the first point is noise
119 | if i == 1 and speed_pre > self.max_speed > speed_next:
120 | lat = pt_list[i].lat * 2 - pt_list[i + 1].lat
121 | lng = pt_list[i].lng * 2 - pt_list[i + 1].lng
122 | pt_list[0] = STPoint(lat, lng, pt_list[0].time)
123 | # if the last point is noise
124 | elif i == len(pt_list) - 2 and speed_next > self.max_speed >= speed_pre:
125 | lat = pt_list[i - 1].lat * 2 - pt_list[i - 2].lat
126 | lng = pt_list[i - 1].lng * 2 - pt_list[i - 2].lng
127 | pt_list[i + 1] = STPoint(lat, lng, pt_list[i].time)
128 | # if the middle point is noise
129 | elif speed_pre > self.max_speed and speed_next > self.max_speed:
130 | pt_list[i] = STPoint(0, 0, pt_list[i].time)
131 | lats, lngs = [], []
132 | # fix index bug. make sure index_i is in the middle.
133 | wind_size = self.win_size
134 | if self.win_size > i:
135 | wind_size = i
136 | elif self.win_size > len(pt_list) - 1 - i:
137 | wind_size = len(pt_list) - 1 - i
138 | for pt in pt_list[i-wind_size:i+wind_size+1]:
139 | lats.append(pt.lat)
140 | lngs.append(pt.lng)
141 |
142 | lat = sum(lats) / (len(lats) - 1)
143 | lng = sum(lngs) / (len(lngs) - 1)
144 | pt_list[i] = STPoint(lat, lng, pt_list[i].time)
145 |
146 | if len(pt_list) > 1:
147 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list)
148 | return Trajectory(traj.oid, traj.tid, pt_list)
149 | else:
150 | return None
151 |
152 |
153 | # remove point
154 | class HeuristicFilter(NoiseFilter):
155 | """
156 | Remove outlier if it is out of the max speed
157 | """
158 |
159 | def __init__(self, max_speed):
160 | super(NoiseFilter, self).__init__()
161 | self.max_speed = max_speed
162 |
163 | def filter(self, traj):
164 | pt_list = traj.pt_list
165 | if len(pt_list) <= 1:
166 | return None
167 |
168 | remove_inds = []
169 | for i in range(1, len(pt_list) - 1):
170 | time_span_pre = (pt_list[i].time - pt_list[i - 1].time).total_seconds()
171 | dist_pre = distance(pt_list[i - 1], pt_list[i])
172 | time_span_next = (pt_list[i + 1].time - pt_list[i].time).total_seconds()
173 | dist_next = distance(pt_list[i], pt_list[i + 1])
174 | speed_pre = dist_pre / time_span_pre
175 | speed_next = dist_next / time_span_next
176 | # the first point is outlier
177 | if i == 1 and speed_pre > self.max_speed > speed_next:
178 | remove_inds.append(0)
179 | # the last point is outlier
180 | elif i == len(pt_list) - 2 and speed_next > self.max_speed >= speed_pre:
181 | remove_inds.append(len(pt_list) - 1)
182 | # middle point is outlier
183 | elif speed_pre > self.max_speed and speed_next > self.max_speed:
184 | remove_inds.append(i)
185 |
186 | clean_pt_list = []
187 | for j in range(len(pt_list)):
188 | if j in remove_inds:
189 | continue
190 | clean_pt_list.append(pt_list[j])
191 |
192 | if len(clean_pt_list) > 1:
193 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list)
194 | return Trajectory(traj.oid, traj.tid, pt_list)
195 | else:
196 | return None
197 |
198 |
199 | # remove point
200 | class STFilter(NoiseFilter):
201 | """
202 | remove point if it is out of mbr
203 | """
204 | def __init__(self, mbr, start_time, end_time):
205 | super(STFilter, self).__init__()
206 | self.mbr = mbr
207 | self.start_time = start_time
208 | self.end_time = end_time
209 |
210 | def filter(self, traj):
211 | pt_list = traj.pt_list
212 | if len(pt_list) <= 1:
213 | return None
214 | clean_pt_list = []
215 | for pt in pt_list:
216 | if self.start_time <= pt.time < self.end_time and self.mbr.contains(pt.lat, pt.lng):
217 | clean_pt_list.append(pt)
218 | if len(clean_pt_list) > 1:
219 | # return Trajectory(traj.oid, self.get_tid(traj.oid, pt_list), pt_list)
220 | return Trajectory(traj.oid, traj.tid, pt_list)
221 | else:
222 | return None
223 |
--------------------------------------------------------------------------------
/utils/parse_traj.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/23 15:40
4 | # Reference: https://github.com/huiminren/tptk/blob/master/common/trajectory.py
5 |
6 | import re
7 | from datetime import datetime
8 | import pandas as pd
9 |
10 | from common.trajectory import Trajectory, STPoint
11 | from map_matching.candidate_point import CandidatePoint
12 |
13 |
14 | class ParseTraj:
15 | """
16 | ParseTraj is an abstract class for parsing trajectory.
17 | It defines parse() function for parsing trajectory.
18 | """
19 | def __init__(self):
20 | pass
21 |
22 | def parse(self, input_path):
23 | """
24 | The parse() function is to load data to a list of Trajectory()
25 | """
26 | pass
27 |
28 |
29 | class ParseRawTraj(ParseTraj):
30 | """
31 | Parse original GPS points to trajectories list. No extra data preprocessing
32 | """
33 | def __init__(self):
34 | super().__init__()
35 |
36 | def parse(self, input_path):
37 | """
38 | Args:
39 | -----
40 | input_path:
41 | str. input directory with file name
42 | Returns:
43 | --------
44 | trajs:
45 | list. list of trajectories. trajs contain input_path file's all gps points
46 | """
47 | time_format = '%Y/%m/%d %H:%M:%S'
48 | tid_to_remove = '[:/ ]'
49 | with open(input_path, 'r') as f:
50 | trajs = []
51 | pt_list = []
52 | for line in f.readlines():
53 | attrs = line.rstrip().split(',')
54 | if attrs[0] == '#':
55 | if len(pt_list) > 1:
56 | traj = Trajectory(oid, tid, pt_list)
57 | trajs.append(traj)
58 | oid = attrs[2]
59 | tid = attrs[1]
60 | pt_list = []
61 | else:
62 | lat = float(attrs[1])
63 | lng = float(attrs[2])
64 | pt = STPoint(lat, lng, datetime.strptime(attrs[0], time_format))
65 | # pt contains all the attributes of class STPoint
66 | pt_list.append(pt)
67 | if len(pt_list) > 1:
68 | traj = Trajectory(oid, tid, pt_list)
69 | trajs.append(traj)
70 | return trajs
71 |
72 |
73 | class ParseMMTraj(ParseTraj):
74 | """
75 | Parse map matched GPS points to trajectories list. No extra data preprocessing
76 | """
77 | def __init__(self):
78 | super().__init__()
79 |
80 | def parse(self, input_path):
81 | """
82 | Args:
83 | -----
84 | input_path:
85 | str. input directory with file name
86 | Returns:
87 | --------
88 | trajs:
89 | list. list of trajectories. trajs contain input_path file's all gps points
90 | """
91 | time_format = '%Y/%m/%d %H:%M:%S'
92 | tid_to_remove = '[:/ ]'
93 | with open(input_path, 'r') as f:
94 | trajs = []
95 | pt_list = []
96 | for line in f.readlines():
97 | attrs = line.rstrip().split(',')
98 | if attrs[0] == '#':
99 | if len(pt_list) > 1:
100 | traj = Trajectory(oid, tid, pt_list)
101 | trajs.append(traj)
102 | oid = attrs[2]
103 | tid = attrs[1]
104 | pt_list = []
105 | else:
106 | lat = float(attrs[1])
107 | lng = float(attrs[2])
108 | if attrs[3] == 'None':
109 | candi_pt = None
110 | else:
111 | eid = int(attrs[3])
112 | proj_lat = float(attrs[4])
113 | proj_lng = float(attrs[5])
114 | error = float(attrs[6])
115 | offset = float(attrs[7])
116 | rate = float(attrs[8])
117 | candi_pt = CandidatePoint(proj_lat, proj_lng, eid, error, offset, rate)
118 | pt = STPoint(lat, lng, datetime.strptime(attrs[0], time_format), {'candi_pt': candi_pt})
119 | # pt contains all the attributes of class STPoint
120 | pt_list.append(pt)
121 | if len(pt_list) > 1:
122 | traj = Trajectory(oid, tid, pt_list)
123 | trajs.append(traj)
124 | return trajs
125 |
126 |
127 | class ParseJUSTInputTraj(ParseTraj):
128 | """
129 | Parse JUST input format to list of Trajectory()
130 | """
131 | def __init__(self):
132 | super().__init__()
133 |
134 | def parse(self, input_path):
135 | time_format = '%Y-%m-%d %H:%M:%S'
136 | with open(input_path, 'r') as f:
137 | trajs = []
138 | pt_list = []
139 | pre_tid = ''
140 | for line in f.readlines():
141 | attrs = line.rstrip().split(',')
142 | tid = attrs[0]
143 | oid = attrs[1]
144 | time = datetime.strptime(attrs[2][:19], time_format)
145 | lat = float(attrs[3])
146 | lng = float(attrs[4])
147 | pt = STPoint(lat, lng, time)
148 | if pre_tid != tid:
149 | if len(pt_list) > 1:
150 | traj = Trajectory(oid, pre_tid, pt_list)
151 | trajs.append(traj)
152 | pt_list = []
153 | pt_list.append(pt)
154 | pre_tid = tid
155 | if len(pt_list) > 1:
156 | traj = Trajectory(oid, tid, pt_list)
157 | trajs.append(traj)
158 |
159 | return trajs
160 |
161 |
162 | class ParseJUSTOutputTraj(ParseTraj):
163 | """
164 | Parse JUST output to trajectories list. The output format will be the same as Trajectory()
165 | """
166 | def __init__(self):
167 | super().__init__()
168 |
169 | def parse(self, input_path, feature_flag=False):
170 | """
171 | Args:
172 | -----
173 | input_path:
174 | str. input directory with file name
175 | Returns:
176 | --------
177 | trajs:
178 | list of Trajectory()
179 |
180 | 'oid': object_id
181 | 'geom': line string of raw trajectory
182 | 'time': trajectory start time
183 | 'tid': trajectory id
184 | 'time_series': line string of map matched trajectory and containing other features
185 | raw start time, raw lng, raw lat, road segment ID, index of road segment,
186 | distance between raw point to map matched point, distanced between projected point and start of road segment.
187 | 'start_position': raw start position
188 | 'end_position': raw end position
189 | 'point_number': number of points in the trajectory
190 | 'length': distance of the trajectory in km
191 | 'speed': average speed of the trajectory in km/h
192 | 'signature': signature for GIS
193 | 'id': primary key
194 | """
195 | col_names = ['oid', 'geom', 'tid', 'time_series']
196 | df = pd.read_csv(input_path, sep='|', usecols=col_names)
197 |
198 | str_to_remove = '[LINESTRING()]'
199 | time_format = '%Y-%m-%d %H:%M:%S'
200 | trajs = []
201 | for i in range(len(df)):
202 | oid = df['oid'][i]
203 | tid = str(df['tid'][i])
204 | # prepare to get map matched lat, lng and original datetime for each GPS point
205 | time_series = df['time_series'][i]
206 | geom = re.sub(str_to_remove, "", df['geom'][i]) # load geom and remove "LINESTRING()"
207 | ts_list = time_series.split(';') # contain datetime of each gps point and original gps location
208 | geom_list = geom.split(',') # contain map matched gps points
209 | assert len(ts_list) == len(geom_list)
210 |
211 | pt_list = []
212 | for j in range(len(ts_list)):
213 | tmp_location = geom_list[j].split(" ")
214 | tmp_features = ts_list[j].split(",")
215 | lat = tmp_location[2]
216 | lng = tmp_location[1]
217 | time = tmp_features[0][:19] # ts_list[j][:19]
218 |
219 | if feature_flag:
220 | rid = int(tmp_features[3])
221 | rdis = float(tmp_features[-1][:-1])
222 | pt = STPoint(lat, lng, datetime.strptime(time, time_format), rid=rid, rdis=rdis)
223 | else:
224 | pt = STPoint(lat, lng, datetime.strptime(time, time_format))
225 |
226 | pt_list.append(pt)
227 | traj = Trajectory(oid, tid, pt_list)
228 | trajs.append(traj)
229 |
230 | return trajs
231 |
--------------------------------------------------------------------------------
/utils/save_traj.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/23 15:40
4 | # Reference: https://github.com/huiminren/tptk/blob/master/common/trajectory.py
5 | from common.trajectory import get_tid
6 | from utils.coord_transform import GCJ02ToWGS84, WGS84ToGCJ02, Convert
7 |
8 |
9 | class SaveTraj:
10 | """
11 | SaveTraj is an abstract class for storing trajectory.
12 | It defines store() function for storing trajectory to different format.
13 | """
14 | def __init__(self, convert_method):
15 | # GCJ: Auto, Didi
16 | # WGS: OSM, Tiandi
17 | if convert_method == 'GCJ02ToWGS84':
18 | self.convert = GCJ02ToWGS84()
19 | elif convert_method == 'WGS84ToGCJ02':
20 | self.convert = WGS84ToGCJ02()
21 | elif convert_method is None:
22 | self.convert = Convert()
23 |
24 | def store(self, trajs, target_path):
25 | pass
26 |
27 |
28 | class SaveTraj2Raw(SaveTraj):
29 | def __init__(self, convert_method=None):
30 | super().__init__(convert_method)
31 |
32 | def store(self, trajs, target_path):
33 | time_format = '%Y/%m/%d %H:%M:%S'
34 | with open(target_path, 'w') as f:
35 | for traj in trajs:
36 | pt_list = traj.pt_list
37 | tid = get_tid(traj.oid, pt_list)
38 | f.write('#,{},{},{},{},{} km\n'.format(tid, traj.oid, pt_list[0].time.strftime(time_format),
39 | pt_list[-1].time.strftime(time_format),
40 | traj.get_distance() / 1000))
41 | for pt in pt_list:
42 | lng, lat = self.convert.convert(pt.lng, pt.lat)
43 | f.write('{},{},{}\n'.format(
44 | pt.time.strftime(time_format), lat, lng))
45 |
46 |
47 | class SaveTraj2MM(SaveTraj):
48 | """
49 | """
50 | def __init__(self, convert_method=None):
51 | super().__init__(convert_method)
52 |
53 | def store(self, trajs, target_path):
54 | time_format = '%Y/%m/%d %H:%M:%S'
55 | with open(target_path, 'w') as f:
56 | for traj in trajs:
57 | pt_list = traj.pt_list
58 | tid = get_tid(traj.oid, pt_list)
59 | f.write('#,{},{},{},{},{} km\n'.format(tid, traj.oid, pt_list[0].time.strftime(time_format),
60 | pt_list[-1].time.strftime(time_format),
61 | traj.get_distance() / 1000))
62 | for pt in pt_list:
63 | candi_pt = pt.data['candi_pt']
64 | if candi_pt is not None:
65 | f.write('{},{},{},{},{},{},{},{},{}\n'.format(pt.time.strftime(time_format), pt.lat, pt.lng,
66 | candi_pt.eid, candi_pt.lat, candi_pt.lng,
67 | candi_pt.error, candi_pt.offset, candi_pt.rate))
68 | else:
69 | f.write('{},{},{},None,None,None,None,None,None\n'.format(
70 | pt.time.strftime(time_format), pt.lat, pt.lng))
71 |
72 |
73 | class SaveTraj2JUST(SaveTraj):
74 | """
75 | Convert trajs to JUST format.
76 | cvs file. trajectory_id, oid, time, lat, lng
77 | """
78 | def __init__(self, convert_method=None):
79 | super().__init__(convert_method)
80 |
81 | def store(self, trajs, target_path):
82 | """
83 | Convert trajs to JUST format.
84 | cvs file. trajectory_id (primary key), oid, time, lat, lng
85 | Args:
86 | ----
87 | trajs:
88 | list. list of Trajectory()
89 | target_path:
90 | str. target path (directory + file_name)
91 | """
92 | with open(target_path, 'w') as f:
93 | for traj in trajs:
94 | for pt in traj.pt_list:
95 | lng, lat = self.convert.convert(pt.lng, pt.lat)
96 | f.write('{},{},{},{},{}\n'.format(traj.tid, traj.oid, pt.time, lat, lng))
97 |
98 |
--------------------------------------------------------------------------------
/utils/segmentation.py:
--------------------------------------------------------------------------------
1 | from common.trajectory import Trajectory, get_tid
2 |
3 |
4 | class Segmentation:
5 | def __init__(self):
6 | pass
7 |
8 | def segment(self, traj):
9 | pass
10 |
11 |
12 | class TimeIntervalSegmentation(Segmentation):
13 | """
14 | Split trajectory if the time interval between two GPS points is larger than max_time_interval_min
15 | Store sub-trajectory when its length is larger than min_len
16 | """
17 | def __init__(self, max_time_interval_min, min_len=1):
18 | super(Segmentation, self).__init__()
19 | self.max_time_interval = max_time_interval_min * 60
20 | self.min_len = min_len
21 |
22 | def segment(self, traj):
23 | segmented_traj_list = []
24 | pt_list = traj.pt_list
25 | if len(pt_list) <= 1:
26 | return []
27 | oid = traj.oid
28 | pre_pt = pt_list[0]
29 | partial_pt_list = [pre_pt]
30 | for cur_pt in pt_list[1:]:
31 | time_span = (cur_pt.time - pre_pt.time).total_seconds()
32 | if time_span <= self.max_time_interval:
33 | partial_pt_list.append(cur_pt)
34 | else:
35 | if len(partial_pt_list) > self.min_len:
36 | segmented_traj = Trajectory(oid, get_tid(oid, partial_pt_list), partial_pt_list)
37 | segmented_traj_list.append(segmented_traj)
38 | partial_pt_list = [cur_pt] # re-initialize partial_pt_list
39 | pre_pt = cur_pt
40 | if len(partial_pt_list) > self.min_len:
41 | segmented_traj = Trajectory(oid, get_tid(oid, partial_pt_list), partial_pt_list)
42 | segmented_traj_list.append(segmented_traj)
43 | return segmented_traj_list
44 |
45 |
46 |
--------------------------------------------------------------------------------
/utils/stats.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/3 17:08
4 | # Refernce: https://github.com/sjruan/tptk/blob/master/statistics.py
5 |
6 | from utils.utils import create_dir
7 |
8 | import matplotlib.pyplot as plt
9 | from matplotlib.ticker import PercentFormatter
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import os
14 | import json
15 |
16 |
17 | def plot_hist(data, x_axis, save_stats_dir, pic_name):
18 | plt.hist(data, weights=np.ones(len(data)) / len(data))
19 | plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
20 | plt.xlabel(x_axis)
21 | plt.ylabel('Percentage')
22 | plt.savefig(os.path.join(save_stats_dir, pic_name))
23 | plt.clf()
24 |
25 |
26 | def statistics(trajs, save_stats_dir, stats, save_stats_name, save_plot=False):
27 | """
28 | Plot basic statistical analysis , such as
29 | Args:
30 | -----
31 | traj_dir:
32 | str. directory of raw GPS points
33 | save_stats_dir:
34 | str. directory of saving stats results
35 | stats:
36 | dict. dictionary of stats
37 | save_stats_name:
38 | str. name of saving stats.
39 | plot_flat:
40 | boolean. if plot the histogram
41 | """
42 |
43 | create_dir(save_stats_dir)
44 |
45 | oids = set()
46 | tot_pts = 0
47 |
48 | distance_data = [] # geographical distance
49 | duration_data = [] # time difference between end and start time of a trajectory
50 | seq_len_data = [] # length of each trajectory
51 | traj_avg_time_interval_data = []
52 | traj_avg_dist_interval_data = []
53 |
54 | if len(stats) == 0:
55 | # if new, initialize stats with keys
56 | stats['#object'], stats['#points'], stats['#trajectories'] = 0, 0, 0
57 | stats['seq_len_data'], stats['distance_data'], stats['duration_data'], \
58 | stats['traj_avg_time_interval_data'], stats['traj_avg_dist_interval_data'] = [], [], [], [], []
59 |
60 | for traj in trajs:
61 | oids.add(traj.oid)
62 | nb_pts = len(traj.pt_list)
63 | tot_pts += nb_pts
64 |
65 | seq_len_data.append(nb_pts)
66 | distance_data.append(traj.get_distance() / 1000.0)
67 | duration_data.append(traj.get_duration() / 60.0)
68 | traj_avg_time_interval_data.append(traj.get_avg_time_interval())
69 | traj_avg_dist_interval_data.append(traj.get_avg_distance_interval())
70 |
71 | print('#objects_single:{}'.format(len(oids)))
72 | print('#points_single:{}'.format(tot_pts))
73 | print('#trajectories_single:{}'.format(len(trajs)))
74 |
75 | stats['#object'] += len(oids)
76 | stats['#points'] += tot_pts
77 | stats['#trajectories'] += len(trajs)
78 | stats['seq_len_data'] += seq_len_data
79 | stats['distance_data'] += distance_data
80 | stats['duration_data'] += duration_data
81 | stats['traj_avg_time_interval_data'] += traj_avg_time_interval_data
82 | stats['traj_avg_dist_interval_data'] += traj_avg_dist_interval_data
83 |
84 | print('#objects_total:{}'.format(stats['#object']))
85 | print('#points_total:{}'.format(stats['#points']))
86 | print('#trajectories_total:{}'.format(stats['#trajectories']))
87 |
88 | with open(os.path.join(save_stats_dir, save_stats_name + '.json'), 'w') as f:
89 | json.dump(stats, f)
90 |
91 | if save_plot:
92 | plot_hist(stats['seq_len_data'], '#Points', save_stats_dir, save_stats_name + '_nb_points_dist.png')
93 | plot_hist(stats['distance_data'], 'Distance (KM)', save_stats_dir, save_stats_name + '_distance_dist.png')
94 | plot_hist(stats['duration_data'], 'Duration (Min)', save_stats_dir, save_stats_name + '_duration_dist.png')
95 | plot_hist(stats['traj_avg_time_interval_data'], 'Time Interval (Sec)', save_stats_dir,
96 | save_stats_name + '_time_interval_dist.png')
97 | plot_hist(stats['traj_avg_dist_interval_data'], 'Distance Interval (Meter)', save_stats_dir,
98 | save_stats_name + '_distance_interval_dist.png')
99 |
100 | return stats
101 |
102 |
103 | def stats_threshold(trajs):
104 | """
105 | Find threshold for trajectory preprocessing,
106 | including time interval for splitting; minimal length and maximal abnormal ratio.
107 | Args:
108 | -----
109 | trajs:
110 | list of Trajectory(). Sampled trajectories.
111 | Returns:
112 | --------
113 | thr_min_len, thr_max_abn_ratio, thr_normal_ratio, thr_split_traj_ts
114 | """
115 | stats = {'oid': [], 'tid': [], 'traj_len': [], 'num_ab_pts': [], 'abn_ts': [], 'avg_ts': [], 'avg_abn_ts': []}
116 |
117 | for i in range(len(trajs)):
118 | stats['oid'].append(trajs[i].oid)
119 | stats['tid'].append(trajs[i].tid)
120 | stats['traj_len'].append(len(trajs[i].pt_list))
121 |
122 | pt_list = trajs[i].pt_list
123 | pre_pt = pt_list[0]
124 |
125 | time_spans = []
126 | abn_ts_list = []
127 | for cur_pt in pt_list[1:]:
128 | time_span = (cur_pt.time - pre_pt.time).total_seconds()
129 | if time_span > 4:
130 | abn_ts_list.append(time_span)
131 | time_spans.append(time_span)
132 | pre_pt = cur_pt
133 |
134 | stats['num_ab_pts'].append(len(abn_ts_list))
135 | stats['abn_ts'].append(abn_ts_list)
136 | stats['avg_ts'].append(sum(time_spans) / len(time_spans))
137 | try:
138 | # in case len(abn_ts_list) = 0
139 | stats['avg_abn_ts'].append(sum(abn_ts_list) / len(abn_ts_list))
140 | except:
141 | stats['avg_abn_ts'].append(0)
142 |
143 | df_stats = pd.DataFrame(stats)
144 | df_stats['abn_ratio'] = df_stats.num_ab_pts / df_stats.traj_len
145 |
146 | thr_min_len = df_stats.traj_len.quantile(0.1) # if less than min length(number of points), remove
147 | thr_max_abn_ratio = df_stats.abn_ratio.quantile(0.8) # if larger than max abnormal ratio, remove
148 | # thr_normal_ratio = 0 # if abn_ratio == 0, use the trajectory directly
149 | # thr_split_traj_ts = 60 # if time_interval is larger than 60 seconds, split trajectories.
150 | # print("min len: {}, max abnormal ratio: {}, normal ratio: {}, split threshold: {} seconds".
151 | # format(thr_min_len, thr_max_abn_ratio, thr_normal_ratio, thr_split_traj_ts))
152 | print("min len: {}, max abnormal ratio: {}".format(thr_min_len, thr_max_abn_ratio))
153 |
154 | return df_stats, thr_min_len, thr_max_abn_ratio #, thr_normal_ratio, thr_split_traj_ts
155 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 | # coding: utf-8
3 | # @Time : 2020/9/2 19:58
4 |
5 |
6 | import pickle
7 | import json
8 | import random
9 | import os
10 | import numpy as np
11 |
12 |
13 | def create_dir(directory):
14 | """
15 | Creates a directory if it does not already exist.
16 | """
17 | if not os.path.exists(directory):
18 | os.makedirs(directory)
19 |
20 |
21 | def save_pkl_data(data, dir, file_name):
22 | create_dir(dir)
23 | pickle.dump(data, open(dir + file_name, 'wb'))
24 |
25 |
26 | def load_pkl_data(dir, file_name):
27 | '''
28 | Args:
29 | -----
30 | path: path
31 | filename: file name
32 | Returns:
33 | --------
34 | data: loaded data
35 | '''
36 | file = open(dir+file_name, 'rb')
37 | data = pickle.load(file)
38 | file.close()
39 | return data
40 |
41 |
42 | def save_json_data(data, dir, file_name):
43 | create_dir(dir)
44 | with open(dir+file_name, 'w') as fp:
45 | json.dump(data, fp)
46 |
47 |
48 | def load_json_data(dir, file_name):
49 | with open(dir+file_name, 'r') as fp:
50 | data = json.load(fp)
51 | return data
--------------------------------------------------------------------------------