├── LICENSE
├── README.md
├── data
└── NYC_cat.pkl
├── data_prepare.py
├── data_preprocess.py
├── evaluate.py
├── model.py
├── train_test.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Urban Mobility
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CSLSL
2 |
3 |
4 | # Performances
5 | The latest experimental results are as follows:
6 |
7 |
8 |
9 | |
10 | Category |
11 | Location |
12 |
13 |
14 | |
15 | R@1 |
16 | R@5 |
17 | R@10 |
18 | R@1 |
19 | R@5 |
20 | R@10 |
21 |
22 |
23 | NYC |
24 | 0.327 |
25 | 0.661 |
26 | 0.759 |
27 | 0.268 |
28 | 0.568 |
29 | 0.656 |
30 |
31 |
32 | TKY |
33 | 0.448 |
34 | 0.801 |
35 | 0.875 |
36 | 0.240 |
37 | 0.488 |
38 | 0.580 |
39 |
40 |
41 | Dallas |
42 | - |
43 | - |
44 | - |
45 | 0.126 |
46 | 0.243 |
47 | 0.297 |
48 |
49 |
50 |
51 |
52 |
53 | # Datasets
54 | - The processed data can be found in the "data" folder, which was processed by ```data_preproess.py``` and ```data_prepare.py```.
55 | - The raw data can be found at the following open source.
56 | - [Foursqure (NYC and TKY)](https://sites.google.com/site/yangdingqi/home/foursquare-dataset?authuser=0)
57 | - [Gowalla (Dallas)](https://snap.stanford.edu/data/loc-gowalla.html)
58 |
59 | # Requirements
60 | - Python>=3.8
61 | - Pytorch>=1.8.1
62 | - Numpy
63 | - Pandas
64 |
65 |
66 | # Project Structure
67 | - ```/data```: file to store processed data
68 | - ```/results```: file to store results such as trained model and metrics.
69 | - ```data_preprocess.py```: data preprocessing to filter sparse users and locations (fewer than 10 records) and merge consecutive records (same user and location on the same day).
70 | - ```data_prepare.py```: data preparation for CSLSL (split trajectory and generate data).
71 | - ```train_test.py```: the entry to train and test a new model.
72 | - ```evaluate.py```: the entry to evalute a pretrained model.
73 | - ```model.py```: model defination.
74 | - ```utils.py```: tools such as batch generation and metric calculation.
75 |
76 |
77 |
78 |
79 | # Usage
80 | 1. Train and test a new model
81 | > ```python
82 | > python train_test.py --data_name NYC
83 | > ```
84 |
85 | 2. Evaluate a pretrained model
86 | > ```python
87 | > python evaluate.py --data_name NYC --model_name model_NYC
88 | > ```
89 |
90 | Detailed parameter description refers to ```evaluate.py``` and ```train_test.py```
91 |
--------------------------------------------------------------------------------
/data/NYC_cat.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/urbanmobility/CSLSL/d59c882b1feb329d0d3f72067bb111e72a57cbe3/data/NYC_cat.pkl
--------------------------------------------------------------------------------
/data_prepare.py:
--------------------------------------------------------------------------------
1 | import time
2 | import datetime
3 | import argparse
4 | import numpy as np
5 | import pickle
6 | from math import pi
7 | from collections import Counter
8 |
9 |
10 |
11 | class DataGeneration(object):
12 | def __init__(self, params):
13 |
14 | self.__dict__.update(params.__dict__)
15 |
16 | self.raw_data = {} # raw user's trajectory. {uid: [[pid, tim], ...]}
17 | self.poi_count = {} # raw location counts. {pid: count}
18 | self.data_filtered = {}
19 | self.uid_list = [] # filtered user id
20 | self.pid_dict = {} # filtered location id map
21 | self.train_data = {} # train data with history, {'uid': {'sid': {'loc': [], 'tim': [], 'target': [] (, 'cat': [])}}}
22 | self.train_id = {} # train data session id list
23 | self.test_data = {}
24 | self.test_id = {}
25 | self.tim_w = set()
26 | self.tim_h = set()
27 |
28 | self.raw_lat_lon = {} # count for latitude and longitude
29 | self.new_lat_lon = {}
30 | self.lat_lon_radians = {}
31 |
32 | if self.cat_contained:
33 | self.cid_dict = {} # cid_dict
34 | self.cid_count_dict = {} # cid count
35 | self.raw_cat_dict = {} # cid-cat
36 | self.new_cat_dict = {}
37 | self.pid_cid_dict = {} # pid-cid dict
38 |
39 | # 1. read trajectory data
40 | def load_trajectory(self):
41 | with open(self.path_in + self.data_name + '.txt', 'r') as fid:
42 | for i, line in enumerate(fid):
43 | if self.data_name in ['NYC', 'TKY']:
44 | uid, pid, cid, cat, lat, lon, _, tim = line.strip().split('\t')
45 | elif self.data_name == 'Dallas':
46 | uid, tim, pid, lon, lat = line.strip().split('\t')
47 | else:
48 | uid, tim, lat, lon, pid = line.strip().split('\t')
49 |
50 | # Note: user and location is id
51 | if self.cat_contained:
52 | # count uid records
53 | if uid not in self.raw_data:
54 | self.raw_data[uid] = [[pid, tim, cid]]
55 | else:
56 | self.raw_data[uid].append([pid, tim, cid])
57 | # count raw_cid-cat
58 | if cid not in self.raw_cat_dict:
59 | self.raw_cat_dict[cid] = cat
60 | else:
61 | if uid not in self.raw_data:
62 | self.raw_data[uid] = [[pid, tim]]
63 | else:
64 | self.raw_data[uid].append([pid, tim])
65 | if pid not in self.poi_count:
66 | self.poi_count[pid] = 1
67 | else:
68 | self.poi_count[pid] += 1
69 |
70 | # count poi latitude and longitude
71 | if pid not in self.raw_lat_lon:
72 | self.raw_lat_lon[pid] = [eval(lat), eval(lon)]
73 |
74 | # 2. filter users and locations, and then split trajectory into sessions
75 | def filter_and_divide_sessions(self):
76 | POI_MIN_RECORD_FOR_USER = 1 # keep same setting with DeepMove and LSTPM
77 |
78 | # filter user and location
79 | uid_list = [x for x in self.raw_data if len(self.raw_data[x]) > self.user_record_min] # uid list
80 | pid_list = [x for x in self.poi_count if self.poi_count[x] > self.poi_record_min] # pid list
81 |
82 | # iterate each user
83 | for uid in uid_list:
84 | user_records = self.raw_data[uid] # user_records is [[pid, tim (, cid)]]
85 | topk = Counter([x[0] for x in user_records]).most_common() # most common poi, [(poi, count), ...]
86 | topk_filter = [x[0] for x in topk if x[1] > POI_MIN_RECORD_FOR_USER] # the poi that the user go more than one time
87 | sessions = {} # sessions is {'sid' : [[pid, [week, hour] (, cid)], ...]}
88 | # iterate each record
89 | for i, record in enumerate(user_records):
90 | if self.cat_contained:
91 | poi, tim, cid = record
92 | else:
93 | poi, tim = record
94 | try:
95 | # time processing
96 | time_struct = time.strptime(tim, "%Y-%m-%dT%H:%M:%SZ")
97 | calendar_date = datetime.date(time_struct.tm_year, time_struct.tm_mon, time_struct.tm_mday).isocalendar()
98 | current_week = f'{calendar_date.year}-{calendar_date.week}'
99 | # Encode time
100 | # tim_code = [time_struct.tm_wday+1, time_struct.tm_hour*2+int(time_struct.tm_min/30)+1 ] # week(1~7), hours(1~48)
101 | tim_code = [time_struct.tm_wday+1, int(time_struct.tm_hour/2)+1 ] # week(1~7), hours(1~12)
102 | # revise record
103 | record[1] = tim_code
104 | except Exception as e:
105 | print('error:{}'.format(e))
106 | raise Exception
107 |
108 | # divide session. Rule is: same week
109 | sid = len(sessions) # session id
110 | if poi not in pid_list and poi not in topk_filter:
111 | # filter the poi if poi not in topk_filter:
112 | continue
113 | if i == 0 or len(sessions) == 0:
114 | sessions[sid] = [record]
115 | else:
116 | if last_week != current_week: # new session
117 | sessions[sid] = [record]
118 | else:
119 | sessions[sid - 1].append(record) # Note: data is already merged
120 | last_week = current_week
121 | sessions_filtered = {}
122 | # filter session with session_min
123 | for s in sessions:
124 | if len(sessions[s]) >= self.session_min:
125 | sessions_filtered[len(sessions_filtered)] = sessions[s]
126 | # filter user with sessions_min, that is, the user must have sessions_min's sessions.
127 | if len(sessions_filtered) < self.sessions_min:
128 | continue
129 |
130 | # ReEncode location index (may encode category)
131 | for sid in sessions_filtered: # sessions is {'sid' : [[pid, [week, hour]], ...]}
132 | for idx, record in enumerate(sessions_filtered[sid]):
133 | # reEncode location
134 | if record[0] not in self.pid_dict:
135 | self.pid_dict[record[0]] = len(self.pid_dict) + 1 # the id start from 1
136 | new_pid = self.pid_dict[record[0]]
137 | # new pid for latitude and longitude
138 | self.new_lat_lon[new_pid] = self.raw_lat_lon[record[0]]
139 | self.lat_lon_radians[new_pid] = list(np.array(self.raw_lat_lon[record[0]]) * pi / 180)
140 | # assign new pid
141 | record[0] = new_pid
142 | # time
143 | self.tim_w.add(record[1][0])
144 | self.tim_h.add(record[1][1])
145 | # category
146 | if self.cat_contained:
147 | # encode cid
148 | if record[2] not in self.cid_dict:
149 | new_cid = len(self.cid_dict) + 1 # the id start from 1
150 | self.cid_dict[record[2]] = new_cid
151 | self.new_cat_dict[new_cid] = self.raw_cat_dict[record[2]] # raw_cid-cat to new_cid-cat
152 | self.cid_count_dict[new_cid] = 1
153 | # assign cid
154 | record[2] = self.cid_dict[record[2]]
155 | # count cid
156 | self.cid_count_dict[record[2]] += 1
157 | # pid-cid dict
158 | if new_pid not in self.pid_cid_dict:
159 | self.pid_cid_dict[new_pid] = record[2]
160 | # reassign record
161 | sessions_filtered[sid][idx] = record
162 |
163 | # divide train and test
164 | sessions_id = list(sessions_filtered.keys())
165 | split_id = int(np.floor(self.train_split * len(sessions_id)))
166 | train_id = sessions_id[:split_id]
167 | test_id = sessions_id[split_id:]
168 | assert len(train_id) > 0, 'train sessions have error'
169 | assert len(test_id) > 0, 'test sessions have error'
170 | # preprare final data. (ReEncode user index), he id start from 1
171 | self.data_filtered[len(self.data_filtered)+1] = {'sessions_count': len(sessions_filtered), 'sessions': sessions_filtered, 'train': train_id, 'test': test_id}
172 |
173 |
174 | # final uid list
175 | self.uid_list = list(self.data_filtered.keys())
176 | print(f'Final user is {len(self.uid_list)}, location is {len(self.pid_dict)}')
177 |
178 |
179 | # 3. generate data with history sessions
180 | def generate_history_sessions(self, mode):
181 | print('='*4, f'generate history sessions in mode={mode}')
182 | data_input = self.data_filtered
183 | data = {}
184 | data_id = {}
185 | user_list = data_input.keys()
186 |
187 | for uid in user_list:
188 | data[uid] = {}
189 | user_sid_list = data_input[uid][mode]
190 | data_id[uid] = user_sid_list.copy()
191 | for idx, sid in enumerate(user_sid_list):
192 | # require at least one session as history; one <- history_session_min
193 | if mode == 'train' and idx < self.history_session_min:
194 | data_id[uid].pop(idx)
195 | continue
196 | data[uid][sid] = {}
197 | loc_seq_cur = []
198 | loc_seq_his = []
199 | tim_seq_cur = []
200 | tim_seq_his = []
201 |
202 | if self.cat_contained:
203 | cat_seq_cur = []
204 | cat_seq_his = []
205 |
206 | if mode == 'test': # in test mode, append all train data as history
207 | train_sid = data_input[uid]['train']
208 | for tmp_sid in train_sid:
209 | loc_seq_his.extend([record[0] for record in data_input[uid]['sessions'][tmp_sid]])
210 | tim_seq_his.extend([record[1] for record in data_input[uid]['sessions'][tmp_sid]])
211 | if self.cat_contained:
212 | cat_seq_his.extend([record[2] for record in data_input[uid]['sessions'][tmp_sid]])
213 | # append past sessions
214 | for past_idx in range(idx):
215 | tmp_sid = user_sid_list[past_idx]
216 | loc_seq_his.extend([record[0] for record in data_input[uid]['sessions'][tmp_sid]])
217 | tim_seq_his.extend([record[1] for record in data_input[uid]['sessions'][tmp_sid]])
218 | if self.cat_contained:
219 | cat_seq_his.extend([record[2] for record in data_input[uid]['sessions'][tmp_sid]])
220 | # current session
221 | loc_seq_cur.extend([record[0] for record in data_input[uid]['sessions'][sid][:-1]]) # [[pid1], [pid2], ...]
222 | tim_seq_cur.extend([record[1] for record in data_input[uid]['sessions'][sid][:-1]])
223 | if self.cat_contained:
224 | cat_seq_cur.extend([record[2] for record in data_input[uid]['sessions'][sid][:-1]])
225 |
226 | # store sequence
227 | data[uid][sid]['target_l'] = [record[0] for record in data_input[uid]['sessions'][sid][1:]]
228 | data[uid][sid]['target_th'] = [record[1][1] for record in data_input[uid]['sessions'][sid][1:]]
229 | data[uid][sid]['loc'] = [loc_seq_his, loc_seq_cur] # list
230 | data[uid][sid]['tim'] = [tim_seq_his, tim_seq_cur] # list
231 | if self.cat_contained:
232 | data[uid][sid]['cat'] = [cat_seq_his, cat_seq_cur] # list
233 | data[uid][sid]['target_c'] = [record[2] for record in data_input[uid]['sessions'][sid][1:]]
234 |
235 | # train/test_data is {'uid': {'sid': {'loc': [pid_seq], 'tim': [[week, hour], ...] (, 'cat': [cid_seq])}}}
236 | #
237 | if mode == 'train':
238 | self.train_data = data
239 | self.train_id = data_id
240 | elif mode == 'test':
241 | self.test_data = data
242 | self.test_id = data_id
243 | print('Finish')
244 |
245 |
246 | # 4. save variables
247 | def save_variables(self):
248 | dataset = {'train_data': self.train_data, 'train_id': self.train_id,
249 | 'test_data': self.test_data, 'test_id': self.test_id,
250 | 'pid_dict': self.pid_dict, 'uid_list': self.uid_list,
251 | 'pid_lat_lon': self.new_lat_lon,
252 | 'pid_lat_lon_radians' : self.lat_lon_radians,
253 | 'parameters': self.get_parameters()}
254 | if self.cat_contained:
255 | dataset['cid_dict'] = self.cid_dict
256 | dataset['cid_count_dict'] = self.cid_count_dict
257 | dataset['cid_cat_dict'] = self.new_cat_dict
258 | dataset['pid_cid_dict'] = self.pid_cid_dict
259 | pickle.dump(dataset, open(self.path_out + self.data_name + '_cat.pkl', 'wb'))
260 | else:
261 | pickle.dump(dataset, open(self.path_out + self.data_name + '.pkl', 'wb'))
262 | def get_parameters(self):
263 | parameters = self.__dict__.copy()
264 | del parameters['raw_data']
265 | del parameters['poi_count']
266 | del parameters['data_filtered']
267 | del parameters['uid_list']
268 | del parameters['pid_dict']
269 | del parameters['train_data']
270 | del parameters['train_id']
271 | del parameters['test_data']
272 | del parameters['test_id']
273 | if self.cat_contained:
274 | del parameters['cid_dict']
275 |
276 | return parameters
277 |
278 |
279 | def parse_args():
280 | parser = argparse.ArgumentParser()
281 | parser.add_argument('--path_in', type=str, default='../data/', help="input data path")
282 | parser.add_argument('--path_out', type=str, default='./data/', help="output data path")
283 | parser.add_argument('--data_name', type=str, default='NYC', help="data name")
284 | parser.add_argument('--user_record_min', type=int, default=10, help="user record length filter threshold")
285 | parser.add_argument('--poi_record_min', type=int, default=10, help="location record length filter threshold")
286 | parser.add_argument('--session_min', type=int, default=2, help="control the length of session not too short")
287 | parser.add_argument('--sessions_min', type=int, default=5, help="the minimum amount of the user's sessions")
288 | parser.add_argument('--train_split', type=float, default=0.8, help="train/test ratio")
289 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category")
290 | parser.add_argument('--history_session_min', type=int, default=1, help="minimun number of history session")
291 |
292 | if __name__ == '__main__':
293 | return parser.parse_args()
294 | else:
295 | return parser.parse_args([])
296 |
297 |
298 | if __name__ == '__main__':
299 |
300 | start_time = time.time()
301 |
302 | params = parse_args()
303 | data_generator = DataGeneration(params)
304 | parameters = data_generator.get_parameters()
305 | print('='*20 + ' Parameter settings')
306 | print(', '.join([p + '=' + str(parameters[p]) for p in parameters]))
307 | print('='*20 + ' Start processing')
308 | print('==== Load trajectory from {}'.format(data_generator.path_in))
309 | data_generator.load_trajectory()
310 |
311 | print('==== filter users')
312 | data_generator.filter_and_divide_sessions()
313 |
314 |
315 | print('==== generate history sessions')
316 | data_generator.generate_history_sessions('train')
317 | data_generator.generate_history_sessions('test')
318 |
319 | print('==== save prepared data')
320 | data_generator.save_variables()
321 |
322 | print('==== Preparetion Finished')
323 | print('Raw users:{} raw locations:{}'.format(
324 | len(data_generator.raw_data), len(data_generator.poi_count)))
325 | print(f'Final users:{len(data_generator.uid_list)}, min_id:{np.min(data_generator.uid_list)}, max_id:{np.max(data_generator.uid_list)}')
326 | pid_list = list(data_generator.pid_dict.values())
327 | print(f'Final locations:{len(pid_list)}, min_id:{np.min(pid_list)}, max_id:{np.max(pid_list)}')
328 | print(f'Final time-week:{len(data_generator.tim_w)}, min_id:{np.min(list(data_generator.tim_w))}, max_id:{np.max(list(data_generator.tim_w))}')
329 | print(f'Final time-hour:{len(data_generator.tim_h)}, min_id:{np.min(list(data_generator.tim_h))}, max_id:{np.max(list(data_generator.tim_h))}')
330 | if params.cat_contained:
331 | cid_list = list(data_generator.cid_dict.values())
332 | print(f'Final categories:{len(cid_list)}, min_id:{np.min(cid_list)}, max_id:{np.max(cid_list)}')
333 | print(f'Time cost is {time.time()-start_time:.0f}s')
334 |
335 |
--------------------------------------------------------------------------------
/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import time
4 | import argparse
5 |
6 |
7 |
8 | def settings():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--path', default='./data/', type=str,
11 | help='data path')
12 | parser.add_argument('--dataname', default='NYC', type=str,
13 | help='data name')
14 | parser.add_argument('--filetype', default='txt', type=str,
15 | help='file type')
16 | parser.add_argument('--user_po', default=0, type=int,
17 | help='Position in record for user')
18 | parser.add_argument('--loc_po', default=1, type=int,
19 | help='Position in record for location')
20 | parser.add_argument('--tim_po', default=2, type=int,
21 | help='Position in record for time')
22 |
23 | parser.add_argument('--user_record_min', default=10, type=int,
24 | help='Minimun record number for user')
25 | parser.add_argument('--loc_record_min', default=10, type=int,
26 | help='Minimun record number for location')
27 | return parser.parse_args()
28 |
29 |
30 | def preprocessing(params):
31 | '''Preprocessing data
32 | Note:
33 | 1. Raw data is sorted by user(1) and time(2)
34 | 2. Filter sparse data with minimun record numbers.
35 | 3. Encode user and location id.
36 | 4. Merge data with same user and location in a day.
37 |
38 | '''
39 |
40 | # Loading and Filtering sparse data
41 | print('='*20, 'Loading and preprocessing sparse data')
42 | filepath = f'{params.path}{params.dataname}.{params.filetype}'
43 | print(f'Path is {filepath}')
44 | loc_count = {} # store location info with loc-num
45 | user_count = {} # store user info with user-num
46 | user_id = {}
47 | loc_id = {}
48 |
49 | # load file and count numbers
50 | print('='*20, 'Loading and Counting')
51 | if params.filetype == 'txt':
52 | with open(filepath, 'r') as f:
53 | reader = csv.reader(f, delimiter='\t')
54 | for record in reader:
55 | if '' in record:
56 | continue
57 | user = record[params.user_po]
58 | loc = record[params.loc_po]
59 | if user not in user_count:
60 | user_count[user] = 1
61 | else:
62 | user_count[user] += 1
63 | if loc not in loc_count:
64 | loc_count[loc] = 1
65 | else:
66 | loc_count[loc] += 1
67 |
68 | record_num = os.popen(f'wc -l {filepath}').readlines()[0].split()[0]
69 | print(f'Finished, records is {record_num}, all user is {len(user_count)}, all location is {len(loc_count)}')
70 |
71 |
72 | # Filter and encode user and location
73 | print('='*20, 'Filtering and encoding')
74 | for i in user_count:
75 | if user_count[i] > params.user_record_min:
76 | user_id[i] = len(user_id)
77 | for i in loc_count:
78 | if loc_count[i] > params.loc_record_min:
79 | loc_id[i] = len(loc_id)
80 |
81 | # store
82 | filter_path = f'{params.path}{params.dataname}_filtered.txt'
83 | print(f'Filter path is {filter_path}')
84 | with open(filter_path, 'w') as f_out:
85 | writer = csv.writer(f_out, delimiter='\t')
86 | with open(filepath, 'r') as f_in:
87 | reader = csv.reader(f_in, delimiter='\t')
88 | for record in reader:
89 | if '' in record:
90 | continue
91 | user = record[params.user_po]
92 | loc = record[params.loc_po]
93 | if user in user_id and loc in loc_id:
94 | record[params.user_po] = user_id[user]
95 | record[params.loc_po] = loc_id[loc]
96 | writer.writerow(record)
97 |
98 | record_num = os.popen(f'wc -l {filter_path}').readlines()[0].split()[0]
99 | print(f'Finished, records is {record_num}, user is {len(user_id)}, location is {len(loc_id)}')
100 |
101 |
102 | # Merge data
103 | print('='*20, 'Merging')
104 | merge_path = f'{params.path}{params.dataname}_merged.txt'
105 | print(f'Merge path is {filter_path}')
106 | with open(merge_path, 'w') as f_out:
107 | writer = csv.writer(f_out, delimiter='\t')
108 | # get first record
109 | with open(filter_path, 'r') as f_in:
110 | pre_record = f_in.readlines()[0].split('\t')
111 | # all record
112 | with open(filter_path, 'r') as f_in:
113 | reader = csv.reader(f_in, delimiter='\t')
114 | for record in reader:
115 | # same person, same location, same day
116 | if record[params.user_po] == pre_record[params.user_po] and \
117 | record[params.loc_po] == pre_record[params.loc_po] and \
118 | record[params.tim_po].split('T')[0] == pre_record[params.tim_po].split('T')[0]:
119 | continue
120 | writer.writerow(record)
121 | pre_record = record
122 |
123 | record_num = os.popen(f'wc -l {merge_path}').readlines()[0].split()[0]
124 | print(f'Finished, records is {record_num}')
125 |
126 |
127 |
128 | if __name__ == '__main__':
129 |
130 | start_time = time.time()
131 | params = settings()
132 | preprocessing(params)
133 | print('Time cost:', f'{time.time()-start_time:.0f}')
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import pickle
5 | import numpy as np
6 | import pandas as pd
7 |
8 | import torch
9 |
10 | from utils import *
11 | from model import *
12 |
13 |
14 | def settings(param=[]):
15 | parser = argparse.ArgumentParser()
16 | # data params
17 | parser.add_argument('--path_in', type=str, default='./data/', help="input data path")
18 | parser.add_argument('--path_out', type=str, default='./results/', help="output data path")
19 | parser.add_argument('--data_name', type=str, default='NYC', help="data name")
20 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category")
21 | parser.add_argument('--out_filename', type=str, default='', help="output data filename")
22 | # train params
23 | parser.add_argument('--gpu', type=str, default='0', help="GPU index to choose")
24 | parser.add_argument('--run_num', type=int, default=10, help="run number")
25 | parser.add_argument('--model_path', type=str, default='./results/pretrained/', help="model path")
26 | parser.add_argument('--model_name', type=str, default='model_NYC', help="model name to load")
27 | parser.add_argument('--batch_size', type=int, default=64, help="batch size")
28 |
29 | # model params
30 | # embedding
31 | parser.add_argument('--user_embed_dim', type=int, default=20, help="user embedding dimension")
32 | parser.add_argument('--loc_embed_dim', type=int, default=200, help="loc embedding dimension")
33 | parser.add_argument('--tim_h_embed_dim', type=int, default=20, help="time hour embedding dimension")
34 | parser.add_argument('--tim_w_embed_dim', type=int, default=10, help="time week embedding dimension")
35 | parser.add_argument('--cat_embed_dim', type=int, default=100, help="category embedding dimension")
36 | # rnn
37 | parser.add_argument('--rnn_type', type=str, default='gru', help="rnn type")
38 | parser.add_argument('--rnn_layer_num', type=int, default=1, help="rnn layer number")
39 | parser.add_argument('--rnn_t_hid_dim', type=int, default=600, help="rnn hidden dimension for t")
40 | parser.add_argument('--rnn_c_hid_dim', type=int, default=600, help="rnn hidden dimension for c")
41 | parser.add_argument('--rnn_l_hid_dim', type=int, default=600, help="rnn hidden dimension for l")
42 | parser.add_argument('--dropout', type=float, default=0.1, help="drop out for rnn")
43 |
44 |
45 | if __name__ == '__main__' and param == []:
46 | params = parser.parse_args()
47 | else:
48 | params = parser.parse_args(param)
49 |
50 | if not os.path.exists(params.path_out):
51 | os.mkdir(params.path_out)
52 |
53 | return params
54 |
55 |
56 | def evaluate(params, dataset):
57 | '''Evaluate model performance
58 | '''
59 |
60 | # dataset info
61 | params.uid_size = len(dataset['uid_list'])
62 | params.pid_size = len(dataset['pid_dict'])
63 | params.cid_size = len(dataset['cid_dict']) if params.cat_contained else 0
64 | # generate input data
65 | data_test, test_id = dataset['test_data'], dataset['test_id']
66 | pid_lat_lon_radians = torch.tensor([[0, 0]] + list(dataset['pid_lat_lon_radians'].values())).to(params.device)
67 |
68 | # load model
69 | model = Model(params).to(params.device)
70 | model.load_state_dict(torch.load(params.model_path+params.model_name+'.pkl'))
71 | model.eval()
72 |
73 |
74 | l_acc_all = np.zeros(3)
75 | c_acc_all = np.zeros(3)
76 | t_mse_all = 0.
77 | valid_num_all = 0
78 | model.eval()
79 | # evaluate with batch
80 | for mask_batch, target_batch, data_batch in generate_batch_data(data_test, test_id, params.device, params.batch_size, params.cat_contained):
81 | # model forward
82 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch)
83 |
84 | # calculate metrics
85 | l_acc = calculate_recall(target_batch[0], l_pred)
86 | l_acc_all += l_acc
87 | t_mse_all += torch.nn.functional.l1_loss(th_pred.squeeze(-1), target_batch[1].squeeze(-1), reduction='sum').item()
88 | valid_num_all += valid_num
89 |
90 | if params.cat_contained:
91 | c_acc = calculate_recall(target_batch[2], c_pred)
92 | c_acc_all += c_acc
93 |
94 | return l_acc_all / valid_num_all, c_acc_all / valid_num_all, t_mse_all / valid_num_all
95 |
96 |
97 | if __name__ == '__main__':
98 |
99 | print('='*20, ' Program Start')
100 | params = settings()
101 | params.device = torch.device(f"cuda:{params.gpu}")
102 | print('Parameter is\n', params.__dict__)
103 |
104 | # file name to store
105 | FILE_NAME = [params.path_out, f'{time.strftime("%Y%m%d")}_{params.data_name}_']
106 | FILE_NAME[1] += f'{params.out_filename}'
107 |
108 | # Load data
109 | print('='*20, ' Loading data')
110 | start_time = time.time()
111 | if params.cat_contained:
112 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}_cat.pkl', 'rb'))
113 | else:
114 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}.pkl', 'rb'))
115 | print(f'Finished, time cost is {time.time()-start_time:.1f}')
116 |
117 | # metrics
118 | metrics = pd.DataFrame()
119 |
120 | # start running
121 | print('='*20, "Start Evaluating")
122 | for i in range(params.run_num):
123 | print('='*20, f'Run {i}')
124 |
125 | # To Revise
126 | results, results_c, results_t = evaluate(params, dataset)
127 | metric_dict = {'Rec-l@1': results[0], 'Rec-l@5': results[1], 'Rec-l@10': results[2],
128 | 'MAE': results_t, 'Rec-c@1': results_c[0], 'Rec-c@5': results_c[1], 'Rec-c@10': results_c[2]}
129 | metric_tmp = pd.DataFrame(metric_dict, index=[i])
130 | metrics = pd.concat([metrics, metric_tmp])
131 |
132 |
133 | print('='*20, "Finished")
134 | mean = pd.DataFrame(metrics.mean()).T
135 | mean.index = ['mean']
136 | std = pd.DataFrame(metrics.std()).T
137 | std.index = ['std']
138 | metrics = pd.concat([metrics, mean, std])
139 | print(metrics)
140 |
141 | # save
142 | metrics.to_csv(f'{FILE_NAME[0]}metrics_{FILE_NAME[1]}.csv')
143 | print('='*20, f'\nMetrics saved. File name is {FILE_NAME[0]}metrics_{FILE_NAME[1]}.csv')
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
5 | import numpy as np
6 | from utils import generate_mask
7 |
8 | NUM_TASK = 3
9 | EARTHRADIUS = 6371.0
10 |
11 | class Model(nn.Module):
12 | def __init__(self, params):
13 | super().__init__()
14 | self.__dict__.update(params.__dict__)
15 | self.th_size = 12 + 1 # +1 because of the padding value 0,
16 | self.tw_size = 7 + 1
17 | self.pid_size += 1
18 | self.uid_size += 1
19 | self.cid_size += 1
20 | # Dim
21 | RNN_input_dim = self.user_embed_dim + self.loc_embed_dim + self.tim_w_embed_dim + self.tim_h_embed_dim
22 | if self.cat_contained:
23 | RNN_input_dim += self.cat_embed_dim
24 |
25 | # Embedding; (all id is start from 1)
26 | self.user_embedder = nn.Embedding(self.uid_size, self.user_embed_dim) # without padding
27 | self.loc_embedder = nn.Embedding(self.pid_size, self.loc_embed_dim, padding_idx=0)
28 | self.tim_embedder_week = nn.Embedding(self.tw_size, self.tim_w_embed_dim, padding_idx=0)
29 | self.tim_embedder_hour = nn.Embedding(self.th_size, self.tim_h_embed_dim, padding_idx=0)
30 | self.cat_embedder = nn.Embedding(self.cid_size, self.cat_embed_dim, padding_idx=0)
31 |
32 | # Capturer
33 | # Version: Seperate
34 | self.capturer_t = SessionCapturer(RNN_input_dim, self.rnn_t_hid_dim, params)
35 | self.capturer_l = SessionCapturer(RNN_input_dim, self.rnn_l_hid_dim, params)
36 |
37 | if self.cat_contained:
38 | self.cat_embedder = nn.Embedding(self.cid_size, self.cat_embed_dim, padding_idx=0)
39 | self.capturer_c = SessionCapturer(RNN_input_dim, self.rnn_c_hid_dim, params)
40 |
41 | # CMTL
42 | if self.cat_contained:
43 | self.fc_c = nn.Linear(self.rnn_c_hid_dim, self.cid_size)
44 | self.label_trans_c = nn.Linear(self.cid_size, self.cat_embed_dim)
45 | self.fc_t = nn.Linear(self.rnn_t_hid_dim + self.cat_embed_dim, 1)
46 | self.label_trans_t = nn.Linear(1, self.tim_h_embed_dim)
47 | self.fc_l = nn.Linear(self.rnn_l_hid_dim + self.tim_h_embed_dim, self.pid_size)
48 | else:
49 | self.fc_t = nn.Linear(self.rnn_t_hid_dim, 1)
50 | self.label_trans_t = nn.Linear(1, self.tim_h_embed_dim)
51 | self.fc_l = nn.Linear(self.rnn_l_hid_dim + self.tim_h_embed_dim, self.pid_size)
52 |
53 | # Loss
54 | self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum', ignore_index=0)
55 | self.mae_loss = nn.L1Loss(reduction='sum')
56 |
57 | def forward(self, data, mask_batch):
58 |
59 | # Capture Short-term representation
60 | # 1) get data and generate mask
61 | if self.cat_contained:
62 | uid_tensor, loc_his_pad, loc_cur_pad, tim_w_his_pad, tim_w_cur_pad,\
63 | tim_h_his_pad, tim_h_cur_pad, cat_his_pad, cat_cur_pad = data
64 | else:
65 | uid_tensor, loc_his_pad, loc_cur_pad, tim_w_his_pad, tim_w_cur_pad,\
66 | tim_h_his_pad, tim_h_cur_pad = data
67 | target_mask = generate_mask(mask_batch[0]).unsqueeze(2).to(self.device)
68 | his_mask = generate_mask(mask_batch[1]).unsqueeze(2).to(self.device)
69 | cur_mask = generate_mask(mask_batch[2]).unsqueeze(2).to(self.device)
70 |
71 | # 2) embed
72 | uid_embed = self.user_embedder(uid_tensor)
73 | loc_his_embed, loc_cur_embed = self.loc_embedder(loc_his_pad), self.loc_embedder(loc_cur_pad)
74 | tim_week_his_embed, tim_week_cur_embed = self.tim_embedder_week(tim_w_his_pad), self.tim_embedder_week(tim_w_cur_pad)
75 | tim_hour_his_embed, tim_hour_cur_embed = self.tim_embedder_hour(tim_h_his_pad), self.tim_embedder_hour(tim_h_cur_pad)
76 | rnn_input_his_concat = torch.cat((uid_embed.expand(-1, loc_his_embed.shape[1], -1), loc_his_embed, tim_week_his_embed, tim_hour_his_embed), dim=-1)
77 | rnn_input_cur_concat = torch.cat((uid_embed.expand(-1, loc_cur_embed.shape[1], -1), loc_cur_embed, tim_week_cur_embed, tim_hour_cur_embed), dim=-1)
78 | if self.cat_contained:
79 | cat_his_embed, cat_cur_embed = self.cat_embedder(cat_his_pad), self.cat_embedder(cat_cur_pad)
80 | rnn_input_his_concat = torch.cat((rnn_input_his_concat, cat_his_embed), dim=-1)
81 | rnn_input_cur_concat = torch.cat((rnn_input_cur_concat, cat_cur_embed), dim=-1)
82 |
83 | # 3) rnn capturer
84 | # Version: seperate
85 | cur_t_rnn, hc_t = self.capturer_t(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:])
86 | if self.cat_contained:
87 | cur_c_rnn, hc_c = self.capturer_c(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t)
88 | cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_c)
89 |
90 | # 4) tower, t,c,l
91 | # CMTL
92 | hc_t, hc_c, hc_l = hc_t.squeeze(), hc_c.squeeze(), hc_l.squeeze()
93 |
94 | c_pred = self.fc_c(hc_c)
95 | c_trans = self.label_trans_c(c_pred.clone())
96 | t_pred = self.fc_t(torch.cat((hc_t, c_trans), dim=-1))
97 | t_trans = self.label_trans_t(t_pred.clone())
98 | l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1))
99 | else:
100 | cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t)
101 | # 4) tower, t,c,l
102 | # CMTL
103 | hc_t, hc_l = hc_t.squeeze(), hc_l.squeeze()
104 | t_pred = self.fc_t(hc_t)
105 | t_trans = self.label_trans_t(t_pred.clone())
106 | l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1))
107 |
108 | valid_num = (target_mask==0).sum().item()
109 |
110 | if self.cat_contained:
111 | return t_pred, c_pred, l_pred, valid_num
112 | else:
113 | return t_pred, 0, l_pred, valid_num
114 |
115 | def calculate_loss(self, th_pred_in, c_pred_in, l_pred_in, target_batch, valid_num, pid_lat_lon_radians):
116 |
117 | # location loss with cross entropy
118 | l_pred = l_pred_in.reshape(-1, self.pid_size)
119 | l_target = target_batch[0].reshape(-1)
120 | loss_l = self.cross_entropy_loss(l_pred, l_target) / valid_num
121 | # time loss with mae loss
122 | th_target = target_batch[1].reshape(-1)
123 | loss_t = self.mae_loss(th_pred_in.squeeze(-1), th_target.float()) / valid_num
124 |
125 | loss_geocons = self.geo_con_loss(l_pred, l_target, pid_lat_lon_radians) / valid_num
126 |
127 | if self.cat_contained:
128 | # category loss with cross entropy
129 | c_pred = c_pred_in.reshape(-1, self.cid_size)
130 | c_target = target_batch[2].reshape(-1)
131 | loss_c = self.cross_entropy_loss(c_pred, c_target) / valid_num
132 |
133 | return loss_t, loss_c, loss_l, loss_geocons
134 | else:
135 | return loss_t, torch.tensor(0), loss_l, loss_geocons
136 |
137 | def geo_con_loss(self, l_pred_in, l_target, pid_lat_lon):
138 |
139 | log_softmax = nn.functional.log_softmax(l_pred_in, dim=-1)
140 | l_pred = torch.argmax(log_softmax, dim=-1)
141 | l_coor_pred = pid_lat_lon[l_pred]
142 | l_coor_tar = pid_lat_lon[l_target]
143 |
144 | dlat = l_coor_pred[:, 0] - l_coor_tar[:, 0]
145 | dlon = l_coor_pred[:, 1] - l_coor_tar[:, 1]
146 | # a = torch.sin(dlat/2) **2 + torch.cos(l_coor_pred[:, 0]) * torch.cos(l_coor_tar[:, 0]) * (torch.sin(dlon/2))**2
147 | # dist = EARTHRADIUS * 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1-a))
148 | # new version
149 | dist = dlat ** 2 + dlon ** 2
150 | loc_prob = log_softmax * dist.unsqueeze(-1)
151 | loss_geocons = F.nll_loss(loc_prob, l_target, ignore_index=0, reduction='sum')
152 |
153 | return loss_geocons
154 |
155 |
156 |
157 | class SessionCapturer(nn.Module):
158 | '''Expert module'''
159 | def __init__(self, RNN_input_dim, RNN_hid_dim, params):
160 | super().__init__()
161 | self.__dict__.update(params.__dict__)
162 | # RNN
163 | self.HistoryCapturer = RnnFactory(self.rnn_type).create(RNN_input_dim, RNN_hid_dim, self.rnn_layer_num, self.dropout)
164 | self.CurrentCapturer = RnnFactory(self.rnn_type).create(RNN_input_dim, RNN_hid_dim, self.rnn_layer_num, self.dropout)
165 |
166 | def forward(self, his_in, cur_in, his_mask, cur_mask, len_batch, hc=None):
167 |
168 |
169 | # 1) pack padded
170 | his_in_pack = pack_padded_sequence(his_in, len_batch[0], batch_first=True, enforce_sorted=False)
171 | cur_in_pack = pack_padded_sequence(cur_in, len_batch[1], batch_first=True, enforce_sorted=False)
172 |
173 | # 2) history capturer
174 | if hc == None:
175 | history_pack, history_hc = self.HistoryCapturer(his_in_pack)
176 | else:
177 | history_pack, history_hc = self.HistoryCapturer(his_in_pack, hc)
178 | # 3) current capturer
179 | current_pack, current_hc = self.CurrentCapturer(cur_in_pack, history_hc)
180 |
181 | # 4) unpack
182 | history_unpack, _ = pad_packed_sequence(history_pack, batch_first=True)
183 | current_unpack, _ = pad_packed_sequence(current_pack, batch_first=True) # (B, S, BH)
184 |
185 |
186 | # Version: concat
187 | # return current_unpack, current_hc, history_unpack
188 |
189 | # Version: basic
190 | return current_unpack, current_hc
191 |
192 |
193 |
194 |
195 |
196 | class RnnFactory():
197 | ''' Creates the desired RNN unit. '''
198 |
199 | def __init__(self, rnn_type):
200 | self.rnn_type = rnn_type
201 |
202 | def create(self, input_dim, hidden_dim, num_layer, dropout=0):
203 | if self.rnn_type == 'rnn':
204 | return nn.RNN(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout)
205 | if self.rnn_type == 'gru':
206 | return nn.GRU(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout)
207 | if self.rnn_type == 'lstm':
208 | return nn.LSTM(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout)
209 |
--------------------------------------------------------------------------------
/train_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 | import pickle
5 | import numpy as np
6 | import pandas as pd
7 |
8 | import torch
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from utils import *
12 | from model import *
13 |
14 |
15 | def settings(param=[]):
16 | parser = argparse.ArgumentParser()
17 | # data params
18 | parser.add_argument('--path_in', type=str, default='./data/', help="input data path")
19 | parser.add_argument('--path_out', type=str, default='./results/', help="output data path")
20 | parser.add_argument('--data_name', type=str, default='NYC', help="data name")
21 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category")
22 | parser.add_argument('--out_filename', type=str, default='', help="output data filename")
23 | # train params
24 | parser.add_argument('--gpu', type=str, default='0', help="GPU index to choose")
25 | parser.add_argument('--run_num', type=int, default=1, help="run number")
26 | parser.add_argument('--epoch_num', type=int, default=10, help="epoch number")
27 | parser.add_argument('--batch_size', type=int, default=128, help="batch size")
28 | parser.add_argument('--learning_rate', type=float, default=1e-4, help="learning rate")
29 | parser.add_argument('--weight_decay', type=float, default=1e-6, help="weight decay")
30 | parser.add_argument('--evaluate_step', type=int, default=1, help="evaluate step")
31 | parser.add_argument('--lam_t', type=float, default=5, help="loss lambda time")
32 | parser.add_argument('--lam_c', type=float, default=10, help="loss lambda category")
33 | parser.add_argument('--lam_s', type=float, default=10, help="loss lambda for geographcal consistency")
34 | # model params
35 | # embedding
36 | parser.add_argument('--user_embed_dim', type=int, default=20, help="user embedding dimension")
37 | parser.add_argument('--loc_embed_dim', type=int, default=200, help="loc embedding dimension")
38 | parser.add_argument('--tim_h_embed_dim', type=int, default=20, help="time hour embedding dimension")
39 | parser.add_argument('--tim_w_embed_dim', type=int, default=10, help="time week embedding dimension")
40 | parser.add_argument('--cat_embed_dim', type=int, default=100, help="category embedding dimension")
41 | # rnn
42 | parser.add_argument('--rnn_type', type=str, default='gru', help="rnn type")
43 | parser.add_argument('--rnn_layer_num', type=int, default=1, help="rnn layer number")
44 | parser.add_argument('--rnn_t_hid_dim', type=int, default=600, help="rnn hidden dimension for t")
45 | parser.add_argument('--rnn_c_hid_dim', type=int, default=600, help="rnn hidden dimension for c")
46 | parser.add_argument('--rnn_l_hid_dim', type=int, default=600, help="rnn hidden dimension for l")
47 | parser.add_argument('--dropout', type=float, default=0.1, help="drop out for rnn")
48 |
49 |
50 | if __name__ == '__main__' and param == []:
51 | params = parser.parse_args()
52 | else:
53 | params = parser.parse_args(param)
54 |
55 | if not os.path.exists(params.path_out):
56 | os.mkdir(params.path_out)
57 |
58 | return params
59 |
60 |
61 | def train(params, dataset):
62 |
63 | # dataset info
64 | params.uid_size = len(dataset['uid_list'])
65 | params.pid_size = len(dataset['pid_dict'])
66 | params.cid_size = len(dataset['cid_dict']) if params.cat_contained else 0
67 | # generate input data
68 | data_train, train_id = dataset['train_data'], dataset['train_id']
69 | data_test, test_id = dataset['test_data'], dataset['test_id']
70 | pid_lat_lon = torch.tensor([[0, 0]] + list(dataset['pid_lat_lon'].values())).to(params.device)
71 |
72 | # model and optimizer
73 | model = Model(params).to(params.device)
74 | optimizer = torch.optim.Adam(model.parameters(), lr=params.learning_rate, weight_decay=params.weight_decay)
75 | print('==== Model is \n', model)
76 | get_model_params(model)
77 | print('==== Optimizer is \n', optimizer)
78 |
79 | # iterate epoch
80 | best_info_train = {'epoch':0, 'Recall@1':0} # best metrics
81 | best_info_test = {'epoch':0, 'Recall@1':0} # best metrics
82 | print('='*10, ' Training')
83 | for epoch in range(params.epoch_num):
84 | model.train()
85 | # variable
86 | loss_l_all = 0.
87 | loss_t_all = 0.
88 | loss_c_all = 0.
89 | loss_s_all = 0.
90 | loss_all = 0.
91 | valid_all = 0
92 |
93 | # train with batch
94 | time_start = time.time()
95 | print('==== Train', end=', ')
96 | for mask_batch, target_batch, data_batch in generate_batch_data(data_train, train_id, params.device, params.batch_size, params.cat_contained):
97 | # model forward
98 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch)
99 | # calcuate loss
100 | loss_t, loss_c, loss_l, loss_s = model.calculate_loss(th_pred, c_pred, l_pred, target_batch, valid_num, pid_lat_lon)
101 | loss = loss_l + params.lam_t * loss_t + params.lam_c * loss_c + params.lam_s * loss_s
102 | valid_all += valid_num
103 | loss_l_all += loss_l.item() * valid_num
104 | loss_t_all += loss_t.item() * valid_num
105 | loss_c_all += loss_c.item() * valid_num
106 | loss_s_all += loss_s.item() * valid_num
107 | loss_all += loss.item() * valid_num
108 |
109 | # backward
110 | optimizer.zero_grad()
111 | loss.backward()
112 | nn.utils.clip_grad_norm_(model.parameters(), 2.0)
113 | optimizer.step()
114 |
115 | train_time = time.time() - time_start
116 | loss_all /= valid_all
117 | loss_l_all /= valid_all
118 | loss_t_all /= valid_all
119 | loss_c_all /= valid_all
120 | loss_s_all /= valid_all
121 |
122 |
123 | # evaluation
124 | if epoch % params.evaluate_step == 0:
125 | # evaluate with train data
126 | print('==== Evaluate train data', end=', ')
127 | time_start = time.time()
128 | train_acc_l, train_acc_c, train_mse_t = evaluate(model, data_train, train_id, params)
129 | train_eval_time = time.time() - time_start
130 | # evaluate with test data
131 | print('==== Evaluate test data', end=', ')
132 | time_start = time.time()
133 | test_acc_l, test_acc_c, test_mse_t = evaluate(model, data_test, test_id, params)
134 | test_time = time.time() - time_start
135 | print(f'[Epoch={epoch+1}/{params.epoch_num}], loss={loss_all:.2f}, loss_l={loss_l_all:.2f},', end=' ')
136 | print(f'loss_t={loss_t_all:.2f}, loss_c={loss_c_all:.2f}, loss_s={loss_s_all:.2f};')
137 | print(f'Acc_loc: train_l={train_acc_l}, test_l={test_acc_l};')
138 | print(f'Acc_cat: train_c={train_acc_c}, test_c={test_acc_c};')
139 | print(f'MAE_time: train_t={train_mse_t:.2f}, test_t={test_mse_t:.2f};')
140 | print(f'Eval time cost: train={train_eval_time:.1f}s, test={test_time:.1f}s\n')
141 | # store info
142 | if best_info_train['Recall@1'] < train_acc_l[0]:
143 | best_info_train['epoch'] = epoch
144 | best_info_train['Recall@1'] = train_acc_l[0]
145 | best_info_train['Recall@all'] = train_acc_l
146 | best_info_train['MAE'] = train_mse_t
147 | best_info_train['model_params'] = model.state_dict()
148 | if best_info_test['Recall@1'] < test_acc_l[0]:
149 | best_info_test['epoch'] = epoch
150 | best_info_test['Recall@1'] = test_acc_l[0]
151 | best_info_test['Recall@all'] = test_acc_l
152 | best_info_test['MAE'] = test_mse_t
153 | best_info_test['model_params'] = model.state_dict()
154 |
155 | else:
156 | print(f'[Epoch={epoch+1}/{params.epoch_num}], loss={loss_all:.2f}, loss_l={loss_l_all:.2f},', end=' ')
157 | print(f'loss_t={loss_t_all:.2f}, loss_c={loss_c_all:.2f}, loss_s={loss_s_all:.2f};')
158 |
159 | # evaluation
160 | print('='*10, ' Testing')
161 | model.load_state_dict(best_info_test['model_params'])
162 | results_l, results_c, results_t = evaluate(model, data_test, test_id, params)
163 | print(f'Test results: loc={results_l}, cat={results_c}, tim={results_t:.2f}')
164 |
165 | # best metrics info
166 | print('='*10,' Run finished')
167 | print(f'Best train results is {best_info_train["Recall@all"]} at Epoch={best_info_train["epoch"]}')
168 | print(f'Best test results is {best_info_test["Recall@all"]} at Epoch={best_info_test["epoch"]}')
169 |
170 | return results_l, results_c, results_t, best_info_test
171 |
172 |
173 | def evaluate(model, data, data_id, params):
174 | '''Evaluate model performance
175 | '''
176 | l_acc_all = np.zeros(3)
177 | c_acc_all = np.zeros(3)
178 | t_mse_all = 0.
179 | valid_num_all = 0
180 | model.eval()
181 | # evaluate with batch
182 | for mask_batch, target_batch, data_batch in generate_batch_data(data, data_id, params.device, params.batch_size, params.cat_contained):
183 | # model forward
184 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch)
185 |
186 | # calculate metrics
187 | l_acc = calculate_recall(target_batch[0], l_pred)
188 | l_acc_all += l_acc
189 | t_mse_all += torch.nn.functional.l1_loss(th_pred.squeeze(-1), target_batch[1].squeeze(-1), reduction='sum').item()
190 | valid_num_all += valid_num
191 |
192 | if params.cat_contained:
193 | c_acc = calculate_recall(target_batch[2], c_pred)
194 | c_acc_all += c_acc
195 |
196 | return l_acc_all / valid_num_all, c_acc_all / valid_num_all, t_mse_all / valid_num_all
197 |
198 |
199 | if __name__ == '__main__':
200 |
201 | print('='*20, ' Program Start')
202 | params = settings()
203 | params.device = torch.device(f"cuda:{params.gpu}")
204 | print('Parameter is\n', params.__dict__)
205 |
206 | # file name to store
207 | params.file_out = f'{params.path_out}/{time.strftime("%Y%m%d")}_{params.data_name}_{params.out_filename}/'
208 | if not os.path.exists(params.file_out):
209 | os.makedirs(params.file_out)
210 |
211 | # Load data
212 | print('='*20, ' Loading data')
213 | start_time = time.time()
214 | if params.cat_contained:
215 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}_cat.pkl', 'rb'))
216 | else:
217 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}.pkl', 'rb'))
218 | print(f'Finished, time cost is {time.time()-start_time:.1f}')
219 |
220 | # metrics
221 | metrics = pd.DataFrame()
222 | best_info_all_run = {'epoch':0, 'Recall@1':0}
223 |
224 | # start running
225 | print('='*20, "Start Training")
226 | print(time.strftime("%Y%m%d-%H:%M:%S"))
227 | for i in range(params.run_num):
228 | print('='*20, f'Run {i}')
229 |
230 | # To Revise
231 | results, results_c, results_t, best_info_one_run = train(params, dataset)
232 | metric_dict = {'Rec-l@1': results[0], 'Rec-l@5': results[1], 'Rec-l@10': results[2],
233 | 'MAE': results_t, 'Rec-c@1': results_c[0], 'Rec-c@5': results_c[1], 'Rec-c@10': results_c[2]}
234 | metric_tmp = pd.DataFrame(metric_dict, index=[i])
235 | metrics = pd.concat([metrics, metric_tmp])
236 |
237 | if best_info_all_run['Recall@1'] < best_info_one_run['Recall@1']:
238 | best_info_all_run = best_info_one_run.copy()
239 | best_info_all_run['run'] = i
240 |
241 |
242 |
243 | print('='*20, "Finished")
244 | mean = pd.DataFrame(metrics.mean()).T
245 | mean.index = ['mean']
246 | std = pd.DataFrame(metrics.std()).T
247 | std.index = ['std']
248 | metrics = pd.concat([metrics, mean, std])
249 | print(metrics)
250 |
251 | # save
252 | metrics.to_csv(f'{params.file_out}metrics.csv')
253 | print('='*20, f'\nMetrics saved.')
254 | torch.save(best_info_all_run["model_params"], f'{params.file_out}model.pkl')
255 | print(f'Model saved (Run={best_info_all_run["run"]}, Epoch={best_info_all_run["epoch"]})')
256 | print(time.strftime("%Y%m%d-%H:%M:%S"))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | from torch.nn.utils.rnn import pad_sequence
6 | import random
7 |
8 | def generate_batch_data(data_input, data_id, device, batch_size, cat_contained):
9 | '''generate batch data'''
10 |
11 | # generate (uid, sid) queue
12 | data_queue = list()
13 | uid_list = data_id.keys()
14 | for uid in uid_list:
15 | for sid in data_id[uid]:
16 | for tar_idx in range(len(data_input[uid][sid]['target_l'])):
17 | data_queue.append((uid, sid, tar_idx))
18 |
19 | # generate batch data
20 | data_len = len(data_queue)
21 | batch_num = int(data_len/batch_size)
22 | random.shuffle(data_queue)
23 | print(f'Number of batch is {batch_num}')
24 | # iterate batch number times
25 | for i in range(batch_num):
26 | # batch data
27 | uid_batch = []
28 | loc_cur_batch = []
29 | tim_w_cur_batch = []
30 | tim_h_cur_batch = []
31 | loc_his_batch = []
32 | tim_w_his_batch = []
33 | tim_h_his_batch = []
34 | target_l_batch = []
35 | target_c_batch = []
36 | target_th_batch = []
37 | target_len_batch = []
38 | history_len_batch = []
39 | current_len_batch = []
40 | if cat_contained:
41 | cat_cur_batch = []
42 | cat_his_batch = []
43 |
44 | if i % 100 == 0:
45 | print('====', f'[Batch={i}/{batch_num}]', end=', ')
46 |
47 | batch_data = data_queue[i * batch_size : (i+1) * batch_size]
48 | # iterate batch index
49 | for one_data in batch_data:
50 | uid, sid, tar_idx = one_data
51 | uid_batch.append([uid])
52 | # current
53 | in_idx = tar_idx + 1
54 | loc_cur_batch.append(torch.LongTensor(data_input[uid][sid]['loc'][1][:in_idx]))
55 | tim_cur_ts = torch.LongTensor(data_input[uid][sid]['tim'][1][:in_idx])
56 | tim_w_cur_batch.append(tim_cur_ts[:, 0])
57 | tim_h_cur_batch.append(tim_cur_ts[:, 1])
58 | current_len_batch.append(tim_cur_ts.shape[0])
59 | # history
60 | loc_his_batch.append(torch.LongTensor(data_input[uid][sid]['loc'][0]))
61 | tim_his_ts = torch.LongTensor(data_input[uid][sid]['tim'][0])
62 | tim_w_his_batch.append(tim_his_ts[:, 0])
63 | tim_h_his_batch.append(tim_his_ts[:, 1])
64 | history_len_batch.append(tim_his_ts.shape[0])
65 | # target
66 | target_l = torch.LongTensor([data_input[uid][sid]['target_l'][tar_idx]])
67 | target_l_batch.append(target_l)
68 | target_len_batch.append(target_l.shape[0])
69 | target_th_batch.append(torch.LongTensor([data_input[uid][sid]['target_th'][tar_idx]]))
70 | # catrgory
71 | if cat_contained:
72 | cat_his_batch.append(torch.LongTensor(data_input[uid][sid]['cat'][0]))
73 | cat_cur_batch.append(torch.LongTensor(data_input[uid][sid]['cat'][1][:in_idx]))
74 | target_c_batch.append(torch.LongTensor([data_input[uid][sid]['target_c'][tar_idx]]))
75 |
76 |
77 |
78 | # padding
79 | uid_batch_tensor = torch.LongTensor(uid_batch).to(device)
80 | # current
81 | loc_cur_batch_pad = pad_sequence(loc_cur_batch, batch_first=True).to(device)
82 | tim_w_cur_batch_pad = pad_sequence(tim_w_cur_batch, batch_first=True).to(device)
83 | tim_h_cur_batch_pad = pad_sequence(tim_h_cur_batch, batch_first=True).to(device)
84 | # history
85 | loc_his_batch_pad = pad_sequence(loc_his_batch, batch_first=True).to(device)
86 | tim_w_his_batch_pad = pad_sequence(tim_w_his_batch, batch_first=True).to(device)
87 | tim_h_his_batch_pad = pad_sequence(tim_h_his_batch, batch_first=True).to(device)
88 | # target
89 | target_l_batch_pad = pad_sequence(target_l_batch, batch_first=True).to(device)
90 | target_th_batch_pad = pad_sequence(target_th_batch, batch_first=True).to(device)
91 |
92 | if cat_contained:
93 | cat_his_batch_pad = pad_sequence(cat_his_batch, batch_first=True).to(device)
94 | cat_cur_batch_pad = pad_sequence(cat_cur_batch, batch_first=True).to(device)
95 | target_c_batch_pad = pad_sequence(target_c_batch, batch_first=True).to(device)
96 | yield (target_len_batch, history_len_batch, current_len_batch),\
97 | (target_l_batch_pad, target_th_batch_pad, target_c_batch_pad),\
98 | (uid_batch_tensor,\
99 | loc_his_batch_pad, loc_cur_batch_pad,\
100 | tim_w_his_batch_pad, tim_w_cur_batch_pad,\
101 | tim_h_his_batch_pad, tim_h_cur_batch_pad,\
102 | cat_his_batch_pad, cat_cur_batch_pad)
103 | else:
104 | yield (target_len_batch, history_len_batch, current_len_batch),\
105 | (target_l_batch_pad, target_th_batch_pad),\
106 | (uid_batch_tensor,\
107 | loc_his_batch_pad, loc_cur_batch_pad,\
108 | tim_w_his_batch_pad, tim_w_cur_batch_pad,\
109 | tim_h_his_batch_pad, tim_h_cur_batch_pad)
110 |
111 |
112 |
113 |
114 | print('Batch Finished')
115 |
116 |
117 | def generate_mask(data_len):
118 | '''Generate mask
119 | Args:
120 | data_len : one dimension list, reflect sequence length
121 | '''
122 | mask = []
123 | for i_len in data_len:
124 | mask.append(torch.ones(i_len).bool())
125 | return ~pad_sequence(mask, batch_first=True)
126 |
127 |
128 |
129 | def calculate_recall(target_pad, pred_pad):
130 | '''Calculate recall
131 | Args:
132 | target: (batch, max_seq_len), padded target
133 | pred: (batch, max_seq_len, pred_scores), padded
134 | '''
135 | # variable
136 | acc = np.zeros(3) # 1, 5, 10
137 |
138 | # reshape and to numpy
139 | target_list = target_pad.data.reshape(-1).cpu().numpy()
140 | # topK
141 | pid_size = pred_pad.shape[-1]
142 | _, pred_list = pred_pad.data.reshape(-1, pid_size).topk(20)
143 | pred_list = pred_list.cpu().numpy()
144 |
145 | for idx, pred in enumerate(pred_list):
146 | target = target_list[idx]
147 | if target == 0: # pad
148 | continue
149 | if target in pred[:1]:
150 | acc += 1
151 | elif target in pred[:5]:
152 | acc[1:] += 1
153 | elif target in pred[:10]:
154 | acc[2:] += 1
155 |
156 | return acc
157 |
158 |
159 | def get_model_params(model):
160 | total_num = sum(param.numel() for param in model.parameters())
161 | trainable_num = sum(param.numel() for param in model.parameters() if param.requires_grad)
162 | print(f'==== Parameter numbers:\n total={total_num}, trainable={trainable_num}')
163 |
164 |
165 |
166 | class progress_supervisor(object):
167 | def __init__(self, all_num, path):
168 | self.cur_num = 1
169 | self.start_time = time.time()
170 | self.all_num = all_num
171 | self.path = path
172 |
173 | with open(self.path, 'w') as f:
174 | f.write('Start')
175 |
176 | def update(self):
177 | '''Usage:
178 | count_time = count_run_time(5 * 4 * 4)
179 | count_time.path = f'{args.out_dir}{args.model_name}_{args.data_name}.txt'
180 | main()
181 | count_time.current_count()
182 | '''
183 |
184 | past_time = time.time()-self.start_time
185 | avg_time = past_time / self.cur_num
186 | fut_time = avg_time * (self.all_num - self.cur_num)
187 |
188 | content = '=' * 10 + ' Progress observation'
189 | content += f'Current time is {time.strftime("%Y-%m-%d %H:%M:%S")}\n'
190 | content += f'Current Num: {self.cur_num} / {self.all_num}\n'
191 | content += f'Past time: {past_time:.2f}s ({past_time/3600:.2f}h)\n'
192 | content += f'Average time: {avg_time:.2f}s ({avg_time/3600:.2f}h)\n'
193 | content += f'Future time: {fut_time:.2f}s ({fut_time/3600:.2f}h)\n'
194 |
195 | with open(self.path, 'w') as f:
196 | f.write(content)
197 |
198 | self.cur_num += 1
199 | return content
200 |
201 | def delete(self):
202 | if os.path.exists(self.path):
203 | os.remove(self.path)
204 |
205 | if not os.path.exists(self.path):
206 | print('Supervisor file delete success')
--------------------------------------------------------------------------------