├── LICENSE ├── README.md ├── data ├── yelp-split.pkl ├── yelp.pkl └── yelp.txt ├── data_utils.py ├── main.py ├── model.py ├── options.py ├── preprocess_data.py ├── reco_gan_rl.ipynb ├── save_dir ├── best-loss ├── best-pre1 └── best-pre2 └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 rushhan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch 2 | Pytorch implementation of Recommendation System based on Generative Adverserial Reinforcement Learning based user Model 3 | Implementation of the paper under same title [paper](http://proceedings.mlr.press/v97/chen19f/chen19f.pdf) 4 | 5 | ## This Repo Include: 6 | 1. Data Necessary (Yelp Reviews) 7 | 2. Data Preprocessing 8 | 3. Position Weigth (PW) Model 9 | 4. Hyperparameter Tuned Model 10 | 11 | 12 | ## To Train: 13 | Simply Run reco_gan_rl jupyter notebook. It includes both preprocessing of the data as well as training 14 | 15 | ## To DO: 16 | 1. Add LSTM Model 17 | 18 | ## Note: 19 | The data processing is mainly based on original implementation so that the results could be compared .[Source](https://github.com/xinshi-chen/GenerativeAdversarialUserModel) 20 | -------------------------------------------------------------------------------- /data/yelp-split.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rushhan/Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch/d7e34edd8013f9ef468d48c3d7765d17e8ed09dc/data/yelp-split.pkl -------------------------------------------------------------------------------- /data/yelp.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rushhan/Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch/d7e34edd8013f9ef468d48c3d7765d17e8ed09dc/data/yelp.pkl -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from past.builtins import xrange 4 | import pickle 5 | import numpy as np 6 | import os 7 | 8 | # almost similar to the original implementations 9 | 10 | class Dataset(object): 11 | """docstring for Dataset""" 12 | def __init__(self, args): 13 | super(Dataset, self).__init__() 14 | 15 | self.data_folder = args.data_folder 16 | self.dataset = args.dataset 17 | self.model_type = args.user_model 18 | self.band_size = args.pw_band_size 19 | #load the data 20 | data_filename = os.path.join(args.data_folder, args.dataset+'.pkl') 21 | f = open(data_filename, 'rb') 22 | data_behavior = pickle.load(f) # time and user behavior 23 | item_feature = pickle.load(f) # identity matrix 24 | f.close() 25 | 26 | self.size_item = len(item_feature) 27 | self.size_user = len(data_behavior) 28 | self.f_dim = len(item_feature[0]) 29 | 30 | # load the index fo train,test,valid split 31 | 32 | filename = os.path.join(self.data_folder, self.dataset+'-split.pkl') 33 | pkl_file = open(filename, 'rb') 34 | self.train_user = pickle.load(pkl_file) 35 | self.vali_user = pickle.load(pkl_file) 36 | self.test_user = pickle.load(pkl_file) 37 | pkl_file.close() 38 | 39 | 40 | # process the data 41 | 42 | # get the most no of suggetion for an individual at a time 43 | k_max = 0 44 | for d_b in data_behavior: 45 | for disp in d_b[1]: 46 | k_max = max(k_max, len(disp)) 47 | 48 | self.data_click = [[] for x in xrange(self.size_user)] 49 | self.data_disp = [[] for x in xrange(self.size_user)] 50 | self.data_time = np.zeros(self.size_user, dtype=np.int) 51 | self.data_news_cnt = np.zeros(self.size_user, dtype=np.int) 52 | self.feature = [[] for x in xrange(self.size_user)] 53 | self.feature_click = [[] for x in xrange(self.size_user)] 54 | 55 | for user in xrange(self.size_user): 56 | # (1) count number of clicks 57 | click_t = 0 58 | num_events = len(data_behavior[user][1]) 59 | click_t += num_events 60 | self.data_time[user] = click_t 61 | # (2) 62 | news_dict = {} 63 | self.feature_click[user] = np.zeros([click_t, self.f_dim]) 64 | click_t = 0 65 | for event in xrange(num_events): 66 | disp_list = data_behavior[user][1][event] 67 | pick_id = data_behavior[user][2][event] 68 | for id in disp_list: 69 | if id not in news_dict: 70 | news_dict[id] = len(news_dict) # for each user, news id start from 0 71 | id = pick_id 72 | self.data_click[user].append([click_t, news_dict[id]]) 73 | self.feature_click[user][click_t] = item_feature[id] 74 | for idd in disp_list: 75 | self.data_disp[user].append([click_t, news_dict[idd]]) 76 | click_t += 1 # splitter a event with 2 clickings to 2 events 77 | 78 | self.data_news_cnt[user] = len(news_dict) 79 | 80 | self.feature[user] = np.zeros([self.data_news_cnt[user], self.f_dim]) 81 | 82 | for id in news_dict: 83 | self.feature[user][news_dict[id]] = item_feature[id] 84 | self.feature[user] = self.feature[user].tolist() 85 | self.feature_click[user] = self.feature_click[user].tolist() 86 | self.max_disp_size = k_max 87 | 88 | def random_split_user(self): 89 | # dont think this one is really necessary if the initial split is random enough 90 | num_users = len(self.train_user) + len(self.vali_user) + len(self.test_user) 91 | shuffle_order = np.arange(num_users) 92 | np.random.shuffle(shuffle_order) 93 | self.train_user = shuffle_order[0:len(self.train_user)].tolist() 94 | self.vali_user = shuffle_order[len(self.train_user):len(self.train_user)+len(self.vali_user)].tolist() 95 | self.test_user = shuffle_order[len(self.train_user)+len(self.vali_user):].tolist() 96 | 97 | def data_process_for_placeholder(self, user_set): 98 | #print ("user_set",user_set) 99 | if self.model_type == 'PW': 100 | sec_cnt_x = 0 101 | news_cnt_short_x = 0 102 | news_cnt_x = 0 103 | click_2d_x = [] 104 | disp_2d_x = [] 105 | 106 | tril_indice = [] 107 | tril_value_indice = [] 108 | 109 | disp_2d_split_sec = [] 110 | feature_clicked_x = [] 111 | 112 | disp_current_feature_x = [] 113 | click_sub_index_2d = [] 114 | 115 | # started with the validation set 116 | #print (user_set) 117 | #[703, 713, 723, 733, 743, 753, 763, 773, 783, 793, 803, 813, 823, 833, 843, 853, 863, 873, 883, 893, 903, 913, 923, 933, 943, 953, 963, 973, 983, 993, 1003, 1013, 1023, 1033, 1043, 1053] 118 | #user_set = [703] 119 | for u in user_set: 120 | t_indice = [] 121 | #print ("the us is ",u) 703 122 | #print (self.band_size,self.data_time[u]) 20,1 123 | 124 | #print ("the loop",self.data_time[u]-1) 125 | 126 | for kk in xrange(min(self.band_size-1, self.data_time[u]-1)): 127 | t_indice += map(lambda x: [x + kk+1 + sec_cnt_x, x + sec_cnt_x], np.arange(self.data_time[u] - (kk+1))) 128 | # print (t_indice) [] for 703 129 | 130 | tril_indice += t_indice 131 | tril_value_indice += map(lambda x: (x[0] - x[1] - 1), t_indice) 132 | #print ("THE Click data is ",self.data_click[u]) #THE Click data is [[0, 0], [1, 8], [2, 14]] for u =15 133 | click_2d_tmp = map(lambda x: [x[0] + sec_cnt_x, x[1]], self.data_click[u]) 134 | click_2d_tmp = list(click_2d_tmp) 135 | #print (list(click_2d_tmp)) 136 | #print (list(click_2d_tmp)) 137 | click_2d_x += click_2d_tmp 138 | #print ("tenp is ",click_2d_x,list(click_2d_tmp)) # [[0, 0], [1, 8], [2, 14]] for u15 139 | #print ("dispaly data is ", self.data_disp[u]) [0,0] 140 | 141 | 142 | disp_2d_tmp = map(lambda x: [x[0] + sec_cnt_x, x[1]], self.data_disp[u]) 143 | disp_2d_tmp = list(disp_2d_tmp) 144 | #y=[] 145 | #y+=disp_2d_tmp 146 | 147 | 148 | 149 | 150 | #print (disp_2d_tmp, click_2d_tmp) 151 | click_sub_index_tmp = map(lambda x: disp_2d_tmp.index(x), (click_2d_tmp)) 152 | click_sub_index_tmp = list(click_sub_index_tmp) 153 | #print ("the mess is ",click_sub_index_tmp) 154 | click_sub_index_2d += map(lambda x: x+len(disp_2d_x), click_sub_index_tmp) 155 | #print ("click_sub_index_2d",click_sub_index_2d) 156 | disp_2d_x += disp_2d_tmp 157 | #print ("disp_2d_x",disp_2d_x) # [[0, 0]] 158 | #sys.exit() 159 | disp_2d_split_sec += map(lambda x: x[0] + sec_cnt_x, self.data_disp[u]) 160 | 161 | sec_cnt_x += self.data_time[u] 162 | news_cnt_short_x = max(news_cnt_short_x, self.data_news_cnt[u]) 163 | news_cnt_x += self.data_news_cnt[u] 164 | disp_current_feature_x += map(lambda x: self.feature[u][x], [idd[1] for idd in self.data_disp[u]]) 165 | feature_clicked_x += self.feature_click[u] 166 | 167 | out1 ={} 168 | out1['click_2d_x']=click_2d_x 169 | out1['disp_2d_x']=disp_2d_x 170 | out1['disp_current_feature_x']=disp_current_feature_x 171 | out1['sec_cnt_x']=sec_cnt_x 172 | out1['tril_indice']=tril_indice 173 | out1['tril_value_indice']=tril_value_indice 174 | out1['disp_2d_split_sec']=disp_2d_split_sec 175 | out1['news_cnt_short_x']=news_cnt_short_x 176 | out1['click_sub_index_2d']=click_sub_index_2d 177 | out1['feature_clicked_x']=feature_clicked_x 178 | # print ("out",out1['tril_value_indice']) 179 | # sys.exit() 180 | return out1 181 | 182 | else: 183 | news_cnt_short_x = 0 184 | u_t_dispid = [] 185 | u_t_dispid_split_ut = [] 186 | u_t_dispid_feature = [] 187 | 188 | u_t_clickid = [] 189 | 190 | size_user = len(user_set) 191 | max_time = 0 192 | 193 | click_sub_index = [] 194 | 195 | for u in user_set: 196 | max_time = max(max_time, self.data_time[u]) 197 | 198 | user_time_dense = np.zeros([size_user, max_time], dtype=np.float32) 199 | click_feature = np.zeros([max_time, size_user, self.f_dim]) 200 | 201 | for u_idx in xrange(size_user): 202 | u = user_set[u_idx] 203 | 204 | u_t_clickid_tmp = [] 205 | u_t_dispid_tmp = [] 206 | 207 | for x in self.data_click[u]: 208 | t, click_id = x 209 | click_feature[t][u_idx] = self.feature[u][click_id] 210 | u_t_clickid_tmp.append([u_idx, t, click_id]) 211 | user_time_dense[u_idx, t] = 1.0 212 | 213 | u_t_clickid = u_t_clickid + u_t_clickid_tmp 214 | 215 | for x in self.data_disp[u]: 216 | t, disp_id = x 217 | u_t_dispid_tmp.append([u_idx, t, disp_id]) 218 | u_t_dispid_split_ut.append([u_idx, t]) 219 | u_t_dispid_feature.append(self.feature[u][disp_id]) 220 | 221 | click_sub_index_tmp = map(lambda x: u_t_dispid_tmp.index(x), u_t_clickid_tmp) 222 | click_sub_index += map(lambda x: x+len(u_t_dispid), click_sub_index_tmp) 223 | 224 | u_t_dispid = u_t_dispid + u_t_dispid_tmp 225 | news_cnt_short_x = max(news_cnt_short_x, self.data_news_cnt[u]) 226 | 227 | if self.model_type != 'LSTM': 228 | print('model type not supported. using LSTM') 229 | 230 | out = {} 231 | 232 | out['size_user']=size_user 233 | out['max_time']=max_time 234 | out['news_cnt_short_x']=news_cnt_short_x 235 | out['u_t_dispid']=u_t_dispid 236 | out['u_t_dispid_split_ut']=u_t_dispid_split_ut 237 | out['u_t_dispid_feature']=np.array(u_t_dispid_feature) 238 | out['click_feature']=click_feature 239 | out['click_sub_index']=click_sub_index 240 | out['u_t_clickid']=u_t_clickid 241 | out['user_time_dense']=user_time_dense 242 | return out 243 | 244 | 245 | def data_process_for_placeholder_L2(self, user_set): 246 | news_cnt_short_x = 0 247 | u_t_dispid = [] 248 | u_t_dispid_split_ut = [] 249 | u_t_dispid_feature = [] 250 | 251 | u_t_clickid = [] 252 | 253 | size_user = len(user_set) 254 | max_time = 0 255 | 256 | click_sub_index = [] 257 | 258 | for u in user_set: 259 | max_time = max(max_time, self.data_time[u]) 260 | 261 | user_time_dense = np.zeros([size_user, max_time], dtype=np.float32) 262 | click_feature = np.zeros([max_time, size_user, self.f_dim]) 263 | 264 | for u_idx in xrange(size_user): 265 | u = user_set[u_idx] 266 | 267 | item_cnt = [{} for _ in xrange(self.data_time[u])] 268 | 269 | u_t_clickid_tmp = [] 270 | u_t_dispid_tmp = [] 271 | for x in self.data_disp[u]: 272 | t, disp_id = x 273 | u_t_dispid_split_ut.append([u_idx, t]) 274 | u_t_dispid_feature.append(self.feature[u][disp_id]) 275 | if disp_id not in item_cnt[t]: 276 | item_cnt[t][disp_id] = len(item_cnt[t]) 277 | u_t_dispid_tmp.append([u_idx, t, item_cnt[t][disp_id]]) 278 | 279 | for x in self.data_click[u]: 280 | t, click_id = x 281 | click_feature[t][u_idx] = self.feature[u][click_id] 282 | u_t_clickid_tmp.append([u_idx, t, item_cnt[t][click_id]]) 283 | user_time_dense[u_idx, t] = 1.0 284 | 285 | u_t_clickid = u_t_clickid + u_t_clickid_tmp 286 | 287 | click_sub_index_tmp = map(lambda x: u_t_dispid_tmp.index(x), u_t_clickid_tmp) 288 | click_sub_index += map(lambda x: x+len(u_t_dispid), click_sub_index_tmp) 289 | 290 | u_t_dispid = u_t_dispid + u_t_dispid_tmp 291 | # news_cnt_short_x = max(news_cnt_short_x, data_news_cnt[u]) 292 | news_cnt_short_x = self.max_disp_size 293 | 294 | out = {} 295 | 296 | out['size_user']=size_user 297 | out['max_time']=max_time 298 | out['news_cnt_short_x']=news_cnt_short_x 299 | out['u_t_dispid']=u_t_dispid 300 | out['u_t_dispid_split_ut']=u_t_dispid_split_ut 301 | out['u_t_dispid_feature']=np.array(u_t_dispid_feature) 302 | out['click_feature']=click_feature 303 | out['click_sub_index']=click_sub_index 304 | out['u_t_clickid']=u_t_clickid 305 | out['user_time_dense']=user_time_dense 306 | return out 307 | 308 | def prepare_validation_data_L2(self, num_sets, v_user): 309 | vali_thread_u = [[] for _ in xrange(num_sets)] 310 | size_user_v = [[] for _ in xrange(num_sets)] 311 | max_time_v = [[] for _ in xrange(num_sets)] 312 | news_cnt_short_v = [[] for _ in xrange(num_sets)] 313 | u_t_dispid_v = [[] for _ in xrange(num_sets)] 314 | u_t_dispid_split_ut_v = [[] for _ in xrange(num_sets)] 315 | u_t_dispid_feature_v = [[] for _ in xrange(num_sets)] 316 | click_feature_v = [[] for _ in xrange(num_sets)] 317 | click_sub_index_v = [[] for _ in xrange(num_sets)] 318 | u_t_clickid_v = [[] for _ in xrange(num_sets)] 319 | ut_dense_v = [[] for _ in xrange(num_sets)] 320 | for ii in xrange(len(v_user)): 321 | vali_thread_u[ii % num_sets].append(v_user[ii]) 322 | for ii in xrange(num_sets): 323 | out=self.data_process_for_placeholder_L2(vali_thread_u[ii]) 324 | size_user_v[ii], max_time_v[ii], news_cnt_short_v[ii], u_t_dispid_v[ii],\ 325 | u_t_dispid_split_ut_v[ii], u_t_dispid_feature_v[ii], click_feature_v[ii], \ 326 | click_sub_index_v[ii], u_t_clickid_v[ii], ut_dense_v[ii] = out['size_user'],\ 327 | out['max_time'],\ 328 | out['news_cnt_short_x'],\ 329 | out['u_t_dispid'], \ 330 | out['u_t_dispid_split_ut'],\ 331 | out['u_t_dispid_feature'],\ 332 | out['click_feature'],\ 333 | out['click_sub_index'],\ 334 | out['u_t_clickid'],\ 335 | out['user_time_dense'] 336 | 337 | out2={} 338 | out2['vali_thread_u']=vali_thread_u 339 | out2['size_user_v']=size_user_v 340 | out2['max_time_v']=max_time_v 341 | out2['news_cnt_short_v'] =news_cnt_short_v 342 | out2['u_t_dispid_v'] =u_t_dispid_v 343 | out2['u_t_dispid_split_ut_v']=u_t_dispid_split_ut_v 344 | out2['u_t_dispid_feature_v']=u_t_dispid_feature_v 345 | out2['click_feature_v']=click_feature_v 346 | out2['click_sub_index_v']=click_sub_index_v 347 | out2['u_t_clickid_v']=u_t_clickid_v 348 | out2['ut_dense_v']=ut_dense_v 349 | 350 | return out2 351 | 352 | 353 | def prepare_validation_data(self, num_sets, v_user): 354 | 355 | if self.model_type == 'PW': 356 | vali_thread_u = [[] for _ in xrange(num_sets)] 357 | click_2d_v = [[] for _ in xrange(num_sets)] 358 | disp_2d_v = [[] for _ in xrange(num_sets)] 359 | feature_v = [[] for _ in xrange(num_sets)] 360 | sec_cnt_v = [[] for _ in xrange(num_sets)] 361 | tril_ind_v = [[] for _ in xrange(num_sets)] 362 | tril_value_ind_v = [[] for _ in xrange(num_sets)] 363 | disp_2d_split_sec_v = [[] for _ in xrange(num_sets)] 364 | feature_clicked_v = [[] for _ in xrange(num_sets)] 365 | news_cnt_short_v = [[] for _ in xrange(num_sets)] 366 | click_sub_index_2d_v = [[] for _ in xrange(num_sets)] 367 | for ii in xrange(len(v_user)): 368 | vali_thread_u[ii % num_sets].append(v_user[ii]) 369 | for ii in xrange(num_sets): 370 | out=self.data_process_for_placeholder(vali_thread_u[ii]) 371 | # print ("out_val",out['tril_indice']) 372 | # sys.exit() 373 | 374 | click_2d_v[ii], disp_2d_v[ii], feature_v[ii], sec_cnt_v[ii], tril_ind_v[ii], tril_value_ind_v[ii], \ 375 | disp_2d_split_sec_v[ii], news_cnt_short_v[ii], click_sub_index_2d_v[ii], feature_clicked_v[ii] = out['click_2d_x'], \ 376 | out['disp_2d_x'], \ 377 | out['disp_current_feature_x'], \ 378 | out['sec_cnt_x'], \ 379 | out['tril_indice'], \ 380 | out['tril_value_indice'], \ 381 | out['disp_2d_split_sec'], \ 382 | out['news_cnt_short_x'], \ 383 | out['click_sub_index_2d'], \ 384 | out['feature_clicked_x'] 385 | 386 | out2={} 387 | out2['vali_thread_u']=vali_thread_u 388 | out2['click_2d_v']=click_2d_v 389 | out2['disp_2d_v']=disp_2d_v 390 | out2['feature_v']=feature_v 391 | out2['sec_cnt_v']=sec_cnt_v 392 | out2['tril_ind_v']=tril_ind_v 393 | out2['tril_value_ind_v']=tril_value_ind_v 394 | out2['disp_2d_split_sec_v']=disp_2d_split_sec_v 395 | out2['news_cnt_short_v']=news_cnt_short_v 396 | out2['click_sub_index_2d_v']=click_sub_index_2d_v 397 | out2['feature_clicked_v']=feature_clicked_v 398 | return out2 399 | 400 | else: 401 | if self.model_type != 'LSTM': 402 | print('model type not supported. using LSTM') 403 | vali_thread_u = [[] for _ in xrange(num_sets)] 404 | size_user_v = [[] for _ in xrange(num_sets)] 405 | max_time_v = [[] for _ in xrange(num_sets)] 406 | news_cnt_short_v = [[] for _ in xrange(num_sets)] 407 | u_t_dispid_v = [[] for _ in xrange(num_sets)] 408 | u_t_dispid_split_ut_v = [[] for _ in xrange(num_sets)] 409 | u_t_dispid_feature_v = [[] for _ in xrange(num_sets)] 410 | click_feature_v = [[] for _ in xrange(num_sets)] 411 | click_sub_index_v = [[] for _ in xrange(num_sets)] 412 | u_t_clickid_v = [[] for _ in xrange(num_sets)] 413 | ut_dense_v = [[] for _ in xrange(num_sets)] 414 | for ii in xrange(len(v_user)): 415 | vali_thread_u[ii % num_sets].append(v_user[ii]) 416 | for ii in xrange(num_sets): 417 | out = self.data_process_for_placeholder(vali_thread_u[ii]) 418 | size_user_v[ii], max_time_v[ii], news_cnt_short_v[ii], u_t_dispid_v[ii],\ 419 | u_t_dispid_split_ut_v[ii], u_t_dispid_feature_v[ii], click_feature_v[ii], \ 420 | click_sub_index_v[ii], u_t_clickid_v[ii], ut_dense_v[ii] = out['click_2d_x'], \ 421 | out['disp_2d_x'], \ 422 | out['disp_current_feature_x'], \ 423 | out['sec_cnt_x'], \ 424 | out['tril_indice'], \ 425 | out['tril_value_indice'], \ 426 | out['disp_2d_split_sec'], \ 427 | out['news_cnt_short_x'], \ 428 | out['click_sub_index_2d'], \ 429 | out['feature_clicked_x'] 430 | 431 | 432 | 433 | out2 = {} 434 | 435 | 436 | out2['vali_thread_u']=vali_thread_u 437 | out2['size_user_v']=size_user_v 438 | out2['max_time_v']=max_time_v 439 | out2['news_cnt_short_v']=news_cnt_short_v 440 | out2['u_t_dispid_v']=u_t_dispid_v 441 | out2['u_t_dispid_split_ut_v']=u_t_dispid_split_ut_v 442 | out2['u_t_dispid_feature_v']=u_t_dispid_feature_v 443 | out2['click_feature_v']=click_feature_v 444 | out2['click_sub_index_v']=click_sub_index_v 445 | out2['u_t_clickid_v']=u_t_clickid_v 446 | out2['ut_dense_v']=ut_dense_v 447 | return out2 448 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # define the main training and inference here 2 | import sys 3 | sys.path 4 | sys.path.append('../../') 5 | 6 | import torch 7 | from torch import nn 8 | #import Dataloader 9 | import tqdm 10 | from options import get_options 11 | from torch.utils.data import DataLoader 12 | import tqdm 13 | import os 14 | import json 15 | import pprint as pp 16 | from model import UserModelPW 17 | from data_utils import * 18 | import torch.optim as optim 19 | 20 | 21 | import datetime 22 | import numpy as np 23 | import os 24 | import threading 25 | 26 | #train,test,val=get_split(ops.Datadir) 27 | 28 | #train_input = get_data(train) 29 | #val_input = get_data(val) 30 | #test_input = get_data(test) 31 | 32 | 33 | 34 | ''' 35 | @for i in range(pot.iter): 36 | train_model 37 | 38 | 39 | if iter//some_fixed_no==0: 40 | validate 41 | 42 | test_model 43 | plot_figs 44 | ''' 45 | 46 | def multithread_compute_vali(opts,valid_data,model): 47 | global vali_sum, vali_cnt 48 | 49 | vali_sum = [0.0, 0.0, 0.0] 50 | vali_cnt = 0 51 | threads = [] 52 | for ii in xrange(opts.num_thread): 53 | #print ("got here") 54 | #print (dataset.model_type) 55 | #print (" [dataset.vali_user[ii]]", [dataset.vali_user[ii]]) 56 | #valid_data = dataset.prepare_validation_data(1, [dataset.vali_user[15]]) # is a dict 57 | 58 | # print ("valid_data",valid_data) 59 | #sys.exit() 60 | 61 | thread = threading.Thread(target=vali_eval, args=(1, ii,opts,valid_data,model)) 62 | thread.start() 63 | threads.append(thread) 64 | 65 | for thread in threads: 66 | thread.join() 67 | 68 | 69 | return vali_sum[0]/vali_cnt, vali_sum[1]/vali_cnt, vali_sum[2]/vali_cnt 70 | 71 | lock = threading.Lock() 72 | 73 | 74 | def vali_eval(xx, ii,opts,valid_data,model): 75 | global vali_sum, vali_cnt 76 | #print ("dataset.vali_user",dataset.vali_user) 77 | 78 | #valid_data = dataset.prepare_validation_data(1, [dataset.vali_user[ii]]) # is a dict 79 | 80 | #print ("valid_data",valid_data) 81 | #sys.exit() 82 | with torch.no_grad(): 83 | _,_,_, loss_sum, precision_1_sum, precision_2_sum, event_cnt = model(valid_data,index=ii) 84 | 85 | lock.acquire() 86 | vali_sum[0] += loss_sum 87 | vali_sum[1] += precision_1_sum 88 | vali_sum[2] +=precision_2_sum 89 | vali_cnt += event_cnt 90 | lock.release() 91 | 92 | 93 | lock = threading.Lock() 94 | 95 | 96 | 97 | def multithread_compute_test(opts,test_data,model): 98 | global test_sum, test_cnt 99 | 100 | num_sets = 1 * opts.num_thread 101 | 102 | thread_dist = [[] for _ in xrange(opts.num_thread)] 103 | for ii in xrange(num_sets): 104 | thread_dist[ii % opts.num_thread].append(ii) 105 | 106 | test_sum = [0.0, 0.0, 0.0] 107 | test_cnt = 0 108 | threads = [] 109 | for ii in xrange(opts.num_thread): 110 | thread = threading.Thread(target=test_eval, args=(1, thread_dist[ii],opts,test_data,model)) 111 | thread.start() 112 | threads.append(thread) 113 | 114 | for thread in threads: 115 | thread.join() 116 | 117 | return test_sum[0]/test_cnt, test_sum[1]/test_cnt, test_sum[2]/test_cnt 118 | 119 | 120 | def test_eval(xx, thread_dist,opts,test_data,model): 121 | global test_sum, test_cnt 122 | test_thread_eval = [0.0, 0.0, 0.0] 123 | test_thread_cnt = 0 124 | for ii in thread_dist: 125 | 126 | with torch.no_grad(): 127 | _,_,_, loss_sum, precision_1_sum, precision_2_sum, event_cnt = model(test_data,index=ii) 128 | 129 | test_thread_eval[0] += loss_sum 130 | test_thread_eval[1] +=precision_1_sum 131 | test_thread_eval[2] += precision_2_sum 132 | test_thread_cnt += event_cnt 133 | 134 | lock.acquire() 135 | test_sum[0] += test_thread_eval[0] 136 | test_sum[1] += test_thread_eval[1] 137 | test_sum[2] += test_thread_eval[2] 138 | test_cnt += test_thread_cnt 139 | lock.release() 140 | 141 | 142 | 143 | def init_weights(m): 144 | sd = 1e-3 145 | if type(m) == nn.Linear: 146 | torch.nn.init.normal_(m.weight) 147 | m.weight.data.clamp_(-sd,sd) # to mimic the normal clmaped weight initilization 148 | 149 | 150 | def main(opts): 151 | pp.pprint(vars(opts)) 152 | 153 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 154 | print("%s, start" % log_time) 155 | 156 | dataset = Dataset(opts) 157 | 158 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 159 | print("%s, load data completed" % log_time) 160 | 161 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 162 | print("%s, start prepare vali data" % log_time) 163 | 164 | 165 | valid_data=dataset.prepare_validation_data(opts.num_thread, dataset.vali_user) 166 | 167 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 168 | print("%s, prepare validation data, completed" % log_time) 169 | 170 | 171 | model = UserModelPW(dataset.f_dim, opts) 172 | model.apply(init_weights) 173 | 174 | #optimizer = optim.Adam( 175 | # [{'params': model.parameters(), 'lr': opts.learning_rate}]) 176 | 177 | optimizer = optim.Adam(model.parameters(), lr=opts.learning_rate, betas=(0.5, 0.999)) 178 | 179 | 180 | best_metric = [100000.0, 0.0, 0.0] 181 | 182 | vali_path = opts.save_dir+'/' 183 | if not os.path.exists(vali_path): 184 | os.makedirs(vali_path) 185 | 186 | 187 | #training_dataloader = DataLoader(training_dataset, batch_size=opts.batch_size, num_workers=1) # need to change the dataloader 188 | 189 | for i in xrange(opts.num_itrs): 190 | 191 | #model.train() 192 | for p in model.parameters(): 193 | p.requires_grad = True 194 | model.zero_grad() 195 | 196 | training_user_nos = np.random.choice(dataset.train_user, opts.batch_size, replace=False) 197 | 198 | training_user= dataset.data_process_for_placeholder(training_user_nos) 199 | for p in model.parameters(): 200 | p.data.clamp_(-1e0, 1e0) 201 | 202 | #for batch_id, batch in enumerate(tqdm(training_dataloader)): # the original code does not iterate over entire batch , so change this one 203 | 204 | loss,_,_,_,_,_,_= model(training_user,is_train=True) 205 | #print ("the loss is",loss) 206 | 207 | loss.backward() 208 | optimizer.step() 209 | 210 | if np.mod(i, 10) == 0: 211 | if i == 0: 212 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 213 | print("%s, start first iteration validation" % log_time) 214 | 215 | if np.mod(i, 10) == 0: 216 | if i == 0: 217 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 218 | print("%s, start first iteration validation" % log_time) 219 | vali_loss_prc = multithread_compute_vali(opts,valid_data,model) 220 | if i == 0: 221 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 222 | print("%s, first iteration validation complete" % log_time) 223 | 224 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 225 | print("%s: itr%d, vali: %.5f, %.5f, %.5f" % 226 | (log_time, i, vali_loss_prc[0], vali_loss_prc[1], vali_loss_prc[2])) 227 | 228 | if vali_loss_prc[0] < best_metric[0]: 229 | best_metric[0] = vali_loss_prc[0] 230 | best_save_path = os.path.join(vali_path, 'best-loss') 231 | torch.save(model.state_dict(), best_save_path) 232 | #best_save_path = saver.save(sess, best_save_path) 233 | if vali_loss_prc[1] > best_metric[1]: 234 | best_metric[1] = vali_loss_prc[1] 235 | best_save_path = os.path.join(vali_path, 'best-pre1') 236 | torch.save(model.state_dict(), best_save_path) 237 | if vali_loss_prc[2] > best_metric[2]: 238 | best_metric[2] = vali_loss_prc[2] 239 | best_save_path = os.path.join(vali_path, 'best-pre2') 240 | torch.save(model.state_dict(), best_save_path) 241 | 242 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 243 | print("%s, iteration %d train complete" % (log_time, i)) 244 | 245 | # test 246 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 247 | print("%s, start prepare test data" % log_time) 248 | 249 | test_data = dataset.prepare_validation_data(opts.num_thread, dataset.test_user) 250 | 251 | log_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') 252 | print("%s, prepare test data end" % log_time) 253 | 254 | 255 | best_save_path = os.path.join(vali_path, 'best-loss') 256 | model.load_state_dict(torch.load(best_save_path)) 257 | #saver.restore(sess, best_save_path) 258 | test_loss_prc = multithread_compute_test(opts,test_data,model) 259 | vali_loss_prc = multithread_compute_vali(opts,valid_data,model) 260 | print("test!!!loss!!!, test: %.5f, vali: %.5f" % (test_loss_prc[0], vali_loss_prc[0])) 261 | 262 | best_save_path = os.path.join(vali_path, 'best-pre1') 263 | model.load_state_dict(torch.load(best_save_path)) 264 | #saver.restore(sess, best_save_path) 265 | test_loss_prc = multithread_compute_test(opts,test_data,model) 266 | vali_loss_prc = multithread_compute_vali(opts,valid_data,model) 267 | print("test!!!pre1!!!, test: %.5f, vali: %.5f" % (test_loss_prc[1], vali_loss_prc[1])) 268 | 269 | best_save_path = os.path.join(vali_path, 'best-pre2') 270 | model.load_state_dict(torch.load(best_save_path)) 271 | #saver.restore(sess, best_save_path) 272 | test_loss_prc = multithread_compute_test(opts,test_data,model) 273 | vali_loss_prc = multithread_compute_vali(opts,valid_data,model) 274 | print("test!!!pre2!!!, test: %.5f, vali: %.5f" % (test_loss_prc[2], vali_loss_prc[2])) 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | if __name__ == "__main__": 296 | main(get_options()) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # define the gan and rl model here 2 | 3 | 4 | 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import torch 10 | from torch import nn 11 | import numpy as np 12 | from past.builtins import xrange 13 | 14 | 15 | class UserModelPW(nn.Module): 16 | """docstring for UserModelPW""" 17 | def __init__(self, f_dim,args): 18 | super(UserModelPW, self).__init__() 19 | self.f_dim = f_dim 20 | #self.placeholder = {} 21 | self.hidden_dims = args.dims 22 | self.lr = args.learning_rate 23 | self.pw_dim = args.pw_dim 24 | self.band_size = args.pw_band_size 25 | self.mlp_model = self.mlp(4020,args.dims,1, 1e-3, act_last=False) 26 | 27 | def mlp(self,x_shape, hidden_dims, output_dim, sd, act_last=False): 28 | hidden_dims = tuple(map(int, hidden_dims.split("-"))) 29 | #print ("hidden_dims",hidden_dims) 30 | #print ("imp is",x) 31 | #print (x.shape,x.dtype) 32 | cur = x_shape 33 | main_mod = nn.Sequential() 34 | for i,h in enumerate(hidden_dims): 35 | main_mod.add_module('Linear-{0}'.format(i),torch.nn.Linear(cur,h)) 36 | main_mod.add_module('act-{0}'.format(i),nn.ELU()) 37 | cur =h 38 | 39 | if act_last: 40 | main_mod.add_module("Linear_last",torch.nn.Linear(cur,output_dim)) 41 | main_mod.add_module("act_last",nn.ELU()) 42 | return main_mod 43 | else: 44 | main_mod.add_module("linear_last",torch.nn.Linear(cur,output_dim)) 45 | return main_mod 46 | 47 | 48 | 49 | 50 | def forward(self,inputs,is_train=False,index=None): 51 | # input is a dictionaty 52 | if is_train==True: 53 | 54 | disp_current_feature = torch.tensor(inputs['disp_current_feature_x']) 55 | Xs_clicked = torch.tensor(inputs['feature_clicked_x']) 56 | item_size= torch.tensor(inputs['news_cnt_short_x']) 57 | section_length= torch.tensor(inputs['sec_cnt_x']) 58 | click_values= torch.tensor(np.ones(len(inputs['click_2d_x']), dtype=np.float32)) 59 | click_indices = torch.tensor(inputs['click_2d_x']) 60 | disp_indices= torch.tensor(np.array(inputs['disp_2d_x'])) 61 | disp_2d_split_sec_ind= torch.tensor(inputs['disp_2d_split_sec']) 62 | cumsum_tril_indices= torch.tensor(inputs['tril_indice']) 63 | cumsum_tril_value_indices= torch.tensor(np.array(inputs['tril_value_indice'], dtype=np.int64)) 64 | click_2d_subindex= torch.tensor(inputs['click_sub_index_2d']) 65 | 66 | else: 67 | #define the inputs for val/tst here 68 | #print ("input_val",inputs) 69 | 70 | disp_current_feature = torch.tensor(inputs['feature_v'][index]) 71 | Xs_clicked = torch.tensor(inputs['feature_clicked_v'][index]) 72 | item_size= torch.tensor(inputs['news_cnt_short_v'][index]) 73 | section_length= torch.tensor(inputs['sec_cnt_v'][index]) 74 | click_values= torch.tensor(np.ones(len(inputs['click_2d_v'][index]), dtype=np.float32)) 75 | click_indices = torch.tensor(inputs['click_2d_v'][index]) 76 | disp_indices= torch.tensor(np.array(inputs['disp_2d_v'][index])) 77 | disp_2d_split_sec_ind= torch.tensor(inputs['disp_2d_split_sec_v'][index]) 78 | cumsum_tril_indices= torch.tensor(inputs['tril_ind_v'][index]) 79 | cumsum_tril_value_indices= torch.tensor(np.array(inputs['tril_value_ind_v'][index], dtype=np.int64)) 80 | click_2d_subindex= torch.tensor(inputs['click_sub_index_2d_v'][index]) 81 | 82 | 83 | 84 | denseshape = [section_length,item_size]# this wont work 85 | 86 | click_history = [[] for _ in xrange(self.pw_dim)] 87 | 88 | for ii in xrange(self.pw_dim): 89 | position_weight = torch.ones(size = [self.band_size]).to(dtype = torch.float64)* 0.0001 90 | #print (position_weight,cumsum_tril_value_indices) 91 | 92 | cumsum_tril_value = position_weight[cumsum_tril_value_indices]# tf.gather(position_weight, self.placeholder['cumsum_tril_value_indices']) 93 | # seel if torch gather could be better here 94 | 95 | #print ("cumsum_tril_indices",cumsum_tril_indices) 96 | #print ("cumsum_tril_value",cumsum_tril_value) 97 | #print ("section_length",section_length) 98 | cumsum_tril_matrix = torch.sparse.FloatTensor(cumsum_tril_indices.t(),cumsum_tril_value,[section_length,section_length]).to_dense() 99 | #print ("cumsum_tril_matrix",cumsum_tril_matrix) 100 | #print ("Xs_clicked",Xs_clicked.dtype) 101 | click_history[ii] = torch.matmul(cumsum_tril_matrix, Xs_clicked.to(dtype=torch.float64)) # Xs_clicked: section by _f_dim 102 | 103 | 104 | concat_history = torch.cat(click_history, axis=1) 105 | 106 | disp_history_feature = concat_history[disp_2d_split_sec_ind] 107 | 108 | 109 | # (4) combine features 110 | concat_disp_features = torch.reshape(torch.cat([disp_history_feature, disp_current_feature], axis=1), 111 | [-1, self.f_dim * self.pw_dim + self.f_dim]) 112 | 113 | # (5) compute utility 114 | #print ("the in pu t shape s ",concat_disp_features.shape) 115 | 116 | u_disp = self.mlp_model(concat_disp_features.float()) 117 | #net.apply(init_weights,sdv) 118 | # (5) 119 | exp_u_disp = torch.exp(u_disp) 120 | 121 | sum_exp_disp_ubar_ut = segment_sum(exp_u_disp, disp_2d_split_sec_ind) 122 | #print ("index",click_2d_subindex) 123 | sum_click_u_bar_ut = u_disp[click_2d_subindex] 124 | 125 | 126 | # (6) loss and precision 127 | #print ("click_values",click_values) 128 | #print ("click_indices",click_indices) 129 | #print ("denseshape",denseshape) 130 | click_tensor = torch.sparse.FloatTensor(click_indices.t(),click_values, denseshape).to_dense() 131 | click_cnt = click_tensor.sum(1) 132 | loss_sum = torch.sum(- sum_click_u_bar_ut + torch.log(sum_exp_disp_ubar_ut + 1)) 133 | event_cnt = torch.sum(click_cnt) 134 | loss = loss_sum / event_cnt 135 | 136 | exp_disp_ubar_ut = torch.sparse.FloatTensor(disp_indices.t(), torch.reshape(exp_u_disp, (-1,)), denseshape) 137 | dense_exp_disp_util = exp_disp_ubar_ut.to_dense() 138 | argmax_click = torch.argmax(click_tensor, dim=1) 139 | argmax_disp = torch.argmax(dense_exp_disp_util, dim=1) 140 | 141 | top_2_disp = torch.topk(dense_exp_disp_util, k=2, sorted=False)[1] 142 | 143 | 144 | # print ("argmax_click",argmax_click.shape) 145 | # #print ("argmax_disp",argmax_disp) 146 | # print ("top_2_disp",top_2_disp.shape) 147 | # sys.exit() 148 | precision_1_sum = torch.sum((torch.eq(argmax_click, argmax_disp))) 149 | precision_1 = precision_1_sum / event_cnt 150 | 151 | 152 | precision_2_sum = (torch.eq(argmax_click[:,None].to(torch.int64), top_2_disp.to(torch.int64))).sum() 153 | precision_2 = precision_2_sum / event_cnt 154 | 155 | 156 | #self.lossL2 = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name]) * 0.05 # regularity 157 | # weight decay can be added in the optimizer for l2 decay 158 | return loss, precision_1, precision_2, loss_sum, precision_1_sum, precision_2_sum, event_cnt 159 | 160 | 161 | 162 | 163 | def segment_sum(data, segment_ids): 164 | """ 165 | Analogous to tf.segment_sum (https://www.tensorflow.org/api_docs/python/tf/math/segment_sum). 166 | 167 | :param data: A pytorch tensor of the data for segmented summation. 168 | :param segment_ids: A 1-D tensor containing the indices for the segmentation. 169 | :return: a tensor of the same type as data containing the results of the segmented summation. 170 | """ 171 | if not all(segment_ids[i] <= segment_ids[i + 1] for i in range(len(segment_ids) - 1)): 172 | raise AssertionError("elements of segment_ids must be sorted") 173 | 174 | if len(segment_ids.shape) != 1: 175 | raise AssertionError("segment_ids have be a 1-D tensor") 176 | 177 | if data.shape[0] != segment_ids.shape[0]: 178 | raise AssertionError("segment_ids should be the same size as dimension 0 of input.") 179 | 180 | # t_grp = {} 181 | # idx = 0 182 | # for i, s_id in enumerate(segment_ids): 183 | # s_id = s_id.item() 184 | # if s_id in t_grp: 185 | # t_grp[s_id] = t_grp[s_id] + data[idx] 186 | # else: 187 | # t_grp[s_id] = data[idx] 188 | # idx = i + 1 189 | # 190 | # lst = list(t_grp.values()) 191 | # tensor = torch.stack(lst) 192 | 193 | num_segments = len(torch.unique(segment_ids)) 194 | return unsorted_segment_sum(data, segment_ids, num_segments) 195 | 196 | 197 | def unsorted_segment_sum(data, segment_ids, num_segments): 198 | """ 199 | Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. 200 | 201 | :param data: A tensor whose segments are to be summed. 202 | :param segment_ids: The segment indices tensor. 203 | :param num_segments: The number of segments. 204 | :return: A tensor of same data type as the data argument. 205 | """ 206 | assert all([i in data.shape for i in segment_ids.shape]), "segment_ids.shape should be a prefix of data.shape" 207 | 208 | # segment_ids is a 1-D tensor repeat it to have the same shape as data 209 | if len(segment_ids.shape) == 1: 210 | s = torch.prod(torch.tensor(data.shape[1:])).long() 211 | segment_ids = segment_ids.repeat_interleave(s).view(segment_ids.shape[0], *data.shape[1:]) 212 | 213 | assert data.shape == segment_ids.shape, "data.shape and segment_ids.shape should be equal" 214 | 215 | shape = [num_segments] + list(data.shape[1:]) 216 | tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float()) 217 | tensor = tensor.type(data.dtype) 218 | return tensor -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | 6 | import os 7 | import time 8 | import argparse 9 | import torch 10 | import numpy as np 11 | 12 | 13 | 14 | def get_options(args=None): 15 | parser = argparse.ArgumentParser( 16 | description="Args for recommendation system reinforce_gan model") 17 | 18 | parser.add_argument('--data_folder',type =str,default ='./data/',help = 'dataset_folder') 19 | parser.add_argument('--dataset',type = str,default = 'yelp',help='çhoose from yelp,tb or rsc') 20 | parser.add_argument('--save_dir',type = str,default = './save_dir/',help='save folder') 21 | 22 | parser.add_argument('--resplit', type=eval, default=False) 23 | parser.add_argument('--num_thread', type=int, default=10, help='number of threadings') 24 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='learning rate') 25 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 26 | parser.add_argument('--num_itrs', type=int, default=2000, help='num of iterations for q learning') 27 | # might change later to policy_grad method with attetion rather than lstm 28 | parser.add_argument('--rnn_hidden_dim', type=int, default=20, help='LSTM hidden sizes') 29 | parser.add_argument('--pw_dim', type=int, default=4, help='position weight dim') 30 | parser.add_argument('--pw_band_size', type=int, default=20, help='position weight banded size (i.e. length of history)') 31 | 32 | 33 | parser.add_argument('--dims', type=str, default='64-64') 34 | parser.add_argument('--user_model', type=str, default='PW', help='architecture choice: LSTM or PW') 35 | # dont think that the PW model could be used atm 36 | 37 | opts = parser.parse_args(args) 38 | 39 | return opts -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | 4 | import numpy as np 5 | import pickle 6 | import pandas as pd 7 | import argparse 8 | from past.builtins import xrange 9 | 10 | from options import get_options 11 | 12 | 13 | def main(opts): 14 | 15 | filename = opts.data_folder+opts.dataset+'.txt' 16 | 17 | raw_data = pd.read_csv(filename, sep='\t', usecols=[1, 3, 5, 7, 6], dtype={1: int, 3: int, 7: int, 5:int, 6:int}) 18 | 19 | 20 | raw_data.drop_duplicates(subset=['session_new_index','Time','item_new_index','is_click'], inplace=True) 21 | 22 | raw_data.sort_values(by='is_click',inplace=True) 23 | print (raw_data.head()) 24 | raw_data.drop_duplicates(keep='last', subset=['session_new_index','Time','item_new_index'], inplace=True) 25 | 26 | sizes = raw_data.nunique() 27 | print (sizes) 28 | size_user = sizes['session_new_index'] 29 | size_item = sizes['item_new_index'] 30 | 31 | data_user = raw_data.groupby(by='session_new_index') 32 | print (data_user) 33 | data_behavior = [[] for _ in xrange(size_user)] 34 | 35 | train_user = [] 36 | vali_user = [] 37 | test_user = [] 38 | 39 | sum_length = 0 40 | event_cnt = 0 41 | 42 | for user in xrange(size_user): 43 | data_behavior[user] = [[], [], []] 44 | data_behavior[user][0] = user 45 | data_u = data_user.get_group(user) 46 | split_tag = list(data_u['tr_val_tst'])[0] 47 | if split_tag == 0: 48 | train_user.append(user) 49 | elif split_tag == 1: 50 | vali_user.append(user) 51 | else: 52 | test_user.append(user) 53 | 54 | data_u_time = data_u.groupby(by='Time') 55 | time_set = np.array(list(set(data_u['Time']))) 56 | time_set.sort() 57 | 58 | true_t = 0 59 | for t in xrange(len(time_set)): 60 | display_set = data_u_time.get_group(time_set[t]) 61 | event_cnt += 1 62 | sum_length += len(display_set) 63 | 64 | data_behavior[user][1].append(list(display_set['item_new_index'])) 65 | data_behavior[user][2].append(int(display_set[display_set.is_click==1]['item_new_index'])) 66 | 67 | new_features = np.eye(size_item) 68 | 69 | filename = opts.data_folder+opts.dataset+'.pkl' 70 | file = open(filename, 'wb') 71 | print (data_behavior) 72 | pickle.dump(data_behavior, file, protocol=pickle.HIGHEST_PROTOCOL) 73 | pickle.dump(new_features, file, protocol=pickle.HIGHEST_PROTOCOL) 74 | file.close() 75 | 76 | filename = opts.data_folder+opts.dataset+'-split.pkl' 77 | file = open(filename, 'wb') 78 | pickle.dump(train_user, file, protocol=pickle.HIGHEST_PROTOCOL) 79 | pickle.dump(vali_user, file, protocol=pickle.HIGHEST_PROTOCOL) 80 | pickle.dump(test_user, file, protocol=pickle.HIGHEST_PROTOCOL) 81 | file.close() 82 | 83 | 84 | 85 | 86 | if __name__ == "__main__": 87 | main(get_options()) -------------------------------------------------------------------------------- /reco_gan_rl.ipynb: -------------------------------------------------------------------------------- 1 | {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"reco_gan_rl.ipynb","provenance":[],"collapsed_sections":[],"mount_file_id":"1kUjgM741QgcTewKCoX9wv6seGKbqQ4LK","authorship_tag":"ABX9TyOo8LAlNrbBX8O9v6cKrigM"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"code","metadata":{"id":"OQVLRArWLsKg","colab":{"base_uri":"https://localhost:8080/","height":102},"executionInfo":{"status":"ok","timestamp":1599833158129,"user_tz":-540,"elapsed":1908,"user":{"displayName":"Rushikesh Handal","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GgOz3selWel07DhMWJ_Qtdhqke75MMSNUhZbZegnA=s64","userId":"09084884598721117870"}},"outputId":"82731377-a2cf-46f5-c2f2-964f1e10a8d8"},"source":["%cd /content/drive/My\\ Drive/Colab\\ Notebooks/reco_gan_rl/"],"execution_count":null,"outputs":[{"output_type":"stream","text":["/content/drive/My Drive/Colab Notebooks/reco_gan_rl\n","data\t\t\t\t main.py\t\t __pycache__\n","data_utils_back.py\t\t model.py\t\t reco_gan_rl.ipynb\n","data_utils.py\t\t\t options.py\t save_dir\n","GenerativeAdversarialUserModel-master preprocess_data.py utils.py\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"zR2LmHjH9r41"},"source":["!python3 preprocess_data.py"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"pJWo_4LFqo0g"},"source":["!python main.py"],"execution_count":null,"outputs":[]}]} -------------------------------------------------------------------------------- /save_dir/best-loss: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rushhan/Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch/d7e34edd8013f9ef468d48c3d7765d17e8ed09dc/save_dir/best-loss -------------------------------------------------------------------------------- /save_dir/best-pre1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rushhan/Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch/d7e34edd8013f9ef468d48c3d7765d17e8ed09dc/save_dir/best-pre1 -------------------------------------------------------------------------------- /save_dir/best-pre2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rushhan/Generative-Adversarial-User-Model-for-Reinforcement-Learning-Based-Recommendation-System-Pytorch/d7e34edd8013f9ef468d48c3d7765d17e8ed09dc/save_dir/best-pre2 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # utils for the data and model --------------------------------------------------------------------------------