├── LICENSE ├── README.md ├── data └── NYC_cat.pkl ├── data_prepare.py ├── data_preprocess.py ├── evaluate.py ├── model.py ├── train_test.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Urban Mobility 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSLSL 2 | 3 | 4 | # Performances 5 | The latest experimental results are as follows: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 |
Category Location
R@1 R@5 R@10 R@1 R@5 R@10
NYC 0.327 0.661 0.759 0.268 0.568 0.656
TKY 0.448 0.801 0.875 0.240 0.488 0.580
Dallas - - - 0.126 0.243 0.297
50 | 51 | 52 | 53 | # Datasets 54 | - The processed data can be found in the "data" folder, which was processed by ```data_preproess.py``` and ```data_prepare.py```. 55 | - The raw data can be found at the following open source. 56 | - [Foursqure (NYC and TKY)](https://sites.google.com/site/yangdingqi/home/foursquare-dataset?authuser=0) 57 | - [Gowalla (Dallas)](https://snap.stanford.edu/data/loc-gowalla.html) 58 | 59 | # Requirements 60 | - Python>=3.8 61 | - Pytorch>=1.8.1 62 | - Numpy 63 | - Pandas 64 | 65 | 66 | # Project Structure 67 | - ```/data```: file to store processed data 68 | - ```/results```: file to store results such as trained model and metrics. 69 | - ```data_preprocess.py```: data preprocessing to filter sparse users and locations (fewer than 10 records) and merge consecutive records (same user and location on the same day). 70 | - ```data_prepare.py```: data preparation for CSLSL (split trajectory and generate data). 71 | - ```train_test.py```: the entry to train and test a new model. 72 | - ```evaluate.py```: the entry to evalute a pretrained model. 73 | - ```model.py```: model defination. 74 | - ```utils.py```: tools such as batch generation and metric calculation. 75 | 76 | 77 | 78 | 79 | # Usage 80 | 1. Train and test a new model 81 | > ```python 82 | > python train_test.py --data_name NYC 83 | > ``` 84 | 85 | 2. Evaluate a pretrained model 86 | > ```python 87 | > python evaluate.py --data_name NYC --model_name model_NYC 88 | > ``` 89 | 90 | Detailed parameter description refers to ```evaluate.py``` and ```train_test.py``` 91 | -------------------------------------------------------------------------------- /data/NYC_cat.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/urbanmobility/CSLSL/d59c882b1feb329d0d3f72067bb111e72a57cbe3/data/NYC_cat.pkl -------------------------------------------------------------------------------- /data_prepare.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import argparse 4 | import numpy as np 5 | import pickle 6 | from math import pi 7 | from collections import Counter 8 | 9 | 10 | 11 | class DataGeneration(object): 12 | def __init__(self, params): 13 | 14 | self.__dict__.update(params.__dict__) 15 | 16 | self.raw_data = {} # raw user's trajectory. {uid: [[pid, tim], ...]} 17 | self.poi_count = {} # raw location counts. {pid: count} 18 | self.data_filtered = {} 19 | self.uid_list = [] # filtered user id 20 | self.pid_dict = {} # filtered location id map 21 | self.train_data = {} # train data with history, {'uid': {'sid': {'loc': [], 'tim': [], 'target': [] (, 'cat': [])}}} 22 | self.train_id = {} # train data session id list 23 | self.test_data = {} 24 | self.test_id = {} 25 | self.tim_w = set() 26 | self.tim_h = set() 27 | 28 | self.raw_lat_lon = {} # count for latitude and longitude 29 | self.new_lat_lon = {} 30 | self.lat_lon_radians = {} 31 | 32 | if self.cat_contained: 33 | self.cid_dict = {} # cid_dict 34 | self.cid_count_dict = {} # cid count 35 | self.raw_cat_dict = {} # cid-cat 36 | self.new_cat_dict = {} 37 | self.pid_cid_dict = {} # pid-cid dict 38 | 39 | # 1. read trajectory data 40 | def load_trajectory(self): 41 | with open(self.path_in + self.data_name + '.txt', 'r') as fid: 42 | for i, line in enumerate(fid): 43 | if self.data_name in ['NYC', 'TKY']: 44 | uid, pid, cid, cat, lat, lon, _, tim = line.strip().split('\t') 45 | elif self.data_name == 'Dallas': 46 | uid, tim, pid, lon, lat = line.strip().split('\t') 47 | else: 48 | uid, tim, lat, lon, pid = line.strip().split('\t') 49 | 50 | # Note: user and location is id 51 | if self.cat_contained: 52 | # count uid records 53 | if uid not in self.raw_data: 54 | self.raw_data[uid] = [[pid, tim, cid]] 55 | else: 56 | self.raw_data[uid].append([pid, tim, cid]) 57 | # count raw_cid-cat 58 | if cid not in self.raw_cat_dict: 59 | self.raw_cat_dict[cid] = cat 60 | else: 61 | if uid not in self.raw_data: 62 | self.raw_data[uid] = [[pid, tim]] 63 | else: 64 | self.raw_data[uid].append([pid, tim]) 65 | if pid not in self.poi_count: 66 | self.poi_count[pid] = 1 67 | else: 68 | self.poi_count[pid] += 1 69 | 70 | # count poi latitude and longitude 71 | if pid not in self.raw_lat_lon: 72 | self.raw_lat_lon[pid] = [eval(lat), eval(lon)] 73 | 74 | # 2. filter users and locations, and then split trajectory into sessions 75 | def filter_and_divide_sessions(self): 76 | POI_MIN_RECORD_FOR_USER = 1 # keep same setting with DeepMove and LSTPM 77 | 78 | # filter user and location 79 | uid_list = [x for x in self.raw_data if len(self.raw_data[x]) > self.user_record_min] # uid list 80 | pid_list = [x for x in self.poi_count if self.poi_count[x] > self.poi_record_min] # pid list 81 | 82 | # iterate each user 83 | for uid in uid_list: 84 | user_records = self.raw_data[uid] # user_records is [[pid, tim (, cid)]] 85 | topk = Counter([x[0] for x in user_records]).most_common() # most common poi, [(poi, count), ...] 86 | topk_filter = [x[0] for x in topk if x[1] > POI_MIN_RECORD_FOR_USER] # the poi that the user go more than one time 87 | sessions = {} # sessions is {'sid' : [[pid, [week, hour] (, cid)], ...]} 88 | # iterate each record 89 | for i, record in enumerate(user_records): 90 | if self.cat_contained: 91 | poi, tim, cid = record 92 | else: 93 | poi, tim = record 94 | try: 95 | # time processing 96 | time_struct = time.strptime(tim, "%Y-%m-%dT%H:%M:%SZ") 97 | calendar_date = datetime.date(time_struct.tm_year, time_struct.tm_mon, time_struct.tm_mday).isocalendar() 98 | current_week = f'{calendar_date.year}-{calendar_date.week}' 99 | # Encode time 100 | # tim_code = [time_struct.tm_wday+1, time_struct.tm_hour*2+int(time_struct.tm_min/30)+1 ] # week(1~7), hours(1~48) 101 | tim_code = [time_struct.tm_wday+1, int(time_struct.tm_hour/2)+1 ] # week(1~7), hours(1~12) 102 | # revise record 103 | record[1] = tim_code 104 | except Exception as e: 105 | print('error:{}'.format(e)) 106 | raise Exception 107 | 108 | # divide session. Rule is: same week 109 | sid = len(sessions) # session id 110 | if poi not in pid_list and poi not in topk_filter: 111 | # filter the poi if poi not in topk_filter: 112 | continue 113 | if i == 0 or len(sessions) == 0: 114 | sessions[sid] = [record] 115 | else: 116 | if last_week != current_week: # new session 117 | sessions[sid] = [record] 118 | else: 119 | sessions[sid - 1].append(record) # Note: data is already merged 120 | last_week = current_week 121 | sessions_filtered = {} 122 | # filter session with session_min 123 | for s in sessions: 124 | if len(sessions[s]) >= self.session_min: 125 | sessions_filtered[len(sessions_filtered)] = sessions[s] 126 | # filter user with sessions_min, that is, the user must have sessions_min's sessions. 127 | if len(sessions_filtered) < self.sessions_min: 128 | continue 129 | 130 | # ReEncode location index (may encode category) 131 | for sid in sessions_filtered: # sessions is {'sid' : [[pid, [week, hour]], ...]} 132 | for idx, record in enumerate(sessions_filtered[sid]): 133 | # reEncode location 134 | if record[0] not in self.pid_dict: 135 | self.pid_dict[record[0]] = len(self.pid_dict) + 1 # the id start from 1 136 | new_pid = self.pid_dict[record[0]] 137 | # new pid for latitude and longitude 138 | self.new_lat_lon[new_pid] = self.raw_lat_lon[record[0]] 139 | self.lat_lon_radians[new_pid] = list(np.array(self.raw_lat_lon[record[0]]) * pi / 180) 140 | # assign new pid 141 | record[0] = new_pid 142 | # time 143 | self.tim_w.add(record[1][0]) 144 | self.tim_h.add(record[1][1]) 145 | # category 146 | if self.cat_contained: 147 | # encode cid 148 | if record[2] not in self.cid_dict: 149 | new_cid = len(self.cid_dict) + 1 # the id start from 1 150 | self.cid_dict[record[2]] = new_cid 151 | self.new_cat_dict[new_cid] = self.raw_cat_dict[record[2]] # raw_cid-cat to new_cid-cat 152 | self.cid_count_dict[new_cid] = 1 153 | # assign cid 154 | record[2] = self.cid_dict[record[2]] 155 | # count cid 156 | self.cid_count_dict[record[2]] += 1 157 | # pid-cid dict 158 | if new_pid not in self.pid_cid_dict: 159 | self.pid_cid_dict[new_pid] = record[2] 160 | # reassign record 161 | sessions_filtered[sid][idx] = record 162 | 163 | # divide train and test 164 | sessions_id = list(sessions_filtered.keys()) 165 | split_id = int(np.floor(self.train_split * len(sessions_id))) 166 | train_id = sessions_id[:split_id] 167 | test_id = sessions_id[split_id:] 168 | assert len(train_id) > 0, 'train sessions have error' 169 | assert len(test_id) > 0, 'test sessions have error' 170 | # preprare final data. (ReEncode user index), he id start from 1 171 | self.data_filtered[len(self.data_filtered)+1] = {'sessions_count': len(sessions_filtered), 'sessions': sessions_filtered, 'train': train_id, 'test': test_id} 172 | 173 | 174 | # final uid list 175 | self.uid_list = list(self.data_filtered.keys()) 176 | print(f'Final user is {len(self.uid_list)}, location is {len(self.pid_dict)}') 177 | 178 | 179 | # 3. generate data with history sessions 180 | def generate_history_sessions(self, mode): 181 | print('='*4, f'generate history sessions in mode={mode}') 182 | data_input = self.data_filtered 183 | data = {} 184 | data_id = {} 185 | user_list = data_input.keys() 186 | 187 | for uid in user_list: 188 | data[uid] = {} 189 | user_sid_list = data_input[uid][mode] 190 | data_id[uid] = user_sid_list.copy() 191 | for idx, sid in enumerate(user_sid_list): 192 | # require at least one session as history; one <- history_session_min 193 | if mode == 'train' and idx < self.history_session_min: 194 | data_id[uid].pop(idx) 195 | continue 196 | data[uid][sid] = {} 197 | loc_seq_cur = [] 198 | loc_seq_his = [] 199 | tim_seq_cur = [] 200 | tim_seq_his = [] 201 | 202 | if self.cat_contained: 203 | cat_seq_cur = [] 204 | cat_seq_his = [] 205 | 206 | if mode == 'test': # in test mode, append all train data as history 207 | train_sid = data_input[uid]['train'] 208 | for tmp_sid in train_sid: 209 | loc_seq_his.extend([record[0] for record in data_input[uid]['sessions'][tmp_sid]]) 210 | tim_seq_his.extend([record[1] for record in data_input[uid]['sessions'][tmp_sid]]) 211 | if self.cat_contained: 212 | cat_seq_his.extend([record[2] for record in data_input[uid]['sessions'][tmp_sid]]) 213 | # append past sessions 214 | for past_idx in range(idx): 215 | tmp_sid = user_sid_list[past_idx] 216 | loc_seq_his.extend([record[0] for record in data_input[uid]['sessions'][tmp_sid]]) 217 | tim_seq_his.extend([record[1] for record in data_input[uid]['sessions'][tmp_sid]]) 218 | if self.cat_contained: 219 | cat_seq_his.extend([record[2] for record in data_input[uid]['sessions'][tmp_sid]]) 220 | # current session 221 | loc_seq_cur.extend([record[0] for record in data_input[uid]['sessions'][sid][:-1]]) # [[pid1], [pid2], ...] 222 | tim_seq_cur.extend([record[1] for record in data_input[uid]['sessions'][sid][:-1]]) 223 | if self.cat_contained: 224 | cat_seq_cur.extend([record[2] for record in data_input[uid]['sessions'][sid][:-1]]) 225 | 226 | # store sequence 227 | data[uid][sid]['target_l'] = [record[0] for record in data_input[uid]['sessions'][sid][1:]] 228 | data[uid][sid]['target_th'] = [record[1][1] for record in data_input[uid]['sessions'][sid][1:]] 229 | data[uid][sid]['loc'] = [loc_seq_his, loc_seq_cur] # list 230 | data[uid][sid]['tim'] = [tim_seq_his, tim_seq_cur] # list 231 | if self.cat_contained: 232 | data[uid][sid]['cat'] = [cat_seq_his, cat_seq_cur] # list 233 | data[uid][sid]['target_c'] = [record[2] for record in data_input[uid]['sessions'][sid][1:]] 234 | 235 | # train/test_data is {'uid': {'sid': {'loc': [pid_seq], 'tim': [[week, hour], ...] (, 'cat': [cid_seq])}}} 236 | # 237 | if mode == 'train': 238 | self.train_data = data 239 | self.train_id = data_id 240 | elif mode == 'test': 241 | self.test_data = data 242 | self.test_id = data_id 243 | print('Finish') 244 | 245 | 246 | # 4. save variables 247 | def save_variables(self): 248 | dataset = {'train_data': self.train_data, 'train_id': self.train_id, 249 | 'test_data': self.test_data, 'test_id': self.test_id, 250 | 'pid_dict': self.pid_dict, 'uid_list': self.uid_list, 251 | 'pid_lat_lon': self.new_lat_lon, 252 | 'pid_lat_lon_radians' : self.lat_lon_radians, 253 | 'parameters': self.get_parameters()} 254 | if self.cat_contained: 255 | dataset['cid_dict'] = self.cid_dict 256 | dataset['cid_count_dict'] = self.cid_count_dict 257 | dataset['cid_cat_dict'] = self.new_cat_dict 258 | dataset['pid_cid_dict'] = self.pid_cid_dict 259 | pickle.dump(dataset, open(self.path_out + self.data_name + '_cat.pkl', 'wb')) 260 | else: 261 | pickle.dump(dataset, open(self.path_out + self.data_name + '.pkl', 'wb')) 262 | def get_parameters(self): 263 | parameters = self.__dict__.copy() 264 | del parameters['raw_data'] 265 | del parameters['poi_count'] 266 | del parameters['data_filtered'] 267 | del parameters['uid_list'] 268 | del parameters['pid_dict'] 269 | del parameters['train_data'] 270 | del parameters['train_id'] 271 | del parameters['test_data'] 272 | del parameters['test_id'] 273 | if self.cat_contained: 274 | del parameters['cid_dict'] 275 | 276 | return parameters 277 | 278 | 279 | def parse_args(): 280 | parser = argparse.ArgumentParser() 281 | parser.add_argument('--path_in', type=str, default='../data/', help="input data path") 282 | parser.add_argument('--path_out', type=str, default='./data/', help="output data path") 283 | parser.add_argument('--data_name', type=str, default='NYC', help="data name") 284 | parser.add_argument('--user_record_min', type=int, default=10, help="user record length filter threshold") 285 | parser.add_argument('--poi_record_min', type=int, default=10, help="location record length filter threshold") 286 | parser.add_argument('--session_min', type=int, default=2, help="control the length of session not too short") 287 | parser.add_argument('--sessions_min', type=int, default=5, help="the minimum amount of the user's sessions") 288 | parser.add_argument('--train_split', type=float, default=0.8, help="train/test ratio") 289 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category") 290 | parser.add_argument('--history_session_min', type=int, default=1, help="minimun number of history session") 291 | 292 | if __name__ == '__main__': 293 | return parser.parse_args() 294 | else: 295 | return parser.parse_args([]) 296 | 297 | 298 | if __name__ == '__main__': 299 | 300 | start_time = time.time() 301 | 302 | params = parse_args() 303 | data_generator = DataGeneration(params) 304 | parameters = data_generator.get_parameters() 305 | print('='*20 + ' Parameter settings') 306 | print(', '.join([p + '=' + str(parameters[p]) for p in parameters])) 307 | print('='*20 + ' Start processing') 308 | print('==== Load trajectory from {}'.format(data_generator.path_in)) 309 | data_generator.load_trajectory() 310 | 311 | print('==== filter users') 312 | data_generator.filter_and_divide_sessions() 313 | 314 | 315 | print('==== generate history sessions') 316 | data_generator.generate_history_sessions('train') 317 | data_generator.generate_history_sessions('test') 318 | 319 | print('==== save prepared data') 320 | data_generator.save_variables() 321 | 322 | print('==== Preparetion Finished') 323 | print('Raw users:{} raw locations:{}'.format( 324 | len(data_generator.raw_data), len(data_generator.poi_count))) 325 | print(f'Final users:{len(data_generator.uid_list)}, min_id:{np.min(data_generator.uid_list)}, max_id:{np.max(data_generator.uid_list)}') 326 | pid_list = list(data_generator.pid_dict.values()) 327 | print(f'Final locations:{len(pid_list)}, min_id:{np.min(pid_list)}, max_id:{np.max(pid_list)}') 328 | print(f'Final time-week:{len(data_generator.tim_w)}, min_id:{np.min(list(data_generator.tim_w))}, max_id:{np.max(list(data_generator.tim_w))}') 329 | print(f'Final time-hour:{len(data_generator.tim_h)}, min_id:{np.min(list(data_generator.tim_h))}, max_id:{np.max(list(data_generator.tim_h))}') 330 | if params.cat_contained: 331 | cid_list = list(data_generator.cid_dict.values()) 332 | print(f'Final categories:{len(cid_list)}, min_id:{np.min(cid_list)}, max_id:{np.max(cid_list)}') 333 | print(f'Time cost is {time.time()-start_time:.0f}s') 334 | 335 | -------------------------------------------------------------------------------- /data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import time 4 | import argparse 5 | 6 | 7 | 8 | def settings(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--path', default='./data/', type=str, 11 | help='data path') 12 | parser.add_argument('--dataname', default='NYC', type=str, 13 | help='data name') 14 | parser.add_argument('--filetype', default='txt', type=str, 15 | help='file type') 16 | parser.add_argument('--user_po', default=0, type=int, 17 | help='Position in record for user') 18 | parser.add_argument('--loc_po', default=1, type=int, 19 | help='Position in record for location') 20 | parser.add_argument('--tim_po', default=2, type=int, 21 | help='Position in record for time') 22 | 23 | parser.add_argument('--user_record_min', default=10, type=int, 24 | help='Minimun record number for user') 25 | parser.add_argument('--loc_record_min', default=10, type=int, 26 | help='Minimun record number for location') 27 | return parser.parse_args() 28 | 29 | 30 | def preprocessing(params): 31 | '''Preprocessing data 32 | Note: 33 | 1. Raw data is sorted by user(1) and time(2) 34 | 2. Filter sparse data with minimun record numbers. 35 | 3. Encode user and location id. 36 | 4. Merge data with same user and location in a day. 37 | 38 | ''' 39 | 40 | # Loading and Filtering sparse data 41 | print('='*20, 'Loading and preprocessing sparse data') 42 | filepath = f'{params.path}{params.dataname}.{params.filetype}' 43 | print(f'Path is {filepath}') 44 | loc_count = {} # store location info with loc-num 45 | user_count = {} # store user info with user-num 46 | user_id = {} 47 | loc_id = {} 48 | 49 | # load file and count numbers 50 | print('='*20, 'Loading and Counting') 51 | if params.filetype == 'txt': 52 | with open(filepath, 'r') as f: 53 | reader = csv.reader(f, delimiter='\t') 54 | for record in reader: 55 | if '' in record: 56 | continue 57 | user = record[params.user_po] 58 | loc = record[params.loc_po] 59 | if user not in user_count: 60 | user_count[user] = 1 61 | else: 62 | user_count[user] += 1 63 | if loc not in loc_count: 64 | loc_count[loc] = 1 65 | else: 66 | loc_count[loc] += 1 67 | 68 | record_num = os.popen(f'wc -l {filepath}').readlines()[0].split()[0] 69 | print(f'Finished, records is {record_num}, all user is {len(user_count)}, all location is {len(loc_count)}') 70 | 71 | 72 | # Filter and encode user and location 73 | print('='*20, 'Filtering and encoding') 74 | for i in user_count: 75 | if user_count[i] > params.user_record_min: 76 | user_id[i] = len(user_id) 77 | for i in loc_count: 78 | if loc_count[i] > params.loc_record_min: 79 | loc_id[i] = len(loc_id) 80 | 81 | # store 82 | filter_path = f'{params.path}{params.dataname}_filtered.txt' 83 | print(f'Filter path is {filter_path}') 84 | with open(filter_path, 'w') as f_out: 85 | writer = csv.writer(f_out, delimiter='\t') 86 | with open(filepath, 'r') as f_in: 87 | reader = csv.reader(f_in, delimiter='\t') 88 | for record in reader: 89 | if '' in record: 90 | continue 91 | user = record[params.user_po] 92 | loc = record[params.loc_po] 93 | if user in user_id and loc in loc_id: 94 | record[params.user_po] = user_id[user] 95 | record[params.loc_po] = loc_id[loc] 96 | writer.writerow(record) 97 | 98 | record_num = os.popen(f'wc -l {filter_path}').readlines()[0].split()[0] 99 | print(f'Finished, records is {record_num}, user is {len(user_id)}, location is {len(loc_id)}') 100 | 101 | 102 | # Merge data 103 | print('='*20, 'Merging') 104 | merge_path = f'{params.path}{params.dataname}_merged.txt' 105 | print(f'Merge path is {filter_path}') 106 | with open(merge_path, 'w') as f_out: 107 | writer = csv.writer(f_out, delimiter='\t') 108 | # get first record 109 | with open(filter_path, 'r') as f_in: 110 | pre_record = f_in.readlines()[0].split('\t') 111 | # all record 112 | with open(filter_path, 'r') as f_in: 113 | reader = csv.reader(f_in, delimiter='\t') 114 | for record in reader: 115 | # same person, same location, same day 116 | if record[params.user_po] == pre_record[params.user_po] and \ 117 | record[params.loc_po] == pre_record[params.loc_po] and \ 118 | record[params.tim_po].split('T')[0] == pre_record[params.tim_po].split('T')[0]: 119 | continue 120 | writer.writerow(record) 121 | pre_record = record 122 | 123 | record_num = os.popen(f'wc -l {merge_path}').readlines()[0].split()[0] 124 | print(f'Finished, records is {record_num}') 125 | 126 | 127 | 128 | if __name__ == '__main__': 129 | 130 | start_time = time.time() 131 | params = settings() 132 | preprocessing(params) 133 | print('Time cost:', f'{time.time()-start_time:.0f}') -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | 10 | from utils import * 11 | from model import * 12 | 13 | 14 | def settings(param=[]): 15 | parser = argparse.ArgumentParser() 16 | # data params 17 | parser.add_argument('--path_in', type=str, default='./data/', help="input data path") 18 | parser.add_argument('--path_out', type=str, default='./results/', help="output data path") 19 | parser.add_argument('--data_name', type=str, default='NYC', help="data name") 20 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category") 21 | parser.add_argument('--out_filename', type=str, default='', help="output data filename") 22 | # train params 23 | parser.add_argument('--gpu', type=str, default='0', help="GPU index to choose") 24 | parser.add_argument('--run_num', type=int, default=10, help="run number") 25 | parser.add_argument('--model_path', type=str, default='./results/pretrained/', help="model path") 26 | parser.add_argument('--model_name', type=str, default='model_NYC', help="model name to load") 27 | parser.add_argument('--batch_size', type=int, default=64, help="batch size") 28 | 29 | # model params 30 | # embedding 31 | parser.add_argument('--user_embed_dim', type=int, default=20, help="user embedding dimension") 32 | parser.add_argument('--loc_embed_dim', type=int, default=200, help="loc embedding dimension") 33 | parser.add_argument('--tim_h_embed_dim', type=int, default=20, help="time hour embedding dimension") 34 | parser.add_argument('--tim_w_embed_dim', type=int, default=10, help="time week embedding dimension") 35 | parser.add_argument('--cat_embed_dim', type=int, default=100, help="category embedding dimension") 36 | # rnn 37 | parser.add_argument('--rnn_type', type=str, default='gru', help="rnn type") 38 | parser.add_argument('--rnn_layer_num', type=int, default=1, help="rnn layer number") 39 | parser.add_argument('--rnn_t_hid_dim', type=int, default=600, help="rnn hidden dimension for t") 40 | parser.add_argument('--rnn_c_hid_dim', type=int, default=600, help="rnn hidden dimension for c") 41 | parser.add_argument('--rnn_l_hid_dim', type=int, default=600, help="rnn hidden dimension for l") 42 | parser.add_argument('--dropout', type=float, default=0.1, help="drop out for rnn") 43 | 44 | 45 | if __name__ == '__main__' and param == []: 46 | params = parser.parse_args() 47 | else: 48 | params = parser.parse_args(param) 49 | 50 | if not os.path.exists(params.path_out): 51 | os.mkdir(params.path_out) 52 | 53 | return params 54 | 55 | 56 | def evaluate(params, dataset): 57 | '''Evaluate model performance 58 | ''' 59 | 60 | # dataset info 61 | params.uid_size = len(dataset['uid_list']) 62 | params.pid_size = len(dataset['pid_dict']) 63 | params.cid_size = len(dataset['cid_dict']) if params.cat_contained else 0 64 | # generate input data 65 | data_test, test_id = dataset['test_data'], dataset['test_id'] 66 | pid_lat_lon_radians = torch.tensor([[0, 0]] + list(dataset['pid_lat_lon_radians'].values())).to(params.device) 67 | 68 | # load model 69 | model = Model(params).to(params.device) 70 | model.load_state_dict(torch.load(params.model_path+params.model_name+'.pkl')) 71 | model.eval() 72 | 73 | 74 | l_acc_all = np.zeros(3) 75 | c_acc_all = np.zeros(3) 76 | t_mse_all = 0. 77 | valid_num_all = 0 78 | model.eval() 79 | # evaluate with batch 80 | for mask_batch, target_batch, data_batch in generate_batch_data(data_test, test_id, params.device, params.batch_size, params.cat_contained): 81 | # model forward 82 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch) 83 | 84 | # calculate metrics 85 | l_acc = calculate_recall(target_batch[0], l_pred) 86 | l_acc_all += l_acc 87 | t_mse_all += torch.nn.functional.l1_loss(th_pred.squeeze(-1), target_batch[1].squeeze(-1), reduction='sum').item() 88 | valid_num_all += valid_num 89 | 90 | if params.cat_contained: 91 | c_acc = calculate_recall(target_batch[2], c_pred) 92 | c_acc_all += c_acc 93 | 94 | return l_acc_all / valid_num_all, c_acc_all / valid_num_all, t_mse_all / valid_num_all 95 | 96 | 97 | if __name__ == '__main__': 98 | 99 | print('='*20, ' Program Start') 100 | params = settings() 101 | params.device = torch.device(f"cuda:{params.gpu}") 102 | print('Parameter is\n', params.__dict__) 103 | 104 | # file name to store 105 | FILE_NAME = [params.path_out, f'{time.strftime("%Y%m%d")}_{params.data_name}_'] 106 | FILE_NAME[1] += f'{params.out_filename}' 107 | 108 | # Load data 109 | print('='*20, ' Loading data') 110 | start_time = time.time() 111 | if params.cat_contained: 112 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}_cat.pkl', 'rb')) 113 | else: 114 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}.pkl', 'rb')) 115 | print(f'Finished, time cost is {time.time()-start_time:.1f}') 116 | 117 | # metrics 118 | metrics = pd.DataFrame() 119 | 120 | # start running 121 | print('='*20, "Start Evaluating") 122 | for i in range(params.run_num): 123 | print('='*20, f'Run {i}') 124 | 125 | # To Revise 126 | results, results_c, results_t = evaluate(params, dataset) 127 | metric_dict = {'Rec-l@1': results[0], 'Rec-l@5': results[1], 'Rec-l@10': results[2], 128 | 'MAE': results_t, 'Rec-c@1': results_c[0], 'Rec-c@5': results_c[1], 'Rec-c@10': results_c[2]} 129 | metric_tmp = pd.DataFrame(metric_dict, index=[i]) 130 | metrics = pd.concat([metrics, metric_tmp]) 131 | 132 | 133 | print('='*20, "Finished") 134 | mean = pd.DataFrame(metrics.mean()).T 135 | mean.index = ['mean'] 136 | std = pd.DataFrame(metrics.std()).T 137 | std.index = ['std'] 138 | metrics = pd.concat([metrics, mean, std]) 139 | print(metrics) 140 | 141 | # save 142 | metrics.to_csv(f'{FILE_NAME[0]}metrics_{FILE_NAME[1]}.csv') 143 | print('='*20, f'\nMetrics saved. File name is {FILE_NAME[0]}metrics_{FILE_NAME[1]}.csv') -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 5 | import numpy as np 6 | from utils import generate_mask 7 | 8 | NUM_TASK = 3 9 | EARTHRADIUS = 6371.0 10 | 11 | class Model(nn.Module): 12 | def __init__(self, params): 13 | super().__init__() 14 | self.__dict__.update(params.__dict__) 15 | self.th_size = 12 + 1 # +1 because of the padding value 0, 16 | self.tw_size = 7 + 1 17 | self.pid_size += 1 18 | self.uid_size += 1 19 | self.cid_size += 1 20 | # Dim 21 | RNN_input_dim = self.user_embed_dim + self.loc_embed_dim + self.tim_w_embed_dim + self.tim_h_embed_dim 22 | if self.cat_contained: 23 | RNN_input_dim += self.cat_embed_dim 24 | 25 | # Embedding; (all id is start from 1) 26 | self.user_embedder = nn.Embedding(self.uid_size, self.user_embed_dim) # without padding 27 | self.loc_embedder = nn.Embedding(self.pid_size, self.loc_embed_dim, padding_idx=0) 28 | self.tim_embedder_week = nn.Embedding(self.tw_size, self.tim_w_embed_dim, padding_idx=0) 29 | self.tim_embedder_hour = nn.Embedding(self.th_size, self.tim_h_embed_dim, padding_idx=0) 30 | self.cat_embedder = nn.Embedding(self.cid_size, self.cat_embed_dim, padding_idx=0) 31 | 32 | # Capturer 33 | # Version: Seperate 34 | self.capturer_t = SessionCapturer(RNN_input_dim, self.rnn_t_hid_dim, params) 35 | self.capturer_l = SessionCapturer(RNN_input_dim, self.rnn_l_hid_dim, params) 36 | 37 | if self.cat_contained: 38 | self.cat_embedder = nn.Embedding(self.cid_size, self.cat_embed_dim, padding_idx=0) 39 | self.capturer_c = SessionCapturer(RNN_input_dim, self.rnn_c_hid_dim, params) 40 | 41 | # CMTL 42 | if self.cat_contained: 43 | self.fc_c = nn.Linear(self.rnn_c_hid_dim, self.cid_size) 44 | self.label_trans_c = nn.Linear(self.cid_size, self.cat_embed_dim) 45 | self.fc_t = nn.Linear(self.rnn_t_hid_dim + self.cat_embed_dim, 1) 46 | self.label_trans_t = nn.Linear(1, self.tim_h_embed_dim) 47 | self.fc_l = nn.Linear(self.rnn_l_hid_dim + self.tim_h_embed_dim, self.pid_size) 48 | else: 49 | self.fc_t = nn.Linear(self.rnn_t_hid_dim, 1) 50 | self.label_trans_t = nn.Linear(1, self.tim_h_embed_dim) 51 | self.fc_l = nn.Linear(self.rnn_l_hid_dim + self.tim_h_embed_dim, self.pid_size) 52 | 53 | # Loss 54 | self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum', ignore_index=0) 55 | self.mae_loss = nn.L1Loss(reduction='sum') 56 | 57 | def forward(self, data, mask_batch): 58 | 59 | # Capture Short-term representation 60 | # 1) get data and generate mask 61 | if self.cat_contained: 62 | uid_tensor, loc_his_pad, loc_cur_pad, tim_w_his_pad, tim_w_cur_pad,\ 63 | tim_h_his_pad, tim_h_cur_pad, cat_his_pad, cat_cur_pad = data 64 | else: 65 | uid_tensor, loc_his_pad, loc_cur_pad, tim_w_his_pad, tim_w_cur_pad,\ 66 | tim_h_his_pad, tim_h_cur_pad = data 67 | target_mask = generate_mask(mask_batch[0]).unsqueeze(2).to(self.device) 68 | his_mask = generate_mask(mask_batch[1]).unsqueeze(2).to(self.device) 69 | cur_mask = generate_mask(mask_batch[2]).unsqueeze(2).to(self.device) 70 | 71 | # 2) embed 72 | uid_embed = self.user_embedder(uid_tensor) 73 | loc_his_embed, loc_cur_embed = self.loc_embedder(loc_his_pad), self.loc_embedder(loc_cur_pad) 74 | tim_week_his_embed, tim_week_cur_embed = self.tim_embedder_week(tim_w_his_pad), self.tim_embedder_week(tim_w_cur_pad) 75 | tim_hour_his_embed, tim_hour_cur_embed = self.tim_embedder_hour(tim_h_his_pad), self.tim_embedder_hour(tim_h_cur_pad) 76 | rnn_input_his_concat = torch.cat((uid_embed.expand(-1, loc_his_embed.shape[1], -1), loc_his_embed, tim_week_his_embed, tim_hour_his_embed), dim=-1) 77 | rnn_input_cur_concat = torch.cat((uid_embed.expand(-1, loc_cur_embed.shape[1], -1), loc_cur_embed, tim_week_cur_embed, tim_hour_cur_embed), dim=-1) 78 | if self.cat_contained: 79 | cat_his_embed, cat_cur_embed = self.cat_embedder(cat_his_pad), self.cat_embedder(cat_cur_pad) 80 | rnn_input_his_concat = torch.cat((rnn_input_his_concat, cat_his_embed), dim=-1) 81 | rnn_input_cur_concat = torch.cat((rnn_input_cur_concat, cat_cur_embed), dim=-1) 82 | 83 | # 3) rnn capturer 84 | # Version: seperate 85 | cur_t_rnn, hc_t = self.capturer_t(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:]) 86 | if self.cat_contained: 87 | cur_c_rnn, hc_c = self.capturer_c(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t) 88 | cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_c) 89 | 90 | # 4) tower, t,c,l 91 | # CMTL 92 | hc_t, hc_c, hc_l = hc_t.squeeze(), hc_c.squeeze(), hc_l.squeeze() 93 | 94 | c_pred = self.fc_c(hc_c) 95 | c_trans = self.label_trans_c(c_pred.clone()) 96 | t_pred = self.fc_t(torch.cat((hc_t, c_trans), dim=-1)) 97 | t_trans = self.label_trans_t(t_pred.clone()) 98 | l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1)) 99 | else: 100 | cur_l_rnn, hc_l = self.capturer_l(rnn_input_his_concat, rnn_input_cur_concat, his_mask, cur_mask, mask_batch[1:], hc_t) 101 | # 4) tower, t,c,l 102 | # CMTL 103 | hc_t, hc_l = hc_t.squeeze(), hc_l.squeeze() 104 | t_pred = self.fc_t(hc_t) 105 | t_trans = self.label_trans_t(t_pred.clone()) 106 | l_pred = self.fc_l(torch.cat((hc_l, t_trans), dim=-1)) 107 | 108 | valid_num = (target_mask==0).sum().item() 109 | 110 | if self.cat_contained: 111 | return t_pred, c_pred, l_pred, valid_num 112 | else: 113 | return t_pred, 0, l_pred, valid_num 114 | 115 | def calculate_loss(self, th_pred_in, c_pred_in, l_pred_in, target_batch, valid_num, pid_lat_lon_radians): 116 | 117 | # location loss with cross entropy 118 | l_pred = l_pred_in.reshape(-1, self.pid_size) 119 | l_target = target_batch[0].reshape(-1) 120 | loss_l = self.cross_entropy_loss(l_pred, l_target) / valid_num 121 | # time loss with mae loss 122 | th_target = target_batch[1].reshape(-1) 123 | loss_t = self.mae_loss(th_pred_in.squeeze(-1), th_target.float()) / valid_num 124 | 125 | loss_geocons = self.geo_con_loss(l_pred, l_target, pid_lat_lon_radians) / valid_num 126 | 127 | if self.cat_contained: 128 | # category loss with cross entropy 129 | c_pred = c_pred_in.reshape(-1, self.cid_size) 130 | c_target = target_batch[2].reshape(-1) 131 | loss_c = self.cross_entropy_loss(c_pred, c_target) / valid_num 132 | 133 | return loss_t, loss_c, loss_l, loss_geocons 134 | else: 135 | return loss_t, torch.tensor(0), loss_l, loss_geocons 136 | 137 | def geo_con_loss(self, l_pred_in, l_target, pid_lat_lon): 138 | 139 | log_softmax = nn.functional.log_softmax(l_pred_in, dim=-1) 140 | l_pred = torch.argmax(log_softmax, dim=-1) 141 | l_coor_pred = pid_lat_lon[l_pred] 142 | l_coor_tar = pid_lat_lon[l_target] 143 | 144 | dlat = l_coor_pred[:, 0] - l_coor_tar[:, 0] 145 | dlon = l_coor_pred[:, 1] - l_coor_tar[:, 1] 146 | # a = torch.sin(dlat/2) **2 + torch.cos(l_coor_pred[:, 0]) * torch.cos(l_coor_tar[:, 0]) * (torch.sin(dlon/2))**2 147 | # dist = EARTHRADIUS * 2 * torch.atan2(torch.sqrt(a), torch.sqrt(1-a)) 148 | # new version 149 | dist = dlat ** 2 + dlon ** 2 150 | loc_prob = log_softmax * dist.unsqueeze(-1) 151 | loss_geocons = F.nll_loss(loc_prob, l_target, ignore_index=0, reduction='sum') 152 | 153 | return loss_geocons 154 | 155 | 156 | 157 | class SessionCapturer(nn.Module): 158 | '''Expert module''' 159 | def __init__(self, RNN_input_dim, RNN_hid_dim, params): 160 | super().__init__() 161 | self.__dict__.update(params.__dict__) 162 | # RNN 163 | self.HistoryCapturer = RnnFactory(self.rnn_type).create(RNN_input_dim, RNN_hid_dim, self.rnn_layer_num, self.dropout) 164 | self.CurrentCapturer = RnnFactory(self.rnn_type).create(RNN_input_dim, RNN_hid_dim, self.rnn_layer_num, self.dropout) 165 | 166 | def forward(self, his_in, cur_in, his_mask, cur_mask, len_batch, hc=None): 167 | 168 | 169 | # 1) pack padded 170 | his_in_pack = pack_padded_sequence(his_in, len_batch[0], batch_first=True, enforce_sorted=False) 171 | cur_in_pack = pack_padded_sequence(cur_in, len_batch[1], batch_first=True, enforce_sorted=False) 172 | 173 | # 2) history capturer 174 | if hc == None: 175 | history_pack, history_hc = self.HistoryCapturer(his_in_pack) 176 | else: 177 | history_pack, history_hc = self.HistoryCapturer(his_in_pack, hc) 178 | # 3) current capturer 179 | current_pack, current_hc = self.CurrentCapturer(cur_in_pack, history_hc) 180 | 181 | # 4) unpack 182 | history_unpack, _ = pad_packed_sequence(history_pack, batch_first=True) 183 | current_unpack, _ = pad_packed_sequence(current_pack, batch_first=True) # (B, S, BH) 184 | 185 | 186 | # Version: concat 187 | # return current_unpack, current_hc, history_unpack 188 | 189 | # Version: basic 190 | return current_unpack, current_hc 191 | 192 | 193 | 194 | 195 | 196 | class RnnFactory(): 197 | ''' Creates the desired RNN unit. ''' 198 | 199 | def __init__(self, rnn_type): 200 | self.rnn_type = rnn_type 201 | 202 | def create(self, input_dim, hidden_dim, num_layer, dropout=0): 203 | if self.rnn_type == 'rnn': 204 | return nn.RNN(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout) 205 | if self.rnn_type == 'gru': 206 | return nn.GRU(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout) 207 | if self.rnn_type == 'lstm': 208 | return nn.LSTM(input_dim, hidden_dim, num_layer, batch_first=True, dropout=dropout) 209 | -------------------------------------------------------------------------------- /train_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | 8 | import torch 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from utils import * 12 | from model import * 13 | 14 | 15 | def settings(param=[]): 16 | parser = argparse.ArgumentParser() 17 | # data params 18 | parser.add_argument('--path_in', type=str, default='./data/', help="input data path") 19 | parser.add_argument('--path_out', type=str, default='./results/', help="output data path") 20 | parser.add_argument('--data_name', type=str, default='NYC', help="data name") 21 | parser.add_argument('--cat_contained', action='store_false', default=True, help="whether contain category") 22 | parser.add_argument('--out_filename', type=str, default='', help="output data filename") 23 | # train params 24 | parser.add_argument('--gpu', type=str, default='0', help="GPU index to choose") 25 | parser.add_argument('--run_num', type=int, default=1, help="run number") 26 | parser.add_argument('--epoch_num', type=int, default=10, help="epoch number") 27 | parser.add_argument('--batch_size', type=int, default=128, help="batch size") 28 | parser.add_argument('--learning_rate', type=float, default=1e-4, help="learning rate") 29 | parser.add_argument('--weight_decay', type=float, default=1e-6, help="weight decay") 30 | parser.add_argument('--evaluate_step', type=int, default=1, help="evaluate step") 31 | parser.add_argument('--lam_t', type=float, default=5, help="loss lambda time") 32 | parser.add_argument('--lam_c', type=float, default=10, help="loss lambda category") 33 | parser.add_argument('--lam_s', type=float, default=10, help="loss lambda for geographcal consistency") 34 | # model params 35 | # embedding 36 | parser.add_argument('--user_embed_dim', type=int, default=20, help="user embedding dimension") 37 | parser.add_argument('--loc_embed_dim', type=int, default=200, help="loc embedding dimension") 38 | parser.add_argument('--tim_h_embed_dim', type=int, default=20, help="time hour embedding dimension") 39 | parser.add_argument('--tim_w_embed_dim', type=int, default=10, help="time week embedding dimension") 40 | parser.add_argument('--cat_embed_dim', type=int, default=100, help="category embedding dimension") 41 | # rnn 42 | parser.add_argument('--rnn_type', type=str, default='gru', help="rnn type") 43 | parser.add_argument('--rnn_layer_num', type=int, default=1, help="rnn layer number") 44 | parser.add_argument('--rnn_t_hid_dim', type=int, default=600, help="rnn hidden dimension for t") 45 | parser.add_argument('--rnn_c_hid_dim', type=int, default=600, help="rnn hidden dimension for c") 46 | parser.add_argument('--rnn_l_hid_dim', type=int, default=600, help="rnn hidden dimension for l") 47 | parser.add_argument('--dropout', type=float, default=0.1, help="drop out for rnn") 48 | 49 | 50 | if __name__ == '__main__' and param == []: 51 | params = parser.parse_args() 52 | else: 53 | params = parser.parse_args(param) 54 | 55 | if not os.path.exists(params.path_out): 56 | os.mkdir(params.path_out) 57 | 58 | return params 59 | 60 | 61 | def train(params, dataset): 62 | 63 | # dataset info 64 | params.uid_size = len(dataset['uid_list']) 65 | params.pid_size = len(dataset['pid_dict']) 66 | params.cid_size = len(dataset['cid_dict']) if params.cat_contained else 0 67 | # generate input data 68 | data_train, train_id = dataset['train_data'], dataset['train_id'] 69 | data_test, test_id = dataset['test_data'], dataset['test_id'] 70 | pid_lat_lon = torch.tensor([[0, 0]] + list(dataset['pid_lat_lon'].values())).to(params.device) 71 | 72 | # model and optimizer 73 | model = Model(params).to(params.device) 74 | optimizer = torch.optim.Adam(model.parameters(), lr=params.learning_rate, weight_decay=params.weight_decay) 75 | print('==== Model is \n', model) 76 | get_model_params(model) 77 | print('==== Optimizer is \n', optimizer) 78 | 79 | # iterate epoch 80 | best_info_train = {'epoch':0, 'Recall@1':0} # best metrics 81 | best_info_test = {'epoch':0, 'Recall@1':0} # best metrics 82 | print('='*10, ' Training') 83 | for epoch in range(params.epoch_num): 84 | model.train() 85 | # variable 86 | loss_l_all = 0. 87 | loss_t_all = 0. 88 | loss_c_all = 0. 89 | loss_s_all = 0. 90 | loss_all = 0. 91 | valid_all = 0 92 | 93 | # train with batch 94 | time_start = time.time() 95 | print('==== Train', end=', ') 96 | for mask_batch, target_batch, data_batch in generate_batch_data(data_train, train_id, params.device, params.batch_size, params.cat_contained): 97 | # model forward 98 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch) 99 | # calcuate loss 100 | loss_t, loss_c, loss_l, loss_s = model.calculate_loss(th_pred, c_pred, l_pred, target_batch, valid_num, pid_lat_lon) 101 | loss = loss_l + params.lam_t * loss_t + params.lam_c * loss_c + params.lam_s * loss_s 102 | valid_all += valid_num 103 | loss_l_all += loss_l.item() * valid_num 104 | loss_t_all += loss_t.item() * valid_num 105 | loss_c_all += loss_c.item() * valid_num 106 | loss_s_all += loss_s.item() * valid_num 107 | loss_all += loss.item() * valid_num 108 | 109 | # backward 110 | optimizer.zero_grad() 111 | loss.backward() 112 | nn.utils.clip_grad_norm_(model.parameters(), 2.0) 113 | optimizer.step() 114 | 115 | train_time = time.time() - time_start 116 | loss_all /= valid_all 117 | loss_l_all /= valid_all 118 | loss_t_all /= valid_all 119 | loss_c_all /= valid_all 120 | loss_s_all /= valid_all 121 | 122 | 123 | # evaluation 124 | if epoch % params.evaluate_step == 0: 125 | # evaluate with train data 126 | print('==== Evaluate train data', end=', ') 127 | time_start = time.time() 128 | train_acc_l, train_acc_c, train_mse_t = evaluate(model, data_train, train_id, params) 129 | train_eval_time = time.time() - time_start 130 | # evaluate with test data 131 | print('==== Evaluate test data', end=', ') 132 | time_start = time.time() 133 | test_acc_l, test_acc_c, test_mse_t = evaluate(model, data_test, test_id, params) 134 | test_time = time.time() - time_start 135 | print(f'[Epoch={epoch+1}/{params.epoch_num}], loss={loss_all:.2f}, loss_l={loss_l_all:.2f},', end=' ') 136 | print(f'loss_t={loss_t_all:.2f}, loss_c={loss_c_all:.2f}, loss_s={loss_s_all:.2f};') 137 | print(f'Acc_loc: train_l={train_acc_l}, test_l={test_acc_l};') 138 | print(f'Acc_cat: train_c={train_acc_c}, test_c={test_acc_c};') 139 | print(f'MAE_time: train_t={train_mse_t:.2f}, test_t={test_mse_t:.2f};') 140 | print(f'Eval time cost: train={train_eval_time:.1f}s, test={test_time:.1f}s\n') 141 | # store info 142 | if best_info_train['Recall@1'] < train_acc_l[0]: 143 | best_info_train['epoch'] = epoch 144 | best_info_train['Recall@1'] = train_acc_l[0] 145 | best_info_train['Recall@all'] = train_acc_l 146 | best_info_train['MAE'] = train_mse_t 147 | best_info_train['model_params'] = model.state_dict() 148 | if best_info_test['Recall@1'] < test_acc_l[0]: 149 | best_info_test['epoch'] = epoch 150 | best_info_test['Recall@1'] = test_acc_l[0] 151 | best_info_test['Recall@all'] = test_acc_l 152 | best_info_test['MAE'] = test_mse_t 153 | best_info_test['model_params'] = model.state_dict() 154 | 155 | else: 156 | print(f'[Epoch={epoch+1}/{params.epoch_num}], loss={loss_all:.2f}, loss_l={loss_l_all:.2f},', end=' ') 157 | print(f'loss_t={loss_t_all:.2f}, loss_c={loss_c_all:.2f}, loss_s={loss_s_all:.2f};') 158 | 159 | # evaluation 160 | print('='*10, ' Testing') 161 | model.load_state_dict(best_info_test['model_params']) 162 | results_l, results_c, results_t = evaluate(model, data_test, test_id, params) 163 | print(f'Test results: loc={results_l}, cat={results_c}, tim={results_t:.2f}') 164 | 165 | # best metrics info 166 | print('='*10,' Run finished') 167 | print(f'Best train results is {best_info_train["Recall@all"]} at Epoch={best_info_train["epoch"]}') 168 | print(f'Best test results is {best_info_test["Recall@all"]} at Epoch={best_info_test["epoch"]}') 169 | 170 | return results_l, results_c, results_t, best_info_test 171 | 172 | 173 | def evaluate(model, data, data_id, params): 174 | '''Evaluate model performance 175 | ''' 176 | l_acc_all = np.zeros(3) 177 | c_acc_all = np.zeros(3) 178 | t_mse_all = 0. 179 | valid_num_all = 0 180 | model.eval() 181 | # evaluate with batch 182 | for mask_batch, target_batch, data_batch in generate_batch_data(data, data_id, params.device, params.batch_size, params.cat_contained): 183 | # model forward 184 | th_pred, c_pred, l_pred, valid_num = model(data_batch, mask_batch) 185 | 186 | # calculate metrics 187 | l_acc = calculate_recall(target_batch[0], l_pred) 188 | l_acc_all += l_acc 189 | t_mse_all += torch.nn.functional.l1_loss(th_pred.squeeze(-1), target_batch[1].squeeze(-1), reduction='sum').item() 190 | valid_num_all += valid_num 191 | 192 | if params.cat_contained: 193 | c_acc = calculate_recall(target_batch[2], c_pred) 194 | c_acc_all += c_acc 195 | 196 | return l_acc_all / valid_num_all, c_acc_all / valid_num_all, t_mse_all / valid_num_all 197 | 198 | 199 | if __name__ == '__main__': 200 | 201 | print('='*20, ' Program Start') 202 | params = settings() 203 | params.device = torch.device(f"cuda:{params.gpu}") 204 | print('Parameter is\n', params.__dict__) 205 | 206 | # file name to store 207 | params.file_out = f'{params.path_out}/{time.strftime("%Y%m%d")}_{params.data_name}_{params.out_filename}/' 208 | if not os.path.exists(params.file_out): 209 | os.makedirs(params.file_out) 210 | 211 | # Load data 212 | print('='*20, ' Loading data') 213 | start_time = time.time() 214 | if params.cat_contained: 215 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}_cat.pkl', 'rb')) 216 | else: 217 | dataset = pickle.load(open(f'{params.path_in}{params.data_name}.pkl', 'rb')) 218 | print(f'Finished, time cost is {time.time()-start_time:.1f}') 219 | 220 | # metrics 221 | metrics = pd.DataFrame() 222 | best_info_all_run = {'epoch':0, 'Recall@1':0} 223 | 224 | # start running 225 | print('='*20, "Start Training") 226 | print(time.strftime("%Y%m%d-%H:%M:%S")) 227 | for i in range(params.run_num): 228 | print('='*20, f'Run {i}') 229 | 230 | # To Revise 231 | results, results_c, results_t, best_info_one_run = train(params, dataset) 232 | metric_dict = {'Rec-l@1': results[0], 'Rec-l@5': results[1], 'Rec-l@10': results[2], 233 | 'MAE': results_t, 'Rec-c@1': results_c[0], 'Rec-c@5': results_c[1], 'Rec-c@10': results_c[2]} 234 | metric_tmp = pd.DataFrame(metric_dict, index=[i]) 235 | metrics = pd.concat([metrics, metric_tmp]) 236 | 237 | if best_info_all_run['Recall@1'] < best_info_one_run['Recall@1']: 238 | best_info_all_run = best_info_one_run.copy() 239 | best_info_all_run['run'] = i 240 | 241 | 242 | 243 | print('='*20, "Finished") 244 | mean = pd.DataFrame(metrics.mean()).T 245 | mean.index = ['mean'] 246 | std = pd.DataFrame(metrics.std()).T 247 | std.index = ['std'] 248 | metrics = pd.concat([metrics, mean, std]) 249 | print(metrics) 250 | 251 | # save 252 | metrics.to_csv(f'{params.file_out}metrics.csv') 253 | print('='*20, f'\nMetrics saved.') 254 | torch.save(best_info_all_run["model_params"], f'{params.file_out}model.pkl') 255 | print(f'Model saved (Run={best_info_all_run["run"]}, Epoch={best_info_all_run["epoch"]})') 256 | print(time.strftime("%Y%m%d-%H:%M:%S")) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | from torch.nn.utils.rnn import pad_sequence 6 | import random 7 | 8 | def generate_batch_data(data_input, data_id, device, batch_size, cat_contained): 9 | '''generate batch data''' 10 | 11 | # generate (uid, sid) queue 12 | data_queue = list() 13 | uid_list = data_id.keys() 14 | for uid in uid_list: 15 | for sid in data_id[uid]: 16 | for tar_idx in range(len(data_input[uid][sid]['target_l'])): 17 | data_queue.append((uid, sid, tar_idx)) 18 | 19 | # generate batch data 20 | data_len = len(data_queue) 21 | batch_num = int(data_len/batch_size) 22 | random.shuffle(data_queue) 23 | print(f'Number of batch is {batch_num}') 24 | # iterate batch number times 25 | for i in range(batch_num): 26 | # batch data 27 | uid_batch = [] 28 | loc_cur_batch = [] 29 | tim_w_cur_batch = [] 30 | tim_h_cur_batch = [] 31 | loc_his_batch = [] 32 | tim_w_his_batch = [] 33 | tim_h_his_batch = [] 34 | target_l_batch = [] 35 | target_c_batch = [] 36 | target_th_batch = [] 37 | target_len_batch = [] 38 | history_len_batch = [] 39 | current_len_batch = [] 40 | if cat_contained: 41 | cat_cur_batch = [] 42 | cat_his_batch = [] 43 | 44 | if i % 100 == 0: 45 | print('====', f'[Batch={i}/{batch_num}]', end=', ') 46 | 47 | batch_data = data_queue[i * batch_size : (i+1) * batch_size] 48 | # iterate batch index 49 | for one_data in batch_data: 50 | uid, sid, tar_idx = one_data 51 | uid_batch.append([uid]) 52 | # current 53 | in_idx = tar_idx + 1 54 | loc_cur_batch.append(torch.LongTensor(data_input[uid][sid]['loc'][1][:in_idx])) 55 | tim_cur_ts = torch.LongTensor(data_input[uid][sid]['tim'][1][:in_idx]) 56 | tim_w_cur_batch.append(tim_cur_ts[:, 0]) 57 | tim_h_cur_batch.append(tim_cur_ts[:, 1]) 58 | current_len_batch.append(tim_cur_ts.shape[0]) 59 | # history 60 | loc_his_batch.append(torch.LongTensor(data_input[uid][sid]['loc'][0])) 61 | tim_his_ts = torch.LongTensor(data_input[uid][sid]['tim'][0]) 62 | tim_w_his_batch.append(tim_his_ts[:, 0]) 63 | tim_h_his_batch.append(tim_his_ts[:, 1]) 64 | history_len_batch.append(tim_his_ts.shape[0]) 65 | # target 66 | target_l = torch.LongTensor([data_input[uid][sid]['target_l'][tar_idx]]) 67 | target_l_batch.append(target_l) 68 | target_len_batch.append(target_l.shape[0]) 69 | target_th_batch.append(torch.LongTensor([data_input[uid][sid]['target_th'][tar_idx]])) 70 | # catrgory 71 | if cat_contained: 72 | cat_his_batch.append(torch.LongTensor(data_input[uid][sid]['cat'][0])) 73 | cat_cur_batch.append(torch.LongTensor(data_input[uid][sid]['cat'][1][:in_idx])) 74 | target_c_batch.append(torch.LongTensor([data_input[uid][sid]['target_c'][tar_idx]])) 75 | 76 | 77 | 78 | # padding 79 | uid_batch_tensor = torch.LongTensor(uid_batch).to(device) 80 | # current 81 | loc_cur_batch_pad = pad_sequence(loc_cur_batch, batch_first=True).to(device) 82 | tim_w_cur_batch_pad = pad_sequence(tim_w_cur_batch, batch_first=True).to(device) 83 | tim_h_cur_batch_pad = pad_sequence(tim_h_cur_batch, batch_first=True).to(device) 84 | # history 85 | loc_his_batch_pad = pad_sequence(loc_his_batch, batch_first=True).to(device) 86 | tim_w_his_batch_pad = pad_sequence(tim_w_his_batch, batch_first=True).to(device) 87 | tim_h_his_batch_pad = pad_sequence(tim_h_his_batch, batch_first=True).to(device) 88 | # target 89 | target_l_batch_pad = pad_sequence(target_l_batch, batch_first=True).to(device) 90 | target_th_batch_pad = pad_sequence(target_th_batch, batch_first=True).to(device) 91 | 92 | if cat_contained: 93 | cat_his_batch_pad = pad_sequence(cat_his_batch, batch_first=True).to(device) 94 | cat_cur_batch_pad = pad_sequence(cat_cur_batch, batch_first=True).to(device) 95 | target_c_batch_pad = pad_sequence(target_c_batch, batch_first=True).to(device) 96 | yield (target_len_batch, history_len_batch, current_len_batch),\ 97 | (target_l_batch_pad, target_th_batch_pad, target_c_batch_pad),\ 98 | (uid_batch_tensor,\ 99 | loc_his_batch_pad, loc_cur_batch_pad,\ 100 | tim_w_his_batch_pad, tim_w_cur_batch_pad,\ 101 | tim_h_his_batch_pad, tim_h_cur_batch_pad,\ 102 | cat_his_batch_pad, cat_cur_batch_pad) 103 | else: 104 | yield (target_len_batch, history_len_batch, current_len_batch),\ 105 | (target_l_batch_pad, target_th_batch_pad),\ 106 | (uid_batch_tensor,\ 107 | loc_his_batch_pad, loc_cur_batch_pad,\ 108 | tim_w_his_batch_pad, tim_w_cur_batch_pad,\ 109 | tim_h_his_batch_pad, tim_h_cur_batch_pad) 110 | 111 | 112 | 113 | 114 | print('Batch Finished') 115 | 116 | 117 | def generate_mask(data_len): 118 | '''Generate mask 119 | Args: 120 | data_len : one dimension list, reflect sequence length 121 | ''' 122 | mask = [] 123 | for i_len in data_len: 124 | mask.append(torch.ones(i_len).bool()) 125 | return ~pad_sequence(mask, batch_first=True) 126 | 127 | 128 | 129 | def calculate_recall(target_pad, pred_pad): 130 | '''Calculate recall 131 | Args: 132 | target: (batch, max_seq_len), padded target 133 | pred: (batch, max_seq_len, pred_scores), padded 134 | ''' 135 | # variable 136 | acc = np.zeros(3) # 1, 5, 10 137 | 138 | # reshape and to numpy 139 | target_list = target_pad.data.reshape(-1).cpu().numpy() 140 | # topK 141 | pid_size = pred_pad.shape[-1] 142 | _, pred_list = pred_pad.data.reshape(-1, pid_size).topk(20) 143 | pred_list = pred_list.cpu().numpy() 144 | 145 | for idx, pred in enumerate(pred_list): 146 | target = target_list[idx] 147 | if target == 0: # pad 148 | continue 149 | if target in pred[:1]: 150 | acc += 1 151 | elif target in pred[:5]: 152 | acc[1:] += 1 153 | elif target in pred[:10]: 154 | acc[2:] += 1 155 | 156 | return acc 157 | 158 | 159 | def get_model_params(model): 160 | total_num = sum(param.numel() for param in model.parameters()) 161 | trainable_num = sum(param.numel() for param in model.parameters() if param.requires_grad) 162 | print(f'==== Parameter numbers:\n total={total_num}, trainable={trainable_num}') 163 | 164 | 165 | 166 | class progress_supervisor(object): 167 | def __init__(self, all_num, path): 168 | self.cur_num = 1 169 | self.start_time = time.time() 170 | self.all_num = all_num 171 | self.path = path 172 | 173 | with open(self.path, 'w') as f: 174 | f.write('Start') 175 | 176 | def update(self): 177 | '''Usage: 178 | count_time = count_run_time(5 * 4 * 4) 179 | count_time.path = f'{args.out_dir}{args.model_name}_{args.data_name}.txt' 180 | main() 181 | count_time.current_count() 182 | ''' 183 | 184 | past_time = time.time()-self.start_time 185 | avg_time = past_time / self.cur_num 186 | fut_time = avg_time * (self.all_num - self.cur_num) 187 | 188 | content = '=' * 10 + ' Progress observation' 189 | content += f'Current time is {time.strftime("%Y-%m-%d %H:%M:%S")}\n' 190 | content += f'Current Num: {self.cur_num} / {self.all_num}\n' 191 | content += f'Past time: {past_time:.2f}s ({past_time/3600:.2f}h)\n' 192 | content += f'Average time: {avg_time:.2f}s ({avg_time/3600:.2f}h)\n' 193 | content += f'Future time: {fut_time:.2f}s ({fut_time/3600:.2f}h)\n' 194 | 195 | with open(self.path, 'w') as f: 196 | f.write(content) 197 | 198 | self.cur_num += 1 199 | return content 200 | 201 | def delete(self): 202 | if os.path.exists(self.path): 203 | os.remove(self.path) 204 | 205 | if not os.path.exists(self.path): 206 | print('Supervisor file delete success') --------------------------------------------------------------------------------