├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── __init__.py ├── batch_data.py ├── data_util.py ├── item_pv_dataloader.py ├── item_pv_dataset.py ├── prod_search_dataloader.py └── prod_search_dataset.py ├── main.py ├── models ├── PV.py ├── PVC.py ├── __init__.py ├── item_transformer.py ├── neural.py ├── optimizers.py ├── ps_model.py ├── text_encoder.py └── transformer.py ├── others ├── __init__.py ├── logging.py └── util.py ├── trainer.py └── tune_para.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | ConvProductSearchNF 3 | src 4 | HEM 5 | *log* 6 | *bak* 7 | *.sh 8 | *.pyc 9 | filter.py 10 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2020] [The ProdSearch Authors] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ProdSearch 2 | Personalized Product Search with Product Reviews 3 | 4 | ## Amazon Search Dataset: 5 | For each user, we sort his/her purchased items by time and divide items to train/validation/test in a chronological order. 6 | 7 | Download the code and follow the ''Data Preparation'' section in this [link](https://github.com/QingyaoAi/Explainable-Product-Search-with-a-Dynamic-Relation-Embedding-Model) except for splitting data in 4.3. 8 | Use "python ./utils/AmazonDataset/sequentially_split_train_test_data.py 0.2 0.3" instead. 9 | 10 | ## Train/Test a TEM model [1] 11 | To train a transformer-based embedding model (TEM) [1], run 12 | 13 | ``` 14 | python main.py --model_name item_transformer \ # TEM 15 | --mode train \ # set it to test when evaluating a model 16 | --pretrain_emb_dir PATH/TO/PRETRAINED_EMB_DIR \ # DATA_DIR for the pretrained word embeddings using reviews. If set to "", embeddings will be trained from scratch 17 | --data_dir PATH/TO/DATA \ # generated when preparing the data, e.g. Amazon/reviews_Sports_and_Outdoors_5.json.gz.stem.nostop/mincount_5 18 | --input_train_dir PATH/TO/SPLIT_DATA \ # Amazon/reviews_Sports_and_Outdoors_5.json.gz.stem.nostop/min_count5/seq_query_split 19 | --save_dir PATH/TO/SAVE/TRAINED/MODELS \ # where to store or load models. 20 | --decay_method adam \ # use the weight decay method in adam instead of noam 21 | --max_train_epoch 20 --lr 0.0005 --batch_size 384 \ 22 | --uprev_review_limit 20 \ # the number of historically purchased items used for user. 23 | --embedding_size 128 \ 24 | --inter_layers 1 \ # the number of layers for transformer 25 | --ff_size 512 --heads 8 # other hyper-parameters that may need tune for training. 26 | ``` 27 | ## Train/Test a RTM model [2] 28 | If you want to run a review-based transformer model (RTM) [2], simply use a different model_name: 29 | ``` 30 | --model_name review_transformer 31 | ``` 32 | ## References 33 | [1] Keping Bi, Qingyao Ai, W. Bruce Croft. A Transformer-based Embedding Model for Personalized Product Search. In Proceedings of SIGIR'20. 34 | 35 | [2] Keping Bi, Qingyao Ai, W. Bruce Croft. Learning a Fine-Grained Review-based Transformer Model for Personalized Product Search. In Proceedings of SIGIR'21. 36 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .prod_search_dataset import ProdSearchDataset 2 | from .prod_search_dataloader import ProdSearchDataLoader 3 | from .item_pv_dataset import ItemPVDataset 4 | from .item_pv_dataloader import ItemPVDataloader 5 | -------------------------------------------------------------------------------- /data/batch_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class ItemPVBatch(object): 4 | def __init__(self, query_word_idxs, target_prod_idxs, 5 | u_item_idxs, pos_iword_idxs = [], query_idxs=[], 6 | user_idxs=[], candi_prod_idxs=[], to_tensor=True): #"cpu" or "cuda" 7 | self.query_word_idxs = query_word_idxs 8 | self.target_prod_idxs = target_prod_idxs 9 | self.u_item_idxs = u_item_idxs 10 | self.pos_iword_idxs = pos_iword_idxs 11 | self.query_idxs = query_idxs 12 | self.user_idxs = user_idxs 13 | self.candi_prod_idxs = candi_prod_idxs 14 | if to_tensor: 15 | self.to_tensor() 16 | 17 | def to_tensor(self): 18 | self.query_word_idxs = torch.tensor(self.query_word_idxs) 19 | self.target_prod_idxs = torch.tensor(self.target_prod_idxs) 20 | self.candi_prod_idxs = torch.tensor(self.candi_prod_idxs) 21 | self.u_item_idxs = torch.tensor(self.u_item_idxs) 22 | self.pos_iword_idxs = torch.tensor(self.pos_iword_idxs) 23 | 24 | def to(self, device): 25 | if device == "cpu": 26 | return self 27 | else: 28 | query_word_idxs = self.query_word_idxs.to(device) 29 | target_prod_idxs = self.target_prod_idxs.to(device) 30 | candi_prod_idxs = self.candi_prod_idxs.to(device) 31 | u_item_idxs = self.u_item_idxs.to(device) 32 | pos_iword_idxs = self.pos_iword_idxs.to(device) 33 | 34 | return self.__class__( 35 | query_word_idxs, target_prod_idxs, u_item_idxs, 36 | pos_iword_idxs, self.query_idxs, self.user_idxs, 37 | candi_prod_idxs, to_tensor=False) 38 | 39 | class ItemPVBatch_seg(object): 40 | def __init__(self, query_word_idxs, target_prod_idxs, candi_prod_idxs, pos_seg_idxs, 41 | neg_seg_idxs, pos_seq_item_idxs, neg_seq_item_idxs, 42 | pos_iword_idxs = None, query_idxs=None, 43 | user_idxs=None, to_tensor=True): #"cpu" or "cuda" 44 | self.query_idxs = query_idxs 45 | self.target_prod_idxs = target_prod_idxs 46 | self.candi_prod_idxs = candi_prod_idxs 47 | self.user_idxs = user_idxs 48 | self.query_word_idxs = query_word_idxs 49 | self.pos_seg_idxs = pos_seg_idxs 50 | self.neg_seg_idxs = neg_seg_idxs 51 | self.pos_seq_item_idxs = pos_seq_item_idxs 52 | self.neg_seq_item_idxs = neg_seq_item_idxs 53 | self.pos_iword_idxs = pos_iword_idxs 54 | if to_tensor: 55 | self.to_tensor() 56 | 57 | def to_tensor(self): 58 | self.query_word_idxs = torch.tensor(self.query_word_idxs) 59 | self.target_prod_idxs = torch.tensor(self.target_prod_idxs) 60 | #self.candi_prod_idxs = torch.tensor(self.candi_prod_idxs) 61 | self.neg_seg_idxs = torch.tensor(self.neg_seg_idxs) 62 | self.neg_seq_item_idxs = torch.tensor(self.neg_seq_item_idxs) 63 | if self.pos_seg_idxs is not None: 64 | self.pos_seg_idxs = torch.tensor(self.pos_seg_idxs) 65 | if self.pos_seq_item_idxs is not None: 66 | self.pos_seq_item_idxs = torch.tensor(self.pos_seq_item_idxs) 67 | if self.pos_iword_idxs is not None: 68 | self.pos_iword_idxs = torch.tensor(self.pos_iword_idxs) 69 | 70 | def to(self, device): 71 | if device == "cpu": 72 | return self 73 | else: 74 | query_word_idxs = self.query_word_idxs.to(device) 75 | target_prod_idxs = self.target_prod_idxs.to(device) 76 | #candi_prod_idxs = self.candi_prod_idxs.to(device) 77 | neg_seg_idxs = self.neg_seg_idxs.to(device) 78 | neg_seq_item_idxs = self.neg_seq_item_idxs.to(device) 79 | pos_seg_idxs = self.pos_seg_idxs 80 | if self.pos_seg_idxs is not None: 81 | pos_seg_idxs = self.pos_seg_idxs.to(device) 82 | pos_seq_item_idxs = self.pos_seq_item_idxs 83 | if self.pos_seq_item_idxs is not None: 84 | pos_seq_item_idxs = self.pos_seq_item_idxs.to(device) 85 | pos_iword_idxs = self.pos_iword_idxs 86 | if self.pos_iword_idxs is not None: 87 | pos_iword_idxs = self.pos_iword_idxs.to(device) 88 | 89 | return self.__class__( 90 | query_word_idxs, target_prod_idxs, self.candi_prod_idxs, pos_seg_idxs, 91 | neg_seg_idxs, pos_seq_item_idxs, neg_seq_item_idxs, 92 | pos_iword_idxs, self.query_idxs, self.user_idxs, to_tensor=False) 93 | 94 | class ProdSearchTestBatch(object): 95 | def __init__(self, query_idxs, user_idxs, target_prod_idxs, candi_prod_idxs, query_word_idxs, 96 | candi_prod_ridxs, candi_seg_idxs, candi_seq_user_idxs, candi_seq_item_idxs, to_tensor=True): #"cpu" or "cuda" 97 | self.query_idxs = query_idxs 98 | #for output query, user to ranklist 99 | self.user_idxs = user_idxs 100 | self.target_prod_idxs = target_prod_idxs 101 | self.candi_prod_idxs = candi_prod_idxs 102 | self.candi_seq_user_idxs = candi_seq_user_idxs 103 | self.candi_seq_item_idxs = candi_seq_item_idxs 104 | 105 | self.query_word_idxs = query_word_idxs 106 | self.candi_prod_ridxs = candi_prod_ridxs 107 | self.candi_seg_idxs = candi_seg_idxs 108 | if to_tensor: 109 | self.to_tensor() 110 | 111 | def to_tensor(self): 112 | #self.query_idxs = torch.tensor(query_idxs) 113 | #self.user_idxs = torch.tensor(user_idxs) 114 | #self.target_prod_idxs = torch.tensor(target_prod_idxs) 115 | #self.candi_prod_idxs = torch.tensor(candi_prod_idxs) 116 | self.query_word_idxs = torch.tensor(self.query_word_idxs) 117 | self.candi_prod_ridxs = torch.tensor(self.candi_prod_ridxs) 118 | self.candi_seg_idxs = torch.tensor(self.candi_seg_idxs) 119 | self.candi_seq_user_idxs = torch.tensor(self.candi_seq_user_idxs) 120 | self.candi_seq_item_idxs = torch.tensor(self.candi_seq_item_idxs) 121 | 122 | def to(self, device): 123 | if device == "cpu": 124 | return self 125 | else: 126 | query_word_idxs = self.query_word_idxs.to(device) 127 | candi_prod_ridxs = self.candi_prod_ridxs.to(device) 128 | candi_seg_idxs = self.candi_seg_idxs.to(device) 129 | candi_seq_user_idxs = self.candi_seq_user_idxs.to(device) 130 | candi_seq_item_idxs = self.candi_seq_item_idxs.to(device) 131 | 132 | return self.__class__(self.query_idxs, self.user_idxs, 133 | self.target_prod_idxs, self.candi_prod_idxs, query_word_idxs, 134 | candi_prod_ridxs, candi_seg_idxs, candi_seq_user_idxs, 135 | candi_seq_item_idxs, to_tensor=False) 136 | 137 | class ProdSearchTrainBatch(object): 138 | def __init__(self, query_word_idxs, pos_prod_ridxs, pos_seg_idxs, 139 | pos_prod_rword_idxs, pos_prod_rword_masks, 140 | neg_prod_ridxs, neg_seg_idxs, 141 | pos_user_idxs, neg_user_idxs, 142 | pos_item_idxs, neg_item_idxs, 143 | neg_prod_rword_idxs=None, 144 | neg_prod_rword_masks=None, 145 | pos_prod_rword_idxs_pvc=None, 146 | neg_prod_rword_idxs_pvc=None, 147 | to_tensor=True): #"cpu" or "cuda" 148 | self.query_word_idxs = query_word_idxs 149 | self.pos_prod_ridxs = pos_prod_ridxs 150 | self.pos_seg_idxs = pos_seg_idxs 151 | self.pos_prod_rword_idxs = pos_prod_rword_idxs 152 | self.pos_prod_rword_masks = pos_prod_rword_masks 153 | self.neg_prod_ridxs = neg_prod_ridxs 154 | self.neg_seg_idxs = neg_seg_idxs 155 | self.neg_prod_rword_idxs = neg_prod_rword_idxs 156 | self.neg_prod_rword_masks = neg_prod_rword_masks 157 | self.pos_user_idxs = pos_user_idxs 158 | self.neg_user_idxs = neg_user_idxs 159 | self.pos_item_idxs = pos_item_idxs 160 | self.neg_item_idxs = neg_item_idxs 161 | #for pvc 162 | self.neg_prod_rword_idxs_pvc = neg_prod_rword_idxs_pvc 163 | self.pos_prod_rword_idxs_pvc = pos_prod_rword_idxs_pvc 164 | if to_tensor: 165 | self.to_tensor() 166 | 167 | def to_tensor(self): 168 | self.query_word_idxs = torch.tensor(self.query_word_idxs) 169 | self.pos_prod_ridxs = torch.tensor(self.pos_prod_ridxs) 170 | self.pos_seg_idxs = torch.tensor(self.pos_seg_idxs) 171 | self.pos_prod_rword_idxs = torch.tensor(self.pos_prod_rword_idxs) 172 | self.neg_prod_ridxs = torch.tensor(self.neg_prod_ridxs) 173 | self.neg_seg_idxs = torch.tensor(self.neg_seg_idxs) 174 | self.pos_prod_rword_masks = torch.ByteTensor(self.pos_prod_rword_masks) 175 | self.pos_user_idxs = torch.tensor(self.pos_user_idxs) 176 | self.neg_user_idxs = torch.tensor(self.neg_user_idxs) 177 | self.pos_item_idxs = torch.tensor(self.pos_item_idxs) 178 | self.neg_item_idxs = torch.tensor(self.neg_item_idxs) 179 | if self.neg_prod_rword_idxs is not None: 180 | self.neg_prod_rword_idxs = torch.tensor(self.neg_prod_rword_idxs) 181 | if self.neg_prod_rword_masks is not None: 182 | self.neg_prod_rword_masks = torch.ByteTensor(self.neg_prod_rword_masks) 183 | #for pvc 184 | if self.neg_prod_rword_idxs_pvc is not None: 185 | self.neg_prod_rword_idxs_pvc = torch.tensor(self.neg_prod_rword_idxs_pvc) 186 | if self.pos_prod_rword_idxs_pvc is not None: 187 | self.pos_prod_rword_idxs_pvc = torch.tensor(self.pos_prod_rword_idxs_pvc) 188 | 189 | def to(self, device): 190 | if device == "cpu": 191 | return self 192 | else: 193 | query_word_idxs = self.query_word_idxs.to(device) 194 | pos_prod_ridxs = self.pos_prod_ridxs.to(device) 195 | pos_seg_idxs = self.pos_seg_idxs.to(device) 196 | pos_prod_rword_idxs = self.pos_prod_rword_idxs.to(device) 197 | pos_prod_rword_masks = self.pos_prod_rword_masks.to(device) 198 | neg_prod_ridxs = self.neg_prod_ridxs.to(device) 199 | neg_seg_idxs = self.neg_seg_idxs.to(device) 200 | pos_user_idxs = self.pos_user_idxs.to(device) 201 | neg_user_idxs = self.neg_user_idxs.to(device) 202 | pos_item_idxs = self.pos_item_idxs.to(device) 203 | neg_item_idxs = self.neg_item_idxs.to(device) 204 | 205 | neg_prod_rword_idxs = None if self.neg_prod_rword_idxs is None \ 206 | else self.neg_prod_rword_idxs.to(device) 207 | neg_prod_rword_masks = None if self.neg_prod_rword_masks is None \ 208 | else self.neg_prod_rword_masks.to(device) 209 | #for pvc 210 | neg_prod_rword_idxs_pvc = None if self.neg_prod_rword_idxs_pvc is None \ 211 | else self.neg_prod_rword_idxs_pvc.to(device) 212 | pos_prod_rword_idxs_pvc = None if self.pos_prod_rword_idxs_pvc is None \ 213 | else self.pos_prod_rword_idxs_pvc.to(device) 214 | return self.__class__( 215 | query_word_idxs, pos_prod_ridxs, pos_seg_idxs, 216 | pos_prod_rword_idxs, pos_prod_rword_masks, 217 | neg_prod_ridxs, neg_seg_idxs, 218 | pos_user_idxs, neg_user_idxs, 219 | pos_item_idxs, neg_item_idxs, 220 | neg_prod_rword_idxs, 221 | neg_prod_rword_masks, 222 | pos_prod_rword_idxs_pvc, 223 | neg_prod_rword_idxs_pvc, to_tensor=False) 224 | -------------------------------------------------------------------------------- /data/data_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from others.logging import logger, init_logger 5 | from collections import defaultdict 6 | import others.util as util 7 | import gzip 8 | import os 9 | 10 | class ProdSearchData(): 11 | def __init__(self, args, input_train_dir, set_name, global_data): 12 | self.args = args 13 | self.neg_per_pos = args.neg_per_pos 14 | self.set_name = set_name 15 | self.global_data = global_data 16 | self.product_size = global_data.product_size 17 | self.user_size = global_data.user_size 18 | self.vocab_size = global_data.vocab_size 19 | self.sub_sampling_rate = None 20 | self.neg_sample_products = None 21 | self.word_dists = None 22 | self.subsampling_rate = args.subsampling_rate 23 | self.uq_pids = None 24 | if args.fix_emb: 25 | self.subsampling_rate = 0 26 | if set_name == "train": 27 | self.vocab_distribute = self.read_reviews("{}/{}.txt.gz".format(input_train_dir, set_name)) 28 | self.vocab_distribute = self.vocab_distribute.tolist() 29 | self.sub_sampling(self.subsampling_rate) 30 | self.word_dists = self.neg_distributes(self.vocab_distribute) 31 | 32 | if set_name == "train": 33 | self.product_query_idx = GlobalProdSearchData.read_arr_from_lines( 34 | "{}/{}_query_idx.txt.gz".format(input_train_dir, set_name)) 35 | self.review_info = global_data.train_review_info 36 | self.review_query_idx = global_data.train_query_idxs 37 | else: 38 | read_set_name = self.set_name 39 | if not self.args.has_valid: #if there is no validation set, use test as validation 40 | read_set_name = 'test' 41 | self.product_query_idx = GlobalProdSearchData.read_arr_from_lines( 42 | "{}/test_query_idx.txt.gz".format(input_train_dir)) #validation and test have same set of queries 43 | self.review_info, self.review_query_idx = GlobalProdSearchData.read_review_id( 44 | "{}/{}_id.txt.gz".format(input_train_dir, read_set_name), 45 | global_data.line_review_id_map) 46 | 47 | franklist = '{}/{}.bias_product.ranklist'.format(input_train_dir, read_set_name) 48 | #franklist = '{}/test.bias_product.ranklist'.format(input_train_dir) 49 | if args.test_candi_size > 0 and os.path.exists(franklist): #otherwise use all the product ids 50 | self.uq_pids = self.read_ranklist(franklist, global_data.product_asin2ids) 51 | #if args.train_review_only: 52 | #self.u_reviews, self.p_reviews = self.get_u_i_reviews( 53 | self.u_reviews, self.p_reviews = self.get_u_i_reviews_set( 54 | self.user_size, self.product_size, global_data.train_review_info) 55 | 56 | if args.prod_freq_neg_sample: 57 | self.product_distribute = self.collect_product_distribute(global_data.train_review_info) 58 | else: 59 | self.product_distribute = np.ones(self.product_size) 60 | self.product_dists = self.neg_distributes(self.product_distribute) 61 | #print(self.product_dists) 62 | 63 | self.set_review_size = len(self.review_info) 64 | #u:reviews i:reviews 65 | 66 | def read_ranklist(self, fname, product_asin2ids): 67 | uq_pids = defaultdict(list) 68 | with open(fname, 'r') as fin: 69 | for line in fin: 70 | arr = line.strip().split(' ') 71 | uid, qid = arr[0].split('_') 72 | asin = arr[2] 73 | uq_pids[(uid, int(qid))].append(product_asin2ids[asin]) 74 | return uq_pids 75 | 76 | def get_u_i_reviews(self, user_size, product_size, review_info): 77 | u_reviews = [[] for i in range(self.user_size)] 78 | p_reviews = [[] for i in range(self.product_size)] 79 | for _, u_idx, p_idx, r_idx in review_info: 80 | u_reviews[u_idx].append(r_idx) 81 | p_reviews[p_idx].append(r_idx) 82 | return u_reviews, p_reviews 83 | 84 | def get_u_i_reviews_set(self, user_size, product_size, review_info): 85 | u_reviews = [set() for i in range(self.user_size)] 86 | p_reviews = [set() for i in range(self.product_size)] 87 | for _, u_idx, p_idx, r_idx in review_info: 88 | u_reviews[u_idx].add(r_idx) 89 | p_reviews[p_idx].add(r_idx) 90 | return u_reviews, p_reviews 91 | 92 | def initialize_epoch(self): 93 | #self.neg_sample_products = np.random.randint(0, self.product_size, size = (self.set_review_size, self.neg_per_pos)) 94 | #exlude padding idx 95 | if self.args.model_name == "item_transformer": 96 | return 97 | self.neg_sample_products = np.random.choice(self.product_size, 98 | size = (self.set_review_size, self.neg_per_pos), replace=True, p=self.product_dists) 99 | #do subsampling to self.global_data.review_words 100 | if self.args.do_subsample_mask: 101 | #self.global_data.set_padded_review_words(self.global_data.review_words) 102 | return 103 | 104 | rand_numbers = np.random.random(sum(self.global_data.review_length)) 105 | updated_review_words = [] 106 | entry_id = 0 107 | for review in self.global_data.review_words[:-1]: 108 | filtered_review = [] 109 | for word_idx in review: 110 | if rand_numbers[entry_id] > self.sub_sampling_rate[word_idx]: 111 | continue 112 | filtered_review.append(word_idx) 113 | updated_review_words.append(filtered_review) 114 | updated_review_words.append([self.global_data.word_pad_idx]) 115 | updated_review_words = util.pad(updated_review_words, 116 | pad_id=self.global_data.word_pad_idx, width=self.args.review_word_limit) 117 | self.global_data.set_padded_review_words(updated_review_words) 118 | 119 | def collect_product_distribute(self, review_info): 120 | product_distribute = np.zeros(self.product_size) 121 | for _, uid, pid, _ in review_info: 122 | product_distribute[pid] += 1 123 | return product_distribute 124 | 125 | def read_reviews(self, fname): 126 | vocab_distribute = np.zeros(self.vocab_size) 127 | #review_info = [] 128 | with gzip.open(fname, 'rt') as fin: 129 | for line in fin: 130 | arr = line.strip().split('\t') 131 | #review_info.append((int(arr[0]), int(arr[1]))) # (user_idx, product_idx) 132 | review_text = [int(i) for i in arr[2].split(' ')] 133 | for idx in review_text: 134 | vocab_distribute[idx] += 1 135 | #return vocab_distribute, review_info 136 | return vocab_distribute 137 | 138 | def sub_sampling(self, subsample_threshold): 139 | self.sub_sampling_rate = np.asarray([1.0 for _ in range(self.vocab_size)]) 140 | if subsample_threshold == 0.0: 141 | return 142 | threshold = sum(self.vocab_distribute) * subsample_threshold 143 | for i in range(self.vocab_size): 144 | #vocab_distribute[i] could be zero if the word does not appear in the training set 145 | if self.vocab_distribute[i] == 0: 146 | self.sub_sampling_rate[i] = 0 147 | #if this word does not appear in training set, set the rate to 0. 148 | continue 149 | self.sub_sampling_rate[i] = min(1.0, (np.sqrt(float(self.vocab_distribute[i]) / threshold) + 1) * threshold / float(self.vocab_distribute[i])) 150 | 151 | self.sample_count = sum([self.sub_sampling_rate[i] * self.vocab_distribute[i] for i in range(self.vocab_size)]) 152 | self.sub_sampling_rate = np.asarray(self.sub_sampling_rate) 153 | logger.info("sample_count:{}".format(self.sample_count)) 154 | 155 | def neg_distributes(self, weights, distortion = 0.75): 156 | #print weights 157 | weights = np.asarray(weights) 158 | #print weights.sum() 159 | wf = weights / weights.sum() 160 | wf = np.power(wf, distortion) 161 | wf = wf / wf.sum() 162 | return wf 163 | 164 | 165 | class GlobalProdSearchData(): 166 | def __init__(self, args, data_path, input_train_dir): 167 | 168 | self.product_ids = self.read_lines("{}/product.txt.gz".format(data_path)) 169 | self.product_asin2ids = {x:i for i,x in enumerate(self.product_ids)} 170 | self.product_size = len(self.product_ids) 171 | self.user_ids = self.read_lines("{}/users.txt.gz".format(data_path)) 172 | self.user_size = len(self.user_ids) 173 | self.words = self.read_lines("{}/vocab.txt.gz".format(data_path)) 174 | self.vocab_size = len(self.words) + 1 175 | self.query_words = self.read_words_in_lines("{}/query.txt.gz".format(input_train_dir)) 176 | self.word_pad_idx = self.vocab_size-1 177 | self.query_words = util.pad(self.query_words, pad_id=self.word_pad_idx) 178 | 179 | #review_word_limit = -1 180 | #if args.model_name == "review_transformer": 181 | # self.review_word_limit = args.review_word_limit 182 | self.review_words = self.read_words_in_lines( 183 | "{}/review_text.txt.gz".format(data_path)) #, cutoff=review_word_limit) 184 | #when using average word embeddings to train, review_word_limit is set 185 | self.review_length = [len(x) for x in self.review_words] 186 | self.review_count = len(self.review_words) + 1 187 | if args.model_name == "review_transformer": 188 | self.review_words.append([self.word_pad_idx]) # * args.review_word_limit) 189 | #so that review_words[-1] = -1, ..., -1 190 | if args.do_subsample_mask: 191 | self.review_words = util.pad(self.review_words, pad_id=self.vocab_size-1, width=args.review_word_limit) 192 | #if args.do_seq_review_train or args.do_seq_review_test: 193 | self.u_r_seq = self.read_arr_from_lines("{}/u_r_seq.txt.gz".format(data_path)) #list of review ids 194 | self.i_r_seq = self.read_arr_from_lines("{}/p_r_seq.txt.gz".format(data_path)) #list of review ids 195 | self.review_loc_time = self.read_arr_from_lines("{}/review_uloc_ploc_and_time.txt.gz".format(data_path)) #(loc_in_u, loc_in_i, time) of each review 196 | 197 | self.line_review_id_map = self.read_review_id_line_map("{}/review_id.txt.gz".format(data_path)) 198 | self.train_review_info, self.train_query_idxs = self.read_review_id( 199 | "{}/train_id.txt.gz".format(input_train_dir), self.line_review_id_map) 200 | self.review_u_p = self.read_arr_from_lines("{}/review_u_p.txt.gz".format(data_path)) #list of review ids 201 | 202 | logger.info("Data statistic: vocab %d, review %d, user %d, product %d" % (self.vocab_size, 203 | self.review_count, self.user_size, self.product_size)) 204 | self.padded_review_words = None 205 | 206 | def set_padded_review_words(self, review_words): 207 | self.padded_review_words = review_words 208 | #words after subsampling and cutoff and padding 209 | 210 | ''' 211 | def read_review_loc_time(self, fname): 212 | line_arr = [] 213 | line_no = 0 214 | with gzip.open(fname, 'r') as fin: 215 | for line in fin: 216 | arr = line.strip().split(' ') 217 | arr = [int(x) for x in arr] 218 | line_arr.append([line_no] + arr) 219 | #review_id, location in user's review, timestamp 220 | line_no += 1 221 | line_arr.sort(lambda x:x[-1]) 222 | rtn_line_arr = [[] for i in range(len(line_arr))] 223 | for rank, review_info in enumerate(line_arr): 224 | review_id, loc, time = review_info 225 | rtn_line_arr[review_id] += [loc, time, rank] 226 | return rtn_line_arr 227 | ''' 228 | @staticmethod 229 | def read_review_id(fname, line_review_id_map): 230 | query_ids = [] 231 | review_info = [] 232 | with gzip.open(fname, 'rt') as fin: 233 | line_no = 0 234 | for line in fin: 235 | arr = line.strip().split('\t') 236 | review_id = line_review_id_map[int(arr[2].split('_')[-1])] 237 | review_info.append((line_no, int(arr[0]), int(arr[1]), review_id))#(user_idx, product_idx) 238 | if arr[-1].isdigit(): 239 | query_ids.append(int(arr[-1])) 240 | line_no += 1 241 | #if there is no query idx afer review_id, query_ids will be illegal and not used 242 | return review_info, query_ids 243 | 244 | @staticmethod 245 | def read_review_id_line_map(fname): 246 | line_review_id_map = dict() 247 | with gzip.open(fname, 'rt') as fin: 248 | idx = 0 249 | for line in fin: 250 | ori_line_id = int(line.strip().split('_')[-1]) 251 | line_review_id_map[ori_line_id] = idx 252 | idx += 1 253 | return line_review_id_map 254 | 255 | @staticmethod 256 | def read_arr_from_lines(fname): 257 | line_arr = [] 258 | with gzip.open(fname, 'rt') as fin: 259 | for line in fin: 260 | arr = line.strip().split(' ') 261 | filter_arr = [] 262 | for idx in arr: 263 | if len(idx) < 1: 264 | continue 265 | filter_arr.append(int(idx)) 266 | line_arr.append(filter_arr) 267 | return line_arr 268 | 269 | @staticmethod 270 | def read_lines(fname): 271 | arr = [] 272 | with gzip.open(fname, 'rt') as fin: 273 | for line in fin: 274 | arr.append(line.strip()) 275 | return arr 276 | 277 | @staticmethod 278 | def read_words_in_lines(fname, cutoff=-1): 279 | line_arr = [] 280 | with gzip.open(fname, 'rt') as fin: 281 | for line in fin: 282 | words = [int(i) for i in line.strip().split(' ')] 283 | if cutoff < 0: 284 | line_arr.append(words) 285 | else: 286 | line_arr.append(words[:cutoff]) 287 | return line_arr 288 | 289 | -------------------------------------------------------------------------------- /data/item_pv_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import others.util as util 4 | import numpy as np 5 | import random 6 | from data.batch_data import ProdSearchTrainBatch, ProdSearchTestBatch, ItemPVBatch 7 | 8 | 9 | class ItemPVDataloader(DataLoader): 10 | def __init__(self, args, dataset, prepare_pv=True, batch_size=1, shuffle=False, sampler=None, 11 | batch_sampler=None, num_workers=0, pin_memory=False, 12 | drop_last=False, timeout=0, worker_init_fn=None): 13 | super(ItemPVDataloader, self).__init__( 14 | dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, 15 | batch_sampler=batch_sampler, num_workers=num_workers, 16 | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, 17 | worker_init_fn=worker_init_fn, collate_fn=self._collate_fn) 18 | self.args = args 19 | self.prod_pad_idx = self.dataset.prod_pad_idx 20 | self.word_pad_idx = self.dataset.word_pad_idx 21 | self.seg_pad_idx = self.dataset.seg_pad_idx 22 | self.global_data = self.dataset.global_data 23 | self.prod_data = self.dataset.prod_data 24 | 25 | def _collate_fn(self, batch): 26 | if self.prod_data.set_name == 'train': 27 | return self.get_train_batch(batch) 28 | else: #validation or test 29 | return self.get_test_batch(batch) 30 | 31 | def get_test_batch(self, batch): 32 | query_idxs = [entry[0] for entry in batch] 33 | query_word_idxs = [self.global_data.query_words[x] for x in query_idxs] 34 | user_idxs = [entry[1] for entry in batch] 35 | target_prod_idxs = [entry[2] for entry in batch] 36 | candi_prod_idxs = [entry[4] for entry in batch] 37 | candi_u_item_idxs = [] 38 | for _, user_idx, prod_idx, review_idx, candidate_items in batch: 39 | do_seq = self.args.do_seq_review_test and not self.args.train_review_only 40 | u_prev_review_idxs = self.get_user_review_idxs(user_idx, review_idx, do_seq, fix=True) 41 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 42 | candi_u_item_idxs.append(u_item_idxs) 43 | 44 | candi_prod_idxs = util.pad(candi_prod_idxs, pad_id = self.prod_pad_idx) 45 | candi_u_item_idxs = util.pad(candi_u_item_idxs, pad_id = self.prod_pad_idx) 46 | 47 | batch = ItemPVBatch(query_word_idxs, target_prod_idxs, candi_u_item_idxs, 48 | query_idxs=query_idxs, user_idxs=user_idxs, candi_prod_idxs=candi_prod_idxs) 49 | return batch 50 | 51 | def get_test_batch_seq(self, batch): 52 | query_idxs = [entry[0] for entry in batch] 53 | query_word_idxs = [self.global_data.query_words[x] for x in query_idxs] 54 | user_idxs = [entry[1] for entry in batch] 55 | target_prod_idxs = [entry[2] for entry in batch] 56 | candi_prod_idxs = [entry[4] for entry in batch] 57 | candi_seg_idxs = [] 58 | candi_seq_item_idxs = [] 59 | for _, user_idx, prod_idx, review_idx, candidate_items in batch: 60 | do_seq = self.args.do_seq_review_test and not self.args.train_review_only 61 | u_prev_review_idxs = self.get_user_review_idxs(user_idx, review_idx, do_seq, fix=True) 62 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 63 | 64 | candi_batch_item_idxs = [] 65 | candi_batch_seg_idxs = [] 66 | for candi_i in candidate_items: 67 | cur_candi_i_item_idxs = u_item_idxs + [candi_i] 68 | cur_candi_i_masks = [0] + [1] * len(u_prev_review_idxs) + [2] 69 | candi_batch_seg_idxs.append(cur_candi_i_masks) 70 | candi_batch_item_idxs.append(cur_candi_i_item_idxs) 71 | candi_seg_idxs.append(candi_batch_seg_idxs) 72 | candi_seq_item_idxs.append(candi_batch_item_idxs) 73 | 74 | candi_prod_idxs = util.pad(candi_prod_idxs, pad_id = -1) 75 | candi_seg_idxs = util.pad_3d(candi_seg_idxs, pad_id = self.seg_pad_idx, dim=1) 76 | candi_seg_idxs = util.pad_3d(candi_seg_idxs, pad_id = self.seg_pad_idx, dim=2) 77 | candi_seq_item_idxs = util.pad_3d(candi_seq_item_idxs, pad_id = self.prod_pad_idx, dim=1) 78 | candi_seq_item_idxs = util.pad_3d(candi_seq_item_idxs, pad_id = self.prod_pad_idx, dim=2) 79 | 80 | batch = ItemPVBatch(query_word_idxs, target_prod_idxs, candi_prod_idxs, None, 81 | candi_seg_idxs, None, candi_seq_item_idxs, 82 | query_idxs=query_idxs, user_idxs=user_idxs) 83 | return batch 84 | 85 | def get_user_review_idxs(self, user_idx, review_idx, do_seq, fix=True): 86 | u_seq_review_idxs = self.global_data.u_r_seq[user_idx] 87 | u_train_review_set = self.prod_data.u_reviews[user_idx] #set 88 | if do_seq: 89 | loc_in_u = self.global_data.review_loc_time[review_idx][0] 90 | u_prev_review_idxs = self.global_data.u_r_seq[user_idx][:loc_in_u] 91 | u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 92 | else: 93 | u_seq_train_review_idxs = [x for x in u_seq_review_idxs if x in u_train_review_set and x!= review_idx] 94 | u_prev_review_idxs = u_seq_train_review_idxs 95 | if len(u_seq_train_review_idxs) > self.args.uprev_review_limit: 96 | if fix: 97 | u_prev_review_idxs = u_seq_train_review_idxs[-self.args.uprev_review_limit:] 98 | else: 99 | rand_review_set = random.sample(u_seq_train_review_idxs, self.args.uprev_review_limit) 100 | rand_review_set = set(rand_review_set) 101 | u_prev_review_idxs = [x for x in u_seq_train_review_idxs if x in rand_review_set] 102 | return u_prev_review_idxs 103 | 104 | def get_user_review_idxs_prev(self, user_idx, review_idx, do_seq, fix=True): 105 | if do_seq: 106 | loc_in_u = self.global_data.review_loc_time[review_idx][0] 107 | u_prev_review_idxs = self.global_data.u_r_seq[user_idx][:loc_in_u] 108 | u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 109 | #u_prev_review_idxs = self.global_data.u_r_seq[user_idx][max(0,loc_in_u-self.uprev_review_limit):loc_in_u] 110 | else: 111 | u_prev_review_idxs = self.prod_data.u_reviews[user_idx] 112 | if len(u_prev_review_idxs) > self.args.uprev_review_limit: 113 | if fix: 114 | u_prev_review_idxs = u_prev_review_idxs[:self.args.uprev_review_limit] 115 | #u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 116 | else: 117 | u_prev_review_idxs = random.sample(u_prev_review_idxs, self.args.uprev_review_limit) 118 | u_prev_review_idxs = [x for x in u_prev_review_idxs if x != review_idx] 119 | return u_prev_review_idxs 120 | 121 | def get_train_batch(self, batch): 122 | batch_query_word_idxs, batch_word_idxs = [],[] 123 | batch_u_item_idxs, batch_target_prod_idxs = [],[] 124 | #batch_neg_prod_idxs = np.random.choice( 125 | # self.prod_data.product_size, 126 | # size=(len(batch), self.args.neg_per_pos), p=self.prod_data.product_dists) 127 | cur_no = 0 128 | for word_idxs, review_idx in batch: 129 | batch_word_idxs.append(word_idxs) 130 | user_idx, prod_idx = self.global_data.review_u_p[review_idx] 131 | query_idx = random.choice(self.prod_data.product_query_idx[prod_idx]) 132 | query_word_idxs = self.global_data.query_words[query_idx] 133 | 134 | u_prev_review_idxs = self.get_user_review_idxs( 135 | user_idx, review_idx, self.args.do_seq_review_train, fix=self.args.fix_train_review) 136 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 137 | batch_query_word_idxs.append(query_word_idxs) 138 | batch_target_prod_idxs.append(prod_idx) 139 | batch_u_item_idxs.append(u_item_idxs) 140 | 141 | batch_u_item_idxs = util.pad(batch_u_item_idxs, pad_id = self.prod_pad_idx) 142 | batch = ItemPVBatch(batch_query_word_idxs, batch_target_prod_idxs, batch_u_item_idxs, batch_word_idxs) 143 | return batch 144 | 145 | def prepare_train_batch_pad_ui_seq(self, batch): 146 | batch_query_word_idxs, batch_word_idxs = [],[] 147 | batch_pos_seg_idxs, batch_pos_item_idxs = [],[] 148 | batch_neg_seg_idxs, batch_neg_item_idxs = [],[] 149 | batch_neg_prod_idxs = np.random.choice( 150 | self.prod_data.product_size, 151 | size=(len(batch), self.args.neg_per_pos), p=self.prod_data.product_dists) 152 | cur_no = 0 153 | for word_idxs, review_idx in batch: 154 | batch_word_idxs.append(word_idxs) 155 | user_idx, prod_idx = self.global_data.review_u_p[review_idx] 156 | query_idx = random.choice(self.prod_data.product_query_idx[prod_idx]) 157 | query_word_idxs = self.global_data.query_words[query_idx] 158 | 159 | u_prev_review_idxs = self.get_user_review_idxs(user_idx, review_idx, self.args.do_seq_review_train, fix=False) 160 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 161 | pos_seq_item_idxs = u_item_idxs + [prod_idx] 162 | pos_seg_idxs = [0] + [1] * len(u_prev_review_idxs) + [2] 163 | neg_seg_idxs = [] 164 | neg_seq_item_idxs = [] 165 | for neg_i in batch_neg_prod_idxs[cur_no]: 166 | cur_neg_i_item_idxs = u_item_idxs + [neg_i] 167 | cur_neg_i_masks = [0] + [1] * len(u_prev_review_idxs) + [2] 168 | neg_seq_item_idxs.append(cur_neg_i_item_idxs) 169 | neg_seg_idxs.append(cur_neg_i_masks) 170 | batch_query_word_idxs.append(query_word_idxs) 171 | batch_pos_seg_idxs.append(pos_seg_idxs) 172 | batch_pos_item_idxs.append(pos_seq_item_idxs) 173 | batch_neg_seg_idxs.append(neg_seg_idxs) 174 | batch_neg_item_idxs.append(neg_seq_item_idxs) 175 | 176 | data_batch = [batch_query_word_idxs, batch_word_idxs, batch_pos_seg_idxs, 177 | batch_neg_seg_idxs, batch_pos_item_idxs, batch_neg_item_idxs] 178 | return data_batch 179 | 180 | def get_train_batch_ui_seq(self, batch): 181 | query_word_idxs, pos_iword_idxs, pos_seg_idxs, neg_seg_idxs, \ 182 | pos_seq_item_idxs, neg_seq_item_idxs = self.prepare_train_batch(batch) 183 | target_prod_idxs = [x[-1] for x in pos_seq_item_idxs] 184 | pos_seg_idxs = util.pad(pos_seg_idxs, pad_id = self.seg_pad_idx) 185 | pos_seq_item_idxs = util.pad(pos_seq_item_idxs, pad_id = self.prod_pad_idx) 186 | batch_size, prev_item_count = np.asarray(pos_seq_item_idxs).shape 187 | #batch, neg_k, item_count 188 | neg_seg_idxs = util.pad_3d(neg_seg_idxs, pad_id = self.seg_pad_idx, dim=2) 189 | neg_seq_item_idxs = util.pad_3d(neg_seq_item_idxs, pad_id = self.prod_pad_idx, dim=2) 190 | 191 | batch = ItemPVBatch(query_word_idxs, target_prod_idxs, [], pos_seg_idxs, 192 | neg_seg_idxs, pos_seq_item_idxs, neg_seq_item_idxs, pos_iword_idxs=pos_iword_idxs) 193 | return batch 194 | 195 | -------------------------------------------------------------------------------- /data/item_pv_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import random 5 | import others.util as util 6 | 7 | from collections import defaultdict 8 | 9 | """ load training, validation and test data 10 | u,Q,i for a purchase i given u, Q 11 | negative samples i- for u, Q 12 | read reviews of u, Q before i (previous m reviews) 13 | reviews of others before r_i (review of i from u) 14 | load test data 15 | for each review, collect random words in reviews or words in a sliding window in reviews. 16 | or all the words in the review 17 | """ 18 | 19 | class ItemPVDataset(Dataset): 20 | def __init__(self, args, global_data, prod_data): 21 | self.args = args 22 | self.valid_candi_size = args.valid_candi_size 23 | self.prod_pad_idx = global_data.product_size 24 | self.word_pad_idx = global_data.vocab_size - 1 25 | self.seg_pad_idx = 3 # 0, 1, 2 26 | self.pv_window_size = args.pv_window_size 27 | self.train_review_only = args.train_review_only 28 | self.uprev_review_limit = args.uprev_review_limit 29 | self.global_data = global_data 30 | self.prod_data = prod_data 31 | if prod_data.set_name == "train": 32 | self._data = self.collect_train_samples(self.global_data, self.prod_data) 33 | else: 34 | self._data = self.collect_test_samples(self.global_data, self.prod_data, args.candi_batch_size) 35 | 36 | def collect_test_samples(self, global_data, prod_data, candi_batch_size=1000): 37 | #Q, review of u + review of pos i, review of u + review of neg i; 38 | #words of pos reviews; words of neg reviews, all if encoder is not pv 39 | test_data = [] 40 | uq_set = set() 41 | for line_id, user_idx, prod_idx, review_idx in prod_data.review_info: 42 | if (line_id+1) % 10000 == 0: 43 | progress = (line_id+1.) / len(prod_data.review_info) * 100 44 | print("{}% data processed".format(progress)) 45 | #query_idx = prod_data.review_query_idx[line_id] 46 | query_idxs = prod_data.product_query_idx[prod_idx] 47 | for query_idx in query_idxs: 48 | if (user_idx, query_idx) in uq_set: 49 | continue 50 | uq_set.add((user_idx, query_idx)) 51 | 52 | #candidate item list according to user_idx and query_idx, or by default all the items 53 | if prod_data.uq_pids is None: 54 | if self.prod_data.set_name == "valid" and self.valid_candi_size > 1: 55 | candidate_items = np.random.choice(global_data.product_size, 56 | size=self.valid_candi_size-1, replace=False, p=prod_data.product_dists).tolist() 57 | candidate_items.append(prod_idx) 58 | random.shuffle(candidate_items) 59 | else: 60 | candidate_items = list(range(global_data.product_size)) 61 | else: 62 | candidate_items = prod_data.uq_pids[(global_data.user_ids[user_idx], query_idx)] 63 | random.shuffle(candidate_items) 64 | #print(len(candidate_items)) 65 | seg_count = int((len(candidate_items) - 1) / candi_batch_size) + 1 66 | for i in range(seg_count): 67 | test_data.append([query_idx, user_idx, prod_idx, review_idx, 68 | candidate_items[i*candi_batch_size:(i+1)*candi_batch_size]]) 69 | print(len(uq_set)) 70 | return test_data 71 | 72 | 73 | def collect_train_samples(self, global_data, prod_data): 74 | #Q, review of u + review of pos i, review of u + review of neg i; 75 | #words of pos reviews; words of neg reviews, all if encoder is not pv 76 | train_data = [] 77 | rand_numbers = np.random.random(sum(global_data.review_length)) 78 | entry_id = 0 79 | word_idxs = [] 80 | for line_no, user_idx, prod_idx, review_idx in prod_data.review_info: 81 | cur_review_word_idxs = self.global_data.review_words[review_idx] 82 | random.shuffle(cur_review_word_idxs) 83 | for word_idx in cur_review_word_idxs: 84 | if rand_numbers[entry_id] > prod_data.sub_sampling_rate[word_idx]: 85 | continue 86 | word_idxs.append(word_idx) 87 | if len(word_idxs) == self.pv_window_size: 88 | train_data.append([word_idxs, review_idx]) 89 | word_idxs = [] 90 | entry_id += 1 91 | if len(word_idxs) > 0: 92 | train_data.append([word_idxs+[self.word_pad_idx]*(self.pv_window_size-len(word_idxs)), review_idx]) 93 | return train_data 94 | 95 | def __len__(self): 96 | return len(self._data) 97 | 98 | def __getitem__(self, index): 99 | return self._data[index] 100 | -------------------------------------------------------------------------------- /data/prod_search_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | import others.util as util 4 | import numpy as np 5 | import random 6 | from data.batch_data import ProdSearchTrainBatch, ProdSearchTestBatch 7 | 8 | 9 | class ProdSearchDataLoader(DataLoader): 10 | def __init__(self, args, dataset, prepare_pv=True, batch_size=1, shuffle=False, sampler=None, 11 | batch_sampler=None, num_workers=0, pin_memory=False, 12 | drop_last=False, timeout=0, worker_init_fn=None): 13 | super(ProdSearchDataLoader, self).__init__( 14 | dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, 15 | batch_sampler=batch_sampler, num_workers=num_workers, 16 | pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, 17 | worker_init_fn=worker_init_fn, collate_fn=self._collate_fn) 18 | self.args = args 19 | self.prepare_pv = prepare_pv 20 | self.shuffle = shuffle 21 | self.prod_pad_idx = self.dataset.prod_pad_idx 22 | self.user_pad_idx = self.dataset.user_pad_idx 23 | self.review_pad_idx = self.dataset.review_pad_idx 24 | self.word_pad_idx = self.dataset.word_pad_idx 25 | self.seg_pad_idx = self.dataset.seg_pad_idx 26 | self.global_data = self.dataset.global_data 27 | self.prod_data = self.dataset.prod_data 28 | self.shuffle_review_words = self.dataset.shuffle_review_words 29 | self.total_review_limit = self.args.uprev_review_limit + self.args.iprev_review_limit 30 | if self.args.do_subsample_mask: 31 | self.review_words = self.global_data.review_words 32 | self.sub_sampling_rate = self.prod_data.sub_sampling_rate 33 | else: 34 | self.review_words = self.global_data.padded_review_words 35 | self.sub_sampling_rate = None 36 | #if subsampling_rate is 0 then sub_sampling_rate is [1,1,1], all the words are kept 37 | 38 | def _collate_fn(self, batch): 39 | if self.prod_data.set_name == 'train': 40 | return self.get_train_batch(batch) 41 | else: #validation or test 42 | return self.get_test_batch(batch) 43 | 44 | def get_test_batch(self, batch): 45 | query_idxs = [entry[0] for entry in batch] 46 | query_word_idxs = [self.global_data.query_words[x] for x in query_idxs] 47 | user_idxs = [entry[1] for entry in batch] 48 | target_prod_idxs = [entry[2] for entry in batch] 49 | candi_prod_idxs = [entry[4] for entry in batch] 50 | candi_prod_ridxs = [] 51 | candi_seg_idxs = [] 52 | candi_seq_user_idxs = [] 53 | candi_seq_item_idxs = [] 54 | #candi_prod_ridxs = [entry[4] for entry in batch] 55 | #candi_seg_idxs = [entry[5] for entry in batch] 56 | for _, user_idx, prod_idx, review_idx, candidate_items in batch: 57 | #if self.args.train_review_only: 58 | # u_prev_review_idxs = self.prod_data.u_reviews[user_idx][:self.args.uprev_review_limit] 59 | #else: 60 | do_seq = self.args.do_seq_review_test and not self.args.train_review_only 61 | u_prev_review_idxs = self.get_user_review_idxs(user_idx, review_idx, do_seq, fix=True) 62 | i_prev_review_idxs = self.get_item_review_idxs(prod_idx, review_idx, do_seq, fix=True) 63 | review_time_stamp = None 64 | if self.args.do_seq_review_test: 65 | review_time_stamp = self.global_data.review_loc_time[review_idx][2] 66 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 67 | 68 | candi_batch_item_idxs = [] 69 | candi_batch_user_idxs = [] 70 | candi_batch_seg_idxs = [] 71 | candi_batch_prod_ridxs = [] 72 | for candi_i in candidate_items: 73 | #if self.args.train_review_only: 74 | # candi_i_prev_review_idxs = self.prod_data.p_reviews[candi_i][:self.args.iprev_review_limit] 75 | #else: 76 | candi_i_prev_review_idxs = self.get_item_review_idxs( 77 | candi_i, None, do_seq, review_time_stamp, fix=True) 78 | candi_i_user_idxs = [self.global_data.review_u_p[x][0] for x in candi_i_prev_review_idxs] 79 | cur_candi_i_user_idxs = [self.user_pad_idx] + [user_idx] * len(u_prev_review_idxs) + candi_i_user_idxs 80 | cur_candi_i_user_idxs = cur_candi_i_user_idxs[:self.total_review_limit+1] 81 | cur_candi_i_item_idxs = [self.prod_pad_idx] + u_item_idxs + [candi_i] * len(candi_i_prev_review_idxs) 82 | cur_candi_i_item_idxs = cur_candi_i_item_idxs[:self.total_review_limit+1] 83 | cur_candi_i_masks = [0] + [1] * len(u_prev_review_idxs) + [2] * len(candi_i_prev_review_idxs) #might be 0 84 | cur_candi_i_masks = cur_candi_i_masks[:self.total_review_limit+1] 85 | cur_candi_i_review_idxs = u_prev_review_idxs + candi_i_prev_review_idxs 86 | cur_candi_i_review_idxs = cur_candi_i_review_idxs[:self.total_review_limit] 87 | candi_batch_seg_idxs.append(cur_candi_i_masks) 88 | candi_batch_prod_ridxs.append(cur_candi_i_review_idxs) 89 | candi_batch_item_idxs.append(cur_candi_i_item_idxs) 90 | candi_batch_user_idxs.append(cur_candi_i_user_idxs) 91 | candi_prod_ridxs.append(candi_batch_prod_ridxs) 92 | candi_seg_idxs.append(candi_batch_seg_idxs) 93 | candi_seq_item_idxs.append(candi_batch_item_idxs) 94 | candi_seq_user_idxs.append(candi_batch_user_idxs) 95 | 96 | candi_prod_idxs = util.pad(candi_prod_idxs, pad_id = -1) #pad reviews 97 | candi_prod_ridxs = util.pad_3d(candi_prod_ridxs, pad_id = self.review_pad_idx, dim=1) #pad candi products 98 | candi_prod_ridxs = util.pad_3d(candi_prod_ridxs, pad_id = self.review_pad_idx, dim=2) #pad reviews of each candi 99 | candi_seg_idxs = util.pad_3d(candi_seg_idxs, pad_id = self.seg_pad_idx, dim=1) 100 | candi_seg_idxs = util.pad_3d(candi_seg_idxs, pad_id = self.seg_pad_idx, dim=2) 101 | candi_seq_user_idxs = util.pad_3d(candi_seq_user_idxs, pad_id = self.user_pad_idx, dim=1) 102 | candi_seq_user_idxs = util.pad_3d(candi_seq_user_idxs, pad_id = self.user_pad_idx, dim=2) 103 | candi_seq_item_idxs = util.pad_3d(candi_seq_item_idxs, pad_id = self.prod_pad_idx, dim=1) 104 | candi_seq_item_idxs = util.pad_3d(candi_seq_item_idxs, pad_id = self.prod_pad_idx, dim=2) 105 | 106 | batch = ProdSearchTestBatch(query_idxs, user_idxs, target_prod_idxs, candi_prod_idxs, 107 | query_word_idxs, candi_prod_ridxs, candi_seg_idxs, 108 | candi_seq_user_idxs, candi_seq_item_idxs) 109 | return batch 110 | 111 | def get_item_review_idxs_prev(self, prod_idx, review_idx, do_seq, review_time_stamp=None,fix=True): 112 | if do_seq: 113 | if review_idx is None: 114 | loc_in_i = self.dataset.bisect_right( 115 | self.global_data.i_r_seq[prod_idx], self.global_data.review_loc_time, review_time_stamp) 116 | else: 117 | loc_in_i = self.global_data.review_loc_time[review_idx][1] 118 | if loc_in_i == 0: 119 | return [] 120 | i_prev_review_idxs = self.global_data.i_r_seq[prod_idx][:loc_in_i] 121 | i_prev_review_idxs = i_prev_review_idxs[-self.args.iprev_review_limit:] 122 | #i_prev_review_idxs = self.global_data.i_r_seq[prod_idx][max(0,loc_in_i-self.args.iprev_review_limit):loc_in_i] 123 | 124 | else: 125 | i_prev_review_idxs = self.prod_data.p_reviews[prod_idx] 126 | if len(i_prev_review_idxs) > self.args.iprev_review_limit: 127 | if fix: 128 | i_prev_review_idxs = i_prev_review_idxs[:self.args.iprev_review_limit] 129 | #i_prev_review_idxs = i_prev_review_idxs[-self.args.iprev_review_limit:] 130 | else: 131 | i_prev_review_idxs = random.sample(i_prev_review_idxs, self.args.iprev_review_limit) 132 | 133 | return i_prev_review_idxs 134 | 135 | def get_item_review_idxs(self, prod_idx, review_idx, do_seq, review_time_stamp=None,fix=True): 136 | i_seq_review_idxs = self.global_data.i_r_seq[prod_idx] 137 | i_train_review_set = self.prod_data.p_reviews[prod_idx] 138 | if do_seq: 139 | if review_idx is None: 140 | loc_in_i = self.dataset.bisect_right( 141 | self.global_data.i_r_seq[prod_idx], self.global_data.review_loc_time, review_time_stamp) 142 | else: 143 | loc_in_i = self.global_data.review_loc_time[review_idx][1] 144 | if loc_in_i == 0: 145 | return [] 146 | i_prev_review_idxs = self.global_data.i_r_seq[prod_idx][:loc_in_i] 147 | i_prev_review_idxs = i_prev_review_idxs[-self.args.iprev_review_limit:] 148 | else: 149 | i_seq_train_review_idxs = [x for x in i_seq_review_idxs if x in i_train_review_set and x!= review_idx] 150 | i_prev_review_idxs = i_seq_train_review_idxs 151 | if len(i_prev_review_idxs) > self.args.iprev_review_limit: 152 | if fix: 153 | i_prev_review_idxs = i_prev_review_idxs[-self.args.iprev_review_limit:] 154 | else: 155 | rand_review_set = random.sample(i_seq_train_review_idxs, self.args.iprev_review_limit) 156 | rand_review_set = set(rand_review_set) 157 | i_prev_review_idxs = [x for x in i_seq_train_review_idxs if x in rand_review_set] 158 | 159 | return i_prev_review_idxs 160 | 161 | def get_user_review_idxs_prev(self, user_idx, review_idx, do_seq, fix=True): 162 | if do_seq: 163 | loc_in_u = self.global_data.review_loc_time[review_idx][0] 164 | u_prev_review_idxs = self.global_data.u_r_seq[user_idx][:loc_in_u] 165 | u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 166 | #u_prev_review_idxs = self.global_data.u_r_seq[user_idx][max(0,loc_in_u-self.uprev_review_limit):loc_in_u] 167 | else: 168 | u_prev_review_idxs = self.prod_data.u_reviews[user_idx] 169 | if len(u_prev_review_idxs) > self.args.uprev_review_limit: 170 | if fix: 171 | u_prev_review_idxs = u_prev_review_idxs[:self.args.uprev_review_limit] 172 | #u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 173 | else: 174 | u_prev_review_idxs = random.sample(u_prev_review_idxs, self.args.uprev_review_limit) 175 | return u_prev_review_idxs 176 | 177 | def get_user_review_idxs(self, user_idx, review_idx, do_seq, fix=True): 178 | u_seq_review_idxs = self.global_data.u_r_seq[user_idx] 179 | u_train_review_set = self.prod_data.u_reviews[user_idx] #set 180 | if do_seq: 181 | loc_in_u = self.global_data.review_loc_time[review_idx][0] 182 | u_prev_review_idxs = self.global_data.u_r_seq[user_idx][:loc_in_u] 183 | u_prev_review_idxs = u_prev_review_idxs[-self.args.uprev_review_limit:] 184 | else: 185 | u_seq_train_review_idxs = [x for x in u_seq_review_idxs if x in u_train_review_set and x!= review_idx] 186 | u_prev_review_idxs = u_seq_train_review_idxs 187 | if len(u_seq_train_review_idxs) > self.args.uprev_review_limit: 188 | if fix: 189 | u_prev_review_idxs = u_seq_train_review_idxs[-self.args.uprev_review_limit:] 190 | else: 191 | rand_review_set = random.sample(u_seq_train_review_idxs, self.args.uprev_review_limit) 192 | rand_review_set = set(rand_review_set) 193 | u_prev_review_idxs = [x for x in u_seq_train_review_idxs if x in rand_review_set] 194 | return u_prev_review_idxs 195 | 196 | def prepare_train_batch(self, batch): 197 | batch_query_word_idxs = [] 198 | batch_pos_prod_ridxs, batch_pos_seg_idxs, batch_pos_user_idxs, batch_pos_item_idxs = [],[],[],[] 199 | batch_neg_prod_ridxs, batch_neg_seg_idxs, batch_neg_user_idxs, batch_neg_item_idxs = [],[],[],[] 200 | for line_id, user_idx, prod_idx, review_idx in batch: 201 | query_idx = random.choice(self.prod_data.product_query_idx[prod_idx]) 202 | query_word_idxs = self.global_data.query_words[query_idx] 203 | u_prev_review_idxs = self.get_user_review_idxs(user_idx, review_idx, self.args.do_seq_review_train, fix=False) 204 | i_prev_review_idxs = self.get_item_review_idxs(prod_idx, review_idx, self.args.do_seq_review_train, fix=False) 205 | review_time_stamp = None 206 | if self.args.do_seq_review_train: 207 | review_time_stamp = self.global_data.review_loc_time[review_idx][2] 208 | if len(i_prev_review_idxs) == 0: 209 | continue 210 | i_user_idxs = [self.global_data.review_u_p[x][0] for x in i_prev_review_idxs] 211 | u_item_idxs = [self.global_data.review_u_p[x][1] for x in u_prev_review_idxs] 212 | pos_user_idxs = [self.user_pad_idx] + [user_idx] * len(u_prev_review_idxs) + i_user_idxs 213 | pos_user_idxs = pos_user_idxs[:self.total_review_limit + 1] 214 | pos_item_idxs = [self.prod_pad_idx] + u_item_idxs + [prod_idx] * len(i_prev_review_idxs) 215 | pos_item_idxs = pos_item_idxs[:self.total_review_limit + 1] 216 | pos_seg_idxs = [0] + [1] * len(u_prev_review_idxs) + [2] * len(i_prev_review_idxs) 217 | pos_seg_idxs = pos_seg_idxs[:self.total_review_limit + 1] 218 | pos_prod_ridxs = u_prev_review_idxs + i_prev_review_idxs 219 | pos_prod_ridxs = pos_prod_ridxs[:self.total_review_limit] # or select reviews with the most words 220 | 221 | neg_prod_idxs = self.prod_data.neg_sample_products[line_id] #neg_per_pos 222 | neg_prod_ridxs = [] 223 | neg_seg_idxs = [] 224 | neg_user_idxs = [] 225 | neg_item_idxs = [] 226 | for neg_i in neg_prod_idxs: 227 | neg_i_prev_review_idxs = self.get_item_review_idxs( 228 | neg_i, None, self.args.do_seq_review_train, review_time_stamp, fix=False) 229 | if len(neg_i_prev_review_idxs) == 0: 230 | continue 231 | neg_i_user_idxs = [self.global_data.review_u_p[x][0] for x in neg_i_prev_review_idxs] 232 | cur_neg_i_user_idxs = [self.user_pad_idx] + [user_idx] * len(u_prev_review_idxs) + neg_i_user_idxs 233 | cur_neg_i_user_idxs = cur_neg_i_user_idxs[:self.total_review_limit+1] 234 | cur_neg_i_item_idxs = [self.prod_pad_idx] + u_item_idxs + [neg_i] * len(neg_i_prev_review_idxs) 235 | cur_neg_i_item_idxs = cur_neg_i_item_idxs[:self.total_review_limit+1] 236 | cur_neg_i_masks = [0] + [1] * len(u_prev_review_idxs) + [2] * len(neg_i_prev_review_idxs) 237 | cur_neg_i_masks = cur_neg_i_masks[:self.total_review_limit+1] 238 | cur_neg_i_review_idxs = u_prev_review_idxs + neg_i_prev_review_idxs 239 | cur_neg_i_review_idxs = cur_neg_i_review_idxs[:self.total_review_limit] 240 | neg_user_idxs.append(cur_neg_i_user_idxs) 241 | neg_item_idxs.append(cur_neg_i_item_idxs) 242 | neg_seg_idxs.append(cur_neg_i_masks) 243 | neg_prod_ridxs.append(cur_neg_i_review_idxs) 244 | #neg_prod_rword_idxs.append([self.global_data.review_words[x] for x in cur_neg_i_review_idxs]) 245 | if len(neg_prod_ridxs) == 0: 246 | #all the neg prod do not have available reviews 247 | continue 248 | batch_query_word_idxs.append(query_word_idxs) 249 | batch_pos_prod_ridxs.append(pos_prod_ridxs) 250 | batch_pos_seg_idxs.append(pos_seg_idxs) 251 | batch_pos_user_idxs.append(pos_user_idxs) 252 | batch_pos_item_idxs.append(pos_item_idxs) 253 | batch_neg_prod_ridxs.append(neg_prod_ridxs) 254 | batch_neg_seg_idxs.append(neg_seg_idxs) 255 | batch_neg_user_idxs.append(neg_user_idxs) 256 | batch_neg_item_idxs.append(neg_item_idxs) 257 | 258 | data_batch = [batch_query_word_idxs, batch_pos_prod_ridxs, batch_pos_seg_idxs, 259 | batch_neg_prod_ridxs, batch_neg_seg_idxs, batch_pos_user_idxs, 260 | batch_neg_user_idxs, batch_pos_item_idxs, batch_neg_item_idxs] 261 | return data_batch 262 | ''' 263 | u, Q, i (positive, negative) 264 | Q; ru1,ru2,ri1,ri2 and k negative (ru1,ru2,rn1i1,rn1i2; ru1,ru2,rnji1,rnji2) 265 | segs 0; 1,1;pos 2,2, -1,-1 neg_1, neg_2 266 | r: word_id1, word_id2, ... 267 | pos_seg_idxs:0,1,1,2,2,-1 268 | word_count can be computed with words that are not padding 269 | review of u concat with review of i 270 | review of u concat with review of each negative i 271 | batch_size, review_count (u+i), max_word_count_per_review 272 | batch_size, neg_k, review_count (u+i), max_word_count_per_review 273 | ''' 274 | def get_train_batch(self, batch): 275 | query_word_idxs, pos_prod_ridxs, pos_seg_idxs, \ 276 | neg_prod_ridxs, neg_seg_idxs, pos_user_idxs, \ 277 | neg_user_idxs, pos_item_idxs, neg_item_idxs = self.prepare_train_batch(batch) 278 | if len(query_word_idxs) == 0: 279 | print("0 available instance in the batch") 280 | return None 281 | pos_prod_ridxs = util.pad(pos_prod_ridxs, pad_id = self.review_pad_idx) #pad reviews 282 | pos_seg_idxs = util.pad(pos_seg_idxs, pad_id = self.seg_pad_idx) 283 | pos_user_idxs = util.pad(pos_user_idxs, pad_id = self.user_pad_idx) 284 | pos_item_idxs = util.pad(pos_item_idxs, pad_id = self.prod_pad_idx) 285 | pos_prod_ridxs = np.asarray(pos_prod_ridxs) 286 | batch_size, pos_rcount = pos_prod_ridxs.shape 287 | pos_prod_rword_idxs = [self.review_words[x] for x in pos_prod_ridxs.reshape(-1)] 288 | #pos_prod_rword_idxs = util.pad(pos_prod_rword_idxs, pad_id = self.word_pad_idx) 289 | pos_prod_rword_idxs = np.asarray(pos_prod_rword_idxs).reshape(batch_size, pos_rcount, -1) 290 | pos_prod_rword_masks = self.dataset.get_pv_word_masks( 291 | #pos_prod_rword_idxs, self.prod_data.sub_sampling_rate, pad_id=self.word_pad_idx) 292 | pos_prod_rword_idxs, self.sub_sampling_rate, pad_id=self.word_pad_idx) 293 | neg_prod_ridxs = util.pad_3d(neg_prod_ridxs, pad_id = self.review_pad_idx, dim=1) #pad neg products 294 | neg_prod_ridxs = util.pad_3d(neg_prod_ridxs, pad_id = self.review_pad_idx, dim=2) #pad reviews of each neg 295 | neg_seg_idxs = util.pad_3d(neg_seg_idxs, pad_id = self.seg_pad_idx, dim=1) 296 | neg_seg_idxs = util.pad_3d(neg_seg_idxs, pad_id = self.seg_pad_idx, dim=2) 297 | neg_user_idxs = util.pad_3d(neg_user_idxs, pad_id = self.user_pad_idx, dim=1) 298 | neg_user_idxs = util.pad_3d(neg_user_idxs, pad_id = self.user_pad_idx, dim=2) 299 | neg_item_idxs = util.pad_3d(neg_item_idxs, pad_id = self.prod_pad_idx, dim=1) 300 | neg_item_idxs = util.pad_3d(neg_item_idxs, pad_id = self.prod_pad_idx, dim=2) 301 | neg_prod_ridxs = np.asarray(neg_prod_ridxs) 302 | _, neg_k, nr_count = neg_prod_ridxs.shape 303 | neg_prod_rword_idxs = [self.review_words[x] for x in neg_prod_ridxs.reshape(-1)] 304 | #neg_prod_rword_idxs = util.pad(neg_prod_rword_idxs, pad_id = self.word_pad_idx) 305 | neg_prod_rword_idxs = np.asarray(neg_prod_rword_idxs).reshape(batch_size, neg_k, nr_count, -1) 306 | 307 | if "pv" in self.dataset.review_encoder_name and self.prepare_pv: 308 | pos_prod_rword_idxs_pvc = pos_prod_rword_idxs 309 | neg_prod_rword_idxs_pvc = neg_prod_rword_idxs 310 | batch_size, pos_rcount, word_limit = pos_prod_rword_idxs.shape 311 | pv_window_size = self.dataset.pv_window_size 312 | if self.shuffle_review_words: 313 | self.dataset.shuffle_words_in_reviews(pos_prod_rword_idxs) 314 | slide_pos_prod_rword_idxs = self.dataset.slide_padded_matrices_for_pv( 315 | pos_prod_rword_idxs.reshape(-1, word_limit), 316 | pv_window_size, self.word_pad_idx) 317 | slide_pos_prod_rword_masks = self.dataset.slide_padded_matrices_for_pv( 318 | pos_prod_rword_masks.reshape(-1, word_limit), 319 | pv_window_size, pad_id = 0) 320 | #seg_count, batch_size * pos_rcount, pv_window_size 321 | seg_count = slide_pos_prod_rword_idxs.shape[0] 322 | slide_pos_prod_rword_idxs = slide_pos_prod_rword_idxs.reshape( 323 | seg_count, batch_size, pos_rcount, pv_window_size).reshape( 324 | -1, pos_rcount, pv_window_size) #seg_count, batch_size 325 | slide_pos_prod_rword_masks = slide_pos_prod_rword_masks.reshape( 326 | batch_size, pos_rcount, -1, pv_window_size).reshape( 327 | -1, pos_rcount, pv_window_size) #seg_count, batch_size 328 | batch_indices = np.repeat(np.expand_dims(np.arange(batch_size),0), seg_count, axis=0) 329 | if self.shuffle: 330 | I = np.random.permutation(batch_size * seg_count) 331 | batch_indices = batch_indices.reshape(-1)[I].reshape(seg_count, batch_size) 332 | slide_pos_prod_rword_idxs = slide_pos_prod_rword_idxs[I] 333 | slide_pos_prod_rword_masks = slide_pos_prod_rword_masks[I] 334 | slide_pos_prod_rword_idxs = slide_pos_prod_rword_idxs.reshape(seg_count, batch_size, pos_rcount, -1) 335 | slide_pos_prod_rword_masks = slide_pos_prod_rword_masks.reshape(seg_count, batch_size, pos_rcount, -1) 336 | query_word_idxs, pos_prod_ridxs, pos_seg_idxs, neg_prod_ridxs, neg_seg_idxs \ 337 | = map(np.asarray, [query_word_idxs, pos_prod_ridxs, pos_seg_idxs, neg_prod_ridxs, neg_seg_idxs]) 338 | batch = [ProdSearchTrainBatch(query_word_idxs[batch_indices[i]], 339 | pos_prod_ridxs[batch_indices[i]], pos_seg_idxs[batch_indices[i]], 340 | slide_pos_prod_rword_idxs[i], slide_pos_prod_rword_masks[i], 341 | neg_prod_ridxs[batch_indices[i]], neg_seg_idxs[batch_indices[i]], 342 | pos_user_idxs[batch_indices[i]], neg_user_idxs[batch_indices[i]], 343 | pos_item_idxs[batch_indices[i]], neg_item_idxs[batch_indices[i]], 344 | pos_prod_rword_idxs_pvc = pos_prod_rword_idxs_pvc[batch_indices[i]], 345 | neg_prod_rword_idxs_pvc = neg_prod_rword_idxs_pvc[batch_indices[i]]) for i in range(seg_count)] 346 | else: 347 | neg_prod_rword_masks = self.dataset.get_pv_word_masks( 348 | #neg_prod_rword_idxs, self.prod_data.sub_sampling_rate, pad_id=self.word_pad_idx) 349 | neg_prod_rword_idxs, self.sub_sampling_rate, pad_id=self.word_pad_idx) 350 | batch = ProdSearchTrainBatch(query_word_idxs, pos_prod_ridxs, pos_seg_idxs, 351 | pos_prod_rword_idxs, pos_prod_rword_masks, 352 | neg_prod_ridxs, neg_seg_idxs, 353 | pos_user_idxs, neg_user_idxs, 354 | pos_item_idxs, neg_item_idxs, 355 | neg_prod_rword_idxs = neg_prod_rword_idxs, 356 | neg_prod_rword_masks = neg_prod_rword_masks) 357 | return batch 358 | 359 | -------------------------------------------------------------------------------- /data/prod_search_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import random 5 | import others.util as util 6 | 7 | from collections import defaultdict 8 | 9 | """ load training, validation and test data 10 | u,Q,i for a purchase i given u, Q 11 | negative samples i- for u, Q 12 | read reviews of u, Q before i (previous m reviews) 13 | reviews of others before r_i (review of i from u) 14 | load test data 15 | for each review, collect random words in reviews or words in a sliding window in reviews. 16 | or all the words in the review 17 | """ 18 | 19 | class ProdSearchDataset(Dataset): 20 | def __init__(self, args, global_data, prod_data): 21 | self.args = args 22 | self.valid_candi_size = args.valid_candi_size 23 | self.user_pad_idx = global_data.user_size 24 | self.prod_pad_idx = global_data.product_size 25 | self.word_pad_idx = global_data.vocab_size - 1 26 | self.review_pad_idx = global_data.review_count - 1 27 | self.seg_pad_idx = 3 # 0, 1, 2 28 | self.shuffle_review_words = args.shuffle_review_words 29 | self.review_encoder_name = args.review_encoder_name 30 | self.pv_window_size = args.pv_window_size 31 | self.corrupt_rate = args.corrupt_rate 32 | self.train_review_only = args.train_review_only 33 | self.uprev_review_limit = args.uprev_review_limit 34 | self.iprev_review_limit = args.iprev_review_limit #can be a really large number, can random select 35 | self.total_review_limit = self.uprev_review_limit + self.iprev_review_limit 36 | self.global_data = global_data 37 | self.prod_data = prod_data 38 | if prod_data.set_name == "train": 39 | self._data = self.collect_train_samples(self.global_data, self.prod_data) 40 | else: 41 | self._data = self.collect_test_samples(self.global_data, self.prod_data, args.candi_batch_size) 42 | 43 | def collect_test_samples(self, global_data, prod_data, candi_batch_size=1000): 44 | #Q, review of u + review of pos i, review of u + review of neg i; 45 | #words of pos reviews; words of neg reviews, all if encoder is not pv 46 | test_data = [] 47 | uq_set = set() 48 | for line_id, user_idx, prod_idx, review_idx in prod_data.review_info: 49 | if (line_id+1) % 10000 == 0: 50 | progress = (line_id+1.) / len(prod_data.review_info) * 100 51 | print("{}% data processed".format(progress)) 52 | #user_idx, prod_idx, review_idx = review 53 | #query_idx = prod_data.review_query_idx[line_id] 54 | query_idxs = prod_data.product_query_idx[prod_idx] 55 | for query_idx in query_idxs: 56 | if (user_idx, query_idx) in uq_set: 57 | continue 58 | uq_set.add((user_idx, query_idx)) 59 | 60 | #candidate item list according to user_idx and query_idx, or by default all the items 61 | #candidate_items = list(range(global_data.product_size))[:1000] 62 | if prod_data.uq_pids is None: 63 | if self.prod_data.set_name == "valid" and self.valid_candi_size > 1: 64 | candidate_items = np.random.choice(global_data.product_size, 65 | size=self.valid_candi_size-1, replace=False, p=prod_data.product_dists).tolist() 66 | #candidate_items = np.random.randint(0, global_data.product_size, size =self.valid_candi_size-1).tolist() 67 | candidate_items.append(prod_idx) 68 | random.shuffle(candidate_items) 69 | else: 70 | candidate_items = list(range(global_data.product_size)) 71 | else: 72 | #print(global_data.user_ids[user_idx], query_idx) 73 | candidate_items = prod_data.uq_pids[(global_data.user_ids[user_idx], query_idx)] 74 | random.shuffle(candidate_items) 75 | #candidate_items = [global_data.product_asin2ids[x] for x in asin_list] 76 | #print(len(candidate_items)) 77 | seg_count = int((len(candidate_items) - 1) / candi_batch_size) + 1 78 | for i in range(seg_count): 79 | test_data.append([query_idx, user_idx, prod_idx, review_idx, 80 | candidate_items[i*candi_batch_size:(i+1)*candi_batch_size]]) 81 | print(len(uq_set)) 82 | 83 | return test_data 84 | 85 | 86 | def collect_train_samples(self, global_data, prod_data): 87 | #Q, review of u + review of pos i, review of u + review of neg i; 88 | #words of pos reviews; words of neg reviews, all if encoder is not pv 89 | return prod_data.review_info 90 | 91 | def get_pv_word_masks(self, prod_rword_idxs, subsampling_rate, pad_id): 92 | if subsampling_rate is not None: 93 | rand_numbers = np.random.random(prod_rword_idxs.shape) 94 | #subsampling_rate_arr = np.asarray([[subsampling_rate[prod_rword_idxs[i][j]] \ 95 | # for j in range(prod_rword_idxs.shape[1])] for i in range(prod_rword_idxs.shape[0])]) 96 | subsampling_rate_arr = subsampling_rate[prod_rword_idxs] 97 | masks = np.logical_and(prod_rword_idxs !=pad_id, rand_numbers < subsampling_rate_arr) 98 | else: 99 | masks = (prod_rword_idxs !=pad_id) 100 | return masks 101 | 102 | def shuffle_words_in_reviews(self, prod_rword_idxs): 103 | #consider random shuffle words 104 | for row in prod_rword_idxs: 105 | np.random.shuffle(row) 106 | 107 | def slide_matrices_for_pv(self, prod_rword_idxs, pv_window_size): 108 | #review_count * review_word_limit 109 | seg_prod_rword_idxs = [] 110 | cur_length = 0 111 | while cur_length < prod_rword_idxs.shape[1]: # review_word_limit 112 | seg_prod_rword_idxs.append(prod_rword_idxs[:,cur_length:cur_length+pv_window_size]) #).tolist()) 113 | cur_length += pv_window_size 114 | return seg_prod_rword_idxs 115 | 116 | def slide_padded_matrices_for_pv(self, prod_rword_idxs, pv_window_size, pad_id): 117 | ''' 118 | word_limit = prod_rword_idxs.shape[1] 119 | seg_count = word_limit / pv_window_size 120 | mod = word_limit % pv_window_size 121 | if mod > 0: 122 | seg_count += 1 123 | new_length = pv_window_size * seg_count 124 | prod_rword_idxs = util.pad_3d( 125 | prod_rword_idxs.tolist(), pad_id=pad_id, dim=2, width=new_length) #pad words 126 | #seg_count = (prod_rword_idxs.shape[1]-1)/pv_window_size + 1 127 | ''' 128 | pad_size = pv_window_size - (prod_rword_idxs.shape[1] % pv_window_size) 129 | if pad_size < pv_window_size: 130 | prod_rword_idxs = np.pad(prod_rword_idxs, ((0,0),(0,pad_size)),mode='constant', constant_values=pad_id) 131 | 132 | seg_count = int(prod_rword_idxs.shape[1]/pv_window_size) 133 | return np.asarray([prod_rword_idxs[:,i*pv_window_size:(i+1)*pv_window_size] for i in range(seg_count)]) 134 | 135 | def bisect_right(self, review_arr, review_loc_time_arr, timestamp, lo=0, hi=None): 136 | """Return the index where timestamp is larger than the review in review_arr (sorted) 137 | The return value i is such that all e in a[:i] have e <= x, and all e in 138 | a[i:] have e > x. So if x already appears in the list, a.insert(x) will 139 | insert just after the rightmost x already there. 140 | Optional args lo (default 0) and hi (default len(a)) bound the 141 | slice of a to be searched. 142 | """ 143 | 144 | if lo < 0: 145 | raise ValueError('lo must be non-negative') 146 | if hi is None: 147 | hi = len(review_arr) 148 | while lo < hi: 149 | mid = (lo+hi)//2 150 | if timestamp < review_loc_time_arr[review_arr[mid]][2]: hi = mid 151 | else: lo = mid+1 152 | return lo 153 | 154 | def __len__(self): 155 | return len(self._data) 156 | 157 | def __getitem__(self, index): 158 | return self._data[index] 159 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | main entry of the script, train, validate and test 3 | """ 4 | import torch 5 | import argparse 6 | import random 7 | import glob 8 | import os 9 | 10 | from others.logging import logger, init_logger 11 | from models.ps_model import ProductRanker, build_optim 12 | from models.item_transformer import ItemTransformerRanker 13 | from data.data_util import GlobalProdSearchData, ProdSearchData 14 | from trainer import Trainer 15 | from data.prod_search_dataset import ProdSearchDataset 16 | 17 | def str2bool(v): 18 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 19 | return True 20 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 21 | return False 22 | else: 23 | raise argparse.ArgumentTypeError('Boolean value expected.') 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--seed', default=666, type=int) 29 | parser.add_argument("--train_from", default='') 30 | parser.add_argument("--model_name", default='review_transformer', 31 | choices=['review_transformer', 'item_transformer', 'ZAM', 'AEM', 'QEM'], help="which type of model is used to train") 32 | parser.add_argument("--sep_prod_emb", type=str2bool, nargs='?',const=True,default=False, 33 | help="whether to use separate embeddings for historical product and the target product") 34 | parser.add_argument("--pretrain_emb_dir", default='', help="pretrained paragraph and word embeddings") 35 | parser.add_argument("--pretrain_up_emb_dir", default='', help="pretrained user item embeddings") 36 | #parser.add_argument("--review_train_mode", type=str, default='random', choices=["random", "prev", "last"] 37 | # help="use (sorted) random reviews in training data; reviews before current purchase; reviews that occur at the last in training set for training; ") 38 | parser.add_argument("--fix_train_review", type=str2bool, nargs='?',const=True,default=False, 39 | help="fix train reviews (the last reviews in the training set); ") 40 | parser.add_argument("--do_seq_review_train", type=str2bool, nargs='?',const=True,default=False, 41 | help="only use reviews before current purchase for training; ") 42 | parser.add_argument("--do_seq_review_test", type=str2bool, nargs='?',const=True,default=False, 43 | help="during test time if only training data is available, use the most recent iprev and uprev reviews; if train_review_only is False, use all the sequential reviews available before current review, including validation and test.") 44 | parser.add_argument("--fix_emb", type=str2bool, nargs='?',const=True,default=False, 45 | help="fix word embeddings or review embeddings during training.") 46 | parser.add_argument("--use_dot_prod", type=str2bool, nargs='?',const=True,default=True, 47 | help="use positional embeddings when encoding reviews.") 48 | parser.add_argument("--sim_func", type=str, default="product", choices=["bias_product", "product", "cosine"], help="similarity computation method.") 49 | parser.add_argument("--use_pos_emb", type=str2bool, nargs='?',const=True,default=True, 50 | help="use positional embeddings when encoding reviews.") 51 | parser.add_argument("--use_seg_emb", type=str2bool, nargs='?',const=True,default=True, 52 | help="use segment embeddings when encoding reviews.") 53 | parser.add_argument("--use_item_pos", type=str2bool, nargs='?',const=True,default=False, 54 | help="use the embeddings corresponding to a candidate item as an output when encoding the purchased item sequence.") 55 | parser.add_argument("--use_item_emb", type=str2bool, nargs='?',const=True,default=False, 56 | help="use item embeddings when encoding review sequence.") 57 | parser.add_argument("--use_user_emb", type=str2bool, nargs='?',const=True,default=False, 58 | help="use user embeddings when encoding review sequence.") 59 | parser.add_argument("--rankfname", default="test.best_model.ranklist") 60 | parser.add_argument("--dropout", default=0.1, type=float) 61 | parser.add_argument("--token_dropout", default=0.1, type=float) 62 | parser.add_argument("--optim", type=str, default="adam", help="sgd or adam") 63 | parser.add_argument("--lr", default=0.002, type=float) #0.002 64 | parser.add_argument("--beta1", default= 0.9, type=float) 65 | parser.add_argument("--beta2", default=0.999, type=float) 66 | parser.add_argument("--decay_method", default='adam', choices=['noam', 'adam'],type=str) #warmup learning rate then decay 67 | parser.add_argument("--warmup_steps", default=8000, type=int) #10000 68 | parser.add_argument("--max_grad_norm", type=float, default=5.0, 69 | help="Clip gradients to this norm.") 70 | parser.add_argument("--subsampling_rate", type=float, default=1e-5, 71 | help="The rate to subsampling.") 72 | parser.add_argument("--do_subsample_mask", type=str2bool, nargs='?',const=True,default=False, 73 | help="do subsampling mask do the reviews with cutoff review_word_limit; otherwise do subsampling then do the cutoff.") 74 | parser.add_argument("--prod_freq_neg_sample", type=str2bool, nargs='?',const=True,default=False, 75 | help="whether to sample negative products according to their purchase frequency.") 76 | parser.add_argument("--pos_weight", type=str2bool, nargs='?',const=True,default=False, 77 | help="use pos_weight different from 1 during training.") 78 | parser.add_argument("--l2_lambda", type=float, default=0.0, 79 | help="The lambda for L2 regularization.") 80 | parser.add_argument("--batch_size", type=int, default=32, 81 | help="Batch size to use during training.") 82 | parser.add_argument("--has_valid", type=str2bool, nargs='?',const=True, default=False, 83 | help="whether there is validation set; if not use test as validation.") 84 | parser.add_argument("--valid_batch_size", type=int, default=24, 85 | help="Batch size for validation to use during training.") 86 | parser.add_argument("--valid_candi_size", type=int, default=500, # can be -1 if you want to evaluate on all the products. 87 | help="Random products used for validation. When it is 0 or less, all the products are used.") 88 | parser.add_argument("--test_candi_size", type=int, default=-1, # 89 | help="When it is 0 or less, all the products are used. Otherwise, test_candi_size samples from ranklist will be reranked") 90 | parser.add_argument("--candi_batch_size", type=int, default=500, 91 | help="Batch size for validation to use during training.") 92 | parser.add_argument("--num_workers", type=int, default=4, 93 | help="Number of processes to load batches of data during training.") 94 | parser.add_argument("--data_dir", type=str, default="/tmp", help="Data directory") 95 | parser.add_argument("--input_train_dir", type=str, default="", help="The directory of training and testing data") 96 | parser.add_argument("--save_dir", type=str, default="/tmp", help="Model directory & output directory") 97 | parser.add_argument("--log_file", type=str, default="train.log", help="log file name") 98 | parser.add_argument("--query_encoder_name", type=str, default="fs", choices=["fs","avg"], 99 | help="Specify network structure parameters. Please read readme.txt for details.") 100 | parser.add_argument("--review_encoder_name", type=str, default="pvc", choices=["pv", "pvc", "fs", "avg"], 101 | help="Specify network structure parameters. ") 102 | parser.add_argument("--embedding_size", type=int, default=128, help="Size of each embedding.") 103 | parser.add_argument("--ff_size", type=int, default=512, help="size of feedforward layers in transformers.") 104 | parser.add_argument("--heads", default=8, type=int, help="attention heads in transformers") 105 | parser.add_argument("--inter_layers", default=2, type=int, help="transformer layers") 106 | parser.add_argument("--review_word_limit", type=int, default=100, 107 | help="the limit of number of words in reviews, for review_transformer.") 108 | parser.add_argument("--uprev_review_limit", type=int, default=20, 109 | help="the number of items the user previously purchased in TEM; \ 110 | the number of users previous reviews used in RTM.") 111 | parser.add_argument("--iprev_review_limit", type=int, default=30, 112 | help="the number of item's previous reviews used.") 113 | parser.add_argument("--pv_window_size", type=int, default=1, help="Size of context window.") 114 | parser.add_argument("--corrupt_rate", type=float, default=0.9, help="the corruption rate that is used to represent the paragraph in the corruption module.") 115 | parser.add_argument("--shuffle_review_words", type=str2bool, nargs='?',const=True,default=True,help="shuffle review words before collecting sliding words.") 116 | parser.add_argument("--train_review_only", type=str2bool, nargs='?',const=True,default=True,help="whether the representation of negative products need to be learned at each step.") 117 | parser.add_argument("--max_train_epoch", type=int, default=20, 118 | help="Limit on the epochs of training (0: no limit).") 119 | parser.add_argument("--train_pv_epoch", type=int, default=0, 120 | help="Limit on the epochs of training pv (0: do not train according to pv loss).") 121 | parser.add_argument("--start_epoch", type=int, default=0, 122 | help="the epoch where we start training.") 123 | parser.add_argument("--steps_per_checkpoint", type=int, default=200, 124 | help="How many training steps to do per checkpoint.") 125 | parser.add_argument("--neg_per_pos", type=int, default=5, 126 | help="How many negative samples used to pair with postive results.") 127 | parser.add_argument("--sparse_emb", action='store_true', 128 | help="use sparse embedding or not.") 129 | parser.add_argument("--scale_grad", action='store_true', 130 | help="scale the grad of word and av embeddings.") 131 | parser.add_argument("-nw", "--weight_distort", action='store_true', 132 | help="Set to True to use 0.75 power to redistribute for neg sampling .") 133 | parser.add_argument("--mode", type=str, default="train", choices=["train", "valid", "test"]) 134 | parser.add_argument("--rank_cutoff", type=int, default=100, 135 | help="Rank cutoff for output ranklists.") 136 | parser.add_argument('--device', default='cuda', choices=['cpu', 'cuda'], help="use CUDA or cpu") 137 | return parser.parse_args() 138 | 139 | model_flags = ['embedding_size', 'ff_size', 'heads', 'inter_layers','review_encoder_name','query_encoder_name'] 140 | 141 | def create_model(args, global_data, prod_data, load_path=''): 142 | """Create translation model and initialize or load parameters in session.""" 143 | if args.model_name == "review_transformer": 144 | model = ProductRanker(args, args.device, global_data.vocab_size, 145 | global_data.review_count, global_data.product_size, global_data.user_size, 146 | global_data.review_words, global_data.words, word_dists=prod_data.word_dists) 147 | else: 148 | model = ItemTransformerRanker(args, args.device, global_data.vocab_size, 149 | global_data.product_size, global_data.words, word_dists=prod_data.word_dists) 150 | 151 | if os.path.exists(load_path): 152 | #if load_path != '': 153 | logger.info('Loading checkpoint from %s' % load_path) 154 | checkpoint = torch.load(load_path, 155 | map_location=lambda storage, loc: storage) 156 | opt = vars(checkpoint['opt']) 157 | for k in opt.keys(): 158 | if (k in model_flags): 159 | setattr(args, k, opt[k]) 160 | args.start_epoch = checkpoint['epoch'] 161 | model.load_cp(checkpoint) 162 | optim = build_optim(args, model, checkpoint) 163 | else: 164 | logger.info('No available model to load. Build new model.') 165 | optim = build_optim(args, model, None) 166 | logger.info(model) 167 | return model, optim 168 | 169 | def train(args): 170 | args.start_epoch = 0 171 | logger.info('Device %s' % args.device) 172 | 173 | torch.manual_seed(args.seed) 174 | random.seed(args.seed) 175 | torch.backends.cudnn.deterministic = True 176 | if args.device == "cuda": 177 | torch.cuda.manual_seed(args.seed) 178 | 179 | global_data = GlobalProdSearchData(args, args.data_dir, args.input_train_dir) 180 | train_prod_data = ProdSearchData(args, args.input_train_dir, "train", global_data) 181 | #subsampling has been done in train_prod_data 182 | model, optim = create_model(args, global_data, train_prod_data, args.train_from) 183 | trainer = Trainer(args, model, optim) 184 | valid_prod_data = ProdSearchData(args, args.input_train_dir, "valid", global_data) 185 | best_checkpoint_path = trainer.train(trainer.args, global_data, train_prod_data, valid_prod_data) 186 | test_prod_data = ProdSearchData(args, args.input_train_dir, "test", global_data) 187 | best_model, _ = create_model(args, global_data, train_prod_data, best_checkpoint_path) 188 | del trainer 189 | torch.cuda.empty_cache() 190 | trainer = Trainer(args, best_model, None) 191 | trainer.test(args, global_data, test_prod_data, args.rankfname) 192 | 193 | def validate(args): 194 | cp_files = sorted(glob.glob(os.path.join(args.save_dir, 'model_epoch_*.ckpt'))) 195 | global_data = GlobalProdSearchData(args, args.data_dir, args.input_train_dir) 196 | valid_prod_data = ProdSearchData(args, args.input_train_dir, "valid", global_data) 197 | #valid_prod_data = ProdSearchData(args, args.input_train_dir, "test", global_data) 198 | valid_dataset = ProdSearchDataset(args, global_data, valid_prod_data) 199 | best_mrr, best_model = 0, None 200 | for cur_model_file in cp_files: 201 | #logger.info("Loading {}".format(cur_model_file)) 202 | cur_model, _ = create_model(args, global_data, valid_prod_data, cur_model_file) 203 | trainer = Trainer(args, cur_model, None) 204 | mrr, prec = trainer.validate(args, global_data, valid_dataset) 205 | logger.info("MRR:{} P@1:{} Model:{}".format(mrr, prec, cur_model_file)) 206 | if mrr > best_mrr: 207 | best_mrr = mrr 208 | best_model = cur_model_file 209 | 210 | test_prod_data = ProdSearchData(args, args.input_train_dir, "test", global_data) 211 | 212 | best_model, _ = create_model(args, global_data, test_prod_data, best_model) 213 | trainer = Trainer(args, best_model, None) 214 | trainer.test(args, global_data, test_prod_data, args.rankfname) 215 | 216 | def get_product_scores(args): 217 | global_data = GlobalProdSearchData(args, args.data_dir, args.input_train_dir) 218 | test_prod_data = ProdSearchData(args, args.input_train_dir, "test", global_data) 219 | model_path = os.path.join(args.save_dir, 'model_best.ckpt') 220 | best_model, _ = create_model(args, global_data, test_prod_data, model_path) 221 | trainer = Trainer(args, best_model, None) 222 | trainer.test(args, global_data, test_prod_data, args.rankfname) 223 | 224 | def main(args): 225 | if not os.path.isdir(args.save_dir): 226 | os.makedirs(args.save_dir) 227 | init_logger(os.path.join(args.save_dir, args.log_file)) 228 | logger.info(args) 229 | if args.mode == "train": 230 | train(args) 231 | elif args.mode == "valid": 232 | validate(args) 233 | else: 234 | get_product_scores(args) 235 | if __name__ == '__main__': 236 | main(parse_args()) 237 | -------------------------------------------------------------------------------- /models/PV.py: -------------------------------------------------------------------------------- 1 | """Encode reviews. It can be: 2 | 1) Read from previously trained paragraph vectors. 3 | 2) From word embeddings [avg, projected weight avg, or CNN, RNN] 4 | 3) Train embedding jointly with the loss of purchases 5 | review_id, a group of words in the review (random -> PV with corruption; in order -> PV) 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.text_encoder import get_vector_mean 11 | from others.util import load_pretrain_embeddings 12 | 13 | import argparse 14 | 15 | class ParagraphVector(nn.Module): 16 | def __init__(self, word_embeddings, word_dists, review_count, 17 | dropout=0.0, pretrain_emb_path=None, fix_emb=False): 18 | super(ParagraphVector, self).__init__() 19 | self.word_embeddings = word_embeddings 20 | self.fix_emb = fix_emb 21 | self.dropout_ = dropout 22 | self.word_dists = word_dists 23 | self._embedding_size = self.word_embeddings.weight.size()[-1] 24 | self.review_count = review_count 25 | self.review_pad_idx = review_count-1 26 | self.pretrain_emb_path = pretrain_emb_path 27 | if pretrain_emb_path is not None: 28 | _, pretrained_weights = load_pretrain_embeddings(pretrain_emb_path) 29 | pretrained_weights.append([0. for _ in range(self._embedding_size)]) 30 | pretrained_weights = torch.FloatTensor(pretrained_weights) 31 | self.review_embeddings = nn.Embedding.from_pretrained(pretrained_weights) 32 | #, scale_grad_by_freq = scale_grad, sparse=self.is_emb_sparse 33 | else: 34 | self.review_embeddings = nn.Embedding( 35 | self.review_count, self._embedding_size, padding_idx=self.review_pad_idx) 36 | if self.fix_emb: 37 | self.review_embeddings.weight.requires_grad = False 38 | self.dropout_ = 0 39 | self.drop_layer = nn.Dropout(p=self.dropout_) 40 | self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(reduction='none')#by default it's mean 41 | 42 | @property 43 | def embedding_size(self): 44 | return self._embedding_size 45 | 46 | def get_para_vector(self, review_ids): 47 | review_emb = self.review_embeddings(review_ids) 48 | return review_emb 49 | 50 | def forward(self, review_ids, review_word_emb, review_word_mask, n_negs): 51 | batch_size, pv_window_size, embedding_size = review_word_emb.size() 52 | #for each target word, there is k words negative sampling 53 | review_emb = self.review_embeddings(review_ids) 54 | review_emb = self.drop_layer(review_emb) 55 | #vocab_size = self.word_embeddings.weight.size() - 1 56 | #compute the loss of review generating positive and negative words 57 | neg_sample_idxs = torch.multinomial(self.word_dists, batch_size * pv_window_size * n_negs, replacement=True) 58 | neg_sample_emb = self.word_embeddings(neg_sample_idxs.view(batch_size,-1)) 59 | output_pos = torch.bmm(review_word_emb, review_emb.unsqueeze(2)) # batch_size, pv_window_size, 1 60 | output_neg = torch.bmm(neg_sample_emb, review_emb.unsqueeze(2)).view(batch_size, pv_window_size, -1) 61 | scores = torch.cat((output_pos, output_neg), dim=-1) #batch_size, pv_window_size, 1+n_negs 62 | target = torch.cat((torch.ones(output_pos.size(), device=scores.device), 63 | torch.zeros(output_neg.size(), device=scores.device)), dim=-1) 64 | loss = self.bce_logits_loss(scores, target).sum(-1) #batch_size, pv_window_size 65 | loss = get_vector_mean(loss.unsqueeze(-1), review_word_mask) 66 | #cuda.longtensor 67 | #negative sampling according to x^0.75 68 | #each word has n_neg corresponding samples 69 | ''' 70 | oloss = torch.bmm(review_word_emb, review_emb.unsqueeze(2)).squeeze(-1) 71 | nloss = torch.bmm(neg_sample_emb.neg(), review_emb.unsqueeze(2)).squeeze(-1) 72 | nloss = nloss.view(batch_size, pv_window_size, -1) 73 | oloss = oloss.sigmoid().log() #batch_size, pv_window_size 74 | nloss = nloss.sigmoid().log().sum(2)# batch_size, pv_window_size#(n_negs->1) 75 | loss = -(nloss + oloss) # * review_word_mask.float() 76 | loss = get_vector_mean(loss.unsqueeze(-1), review_word_mask) 77 | #(batch_size, ) 78 | #loss = loss.sum() / review_ids.ne(self.review_pad_idx).float().sum() 79 | ''' 80 | return review_emb, loss 81 | 82 | def initialize_parameters(self, logger=None): 83 | if logger: 84 | logger.info(" ReviewEncoder initialization started.") 85 | #otherwise, load pretrained embeddings 86 | if self.pretrain_emb_path is None: 87 | nn.init.normal_(self.review_embeddings.weight) 88 | 89 | if logger: 90 | logger.info(" ReviewEncoder initialization finished.") 91 | 92 | -------------------------------------------------------------------------------- /models/PVC.py: -------------------------------------------------------------------------------- 1 | """Encode reviews. It can be: 2 | 1) Read from previously trained paragraph vectors. 3 | 2) From word embeddings [avg, projected weight avg, or CNN, RNN] 4 | 3) Train embedding jointly with the loss of purchases 5 | review_id, a group of words in the review (random -> PV with corruption; in order -> PV) 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.text_encoder import get_vector_mean 11 | from others.util import load_pretrain_embeddings 12 | 13 | import argparse 14 | 15 | class ParagraphVectorCorruption(nn.Module): 16 | def __init__(self, word_embeddings, word_dists, corrupt_rate, 17 | dropout=0.0, pretrain_emb_path=None, vocab_words=None, fix_emb=False): 18 | super(ParagraphVectorCorruption, self).__init__() 19 | self.word_embeddings = word_embeddings 20 | self.word_dists = word_dists 21 | self._embedding_size = self.word_embeddings.weight.size()[-1] 22 | vocab_size = self.word_embeddings.weight.size()[0] 23 | self.word_pad_idx = vocab_size - 1 24 | if pretrain_emb_path is not None and vocab_words is not None: 25 | word_index_dic, pretrained_weights = load_pretrain_embeddings(pretrain_emb_path) 26 | word_indices = torch.tensor([0] + [word_index_dic[x] for x in vocab_words[1:]] + [self.word_pad_idx]) 27 | pretrained_weights = torch.FloatTensor(pretrained_weights) 28 | self.context_embeddings = nn.Embedding.from_pretrained(pretrained_weights[word_indices], padding_idx=self.word_pad_idx) 29 | else: 30 | self.context_embeddings = self.word_embeddings 31 | #self.context_embeddings = nn.Embedding( 32 | # vocab_size, self._embedding_size, padding_idx=self.word_pad_idx) 33 | if fix_emb: 34 | self.context_embeddings.weight.requires_grad = False 35 | self.dropout_ = 0 36 | self.corrupt_rate = corrupt_rate 37 | self.train_corrupt_rate = corrupt_rate 38 | self.dropout_ = dropout 39 | self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(reduction='none')#by default it's mean 40 | #vocab_size - 1 41 | 42 | @property 43 | def embedding_size(self): 44 | return self._embedding_size 45 | 46 | def apply_token_dropout(self, inputs, drop_prob): 47 | #randomly dropout some token in the review 48 | #batch_size, review_word_count, embedding_size 49 | probs = inputs.data.new().resize_(inputs.size()[:-1]).fill_(drop_prob) 50 | #batch_size, review_word_count 51 | mask = torch.bernoulli(probs).byte().unsqueeze(-1) #the probability of drawing 1 52 | #batch_size, review_word_count, 1 53 | inputs.data.masked_fill_(mask, 0).mul_(1./(1-drop_prob)) #drop_prob to fill data to 0 54 | return inputs 55 | 56 | def get_para_vector(self, prod_rword_idxs_pvc): 57 | pvc_word_emb = self.context_embeddings(prod_rword_idxs_pvc) 58 | if self.corrupt_rate > 0.: 59 | self.apply_token_dropout(pvc_word_emb, self.corrupt_rate) 60 | review_emb = get_vector_mean(pvc_word_emb, prod_rword_idxs_pvc.ne(self.word_pad_idx)) 61 | return review_emb 62 | 63 | def set_to_evaluation_mode(self): 64 | self.corrupt_rate = 0. 65 | 66 | def set_to_train_mode(self): 67 | self.corrupt_rate = self.train_corrupt_rate 68 | 69 | def forward(self, review_word_emb, review_word_mask, prod_rword_idxs_pvc, n_negs): 70 | ''' 71 | prod_rword_idxs_pvc: batch_size (real_batch_size * review_count), review_word_limit 72 | review_word_emb: batch_size * reivew_count, embedding_size 73 | review_word_mask: indicate which target is valid 74 | ''' 75 | batch_size, pv_window_size, embedding_size = review_word_emb.size() 76 | pvc_word_emb = self.context_embeddings(prod_rword_idxs_pvc) 77 | review_emb = get_vector_mean(pvc_word_emb, prod_rword_idxs_pvc.ne(self.word_pad_idx)) 78 | self.apply_token_dropout(pvc_word_emb, self.corrupt_rate) 79 | corr_review_emb = get_vector_mean(pvc_word_emb, prod_rword_idxs_pvc.ne(self.word_pad_idx)) 80 | 81 | #for each target word, there is k words negative sampling 82 | #compute the loss of review generating positive and negative words 83 | neg_sample_idxs = torch.multinomial(self.word_dists, batch_size * pv_window_size * n_negs, replacement=True) 84 | neg_sample_emb = self.word_embeddings(neg_sample_idxs.view(batch_size, -1)) 85 | output_pos = torch.bmm(review_word_emb, corr_review_emb.unsqueeze(2)) # batch_size, pv_window_size, 1 86 | output_neg = torch.bmm(neg_sample_emb, corr_review_emb.unsqueeze(2)).view(batch_size, pv_window_size, -1) 87 | scores = torch.cat((output_pos, output_neg), dim=-1) #batch_size, pv_window_size, 1+n_negs 88 | target = torch.cat((torch.ones(output_pos.size(), device=scores.device), 89 | torch.zeros(output_neg.size(), device=scores.device)), dim=-1) 90 | loss = self.bce_logits_loss(scores, target).sum(-1) #batch_size, pv_window_size 91 | loss = get_vector_mean(loss.unsqueeze(-1), review_word_mask) 92 | 93 | #cuda.longtensor 94 | #negative sampling according to x^0.75 95 | return review_emb, loss 96 | 97 | def forward_deprecated(self, review_word_emb, review_word_mask, prod_rword_idxs_pvc, rand_prod_rword_idxs, n_negs): 98 | ''' rand_prod_rword_idxs: batch_size, review_count, pv_window_size * pvc_word_count 99 | prod_rword_idxs_pvc: batch_size, review_count, review_word_limit 100 | review_word_mask: indicate which target is valid 101 | ''' 102 | batch_size, pv_window_size, embedding_size = review_word_emb.size() 103 | _,_,word_count = prod_rword_idxs_pvc.size() 104 | pvc_word_count = word_count / pv_window_size 105 | 106 | rand_word_emb = self.word_embeddings(rand_prod_rword_idxs.view(-1, pvc_word_count)) 107 | corr_review_vector = get_vector_mean(rand_word_emb, rand_prod_rword_idxs.ne(self.word_pad_idx)) 108 | corr_review_vector = corr_review_vector.view(-1, embedding_size, 1) 109 | 110 | #for each target word, there is k words negative sampling 111 | vocab_size = word_embeddings.weight.size() - 1 112 | #compute the loss of review generating positive and negative words 113 | neg_sample_idxs = torch.multinomial(self.word_dists, batch_size * pv_window_size * n_negs, replacement=True) 114 | neg_sample_emb = self.word_embeddings(neg_sample_idxs) 115 | #cuda.longtensor 116 | #negative sampling according to x^0.75 117 | #each word has n_neg corresponding samples 118 | target_emb = review_word_emb.view(batch_size*pv_window_size, embedding_size).unsqueeze(1) 119 | oloss = torch.bmm(target_emb, corr_review_vector).squeeze(-1).squeeze(-1).view(batch_size, -1) 120 | nloss = torch.bmm(neg_sample_emb.unsqueeze(1).neg(), corr_review_vector).squeeze(-1).squeeze(-1) 121 | nloss = nloss.view(batch_size, pv_window_size, -1) 122 | oloss = oloss.sigmoid().log() #batch_size, pv_window_size 123 | nloss = nloss.sigmoid().log().sum(2)# batch_size, pv_window_size#(n_negs->1) 124 | loss = -(nloss + oloss) #* review_word_mask.float() 125 | loss = get_vector_mean(loss.unsqueeze(-1), review_word_mask) 126 | #(batch_size, ) 127 | #loss = get_vector_mean(loss.unsqueeze(-1), review_ids.ne(self.review_pad_idx)) 128 | #loss = loss.mean() 129 | _,rcount, review_word_limit = prod_rword_idxs_pvc.size() 130 | pvc_word_emb = self.word_embeddings(prod_rword_idxs_pvc.view(-1, review_word_limit)) 131 | review_emb = get_vector_mean(pvc_word_emb, prod_rword_idxs_pvc.ne(self.word_pad_idx)) 132 | 133 | return review_emb.view(-1, rcount, embedding_size), loss 134 | 135 | def initialize_parameters(self, logger=None): 136 | if logger: 137 | logger.info(" Another group of embeddings initialization started.") 138 | #otherwise, load pretrained embeddings 139 | #nn.init.normal_(self.review_embeddings.weight) 140 | if logger: 141 | logger.info(" Another group of embeddings initialization finished.") 142 | 143 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kepingbi/ProdSearch/449335ba652fe7c877a008e154157d7b2a4b0e76/models/__init__.py -------------------------------------------------------------------------------- /models/neural.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def gelu(x): 8 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 9 | 10 | 11 | class PositionwiseFeedForward(nn.Module): 12 | """ A two-layer Feed-Forward-Network with residual layer norm. 13 | 14 | Args: 15 | d_model (int): the size of input for the first-layer of the FFN. 16 | d_ff (int): the hidden layer size of the second-layer 17 | of the FNN. 18 | dropout (float): dropout probability in :math:`[0, 1)`. 19 | """ 20 | 21 | def __init__(self, d_model, d_ff, dropout=0.1): 22 | super(PositionwiseFeedForward, self).__init__() 23 | self.w_1 = nn.Linear(d_model, d_ff) 24 | self.w_2 = nn.Linear(d_ff, d_model) 25 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 26 | self.actv = gelu 27 | self.dropout_1 = nn.Dropout(dropout) 28 | self.dropout_2 = nn.Dropout(dropout) 29 | 30 | def forward(self, x): 31 | inter = self.dropout_1(self.actv(self.w_1(self.layer_norm(x)))) 32 | output = self.dropout_2(self.w_2(inter)) 33 | return output + x 34 | 35 | 36 | class MultiHeadedAttention(nn.Module): 37 | """ 38 | Multi-Head Attention module from 39 | "Attention is All You Need" 40 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 41 | 42 | Similar to standard `dot` attention but uses 43 | multiple attention distributions simulataneously 44 | to select relevant items. 45 | 46 | .. mermaid:: 47 | 48 | graph BT 49 | A[key] 50 | B[value] 51 | C[query] 52 | O[output] 53 | subgraph Attn 54 | D[Attn 1] 55 | E[Attn 2] 56 | F[Attn N] 57 | end 58 | A --> D 59 | C --> D 60 | A --> E 61 | C --> E 62 | A --> F 63 | C --> F 64 | D --> O 65 | E --> O 66 | F --> O 67 | B --> O 68 | 69 | Also includes several additional tricks. 70 | 71 | Args: 72 | head_count (int): number of parallel heads 73 | model_dim (int): the dimension of keys/values/queries, 74 | must be divisible by head_count 75 | dropout (float): dropout parameter 76 | """ 77 | 78 | def __init__(self, head_count, model_dim, dropout=0.1, use_final_linear=True): 79 | assert model_dim % head_count == 0 80 | self.dim_per_head = model_dim // head_count 81 | self.model_dim = model_dim 82 | 83 | super(MultiHeadedAttention, self).__init__() 84 | self.head_count = head_count 85 | 86 | self.linear_keys = nn.Linear(model_dim, 87 | head_count * self.dim_per_head) 88 | self.linear_values = nn.Linear(model_dim, 89 | head_count * self.dim_per_head) 90 | self.linear_query = nn.Linear(model_dim, 91 | head_count * self.dim_per_head) 92 | self.softmax = nn.Softmax(dim=-1) 93 | self.dropout = nn.Dropout(dropout) 94 | self.use_final_linear = use_final_linear 95 | if (self.use_final_linear): 96 | self.final_linear = nn.Linear(model_dim, model_dim) 97 | 98 | def forward(self, key, value, query, mask=None, 99 | layer_cache=None, type=None, predefined_graph_1=None): 100 | """ 101 | Compute the context vector and the attention vectors. 102 | 103 | Args: 104 | key (`FloatTensor`): set of `key_len` 105 | key vectors `[batch, key_len, dim]` 106 | value (`FloatTensor`): set of `key_len` 107 | value vectors `[batch, key_len, dim]` 108 | query (`FloatTensor`): set of `query_len` 109 | query vectors `[batch, query_len, dim]` 110 | mask: binary mask indicating which keys have 111 | non-zero attention `[batch, query_len, key_len]` 112 | Returns: 113 | (`FloatTensor`, `FloatTensor`) : 114 | 115 | * output context vectors `[batch, query_len, dim]` 116 | * one of the attention vectors `[batch, query_len, key_len]` 117 | """ 118 | 119 | # CHECKS 120 | # batch, k_len, d = key.size() 121 | # batch_, k_len_, d_ = value.size() 122 | # aeq(batch, batch_) 123 | # aeq(k_len, k_len_) 124 | # aeq(d, d_) 125 | # batch_, q_len, d_ = query.size() 126 | # aeq(batch, batch_) 127 | # aeq(d, d_) 128 | # aeq(self.model_dim % 8, 0) 129 | # if mask is not None: 130 | # batch_, q_len_, k_len_ = mask.size() 131 | # aeq(batch_, batch) 132 | # aeq(k_len_, k_len) 133 | # aeq(q_len_ == q_len) 134 | # END CHECKS 135 | 136 | batch_size = key.size(0) 137 | dim_per_head = self.dim_per_head 138 | head_count = self.head_count 139 | key_len = key.size(1) 140 | query_len = query.size(1) 141 | 142 | def shape(x): 143 | """ projection """ 144 | return x.view(batch_size, -1, head_count, dim_per_head) \ 145 | .transpose(1, 2) 146 | 147 | def unshape(x): 148 | """ compute context """ 149 | return x.transpose(1, 2).contiguous() \ 150 | .view(batch_size, -1, head_count * dim_per_head) 151 | 152 | # 1) Project key, value, and query. 153 | if layer_cache is not None: 154 | if type == "self": 155 | query, key, value = self.linear_query(query), \ 156 | self.linear_keys(query), \ 157 | self.linear_values(query) 158 | 159 | key = shape(key) 160 | value = shape(value) 161 | 162 | if layer_cache is not None: 163 | device = key.device 164 | if layer_cache["self_keys"] is not None: 165 | key = torch.cat( 166 | (layer_cache["self_keys"].to(device), key), 167 | dim=2) 168 | if layer_cache["self_values"] is not None: 169 | value = torch.cat( 170 | (layer_cache["self_values"].to(device), value), 171 | dim=2) 172 | layer_cache["self_keys"] = key 173 | layer_cache["self_values"] = value 174 | elif type == "context": 175 | query = self.linear_query(query) 176 | if layer_cache is not None: 177 | if layer_cache["memory_keys"] is None: 178 | key, value = self.linear_keys(key), \ 179 | self.linear_values(value) 180 | key = shape(key) 181 | value = shape(value) 182 | else: 183 | key, value = layer_cache["memory_keys"], \ 184 | layer_cache["memory_values"] 185 | layer_cache["memory_keys"] = key 186 | layer_cache["memory_values"] = value 187 | else: 188 | key, value = self.linear_keys(key), \ 189 | self.linear_values(value) 190 | key = shape(key) 191 | value = shape(value) 192 | else: 193 | key = self.linear_keys(key) 194 | value = self.linear_values(value) 195 | query = self.linear_query(query) 196 | key = shape(key) 197 | value = shape(value) 198 | 199 | query = shape(query) 200 | #batch_size, head_count, max_sent_count, dim_per_head 201 | 202 | key_len = key.size(2) 203 | query_len = query.size(2) 204 | 205 | # 2) Calculate and scale scores. 206 | query = query / math.sqrt(dim_per_head) 207 | scores = torch.matmul(query, key.transpose(2, 3)) 208 | #batch_size, head_count, query_len, key_len 209 | 210 | if mask is not None: 211 | #mask is (batch_size, 1, key_len) 212 | if len(mask.size()) == 3: 213 | mask = mask.unsqueeze(1).expand_as(scores) 214 | scores = scores.masked_fill(mask, -1e18) 215 | 216 | # 3) Apply attention dropout and compute context vectors. 217 | 218 | attn = self.softmax(scores) 219 | 220 | if (not predefined_graph_1 is None): 221 | attn_masked = attn[:, -1] * predefined_graph_1 222 | attn_masked = attn_masked / (torch.sum(attn_masked, 2).unsqueeze(2) + 1e-9) 223 | 224 | attn = torch.cat([attn[:, :-1], attn_masked.unsqueeze(1)], 1) 225 | 226 | drop_attn = self.dropout(attn) 227 | if (self.use_final_linear): 228 | context = unshape(torch.matmul(drop_attn, value)) 229 | #batch_size, query_len, head_count * dim_per_head 230 | output = self.final_linear(context) 231 | return output 232 | else: 233 | context = torch.matmul(drop_attn, value) 234 | return context 235 | #batch_size, head_count, query_len, dim_per_head 236 | 237 | # CHECK 238 | # batch_, q_len_, d_ = output.size() 239 | # aeq(q_len, q_len_) 240 | # aeq(batch, batch_) 241 | # aeq(d, d_) 242 | 243 | # Return one attn 244 | 245 | -------------------------------------------------------------------------------- /models/optimizers.py: -------------------------------------------------------------------------------- 1 | """ Optimizers class """ 2 | import torch 3 | import torch.optim as optim 4 | from torch.nn.utils import clip_grad_norm_ 5 | 6 | 7 | # from onmt.utils import use_gpu 8 | 9 | 10 | def use_gpu(opt): 11 | """ 12 | Creates a boolean if gpu used 13 | """ 14 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 15 | (hasattr(opt, 'gpu') and opt.gpu > -1) 16 | 17 | def build_optim(model, opt, checkpoint): 18 | """ Build optimizer """ 19 | saved_optimizer_state_dict = None 20 | 21 | if opt.train_from: 22 | optim = checkpoint['optim'] 23 | # We need to save a copy of optim.optimizer.state_dict() for setting 24 | # the, optimizer state later on in Stage 2 in this method, since 25 | # the method optim.set_parameters(model.parameters()) will overwrite 26 | # optim.optimizer, and with ith the values stored in 27 | # optim.optimizer.state_dict() 28 | saved_optimizer_state_dict = optim.optimizer.state_dict() 29 | else: 30 | optim = Optimizer( 31 | opt.optim, opt.learning_rate, opt.max_grad_norm, 32 | lr_decay=opt.learning_rate_decay, 33 | start_decay_steps=opt.start_decay_steps, 34 | decay_steps=opt.decay_steps, 35 | beta1=opt.adam_beta1, 36 | beta2=opt.adam_beta2, 37 | adagrad_accum=opt.adagrad_accumulator_init, 38 | decay_method=opt.decay_method, 39 | warmup_steps=opt.warmup_steps) 40 | 41 | # Stage 1: 42 | # Essentially optim.set_parameters (re-)creates and optimizer using 43 | # model.paramters() as parameters that will be stored in the 44 | # optim.optimizer.param_groups field of the torch optimizer class. 45 | # Importantly, this method does not yet load the optimizer state, as 46 | # essentially it builds a new optimizer with empty optimizer state and 47 | # parameters from the model. 48 | optim.set_parameters(model.named_parameters()) 49 | 50 | if opt.train_from: 51 | # Stage 2: In this stage, which is only performed when loading an 52 | # optimizer from a checkpoint, we load the saved_optimizer_state_dict 53 | # into the re-created optimizer, to set the optim.optimizer.state 54 | # field, which was previously empty. For this, we use the optimizer 55 | # state saved in the "saved_optimizer_state_dict" variable for 56 | # this purpose. 57 | # See also: https://github.com/pytorch/pytorch/issues/2830 58 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 59 | # Convert back the state values to cuda type if applicable 60 | if use_gpu(opt): 61 | for state in optim.optimizer.state.values(): 62 | for k, v in state.items(): 63 | if torch.is_tensor(v): 64 | state[k] = v.cuda() 65 | 66 | # We want to make sure that indeed we have a non-empty optimizer state 67 | # when we loaded an existing model. This should be at least the case 68 | # for Adam, which saves "exp_avg" and "exp_avg_sq" state 69 | # (Exponential moving average of gradient and squared gradient values) 70 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 71 | raise RuntimeError( 72 | "Error: loaded Adam optimizer from existing model" + 73 | " but optimizer state is empty") 74 | 75 | return optim 76 | 77 | 78 | class MultipleOptimizer(object): 79 | """ Implement multiple optimizers needed for sparse adam """ 80 | 81 | def __init__(self, op): 82 | """ ? """ 83 | self.optimizers = op 84 | 85 | def zero_grad(self): 86 | """ ? """ 87 | for op in self.optimizers: 88 | op.zero_grad() 89 | 90 | def step(self): 91 | """ ? """ 92 | for op in self.optimizers: 93 | op.step() 94 | 95 | @property 96 | def state(self): 97 | """ ? """ 98 | return {k: v for op in self.optimizers for k, v in op.state.items()} 99 | 100 | def state_dict(self): 101 | """ ? """ 102 | return [op.state_dict() for op in self.optimizers] 103 | 104 | def load_state_dict(self, state_dicts): 105 | """ ? """ 106 | assert len(state_dicts) == len(self.optimizers) 107 | for i in range(len(state_dicts)): 108 | self.optimizers[i].load_state_dict(state_dicts[i]) 109 | 110 | 111 | class Optimizer(object): 112 | """ 113 | Controller class for optimization. Mostly a thin 114 | wrapper for `optim`, but also useful for implementing 115 | rate scheduling beyond what is currently available. 116 | Also implements necessary methods for training RNNs such 117 | as grad manipulations. 118 | 119 | Args: 120 | method (:obj:`str`): one of [sgd, adagrad, adadelta, adam] 121 | lr (float): learning rate 122 | lr_decay (float, optional): learning rate decay multiplier 123 | start_decay_steps (int, optional): step to start learning rate decay 124 | beta1, beta2 (float, optional): parameters for adam 125 | adagrad_accum (float, optional): initialization parameter for adagrad 126 | decay_method (str, option): custom decay options 127 | warmup_steps (int, option): parameter for `noam` decay 128 | 129 | We use the default parameters for Adam that are suggested by 130 | the original paper https://arxiv.org/pdf/1412.6980.pdf 131 | These values are also used by other established implementations, 132 | e.g. https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer 133 | https://keras.io/optimizers/ 134 | Recently there are slightly different values used in the paper 135 | "Attention is all you need" 136 | https://arxiv.org/pdf/1706.03762.pdf, particularly the value beta2=0.98 137 | was used there however, beta2=0.999 is still arguably the more 138 | established value, so we use that here as well 139 | """ 140 | 141 | def __init__(self, method, learning_rate, max_grad_norm, 142 | lr_decay=1, start_decay_steps=None, decay_steps=None, 143 | beta1=0.9, beta2=0.999, 144 | adagrad_accum=0.0, 145 | decay_method=None, 146 | warmup_steps=4000, 147 | weight_decay = 0. 148 | ): 149 | self.last_ppl = None 150 | self.learning_rate = learning_rate 151 | self.original_lr = learning_rate 152 | self.max_grad_norm = max_grad_norm 153 | self.method = method 154 | self.lr_decay = lr_decay 155 | self.start_decay_steps = start_decay_steps 156 | self.decay_steps = decay_steps 157 | self.start_decay = False 158 | self._step = 0 159 | self.betas = [beta1, beta2] 160 | self.adagrad_accum = adagrad_accum 161 | self.decay_method = decay_method 162 | self.warmup_steps = warmup_steps 163 | self.weight_decay = weight_decay 164 | 165 | def set_parameters(self, params): 166 | """ ? """ 167 | self.params = [] 168 | self.sparse_params = [] 169 | for k, p in params: 170 | if p.requires_grad: 171 | if self.method != 'sparseadam' or "embed" not in k: 172 | self.params.append(p) 173 | else: 174 | self.sparse_params.append(p) 175 | if self.method == 'sgd': 176 | self.optimizer = optim.SGD(self.params, lr=self.learning_rate, weight_decay=self.weight_decay) 177 | elif self.method == 'adagrad': 178 | self.optimizer = optim.Adagrad(self.params, lr=self.learning_rate, weight_decay=self.weight_decay) 179 | for group in self.optimizer.param_groups: 180 | for p in group['params']: 181 | self.optimizer.state[p]['sum'] = self.optimizer\ 182 | .state[p]['sum'].fill_(self.adagrad_accum) 183 | elif self.method == 'adadelta': 184 | self.optimizer = optim.Adadelta(self.params, lr=self.learning_rate, weight_decay=self.weight_decay) 185 | elif self.method == 'adam': 186 | self.optimizer = optim.Adam(self.params, lr=self.learning_rate, 187 | betas=self.betas, eps=1e-9, weight_decay=self.weight_decay) 188 | elif self.method == 'sparseadam': 189 | self.optimizer = MultipleOptimizer( 190 | [optim.Adam(self.params, lr=self.learning_rate, 191 | betas=self.betas, eps=1e-8, weight_decay=self.weight_decay), 192 | optim.SparseAdam(self.sparse_params, lr=self.learning_rate, 193 | betas=self.betas, eps=1e-8)]) 194 | else: 195 | raise RuntimeError("Invalid optim method: " + self.method) 196 | 197 | def _set_rate(self, learning_rate): 198 | self.learning_rate = learning_rate 199 | if self.method != 'sparseadam': 200 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 201 | else: 202 | for op in self.optimizer.optimizers: 203 | op.param_groups[0]['lr'] = self.learning_rate 204 | 205 | def step(self): 206 | """Update the model parameters based on current gradients. 207 | 208 | Optionally, will employ gradient modification or update learning 209 | rate. 210 | """ 211 | self._step += 1 212 | 213 | # Decay method used in tensor2tensor. 214 | if self.decay_method == "noam": 215 | self._set_rate( 216 | self.original_lr * 217 | 218 | min(self._step ** (-0.5), 219 | self._step * self.warmup_steps**(-1.5))) 220 | 221 | # self._set_rate(self.original_lr *self.model_size ** (-0.5) *min(1.0, self._step / self.warmup_steps)*max(self._step, self.warmup_steps)**(-0.5)) 222 | # Decay based on start_decay_steps every decay_steps 223 | else: 224 | if ((self.start_decay_steps is not None) and ( 225 | self._step >= self.start_decay_steps)): 226 | self.start_decay = True 227 | if self.start_decay: 228 | if ((self._step - self.start_decay_steps) 229 | % self.decay_steps == 0): 230 | self.learning_rate = self.learning_rate * self.lr_decay 231 | 232 | #if self.method != 'sparseadam': 233 | if self.decay_method == "noam" and self.method != 'sparseadam': 234 | #only if we schedule the learning rate with noam decay, we set the learning rate 235 | #otherwise we just use the original one by the optimizer 236 | #revised by Keping 237 | self.optimizer.param_groups[0]['lr'] = self.learning_rate 238 | #print("Number of parameter groups", len(self.optimizer.param_groups)) 239 | #there are only one parameter groups in current model 240 | 241 | if self.max_grad_norm: 242 | clip_grad_norm_(self.params, self.max_grad_norm) 243 | self.optimizer.step() 244 | 245 | 246 | -------------------------------------------------------------------------------- /models/ps_model.py: -------------------------------------------------------------------------------- 1 | """ transformer based on reviews 2 | Q+r_{u1}+r_{u2} <> r_1, r_2 (of a target i) 3 | """ 4 | """ 5 | review_encoder 6 | query_encoder 7 | transformer 8 | """ 9 | import os 10 | import torch 11 | import torch.nn as nn 12 | from models.PV import ParagraphVector 13 | from models.PVC import ParagraphVectorCorruption 14 | from models.text_encoder import AVGEncoder, FSEncoder 15 | from models.transformer import TransformerEncoder 16 | from models.optimizers import Optimizer 17 | from others.logging import logger 18 | from others.util import pad, load_pretrain_embeddings, load_user_item_embeddings 19 | 20 | def build_optim(args, model, checkpoint): 21 | """ Build optimizer """ 22 | saved_optimizer_state_dict = None 23 | 24 | if args.train_from != '' and checkpoint is not None: 25 | optim = checkpoint['optim'] 26 | saved_optimizer_state_dict = optim.optimizer.state_dict() 27 | else: 28 | optim = Optimizer( 29 | args.optim, args.lr, args.max_grad_norm, 30 | beta1=args.beta1, beta2=args.beta2, 31 | decay_method=args.decay_method, 32 | warmup_steps=args.warmup_steps, 33 | weight_decay=args.l2_lambda) 34 | #self.start_decay_steps take effect when decay_method is not noam 35 | 36 | optim.set_parameters(list(model.named_parameters())) 37 | 38 | if args.train_from != '' and checkpoint is not None: 39 | optim.optimizer.load_state_dict(saved_optimizer_state_dict) 40 | if args.device == "cuda": 41 | for state in optim.optimizer.state.values(): 42 | for k, v in state.items(): 43 | if torch.is_tensor(v): 44 | state[k] = v.cuda() 45 | 46 | if (optim.method == 'adam') and (len(optim.optimizer.state) < 1): 47 | raise RuntimeError( 48 | "Error: loaded Adam optimizer from existing model" + 49 | " but optimizer state is empty") 50 | 51 | return optim 52 | 53 | class ProductRanker(nn.Module): 54 | def __init__(self, args, device, vocab_size, review_count, 55 | product_size, user_size, 56 | review_words, vocab_words, word_dists=None): 57 | super(ProductRanker, self).__init__() 58 | self.args = args 59 | self.device = device 60 | self.train_review_only = args.train_review_only 61 | self.embedding_size = args.embedding_size 62 | self.vocab_words = vocab_words 63 | self.word_dists = None 64 | if word_dists is not None: 65 | self.word_dists = torch.tensor(word_dists, device=device) 66 | self.prod_pad_idx = product_size 67 | self.user_pad_idx = user_size 68 | self.word_pad_idx = vocab_size - 1 69 | self.seg_pad_idx = 3 70 | self.review_pad_idx = review_count - 1 71 | self.emb_dropout = args.dropout 72 | self.review_encoder_name = args.review_encoder_name 73 | self.fix_emb = args.fix_emb 74 | 75 | padded_review_words = review_words 76 | if not self.args.do_subsample_mask: 77 | #otherwise, review_words should be already padded 78 | padded_review_words = pad(review_words, pad_id=self.word_pad_idx, width=args.review_word_limit) 79 | self.review_words = torch.tensor(padded_review_words, device=device) 80 | 81 | self.pretrain_emb_dir = None 82 | if os.path.exists(args.pretrain_emb_dir): 83 | self.pretrain_emb_dir = args.pretrain_emb_dir 84 | self.pretrain_up_emb_dir = None 85 | if os.path.exists(args.pretrain_up_emb_dir): 86 | self.pretrain_up_emb_dir = args.pretrain_up_emb_dir 87 | self.dropout_layer = nn.Dropout(p=args.dropout) 88 | 89 | if self.args.use_user_emb: 90 | if self.pretrain_up_emb_dir is None: 91 | self.user_emb = nn.Embedding(user_size+1, self.embedding_size, padding_idx=self.user_pad_idx) 92 | else: 93 | pretrain_user_emb_path = os.path.join(self.pretrain_up_emb_dir, "user_emb.txt") 94 | pretrained_weights = load_user_item_embeddings(pretrain_user_emb_path) 95 | pretrained_weights.append([0.] * len(pretrained_weights[0])) 96 | assert len(pretrained_weights[0]) == self.embedding_size 97 | self.user_emb = nn.Embedding.from_pretrained( 98 | torch.FloatTensor(pretrained_weights), padding_idx=self.user_pad_idx) 99 | 100 | if self.args.use_item_emb: 101 | if self.pretrain_up_emb_dir is None: 102 | self.product_emb = nn.Embedding(product_size+1, self.embedding_size, padding_idx=self.prod_pad_idx) 103 | else: 104 | pretrain_product_emb_path = os.path.join(self.pretrain_up_emb_dir, "product_emb.txt") 105 | pretrained_weights = load_user_item_embeddings(pretrain_product_emb_path) 106 | pretrained_weights.append([0.] * len(pretrained_weights[0])) 107 | self.product_emb = nn.Embedding.from_pretrained( 108 | torch.FloatTensor(pretrained_weights), padding_idx=self.prod_pad_idx) 109 | 110 | if self.pretrain_emb_dir is not None: 111 | #word_emb_fname = "word_emb.txt.gz" #for query and target words in pv and pvc 112 | word_emb_fname = "context_emb.txt.gz" if args.review_encoder_name == "pvc" else "word_emb.txt.gz" #for query and target words in pv and pvc 113 | pretrain_word_emb_path = os.path.join(self.pretrain_emb_dir, word_emb_fname) 114 | word_index_dic, pretrained_weights = load_pretrain_embeddings(pretrain_word_emb_path) 115 | word_indices = torch.tensor([0] + [word_index_dic[x] for x in self.vocab_words[1:]] + [self.word_pad_idx]) 116 | #print(len(word_indices)) 117 | #print(word_indices.cpu().tolist()) 118 | pretrained_weights = torch.FloatTensor(pretrained_weights) 119 | self.word_embeddings = nn.Embedding.from_pretrained(pretrained_weights[word_indices], padding_idx=self.word_pad_idx) 120 | #vectors of padding idx will not be updated 121 | else: 122 | self.word_embeddings = nn.Embedding( 123 | vocab_size, self.embedding_size, padding_idx=self.word_pad_idx) 124 | 125 | if self.fix_emb and args.review_encoder_name == "pvc": 126 | #if review embeddings are fixed, just load the aggregated embeddings which include all the words in the review 127 | #otherwise the reviews are cut off at review_word_limit 128 | self.review_encoder_name = "pv" 129 | 130 | self.transformer_encoder = TransformerEncoder( 131 | self.embedding_size, args.ff_size, args.heads, 132 | args.dropout, args.inter_layers) 133 | 134 | if self.review_encoder_name == "pv": 135 | pretrain_emb_path = None 136 | if self.pretrain_emb_dir is not None: 137 | pretrain_emb_path = os.path.join(self.pretrain_emb_dir, "doc_emb.txt.gz") 138 | self.review_encoder = ParagraphVector( 139 | self.word_embeddings, self.word_dists, 140 | review_count, self.emb_dropout, pretrain_emb_path, fix_emb=self.fix_emb) 141 | elif self.review_encoder_name == "pvc": 142 | pretrain_emb_path = None 143 | #if self.pretrain_emb_dir is not None: 144 | # pretrain_emb_path = os.path.join(self.pretrain_emb_dir, "context_emb.txt.gz") 145 | self.review_encoder = ParagraphVectorCorruption( 146 | self.word_embeddings, self.word_dists, args.corrupt_rate, 147 | self.emb_dropout, pretrain_emb_path, self.vocab_words, fix_emb=self.fix_emb) 148 | elif self.review_encoder_name == "fs": 149 | self.review_encoder = FSEncoder(self.embedding_size, self.emb_dropout) 150 | else: 151 | self.review_encoder = AVGEncoder(self.embedding_size, self.emb_dropout) 152 | 153 | if args.query_encoder_name == "fs": 154 | self.query_encoder = FSEncoder(self.embedding_size, self.emb_dropout) 155 | else: 156 | self.query_encoder = AVGEncoder(self.embedding_size, self.emb_dropout) 157 | self.seg_embeddings = nn.Embedding(4, self.embedding_size, padding_idx=self.seg_pad_idx) 158 | #for each q,u,i 159 | #Q, previous purchases of u, current available reviews for i, padding value 160 | #self.logsoftmax = torch.nn.LogSoftmax(dim = -1) 161 | #self.bce_logits_loss = torch.nn.BCEWithLogitsLoss(reduction='none')#by default it's mean 162 | 163 | self.review_embeddings = None 164 | if self.fix_emb: 165 | #self.word_embeddings.weight.requires_grad = False 166 | #embeddings of query words need to be update 167 | #self.emb_dropout = 0 168 | self.get_review_embeddings() #get model.review_embeddings 169 | 170 | self.initialize_parameters(logger) #logger 171 | self.to(device) #change model in place 172 | 173 | 174 | def load_cp(self, pt, strict=True): 175 | self.load_state_dict(pt['model'], strict=strict) 176 | 177 | def clear_review_embbeddings(self): 178 | #otherwise review_embeddings are always the same 179 | if not self.fix_emb: 180 | self.review_embeddings = None 181 | #del self.review_embeddings 182 | torch.cuda.empty_cache() 183 | 184 | def get_review_embeddings(self, batch_size=128): 185 | if hasattr(self, "review_embeddings") and self.review_embeddings is not None: 186 | return #if already computed and not deleted 187 | if self.review_encoder_name == "pv": 188 | self.review_embeddings = self.review_encoder.review_embeddings.weight 189 | else: 190 | review_count = self.review_pad_idx 191 | seg_count = int((review_count - 1) / batch_size) + 1 192 | self.review_embeddings = torch.zeros(review_count+1, self.embedding_size, device=self.device) 193 | #The last one is always 0 194 | for i in range(seg_count): 195 | slice_reviews = self.review_words[i*batch_size:(i+1)*batch_size] 196 | if self.review_encoder_name == "pvc": 197 | self.review_encoder.set_to_evaluation_mode() 198 | slice_review_emb = self.review_encoder.get_para_vector(slice_reviews) 199 | self.review_encoder.set_to_train_mode() 200 | else: #fs or avg 201 | slice_rword_emb = self.word_embeddings(slice_reviews) 202 | slice_review_emb = self.review_encoder(slice_rword_emb, slice_reviews.ne(self.word_pad_idx)) 203 | self.review_embeddings[i*batch_size:(i+1)*batch_size] = slice_review_emb 204 | 205 | def test(self, batch_data): 206 | query_word_idxs = batch_data.query_word_idxs 207 | candi_prod_ridxs = batch_data.candi_prod_ridxs 208 | candi_seg_idxs = batch_data.candi_seg_idxs 209 | candi_seq_item_idxs = batch_data.candi_seq_item_idxs 210 | candi_seq_user_idxs = batch_data.candi_seq_user_idxs 211 | query_word_emb = self.word_embeddings(query_word_idxs) 212 | query_emb = self.query_encoder(query_word_emb, query_word_idxs.ne(self.word_pad_idx)) 213 | batch_size, candi_k, candi_rcount = candi_prod_ridxs.size() 214 | candi_review_emb = self.review_embeddings[candi_prod_ridxs] 215 | 216 | #concat query_emb with pos_review_emb and candi_review_emb 217 | query_mask = torch.ones(batch_size, 1, dtype=torch.uint8, device=query_word_idxs.device) 218 | candi_prod_ridx_mask = candi_prod_ridxs.ne(self.review_pad_idx) 219 | candi_review_mask = torch.cat([query_mask.unsqueeze(1).expand(-1,candi_k,-1), candi_prod_ridx_mask], dim=2) 220 | #batch_size, 1, embedding_size 221 | candi_sequence_emb = torch.cat( 222 | (query_emb.unsqueeze(1).expand(-1, candi_k, -1).unsqueeze(2), candi_review_emb), dim=2) 223 | #batch_size, candi_k, max_review_count+1, embedding_size 224 | candi_seg_emb = self.seg_embeddings(candi_seg_idxs) #batch_size, candi_k, max_review_count+1, embedding_size 225 | if self.args.use_seg_emb: 226 | candi_sequence_emb += candi_seg_emb 227 | if self.args.use_user_emb: 228 | candi_seq_user_emb = self.user_emb(candi_seq_user_idxs) 229 | candi_sequence_emb += candi_seq_user_emb 230 | if self.args.use_item_emb: 231 | candi_seq_item_emb = self.product_emb(candi_seq_item_idxs) 232 | candi_sequence_emb += candi_seq_item_emb 233 | 234 | candi_scores = self.transformer_encoder( 235 | candi_sequence_emb.view(batch_size*candi_k, candi_rcount+1, -1), 236 | candi_review_mask.view(batch_size*candi_k, candi_rcount+1), 237 | use_pos=self.args.use_pos_emb) 238 | candi_scores = candi_scores.view(batch_size, candi_k) 239 | return candi_scores 240 | 241 | def forward(self, batch_data, train_pv=True): 242 | query_word_idxs = batch_data.query_word_idxs 243 | pos_prod_ridxs = batch_data.pos_prod_ridxs 244 | pos_seg_idxs = batch_data.pos_seg_idxs 245 | pos_prod_rword_idxs= batch_data.pos_prod_rword_idxs 246 | pos_prod_rword_masks = batch_data.pos_prod_rword_masks 247 | neg_prod_ridxs = batch_data.neg_prod_ridxs 248 | neg_seg_idxs = batch_data.neg_seg_idxs 249 | pos_user_idxs = batch_data.pos_user_idxs 250 | neg_user_idxs = batch_data.neg_user_idxs 251 | pos_item_idxs = batch_data.pos_item_idxs 252 | neg_item_idxs = batch_data.neg_item_idxs 253 | neg_prod_rword_idxs = batch_data.neg_prod_rword_idxs 254 | neg_prod_rword_masks = batch_data.neg_prod_rword_masks 255 | pos_prod_rword_idxs_pvc = batch_data.pos_prod_rword_idxs_pvc 256 | neg_prod_rword_idxs_pvc = batch_data.neg_prod_rword_idxs_pvc 257 | query_word_emb = self.word_embeddings(query_word_idxs) 258 | query_emb = self.query_encoder(query_word_emb, query_word_idxs.ne(self.word_pad_idx)) 259 | batch_size, pos_rcount, posr_word_limit = pos_prod_rword_idxs.size() 260 | _, neg_k, neg_rcount = neg_prod_ridxs.size() 261 | posr_word_emb = self.word_embeddings(pos_prod_rword_idxs.view(-1, posr_word_limit)) 262 | update_pos_prod_rword_masks = pos_prod_rword_masks.view(-1, posr_word_limit) 263 | pv_loss = None 264 | if "pv" in self.review_encoder_name: 265 | if train_pv: 266 | if self.review_encoder_name == "pv": 267 | pos_review_emb, pos_prod_loss = self.review_encoder( 268 | pos_prod_ridxs.view(-1), posr_word_emb, 269 | update_pos_prod_rword_masks, self.args.neg_per_pos) 270 | elif self.review_encoder_name == "pvc": 271 | pos_review_emb, pos_prod_loss = self.review_encoder( 272 | posr_word_emb, update_pos_prod_rword_masks, 273 | pos_prod_rword_idxs_pvc.view(-1, pos_prod_rword_idxs_pvc.size(-1)), 274 | self.args.neg_per_pos) 275 | sample_count = pos_prod_ridxs.ne(self.review_pad_idx).float().sum(-1) 276 | # it won't be less than batch_size since there is not any sequence with all padding indices 277 | #sample_count = sample_count.masked_fill(sample_count.eq(0),1) 278 | pv_loss = pos_prod_loss.sum() / sample_count.sum() 279 | else: 280 | if self.fix_emb: 281 | pos_review_emb = self.review_embeddings[pos_prod_ridxs] 282 | else: 283 | if self.review_encoder_name == "pv": 284 | pos_review_emb = self.review_encoder.get_para_vector(pos_prod_ridxs) 285 | elif self.review_encoder_name == "pvc": 286 | pos_review_emb = self.review_encoder.get_para_vector( 287 | #pos_prod_rword_idxs_pvc.view(-1, pos_prod_rword_idxs_pvc.size(-1))) 288 | pos_prod_rword_idxs.view(-1, pos_prod_rword_idxs.size(-1))) 289 | if self.fix_emb: 290 | neg_review_emb = self.review_embeddings[neg_prod_ridxs] 291 | else: 292 | if self.review_encoder_name == "pv": 293 | neg_review_emb = self.review_encoder.get_para_vector(neg_prod_ridxs) 294 | elif self.review_encoder_name == "pvc": 295 | if not train_pv: 296 | neg_prod_rword_idxs_pvc = neg_prod_rword_idxs 297 | neg_review_emb = self.review_encoder.get_para_vector( 298 | neg_prod_rword_idxs_pvc.view(-1, neg_prod_rword_idxs_pvc.size(-1))) 299 | pos_review_emb = self.dropout_layer(pos_review_emb) 300 | neg_review_emb = self.dropout_layer(neg_review_emb) 301 | else: 302 | negr_word_limit = neg_prod_rword_idxs.size()[-1] 303 | negr_word_emb = self.word_embeddings(neg_prod_rword_idxs.view(-1, negr_word_limit)) 304 | pos_review_emb = self.review_encoder(posr_word_emb, update_pos_prod_rword_masks) 305 | neg_review_emb = self.review_encoder(negr_word_emb, neg_prod_rword_masks.view(-1, negr_word_limit)) 306 | 307 | pos_review_emb = pos_review_emb.view(batch_size, pos_rcount, -1) 308 | neg_review_emb = neg_review_emb.view(batch_size, neg_k, neg_rcount, -1) 309 | 310 | #concat query_emb with pos_review_emb and neg_review_emb 311 | query_mask = torch.ones(batch_size, 1, dtype=torch.uint8, device=query_word_idxs.device) 312 | pos_review_mask = torch.cat([query_mask, pos_prod_ridxs.ne(self.review_pad_idx)], dim=1) #batch_size, 1+max_review_count 313 | neg_prod_ridx_mask = neg_prod_ridxs.ne(self.review_pad_idx) 314 | neg_review_mask = torch.cat([query_mask.unsqueeze(1).expand(-1,neg_k,-1), neg_prod_ridx_mask], dim=2) 315 | #batch_size, 1, embedding_size 316 | pos_sequence_emb = torch.cat((query_emb.unsqueeze(1), pos_review_emb), dim=1) 317 | pos_seg_emb = self.seg_embeddings(pos_seg_idxs) #batch_size, max_review_count+1, embedding_size 318 | neg_sequence_emb = torch.cat( 319 | (query_emb.unsqueeze(1).expand(-1, neg_k, -1).unsqueeze(2), neg_review_emb), dim=2) 320 | #batch_size, neg_k, max_review_count+1, embedding_size 321 | neg_seg_emb = self.seg_embeddings(neg_seg_idxs) #batch_size, neg_k, max_review_count+1, embedding_size 322 | if self.args.use_seg_emb: 323 | pos_sequence_emb += pos_seg_emb 324 | neg_sequence_emb += neg_seg_emb 325 | if self.args.use_item_emb: 326 | pos_seq_item_emb = self.product_emb(pos_item_idxs) 327 | neg_seq_item_emb = self.product_emb(neg_item_idxs) 328 | pos_sequence_emb += pos_seq_item_emb 329 | neg_sequence_emb += neg_seq_item_emb 330 | if self.args.use_user_emb: 331 | pos_seq_user_emb = self.user_emb(pos_user_idxs) 332 | neg_seq_user_emb = self.user_emb(neg_user_idxs) 333 | pos_sequence_emb += pos_seq_user_emb 334 | neg_sequence_emb += neg_seq_user_emb 335 | 336 | pos_scores = self.transformer_encoder(pos_sequence_emb, pos_review_mask, use_pos=self.args.use_pos_emb) 337 | neg_scores = self.transformer_encoder( 338 | neg_sequence_emb.view(batch_size*neg_k, neg_rcount+1, -1), 339 | neg_review_mask.view(batch_size*neg_k, neg_rcount+1), use_pos=self.args.use_pos_emb) 340 | neg_scores = neg_scores.view(batch_size, neg_k) 341 | pos_weight = 1 342 | if self.args.pos_weight: 343 | pos_weight = self.args.neg_per_pos 344 | prod_mask = torch.cat([torch.ones(batch_size, 1, dtype=torch.uint8, device=query_word_idxs.device) * pos_weight, 345 | neg_prod_ridx_mask.sum(-1).ne(0)], dim=-1) #batch_size, neg_k (valid products, some are padded) 346 | #TODO: this mask does not reflect true neg prods, when reviews are randomly selected all of them should valid since there is no need for padding 347 | prod_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1) 348 | target = torch.cat([torch.ones(batch_size, 1, device=query_word_idxs.device), 349 | torch.zeros(batch_size, neg_k, device=query_word_idxs.device)], dim=-1) 350 | #ps_loss = self.bce_logits_loss(prod_scores, target, weight=prod_mask.float()) 351 | ps_loss = nn.functional.binary_cross_entropy_with_logits( 352 | prod_scores, target, 353 | weight=prod_mask.float(), 354 | reduction='none') 355 | 356 | ps_loss = ps_loss.sum(-1).mean() 357 | loss = ps_loss + pv_loss if pv_loss is not None else ps_loss 358 | return loss 359 | 360 | def initialize_parameters(self, logger=None): 361 | if logger: 362 | logger.info(" ProductRanker initialization started.") 363 | if self.pretrain_emb_dir is None: 364 | nn.init.normal_(self.word_embeddings.weight) 365 | nn.init.normal_(self.seg_embeddings.weight) 366 | self.review_encoder.initialize_parameters(logger) 367 | self.query_encoder.initialize_parameters(logger) 368 | self.transformer_encoder.initialize_parameters(logger) 369 | if logger: 370 | logger.info(" ProductRanker initialization finished.") 371 | 372 | -------------------------------------------------------------------------------- /models/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ query encoder can be the same as HEM: (fs) 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | 6 | def get_vector_mean(inputs, input_mask): 7 | #batch_size, max_word_count, embedding_size 8 | inputs_sum = (inputs * input_mask.float().unsqueeze(-1)).sum(1) 9 | word_count = input_mask.sum(-1) 10 | 11 | word_count = word_count.masked_fill( 12 | word_count.eq(0), 1).unsqueeze(-1) 13 | 14 | inputs_mean = inputs_sum / word_count.float() 15 | 16 | return inputs_mean 17 | 18 | 19 | class FSEncoder(nn.Module): 20 | def __init__(self, embedding_size, dropout=0.0): 21 | super(FSEncoder, self).__init__() 22 | self.dropout_ = dropout 23 | self.output_size_ = embedding_size 24 | self.f_W = nn.Linear(embedding_size, embedding_size) 25 | self.drop_layer = nn.Dropout(p=self.dropout_) 26 | #by default bias=True 27 | 28 | @property 29 | def size(self): 30 | return self.output_size_ 31 | 32 | def forward(self, inputs, input_mask): 33 | #batch_size, max_word_count, embedding_size 34 | inputs_mean = get_vector_mean(inputs, input_mask) 35 | inputs_mean = torch.dropout( 36 | inputs_mean, p=self.dropout_, train=self.training) 37 | #inputs_mean = self.drop_layer(inputs_mean) 38 | 39 | f_s = torch.tanh(self.f_W(inputs_mean)) 40 | return f_s 41 | 42 | def initialize_parameters(self, logger=None): 43 | if logger: 44 | logger.info(" FSEncoder initialization started.") 45 | for name, p in self.named_parameters(): 46 | if "weight" in name: 47 | if logger: 48 | logger.info(" {} ({}): Xavier normal init.".format( 49 | name, ",".join([str(x) for x in p.size()]))) 50 | nn.init.xavier_normal_(p) 51 | elif "bias" in name: 52 | if logger: 53 | logger.info(" {} ({}): constant (0) init.".format( 54 | name, ",".join([str(x) for x in p.size()]))) 55 | nn.init.constant_(p, 0) 56 | else: 57 | if logger: 58 | logger.info(" {} ({}): random normal init.".format( 59 | name, ",".join([str(x) for x in p.size()]))) 60 | nn.init.normal_(p) 61 | if logger: 62 | logger.info(" FSEncoder initialization finished.") 63 | 64 | class AVGEncoder(nn.Module): 65 | def __init__(self, embedding_size, dropout=0.0): 66 | super(AVGEncoder, self).__init__() 67 | self.dropout_ = dropout 68 | self.output_size_ = embedding_size 69 | self.drop_layer = nn.Dropout(p=self.dropout_) 70 | 71 | @property 72 | def size(self): 73 | return self.output_size_ 74 | 75 | def forward(self, inputs, input_mask): 76 | #batch_size, max_word_count, embedding_size 77 | inputs_mean = get_vector_mean(inputs, input_mask) 78 | #inputs_mean = torch.dropout( 79 | # inputs_mean, p=self.dropout_, train=self.training) 80 | inputs_mean = self.drop_layer(inputs_mean) #better managed than using torch.dropout 81 | 82 | return inputs_mean 83 | 84 | def initialize_parameters(self, logger=None): 85 | if logger: 86 | logger.info(" AveragingEncoder initialization skipped" 87 | " (no parameters).") 88 | 89 | #CNN or RNN 90 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.neural import MultiHeadedAttention, PositionwiseFeedForward 7 | 8 | class PositionalEncoding(nn.Module): 9 | 10 | def __init__(self, dropout, dim, max_len=5000): 11 | pe = torch.zeros(max_len, dim) 12 | position = torch.arange(0, max_len).unsqueeze(1) 13 | div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * 14 | -(math.log(10000.0) / dim))) 15 | pe[:, 0::2] = torch.sin(position.float() * div_term) 16 | pe[:, 1::2] = torch.cos(position.float() * div_term) 17 | pe = pe.unsqueeze(0) 18 | super(PositionalEncoding, self).__init__() 19 | self.register_buffer('pe', pe) 20 | self.dropout = nn.Dropout(p=dropout) 21 | self.dim = dim 22 | 23 | def forward(self, emb, step=None): 24 | emb = emb * math.sqrt(self.dim) 25 | if (step): 26 | emb = emb + self.pe[:, step][:, None, :] 27 | 28 | else: 29 | emb = emb + self.pe[:, :emb.size(1)] 30 | emb = self.dropout(emb) 31 | return emb 32 | 33 | def get_emb(self, emb): 34 | return self.pe[:, :emb.size(1)] 35 | 36 | 37 | class TransformerEncoderLayer(nn.Module): 38 | def __init__(self, d_model, heads, d_ff, dropout): 39 | super(TransformerEncoderLayer, self).__init__() 40 | 41 | self.self_attn = MultiHeadedAttention( 42 | heads, d_model, dropout=dropout) 43 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 44 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward(self, iter, query, inputs, mask): 48 | if (iter != 0): 49 | input_norm = self.layer_norm(inputs) 50 | else: 51 | input_norm = inputs 52 | 53 | mask = mask.unsqueeze(1) 54 | context = self.self_attn(input_norm, input_norm, input_norm, 55 | mask=mask) 56 | out = self.dropout(context) + inputs 57 | return self.feed_forward(out) 58 | 59 | class TransformerEncoder(nn.Module): 60 | def __init__(self, d_model, d_ff, heads, dropout, num_inter_layers=0): 61 | super(TransformerEncoder, self).__init__() 62 | self.d_model = d_model 63 | self.num_inter_layers = num_inter_layers 64 | self.pos_emb = PositionalEncoding(dropout, d_model) 65 | self.transformer_inter = nn.ModuleList( 66 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 67 | for _ in range(num_inter_layers)]) 68 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 69 | self.wo = nn.Linear(d_model, 1, bias=True) 70 | 71 | def encode(self, input_vecs, mask, use_pos=True): 72 | """ See :obj:`EncoderBase.forward()`""" 73 | 74 | #input_vecs is batch_size, sequence_length, embedding_size 75 | batch_size, n_sents = input_vecs.size(0), input_vecs.size(1) 76 | #batch_size, n_sents, 77 | x = input_vecs * mask[:, :, None].float() 78 | 79 | if use_pos: 80 | pos_emb = self.pos_emb.pe[:, :n_sents] 81 | x = x + pos_emb 82 | 83 | for i in range(self.num_inter_layers): 84 | x = self.transformer_inter[i](i, x, x, 1 - mask) # all_sents * max_tokens * dim 85 | 86 | x = self.layer_norm(x) 87 | #out_pos can be 0 or -1 # represent query or item in the item_transformer model 88 | return x 89 | 90 | def forward(self, input_vecs, mask, use_pos=True, out_pos=0): 91 | """ See :obj:`EncoderBase.forward()`""" 92 | 93 | #out_pos can be 0 or -1 # represent query or item in the item_transformer model 94 | x = self.encode(input_vecs, mask, use_pos) 95 | out_emb = x[:,out_pos,:]# x[:,0,:] will return size batch_size, d_model 96 | scores = self.wo(out_emb).squeeze(-1) #* mask.float() 97 | #batch_size 98 | return scores 99 | 100 | def initialize_parameters(self, logger=None): 101 | if logger: 102 | logger.info(" Transformer initialization started.") 103 | for name, p in self.named_parameters(): 104 | if "weight" in name and p.dim() > 1: 105 | if logger: 106 | logger.info(" {} ({}): Xavier normal init.".format( 107 | name, ",".join([str(x) for x in p.size()]))) 108 | nn.init.xavier_normal_(p) 109 | elif "bias" in name: 110 | if logger: 111 | logger.info(" {} ({}): constant (0) init.".format( 112 | name, ",".join([str(x) for x in p.size()]))) 113 | nn.init.constant_(p, 0) 114 | else: 115 | if logger: 116 | logger.info(" {} ({}): random normal init.".format( 117 | name, ",".join([str(x) for x in p.size()]))) 118 | nn.init.normal_(p) 119 | if logger: 120 | logger.info(" Transformer initialization finished.") 121 | -------------------------------------------------------------------------------- /others/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kepingbi/ProdSearch/449335ba652fe7c877a008e154157d7b2a4b0e76/others/__init__.py -------------------------------------------------------------------------------- /others/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /others/util.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | from others.logging import logger 3 | 4 | def load_pretrain_embeddings(fname): 5 | embeddings = [] 6 | word_index_dic = dict() 7 | with gzip.open(fname, 'rt') as fin: 8 | count = int(fin.readline().strip()) 9 | emb_size = int(fin.readline().strip()) 10 | line_no = 0 11 | for line in fin: 12 | arr = line.strip(' ').split('\t')#the first element is empty 13 | word_index_dic[arr[0]] = line_no 14 | line_no += 1 15 | vector = arr[1].split() 16 | vector = [float(x) for x in vector] 17 | embeddings.append(vector) 18 | logger.info("Loading {}".format(fname)) 19 | logger.info("Count:{} Embeddings size:{}".format(len(embeddings), len(embeddings[0]))) 20 | return word_index_dic, embeddings 21 | 22 | def load_user_item_embeddings(fname): 23 | embeddings = [] 24 | #with gzip.open(fname, 'rt') as fin: 25 | with open(fname, 'r') as fin: 26 | count = int(fin.readline().strip()) 27 | emb_size = int(fin.readline().strip()) 28 | for line in fin: 29 | arr = line.strip().split(' ') 30 | vector = [float(x) for x in arr] 31 | embeddings.append(vector) 32 | logger.info("Loading {}".format(fname)) 33 | logger.info("Count:{} Embeddings size:{}".format(len(embeddings), len(embeddings[0]))) 34 | return embeddings 35 | 36 | def pad(data, pad_id, width=-1): 37 | if (width == -1): 38 | width = max(len(d) for d in data) 39 | rtn_data = [d[:width] + [pad_id] * (width - len(d)) for d in data]#if width < max(len(d)) of data 40 | return rtn_data 41 | 42 | def pad_3d(data, pad_id, dim=1, width=-1): 43 | #dim = 1 or 2 44 | if dim < 1 or dim > 2: 45 | return data 46 | if (width == -1): 47 | if (dim == 1): 48 | #dim 0,2 is same across the batch 49 | width = max(len(d) for d in data) 50 | elif (dim == 2): 51 | #dim 0,1 is same across the batch 52 | for entry in data: 53 | width = max(width, max(len(d) for d in entry)) 54 | #print(width) 55 | if dim == 1: 56 | rtn_data = [d[:width] + [[pad_id] * len(data[0][0])] * (width - len(d)) for d in data] 57 | elif dim == 2: 58 | rtn_data = [] 59 | for entry in data: 60 | rtn_data.append([d[:width] + [pad_id] * (width - len(d)) for d in entry]) 61 | return rtn_data 62 | 63 | def pad_4d_dim1(data, pad_id, width=-1): 64 | if (width == -1): 65 | #max width of dim1 66 | width = max(width, max(len(d) for d in data)) 67 | #print(width) 68 | rtn_data = [d[:width] + [[[pad_id] * len(data[0][0][0])]] * (width - len(d)) for d in data] 69 | return rtn_data 70 | 71 | def pad_4d_dim2(data, pad_id, width=-1): 72 | #only handle padding to dim = 2 73 | if (width == -1): 74 | #max width of dim2 75 | for entry in data: 76 | width = max(width, max(len(d) for d in entry)) 77 | #print(width) 78 | rtn_data = [] 79 | for entry_dim1 in data: 80 | rtn_data.append([d[:width] + [[pad_id] * len(data[0][0][0])] * (width - len(d)) for d in entry_dim1]) 81 | return rtn_data 82 | 83 | def main(): 84 | data = [[[[2,2,2],[2,2,2]],[[2,2,2]]],[[[2,2,2]]]] 85 | rtn = pad_4d_dim1(data, -1) 86 | rtn = pad_4d_dim2(rtn, -1) 87 | print(rtn) 88 | 89 | if __name__ == "__main__": 90 | main() 91 | 92 | 93 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from others.logging import logger 3 | #from data.prod_search_dataloader import ProdSearchDataloader 4 | #from data.prod_search_dataset import ProdSearchDataset 5 | import shutil 6 | import torch 7 | import numpy as np 8 | import data 9 | import os 10 | import time 11 | import sys 12 | 13 | def _tally_parameters(model): 14 | n_params = sum([p.nelement() for p in model.parameters()]) 15 | return n_params 16 | 17 | class Trainer(object): 18 | """ 19 | Class that controls the training process. 20 | """ 21 | def __init__(self, args, model, optim): 22 | # Basic attributes. 23 | self.args = args 24 | self.model = model 25 | self.optim = optim 26 | if (model): 27 | n_params = _tally_parameters(model) 28 | logger.info('* number of parameters: %d' % n_params) 29 | #self.device = "cpu" if self.n_gpu == 0 else "cuda" 30 | if args.model_name == "review_transformer": 31 | self.ExpDataset = data.ProdSearchDataset 32 | self.ExpDataloader = data.ProdSearchDataLoader 33 | else: 34 | self.ExpDataset = data.ItemPVDataset 35 | self.ExpDataloader = data.ItemPVDataloader 36 | 37 | def train(self, args, global_data, train_prod_data, valid_prod_data): 38 | """ 39 | The main training loops. 40 | """ 41 | logger.info('Start training...') 42 | # Set model in training mode. 43 | model_dir = args.save_dir 44 | valid_dataset = self.ExpDataset(args, global_data, valid_prod_data) 45 | step_time, loss = 0.,0. 46 | get_batch_time = 0.0 47 | start_time = time.time() 48 | current_step = 0 49 | best_mrr = 0. 50 | best_checkpoint_path = '' 51 | for current_epoch in range(args.start_epoch+1, args.max_train_epoch+1): 52 | self.model.train() 53 | logger.info("Initialize epoch:%d" % current_epoch) 54 | train_prod_data.initialize_epoch() 55 | dataset = self.ExpDataset(args, global_data, train_prod_data) 56 | prepare_pv = current_epoch < args.train_pv_epoch+1 57 | print(prepare_pv) 58 | dataloader = self.ExpDataloader( 59 | args, dataset, prepare_pv=prepare_pv, batch_size=args.batch_size, 60 | shuffle=True, num_workers=args.num_workers) 61 | pbar = tqdm(dataloader) 62 | pbar.set_description("[Epoch {}]".format(current_epoch)) 63 | time_flag = time.time() 64 | for batch_data_arr in pbar: 65 | if batch_data_arr is None: 66 | continue 67 | if type(batch_data_arr) is list: 68 | batch_data_arr = [x.to(args.device) for x in batch_data_arr] 69 | else: 70 | batch_data_arr = [batch_data_arr.to(args.device)] 71 | for batch_data in batch_data_arr: 72 | get_batch_time += time.time() - time_flag 73 | time_flag = time.time() 74 | step_loss = self.model(batch_data, train_pv=prepare_pv) 75 | #self.optim.optimizer.zero_grad() 76 | self.model.zero_grad() 77 | step_loss.backward() 78 | self.optim.step() 79 | step_loss = step_loss.item() 80 | pbar.set_postfix(step_loss=step_loss, lr=self.optim.learning_rate) 81 | loss += step_loss / args.steps_per_checkpoint #convert an tensor with dim 0 to value 82 | current_step += 1 83 | step_time += time.time() - time_flag 84 | 85 | # Once in a while, we print statistics. 86 | if current_step % args.steps_per_checkpoint == 0: 87 | ps_loss, item_loss = 0, 0 88 | if hasattr(self.model, "ps_loss"): 89 | ps_loss = self.model.ps_loss/args.steps_per_checkpoint 90 | if hasattr(self.model, "item_loss"): 91 | item_loss = self.model.item_loss/args.steps_per_checkpoint 92 | 93 | logger.info("Epoch %d lr = %5.6f loss = %6.2f ps_loss: %3.2f iw_loss: %3.2f time %.2f prepare_time %.2f step_time %.2f" % 94 | (current_epoch, self.optim.learning_rate, loss, ps_loss, item_loss, 95 | time.time()-start_time, get_batch_time, step_time))#, end="" 96 | step_time, get_batch_time, loss = 0., 0.,0. 97 | if hasattr(self.model, "ps_loss"): 98 | self.model.clear_loss() 99 | sys.stdout.flush() 100 | start_time = time.time() 101 | checkpoint_path = os.path.join(model_dir, 'model_epoch_%d.ckpt' % current_epoch) 102 | self._save(current_epoch, checkpoint_path) 103 | mrr, prec = self.validate(args, global_data, valid_dataset) 104 | logger.info("Epoch {}: MRR:{} P@1:{}".format(current_epoch, mrr, prec)) 105 | if mrr > best_mrr: 106 | best_mrr = mrr 107 | best_checkpoint_path = os.path.join(model_dir, 'model_best.ckpt') 108 | logger.info("Copying %s to checkpoint %s" % (checkpoint_path, best_checkpoint_path)) 109 | shutil.copyfile(checkpoint_path, best_checkpoint_path) 110 | return best_checkpoint_path 111 | 112 | def _save(self, epoch, checkpoint_path): 113 | checkpoint = { 114 | 'epoch': epoch, 115 | 'model': self.model.state_dict(), 116 | 'opt': self.args, 117 | 'optim': self.optim, 118 | } 119 | #model_dir = "%s/model" % (self.args.save_dir) 120 | #checkpoint_path = os.path.join(model_dir, 'model_epoch_%d.ckpt' % epoch) 121 | logger.info("Saving checkpoint %s" % checkpoint_path) 122 | torch.save(checkpoint, checkpoint_path) 123 | 124 | def validate(self, args, global_data, valid_dataset): 125 | """ Validate model. 126 | """ 127 | candidate_size = args.valid_candi_size 128 | if args.valid_candi_size < 1: 129 | candidate_size = global_data.product_size 130 | dataloader = self.ExpDataloader( 131 | args, valid_dataset, batch_size=args.valid_batch_size, 132 | shuffle=False, num_workers=args.num_workers) 133 | all_prod_idxs, all_prod_scores, all_target_idxs, \ 134 | all_query_idxs, all_user_idxs \ 135 | = self.get_prod_scores(args, global_data, valid_dataset, dataloader, "Validation", candidate_size) 136 | sorted_prod_idxs = all_prod_scores.argsort(axis=-1)[:,::-1] #by default axis=-1, along the last axis 137 | mrr, prec = self.calc_metrics(all_prod_idxs, sorted_prod_idxs, all_target_idxs, candidate_size, cutoff=100) 138 | return mrr, prec 139 | 140 | def test(self, args, global_data, test_prod_data, rankfname="test.best_model.ranklist", cutoff=100): 141 | candidate_size = args.test_candi_size 142 | if args.test_candi_size < 1: 143 | candidate_size = global_data.product_size 144 | test_dataset = self.ExpDataset(args, global_data, test_prod_data) 145 | dataloader = self.ExpDataloader( 146 | args, test_dataset, batch_size=args.valid_batch_size, #batch_size 147 | shuffle=False, num_workers=args.num_workers) 148 | 149 | all_prod_idxs, all_prod_scores, all_target_idxs, \ 150 | all_query_idxs, all_user_idxs \ 151 | = self.get_prod_scores(args, global_data, test_dataset, dataloader, "Test", candidate_size) 152 | sorted_prod_idxs = all_prod_scores.argsort(axis=-1)[:,::-1] #by default axis=-1, along the last axis 153 | mrr, prec = self.calc_metrics(all_prod_idxs, sorted_prod_idxs, all_target_idxs, candidate_size, cutoff) 154 | logger.info("Test: MRR:{} P@1:{}".format(mrr, prec)) 155 | output_path = os.path.join(args.save_dir, rankfname) 156 | eval_count = all_prod_scores.shape[0] 157 | print(all_prod_scores.shape) 158 | with open(output_path, 'w') as rank_fout: 159 | for i in range(eval_count): 160 | user_id = global_data.user_ids[all_user_idxs[i]] 161 | qidx = all_query_idxs[i] 162 | ranked_product_ids = all_prod_idxs[i][sorted_prod_idxs[i]] 163 | ranked_product_scores = all_prod_scores[i][sorted_prod_idxs[i]] 164 | for rank in range(min(cutoff, candidate_size)): 165 | product_id = global_data.product_ids[ranked_product_ids[rank]] 166 | score = ranked_product_scores[rank] 167 | line = "%s_%d Q0 %s %d %f ReviewTransformer\n" \ 168 | % (user_id, qidx, product_id, rank+1, score) 169 | rank_fout.write(line) 170 | 171 | def calc_metrics(self, all_prod_idxs, sorted_prod_idxs, all_target_idxs, candidate_size, cutoff=100): 172 | eval_count = all_prod_idxs.shape[0] 173 | mrr, prec = 0, 0 174 | for i in range(eval_count): 175 | result = np.where(all_prod_idxs[i][sorted_prod_idxs[i]] == all_target_idxs[i]) 176 | if len(result[0]) == 0: #not occur in the list 177 | pass 178 | else: 179 | rank = result[0][0] + 1 180 | if cutoff < 0 or rank <= cutoff: 181 | mrr += 1/rank 182 | if rank == 1: 183 | prec +=1 184 | mrr /= eval_count 185 | prec /= eval_count 186 | print("MRR:{} P@1:{}".format(mrr, prec)) 187 | return mrr, prec 188 | 189 | def get_prod_scores(self, args, global_data, dataset, dataloader, description, candidate_size): 190 | self.model.eval() 191 | with torch.no_grad(): 192 | if args.model_name == "review_transformer": 193 | self.model.get_review_embeddings() #get model.review_embeddings 194 | pbar = tqdm(dataloader) 195 | pbar.set_description(description) 196 | seg_count = int((candidate_size - 1) / args.candi_batch_size) + 1 197 | all_prod_scores, all_target_idxs, all_prod_idxs = [], [], [] 198 | all_user_idxs, all_query_idxs = [], [] 199 | for batch_data in pbar: 200 | batch_data = batch_data.to(args.device) 201 | batch_scores = self.model.test(batch_data) 202 | #batch_size, candidate_batch_size 203 | all_user_idxs.append(np.asarray(batch_data.user_idxs)) 204 | all_query_idxs.append(np.asarray(batch_data.query_idxs)) 205 | candi_prod_idxs = batch_data.candi_prod_idxs 206 | if type(candi_prod_idxs) is torch.Tensor: 207 | candi_prod_idxs = candi_prod_idxs.cpu() 208 | all_prod_idxs.append(np.asarray(candi_prod_idxs)) 209 | all_prod_scores.append(batch_scores.cpu().numpy()) 210 | target_prod_idxs = batch_data.target_prod_idxs 211 | if type(target_prod_idxs) is torch.Tensor: 212 | target_prod_idxs = target_prod_idxs.cpu() 213 | all_target_idxs.append(np.asarray(target_prod_idxs)) 214 | #use MRR 215 | assert args.candi_batch_size <= candidate_size #otherwise results are wrong 216 | padded_length = seg_count * args.candi_batch_size 217 | all_prod_idxs = np.concatenate(all_prod_idxs, axis=0).reshape(-1, padded_length)[:, :candidate_size] 218 | all_prod_scores = np.concatenate(all_prod_scores, axis=0).reshape(-1, padded_length)[:, :candidate_size] 219 | all_target_idxs = np.concatenate(all_target_idxs, axis=0).reshape(-1, seg_count)[:,0] 220 | all_user_idxs = np.concatenate(all_user_idxs, axis=0).reshape(-1, seg_count)[:,0] 221 | all_query_idxs = np.concatenate(all_query_idxs, axis=0).reshape(-1, seg_count)[:,0] 222 | #target_scores = all_prod_scores[np.arange(eval_count), all_target_idxs] 223 | #all_prod_scores.sort(axis=-1) #ascending 224 | if args.model_name == "review_transformer": 225 | self.model.clear_review_embbeddings() 226 | return all_prod_idxs, all_prod_scores, all_target_idxs, all_query_idxs, all_user_idxs 227 | 228 | -------------------------------------------------------------------------------- /tune_para.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import argparse 4 | 5 | #titanx-long #1080ti-long #2080ti-long # Partition to submit to 6 | #SBATCH --mem=96000 # Memory in MB per node allocated 7 | config_str = """#!/bin/bash 8 | 9 | #SBATCH --partition=titanx-long 10 | #SBATCH --gres=gpu:1 11 | #SBATCH --ntasks=1 12 | #SBATCH --mem=64000 # Memory in MB per node allocated 13 | #SBATCH --ntasks-per-node=4 14 | """ 15 | datasets = ["Beauty",#dataset, divided into how many parts 16 | "Cell_Phones_and_Accessories", 17 | "Clothing_Shoes_and_Jewelry", 18 | "Health_and_Personal_Care", 19 | "Home_and_Kitchen", 20 | "Sports_and_Outdoors" 21 | "Electronics", 22 | "Movies_and_TV", 23 | "CDs_and_Vinyl", #CD too large 24 | "Kindle_Store", #Kindle too large 25 | ] 26 | 27 | WORKING_DIR="/mnt/nfs/scratch1/kbi/review_transformer/working/Amazon" 28 | #OUTPUT_DIR="/mnt/nfs/scratch1/kbi/review_transformer/output/Amazon/review_transformer_nodup" 29 | OUTPUT_DIR="/mnt/nfs/work1/croft/kbi/review_transformer/output/Amazon/review_transformer_nodup" 30 | 31 | script_path = "python main.py" 32 | #CONST_CMD_ARR = [("data_dir", data_dir),("input_train_dir", input_train_dir)] 33 | #CONST_CMD = " ".join(["--{} {}".format(x[0], x[1]) for x in CONST_CMD_ARR]) 34 | pretrain_pv_root_dir = "/mnt/nfs/scratch1/kbi/review_transformer/working/paragraph_embeddings/reviews_##_5.json.gz.stem.nostop/min_count5/" 35 | pv_path = "batch_size256.negative_sample5.learning_rate0.5.embed_size128.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structpv_hdc./" 36 | small_pvc_path = "batch_size256.negative_sample5.learning_rate0.5.embed_size128.use_local_contextTrue.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structcdv_hdc." 37 | #small_pvc_path = "batch_size256.negative_sample5.learning_rate0.5.embed_size128.subsampling_rate1e-05.use_local_contextTrue.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structcdv_hdc." 38 | large_pvc_path = "batch_size512.negative_sample5.learning_rate0.5.embed_size128.subsampling_rate1e-06.use_local_contextTrue.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structcdv_hdc." #subsampling_rate can be 1e-5,1e-6,1e-7 39 | 40 | #pretrain_pv_emb_dir = "/mnt/nfs/scratch1/kbi/review_transformer/working/paragraph_embeddings/reviews_##_5.json.gz.stem.nostop/min_count5/batch_size256.negative_sample5.learning_rate0.5.embed_size128.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structpv_hdc." 41 | #pretrain_pvc_emb_dir = "/mnt/nfs/scratch1/kbi/review_transformer/working/paragraph_embeddings/reviews_##_5.json.gz.stem.nostop/min_count5/batch_size256.negative_sample5.learning_rate0.5.embed_size128.use_local_contextTrue.steps_per_checkpoint400.max_train_epoch20.L2_lambda0.0.net_structcdv_hdc." 42 | 43 | para_names = ['pretrain_type', 'review_encoder_name', 'max_train_epoch', 'lr', 'warmup_steps', 'batch_size', 'valid_candi_size', \ 44 | 'embedding_size', 'review_word_limit', 'iprev_review_limit', 'dropout', \ 45 | 'use_pos_emb', 'corrupt_rate', 'pos_weight', 'l2_lambda', 'ff_size', \ 46 | 'inter_layers', 'use_user_emb', 'use_item_emb', 'fix_train_review', 'do_subsample_mask', 'subsampling_rate'] 47 | short_names = ['pretrain', 'enc', 'me', 'lr', 'ws', 'bs', 'vcs', 'ebs', \ 48 | 'rwl', 'irl', 'drop', 'upos', 'cr', 'poswt', 'lambda', \ 49 | 'ff', 'ly', 'ue', 'ie', 'ftr','dsm', 'ssr'] 50 | 51 | paras = [ 52 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), #similar as previous, best setting 53 | ('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.005, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), #similar as previous, best setting 54 | ('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 1, False, False, False, True, 0), #similar as previous, best setting 55 | ('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 3, False, False, False, True, 0), #similar as previous, best setting 56 | 57 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.005, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 58 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.02, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 59 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 30, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 1, False, False, False, True, 0), 60 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 30, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 3, False, False, False, True, 0), 61 | 62 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 30, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 63 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), 64 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 65 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 20000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 66 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, False, 1e-5), 67 | 68 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), 69 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 70 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.002, 20000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 71 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, False, 1e-5), 72 | 73 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 74 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), 75 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 76 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, False, 1e-5), 77 | 78 | #('Movies_and_TV', 'pv', 'fs', 20, 0.01, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 79 | #('Movies_and_TV', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), 80 | #('Movies_and_TV', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 81 | #('Movies_and_TV', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, False, 1e-5), 82 | 83 | #('Electronics', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 84 | #('Electronics', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), 85 | #('Electronics', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, True, 0), 86 | #('Electronics', 'pv', 'fs', 20, 0.002, 80000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False, False, False, 1e-5), 87 | 88 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 0), #similar as previous, best setting 89 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, True, True, 0), 90 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, False, 1e-5), 91 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True, False, True, 1e-5), #do what we expected previously 92 | 93 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.005, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True), 94 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.005, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True), 95 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 1e-4, 512, 2, True, True), 96 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 1e-3, 512, 2, True, True), 97 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 1e-4, 512, 2, True, True), 98 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 1e-3, 512, 2, True, True), 99 | 100 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, 0.2, False, 0, 512, 2, True, True), 101 | #('Cell_Phones_and_Accessories', 'pvc', 'pvc', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.3, False, 0, 512, 2, True, True), 102 | #('Cell_Phones_and_Accessories', 'pvc', 'pvc', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, True, 0.4, False, 0, 512, 2, True, True), 103 | #('Clothing_Shoes_and_Jewelry', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True), 104 | #('Clothing_Shoes_and_Jewelry', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False), 105 | #('Clothing_Shoes_and_Jewelry', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, 0.2, False, 0, 512, 2, False, False), 106 | #('Sports_and_Outdoors', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, False, False), 107 | #('Sports_and_Outdoors', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, 0.2, False, 0, 512, 2, False, False), 108 | #('Sports_and_Outdoors', 'pvc', 'pvc', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, True, 0.2, False, 0, 512, 2, True, True), 109 | 110 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 111 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 112 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 200, 30, 0.1, False, False, False, 0, 512, 2, False, False), 113 | #('CDs_and_Vinyl', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 200, 30, 0.1, False, False, False, 0, 512, 2, True, True), 114 | 115 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 116 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.2, False, False, False, 0, 512, 2, True, True), 117 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 1e-4, 512, 2, True, True), 118 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 119 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 1e-4, 512, 2, False, False), 120 | 121 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 122 | #('Electronics', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 123 | #('Movies_and_TV', 'pv', 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 124 | 125 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 126 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 50, 0.01, 15000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 127 | #('Clothing_Shoes_and_Jewelry', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 128 | 129 | #('Sports_and_Outdoors', None, 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 130 | #('Sports_and_Outdoors', 'pv', 'fs', 50, 0.01, 15000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 131 | #('Sports_and_Outdoors', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 132 | 133 | #('Electronics', None, 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 134 | #('Electronics', 'pv', 'fs', 50, 0.01, 15000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 135 | #('Electronics', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 136 | 137 | #('Movies_and_TV', None, 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 138 | #('Movies_and_TV', 'pv', 'fs', 50, 0.01, 15000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 139 | #('Movies_and_TV', 'pv', 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 140 | 141 | #('Cell_Phones_and_Accessories', 'pvc', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 142 | #('Cell_Phones_and_Accessories', 'pvc', 'pvc', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 143 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 144 | #('Cell_Phones_and_Accessories', 'pv', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 145 | #('Cell_Phones_and_Accessories', 'pvc', 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 146 | #('Cell_Phones_and_Accessories', 'pvc', 'pvc', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 147 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 148 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 0.002, 8000, 128, 500, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, False, False), 149 | 150 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 151 | #('Electronics', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 152 | #('Movies_and_TV', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 153 | #('Kindle_Store', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 154 | #('Sports_and_Outdoors', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, True), 155 | 156 | #('Clothing_Shoes_and_Jewelry', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, True, True), 157 | #('Electronics', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, True, True), 158 | #('Movies_and_TV', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, True, True), 159 | #('Kindle_Store', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, True, True), 160 | #('Sports_and_Outdoors', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, True, False, 0, 512, 2, True, True), 161 | 162 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, False), 163 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, True, False), 164 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2, False, True), 165 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 1e-4, 512, 2, False, False), 166 | 167 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 40, 0.1, False, False, False, 0, 512, 2), 168 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 500, 256, 100, 40, 0.1, False, False, False, 0, 512, 2), 169 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 40, 0.1, False, False, False, 1e-4, 512, 2), 170 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 500, 128, 100, 40, 0.1, False, False, False, 0, 512, 4), 171 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0, 512, 2), 172 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 0.002, 8000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0, 512, 2), 173 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 174 | #('Health_and_Personal_Care', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 175 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 176 | #('Sports_and_Outdoors', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 177 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.01, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 178 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.005, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 179 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 180 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 96, 100, 30, 0.1, False, False, False, 0, 512, 2), 181 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 64, 100, 30, 0.1, False, False, False, 0, 512, 2), 182 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 2), 183 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 512, 1), 184 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 256, 2), 185 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 0.002, 8000, 128, 1000, 128, 100, 30, 0.1, False, False, False, 0, 256, 1), 186 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 187 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 188 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0), 189 | #('Cell_Phones_and_Accessories', 'pvc', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 190 | #('Cell_Phones_and_Accessories', 'pvc', 'pvc', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0), 191 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0), 192 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, True, 0), 193 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0.001), 194 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 10000, 128, 1000, 128, 150, 40, 0.1, False, False, False, 0), 195 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 10000, 128, 1000, 128, 200, 40, 0.1, False, False, False, 0), 196 | #('Cell_Phones_and_Accessories', pretrain_pv_emb_dir, 'pv', 30, 3000, 128, 1000, 128, 100, 40, 0.1, True, True, True), 197 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 8000, 32, 1000, 128, 100, 40, 0.1, False, True, False, 0), 198 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 8000, 64, 1000, 128, 100, 40, 0.1, False, True, False, 0), 199 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 200 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0.001), 201 | #('Cell_Phones_and_Accessories', 'pv', 'pv', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0.0001), 202 | #('Cell_Phones_and_Accessories', pretrain_pv_emb_dir, 'pv', 30, 3000, 128, 1000, 128, 100, 40, 0.1, False, True, True), 203 | #('Cell_Phones_and_Accessories', pretrain_pv_emb_dir, 'pv', 30, 3000, 128, 1000, 128, 100, 40, 0.1, False, False, False), 204 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0), 205 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0.001), 206 | #('Cell_Phones_and_Accessories', None, 'fs', 20, 8000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0.0001), 207 | 208 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0), 209 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, False, False, 0.001), 210 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.2, False, False, False, 0), 211 | #('Clothing_Shoes_and_Jewelry', None, 'fs', 20, 10000, 128, 1000, 128, 100, 40, 0.3, False, False, False, 0), 212 | #('Clothing_Shoes_and_Jewelry', 'pv', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 213 | #('Clothing_Shoes_and_Jewelry', 'pvc', 'pv', 20, 10000, 128, 1000, 128, 100, 40, 0.1, False, True, False, 0), 214 | #('Clothing_Shoes_and_Jewelry', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 40, 0.1, 0.9), 215 | #('Movies_and_TV', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 40, 0.1, 0.9), 216 | #('Sports_and_Outdoors', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 40, 0.1, 0.9), 217 | #('Kindle_Store', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 40, 0.1, 0.9), 218 | #('Cell_Phones_and_Accessories', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 40, 0.1, 0.9), 219 | 220 | #('Cell_Phones_and_Accessories', 'pv', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 221 | #('Cell_Phones_and_Accessories', 'pvc', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 222 | #('Cell_Phones_and_Accessories', 'fs', 30, 0, 8000, 128, 1000, 128, 100, 30, 0.1, 0.9), 223 | #('Cell_Phones_and_Accessories', 'avg', 20, 0, 8000, 128, 1000, 128, 100, 30, 0.1, 0.9), 224 | 225 | #('CDs_and_Vinyl', 'pv', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 226 | #('CDs_and_Vinyl', 'pvc', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 227 | #('CDs_and_Vinyl', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 30, 0.1, 0.9), 228 | #('CDs_and_Vinyl', 'avg', 20, 0, 8000, 128, 1000, 128, 100, 30, 0.1, 0.9), 229 | 230 | #('Electronics', 'pv', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 231 | #('Electronics', 'pvc', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 232 | #('Electronics', 'fs', 20, 0, 8000, 128, 1000, 128, 100, 30, 0.1, 0.9), 233 | 234 | #('Kindle_Store', 'pv', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 235 | #('Kindle_Store', 'pvc', 10, 2, 3000, 32, 1000, 128, 100, 30, 0.1, 0.9), 236 | ] 237 | 238 | if __name__ == '__main__': 239 | fscript = open("run_model.sh", 'w') 240 | parser = argparse.ArgumentParser() 241 | #parser.add_argument("--log_dir", type=str, default='exp_log') 242 | parser.add_argument("--log_dir", type=str, default='log_review_transformer_nodup') 243 | parser.add_argument("--script_dir", type=str, default='script_review_transformer_nodup') 244 | args = parser.parse_args() 245 | #os.system("mkdir -p %s" % args.log_dir) 246 | os.system("mkdir -p %s" % args.log_dir) 247 | os.system("mkdir -p %s" % args.script_dir) 248 | job_id = 1 249 | for para in paras: 250 | cmd_arr = [] 251 | cmd_arr.append(script_path) 252 | dataset = para[0] 253 | os.system("mkdir -p {}/{}".format(args.log_dir, dataset)) 254 | 255 | if para[1] == 'pv': 256 | pretrain_emb_dir = os.path.join(pretrain_pv_root_dir.replace('##', dataset), pv_path) 257 | elif para[1] == 'pvc': 258 | if dataset in datasets[:-4]: 259 | pretrain_emb_dir = os.path.join(pretrain_pv_root_dir.replace('##', dataset), small_pvc_path) 260 | else: 261 | pretrain_emb_dir = os.path.join(pretrain_pv_root_dir.replace('##', dataset), large_pvc_path) 262 | else: 263 | pretrain_emb_dir = "None" 264 | dataset_name = "reviews_%s_5.json.gz.stem.nostop" % dataset 265 | data_dir = "%s/%s/min_count5" % (WORKING_DIR, dataset_name) 266 | input_train_dir = os.path.join(data_dir, "seq_query_split") 267 | #input_train_dir = os.path.join(data_dir, "query_split") 268 | cmd_arr.append('--data_dir {}'.format(data_dir)) 269 | cmd_arr.append('--pretrain_emb_dir {}'.format(pretrain_emb_dir)) 270 | cmd_arr.append('--input_train_dir {}'.format(input_train_dir)) 271 | output_path = "%s/%s" % (OUTPUT_DIR, dataset_name) 272 | model_name = "_".join(["{}{}".format(x,y) for x,y in zip(short_names, para[1:])]) 273 | save_dir = os.path.join(output_path, model_name) 274 | cur_cmd_option = " ".join(["--{} {}".format(x,y) for x,y in zip(para_names[1:], para[2:])]) 275 | cmd_arr.append(cur_cmd_option) 276 | cmd_arr.append("--save_dir %s" % save_dir) 277 | cmd_arr.append("--has_valid False") #use test as validation 278 | model_name = "{}_{}".format(dataset, model_name) 279 | #cmd_arr.append("--log_file %s/%s.log" % (args.log_dir, model_name)) 280 | #cmd_arr.append("&> %s/%s.log \n" % (args.log_dir, model_name)) 281 | cmd = " " .join(cmd_arr) 282 | cmd_arr.append("--mode test") 283 | cmd_arr.append("--train_review_only False") 284 | cmd_arr.append("--do_seq_review_test") 285 | cmd_arr.append("--rankfname test.seq_all.ranklist") 286 | test_cmd = " ".join(cmd_arr) 287 | #print(cmd) 288 | #os.system(cmd) 289 | fname = "%s/%s.sh" % (args.script_dir, model_name) 290 | with open(fname, 'w') as fout: 291 | fout.write(config_str) 292 | fout.write("#SBATCH --job-name=%d.sh\n" % job_id) 293 | fout.write("#SBATCH --output=%s/%s/%s.txt\n" % (args.log_dir, dataset, model_name)) 294 | fout.write("#SBATCH -e %s/%s/%s.err.txt\n" % (args.log_dir, dataset, model_name)) 295 | fout.write("\n") 296 | fout.write(cmd) 297 | fout.write("\n") 298 | fout.write(test_cmd) 299 | fout.write("\n\n") 300 | fout.write("exit\n") 301 | 302 | fscript.write("sbatch %s\n" % fname) 303 | job_id += 1 304 | fscript.close() 305 | 306 | 307 | 308 | --------------------------------------------------------------------------------