├── .DS_Store ├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── code ├── .DS_Store ├── 0_gen_sampled_data.py ├── 1_gen_sessions.py ├── 2_gen_dien_input.py ├── 2_gen_din_input.py ├── 2_gen_dsin_input.py ├── config.py ├── train_dien.py ├── train_din.py └── train_dsin.py ├── raw_data └── README.md └── requirements.txt /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenweichen/DSIN/e8ba406eeda0916214897d44866bffc419c3edb0/.DS_Store -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | sampled_data 3 | raw_data 4 | model_input 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | .DS_Store 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019 Weichen Shen 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Session Interest Network for Click-Through Rate Prediction 2 | 3 | Experiment code on Advertising Dataset of paper Deep Session Interest Network for Click-Through Rate Prediction(https://arxiv.org/abs/1905.06482) 4 | 5 | [Yufei Feng](https://github.com/649435349) , Fuyu Lv, Weichen Shen and Menghan Wang and Fei Sun and Yu Zhu and Keping Yang. 6 | 7 | In Proceedings of 28th International Joint Conference on Artificial Intelligence (IJCAI 2019) 8 | 9 | ---------------- 10 | ## Operating environment 11 | please use 12 | `pip install -r requirements.txt` 13 | to setup the operating environment in `python3.6`. 14 | 15 | -------------------------- 16 | ## Download dataset and preprocess 17 | ### Download dataset 18 | 19 | 1. Download Dataset [Ad Display/Click Data on Taobao.com](https://tianchi.aliyun.com/dataset/dataDetail?dataId=56) 20 | 2. Extract the files into the ``raw_data`` directory 21 | 22 | ### Data preprocessing 23 | 24 | 1. run `0_gen_sampled_data.py`, 25 | sample the data by user 26 | 2. run `1_gen_sessions.py`, 27 | generate historical session sequence for each user 28 | 29 | ## Training and Evaluation 30 | 31 | ### Train DIN model 32 | 1. run `2_gen_din_input.py`,generate input data 33 | 2. run `train_din.py` 34 | 35 | ### Train DIEN model 36 | 1. run `2_gen_dien_input.py`,generate input data(It may take a long time to sample negative samples.) 37 | 2. run `train_dien.py` 38 | 39 | ### Train DSIN model 40 | 1. run `2_gen_dsin_input.py`,generate input data 41 | 2. run `train_dsin.py` 42 | > The loss of DSIN with `bias_encoding=True` may be NaN sometimes on Advertising Dataset and it remains a confusing problem since it never occurs in the production environment.We will work on it and also appreciate your help. 43 | 44 | # License 45 | 46 | This project is licensed under the terms of the Apache-2 license. See [LICENSE](./LICENSE) for additional details. -------------------------------------------------------------------------------- /code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shenweichen/DSIN/e8ba406eeda0916214897d44866bffc419c3edb0/code/.DS_Store -------------------------------------------------------------------------------- /code/0_gen_sampled_data.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.preprocessing import LabelEncoder 7 | 8 | from config import FRAC 9 | 10 | if __name__ == "__main__": 11 | 12 | user = pd.read_csv('../raw_data/user_profile.csv') 13 | sample = pd.read_csv('../raw_data/raw_sample.csv') 14 | 15 | if not os.path.exists('../sampled_data/'): 16 | os.mkdir('../sampled_data/') 17 | 18 | if os.path.exists('../sampled_data/user_profile_' + str(FRAC) + '_.pkl') and os.path.exists( 19 | '../sampled_data/raw_sample_' + str(FRAC) + '_.pkl'): 20 | user_sub = pd.read_pickle( 21 | '../sampled_data/user_profile_' + str(FRAC) + '_.pkl') 22 | sample_sub = pd.read_pickle( 23 | '../sampled_data/raw_sample_' + str(FRAC) + '_.pkl') 24 | else: 25 | 26 | if FRAC < 1.0: 27 | user_sub = user.sample(frac=FRAC, random_state=1024) 28 | else: 29 | user_sub = user 30 | sample_sub = sample.loc[sample.user.isin(user_sub.userid.unique())] 31 | pd.to_pickle(user_sub, '../sampled_data/user_profile_' + 32 | str(FRAC) + '.pkl') 33 | pd.to_pickle(sample_sub, '../sampled_data/raw_sample_' + 34 | str(FRAC) + '.pkl') 35 | 36 | if os.path.exists('../raw_data/behavior_log_pv.pkl'): 37 | log = pd.read_pickle('../raw_data/behavior_log_pv.pkl') 38 | else: 39 | log = pd.read_csv('../raw_data/behavior_log.csv') 40 | log = log.loc[log['btag'] == 'pv'] 41 | pd.to_pickle(log, '../raw_data/behavior_log_pv.pkl') 42 | 43 | userset = user_sub.userid.unique() 44 | log = log.loc[log.user.isin(userset)] 45 | # pd.to_pickle(log, '../sampled_data/behavior_log_pv_user_filter_' + str(FRAC) + '_.pkl') 46 | 47 | ad = pd.read_csv('../raw_data/ad_feature.csv') 48 | ad['brand'] = ad['brand'].fillna(-1) 49 | 50 | lbe = LabelEncoder() 51 | # unique_cate_id = ad['cate_id'].unique() 52 | # log = log.loc[log.cate.isin(unique_cate_id)] 53 | 54 | unique_cate_id = np.concatenate( 55 | (ad['cate_id'].unique(), log['cate'].unique())) 56 | 57 | lbe.fit(unique_cate_id) 58 | ad['cate_id'] = lbe.transform(ad['cate_id']) + 1 59 | log['cate'] = lbe.transform(log['cate']) + 1 60 | 61 | lbe = LabelEncoder() 62 | # unique_brand = np.ad['brand'].unique() 63 | # log = log.loc[log.brand.isin(unique_brand)] 64 | 65 | unique_brand = np.concatenate( 66 | (ad['brand'].unique(), log['brand'].unique())) 67 | 68 | lbe.fit(unique_brand) 69 | ad['brand'] = lbe.transform(ad['brand']) + 1 70 | log['brand'] = lbe.transform(log['brand']) + 1 71 | 72 | log = log.loc[log.user.isin(sample_sub.user.unique())] 73 | log.drop(columns=['btag'], inplace=True) 74 | log = log.loc[log['time_stamp'] > 0] 75 | 76 | pd.to_pickle(ad, '../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl') 77 | pd.to_pickle( 78 | log, '../sampled_data/behavior_log_pv_user_filter_enc_' + str(FRAC) + '.pkl') 79 | 80 | print("0_gen_sampled_data done") 81 | -------------------------------------------------------------------------------- /code/1_gen_sessions.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import gc 3 | 4 | import pandas as pd 5 | from joblib import Parallel, delayed 6 | 7 | from config import FRAC 8 | 9 | 10 | def gen_session_list_dsin(uid, t): 11 | t.sort_values('time_stamp', inplace=True, ascending=True) 12 | last_time = 1483574401 # pd.to_datetime("2017-01-05 00:00:01") 13 | session_list = [] 14 | session = [] 15 | for row in t.iterrows(): 16 | time_stamp = row[1]['time_stamp'] 17 | # pd_time = pd.to_datetime(timestamp_datetime(time_stamp)) 18 | delta = time_stamp - last_time 19 | cate_id = row[1]['cate'] 20 | brand = row[1]['brand'] 21 | # delta.total_seconds() 22 | if delta > 30 * 60: # Session begin when current behavior and the last behavior are separated by more than 30 minutes. 23 | if len(session) > 2: # Only use sessions that have >2 behaviors 24 | session_list.append(session[:]) 25 | session = [] 26 | 27 | session.append((cate_id, brand, time_stamp)) 28 | last_time = time_stamp 29 | if len(session) > 2: 30 | session_list.append(session[:]) 31 | return uid, session_list 32 | 33 | 34 | def gen_session_list_din(uid, t): 35 | t.sort_values('time_stamp', inplace=True, ascending=True) 36 | session_list = [] 37 | session = [] 38 | for row in t.iterrows(): 39 | time_stamp = row[1]['time_stamp'] 40 | # pd_time = pd.to_datetime(timestamp_datetime()) 41 | # delta = pd_time - last_time 42 | cate_id = row[1]['cate'] 43 | brand = row[1]['brand'] 44 | session.append((cate_id, brand, time_stamp)) 45 | 46 | if len(session) > 2: 47 | session_list.append(session[:]) 48 | return uid, session_list 49 | 50 | 51 | def applyParallel(df_grouped, func, n_jobs, backend='multiprocessing'): 52 | """Use Parallel and delayed """ # backend='threading' 53 | results = Parallel(n_jobs=n_jobs, verbose=4, backend=backend)( 54 | delayed(func)(name, group) for name, group in df_grouped) 55 | 56 | return {k: v for k, v in results} 57 | 58 | 59 | def gen_user_hist_sessions(model, FRAC=0.25): 60 | if model not in ['din', 'dsin']: 61 | raise ValueError('model must be din or dmsn') 62 | 63 | print("gen " + model + " hist sess", FRAC) 64 | name = '../sampled_data/behavior_log_pv_user_filter_enc_' + str(FRAC) + '.pkl' 65 | data = pd.read_pickle(name) 66 | data = data.loc[data.time_stamp >= 1493769600] # 0503-0513 67 | # 0504~1493856000 68 | # 0503 1493769600 69 | 70 | user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl') 71 | 72 | n_samples = user.shape[0] 73 | print(n_samples) 74 | batch_size = 150000 75 | iters = (n_samples - 1) // batch_size + 1 76 | 77 | print("total", iters, "iters", "batch_size", batch_size) 78 | for i in range(0, iters): 79 | target_user = user['userid'].values[i * batch_size:(i + 1) * batch_size] 80 | sub_data = data.loc[data.user.isin(target_user)] 81 | print(i, 'iter start') 82 | df_grouped = sub_data.groupby('user') 83 | if model == 'din': 84 | user_hist_session = applyParallel( 85 | df_grouped, gen_session_list_din, n_jobs=10, backend='loky') 86 | else: 87 | user_hist_session = applyParallel( 88 | df_grouped, gen_session_list_dsin, n_jobs=10, backend='multiprocessing') 89 | pd.to_pickle(user_hist_session, '../sampled_data/user_hist_session_' + 90 | str(FRAC) + '_' + model + '_' + str(i) + '.pkl') 91 | print(i, 'pickled') 92 | del user_hist_session 93 | gc.collect() 94 | print(i, 'del') 95 | 96 | print("1_gen " + model + " hist sess done") 97 | 98 | 99 | if __name__ == "__main__": 100 | gen_user_hist_sessions('din', FRAC) 101 | gen_user_hist_sessions('dsin', FRAC) 102 | -------------------------------------------------------------------------------- /code/2_gen_dien_input.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat 8 | from sklearn.preprocessing import LabelEncoder, StandardScaler 9 | from tensorflow.python.keras.preprocessing.sequence import pad_sequences 10 | from tqdm import tqdm 11 | 12 | from config import DIN_SESS_MAX_LEN, FRAC, ID_OFFSET 13 | 14 | 15 | def gen_sess_feature_din(row): 16 | sess_max_len = DIN_SESS_MAX_LEN 17 | sess_input_dict = {'cate_id': [0], 'brand': [0]} 18 | sess_input_length = 0 19 | user, time_stamp = row[1]['user'], row[1]['time_stamp'] 20 | if user not in user_hist_session or len(user_hist_session[user]) == 0: 21 | 22 | sess_input_dict['cate_id'] = [0] 23 | sess_input_dict['brand'] = [0] 24 | sess_input_length = 0 25 | else: 26 | cur_sess = user_hist_session[user][0] 27 | for i in reversed(range(len(cur_sess))): 28 | if cur_sess[i][2] < time_stamp: 29 | sess_input_dict['cate_id'] = [e[0] 30 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 31 | sess_input_dict['brand'] = [e[1] 32 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 33 | sess_input_length = len(sess_input_dict['brand']) 34 | break 35 | return sess_input_dict['cate_id'], sess_input_dict['brand'], sess_input_length 36 | 37 | 38 | def sample(cate_id): 39 | global ad 40 | while True: 41 | i = np.random.randint(0, ad.shape[0]) 42 | sample_cate = ad.iloc[i]['cate_id'] 43 | if sample_cate != cate_id: 44 | break 45 | return sample_cate, ad.iloc[i]['brand'] 46 | 47 | 48 | def gen_sess_feature_dien(row): 49 | sess_max_len = DIN_SESS_MAX_LEN 50 | sess_input_dict = {'cate_id': [0], 'brand': [0]} 51 | neg_sess_input_dict = {'cate_id': [0], 'brand': [0]} 52 | sess_input_length = 0 53 | user, time_stamp = row[1]['user'], row[1]['time_stamp'] 54 | if user not in user_hist_session or len(user_hist_session[user]) == 0: 55 | 56 | sess_input_dict['cate_id'] = [0] 57 | sess_input_dict['brand'] = [0] 58 | neg_sess_input_dict['cate_id'] = [0] 59 | neg_sess_input_dict['brand'] = [0] 60 | sess_input_length = 0 61 | else: 62 | cur_sess = user_hist_session[user][0] 63 | for i in reversed(range(len(cur_sess))): 64 | if cur_sess[i][2] < time_stamp: 65 | sess_input_dict['cate_id'] = [e[0] 66 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 67 | sess_input_dict['brand'] = [e[1] 68 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 69 | 70 | neg_sess_input_dict = {'cate_id': [], 'brand': []} 71 | 72 | for c in sess_input_dict['cate_id']: 73 | neg_cate, neg_brand = sample(c) 74 | neg_sess_input_dict['cate_id'].append(neg_cate) 75 | neg_sess_input_dict['brand'].append(neg_brand) 76 | 77 | sess_input_length = len(sess_input_dict['brand']) 78 | break 79 | return sess_input_dict['cate_id'], sess_input_dict['brand'], neg_sess_input_dict['cate_id'], neg_sess_input_dict[ 80 | 'brand'], sess_input_length 81 | 82 | 83 | if __name__ == "__main__": 84 | 85 | user_hist_session = {} 86 | FILE_NUM = len( 87 | list( 88 | filter(lambda x: x.startswith('user_hist_session_' + str(FRAC) + '_din_'), os.listdir('../sampled_data/')))) 89 | 90 | print('total', FILE_NUM, 'files') 91 | for i in range(FILE_NUM): 92 | user_hist_session_ = pd.read_pickle( 93 | '../sampled_data/user_hist_session_' + str(FRAC) + '_din_' + str(i) + '.pkl') 94 | user_hist_session.update(user_hist_session_) 95 | del user_hist_session_ 96 | 97 | sample_sub = pd.read_pickle( 98 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 99 | 100 | ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl') 101 | 102 | sess_input_dict = {'cate_id': [], 'brand': []} 103 | neg_sess_input_dict = {'cate_id': [], 'brand': []} 104 | sess_input_length = [] 105 | for row in tqdm(sample_sub[['user', 'time_stamp']].iterrows()): 106 | a, b, n_a, n_b, c = gen_sess_feature_dien(row) 107 | sess_input_dict['cate_id'].append(a) 108 | sess_input_dict['brand'].append(b) 109 | neg_sess_input_dict['cate_id'].append(n_a) 110 | neg_sess_input_dict['brand'].append(n_b) 111 | sess_input_length.append(c) 112 | 113 | print('done') 114 | 115 | user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl') 116 | ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl') 117 | user = user.fillna(-1) 118 | user.rename( 119 | columns={'new_user_class_level ': 'new_user_class_level'}, inplace=True) 120 | 121 | sample_sub = pd.read_pickle( 122 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 123 | sample_sub.rename(columns={'user': 'userid'}, inplace=True) 124 | 125 | data = pd.merge(sample_sub, user, how='left', on='userid', ) 126 | data = pd.merge(data, ad, how='left', on='adgroup_id') 127 | 128 | sparse_features = ['userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level', 129 | 'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level', 'campaign_id', 130 | 'customer'] 131 | dense_features = ['price'] 132 | 133 | for feat in tqdm(sparse_features): 134 | lbe = LabelEncoder() # or Hash 135 | data[feat] = lbe.fit_transform(data[feat]) 136 | mms = StandardScaler() 137 | data[dense_features] = mms.fit_transform(data[dense_features]) 138 | 139 | sparse_feature_list = [SparseFeat(feat, vocabulary_size=data[feat].max( 140 | ) + ID_OFFSET) for feat in sparse_features + ['cate_id', 'brand']] 141 | 142 | dense_feature_list = [DenseFeat(feat, dimension=1) for feat in dense_features] 143 | sess_feature = ['cate_id', 'brand'] 144 | 145 | feature_dict = {} 146 | for feat in sparse_feature_list + dense_feature_list: 147 | feature_dict[feat.name] = data[feat.name].values 148 | for feat in sess_feature: 149 | feature_dict['hist_' + feat] = pad_sequences( 150 | sess_input_dict[feat], maxlen=DIN_SESS_MAX_LEN, padding='post') 151 | feature_dict['neg_hist_' + feat] = pad_sequences( 152 | neg_sess_input_dict[feat], maxlen=DIN_SESS_MAX_LEN, padding='post') 153 | feature_dict["seq_length"] = np.array(sess_input_length) 154 | 155 | sparse_feature_list += [ 156 | VarLenSparseFeat(SparseFeat('hist_cate_id', vocabulary_size=data['cate_id'].max( 157 | ) + ID_OFFSET, embedding_name='cate_id'), maxlen=DIN_SESS_MAX_LEN, length_name="seq_length"), 158 | VarLenSparseFeat(SparseFeat('hist_brand', vocabulary_size=data['brand'].max( 159 | ) + ID_OFFSET, embedding_name='brand'), maxlen=DIN_SESS_MAX_LEN, length_name="seq_length"), 160 | 161 | VarLenSparseFeat(SparseFeat('neg_hist_cate_id', vocabulary_size=data['cate_id'].max( 162 | ) + ID_OFFSET, embedding_name='cate_id'), maxlen=DIN_SESS_MAX_LEN, length_name="seq_length"), 163 | VarLenSparseFeat(SparseFeat('neg_hist_brand', vocabulary_size=data['brand'].max( 164 | ) + ID_OFFSET, embedding_name='brand'), maxlen=DIN_SESS_MAX_LEN, length_name="seq_length") 165 | ] 166 | 167 | feature_columns = sparse_feature_list + dense_feature_list 168 | model_input = feature_dict 169 | 170 | if not os.path.exists('../model_input/'): 171 | os.mkdir('../model_input/') 172 | 173 | pd.to_pickle(model_input, '../model_input/dien_input_' + 174 | str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl') 175 | pd.to_pickle(data['clk'].values, '../model_input/dien_label_' + 176 | str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl') 177 | pd.to_pickle(feature_columns, 178 | '../model_input/dien_fd_' + str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl', ) 179 | 180 | print("gen dien input done") 181 | -------------------------------------------------------------------------------- /code/2_gen_din_input.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from config import DIN_SESS_MAX_LEN, FRAC, ID_OFFSET 8 | from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat 9 | from sklearn.preprocessing import LabelEncoder, StandardScaler 10 | from tensorflow.python.keras.preprocessing.sequence import pad_sequences 11 | from tqdm import tqdm 12 | 13 | 14 | def gen_sess_feature_din(row): 15 | sess_max_len = DIN_SESS_MAX_LEN 16 | sess_input_dict = {'cate_id': [0], 'brand': [0]} 17 | sess_input_length = 0 18 | user, time_stamp = row[1]['user'], row[1]['time_stamp'] 19 | if user not in user_hist_session or len(user_hist_session[user]) == 0: 20 | 21 | sess_input_dict['cate_id'] = [0] 22 | sess_input_dict['brand'] = [0] 23 | sess_input_length = 0 24 | else: 25 | cur_sess = user_hist_session[user][0] 26 | for i in reversed(range(len(cur_sess))): 27 | if cur_sess[i][2] < time_stamp: 28 | sess_input_dict['cate_id'] = [e[0] 29 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 30 | sess_input_dict['brand'] = [e[1] 31 | for e in cur_sess[max(0, i + 1 - sess_max_len):i + 1]] 32 | sess_input_length = len(sess_input_dict['brand']) 33 | break 34 | return sess_input_dict['cate_id'], sess_input_dict['brand'], sess_input_length 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | user_hist_session = {} 40 | FILE_NUM = len( 41 | list( 42 | filter(lambda x: x.startswith('user_hist_session_' + str(FRAC) + '_din_'), os.listdir('../sampled_data/')))) 43 | 44 | print('total', FILE_NUM, 'files') 45 | for i in range(FILE_NUM): 46 | user_hist_session_ = pd.read_pickle( 47 | '../sampled_data/user_hist_session_' + str(FRAC) + '_din_' + str(i) + '.pkl') 48 | user_hist_session.update(user_hist_session_) 49 | del user_hist_session_ 50 | 51 | sample_sub = pd.read_pickle( 52 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 53 | 54 | sess_input_dict = {'cate_id': [], 'brand': []} 55 | sess_input_length = [] 56 | for row in tqdm(sample_sub[['user', 'time_stamp']].iterrows()): 57 | a, b, c = gen_sess_feature_din(row) 58 | sess_input_dict['cate_id'].append(a) 59 | sess_input_dict['brand'].append(b) 60 | sess_input_length.append(c) 61 | 62 | print('done') 63 | 64 | user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl') 65 | ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl') 66 | user = user.fillna(-1) 67 | user.rename( 68 | columns={'new_user_class_level ': 'new_user_class_level'}, inplace=True) 69 | 70 | sample_sub = pd.read_pickle( 71 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 72 | sample_sub.rename(columns={'user': 'userid'}, inplace=True) 73 | 74 | data = pd.merge(sample_sub, user, how='left', on='userid', ) 75 | data = pd.merge(data, ad, how='left', on='adgroup_id') 76 | 77 | sparse_features = ['userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level', 78 | 'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level', 'campaign_id', 79 | 'customer'] 80 | dense_features = ['price'] 81 | 82 | for feat in tqdm(sparse_features): 83 | lbe = LabelEncoder() # or Hash 84 | data[feat] = lbe.fit_transform(data[feat]) 85 | mms = StandardScaler() 86 | data[dense_features] = mms.fit_transform(data[dense_features]) 87 | 88 | sparse_feature_list = [SparseFeat(feat, vocabulary_size=data[feat].max( 89 | ) + ID_OFFSET) for feat in sparse_features + ['cate_id', 'brand']] 90 | 91 | dense_feature_list = [DenseFeat(feat, dimension=1) for feat in dense_features] 92 | 93 | sess_feature = ['cate_id', 'brand'] 94 | 95 | feature_dict = {} 96 | for feat in sparse_feature_list + dense_feature_list: 97 | feature_dict[feat.name] = data[feat.name].values 98 | for feat in sess_feature: 99 | feature_dict['hist_' + feat] = pad_sequences( 100 | sess_input_dict[feat], maxlen=DIN_SESS_MAX_LEN, padding='post') 101 | sparse_feature_list += [ 102 | VarLenSparseFeat(SparseFeat('hist_cate_id', vocabulary_size=data['cate_id'].max( 103 | ) + ID_OFFSET, embedding_name='cate_id'), maxlen=DIN_SESS_MAX_LEN), 104 | VarLenSparseFeat(SparseFeat('hist_brand', vocabulary_size=data['brand'].max( 105 | ) + ID_OFFSET, embedding_name='brand'), maxlen=DIN_SESS_MAX_LEN)] 106 | feature_columns = sparse_feature_list + dense_feature_list 107 | model_input = feature_dict 108 | if not os.path.exists('../model_input/'): 109 | os.mkdir('../model_input/') 110 | 111 | pd.to_pickle(model_input, '../model_input/din_input_' + 112 | str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl') 113 | pd.to_pickle([np.array(sess_input_length)], '../model_input/din_input_len_' + 114 | str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl') 115 | 116 | pd.to_pickle(data['clk'].values, '../model_input/din_label_' + 117 | str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl') 118 | pd.to_pickle(feature_columns, 119 | '../model_input/din_fd_' + str(FRAC) + '_' + str(DIN_SESS_MAX_LEN) + '.pkl', ) 120 | 121 | print("gen din input done") 122 | -------------------------------------------------------------------------------- /code/2_gen_dsin_input.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from config import DSIN_SESS_COUNT, DSIN_SESS_MAX_LEN, FRAC, ID_OFFSET 8 | from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat 9 | from sklearn.preprocessing import LabelEncoder, StandardScaler 10 | from tensorflow.python.keras.preprocessing.sequence import pad_sequences 11 | from tqdm import tqdm 12 | 13 | FRAC = FRAC 14 | SESS_COUNT = DSIN_SESS_COUNT 15 | 16 | 17 | def gen_sess_feature_dsin(row): 18 | sess_count = DSIN_SESS_COUNT 19 | sess_max_len = DSIN_SESS_MAX_LEN 20 | sess_input_dict = {} 21 | sess_input_length_dict = {} 22 | for i in range(sess_count): 23 | sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []} 24 | sess_input_length_dict['sess_' + str(i)] = 0 25 | sess_length = 0 26 | user, time_stamp = row[1]['user'], row[1]['time_stamp'] 27 | # sample_time = pd.to_datetime(timestamp_datetime(time_stamp )) 28 | if user not in user_hist_session: 29 | for i in range(sess_count): 30 | sess_input_dict['sess_' + str(i)]['cate_id'] = [0] 31 | sess_input_dict['sess_' + str(i)]['brand'] = [0] 32 | sess_input_length_dict['sess_' + str(i)] = 0 33 | sess_length = 0 34 | else: 35 | valid_sess_count = 0 36 | last_sess_idx = len(user_hist_session[user]) - 1 37 | for i in reversed(range(len(user_hist_session[user]))): 38 | cur_sess = user_hist_session[user][i] 39 | if cur_sess[0][2] < time_stamp: 40 | in_sess_count = 1 41 | for j in range(1, len(cur_sess)): 42 | if cur_sess[j][2] < time_stamp: 43 | in_sess_count += 1 44 | if in_sess_count > 2: 45 | sess_input_dict['sess_0']['cate_id'] = [e[0] for e in cur_sess[max(0, 46 | in_sess_count - sess_max_len):in_sess_count]] 47 | sess_input_dict['sess_0']['brand'] = [e[1] for e in 48 | cur_sess[max(0, in_sess_count - sess_max_len):in_sess_count]] 49 | sess_input_length_dict['sess_0'] = min( 50 | sess_max_len, in_sess_count) 51 | last_sess_idx = i 52 | valid_sess_count += 1 53 | break 54 | for i in range(1, sess_count): 55 | if last_sess_idx - i >= 0: 56 | cur_sess = user_hist_session[user][last_sess_idx - i] 57 | sess_input_dict['sess_' + str(i)]['cate_id'] = [e[0] 58 | for e in cur_sess[-sess_max_len:]] 59 | sess_input_dict['sess_' + str(i)]['brand'] = [e[1] 60 | for e in cur_sess[-sess_max_len:]] 61 | sess_input_length_dict['sess_' + 62 | str(i)] = min(sess_max_len, len(cur_sess)) 63 | valid_sess_count += 1 64 | else: 65 | sess_input_dict['sess_' + str(i)]['cate_id'] = [0] 66 | sess_input_dict['sess_' + str(i)]['brand'] = [0] 67 | sess_input_length_dict['sess_' + str(i)] = 0 68 | 69 | sess_length = valid_sess_count 70 | return sess_input_dict, sess_input_length_dict, sess_length 71 | 72 | 73 | if __name__ == "__main__": 74 | 75 | user_hist_session = {} 76 | FILE_NUM = len( 77 | list(filter(lambda x: x.startswith('user_hist_session_' + str(FRAC) + '_dsin_'), 78 | os.listdir('../sampled_data/')))) 79 | 80 | print('total', FILE_NUM, 'files') 81 | 82 | for i in range(FILE_NUM): 83 | user_hist_session_ = pd.read_pickle( 84 | '../sampled_data/user_hist_session_' + str(FRAC) + '_dsin_' + str(i) + '.pkl') # 19,34 85 | user_hist_session.update(user_hist_session_) 86 | del user_hist_session_ 87 | 88 | sample_sub = pd.read_pickle( 89 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 90 | 91 | index_list = [] 92 | sess_input_dict = {} 93 | sess_input_length_dict = {} 94 | for i in range(SESS_COUNT): 95 | sess_input_dict['sess_' + str(i)] = {'cate_id': [], 'brand': []} 96 | sess_input_length_dict['sess_' + str(i)] = [] 97 | 98 | sess_length_list = [] 99 | for row in tqdm(sample_sub[['user', 'time_stamp']].iterrows()): 100 | sess_input_dict_, sess_input_length_dict_, sess_length = gen_sess_feature_dsin( 101 | row) 102 | # index_list.append(index) 103 | for i in range(SESS_COUNT): 104 | sess_name = 'sess_' + str(i) 105 | sess_input_dict[sess_name]['cate_id'].append( 106 | sess_input_dict_[sess_name]['cate_id']) 107 | sess_input_dict[sess_name]['brand'].append( 108 | sess_input_dict_[sess_name]['brand']) 109 | sess_input_length_dict[sess_name].append( 110 | sess_input_length_dict_[sess_name]) 111 | sess_length_list.append(sess_length) 112 | 113 | print('done') 114 | 115 | user = pd.read_pickle('../sampled_data/user_profile_' + str(FRAC) + '.pkl') 116 | ad = pd.read_pickle('../sampled_data/ad_feature_enc_' + str(FRAC) + '.pkl') 117 | user = user.fillna(-1) 118 | user.rename( 119 | columns={'new_user_class_level ': 'new_user_class_level'}, inplace=True) 120 | 121 | sample_sub = pd.read_pickle( 122 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 123 | sample_sub.rename(columns={'user': 'userid'}, inplace=True) 124 | 125 | data = pd.merge(sample_sub, user, how='left', on='userid', ) 126 | data = pd.merge(data, ad, how='left', on='adgroup_id') 127 | 128 | sparse_features = ['userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', 'final_gender_code', 'age_level', 129 | 'pvalue_level', 'shopping_level', 'occupation', 'new_user_class_level', 'campaign_id', 130 | 'customer'] 131 | 132 | dense_features = ['price'] 133 | 134 | for feat in tqdm(sparse_features): 135 | lbe = LabelEncoder() # or Hash 136 | data[feat] = lbe.fit_transform(data[feat]) 137 | mms = StandardScaler() 138 | data[dense_features] = mms.fit_transform(data[dense_features]) 139 | 140 | sparse_feature_list = [SparseFeat(feat, vocabulary_size=data[feat].max( 141 | ) + ID_OFFSET) for feat in sparse_features + ['cate_id', 'brand']] 142 | dense_feature_list = [DenseFeat(feat, dimension=1) for feat in dense_features] 143 | sess_feature = ['cate_id', 'brand'] 144 | 145 | feature_dict = {} 146 | for feat in sparse_feature_list + dense_feature_list: 147 | feature_dict[feat.name] = data[feat.name].values 148 | for i in tqdm(range(SESS_COUNT)): 149 | sess_name = 'sess_' + str(i) 150 | for feat in sess_feature: 151 | feature_dict[sess_name + '_' + feat] = pad_sequences( 152 | sess_input_dict[sess_name][feat], maxlen=DSIN_SESS_MAX_LEN, padding='post') 153 | sparse_feature_list.append( 154 | VarLenSparseFeat(SparseFeat(sess_name + '_' + feat, vocabulary_size=data[feat].max( 155 | ) + ID_OFFSET, embedding_name='feat'), 156 | maxlen=DSIN_SESS_MAX_LEN)) 157 | feature_dict['sess_length'] = np.array(sess_length_list) 158 | 159 | feature_columns = sparse_feature_list + dense_feature_list 160 | model_input = feature_dict 161 | 162 | if not os.path.exists('../model_input/'): 163 | os.mkdir('../model_input/') 164 | 165 | pd.to_pickle(model_input, '../model_input/dsin_input_' + 166 | str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 167 | pd.to_pickle(data['clk'].values, '../model_input/dsin_label_' + 168 | str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 169 | pd.to_pickle(feature_columns, 170 | '../model_input/dsin_fd_' + str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 171 | print("gen dsin input done") 172 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | FRAC = 0.25 2 | 3 | DIN_SESS_MAX_LEN = 50 4 | 5 | DSIN_SESS_COUNT = 5 6 | DSIN_SESS_MAX_LEN = 10 7 | ID_OFFSET = 1000 -------------------------------------------------------------------------------- /code/train_dien.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | import pandas as pd 5 | import tensorflow as tf 6 | from sklearn.metrics import log_loss, roc_auc_score 7 | from tensorflow.python.keras import backend as K 8 | 9 | from config import DIN_SESS_MAX_LEN, FRAC 10 | from deepctr.models import DIEN 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | tfconfig = tf.ConfigProto() 14 | tfconfig.gpu_options.allow_growth = True 15 | K.set_session(tf.Session(config=tfconfig)) 16 | 17 | if __name__ == "__main__": 18 | DIEN_NEG_SAMPLING = True 19 | FRAC = FRAC 20 | SESS_MAX_LEN = DIN_SESS_MAX_LEN 21 | dnn_feature_columns = pd.read_pickle('../model_input/dien_fd_' + 22 | str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 23 | model_input = pd.read_pickle( 24 | '../model_input/dien_input_' + str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 25 | label = pd.read_pickle('../model_input/dien_label_' + 26 | str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 27 | 28 | sample_sub = pd.read_pickle( 29 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 30 | 31 | sample_sub['idx'] = list(range(sample_sub.shape[0])) 32 | train_idx = sample_sub.loc[sample_sub.time_stamp < 33 | 1494633600, 'idx'].values 34 | test_idx = sample_sub.loc[sample_sub.time_stamp >= 35 | 1494633600, 'idx'].values 36 | 37 | train_input = {k: v[train_idx] for k, v in model_input.items()} 38 | test_input = {k: v[test_idx] for k, v in model_input.items()} 39 | 40 | train_label = label[train_idx] 41 | test_label = label[test_idx] 42 | 43 | sess_len_max = SESS_MAX_LEN 44 | BATCH_SIZE = 4096 45 | history_feature_list = ['cate_id', 'brand'] 46 | TEST_BATCH_SIZE = 2 ** 14 47 | 48 | model = DIEN(dnn_feature_columns, history_feature_list, 49 | gru_type="AUGRU", use_negsampling=DIEN_NEG_SAMPLING, dnn_hidden_units=(200, 80), 50 | dnn_activation='relu', 51 | att_hidden_units=(64, 16)) 52 | 53 | model.compile('adagrad', 'binary_crossentropy', 54 | metrics=['binary_crossentropy', ]) 55 | 56 | hist_ = model.fit(train_input, train_label, batch_size=BATCH_SIZE, 57 | epochs=1, initial_epoch=0, verbose=1, ) 58 | pred_ans = model.predict(test_input, TEST_BATCH_SIZE) 59 | 60 | print() 61 | 62 | print("test LogLoss", round(log_loss(test_label, pred_ans), 4), "test AUC", 63 | round(roc_auc_score(test_label, pred_ans), 4)) 64 | -------------------------------------------------------------------------------- /code/train_din.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | import pandas as pd 5 | import tensorflow as tf 6 | from config import DIN_SESS_MAX_LEN, FRAC 7 | from deepctr.models import DIN 8 | from sklearn.metrics import log_loss, roc_auc_score 9 | from tensorflow.python.keras import backend as K 10 | 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | tfconfig = tf.ConfigProto() 13 | tfconfig.gpu_options.allow_growth = True 14 | K.set_session(tf.Session(config=tfconfig)) 15 | 16 | if __name__ == "__main__": 17 | FRAC = FRAC 18 | SESS_MAX_LEN = DIN_SESS_MAX_LEN 19 | dnn_feature_columns = pd.read_pickle('../model_input/din_fd_' + 20 | str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 21 | model_input = pd.read_pickle( 22 | '../model_input/din_input_' + str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 23 | label = pd.read_pickle('../model_input/din_label_' + 24 | str(FRAC) + '_' + str(SESS_MAX_LEN) + '.pkl') 25 | 26 | sample_sub = pd.read_pickle( 27 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 28 | 29 | sample_sub['idx'] = list(range(sample_sub.shape[0])) 30 | train_idx = sample_sub.loc[sample_sub.time_stamp < 31 | 1494633600, 'idx'].values 32 | test_idx = sample_sub.loc[sample_sub.time_stamp >= 33 | 1494633600, 'idx'].values 34 | 35 | train_input = {k: v[train_idx] for k, v in model_input.items()} 36 | test_input = {k: v[test_idx] for k, v in model_input.items()} 37 | train_label = label[train_idx] 38 | test_label = label[test_idx] 39 | 40 | sess_len_max = SESS_MAX_LEN 41 | BATCH_SIZE = 4096 42 | 43 | sess_feature_list = ['cate_id', 'brand'] 44 | TEST_BATCH_SIZE = 2 ** 14 45 | 46 | model = DIN(dnn_feature_columns, sess_feature_list, dnn_hidden_units=(200, 80), dnn_activation='relu', 47 | att_hidden_size=(64, 16)) 48 | 49 | model.compile('adagrad', 'binary_crossentropy', 50 | metrics=['binary_crossentropy', ]) 51 | 52 | hist_ = model.fit(train_input, train_label, 53 | batch_size=BATCH_SIZE, epochs=1, initial_epoch=0, verbose=1, ) 54 | pred_ans = model.predict(test_input, TEST_BATCH_SIZE) 55 | 56 | print() 57 | print("test LogLoss", round(log_loss(test_label, pred_ans), 4), "test AUC", 58 | round(roc_auc_score(test_label, pred_ans), 4)) 59 | -------------------------------------------------------------------------------- /code/train_dsin.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | 4 | import pandas as pd 5 | import tensorflow as tf 6 | from sklearn.metrics import log_loss, roc_auc_score 7 | from tensorflow.python.keras import backend as K 8 | 9 | from config import DSIN_SESS_COUNT, DSIN_SESS_MAX_LEN, FRAC 10 | from deepctr.models import DSIN 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | tfconfig = tf.ConfigProto() 14 | tfconfig.gpu_options.allow_growth = True 15 | K.set_session(tf.Session(config=tfconfig)) 16 | 17 | if __name__ == "__main__": 18 | SESS_COUNT = DSIN_SESS_COUNT 19 | SESS_MAX_LEN = DSIN_SESS_MAX_LEN 20 | 21 | dnn_feature_columns = pd.read_pickle('../model_input/dsin_fd_' + 22 | str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 23 | model_input = pd.read_pickle( 24 | '../model_input/dsin_input_' + str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 25 | label = pd.read_pickle('../model_input/dsin_label_' + 26 | str(FRAC) + '_' + str(SESS_COUNT) + '.pkl') 27 | 28 | sample_sub = pd.read_pickle( 29 | '../sampled_data/raw_sample_' + str(FRAC) + '.pkl') 30 | 31 | sample_sub['idx'] = list(range(sample_sub.shape[0])) 32 | train_idx = sample_sub.loc[sample_sub.time_stamp < 33 | 1494633600, 'idx'].values 34 | test_idx = sample_sub.loc[sample_sub.time_stamp >= 35 | 1494633600, 'idx'].values 36 | 37 | train_input = {k: v[train_idx] for k, v in model_input.items()} 38 | test_input = {k: v[test_idx] for k, v in model_input.items()} 39 | 40 | train_label = label[train_idx] 41 | test_label = label[test_idx] 42 | 43 | sess_count = SESS_COUNT 44 | sess_len_max = SESS_MAX_LEN 45 | BATCH_SIZE = 4096 46 | 47 | sess_feature_list = ['cate_id', 'brand'] 48 | TEST_BATCH_SIZE = 2 ** 14 49 | 50 | model = DSIN(dnn_feature_columns, sess_feature_list, sess_max_count=sess_count, bias_encoding=False, 51 | att_embedding_size=1, att_head_num=8, dnn_hidden_units=(200, 80), dnn_activation='relu', 52 | ) 53 | 54 | model.compile('adagrad', 'binary_crossentropy', 55 | metrics=['binary_crossentropy', ]) 56 | 57 | hist_ = model.fit(train_input, train_label, batch_size=BATCH_SIZE, 58 | epochs=1, initial_epoch=0, verbose=1, ) 59 | 60 | pred_ans = model.predict(test_input, TEST_BATCH_SIZE) 61 | 62 | print() 63 | print("test LogLoss", round(log_loss(test_label, pred_ans), 4), "test AUC", 64 | round(roc_auc_score(test_label, pred_ans), 4)) 65 | -------------------------------------------------------------------------------- /raw_data/README.md: -------------------------------------------------------------------------------- 1 | Download Dataset [Ad Display/Click Data on Taobao.com](https://tianchi.aliyun.com/dataset/dataDetail?dataId=56) 2 | 3 | Extract the files into the this directory 4 | 5 | - raw_sample.csv 6 | - ad_feature.csv 7 | - user_profile.csv 8 | - raw_behavior_log.csv -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | deepctr==0.9.3 2 | joblib==1.0.1 3 | numpy==1.16.6 4 | pandas==1.1.5 5 | requests==2.27.1 6 | scikit-learn==0.24.2 7 | tensorflow-gpu==1.4.0 8 | tqdm==4.36.1 --------------------------------------------------------------------------------