├── material ├── Many2Many.png └── Many2One.png ├── LICENSE ├── README.md ├── deep_tools.py └── deep_learning.ipynb /material/Many2Many.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drop-out/RNN-Active-User-Forecast/HEAD/material/Many2Many.png -------------------------------------------------------------------------------- /material/Many2One.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drop-out/RNN-Active-User-Forecast/HEAD/material/Many2One.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 drop-out 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **作者按:由于比赛时间仓促,这份代码中有些地方写的并不规范。更规范的tensorflow RNN构建,可以参考作者的另外一个项目[tenosrflow-RNN-toolkit](https://github.com/drop-out/tensorflow-RNN-toolkit),该项目使用更高程度抽象的building block构建RNN,同时不失灵活性。** 2 | 3 | ## 赛题回顾 4 | 5 | 这是一个活跃用户预测问题。给定快手用户注册、登陆、视频观看与发布、互动的记录,预测未来7天活跃用户。 6 | 7 | 详情可参见[比赛页面](https://www.kesci.com/home/competition/5ab8c36a8643e33f5138cba4)。 8 | 9 | ## RNN: Many2One vs Many2Many 10 | 11 | 使用RNN,一般地会想到如下解决方案:以几天内的用户行为序列为输入,以未来七天该用户是否活跃为标签,标注该序列。这是一种Many2One的解决方案。 12 | 13 | ![](https://github.com/drop-out/RNN-Active-User-Forecast/raw/master/material/Many2One.png) 14 | 15 | 为了充分利用数据,需要对训练数据做大量的滑窗,以实现数据增广,计算成本高。另外,每个序列只有一个标签,梯度难以传导,导致训练困难。相反的,我们可以考虑Many2Many结构,即每个输入都对应输出之后7天是否活跃,充分利用监督信息,减轻梯度传到负担,使训练更加容易。 16 | 17 | ![](https://github.com/drop-out/RNN-Active-User-Forecast/raw/master/material/Many2Many.png) 18 | 19 | Many2One和Many2Many结构的简单对比如下。 20 | 21 | | | Many2One | Many2Many | 22 | | ---------------- | -------- | --------- | 23 | | 无需滑窗 | | √ | 24 | | 充分利用监督信息 | | √ | 25 | | 变长序列 | | √ | 26 | 27 | ## 输入序列 28 | 29 | 相比xgboost的历史统计量为特征的解决方案,RNN无需对输入序列做过多处理,对各类行为序列直接输入即可。简单列表如下: 30 | 31 | - 当天是否登陆(0/1) 32 | - 当天观看次数(加1取对数) 33 | - 分action_type行为记录数(加1取对数) 34 | - 分page行为记录数(加1取对数) 35 | 36 | ## Intercept 37 | 38 | 另外,在输出层直接做一个intercept拼接,将日期、device_type、register_type one-hot后输入。低频类别可归为一类。 39 | 40 | ## Variable Length 41 | 42 | 因为序列是变长的,采用dynamic-RNN,每个batch中取相同长度的序列,不同batch长度不同,每次随机取某一长度的batch。 43 | 44 | 45 | ## 余弦退火快照集成 46 | 47 | 采用余弦退火快照集成,可以以极低的成本获得大量有差异的局部最优,最后再进行融合,能获得显著的提升。 48 | -------------------------------------------------------------------------------- /deep_tools.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from random import shuffle 6 | 7 | def f(table,name='prob'): 8 | table=table.copy() 9 | score=[] 10 | for i in [0.40,0.41,0.42,0.43,0.44,0.45]: 11 | table['pred']=1*(table[name]>i) 12 | c=((table.pred==1)&(table.label==1)).sum() 13 | p=c/table.pred.sum() 14 | r=c/table.label.sum() 15 | score.append(2*p*r/(p+r)) 16 | return score 17 | 18 | def record_to_sequence(table): 19 | table.columns=['user_id','day','value'] 20 | table.sort_values(by=['user_id','day'],inplace=True) 21 | table['string']=table.day.map(str)+':'+table.value.map(str) 22 | table=table.groupby(['user_id'],as_index=False).agg({'string':lambda x:','.join(x)}) 23 | return table 24 | 25 | class user_seq: 26 | 27 | def __init__(self,register_day,seq_length,n_features): 28 | self.register_day=register_day 29 | self.seq_length=seq_length 30 | self.array=np.zeros([self.seq_length,n_features]) 31 | self.array[0,0]=1 32 | self.page_rank=np.zeros([self.seq_length]) 33 | self.pointer=1 34 | 35 | def put_feature(self,feature_number,string): 36 | for i in string.split(','): 37 | pos,value=i.split(':') 38 | self.array[int(pos)-self.register_day,feature_number]=1 39 | 40 | def put_PR(self,string): 41 | for i in string.split(','): 42 | pos,value=i.split(':') 43 | self.page_rank[int(pos)-self.register_day]=value 44 | 45 | def get_array(self): 46 | return self.array 47 | 48 | def get_label(self): 49 | self.label=np.array([None]*self.seq_length) 50 | active=self.array[:,:10].sum(axis=1) 51 | for i in range(self.seq_length-7): 52 | self.label[i]=1*(np.sum(active[i+1:i+8])>0) 53 | return self.label 54 | 55 | 56 | class DataGenerator: 57 | 58 | def __init__(self,register,launch,create,activity): 59 | 60 | register=register.copy() 61 | launch=launch.copy() 62 | create=create.copy() 63 | activity=activity.copy() 64 | 65 | #user_queue 66 | register['seq_length']=31-register['register_day'] 67 | self.user_queue={i:[] for i in range(1,31)} 68 | for index,row in register.iterrows(): 69 | self.user_queue[row[-1]].append(row[0]) #row[-1]是seq_length,row[0]是user_id 70 | 71 | #初始化self.data 72 | n_features=12 #row[0]是user_id,row[1]是register_day,row[-1]是seq_length 73 | self.data={row[0]:user_seq(register_day=row[1],seq_length=row[-1],n_features=n_features) for index,row in register.iterrows()} 74 | 75 | 76 | #提取launch_seq 77 | launch['launch']=1 78 | launch_table=launch.groupby(['user_id','launch_day'],as_index=False).agg({'launch':'sum'}) 79 | launch_table=record_to_sequence(launch_table) 80 | for index,row in launch_table.iterrows(): 81 | self.data[row[0]].put_feature(1,row[1]) #row[0]是user_id,row[1]是string 82 | 83 | #提取create_seq 84 | create['create']=1 85 | create_table=create.groupby(['user_id','create_day'],as_index=False).agg({'create':'sum'}) 86 | create_table=record_to_sequence(create_table) 87 | for index,row in create_table.iterrows(): 88 | self.data[row[0]].put_feature(2,row[1]) #row[0]是user_id,row[1]是string 89 | 90 | #提取act_seq 91 | for i in range(6): 92 | act=activity[activity.act_type==i].copy() 93 | act=act.groupby(['user_id','act_day'],as_index=False).agg({'video_id':'count'}) 94 | act=record_to_sequence(act) 95 | for index,row in act.iterrows(): 96 | self.data[row[0]].put_feature(i+3,row[1]) #row[0]是user_id,row[1]是string 97 | 98 | #提取page_seq 99 | for i in range(1): 100 | act=activity[activity.page==i].copy() 101 | act=act.groupby(['user_id','act_day'],as_index=False).agg({'video_id':'count'}) 102 | act=record_to_sequence(act) 103 | for index,row in act.iterrows(): 104 | self.data[row[0]].put_feature(i+9,row[1]) #row[0]是user_id,row[1]是string 105 | 106 | #提取watched 107 | watched=register.loc[:,['user_id']].copy() 108 | watched.columns=['author_id'] 109 | watched=pd.merge(watched,activity[activity.author_id!=activity.user_id],how='inner') 110 | watched=watched.groupby(['author_id','act_day'],as_index=False).agg({'video_id':'count'}) 111 | watched=record_to_sequence(watched) 112 | for index,row in watched.iterrows(): 113 | self.data[row[0]].put_feature(10,row[1]) #row[0]是user_id,row[1]是string 114 | 115 | #提取watched by self 116 | watched=activity[activity.author_id==activity.user_id].copy() 117 | watched=watched.groupby(['user_id','act_day'],as_index=False).agg({'video_id':'count'}) 118 | watched=record_to_sequence(watched) 119 | for index,row in watched.iterrows(): 120 | self.data[row[0]].put_feature(11,row[1]) #row[0]是user_id,row[1]是string 121 | 122 | #提取label 123 | self.label={user_id:user.get_label() for user_id,user in self.data.items()} 124 | 125 | #提取data 126 | self.data={user_id:user.get_array() for user_id,user in self.data.items()} 127 | 128 | 129 | #set sample strategy 130 | self.local_random_list=[] 131 | for i in range(15,31): 132 | self.local_random_list+=[i]*(i-14) 133 | 134 | self.online_random_list=[] 135 | for i in range(8,31): 136 | self.online_random_list+=[i]*(i-7) 137 | 138 | self.local_train_list=list(range(15,31)) 139 | self.local_test_list=list(range(8,31)) 140 | self.online_train_list=list(range(8,31)) 141 | self.online_test_list=list(range(1,31)) 142 | 143 | self.pointer={i:0 for i in range(1,31)} 144 | 145 | 146 | def reset_pointer(self): 147 | self.pointer={i:0 for i in range(1,31)} 148 | 149 | 150 | def next_batch(self,strategy='local',batch_size=1000): 151 | 152 | if strategy=='local': 153 | seq_length=self.local_random_list[np.random.randint(len(self.local_random_list))] 154 | batch_size=batch_size//(seq_length-14)+1 155 | else: 156 | seq_length=self.online_random_list[np.random.randint(len(self.online_random_list))] 157 | batch_size=batch_size//(seq_length-7)+1 158 | 159 | if self.pointer[seq_length]+batch_size>len(self.user_queue[seq_length]): 160 | self.pointer[seq_length]=0 161 | shuffle(self.user_queue[seq_length]) 162 | #print('---------------------',seq_length,'shuffled ------------------------------') 163 | start=self.pointer[seq_length] 164 | user_list=self.user_queue[seq_length][start:start+batch_size] 165 | self.pointer[seq_length]+=batch_size 166 | 167 | user_matrix=np.array(user_list) 168 | data_matrix=np.array([self.data[i] for i in user_list]) 169 | label_matrix=np.array([self.label[i] for i in user_list]) 170 | 171 | return seq_length,user_matrix,data_matrix,label_matrix 172 | 173 | 174 | def get_set(self,strategy='local',usage='train'): 175 | if strategy=='local': 176 | if usage=='train': 177 | test_list=self.local_train_list 178 | else: 179 | test_list=self.local_test_list 180 | else: 181 | if usage=='train': 182 | test_list=self.online_train_list 183 | else: 184 | test_list=self.online_test_list 185 | user_list=[np.array(self.user_queue[seq_length]) for seq_length in test_list] 186 | data_list=[np.array([self.data[user_id] for user_id in self.user_queue[seq_length]]) for seq_length in test_list] 187 | label_list=[np.array([self.label[user_id] for user_id in self.user_queue[seq_length]]) for seq_length in test_list] 188 | return test_list,user_list,data_list,label_list 189 | 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /deep_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": false 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import pandas as pd\n", 12 | "import numpy as np\n", 13 | "import tensorflow as tf\n", 14 | "import datetime\n", 15 | "from deep_tools import f\n", 16 | "from deep_tools import DataGenerator\n", 17 | "\n", 18 | "register=pd.read_csv('./data/user_register_log.txt',sep='\\t',names=['user_id','register_day','register_type','device_type'])\n", 19 | "launch=pd.read_csv('./data/app_launch_log.txt',sep='\\t',names=['user_id','launch_day'])\n", 20 | "create=pd.read_csv('./data/video_create_log.txt',sep='\\t',names=['user_id','create_day'])\n", 21 | "activity=pd.read_csv('./data/user_activity_log.txt',sep='\\t',names=['user_id','act_day','page','video_id','author_id','act_type'])" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": { 28 | "collapsed": true 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "#参数\n", 33 | "n_features=12\n", 34 | "n_hu=5\n", 35 | "n_device=50\n", 36 | "n_register=7\n", 37 | "n_days=31" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": { 44 | "collapsed": true 45 | }, 46 | "outputs": [], 47 | "source": [ 48 | "data_generator=DataGenerator(register,launch,create,activity)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": { 55 | "collapsed": true 56 | }, 57 | "outputs": [], 58 | "source": [ 59 | "#device_dict\n", 60 | "device_table=register.groupby(['device_type'],as_index=False).agg({'user_id':'count'})\n", 61 | "device_table=device_table.sort_values(by=['user_id'],ascending=False)\n", 62 | "device_table['device_type_map']=np.arange(len(device_table))\n", 63 | "device_table.drop('user_id',axis=1,inplace=True)\n", 64 | "register=pd.merge(register,device_table)\n", 65 | "device_dict={row[0]:row[-1] for index,row in register.iterrows()}\n", 66 | "\n", 67 | "#register_dict\n", 68 | "register_dict={row[0]:row[2] for index,row in register.iterrows()}" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "collapsed": true 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "tf.reset_default_graph()\n", 80 | "tf.set_random_seed(10)\n", 81 | "\n", 82 | "#Variables\n", 83 | "with tf.variable_scope('test4'):\n", 84 | " \n", 85 | " #变量与输入\n", 86 | " lr=tf.placeholder(tf.float32,[],name='learning_rate')\n", 87 | "\n", 88 | " W_out=tf.get_variable('W_out',[n_hu,1])\n", 89 | " b_out=tf.get_variable('b_out',[1])\n", 90 | "\n", 91 | " \n", 92 | " x=tf.placeholder(tf.float32,[None,None,n_features])\n", 93 | " y=tf.placeholder(tf.float32,[None,None])\n", 94 | " \n", 95 | " batch_size=tf.shape(x)[0]\n", 96 | " seq_length=tf.shape(x)[1]\n", 97 | "\n", 98 | " PR_input=tf.placeholder(tf.float32,[None,None,1])\n", 99 | " \n", 100 | " device_input=tf.placeholder(tf.int32,[None])\n", 101 | " register_input=tf.placeholder(tf.int32,[None])\n", 102 | " date_input=tf.placeholder(tf.int32,[None])\n", 103 | " \n", 104 | " device_embedding=tf.get_variable('device_embedding',[n_device,1],initializer=tf.zeros_initializer)\n", 105 | " register_embedding=tf.get_variable('register_embedding',[n_register,1],initializer=tf.zeros_initializer)\n", 106 | " date_embedding=tf.get_variable('date_embedding',[n_days,1],initializer=tf.zeros_initializer)\n", 107 | " \n", 108 | " #RNN层\n", 109 | " cell=tf.nn.rnn_cell.GRUCell(n_hu)\n", 110 | " initial_state = cell.zero_state(batch_size, dtype=tf.float32)\n", 111 | " outputs, state = tf.nn.dynamic_rnn(cell, x,\n", 112 | " initial_state=initial_state)\n", 113 | " \n", 114 | " #输出层\n", 115 | " outputs=tf.reshape(outputs,[-1,n_hu])\n", 116 | " logits=tf.matmul(outputs,W_out)+b_out\n", 117 | " logits=tf.reshape(logits,tf.stack([batch_size,seq_length]))\n", 118 | " \n", 119 | " device_intercept=tf.nn.embedding_lookup(device_embedding,device_input)\n", 120 | " register_intercept=tf.nn.embedding_lookup(register_embedding,register_input)\n", 121 | " date_intercept=tf.nn.embedding_lookup(date_embedding,date_input)\n", 122 | " date_intercept=tf.reshape(date_intercept,tf.stack([1,seq_length]))\n", 123 | " \n", 124 | " \n", 125 | " logits=logits+device_intercept+register_intercept+date_intercept" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "collapsed": true 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "#local_train\n", 137 | "logits_local_train=logits[:,:-14]\n", 138 | "label_local_train=y[:,:-14]\n", 139 | "\n", 140 | "regularizer=tf.contrib.layers.l2_regularizer(0.00001)\n", 141 | "penalty=tf.contrib.layers.apply_regularization(regularizer,tf.trainable_variables())\n", 142 | "\n", 143 | "obj_local=tf.losses.sigmoid_cross_entropy(label_local_train,logits_local_train)+penalty\n", 144 | "optimizer=tf.train.AdamOptimizer(lr)\n", 145 | "step_local=optimizer.minimize(obj_local)\n", 146 | "\n", 147 | "#local_test\n", 148 | "logits_local_test=logits[:,-8]\n", 149 | "label_local_test=y[:,-8]\n", 150 | "\n", 151 | "#online_train\n", 152 | "logits_online_train=logits[:,:-7]\n", 153 | "label_online_train=y[:,:-7]\n", 154 | "\n", 155 | "obj_online=tf.losses.sigmoid_cross_entropy(label_online_train,logits_online_train)+penalty\n", 156 | "optimizer=tf.train.AdamOptimizer(lr)\n", 157 | "step_online=optimizer.minimize(obj_online)\n", 158 | "\n", 159 | "#online_test\n", 160 | "logits_online_test=logits[:,-1]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": { 167 | "collapsed": true 168 | }, 169 | "outputs": [], 170 | "source": [ 171 | "sess=tf.Session()\n", 172 | "sess.run(tf.global_variables_initializer())" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": { 179 | "collapsed": true 180 | }, 181 | "outputs": [], 182 | "source": [ 183 | "def test(strategy='local'):\n", 184 | " if strategy=='local':\n", 185 | " n_NA=14\n", 186 | " date_seq=[31]+list(range(2,16))+[16]*15\n", 187 | " variables_1=[obj_local,logits_local_train,label_local_train]\n", 188 | " variables_2=[logits_local_test,label_local_test]\n", 189 | " else:\n", 190 | " n_NA=7\n", 191 | " date_seq=[31]+list(range(2,23))+[23]*8\n", 192 | " variables_1=[obj_online,logits_online_train,label_online_train]\n", 193 | " variables_2=logits_online_test\n", 194 | " \n", 195 | " obs_count,cum_loss,correct=0,0,0\n", 196 | " user,prob,real=[],[],[]\n", 197 | "\n", 198 | " #训练损失\n", 199 | " for length,id_list,data_x,data_y in zip(*data_generator.get_set(strategy,'train')):\n", 200 | " _obj,_logits_train,_label_train=sess.run(variables_1,\n", 201 | " feed_dict={x:data_x,\n", 202 | " y:data_y,\n", 203 | " device_input:[device_dict[u] for u in id_list],\n", 204 | " register_input:[register_dict[u] for u in id_list],\n", 205 | " date_input:date_seq[-length:],\n", 206 | " lr:0.001})\n", 207 | "\n", 208 | " obs_count+=(length-n_NA)*len(id_list)\n", 209 | " cum_loss+=_obj*(length-n_NA)*len(id_list)\n", 210 | " correct+=np.sum((1*(_logits_train>0)==_label_train))\n", 211 | "\n", 212 | " #测试损失\n", 213 | " for length,id_list,data_x,data_y in zip(*data_generator.get_set(strategy,'test')):\n", 214 | " _=sess.run(variables_2,\n", 215 | " feed_dict={x:data_x,\n", 216 | " y:data_y,\n", 217 | " device_input:[device_dict[u] for u in id_list],\n", 218 | " register_input:[register_dict[u] for u in id_list],\n", 219 | " date_input:date_seq[-length:],\n", 220 | " lr:0.001})\n", 221 | " if strategy=='local':\n", 222 | " _logits_test,_label_test=_\n", 223 | " real+=list(_label_test)\n", 224 | " else:\n", 225 | " _logits_test=_\n", 226 | "\n", 227 | " user+=list(id_list)\n", 228 | " prob+=list(1/(1+np.exp(-_logits_test.reshape([-1]))))\n", 229 | " \n", 230 | " #训练损失\n", 231 | " print('train_loss',cum_loss/obs_count,correct/obs_count)\n", 232 | " \n", 233 | " #测试损失\n", 234 | " if strategy=='local':\n", 235 | " result=pd.DataFrame({'user_id':user,'prob':prob,'label':real})\n", 236 | " print('test_score:',f(result))\n", 237 | " else:\n", 238 | " result=pd.DataFrame({'user_id':user,'prob':prob})\n", 239 | " return result" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "collapsed": true 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "def train(strategy='local',n_obs=1000,step=1000,lr_feed=0.01):\n", 251 | " \n", 252 | " if strategy=='local':\n", 253 | " date_seq=[31]+list(range(2,16))+[16]*15\n", 254 | " variables=[step_local,obj_local,label_local_train,logits_local_train]\n", 255 | " else:\n", 256 | " date_seq=[31]+list(range(2,23))+[23]*8\n", 257 | " variables=[step_online,obj_online,label_online_train,logits_online_train]\n", 258 | "\n", 259 | " for i in range(step):\n", 260 | " length,id_list,data_x,data_y=data_generator.next_batch(strategy,n_obs)\n", 261 | " _,los,lab,log=sess.run(variables,\n", 262 | " feed_dict={x:data_x,\n", 263 | " y:data_y,\n", 264 | " device_input:[device_dict[u] for u in id_list],\n", 265 | " register_input:[register_dict[u] for u in id_list],\n", 266 | " date_input:date_seq[-length:],\n", 267 | " lr:lr_feed})" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "metadata": { 274 | "collapsed": true 275 | }, 276 | "outputs": [], 277 | "source": [ 278 | "sess.run(tf.global_variables_initializer())" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": null, 284 | "metadata": { 285 | "collapsed": true 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "def cos_annealing_local(epoch=5):\n", 290 | " all_result=None\n", 291 | " for i in range(epoch):\n", 292 | " train('local',n_obs=1000,step=2000,lr_feed=0.01)\n", 293 | " train('local',n_obs=1000,step=2000,lr_feed=0.001)\n", 294 | " result=test('local')\n", 295 | " print(sess.run(penalty))\n", 296 | " result.columns=['label','prob%s'%i,'user_id']\n", 297 | " if i==0:\n", 298 | " all_result=result\n", 299 | " else:\n", 300 | " all_result=pd.merge(all_result,result)\n", 301 | " return all_result\n", 302 | "\n", 303 | "def cos_annealing_online(epoch=5):\n", 304 | " all_result=None\n", 305 | " for i in range(epoch):\n", 306 | " train('online',n_obs=1000,step=2000,lr_feed=0.01)\n", 307 | " train('online',n_obs=1000,step=2000,lr_feed=0.001)\n", 308 | " result=test('online')\n", 309 | " print(sess.run(penalty))\n", 310 | " result.columns=['prob%s'%i,'user_id']\n", 311 | " if i==0:\n", 312 | " all_result=result\n", 313 | " else:\n", 314 | " all_result=pd.merge(all_result,result)\n", 315 | " return all_result" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": null, 321 | "metadata": { 322 | "collapsed": false 323 | }, 324 | "outputs": [], 325 | "source": [ 326 | "#线下测试\n", 327 | "print(datetime.datetime.now())\n", 328 | "result=cos_annealing_local(5)\n", 329 | "print(datetime.datetime.now())" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": { 336 | "collapsed": false 337 | }, 338 | "outputs": [], 339 | "source": [ 340 | "#线上提交\n", 341 | "print(datetime.datetime.now())\n", 342 | "result=cos_annealing_online(5)\n", 343 | "print(datetime.datetime.now())" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "metadata": { 350 | "collapsed": true 351 | }, 352 | "outputs": [], 353 | "source": [ 354 | "#融合\n", 355 | "result['prob']=(result.prob0+result.prob1+result.prob2+result.prob3+result.prob4)/5" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": { 362 | "collapsed": true 363 | }, 364 | "outputs": [], 365 | "source": [ 366 | "result.sort_values(by='prob',ascending=False,inplace=True)\n", 367 | "result=result.reset_index(drop=True)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "metadata": { 374 | "collapsed": true 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "result.loc[:24800,['user_id']].to_csv('output/result.csv',header=False,index=False)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "metadata": { 385 | "collapsed": true 386 | }, 387 | "outputs": [], 388 | "source": [] 389 | } 390 | ], 391 | "metadata": { 392 | "kernelspec": { 393 | "display_name": "Python 3", 394 | "language": "python", 395 | "name": "python3" 396 | }, 397 | "language_info": { 398 | "codemirror_mode": { 399 | "name": "ipython", 400 | "version": 3 401 | }, 402 | "file_extension": ".py", 403 | "mimetype": "text/x-python", 404 | "name": "python", 405 | "nbconvert_exporter": "python", 406 | "pygments_lexer": "ipython3", 407 | "version": "3.6.0" 408 | } 409 | }, 410 | "nbformat": 4, 411 | "nbformat_minor": 2 412 | } 413 | --------------------------------------------------------------------------------