├── .gitignore ├── LICENSE ├── README.md ├── conceptnet_edges2nd.txt ├── data ├── entity2entityId.pkl ├── id2entity.pkl ├── movie_ids.pkl ├── subkg.pkl ├── test_data.jsonl ├── text_dict.pkl ├── train_data.jsonl └── valid_data.jsonl ├── dataset.py ├── dataset_vis.py ├── e2e_model.py ├── e2e_run.py ├── install_geometric.sh ├── key2index_3rd.json ├── mask4key.npy ├── mask4movie.npy ├── model.py ├── model_novel.py ├── models ├── graph.py ├── transformer.py └── utils.py ├── movieID2selection_label.pkl ├── run.py ├── run_novel.py ├── stopwords.txt ├── utils.py ├── visualize_dataset.py └── word2index_redial.json /.gitignore: -------------------------------------------------------------------------------- 1 | saved_model/ 2 | origin_w2v_data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # pytype static type analyzer 138 | .pytype/ 139 | 140 | # Cython debug symbols 141 | cython_debug/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NTRD 2 | This repository is the Pytorch implementation of our paper "[**Learning Neural Templates for Recommender Dialogue System**](https://arxiv.org/abs/2109.12302)" in **EMNLP 2021**. 3 | 4 | 5 | 6 | In this paper, we introduce NTRD, a novel recommender dialogue system (i.e., conversational recommendation system) framework that decouples the dialogue generation from the item recommendation via a two-stage strategy. Our approach makes the recommender dialogue system more flexible and controllable. Extensive experiments show our approach significantly outperforms the previous state-of-the-art methods. 7 | 8 | The code is still being organized, feel free to contact me if you encounter any problems. 9 | 10 | # Dependencies 11 | ``` 12 | pytorch==1.6.0 13 | gensim==3.8.3 14 | torch_geometric==1.6.3 15 | torch-cluster==1.5.8 16 | torch-scatter==2.0.5 17 | torch-sparse==0.6.8 18 | torch-spline-conv==1.2.0 19 | ``` 20 | 21 | 22 | 23 | 24 | the required data **word2vec_redial.npy** can be produced by the function ```dataset.prepare_word2vec()```. 25 | 26 | # Run 27 | Run the script below to pre-train the recommender module. It would converge after 3 epochs pre-training and 3 epochs fine-tuning. 28 | 29 | ```python 30 | python run.py 31 | ``` 32 | 33 | Then, run the following script to train the seq2seq dialogue task. Transformer model is difficult to coverge, so the model need many of epochs to covergence. Please be patient to train this model. 34 | 35 | ```python 36 | python run.py --is_finetune True 37 | ``` 38 | 39 | The model will report the result on test data automatically after covergence. 40 | 41 | To run the novel experiments, you need to generate the ```data/full_data.jsonl``` first by combining the ```data/train_data.jsonl``` and ```data/test_data.jsonl``` into one file. 42 | 43 | Also, you need to uncomment the code in ```dataset.py``` L117 and L317 - L 322. 44 | 45 | Then, run the following script to pretrained the recommender module. 46 | 47 | ```python 48 | python run_novel.py 49 | ``` 50 | 51 | and the following step is the same as the conventional setting by runing the command below. 52 | 53 | ```python 54 | python run_novel.py --is_finetune True 55 | ``` 56 | 57 | # Citation 58 | 59 | If you find this codebase helps your research, please kindly consider citing our paper in your publications. 60 | 61 | ```bibtex 62 | @inproceedings{liang2021learning, 63 | title={Learning Neural Templates for Recommender Dialogue System}, 64 | author={Liang, Zujie and 65 | Hu, Huang and 66 | Xu, Can and 67 | Miao, Jian and 68 | He, Yingying and 69 | Chen, Yining and 70 | Geng, Xiubo and 71 | Liang, Fan and 72 | Jiang, Daxin}, 73 | booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (EMNLP)}, 74 | year={2021} 75 | } 76 | ``` 77 | 78 | # Acknowledgment 79 | 80 | This codebase is implemented based on [KGSF](https://github.com/RUCAIBox/KGSF). Many thanks to the authors for their open-source project. 81 | -------------------------------------------------------------------------------- /data/entity2entityId.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/data/entity2entityId.pkl -------------------------------------------------------------------------------- /data/id2entity.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/data/id2entity.pkl -------------------------------------------------------------------------------- /data/movie_ids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/data/movie_ids.pkl -------------------------------------------------------------------------------- /data/subkg.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/data/subkg.pkl -------------------------------------------------------------------------------- /data/text_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/data/text_dict.pkl -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import pickle as pkl 4 | import json 5 | from nltk import word_tokenize 6 | import re 7 | from torch.utils.data.dataset import Dataset 8 | import numpy as np 9 | from copy import deepcopy 10 | 11 | # import sys 12 | # # from importlib import reload 13 | # # reload(sys) 14 | # # sys.setdefaultencoding('utf-8') 15 | # import importlib 16 | # importlib.reload(sys) 17 | 18 | MOVIE_TOKEN = '__MOVIE__' 19 | 20 | class dataset(object): 21 | def __init__(self,filename,opt): 22 | self.entity2entityId=pkl.load(open('data/entity2entityId.pkl','rb')) 23 | self.movieID2selection_label=pkl.load(open('movieID2selection_label.pkl','rb')) 24 | self.selection_label2movieID={self.movieID2selection_label[key]:key for key in self.movieID2selection_label} 25 | self.entity_max=len(self.entity2entityId) 26 | 27 | self.id2entity=pkl.load(open('data/id2entity.pkl','rb')) 28 | self.subkg=pkl.load(open('data/subkg.pkl','rb')) #need not back process 29 | self.text_dict=pkl.load(open('data/text_dict.pkl','rb')) 30 | 31 | self.batch_size=opt['batch_size'] 32 | self.max_c_length=opt['max_c_length'] 33 | self.max_r_length=opt['max_r_length'] 34 | self.max_count=opt['max_count'] 35 | self.entity_num=opt['n_entity'] 36 | #self.word2index=json.load(open('word2index.json',encoding='utf-8')) 37 | 38 | f=open(filename,encoding='utf-8') 39 | self.data=[] 40 | self.corpus=[] 41 | for line in tqdm(f): 42 | lines=json.loads(line.strip()) 43 | seekerid=lines["initiatorWorkerId"] 44 | recommenderid=lines["respondentWorkerId"] 45 | contexts=lines['messages'] 46 | movies=lines['movieMentions'] 47 | altitude=lines['respondentQuestions'] 48 | initial_altitude=lines['initiatorQuestions'] 49 | cases=self._context_reformulate(contexts,movies,altitude,initial_altitude,seekerid,recommenderid) 50 | self.data.extend(cases) 51 | 52 | if 'train' in filename: 53 | self.prepare_word2vec() 54 | self.word2index = json.load(open('word2index_redial.json', encoding='utf-8')) 55 | self.key2index=json.load(open('key2index_3rd.json',encoding='utf-8')) 56 | 57 | self.stopwords=set([word.strip() for word in open('stopwords.txt',encoding='utf-8')]) 58 | 59 | #self.co_occurance_ext(self.data) 60 | #exit() 61 | 62 | def prepare_word2vec(self): 63 | import gensim 64 | # model=gensim.models.word2vec.Word2Vec(self.corpus,vector_size=300,min_count=1) 65 | model=gensim.models.word2vec.Word2Vec(self.corpus,size=300,min_count=1) 66 | # model.save('word2vec_redial') 67 | # word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} 68 | word2index = {word: i + 4 for i, word in enumerate(model.wv.index2word)} 69 | #word2index['_split_']=len(word2index)+4 70 | #json.dump(word2index, open('word2index_redial.json', 'w', encoding='utf-8'), ensure_ascii=False) 71 | word2embedding = [[0] * 300] * 4 + [model[word] for word in word2index]+[[0]*300] 72 | import numpy as np 73 | 74 | word2index['_split_']=len(word2index)+4 75 | print('---Saving word2vec data... ----') 76 | json.dump(word2index, open('word2index_redial.json', 'w', encoding='utf-8'), ensure_ascii=False) 77 | 78 | print(np.shape(word2embedding)) 79 | np.save('word2vec_redial.npy', word2embedding) 80 | 81 | def padding_w2v(self,sentence,max_length,transformer=True,is_response=False,movies_gth=None,pad=0,end=2,unk=3): 82 | vector=[] 83 | concept_mask=[] 84 | dbpedia_mask=[] 85 | movie_mask=[] 86 | 87 | 88 | cur_movie = 0 89 | for word in sentence: 90 | vector.append(self.word2index.get(word,unk)) 91 | #if word.lower() not in self.stopwords: 92 | # key2index file may have problem due to the vocab size, TODO by Jokie 2021/04/12 93 | concept_mask.append(self.key2index.get(word.lower(),0)) 94 | #else: 95 | # concept_mask.append(0) 96 | 97 | # if '@' in word: 98 | # try: 99 | # entity = self.id2entity[int(word[1:])] 100 | # id=self.entity2entityId[entity] 101 | # except: 102 | # id=self.entity_max 103 | # dbpedia_mask.append(id) 104 | # else: 105 | # dbpedia_mask.append(self.entity_max) 106 | 107 | if MOVIE_TOKEN in word: 108 | # if is_response and MOVIE_TOKEN in word: 109 | movie_mask.append(1) 110 | try: 111 | #WAY1: original movie gth 112 | cur_movie_id = self.selection_label2movieID[movies_gth[cur_movie]] 113 | entity = self.id2entity[int(cur_movie_id)] 114 | id=self.entity2entityId[entity] 115 | #WAY2: hacked for novel exp by Jokie 2021/05/01 116 | # id=movies_gth[cur_movie] 117 | except: 118 | id=self.entity_max 119 | cur_movie+=1 120 | dbpedia_mask.append(id) 121 | else: 122 | dbpedia_mask.append(self.entity_max) 123 | movie_mask.append(0) 124 | 125 | if movies_gth is not None: 126 | # print('cur_movie:', cur_movie) 127 | # print('len(movies_gth):', len(movies_gth)) 128 | assert cur_movie == len(movies_gth) 129 | 130 | 131 | vector.append(end) 132 | concept_mask.append(0) 133 | dbpedia_mask.append(self.entity_max) 134 | movie_mask.append(0) 135 | 136 | if len(vector)>max_length: 137 | if transformer: 138 | movie_nums = sum(movie_mask[-max_length:]) 139 | if movie_nums!=0: 140 | movies_gth = movies_gth[-movie_nums:] 141 | else: 142 | movies_gth = [] 143 | return vector[-max_length:],max_length,concept_mask[-max_length:],dbpedia_mask[-max_length:],movies_gth 144 | else: 145 | movie_nums = sum(movie_mask[:max_length]) 146 | movies_gth = movies_gth[:movie_nums] 147 | return vector[:max_length],max_length,concept_mask[:max_length],dbpedia_mask[:max_length],movies_gth 148 | else: 149 | length=len(vector) 150 | return vector+(max_length-len(vector))*[pad],length,\ 151 | concept_mask+(max_length-len(vector))*[0],dbpedia_mask+(max_length-len(vector))*[0],movies_gth 152 | # return vector+(max_length-len(vector))*[pad],length,\ 153 | # concept_mask+(max_length-len(vector))*[0],dbpedia_mask+(max_length-len(vector))*[self.entity_max] 154 | 155 | def padding_all_movies(self,movies,max_length,transformer=True,pad=-1): 156 | 157 | 158 | if len(movies)>max_length: 159 | # if transformer: 160 | return movies[-max_length:] 161 | # else: 162 | # return vector[:max_length],max_length,concept_mask[:max_length],dbpedia_mask[:max_length] 163 | else: 164 | length=len(movies) 165 | return movies+(max_length-len(movies))*[pad], length 166 | 167 | def padding_context(self,contexts,pad=0,transformer=True): 168 | vectors=[] 169 | vec_lengths=[] 170 | #--------------------Useless --------------------- 171 | if transformer==False: 172 | if len(contexts)>self.max_count: 173 | for sen in contexts[-self.max_count:]: 174 | vec,v_l=self.padding_w2v(sen,self.max_r_length,transformer) 175 | vectors.append(vec) 176 | vec_lengths.append(v_l) 177 | return vectors,vec_lengths,self.max_count 178 | else: 179 | length=len(contexts) 180 | for sen in contexts: 181 | vec, v_l = self.padding_w2v(sen,self.max_r_length,transformer) 182 | vectors.append(vec) 183 | vec_lengths.append(v_l) 184 | return vectors+(self.max_count-length)*[[pad]*self.max_c_length],vec_lengths+[0]*(self.max_count-length),length 185 | #--------------------Useless(end) --------------------- 186 | else: 187 | contexts_com=[] 188 | for sen in contexts[-self.max_count:-1]: 189 | contexts_com.extend(sen) 190 | contexts_com.append('_split_') 191 | contexts_com.extend(contexts[-1]) 192 | vec,v_l,concept_mask,dbpedia_mask,_=self.padding_w2v(contexts_com,self.max_c_length,transformer) 193 | return vec,v_l,concept_mask,dbpedia_mask,0 194 | 195 | def response_delibration(self,response,unk='MASKED_WORD'): 196 | new_response=[] 197 | for word in response: 198 | # if word in self.key2index: 199 | if '@' in word: 200 | # print(word) 201 | new_response.append(unk) 202 | else: 203 | new_response.append(word) 204 | return new_response 205 | 206 | def data_process(self,is_finetune=False): 207 | data_set = [] 208 | context_before = [] 209 | for line in self.data: 210 | #if len(line['contexts'])>2: 211 | # continue 212 | if is_finetune and line['contexts'] == context_before: 213 | continue 214 | else: 215 | context_before = line['contexts'] 216 | 217 | 218 | 219 | context,c_lengths,concept_mask,dbpedia_mask,_=self.padding_context(line['contexts']) 220 | # context,c_lengths,concept_mask,dbpedia_mask_context,_=self.padding_context(line['contexts']) 221 | response,r_length,_,_,movies_gth=self.padding_w2v(line['response'],self.max_r_length,transformer=True,is_response=True,movies_gth=line['all_movies']) 222 | 223 | #padding all_movies 224 | movies_gth, movies_num = self.padding_all_movies(movies_gth,self.max_r_length) 225 | # movies_gth, movies_num = self.padding_all_movies(line['all_movies'],self.max_r_length) 226 | 227 | if False: 228 | mask_response,mask_r_length,_,_,_=self.padding_w2v(self.response_delibration(line['response']),self.max_r_length) 229 | else: 230 | mask_response, mask_r_length=response,r_length 231 | assert len(context)==self.max_c_length 232 | assert len(concept_mask)==self.max_c_length 233 | assert len(dbpedia_mask)==self.max_c_length 234 | # assert len(dbpedia_mask)==self.max_c_length or len(dbpedia_mask)==self.max_r_length 235 | 236 | data_set.append([np.array(context),c_lengths,np.array(response),r_length,np.array(mask_response),mask_r_length,line['entity'], 237 | line['movie'],concept_mask,dbpedia_mask,line['rec'], np.array(movies_gth), movies_num]) 238 | return data_set 239 | 240 | def co_occurance_ext(self,data): 241 | stopwords=set([word.strip() for word in open('stopwords.txt',encoding='utf-8')]) 242 | keyword_sets=set(self.key2index.keys())-stopwords 243 | movie_wordset=set() 244 | for line in data: 245 | movie_words=[] 246 | if line['rec']==1: 247 | for word in line['response']: 248 | if '@' in word: 249 | try: 250 | num=self.entity2entityId[self.id2entity[int(word[1:])]] 251 | movie_words.append(word) 252 | movie_wordset.add(word) 253 | except: 254 | pass 255 | line['movie_words']=movie_words 256 | new_edges=set() 257 | for line in data: 258 | if len(line['movie_words'])>0: 259 | before_set=set() 260 | after_set=set() 261 | co_set=set() 262 | for sen in line['contexts']: 263 | for word in sen: 264 | if word in keyword_sets: 265 | before_set.add(word) 266 | if word in movie_wordset: 267 | after_set.add(word) 268 | for word in line['response']: 269 | if word in keyword_sets: 270 | co_set.add(word) 271 | 272 | for movie in line['movie_words']: 273 | for word in list(before_set): 274 | new_edges.add('co_before'+'\t'+movie+'\t'+word+'\n') 275 | for word in list(co_set): 276 | new_edges.add('co_occurance' + '\t' + movie + '\t' + word + '\n') 277 | for word in line['movie_words']: 278 | if word!=movie: 279 | new_edges.add('co_occurance' + '\t' + movie + '\t' + word + '\n') 280 | for word in list(after_set): 281 | new_edges.add('co_after'+'\t'+word+'\t'+movie+'\n') 282 | for word_a in list(co_set): 283 | new_edges.add('co_after'+'\t'+word+'\t'+word_a+'\n') 284 | f=open('co_occurance.txt','w',encoding='utf-8') 285 | f.writelines(list(new_edges)) 286 | f.close() 287 | json.dump(list(movie_wordset),open('movie_word.json','w',encoding='utf-8'),ensure_ascii=False) 288 | print(len(new_edges)) 289 | print(len(movie_wordset)) 290 | 291 | def entities2ids(self,entities): 292 | return [self.entity2entityId[word] for word in entities] 293 | 294 | def detect_movie(self,sentence,movies): 295 | token_text = word_tokenize(sentence) 296 | num=0 297 | token_text_com=[] 298 | masked_movie_by = MOVIE_TOKEN 299 | masked_movie_num = 0 300 | while num0: 351 | # token_text,movie_rec,all_movie_selection_label,masked_movie_num=self.detect_movie(message['text'],movies) 352 | 353 | # # print('processed context', u' '.join(token_text).encode('utf-8').strip()) 354 | # # print('movie rec:', movie_rec) 355 | # # print('all_movie_selection_label :', all_movie_selection_label) 356 | # else: 357 | # token_text,movie_rec,_,_=self.detect_movie(message['text'],movies) 358 | # all_movie_selection_label=[] 359 | 360 | if len(context_list)==0: 361 | # context_dict={'text':token_text,'entity':entities+movie_rec,'user':message['senderWorkerId'],'movie':movie_rec} 362 | context_dict={'text':token_text,'entity':entities+movie_rec,'user':message['senderWorkerId'],'movie':movie_rec,'movies_selection_labels': all_movie_selection_label} 363 | context_list.append(context_dict) 364 | last_id=message['senderWorkerId'] 365 | continue 366 | if message['senderWorkerId']==last_id: 367 | context_list[-1]['text']+=token_text 368 | context_list[-1]['entity']+=entities+movie_rec 369 | context_list[-1]['movie']+=movie_rec 370 | # if message['senderWorkerId']==re_id and len(context_list)>0: 371 | # context_list[-1]['movies_selection_labels']+=all_movie_selection_label 372 | context_list[-1]['movies_selection_labels']+=all_movie_selection_label 373 | else: 374 | # context_dict = {'text': token_text, 'entity': entities+movie_rec, 375 | # 'user': message['senderWorkerId'], 'movie':movie_rec, 'movies_selection_labels': all_movie_selection_label} 376 | context_dict = {'text': token_text, 'entity': entities+movie_rec, 377 | 'user': message['senderWorkerId'], 'movie':movie_rec} 378 | # if message['senderWorkerId']==re_id and len(context_list)>0: 379 | # context_dict['movies_selection_labels']=all_movie_selection_label 380 | context_dict['movies_selection_labels']=all_movie_selection_label 381 | context_list.append(context_dict) 382 | last_id = message['senderWorkerId'] 383 | 384 | # if message['senderWorkerId']==re_id and len(context_list)>0: 385 | movie_num=context_list[-1]['text'].count(MOVIE_TOKEN) 386 | len_labels = len(context_list[-1]['movies_selection_labels']) 387 | assert movie_num == len_labels 388 | 389 | # if 'zombie' in ' '.join(context_list[-1]['text']) and 'focuses' in ' '.join(context_list[-1]['text']) and 'survive' in ' '.join(context_list[-1]['text']) and 'individual' in ' '.join(context_list[-1]['text']): 390 | # print(u' '.join(context_list[-1]['text']).encode('utf-8').strip()) 391 | # print('movie_num', movie_num) 392 | # # print('len(movies_selection_labels', len_labels) 393 | # # print(u' '.join(context_list[-1]['text']).encode('utf-8').strip()) 394 | # print(context_list[-1]['movies_selection_labels']) 395 | # print(context_list[-1]['movie']) 396 | # # print('-----------------------------------') 397 | 398 | 399 | 400 | cases=[] 401 | contexts=[] 402 | entities_set=set() 403 | entities=[] 404 | for context_dict in context_list: 405 | self.corpus.append(context_dict['text']) 406 | if context_dict['user']==re_id and len(contexts)>0: 407 | response=context_dict['text'] 408 | 409 | #entity_vec=np.zeros(self.entity_num) 410 | #for en in list(entities): 411 | # entity_vec[en]=1 412 | #movie_vec=np.zeros(self.entity_num+1,dtype=np.float) 413 | if len(context_dict['movie'])!=0: 414 | for movie in context_dict['movie']: 415 | #if movie not in entities_set: 416 | cases.append({'contexts': deepcopy(contexts), 'response': response, 'entity': deepcopy(entities), 'movie': movie, 'rec':1, 'all_movies':context_dict['movies_selection_labels']}) 417 | # cases[-1]['all_movies']=context_dict['movie'] 418 | else: 419 | cases.append({'contexts': deepcopy(contexts), 'response': response, 'entity': deepcopy(entities), 'movie': 0, 'rec':0,'all_movies':context_dict['movies_selection_labels']}) 420 | 421 | contexts.append(context_dict['text']) 422 | for word in context_dict['entity']: 423 | if word not in entities_set: 424 | entities.append(word) 425 | entities_set.add(word) 426 | else: 427 | contexts.append(context_dict['text']) 428 | for word in context_dict['entity']: 429 | if word not in entities_set: 430 | entities.append(word) 431 | entities_set.add(word) 432 | return cases 433 | 434 | class CRSdataset(Dataset): 435 | def __init__(self, dataset, entity_num, concept_num): 436 | self.data=dataset 437 | self.entity_num = entity_num 438 | self.concept_num = concept_num+1 439 | 440 | def __getitem__(self, index): 441 | ''' 442 | movie_vec = np.zeros(self.entity_num, dtype=np.float) 443 | context, c_lengths, response, r_length, entity, movie, concept_mask, dbpedia_mask, rec = self.data[index] 444 | for en in movie: 445 | movie_vec[en] = 1 / len(movie) 446 | return context, c_lengths, response, r_length, entity, movie_vec, concept_mask, dbpedia_mask, rec 447 | ''' 448 | context, c_lengths, response, r_length, mask_response, mask_r_length, entity, movie, concept_mask, dbpedia_mask, rec, movies_gth, movies_num= self.data[index] 449 | entity_vec = np.zeros(self.entity_num) 450 | entity_vector=np.zeros(50,dtype=np.int) 451 | point=0 452 | for en in entity: 453 | entity_vec[en]=1 454 | entity_vector[point]=en 455 | point+=1 456 | 457 | concept_vec=np.zeros(self.concept_num) 458 | for con in concept_mask: 459 | if con!=0: 460 | concept_vec[con]=1 461 | 462 | db_vec=np.zeros(self.entity_num) 463 | for db in dbpedia_mask: 464 | if db!=0: 465 | db_vec[db]=1 466 | 467 | return context, c_lengths, response, r_length, mask_response, mask_r_length, entity_vec, entity_vector, movie, np.array(concept_mask), np.array(dbpedia_mask), concept_vec, db_vec, rec, movies_gth, movies_num 468 | 469 | def __len__(self): 470 | return len(self.data) 471 | 472 | if __name__=='__main__': 473 | ds=dataset('data/train_data.jsonl') 474 | print() 475 | -------------------------------------------------------------------------------- /dataset_vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import pickle as pkl 4 | import json 5 | from nltk import word_tokenize 6 | import re 7 | from torch.utils.data.dataset import Dataset 8 | import numpy as np 9 | from copy import deepcopy 10 | 11 | # import sys 12 | # # from importlib import reload 13 | # # reload(sys) 14 | # # sys.setdefaultencoding('utf-8') 15 | # import importlib 16 | # importlib.reload(sys) 17 | 18 | MOVIE_TOKEN = '__MOVIE__' 19 | 20 | class dataset(object): 21 | def __init__(self,filename): 22 | self.entity2entityId=pkl.load(open('data/entity2entityId.pkl','rb')) 23 | self.movieID2selection_label=pkl.load(open('movieID2selection_label.pkl','rb')) 24 | self.selection_label2movieID={self.movieID2selection_label[key]:key for key in self.movieID2selection_label} 25 | self.entity_max=len(self.entity2entityId) 26 | 27 | self.id2entity=pkl.load(open('data/id2entity.pkl','rb')) 28 | self.subkg=pkl.load(open('data/subkg.pkl','rb')) #need not back process 29 | self.text_dict=pkl.load(open('data/text_dict.pkl','rb')) 30 | 31 | # self.batch_size=opt['batch_size'] 32 | # self.max_c_length=opt['max_c_length'] 33 | # self.max_r_length=opt['max_r_length'] 34 | # self.max_count=opt['max_count'] 35 | # self.entity_num=opt['n_entity'] 36 | #self.word2index=json.load(open('word2index.json',encoding='utf-8')) 37 | 38 | f=open(filename,encoding='utf-8') 39 | self.data=[] 40 | self.corpus=[] 41 | self.all_movies=[] 42 | for line in tqdm(f): 43 | lines=json.loads(line.strip()) 44 | seekerid=lines["initiatorWorkerId"] 45 | recommenderid=lines["respondentWorkerId"] 46 | contexts=lines['messages'] 47 | movies=lines['movieMentions'] 48 | self.all_movies.extend([int(movie_id) for movie_id in movies]) 49 | altitude=lines['respondentQuestions'] 50 | initial_altitude=lines['initiatorQuestions'] 51 | # cases=self._context_reformulate(contexts,movies,altitude,initial_altitude,seekerid,recommenderid) 52 | # self.data.extend(cases) 53 | 54 | #if 'train' in filename: 55 | 56 | # self.prepare_word2vec() 57 | self.word2index = json.load(open('word2index_redial.json', encoding='utf-8')) 58 | self.key2index=json.load(open('key2index_3rd.json',encoding='utf-8')) 59 | 60 | self.stopwords=set([word.strip() for word in open('stopwords.txt',encoding='utf-8')]) 61 | 62 | #self.co_occurance_ext(self.data) 63 | #exit() 64 | 65 | def prepare_word2vec(self): 66 | import gensim 67 | # model=gensim.models.word2vec.Word2Vec(self.corpus,vector_size=300,min_count=1) 68 | model=gensim.models.word2vec.Word2Vec(self.corpus,size=300,min_count=1) 69 | model.save('word2vec_redial') 70 | # word2index = {word: i + 4 for i, word in enumerate(model.wv.index_to_key)} 71 | word2index = {word: i + 4 for i, word in enumerate(model.wv.index2word)} 72 | #word2index['_split_']=len(word2index)+4 73 | #json.dump(word2index, open('word2index_redial.json', 'w', encoding='utf-8'), ensure_ascii=False) 74 | word2embedding = [[0] * 300] * 4 + [model[word] for word in word2index]+[[0]*300] 75 | import numpy as np 76 | 77 | word2index['_split_']=len(word2index)+4 78 | print('---Saving word2vec data... ----') 79 | json.dump(word2index, open('word2index_redial.json', 'w', encoding='utf-8'), ensure_ascii=False) 80 | 81 | print(np.shape(word2embedding)) 82 | np.save('word2vec_redial.npy', word2embedding) 83 | 84 | def padding_w2v(self,sentence,max_length,transformer=True,is_response=False,movies_gth=None,pad=0,end=2,unk=3): 85 | vector=[] 86 | concept_mask=[] 87 | dbpedia_mask=[] 88 | movie_mask=[] 89 | 90 | 91 | cur_movie = 0 92 | for word in sentence: 93 | vector.append(self.word2index.get(word,unk)) 94 | #if word.lower() not in self.stopwords: 95 | # key2index file may have problem due to the vocab size, TODO by Jokie 2021/04/12 96 | concept_mask.append(self.key2index.get(word.lower(),0)) 97 | #else: 98 | # concept_mask.append(0) 99 | 100 | # if '@' in word: 101 | # try: 102 | # entity = self.id2entity[int(word[1:])] 103 | # id=self.entity2entityId[entity] 104 | # except: 105 | # id=self.entity_max 106 | # dbpedia_mask.append(id) 107 | # else: 108 | # dbpedia_mask.append(self.entity_max) 109 | 110 | if MOVIE_TOKEN in word: 111 | # if is_response and MOVIE_TOKEN in word: 112 | movie_mask.append(1) 113 | try: 114 | cur_movie_id = self.selection_label2movieID[movies_gth[cur_movie]] 115 | entity = self.id2entity[int(cur_movie_id)] 116 | id=self.entity2entityId[entity] 117 | except: 118 | id=self.entity_max 119 | cur_movie+=1 120 | dbpedia_mask.append(id) 121 | else: 122 | dbpedia_mask.append(self.entity_max) 123 | movie_mask.append(0) 124 | 125 | if movies_gth is not None: 126 | # print('cur_movie:', cur_movie) 127 | # print('len(movies_gth):', len(movies_gth)) 128 | assert cur_movie == len(movies_gth) 129 | 130 | 131 | vector.append(end) 132 | concept_mask.append(0) 133 | dbpedia_mask.append(self.entity_max) 134 | movie_mask.append(0) 135 | 136 | if len(vector)>max_length: 137 | if transformer: 138 | movie_nums = sum(movie_mask[-max_length:]) 139 | if movie_nums!=0: 140 | movies_gth = movies_gth[-movie_nums:] 141 | else: 142 | movies_gth = [] 143 | return vector[-max_length:],max_length,concept_mask[-max_length:],dbpedia_mask[-max_length:],movies_gth 144 | else: 145 | movie_nums = sum(movie_mask[:max_length]) 146 | movies_gth = movies_gth[:movie_nums] 147 | return vector[:max_length],max_length,concept_mask[:max_length],dbpedia_mask[:max_length],movies_gth 148 | else: 149 | length=len(vector) 150 | return vector+(max_length-len(vector))*[pad],length,\ 151 | concept_mask+(max_length-len(vector))*[0],dbpedia_mask+(max_length-len(vector))*[0],movies_gth 152 | # return vector+(max_length-len(vector))*[pad],length,\ 153 | # concept_mask+(max_length-len(vector))*[0],dbpedia_mask+(max_length-len(vector))*[self.entity_max] 154 | 155 | def padding_all_movies(self,movies,max_length,transformer=True,pad=-1): 156 | 157 | 158 | if len(movies)>max_length: 159 | # if transformer: 160 | return movies[-max_length:] 161 | # else: 162 | # return vector[:max_length],max_length,concept_mask[:max_length],dbpedia_mask[:max_length] 163 | else: 164 | length=len(movies) 165 | return movies+(max_length-len(movies))*[pad], length 166 | 167 | def padding_context(self,contexts,pad=0,transformer=True): 168 | vectors=[] 169 | vec_lengths=[] 170 | #--------------------Useless --------------------- 171 | if transformer==False: 172 | if len(contexts)>self.max_count: 173 | for sen in contexts[-self.max_count:]: 174 | vec,v_l=self.padding_w2v(sen,self.max_r_length,transformer) 175 | vectors.append(vec) 176 | vec_lengths.append(v_l) 177 | return vectors,vec_lengths,self.max_count 178 | else: 179 | length=len(contexts) 180 | for sen in contexts: 181 | vec, v_l = self.padding_w2v(sen,self.max_r_length,transformer) 182 | vectors.append(vec) 183 | vec_lengths.append(v_l) 184 | return vectors+(self.max_count-length)*[[pad]*self.max_c_length],vec_lengths+[0]*(self.max_count-length),length 185 | #--------------------Useless(end) --------------------- 186 | else: 187 | contexts_com=[] 188 | for sen in contexts[-self.max_count:-1]: 189 | contexts_com.extend(sen) 190 | contexts_com.append('_split_') 191 | contexts_com.extend(contexts[-1]) 192 | vec,v_l,concept_mask,dbpedia_mask,_=self.padding_w2v(contexts_com,self.max_c_length,transformer) 193 | return vec,v_l,concept_mask,dbpedia_mask,0 194 | 195 | def response_delibration(self,response,unk='MASKED_WORD'): 196 | new_response=[] 197 | for word in response: 198 | # if word in self.key2index: 199 | if '@' in word: 200 | # print(word) 201 | new_response.append(unk) 202 | else: 203 | new_response.append(word) 204 | return new_response 205 | 206 | def data_process(self,is_finetune=False): 207 | data_set = [] 208 | context_before = [] 209 | for line in self.data: 210 | #if len(line['contexts'])>2: 211 | # continue 212 | if is_finetune and line['contexts'] == context_before: 213 | continue 214 | else: 215 | context_before = line['contexts'] 216 | 217 | 218 | 219 | context,c_lengths,concept_mask,dbpedia_mask,_=self.padding_context(line['contexts']) 220 | # context,c_lengths,concept_mask,dbpedia_mask_context,_=self.padding_context(line['contexts']) 221 | response,r_length,_,_,movies_gth=self.padding_w2v(line['response'],self.max_r_length,transformer=True,is_response=True,movies_gth=line['all_movies']) 222 | 223 | #padding all_movies 224 | movies_gth, movies_num = self.padding_all_movies(movies_gth,self.max_r_length) 225 | # movies_gth, movies_num = self.padding_all_movies(line['all_movies'],self.max_r_length) 226 | 227 | if False: 228 | mask_response,mask_r_length,_,_,_=self.padding_w2v(self.response_delibration(line['response']),self.max_r_length) 229 | else: 230 | mask_response, mask_r_length=response,r_length 231 | assert len(context)==self.max_c_length 232 | assert len(concept_mask)==self.max_c_length 233 | assert len(dbpedia_mask)==self.max_c_length 234 | # assert len(dbpedia_mask)==self.max_c_length or len(dbpedia_mask)==self.max_r_length 235 | 236 | data_set.append([np.array(context),c_lengths,np.array(response),r_length,np.array(mask_response),mask_r_length,line['entity'], 237 | line['movie'],concept_mask,dbpedia_mask,line['rec'], np.array(movies_gth), movies_num]) 238 | return data_set 239 | 240 | def co_occurance_ext(self,data): 241 | stopwords=set([word.strip() for word in open('stopwords.txt',encoding='utf-8')]) 242 | keyword_sets=set(self.key2index.keys())-stopwords 243 | movie_wordset=set() 244 | for line in data: 245 | movie_words=[] 246 | if line['rec']==1: 247 | for word in line['response']: 248 | if '@' in word: 249 | try: 250 | num=self.entity2entityId[self.id2entity[int(word[1:])]] 251 | movie_words.append(word) 252 | movie_wordset.add(word) 253 | except: 254 | pass 255 | line['movie_words']=movie_words 256 | new_edges=set() 257 | for line in data: 258 | if len(line['movie_words'])>0: 259 | before_set=set() 260 | after_set=set() 261 | co_set=set() 262 | for sen in line['contexts']: 263 | for word in sen: 264 | if word in keyword_sets: 265 | before_set.add(word) 266 | if word in movie_wordset: 267 | after_set.add(word) 268 | for word in line['response']: 269 | if word in keyword_sets: 270 | co_set.add(word) 271 | 272 | for movie in line['movie_words']: 273 | for word in list(before_set): 274 | new_edges.add('co_before'+'\t'+movie+'\t'+word+'\n') 275 | for word in list(co_set): 276 | new_edges.add('co_occurance' + '\t' + movie + '\t' + word + '\n') 277 | for word in line['movie_words']: 278 | if word!=movie: 279 | new_edges.add('co_occurance' + '\t' + movie + '\t' + word + '\n') 280 | for word in list(after_set): 281 | new_edges.add('co_after'+'\t'+word+'\t'+movie+'\n') 282 | for word_a in list(co_set): 283 | new_edges.add('co_after'+'\t'+word+'\t'+word_a+'\n') 284 | f=open('co_occurance.txt','w',encoding='utf-8') 285 | f.writelines(list(new_edges)) 286 | f.close() 287 | json.dump(list(movie_wordset),open('movie_word.json','w',encoding='utf-8'),ensure_ascii=False) 288 | print(len(new_edges)) 289 | print(len(movie_wordset)) 290 | 291 | def entities2ids(self,entities): 292 | return [self.entity2entityId[word] for word in entities] 293 | 294 | def detect_movie(self,sentence,movies): 295 | token_text = word_tokenize(sentence) 296 | num=0 297 | token_text_com=[] 298 | masked_movie_by = MOVIE_TOKEN 299 | masked_movie_num = 0 300 | while num0: 344 | # token_text,movie_rec,all_movie_selection_label,masked_movie_num=self.detect_movie(message['text'],movies) 345 | 346 | # # print('processed context', u' '.join(token_text).encode('utf-8').strip()) 347 | # # print('movie rec:', movie_rec) 348 | # # print('all_movie_selection_label :', all_movie_selection_label) 349 | # else: 350 | # token_text,movie_rec,_,_=self.detect_movie(message['text'],movies) 351 | # all_movie_selection_label=[] 352 | 353 | if len(context_list)==0: 354 | # context_dict={'text':token_text,'entity':entities+movie_rec,'user':message['senderWorkerId'],'movie':movie_rec} 355 | context_dict={'text':token_text,'entity':entities+movie_rec,'user':message['senderWorkerId'],'movie':movie_rec,'movies_selection_labels': all_movie_selection_label} 356 | context_list.append(context_dict) 357 | last_id=message['senderWorkerId'] 358 | continue 359 | if message['senderWorkerId']==last_id: 360 | context_list[-1]['text']+=token_text 361 | context_list[-1]['entity']+=entities+movie_rec 362 | context_list[-1]['movie']+=movie_rec 363 | # if message['senderWorkerId']==re_id and len(context_list)>0: 364 | # context_list[-1]['movies_selection_labels']+=all_movie_selection_label 365 | context_list[-1]['movies_selection_labels']+=all_movie_selection_label 366 | else: 367 | # context_dict = {'text': token_text, 'entity': entities+movie_rec, 368 | # 'user': message['senderWorkerId'], 'movie':movie_rec, 'movies_selection_labels': all_movie_selection_label} 369 | context_dict = {'text': token_text, 'entity': entities+movie_rec, 370 | 'user': message['senderWorkerId'], 'movie':movie_rec} 371 | # if message['senderWorkerId']==re_id and len(context_list)>0: 372 | # context_dict['movies_selection_labels']=all_movie_selection_label 373 | context_dict['movies_selection_labels']=all_movie_selection_label 374 | context_list.append(context_dict) 375 | last_id = message['senderWorkerId'] 376 | 377 | # if message['senderWorkerId']==re_id and len(context_list)>0: 378 | movie_num=context_list[-1]['text'].count(MOVIE_TOKEN) 379 | len_labels = len(context_list[-1]['movies_selection_labels']) 380 | assert movie_num == len_labels 381 | 382 | # if 'zombie' in ' '.join(context_list[-1]['text']) and 'focuses' in ' '.join(context_list[-1]['text']) and 'survive' in ' '.join(context_list[-1]['text']) and 'individual' in ' '.join(context_list[-1]['text']): 383 | # print(u' '.join(context_list[-1]['text']).encode('utf-8').strip()) 384 | # print('movie_num', movie_num) 385 | # # print('len(movies_selection_labels', len_labels) 386 | # # print(u' '.join(context_list[-1]['text']).encode('utf-8').strip()) 387 | # print(context_list[-1]['movies_selection_labels']) 388 | # print(context_list[-1]['movie']) 389 | # # print('-----------------------------------') 390 | 391 | 392 | 393 | cases=[] 394 | contexts=[] 395 | entities_set=set() 396 | entities=[] 397 | for context_dict in context_list: 398 | self.corpus.append(context_dict['text']) 399 | if context_dict['user']==re_id and len(contexts)>0: 400 | response=context_dict['text'] 401 | 402 | #entity_vec=np.zeros(self.entity_num) 403 | #for en in list(entities): 404 | # entity_vec[en]=1 405 | #movie_vec=np.zeros(self.entity_num+1,dtype=np.float) 406 | if len(context_dict['movie'])!=0: 407 | for movie in context_dict['movie']: 408 | #if movie not in entities_set: 409 | cases.append({'contexts': deepcopy(contexts), 'response': response, 'entity': deepcopy(entities), 'movie': movie, 'rec':1, 'all_movies':context_dict['movies_selection_labels']}) 410 | # cases[-1]['all_movies']=context_dict['movie'] 411 | else: 412 | cases.append({'contexts': deepcopy(contexts), 'response': response, 'entity': deepcopy(entities), 'movie': 0, 'rec':0,'all_movies':context_dict['movies_selection_labels']}) 413 | 414 | contexts.append(context_dict['text']) 415 | for word in context_dict['entity']: 416 | if word not in entities_set: 417 | entities.append(word) 418 | entities_set.add(word) 419 | else: 420 | contexts.append(context_dict['text']) 421 | for word in context_dict['entity']: 422 | if word not in entities_set: 423 | entities.append(word) 424 | entities_set.add(word) 425 | return cases 426 | 427 | class CRSdataset(Dataset): 428 | def __init__(self, dataset, entity_num, concept_num): 429 | self.data=dataset 430 | self.entity_num = entity_num 431 | self.concept_num = concept_num+1 432 | 433 | def __getitem__(self, index): 434 | ''' 435 | movie_vec = np.zeros(self.entity_num, dtype=np.float) 436 | context, c_lengths, response, r_length, entity, movie, concept_mask, dbpedia_mask, rec = self.data[index] 437 | for en in movie: 438 | movie_vec[en] = 1 / len(movie) 439 | return context, c_lengths, response, r_length, entity, movie_vec, concept_mask, dbpedia_mask, rec 440 | ''' 441 | context, c_lengths, response, r_length, mask_response, mask_r_length, entity, movie, concept_mask, dbpedia_mask, rec, movies_gth, movies_num= self.data[index] 442 | entity_vec = np.zeros(self.entity_num) 443 | entity_vector=np.zeros(50,dtype=np.int) 444 | point=0 445 | for en in entity: 446 | entity_vec[en]=1 447 | entity_vector[point]=en 448 | point+=1 449 | 450 | concept_vec=np.zeros(self.concept_num) 451 | for con in concept_mask: 452 | if con!=0: 453 | concept_vec[con]=1 454 | 455 | db_vec=np.zeros(self.entity_num) 456 | for db in dbpedia_mask: 457 | if db!=0: 458 | db_vec[db]=1 459 | 460 | return context, c_lengths, response, r_length, mask_response, mask_r_length, entity_vec, entity_vector, movie, np.array(concept_mask), np.array(dbpedia_mask), concept_vec, db_vec, rec, movies_gth, movies_num 461 | 462 | def __len__(self): 463 | return len(self.data) 464 | 465 | if __name__=='__main__': 466 | # train_dataset = dataset('data/full_data.jsonl') 467 | train_dataset = dataset('data/train_data.jsonl') 468 | val_dataset = dataset('data/valid_data.jsonl') 469 | # val_dataset = dataset('data/test_data.jsonl') 470 | match_movie_item = [] 471 | 472 | # print('val_dataset.all_movies') 473 | # print(val_dataset.all_movies) 474 | 475 | 476 | for val_movie in set(val_dataset.all_movies): 477 | if val_movie not in set(train_dataset.all_movies): 478 | # print(val_movie) 479 | match_movie_item.append(val_movie) 480 | print('-'*50) 481 | print(len(set(match_movie_item))) 482 | print('match movie(in test not in train):') 483 | print(set(match_movie_item)) 484 | print('-'*50) 485 | 486 | -------------------------------------------------------------------------------- /e2e_model.py: -------------------------------------------------------------------------------- 1 | from models.transformer import TorchGeneratorModel,_build_encoder,_build_decoder,_build_encoder_mask, _build_encoder4kg, _build_decoder4kg, _build_decoder_selection, _build_decoder_e2e_selection 2 | from models.utils import _create_embeddings,_create_entity_embeddings 3 | from models.graph import SelfAttentionLayer,SelfAttentionLayer_batch 4 | from torch_geometric.nn.conv.rgcn_conv import RGCNConv 5 | from torch_geometric.nn.conv.gcn_conv import GCNConv 6 | import pickle as pkl 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import os 11 | from collections import defaultdict 12 | import numpy as np 13 | import json 14 | 15 | def _load_kg_embeddings(entity2entityId, dim, embedding_path): 16 | kg_embeddings = torch.zeros(len(entity2entityId), dim) 17 | with open(embedding_path, 'r') as f: 18 | for line in f.readlines(): 19 | line = line.split('\t') 20 | entity = line[0] 21 | if entity not in entity2entityId: 22 | continue 23 | entityId = entity2entityId[entity] 24 | embedding = torch.Tensor(list(map(float, line[1:]))) 25 | kg_embeddings[entityId] = embedding 26 | return kg_embeddings 27 | 28 | EDGE_TYPES = [58, 172] 29 | def _edge_list(kg, n_entity, hop): 30 | edge_list = [] 31 | for h in range(hop): 32 | for entity in range(n_entity): 33 | # add self loop 34 | # edge_list.append((entity, entity)) 35 | # self_loop id = 185 36 | edge_list.append((entity, entity, 185)) 37 | if entity not in kg: 38 | continue 39 | for tail_and_relation in kg[entity]: 40 | if entity != tail_and_relation[1] and tail_and_relation[0] != 185 :# and tail_and_relation[0] in EDGE_TYPES: 41 | edge_list.append((entity, tail_and_relation[1], tail_and_relation[0])) 42 | edge_list.append((tail_and_relation[1], entity, tail_and_relation[0])) 43 | 44 | relation_cnt = defaultdict(int) 45 | relation_idx = {} 46 | for h, t, r in edge_list: 47 | relation_cnt[r] += 1 48 | for h, t, r in edge_list: 49 | if relation_cnt[r] > 1000 and r not in relation_idx: 50 | relation_idx[r] = len(relation_idx) 51 | 52 | return [(h, t, relation_idx[r]) for h, t, r in edge_list if relation_cnt[r] > 1000], len(relation_idx) 53 | 54 | def concept_edge_list4GCN(): 55 | node2index=json.load(open('key2index_3rd.json',encoding='utf-8')) 56 | f=open('conceptnet_edges2nd.txt',encoding='utf-8') 57 | edges=set() 58 | stopwords=set([word.strip() for word in open('stopwords.txt',encoding='utf-8')]) 59 | for line in f: 60 | lines=line.strip().split('\t') 61 | entity0=node2index[lines[1].split('/')[0]] 62 | entity1=node2index[lines[2].split('/')[0]] 63 | if lines[1].split('/')[0] in stopwords or lines[2].split('/')[0] in stopwords: 64 | continue 65 | edges.add((entity0,entity1)) 66 | edges.add((entity1,entity0)) 67 | edge_set=[[co[0] for co in list(edges)],[co[1] for co in list(edges)]] 68 | return torch.LongTensor(edge_set).cuda() 69 | 70 | class E2ECrossModel(nn.Module): 71 | def __init__(self, opt, dictionary, is_finetune=False, padding_idx=0, start_idx=1, end_idx=2, longest_label=1): 72 | # self.pad_idx = dictionary[dictionary.null_token] 73 | # self.start_idx = dictionary[dictionary.start_token] 74 | # self.end_idx = dictionary[dictionary.end_token] 75 | super().__init__() # self.pad_idx, self.start_idx, self.end_idx) 76 | self.batch_size = opt['batch_size'] 77 | self.max_r_length = opt['max_r_length'] 78 | self.beam = opt['beam'] 79 | self.is_finetune = is_finetune 80 | 81 | self.index2word={dictionary[key]:key for key in dictionary} 82 | 83 | self.movieID2selection_label=pkl.load(open('movieID2selection_label.pkl','rb')) 84 | 85 | self.NULL_IDX = padding_idx 86 | self.END_IDX = end_idx 87 | self.register_buffer('START', torch.LongTensor([start_idx])) 88 | self.longest_label = longest_label 89 | 90 | self.pad_idx = padding_idx 91 | self.embeddings = _create_embeddings( 92 | dictionary, opt['embedding_size'], self.pad_idx 93 | ) 94 | 95 | self.concept_embeddings=_create_entity_embeddings( 96 | opt['n_concept']+1, opt['dim'], 0) 97 | self.concept_padding=0 98 | 99 | self.kg = pkl.load( 100 | open("data/subkg.pkl", "rb") 101 | ) 102 | 103 | if opt.get('n_positions'): 104 | # if the number of positions is explicitly provided, use that 105 | n_positions = opt['n_positions'] 106 | else: 107 | # else, use the worst case from truncate 108 | n_positions = max( 109 | opt.get('truncate') or 0, 110 | opt.get('text_truncate') or 0, 111 | opt.get('label_truncate') or 0 112 | ) 113 | if n_positions == 0: 114 | # default to 1024 115 | n_positions = 1024 116 | 117 | if n_positions < 0: 118 | raise ValueError('n_positions must be positive') 119 | 120 | self.encoder = _build_encoder( 121 | opt, dictionary, self.embeddings, self.pad_idx, reduction=False, 122 | n_positions=n_positions, 123 | ) 124 | self.decoder = _build_decoder4kg( 125 | opt, dictionary, self.embeddings, self.pad_idx, 126 | n_positions=n_positions, 127 | ) 128 | 129 | self.selection_cross_attn_decoder = _build_decoder_e2e_selection( 130 | opt, len(self.movieID2selection_label), self.pad_idx, 131 | n_positions=n_positions, 132 | ) 133 | # self.selection_cross_attn_decoder = _build_decoder_selection( 134 | # opt, len(self.movieID2selection_label), self.pad_idx, 135 | # n_positions=n_positions, 136 | # ) 137 | self.db_norm = nn.Linear(opt['dim'], opt['embedding_size']) 138 | self.kg_norm = nn.Linear(opt['dim'], opt['embedding_size']) 139 | 140 | self.db_attn_norm=nn.Linear(opt['dim'],opt['embedding_size']) 141 | self.kg_attn_norm=nn.Linear(opt['dim'],opt['embedding_size']) 142 | 143 | self.enti_gcn_linear2_emb=nn.Linear(opt['dim'],opt['embedding_size']) 144 | 145 | self.criterion = nn.CrossEntropyLoss(reduce=False) 146 | 147 | self.self_attn = SelfAttentionLayer_batch(opt['dim'], opt['dim']) 148 | 149 | self.self_attn_db = SelfAttentionLayer(opt['dim'], opt['dim']) 150 | 151 | self.user_norm = nn.Linear(opt['dim']*2, opt['dim']) 152 | self.gate_norm = nn.Linear(opt['dim'], 1) 153 | self.copy_norm = nn.Linear(opt['embedding_size']*2+opt['embedding_size'], opt['embedding_size']) 154 | self.representation_bias = nn.Linear(opt['embedding_size'], len(dictionary) + 4) 155 | 156 | self.info_con_norm = nn.Linear(opt['dim'], opt['dim']) 157 | self.info_db_norm = nn.Linear(opt['dim'], opt['dim']) 158 | self.info_output_db = nn.Linear(opt['dim'], opt['n_entity']) 159 | self.info_output_con = nn.Linear(opt['dim'], opt['n_concept']+1) 160 | self.info_con_loss = nn.MSELoss(size_average=False,reduce=False) 161 | self.info_db_loss = nn.MSELoss(size_average=False,reduce=False) 162 | 163 | self.user_representation_to_bias_1 = nn.Linear(opt['dim'], 512) 164 | self.user_representation_to_bias_2 = nn.Linear(512, len(dictionary) + 4) 165 | 166 | self.output_en = nn.Linear(opt['dim'], opt['n_entity']) 167 | 168 | self.matching_linear = nn.Linear(opt['embedding_size'], opt['n_movies']) 169 | 170 | self.embedding_size=opt['embedding_size'] 171 | self.dim=opt['dim'] 172 | 173 | edge_list, self.n_relation = _edge_list(self.kg, opt['n_entity'], hop=2) 174 | edge_list = list(set(edge_list)) 175 | print(len(edge_list), self.n_relation) 176 | self.dbpedia_edge_sets=torch.LongTensor(edge_list).cuda() 177 | self.db_edge_idx = self.dbpedia_edge_sets[:, :2].t() 178 | self.db_edge_type = self.dbpedia_edge_sets[:, 2] 179 | 180 | self.dbpedia_RGCN=RGCNConv(opt['n_entity'], self.dim, self.n_relation, num_bases=opt['num_bases']) 181 | #self.concept_RGCN=RGCNConv(opt['n_concept']+1, self.dim, self.n_con_relation, num_bases=opt['num_bases']) 182 | self.concept_edge_sets=concept_edge_list4GCN() 183 | self.concept_GCN=GCNConv(self.dim, self.dim) 184 | 185 | #self.concept_GCN4gen=GCNConv(self.dim, opt['embedding_size']) 186 | 187 | w2i=json.load(open('word2index_redial.json',encoding='utf-8')) 188 | self.i2w={w2i[word]:word for word in w2i} 189 | 190 | #---------------------------- still a hack ----------------------------2020/4/22 By Jokie 191 | self.mask4key=torch.Tensor(np.load('mask4key.npy')).cuda() 192 | self.mask4movie=torch.Tensor(np.load('mask4movie.npy')).cuda() 193 | # Original mask 194 | # self.mask4=self.mask4key+self.mask4movie 195 | 196 | # tmp hack for runable By Jokie 2021/04/12 for template generation task 197 | self.mask4=torch.ones(len(dictionary) + 4).cuda() 198 | 199 | # if is_finetune: 200 | # params = [self.dbpedia_RGCN.parameters(), self.concept_GCN.parameters(), 201 | # self.concept_embeddings.parameters(), 202 | # self.self_attn.parameters(), self.self_attn_db.parameters(), self.user_norm.parameters(), 203 | # self.gate_norm.parameters(), self.output_en.parameters()] 204 | # for param in params: 205 | # for pa in param: 206 | # pa.requires_grad = False 207 | 208 | def vector2sentence(self,batch_sen): 209 | sentences=[] 210 | for sen in batch_sen.numpy().tolist(): 211 | sentence=[] 212 | for word in sen: 213 | if word>3: 214 | sentence.append(self.index2word[word]) 215 | elif word==3: 216 | sentence.append('_UNK_') 217 | sentences.append(sentence) 218 | return sentences 219 | 220 | def _starts(self, bsz): 221 | """Return bsz start tokens.""" 222 | return self.START.detach().expand(bsz, 1) 223 | 224 | def decode_greedy(self, encoder_states, encoder_states_kg, encoder_states_db, attention_kg, attention_db, bsz, maxlen): 225 | """ 226 | Greedy search 227 | 228 | :param int bsz: 229 | Batch size. Because encoder_states is model-specific, it cannot 230 | infer this automatically. 231 | 232 | :param encoder_states: 233 | Output of the encoder model. 234 | 235 | :type encoder_states: 236 | Model specific 237 | 238 | :param int maxlen: 239 | Maximum decoding length 240 | 241 | :return: 242 | pair (logits, choices) of the greedy decode 243 | 244 | :rtype: 245 | (FloatTensor[bsz, maxlen, vocab], LongTensor[bsz, maxlen]) 246 | """ 247 | xs = self._starts(bsz) 248 | incr_state = None 249 | logits = [] 250 | latents = [] 251 | for i in range(maxlen): 252 | # todo, break early if all beams saw EOS 253 | scores, incr_state = self.decoder(xs, encoder_states, encoder_states_kg, encoder_states_db, incr_state) 254 | #batch*1*hidden 255 | scores = scores[:, -1:, :] 256 | #scores = self.output(scores) 257 | kg_attn_norm = self.kg_attn_norm(attention_kg) 258 | 259 | db_attn_norm = self.db_attn_norm(attention_db) 260 | 261 | copy_latent = self.copy_norm(torch.cat([kg_attn_norm.unsqueeze(1), db_attn_norm.unsqueeze(1), scores], -1)) 262 | 263 | latents.append(scores) 264 | # latents.append(copy_latent) 265 | 266 | # logits = self.output(latent) 267 | con_logits = self.representation_bias(copy_latent)*self.mask4.unsqueeze(0).unsqueeze(0)#F.linear(copy_latent, self.embeddings.weight) 268 | voc_logits = F.linear(scores, self.embeddings.weight) 269 | # print(logits.size()) 270 | # print(mem_logits.size()) 271 | #gate = F.sigmoid(self.gen_gate_norm(scores)) 272 | 273 | sum_logits = voc_logits + con_logits #* (1 - gate) 274 | _, preds = sum_logits.max(dim=-1) 275 | 276 | #scores = F.linear(scores, self.embeddings.weight) 277 | 278 | #print(attention_map) 279 | #print(db_attention_map) 280 | #print(preds.size()) 281 | #print(con_logits.size()) 282 | #exit() 283 | #print(con_logits.squeeze(0).squeeze(0)[preds.squeeze(0).squeeze(0)]) 284 | #print(voc_logits.squeeze(0).squeeze(0)[preds.squeeze(0).squeeze(0)]) 285 | 286 | #print(torch.topk(voc_logits.squeeze(0).squeeze(0),k=50)[1]) 287 | 288 | #sum_logits = scores 289 | # print(sum_logits.size()) 290 | 291 | #_, preds = sum_logits.max(dim=-1) 292 | logits.append(sum_logits) 293 | xs = torch.cat([xs, preds], dim=1) 294 | # check if everyone has generated an end token 295 | all_finished = ((xs == self.END_IDX).sum(dim=1) > 0).sum().item() == bsz 296 | if all_finished: 297 | break 298 | logits = torch.cat(logits, 1) 299 | latents = torch.cat(latents, 1) 300 | # return logits, xs 301 | return logits, xs, latents 302 | 303 | def decode_beam_search_with_kg(self, token_encoding, encoder_states_kg, encoder_states_db, attention_kg, attention_db, maxlen=None, beam=4): 304 | entity_reps, entity_mask = encoder_states_db 305 | word_reps, word_mask = encoder_states_kg 306 | entity_emb_attn = attention_db 307 | word_emb_attn = attention_kg 308 | batch_size = token_encoding[0].shape[0] 309 | 310 | inputs = self._starts(batch_size).long().reshape(1, batch_size, -1) 311 | incr_state = None 312 | 313 | sequences = [[[list(), list(), 1.0]]] * batch_size 314 | all_latents = [] 315 | # for i in range(self.response_truncate): 316 | for i in range(maxlen): 317 | if i == 1: 318 | token_encoding = (token_encoding[0].repeat(beam, 1, 1), 319 | token_encoding[1].repeat(beam, 1, 1)) 320 | entity_reps = entity_reps.repeat(beam, 1, 1) 321 | entity_emb_attn = entity_emb_attn.repeat(beam, 1) 322 | entity_mask = entity_mask.repeat(beam, 1) 323 | word_reps = word_reps.repeat(beam, 1, 1) 324 | word_emb_attn = word_emb_attn.repeat(beam, 1) 325 | word_mask = word_mask.repeat(beam, 1) 326 | 327 | encoder_states_kg = word_reps, word_mask 328 | encoder_states_db = entity_reps, entity_mask 329 | 330 | # at beginning there is 1 candidate, when i!=0 there are 4 candidates 331 | if i != 0: 332 | inputs = [] 333 | for d in range(len(sequences[0])): 334 | for j in range(batch_size): 335 | text = sequences[j][d][0] 336 | inputs.append(text) 337 | inputs = torch.stack(inputs).reshape(beam, batch_size, -1) # (beam, batch_size, _) 338 | 339 | with torch.no_grad(): 340 | 341 | dialog_latent, incr_state = self.decoder(inputs.reshape(len(sequences[0]) * batch_size, -1), token_encoding, encoder_states_kg, encoder_states_db, incr_state) 342 | # dialog_latent, incr_state = self.conv_decoder( 343 | # inputs.reshape(len(sequences[0]) * batch_size, -1), 344 | # token_encoding, word_reps, word_mask, 345 | # entity_reps, entity_mask, incr_state 346 | # ) 347 | dialog_latent = dialog_latent[:, -1:, :] # (bs, 1, dim) 348 | 349 | concept_latent = self.kg_attn_norm(word_emb_attn).unsqueeze(1) 350 | db_latent = self.db_attn_norm(entity_emb_attn).unsqueeze(1) 351 | 352 | # print('concept_latent shape', concept_latent.shape) 353 | # print('db_latent shape', db_latent.shape) 354 | # print('dialog_latent shape', dialog_latent.shape) 355 | 356 | copy_latent = self.copy_norm(torch.cat((db_latent, concept_latent, dialog_latent), dim=-1)) 357 | 358 | # WAY1 359 | # if i != 0: 360 | # print('dialog_latent shape', dialog_latent.shape) 361 | # print('copy_latent shape', copy_latent.shape) 362 | # all_latents.append(copy_latent) 363 | #WAY2 364 | all_latents.append(copy_latent) 365 | 366 | # copy_logits = self.copy_output(copy_latent) * self.copy_mask.unsqueeze(0).unsqueeze(0) 367 | copy_logits = self.representation_bias(copy_latent)*self.mask4.unsqueeze(0).unsqueeze(0) 368 | gen_logits = F.linear(dialog_latent, self.embeddings.weight) 369 | sum_logits = copy_logits + gen_logits 370 | 371 | logits = sum_logits.reshape(len(sequences[0]), batch_size, 1, -1) 372 | # turn into probabilities,in case of negative numbers 373 | probs, preds = torch.nn.functional.softmax(logits).topk(beam, dim=-1) 374 | 375 | # (candeidate, bs, 1 , beam) during first loop, candidate=1, otherwise candidate=beam 376 | 377 | for j in range(batch_size): 378 | all_candidates = [] 379 | for n in range(len(sequences[j])): 380 | for k in range(beam): 381 | prob = sequences[j][n][2] 382 | logit = sequences[j][n][1] 383 | if logit == []: 384 | logit_tmp = logits[n][j][0].unsqueeze(0) 385 | else: 386 | logit_tmp = torch.cat((logit, logits[n][j][0].unsqueeze(0)), dim=0) 387 | seq_tmp = torch.cat((inputs[n][j].reshape(-1), preds[n][j][0][k].reshape(-1))) 388 | candidate = [seq_tmp, logit_tmp, prob * probs[n][j][0][k]] 389 | all_candidates.append(candidate) 390 | ordered = sorted(all_candidates, key=lambda tup: tup[2], reverse=True) 391 | sequences[j] = ordered[:beam] 392 | 393 | # check if everyone has generated an end token 394 | all_finished = ((inputs == self.END_IDX).sum(dim=1) > 0).sum().item() == batch_size 395 | if all_finished: 396 | break 397 | 398 | # original solution 399 | # logits = torch.stack([seq[0][1] for seq in sequences]) 400 | # inputs = torch.stack([seq[0][0] for seq in sequences]) 401 | 402 | out_logits = [] 403 | out_preds = [] 404 | for beam_num in range(beam): 405 | cur_out_logits = torch.stack([seq[beam_num][1] for seq in sequences]) 406 | curout_preds = torch.stack([seq[beam_num][0] for seq in sequences]) 407 | out_logits.append(cur_out_logits) 408 | out_preds.append(curout_preds) 409 | 410 | logits = torch.cat([x for x in out_logits], dim=0) 411 | inputs = torch.cat([x for x in out_preds], dim=0) 412 | all_latents = torch.cat(all_latents, 1) 413 | 414 | return logits, inputs, all_latents 415 | 416 | def decode_forced(self, encoder_states, encoder_states_kg, encoder_states_db, attention_kg, attention_db, ys): 417 | """ 418 | Decode with a fixed, true sequence, computing loss. Useful for 419 | training, or ranking fixed candidates. 420 | 421 | :param ys: 422 | the prediction targets. Contains both the start and end tokens. 423 | 424 | :type ys: 425 | LongTensor[bsz, time] 426 | 427 | :param encoder_states: 428 | Output of the encoder. Model specific types. 429 | 430 | :type encoder_states: 431 | model specific 432 | 433 | :return: 434 | pair (logits, choices) containing the logits and MLE predictions 435 | 436 | :rtype: 437 | (FloatTensor[bsz, ys, vocab], LongTensor[bsz, ys]) 438 | """ 439 | bsz = ys.size(0) 440 | seqlen = ys.size(1) 441 | inputs = ys.narrow(1, 0, seqlen - 1) 442 | inputs = torch.cat([self._starts(bsz), inputs], 1) 443 | latent, _ = self.decoder(inputs, encoder_states, encoder_states_kg, encoder_states_db) #batch*r_l*hidden 444 | 445 | kg_attention_latent=self.kg_attn_norm(attention_kg) 446 | 447 | #map=torch.bmm(latent,torch.transpose(kg_embs_norm,2,1)) 448 | #map_mask=((1-encoder_states_kg[1].float())*(-1e30)).unsqueeze(1) 449 | #attention_map=F.softmax(map*map_mask,dim=-1) 450 | #attention_latent=torch.bmm(attention_map,encoder_states_kg[0]) 451 | 452 | db_attention_latent=self.db_attn_norm(attention_db) 453 | 454 | #db_map=torch.bmm(latent,torch.transpose(db_embs_norm,2,1)) 455 | #db_map_mask=((1-encoder_states_db[1].float())*(-1e30)).unsqueeze(1) 456 | #db_attention_map=F.softmax(db_map*db_map_mask,dim=-1) 457 | #db_attention_latent=torch.bmm(db_attention_map,encoder_states_db[0]) 458 | 459 | copy_latent=self.copy_norm(torch.cat([kg_attention_latent.unsqueeze(1).repeat(1,seqlen,1), db_attention_latent.unsqueeze(1).repeat(1,seqlen,1), latent],-1)) 460 | 461 | #logits = self.output(latent) 462 | con_logits = self.representation_bias(copy_latent)*self.mask4.unsqueeze(0).unsqueeze(0)#F.linear(copy_latent, self.embeddings.weight) 463 | logits = F.linear(latent, self.embeddings.weight) 464 | # print('logit size', logits.size()) 465 | # print(mem_logits.size()) 466 | #gate=F.sigmoid(self.gen_gate_norm(latent)) 467 | 468 | sum_logits = logits+con_logits#*(1-gate) 469 | _, preds = sum_logits.max(dim=2) 470 | 471 | # return logits, preds, copy_latent 472 | return logits, preds, latent 473 | 474 | def infomax_loss(self, con_nodes_features, db_nodes_features, con_user_emb, db_user_emb, con_label, db_label, mask): 475 | #batch*dim 476 | #node_count*dim 477 | con_emb=self.info_con_norm(con_user_emb) 478 | db_emb=self.info_db_norm(db_user_emb) 479 | con_scores = F.linear(db_emb, con_nodes_features, self.info_output_con.bias) 480 | db_scores = F.linear(con_emb, db_nodes_features, self.info_output_db.bias) 481 | 482 | info_db_loss=torch.sum(self.info_db_loss(db_scores,db_label.cuda().float()),dim=-1)*mask.cuda() 483 | info_con_loss=torch.sum(self.info_con_loss(con_scores,con_label.cuda().float()),dim=-1)*mask.cuda() 484 | 485 | return torch.mean(info_db_loss), torch.mean(info_con_loss) 486 | 487 | def forward(self, xs, ys, mask_ys, concept_mask, db_mask, seed_sets, labels, con_label, db_label, entity_vector, rec,movies_gth=None, movie_nums=None, test=True, cand_params=None, prev_enc=None, maxlen=None, 488 | bsz=None): 489 | """ 490 | Get output predictions from the model. 491 | 492 | :param xs: 493 | input to the encoder 494 | :type xs: 495 | LongTensor[bsz, seqlen] 496 | :param ys: 497 | Expected output from the decoder. Used 498 | for teacher forcing to calculate loss. 499 | :type ys: 500 | LongTensor[bsz, outlen] 501 | :param prev_enc: 502 | if you know you'll pass in the same xs multiple times, you can pass 503 | in the encoder output from the last forward pass to skip 504 | recalcuating the same encoder output. 505 | :param maxlen: 506 | max number of tokens to decode. if not set, will use the length of 507 | the longest label this model has seen. ignored when ys is not None. 508 | :param bsz: 509 | if ys is not provided, then you must specify the bsz for greedy 510 | decoding. 511 | 512 | :return: 513 | (scores, candidate_scores, encoder_states) tuple 514 | 515 | - scores contains the model's predicted token scores. 516 | (FloatTensor[bsz, seqlen, num_features]) 517 | - candidate_scores are the score the model assigned to each candidate. 518 | (FloatTensor[bsz, num_cands]) 519 | - encoder_states are the output of model.encoder. Model specific types. 520 | Feed this back in to skip encoding on the next call. 521 | """ 522 | if test == False: 523 | # TODO: get rid of longest_label 524 | # keep track of longest label we've ever seen 525 | # we'll never produce longer ones than that during prediction 526 | self.longest_label = max(self.longest_label, ys.size(1)) 527 | 528 | #------------------------- original rec sys------------------------------------------------------ 529 | # graph network 530 | db_nodes_features = self.dbpedia_RGCN(None, self.db_edge_idx, self.db_edge_type) 531 | con_nodes_features=self.concept_GCN(self.concept_embeddings.weight,self.concept_edge_sets) 532 | 533 | user_representation_list = [] 534 | db_con_mask=[] 535 | for i, seed_set in enumerate(seed_sets): 536 | if seed_set == []: 537 | user_representation_list.append(torch.zeros(self.dim).cuda()) 538 | db_con_mask.append(torch.zeros([1])) 539 | continue 540 | user_representation = db_nodes_features[seed_set] # torch can reflect 541 | user_representation = self.self_attn_db(user_representation) 542 | user_representation_list.append(user_representation) 543 | db_con_mask.append(torch.ones([1])) 544 | 545 | db_user_emb=torch.stack(user_representation_list) 546 | db_con_mask=torch.stack(db_con_mask) 547 | 548 | graph_con_emb=con_nodes_features[concept_mask] 549 | con_emb_mask=concept_mask==self.concept_padding 550 | 551 | con_user_emb=graph_con_emb 552 | con_user_emb,attention=self.self_attn(con_user_emb,con_emb_mask.cuda()) 553 | 554 | #-------------------generation--------------------------------------------------------------------------------------------------- 555 | encoder_states = prev_enc if prev_enc is not None else self.encoder(xs) 556 | con_nodes_features4gen=con_nodes_features#self.concept_GCN4gen(con_nodes_features,self.concept_edge_sets) 557 | con_emb4gen = con_nodes_features4gen[concept_mask] 558 | con_mask4gen = concept_mask != self.concept_padding 559 | #kg_encoding=self.kg_encoder(con_emb4gen.cuda(),con_mask4gen.cuda()) 560 | kg_encoding=(self.kg_norm(con_emb4gen),con_mask4gen.cuda()) 561 | 562 | db_emb4gen=db_nodes_features[entity_vector] #batch*50*dim 563 | db_mask4gen=entity_vector!=0 564 | #db_encoding=self.db_encoder(db_emb4gen.cuda(),db_mask4gen.cuda()) 565 | db_encoding=(self.db_norm(db_emb4gen),db_mask4gen.cuda()) 566 | 567 | if test == False: 568 | # use teacher forcing scores, pred: (FloatTensor[bsz, ys, vocab], LongTensor[bsz, ys]) 569 | # print('shape of entity_scores', entity_scores.shape) 570 | # print('shape of rec label', labels.shape) 571 | # print('rec label', labels) 572 | # print('shape of movie label', movies_gth.shape) 573 | movies_gth = movies_gth * (movies_gth!=-1) 574 | 575 | # print('num of movies_gth', torch.sum(movies_gth!=0, dim=(0,1))) 576 | # print('num of gth masked hole', torch.sum((mask_ys == 6), dim=(0,1))) 577 | # print('movie_nums', movie_nums) 578 | # print('__MOVIE__ position ', torch.sum((mask_ys == 6), dim=(1))) 579 | assert torch.sum(movies_gth!=0, dim=(0,1)) == torch.sum((mask_ys == 6), dim=(0,1)) 580 | 581 | # cant run case : case 1 : [-15] case2: [-8] By Jokie tmp 2021/4/14 582 | # print(movies_gth[-15]) 583 | # print(mask_ys[-15]) 584 | # print(self.vector2sentence(mask_ys.cpu())[-15]) 585 | 586 | # print('shape of encoder_states,kg_encoding,db_encoding,con_user_emb, db_user_emb, mask_ys', encoder_states[0].shape,kg_encoding[0].shape,db_encoding[0].shape,con_user_emb.shape, db_user_emb.shape, mask_ys.shape) 587 | scores, preds, latent = self.decode_forced(encoder_states, kg_encoding, db_encoding, con_user_emb, db_user_emb, mask_ys) 588 | # print('shape of scores,preds, mask_ys, latent', scores.shape,preds.shape,mask_ys.shape,latent.shape) 589 | gen_loss = torch.mean(self.compute_loss(scores, mask_ys)) 590 | 591 | #-------------------------------- stage2 movie selection loss-------------- by Jokie 592 | 593 | masked_for_selection_token = (mask_ys == 6) 594 | 595 | #WAY1: simply linear 596 | # selected_token_latent = torch.masked_select(latent, masked_for_selection_token.unsqueeze(-1).expand_as(latent)).view(-1, latent.shape[-1]) 597 | # matching_logits = self.matching_linear(selected_token_latent) 598 | 599 | #WAY2: self attn 600 | matching_tensor, _ = self.selection_cross_attn_decoder(latent, encoder_states, db_encoding, kg_encoding) 601 | matching_logits_ = self.matching_linear(matching_tensor) 602 | 603 | matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) 604 | 605 | _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] 606 | movies_gth = torch.masked_select(movies_gth, (movies_gth!=0)) 607 | selection_loss = torch.mean(self.compute_loss(matching_logits, movies_gth)) # movies_gth.squeeze(0):[bsz * dynamic_movie_nums] 608 | 609 | 610 | else: 611 | #---------------------------------------------Beam Search decode---------------------------------------- 612 | # scores, preds, latent = self.decode_beam_search_with_kg( 613 | # encoder_states, kg_encoding, db_encoding, con_user_emb, db_user_emb, 614 | # maxlen, self.beam) 615 | # # #pred here is soft template prediction 616 | # # # --------------post process the prediction to full sentence 617 | # # #-------------------------------- stage2 movie selection loss-------------- by Jokie 618 | # preds_for_selection = preds[:, 1:] # skip the start_ind 619 | # # preds_for_selection = preds[:, 2:] # skip the start_ind 620 | # masked_for_selection_token = (preds_for_selection == 6) 621 | 622 | # # print('latent shape', latent.shape) 623 | # # print('preds_for_selection: ', preds_for_selection) 624 | # # print('masked_for_selection_token shape', masked_for_selection_token.shape) 625 | 626 | # selected_token_latent = torch.masked_select(latent, masked_for_selection_token.unsqueeze(-1).expand_as(latent)).view(-1, latent.shape[-1]) 627 | # print('selected_token_latent shape: ' , selected_token_latent) 628 | # matching_logits = self.matching_linear(selected_token_latent) 629 | 630 | # _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] 631 | # # print('matching_pred', matching_pred.shape) 632 | 633 | 634 | #---------------------------------------------Greedy decode------------------------------------------- 635 | scores, preds, latent = self.decode_greedy( 636 | encoder_states, kg_encoding, db_encoding, con_user_emb, db_user_emb, 637 | bsz, 638 | maxlen or self.longest_label 639 | ) 640 | 641 | # #pred here is soft template prediction 642 | # # --------------post process the prediction to full sentence 643 | # #-------------------------------- stage2 movie selection loss-------------- by Jokie 644 | preds_for_selection = preds[:, 1:] # skip the start_ind 645 | masked_for_selection_token = (preds_for_selection == 6) 646 | 647 | 648 | #WAY1: simply linear 649 | # selected_token_latent = torch.masked_select(latent, masked_for_selection_token.unsqueeze(-1).expand_as(latent)).view(-1, latent.shape[-1]) 650 | # matching_logits = self.matching_linear(selected_token_latent) 651 | 652 | #WAY2: self attn 653 | matching_tensor, _ = self.selection_cross_attn_decoder(latent, encoder_states, db_encoding, kg_encoding) 654 | matching_logits_ = self.matching_linear(matching_tensor) 655 | 656 | matching_logits = torch.masked_select(matching_logits_, masked_for_selection_token.unsqueeze(-1).expand_as(matching_logits_)).view(-1, matching_logits_.shape[-1]) 657 | 658 | if matching_logits.shape[0] is not 0: 659 | _, matching_pred = matching_logits.max(dim=-1) # [bsz * dynamic_movie_nums] 660 | else: 661 | matching_pred = None 662 | # print('matching_pred', matching_pred.shape) 663 | #---------------------------------------------Greedy decode(end)------------------------------------------- 664 | 665 | gen_loss = None 666 | selection_loss = None 667 | 668 | return scores, preds, None, None, gen_loss, None, None, None, selection_loss, matching_pred, matching_logits_ 669 | 670 | 671 | def reorder_encoder_states(self, encoder_states, indices): 672 | """ 673 | Reorder encoder states according to a new set of indices. 674 | 675 | This is an abstract method, and *must* be implemented by the user. 676 | 677 | Its purpose is to provide beam search with a model-agnostic interface for 678 | beam search. For example, this method is used to sort hypotheses, 679 | expand beams, etc. 680 | 681 | For example, assume that encoder_states is an bsz x 1 tensor of values 682 | 683 | .. code-block:: python 684 | 685 | indices = [0, 2, 2] 686 | encoder_states = [[0.1] 687 | [0.2] 688 | [0.3]] 689 | 690 | then the output will be 691 | 692 | .. code-block:: python 693 | 694 | output = [[0.1] 695 | [0.3] 696 | [0.3]] 697 | 698 | :param encoder_states: 699 | output from encoder. type is model specific. 700 | 701 | :type encoder_states: 702 | model specific 703 | 704 | :param indices: 705 | the indices to select over. The user must support non-tensor 706 | inputs. 707 | 708 | :type indices: list[int] 709 | 710 | :return: 711 | The re-ordered encoder states. It should be of the same type as 712 | encoder states, and it must be a valid input to the decoder. 713 | 714 | :rtype: 715 | model specific 716 | """ 717 | enc, mask = encoder_states 718 | if not torch.is_tensor(indices): 719 | indices = torch.LongTensor(indices).to(enc.device) 720 | enc = torch.index_select(enc, 0, indices) 721 | mask = torch.index_select(mask, 0, indices) 722 | return enc, mask 723 | 724 | def reorder_decoder_incremental_state(self, incremental_state, inds): 725 | """ 726 | Reorder incremental state for the decoder. 727 | 728 | Used to expand selected beams in beam_search. Unlike reorder_encoder_states, 729 | implementing this method is optional. However, without incremental decoding, 730 | decoding a single beam becomes O(n^2) instead of O(n), which can make 731 | beam search impractically slow. 732 | 733 | In order to fall back to non-incremental decoding, just return None from this 734 | method. 735 | 736 | :param incremental_state: 737 | second output of model.decoder 738 | :type incremental_state: 739 | model specific 740 | :param inds: 741 | indices to select and reorder over. 742 | :type inds: 743 | LongTensor[n] 744 | 745 | :return: 746 | The re-ordered decoder incremental states. It should be the same 747 | type as incremental_state, and usable as an input to the decoder. 748 | This method should return None if the model does not support 749 | incremental decoding. 750 | 751 | :rtype: 752 | model specific 753 | """ 754 | # no support for incremental decoding at this time 755 | return None 756 | 757 | def compute_loss(self, output, scores): 758 | score_view = scores.view(-1) 759 | output_view = output.view(-1, output.size(-1)) 760 | loss = self.criterion(output_view.cuda(), score_view.cuda()) 761 | return loss 762 | 763 | def save_model(self,model_name='saved_model/net_parameter1.pkl'): 764 | torch.save(self.state_dict(), model_name) 765 | 766 | def load_model(self,model_name='saved_model/net_parameter1.pkl'): 767 | # self.load_state_dict(torch.load('saved_model/net_parameter1.pkl')) 768 | self.load_state_dict(torch.load(model_name), strict= False) 769 | 770 | def output(self, tensor): 771 | # project back to vocabulary 772 | output = F.linear(tensor, self.embeddings.weight) 773 | up_bias = self.user_representation_to_bias_2(F.relu(self.user_representation_to_bias_1(self.user_rep))) 774 | # up_bias = self.user_representation_to_bias_3(F.relu(self.user_representation_to_bias_2(F.relu(self.user_representation_to_bias_1(self.user_representation))))) 775 | # Expand to the whole sequence 776 | up_bias = up_bias.unsqueeze(dim=1) 777 | output += up_bias 778 | return output 779 | -------------------------------------------------------------------------------- /e2e_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """The standard way to train a model. After training, also computes validation 7 | and test error. 8 | 9 | The user must provide a model (with ``--model``) and a task (with ``--task`` or 10 | ``--pytorch-teacher-task``). 11 | 12 | Examples 13 | -------- 14 | 15 | .. code-block:: shell 16 | 17 | python -m parlai.scripts.train -m ir_baseline -t dialog_babi:Task:1 -mf /tmp/model 18 | python -m parlai.scripts.train -m seq2seq -t babi:Task10k:1 -mf '/tmp/model' -bs 32 -lr 0.5 -hs 128 19 | python -m parlai.scripts.train -m drqa -t babi:Task10k:1 -mf /tmp/model -bs 10 20 | 21 | """ # noqa: E501 22 | 23 | # TODO List: 24 | # * More logging (e.g. to files), make things prettier. 25 | 26 | import numpy as np 27 | from tqdm import tqdm 28 | from math import exp 29 | import os 30 | os.environ['CUDA_VISIBLE_DEVICES']='3' 31 | import signal 32 | import json 33 | import argparse 34 | import pickle as pkl 35 | from dataset import dataset,CRSdataset 36 | from e2e_model import E2ECrossModel 37 | import torch.nn as nn 38 | from torch import optim 39 | import torch 40 | try: 41 | import torch.version 42 | import torch.distributed as dist 43 | TORCH_AVAILABLE = True 44 | except ImportError: 45 | TORCH_AVAILABLE = False 46 | from nltk.translate.bleu_score import sentence_bleu 47 | 48 | def is_distributed(): 49 | """ 50 | Returns True if we are in distributed mode. 51 | """ 52 | return TORCH_AVAILABLE and dist.is_available() and dist.is_initialized() 53 | 54 | def setup_args(): 55 | train = argparse.ArgumentParser() 56 | train.add_argument("-max_c_length","--max_c_length",type=int,default=256) 57 | train.add_argument("-max_r_length","--max_r_length",type=int,default=30) 58 | train.add_argument("-beam","--beam",type=int,default=1) 59 | # train.add_argument("-max_r_length","--max_r_length",type=int,default=256) 60 | train.add_argument("-batch_size","--batch_size",type=int,default=128) 61 | train.add_argument("-max_count","--max_count",type=int,default=5) 62 | train.add_argument("-use_cuda","--use_cuda",type=bool,default=True) 63 | train.add_argument("-is_template","--is_template",type=bool,default=True) 64 | train.add_argument("-infomax_pretrain","--infomax_pretrain",type=bool,default=False) 65 | train.add_argument("-load_dict","--load_dict",type=str,default=None) 66 | train.add_argument("-learningrate","--learningrate",type=float,default=1e-3) 67 | train.add_argument("-optimizer","--optimizer",type=str,default='adam') 68 | train.add_argument("-momentum","--momentum",type=float,default=0) 69 | train.add_argument("-is_finetune","--is_finetune",type=bool,default=False) 70 | train.add_argument("-embedding_type","--embedding_type",type=str,default='random') 71 | train.add_argument("-save_exp_name","--save_exp_name",type=str,default='saved_model/sattn_e2eCRS') 72 | train.add_argument("-epoch","--epoch",type=int,default=50) 73 | train.add_argument("-gpu","--gpu",type=str,default='2') 74 | train.add_argument("-gradient_clip","--gradient_clip",type=float,default=0.1) 75 | train.add_argument("-embedding_size","--embedding_size",type=int,default=300) 76 | 77 | train.add_argument("-n_heads","--n_heads",type=int,default=2) 78 | train.add_argument("-n_layers","--n_layers",type=int,default=2) 79 | train.add_argument("-ffn_size","--ffn_size",type=int,default=300) 80 | 81 | train.add_argument("-dropout","--dropout",type=float,default=0.1) 82 | train.add_argument("-attention_dropout","--attention_dropout",type=float,default=0.0) 83 | train.add_argument("-relu_dropout","--relu_dropout",type=float,default=0.1) 84 | 85 | train.add_argument("-learn_positional_embeddings","--learn_positional_embeddings",type=bool,default=False) 86 | train.add_argument("-embeddings_scale","--embeddings_scale",type=bool,default=True) 87 | 88 | train.add_argument("-n_movies","--n_movies",type=int,default=6924) 89 | train.add_argument("-n_entity","--n_entity",type=int,default=64368) 90 | train.add_argument("-n_relation","--n_relation",type=int,default=214) 91 | train.add_argument("-n_concept","--n_concept",type=int,default=29308) 92 | train.add_argument("-n_con_relation","--n_con_relation",type=int,default=48) 93 | train.add_argument("-dim","--dim",type=int,default=128) 94 | train.add_argument("-n_hop","--n_hop",type=int,default=2) 95 | train.add_argument("-kge_weight","--kge_weight",type=float,default=1) 96 | train.add_argument("-l2_weight","--l2_weight",type=float,default=2.5e-6) 97 | train.add_argument("-n_memory","--n_memory",type=float,default=32) 98 | train.add_argument("-item_update_mode","--item_update_mode",type=str,default='0,1') 99 | train.add_argument("-using_all_hops","--using_all_hops",type=bool,default=True) 100 | train.add_argument("-num_bases", "--num_bases", type=int, default=8) 101 | 102 | 103 | 104 | 105 | 106 | return train 107 | 108 | class TrainLoop_fusion_e2e(): 109 | def __init__(self, opt, is_finetune): 110 | self.opt=opt 111 | self.train_dataset=dataset('data/train_data.jsonl',opt) 112 | 113 | self.dict=self.train_dataset.word2index 114 | self.index2word={self.dict[key]:key for key in self.dict} 115 | 116 | self.movieID2selection_label=pkl.load(open('movieID2selection_label.pkl','rb')) 117 | self.selection_label2movieID={self.movieID2selection_label[key]:key for key in self.movieID2selection_label} 118 | self.id2entity=pkl.load(open('data/id2entity.pkl','rb')) 119 | 120 | self.batch_size=self.opt['batch_size'] 121 | self.epoch=self.opt['epoch'] 122 | 123 | self.use_cuda=opt['use_cuda'] 124 | if opt['load_dict']!=None: 125 | self.load_data=True 126 | else: 127 | self.load_data=False 128 | self.is_finetune=False 129 | 130 | self.is_template = opt['is_template'] 131 | 132 | self.movie_ids = pkl.load(open("data/movie_ids.pkl", "rb")) 133 | # Note: we cannot change the type of metrics ahead of time, so you 134 | # should correctly initialize to floats or ints here 135 | 136 | self.metrics_rec={"recall@1":0,"recall@10":0,"recall@50":0,"loss":0,"count":0} 137 | self.metrics_gen={"dist1":0,"dist2":0,"dist3":0,"dist4":0,"bleu1":0,"bleu2":0,"bleu3":0,"bleu4":0,"count":0, "true_recall_movie_count":0, "res_movie_recall":0.0,"recall@1":0,"recall@10":0,"recall@50":0} 138 | 139 | self.build_model(is_finetune=True) 140 | 141 | if opt['load_dict'] is not None: 142 | # load model parameters if available 143 | print('[ Loading existing model params from {} ]' 144 | ''.format(opt['load_dict'])) 145 | states = self.model.load(opt['load_dict']) 146 | else: 147 | states = {} 148 | 149 | self.init_optim( 150 | [p for p in self.model.parameters() if p.requires_grad], 151 | optim_states=states.get('optimizer'), 152 | saved_optim_type=states.get('optimizer_type') 153 | ) 154 | 155 | def build_model(self,is_finetune): 156 | self.model = E2ECrossModel(self.opt, self.dict, is_finetune) 157 | if self.opt['embedding_type'] != 'random': 158 | pass 159 | if self.use_cuda: 160 | self.model.cuda() 161 | 162 | def train(self): 163 | # self.model.load_model() 164 | losses=[] 165 | best_val_gen=0 166 | best_val_rec=0.0 167 | gen_stop=False 168 | for i in range(self.epoch*3): 169 | train_set=CRSdataset(self.train_dataset.data_process(True),self.opt['n_entity'],self.opt['n_concept']) 170 | train_dataset_loader = torch.utils.data.DataLoader(dataset=train_set, 171 | batch_size=self.batch_size, 172 | shuffle=False) 173 | num=0 174 | for context,c_lengths,response,r_length,mask_response,mask_r_length,entity,entity_vector,movie,concept_mask,dbpedia_mask,concept_vec, db_vec,rec,movies_gth,movie_nums in tqdm(train_dataset_loader): 175 | seed_sets = [] 176 | batch_size = context.shape[0] 177 | for b in range(batch_size): 178 | seed_set = entity[b].nonzero().view(-1).tolist() 179 | seed_sets.append(seed_set) 180 | self.model.train() 181 | self.zero_grad() 182 | 183 | scores, preds, rec_scores, rec_loss, gen_loss, mask_loss, info_db_loss, info_con_loss, selection_loss, matching_pred, matching_scores=self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec,movies_gth.cuda(),movie_nums, test=False) 184 | 185 | joint_loss=gen_loss + selection_loss 186 | 187 | losses.append([gen_loss, selection_loss]) 188 | self.backward(joint_loss) 189 | self.update_params() 190 | if num%20==0: 191 | print('gen loss is %f'%(sum([l[0] for l in losses])/len(losses))) 192 | print('selection_loss is %f'%(sum([l[1] for l in losses])/len(losses))) 193 | losses=[] 194 | num+=1 195 | 196 | output_metrics_gen = self.val(True) 197 | if best_val_gen > output_metrics_gen["dist4"]: 198 | pass 199 | else: 200 | best_val_gen = output_metrics_gen["dist4"] 201 | self.model.save_model(model_name= self.opt['save_exp_name'] + '_best_dist4.pkl') 202 | print("generator model saved once------------------------------------------------") 203 | print("best dist4 is :", best_val_gen) 204 | 205 | if best_val_rec > output_metrics_gen["res_movie_recall"]: 206 | pass 207 | else: 208 | best_val_rec = output_metrics_gen["res_movie_recall"] 209 | self.model.save_model(model_name= self.opt['save_exp_name'] + '_best_Rec1.pkl') 210 | print("generator model saved once------------------------------------------------") 211 | print("best res_movie_R@1 is :", best_val_rec) 212 | 213 | # if i % 5 ==0: # save each 5 epoch 214 | # model_name = self.opt['save_exp_name'] + '_' + str(i) + '.pkl' 215 | # self.model.save_model(model_name=model_name) 216 | # print("generator model saved once------------------------------------------------") 217 | # print('cur selection_loss is %f'%(sum([l[1] for l in losses])/len(losses))) 218 | 219 | _=self.val(is_test=True) 220 | 221 | def val(self,is_test=False): 222 | self.metrics_gen={"ppl":0,"dist1":0,"dist2":0,"dist3":0,"dist4":0,"bleu1":0,"bleu2":0,"bleu3":0,"bleu4":0,"count":0,"true_recall_movie_count":0, "res_movie_recall":0.0,"recall@1":0,"recall@10":0,"recall@50":0} 223 | self.metrics_rec={"recall@1":0,"recall@10":0,"recall@50":0,"loss":0,"gate":0,"count":0,'gate_count':0} 224 | self.model.eval() 225 | if is_test: 226 | val_dataset = dataset('data/test_data.jsonl', self.opt) 227 | else: 228 | val_dataset = dataset('data/valid_data.jsonl', self.opt) 229 | val_set=CRSdataset(val_dataset.data_process(True),self.opt['n_entity'],self.opt['n_concept']) 230 | val_dataset_loader = torch.utils.data.DataLoader(dataset=val_set, 231 | batch_size=self.batch_size, 232 | shuffle=False) 233 | inference_sum=[] 234 | golden_sum=[] 235 | context_sum=[] 236 | losses=[] 237 | recs=[] 238 | for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec,movies_gth,movie_nums in tqdm(val_dataset_loader): 239 | with torch.no_grad(): 240 | seed_sets = [] 241 | batch_size = context.shape[0] 242 | for b in range(batch_size): 243 | seed_set = entity[b].nonzero().view(-1).tolist() 244 | seed_sets.append(seed_set) 245 | 246 | #-----dump , run the first time only to get the gen_loss, could be optimized here ------By Jokie 2021/04/15 247 | _, _, _, _, gen_loss, mask_loss, info_db_loss, info_con_loss, selection_loss, _, _ = self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec,movies_gth.cuda(),movie_nums, test=False) 248 | scores, preds, rec_scores, rec_loss, _, mask_loss, info_db_loss, info_con_loss, selection_loss, matching_pred, matching_scores = self.model(context.cuda(), response.cuda(), mask_response.cuda(), concept_mask, dbpedia_mask, seed_sets, movie, concept_vec, db_vec, entity_vector.cuda(), rec,movies_gth.cuda(),movie_nums,test=True, maxlen=20, bsz=batch_size) 249 | 250 | self.all_response_movie_recall_cal(preds.cpu(), matching_scores.cpu(),movies_gth.cpu()) 251 | 252 | # golden_sum.extend(self.vector2sentence(response.cpu())) 253 | # inference_sum.extend(self.vector2sentence(preds.cpu())) 254 | # context_sum.extend(self.vector2sentence(context.cpu())) 255 | 256 | #-----------template pro-process gth response and prediction-------------------- 257 | if self.is_template: 258 | golden_sum.extend(self.template_vector2sentence(response.cpu(), movies_gth.cpu())) 259 | if matching_pred is not None: 260 | inference_sum.extend(self.template_vector2sentence(preds.cpu(), matching_pred.cpu())) 261 | else: 262 | inference_sum.extend(self.template_vector2sentence(preds.cpu(), None)) 263 | 264 | else: 265 | golden_sum.extend(self.vector2sentence(response.cpu())) 266 | inference_sum.extend(self.vector2sentence(preds.cpu())) 267 | context_sum.extend(self.vector2sentence(context.cpu())) 268 | 269 | 270 | recs.extend(rec.cpu()) 271 | losses.append(torch.mean(gen_loss)) 272 | #print(losses) 273 | #exit() 274 | 275 | self.metrics_cal_gen(losses,inference_sum,golden_sum,recs, beam=self.opt['beam']) 276 | 277 | output_dict_gen={} 278 | for key in self.metrics_gen: 279 | if 'bleu' in key: 280 | output_dict_gen[key]=self.metrics_gen[key]/self.metrics_gen['count'] 281 | else: 282 | output_dict_gen[key]=self.metrics_gen[key] 283 | print(output_dict_gen) 284 | 285 | # f=open('context_test.txt','w',encoding='utf-8') 286 | # f.writelines([' '.join(sen)+'\n' for sen in context_sum]) 287 | # f.close() 288 | 289 | f=open('self_attn_best_rec_output.txt','w',encoding='utf-8') 290 | f.writelines([' '.join(sen)+'\n' for sen in inference_sum]) 291 | f.close() 292 | 293 | # f=open('golden_test.txt','w',encoding='utf-8') 294 | # f.writelines([' '.join(sen)+'\n' for sen in golden_sum]) 295 | # f.close() 296 | 297 | # f=open('case_visualize.txt','w',encoding='utf-8') 298 | # for cont, hypo, gold in zip(context_sum, inference_sum, golden_sum): 299 | # f.writelines('context: '+' '.join(cont)+'\n') 300 | # f.writelines('hypo: '+' '.join(hypo)+'\n') 301 | # f.writelines('gold: '+' '.join(gold)+'\n') 302 | # f.writelines('\n') 303 | # f.close() 304 | 305 | return output_dict_gen 306 | 307 | def all_response_movie_recall_cal(self,decode_preds, matching_scores,labels): 308 | 309 | # matching_scores is non-mask version [bsz, seq_len, matching_vocab] 310 | # decode_preds [bsz, seq_len] 311 | # labels [bsz, movie_length_with_padding] 312 | # print('decode_preds shape', decode_preds.shape) 313 | # print('matching_scores shape', matching_scores.shape) 314 | # print('labels shape', labels.shape) 315 | decode_preds = decode_preds[:, 1:] 316 | 317 | labels = labels * (labels!=-1) # removing the padding token 318 | 319 | batch_size, seq_len = decode_preds.shape[0], decode_preds.shape[1] 320 | for cur_b in range(batch_size): 321 | for cur_seq_len in range(seq_len): 322 | if decode_preds[cur_b][cur_seq_len] ==6: # word id is 6 323 | _, pred_idx = torch.topk(matching_scores[cur_b][cur_seq_len], k=100, dim=-1) 324 | targets = labels[cur_b] 325 | for target in targets: 326 | self.metrics_gen["recall@1"] += int(target in pred_idx[:1].tolist()) 327 | self.metrics_gen["recall@10"] += int(target in pred_idx[:10].tolist()) 328 | self.metrics_gen["recall@50"] += int(target in pred_idx[:50].tolist()) 329 | 330 | 331 | def metrics_cal_gen(self,rec_loss,preds,responses,recs, beam=1): 332 | def bleu_cal(sen1, tar1): 333 | bleu1 = sentence_bleu([tar1], sen1, weights=(1, 0, 0, 0)) 334 | bleu2 = sentence_bleu([tar1], sen1, weights=(0, 1, 0, 0)) 335 | bleu3 = sentence_bleu([tar1], sen1, weights=(0, 0, 1, 0)) 336 | bleu4 = sentence_bleu([tar1], sen1, weights=(0, 0, 0, 1)) 337 | return bleu1, bleu2, bleu3, bleu4 338 | 339 | def response_movie_recall_cal(sen1, tar1): 340 | for word in sen1: 341 | if '@' in word: # if is movie 342 | if word in tar1: # if in gth 343 | return int(1) 344 | else: 345 | return int(0) 346 | return int(0) 347 | 348 | def distinct_metrics(outs): 349 | # outputs is a list which contains several sentences, each sentence contains several words 350 | unigram_count = 0 351 | bigram_count = 0 352 | trigram_count=0 353 | quagram_count=0 354 | unigram_set = set() 355 | bigram_set = set() 356 | trigram_set=set() 357 | quagram_set=set() 358 | for sen in outs: 359 | for word in sen: 360 | unigram_count += 1 361 | unigram_set.add(word) 362 | for start in range(len(sen) - 1): 363 | bg = str(sen[start]) + ' ' + str(sen[start + 1]) 364 | bigram_count += 1 365 | bigram_set.add(bg) 366 | for start in range(len(sen)-2): 367 | trg=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) 368 | trigram_count+=1 369 | trigram_set.add(trg) 370 | for start in range(len(sen)-3): 371 | quag=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) + ' ' + str(sen[start + 3]) 372 | quagram_count+=1 373 | quagram_set.add(quag) 374 | dis1 = len(unigram_set) / len(outs)#unigram_count 375 | dis2 = len(bigram_set) / len(outs)#bigram_count 376 | dis3 = len(trigram_set)/len(outs)#trigram_count 377 | dis4 = len(quagram_set)/len(outs)#quagram_count 378 | return dis1, dis2, dis3, dis4 379 | 380 | predict_s=preds 381 | golden_s=responses 382 | #print(rec_loss[0]) 383 | self.metrics_gen["ppl"]+=sum([exp(ppl) for ppl in rec_loss])/len(rec_loss) 384 | generated=[] 385 | total_movie_gth_response_cnt = 0 386 | have_movie_res_cnt = 0 387 | loop = 0 388 | # for out, tar, rec in zip(predict_s, golden_s, recs): 389 | for out in predict_s: 390 | tar = golden_s[loop // beam] 391 | loop = loop+1 392 | bleu1, bleu2, bleu3, bleu4=bleu_cal(out, tar) 393 | generated.append(out) 394 | self.metrics_gen['bleu1']+=bleu1 395 | self.metrics_gen['bleu2']+=bleu2 396 | self.metrics_gen['bleu3']+=bleu3 397 | self.metrics_gen['bleu4']+=bleu4 398 | self.metrics_gen['count']+=1 399 | self.metrics_gen['true_recall_movie_count']+=response_movie_recall_cal(out, tar) 400 | 401 | for tar in golden_s: 402 | for word in tar: 403 | if '@' in word: 404 | total_movie_gth_response_cnt+=1 405 | for word in tar: 406 | if '@' in word: 407 | have_movie_res_cnt+=1 408 | break 409 | 410 | dis1, dis2, dis3, dis4=distinct_metrics(generated) 411 | self.metrics_gen['dist1']=dis1 412 | self.metrics_gen['dist2']=dis2 413 | self.metrics_gen['dist3']=dis3 414 | self.metrics_gen['dist4']=dis4 415 | 416 | self.metrics_gen['res_movie_recall'] = self.metrics_gen['true_recall_movie_count'] / have_movie_res_cnt 417 | self.metrics_gen["recall@1"] = self.metrics_gen["recall@1"] / have_movie_res_cnt 418 | self.metrics_gen["recall@10"] = self.metrics_gen["recall@10"] / have_movie_res_cnt 419 | self.metrics_gen["recall@50"] = self.metrics_gen["recall@50"] / have_movie_res_cnt 420 | print('total_movie_gth_response_cnt: ', total_movie_gth_response_cnt) 421 | print('have_movie_res_cnt: ', have_movie_res_cnt) 422 | 423 | def vector2sentence(self,batch_sen): 424 | sentences=[] 425 | for sen in batch_sen.numpy().tolist(): 426 | sentence=[] 427 | for word in sen: 428 | if word>3: 429 | sentence.append(self.index2word[word]) 430 | # if word==6: #if MOVIE token 431 | # sentence.append(self.selection_label2movieID[selection_label]) 432 | elif word==3: 433 | sentence.append('_UNK_') 434 | sentences.append(sentence) 435 | return sentences 436 | 437 | def template_vector2sentence(self,batch_sen, batch_selection_pred): 438 | sentences=[] 439 | all_movie_labels = [] 440 | if batch_selection_pred is not None: 441 | batch_selection_pred = batch_selection_pred * (batch_selection_pred!=-1) 442 | batch_selection_pred = torch.masked_select(batch_selection_pred, (batch_selection_pred!=0)) 443 | for movie in batch_selection_pred.numpy().tolist(): 444 | all_movie_labels.append(movie) 445 | 446 | # print('all_movie_labels:', all_movie_labels) 447 | curr_movie_token = 0 448 | for sen in batch_sen.numpy().tolist(): 449 | sentence=[] 450 | for word in sen: 451 | if word>3: 452 | if word==6: #if MOVIE token 453 | # print('all_movie_labels[curr_movie_token]',all_movie_labels[curr_movie_token]) 454 | # print('selection_label2movieID',self.selection_label2movieID[all_movie_labels[curr_movie_token]]) 455 | 456 | # WAY1: original method 457 | sentence.append('@' + str(self.selection_label2movieID[all_movie_labels[curr_movie_token]])) 458 | 459 | 460 | # WAY2: print out the movie name, but should comment when calculating the gen metrics 461 | # if self.id2entity[self.selection_label2movieID[all_movie_labels[curr_movie_token]]] is not None: 462 | # sentence.append(self.id2entity[self.selection_label2movieID[all_movie_labels[curr_movie_token]]].split('/')[-1]) 463 | # else: 464 | # sentence.append('@' + str(self.selection_label2movieID[all_movie_labels[curr_movie_token]])) 465 | 466 | 467 | curr_movie_token +=1 468 | else: 469 | sentence.append(self.index2word[word]) 470 | 471 | elif word==3: 472 | sentence.append('_UNK_') 473 | sentences.append(sentence) 474 | 475 | # print('[DEBUG]sentence : ') 476 | # print(u' '.join(sentence).encode('utf-8').strip()) 477 | 478 | assert curr_movie_token == len(all_movie_labels) 479 | return sentences 480 | 481 | @classmethod 482 | def optim_opts(self): 483 | """ 484 | Fetch optimizer selection. 485 | 486 | By default, collects everything in torch.optim, as well as importing: 487 | - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim 488 | 489 | Override this (and probably call super()) to add your own optimizers. 490 | """ 491 | # first pull torch.optim in 492 | optims = {k.lower(): v for k, v in optim.__dict__.items() 493 | if not k.startswith('__') and k[0].isupper()} 494 | try: 495 | import apex.optimizers.fused_adam as fused_adam 496 | optims['fused_adam'] = fused_adam.FusedAdam 497 | except ImportError: 498 | pass 499 | 500 | try: 501 | # https://openreview.net/pdf?id=S1fUpoR5FQ 502 | from qhoptim.pyt import QHM, QHAdam 503 | optims['qhm'] = QHM 504 | optims['qhadam'] = QHAdam 505 | except ImportError: 506 | # no QHM installed 507 | pass 508 | 509 | return optims 510 | 511 | def init_optim(self, params, optim_states=None, saved_optim_type=None): 512 | """ 513 | Initialize optimizer with model parameters. 514 | 515 | :param params: 516 | parameters from the model 517 | 518 | :param optim_states: 519 | optional argument providing states of optimizer to load 520 | 521 | :param saved_optim_type: 522 | type of optimizer being loaded, if changed will skip loading 523 | optimizer states 524 | """ 525 | 526 | opt = self.opt 527 | 528 | # set up optimizer args 529 | lr = opt['learningrate'] 530 | kwargs = {'lr': lr} 531 | kwargs['amsgrad'] = True 532 | kwargs['betas'] = (0.9, 0.999) 533 | 534 | optim_class = self.optim_opts()[opt['optimizer']] 535 | self.optimizer = optim_class(params, **kwargs) 536 | 537 | def backward(self, loss): 538 | """ 539 | Perform a backward pass. It is recommended you use this instead of 540 | loss.backward(), for integration with distributed training and FP16 541 | training. 542 | """ 543 | loss.backward() 544 | 545 | def update_params(self): 546 | """ 547 | Perform step of optimization, clipping gradients and adjusting LR 548 | schedule if needed. Gradient accumulation is also performed if agent 549 | is called with --update-freq. 550 | 551 | It is recommended (but not forced) that you call this in train_step. 552 | """ 553 | update_freq = 1 554 | if update_freq > 1: 555 | # we're doing gradient accumulation, so we don't only want to step 556 | # every N updates instead 557 | self._number_grad_accum = (self._number_grad_accum + 1) % update_freq 558 | if self._number_grad_accum != 0: 559 | return 560 | 561 | if self.opt['gradient_clip'] > 0: 562 | torch.nn.utils.clip_grad_norm_( 563 | self.model.parameters(), self.opt['gradient_clip'] 564 | ) 565 | 566 | self.optimizer.step() 567 | 568 | def zero_grad(self): 569 | """ 570 | Zero out optimizer. 571 | 572 | It is recommended you call this in train_step. It automatically handles 573 | gradient accumulation if agent is called with --update-freq. 574 | """ 575 | self.optimizer.zero_grad() 576 | 577 | if __name__ == '__main__': 578 | args=setup_args().parse_args() 579 | # import os 580 | # os.environ['CUDA_VISIBLE_DEVICES']=args.gpu 581 | print('CUDA_VISIBLE_DEVICES:', os.environ['CUDA_VISIBLE_DEVICES']) 582 | print(vars(args)) 583 | loop=TrainLoop_fusion_e2e(vars(args),is_finetune=True) 584 | 585 | loop.model.load_model('saved_model/sattn_e2eCRS_best_Rec1.pkl') 586 | # loop.model.load_model('saved_model/sattn_e2eCRS_best_dist4.pkl') 587 | # loop.model.load_model('saved_model/self_attn_generation_model_22.pkl') 588 | # loop.model.load_model() 589 | loop.train() 590 | 591 | -------------------------------------------------------------------------------- /install_geometric.sh: -------------------------------------------------------------------------------- 1 | TORCH=1.6.0 2 | CUDA=cu101 3 | 4 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 5 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 6 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 7 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html 8 | pip install torch-geometric 9 | -------------------------------------------------------------------------------- /mask4key.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/mask4key.npy -------------------------------------------------------------------------------- /mask4movie.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/mask4movie.npy -------------------------------------------------------------------------------- /models/graph.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from sklearn.metrics import roc_auc_score 9 | 10 | from torch_geometric.nn.conv.gcn_conv import GCNConv 11 | from torch_geometric.nn.conv.gat_conv import GATConv 12 | 13 | 14 | def kaiming_reset_parameters(linear_module): 15 | nn.init.kaiming_uniform_(linear_module.weight, a=math.sqrt(5)) 16 | if linear_module.bias is not None: 17 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(linear_module.weight) 18 | bound = 1 / math.sqrt(fan_in) 19 | nn.init.uniform_(linear_module.bias, -bound, bound) 20 | 21 | class GraphConvolution(nn.Module): 22 | """ 23 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 24 | """ 25 | 26 | def __init__(self, in_features, out_features, bias=True): 27 | super(GraphConvolution, self).__init__() 28 | self.in_features = in_features 29 | self.out_features = out_features 30 | self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features)) 31 | if bias: 32 | self.bias = nn.Parameter(torch.FloatTensor(out_features)) 33 | else: 34 | self.register_parameter('bias', None) 35 | self.reset_parameters() 36 | 37 | def reset_parameters(self): 38 | # stdv = 1. / math.sqrt(self.weight.size(1)) 39 | # self.weight.data.uniform_(-stdv, stdv) 40 | # if self.bias is not None: 41 | # self.bias.data.uniform_(-stdv, stdv) 42 | 43 | kaiming_reset_parameters(self) 44 | 45 | def forward(self, input, adj): 46 | support = torch.mm(input, self.weight) 47 | output = torch.spmm(adj, support) 48 | if self.bias is not None: 49 | return output + self.bias 50 | else: 51 | return output 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + ' (' \ 55 | + str(self.in_features) + ' -> ' \ 56 | + str(self.out_features) + ')' 57 | 58 | class GCN(nn.Module): 59 | def __init__(self, ninp, nhid, dropout=0.5): 60 | super(GCN, self).__init__() 61 | 62 | # self.gc1 = GraphConvolution(ninp, nhid) 63 | self.gc2 = GraphConvolution(ninp, nhid) 64 | self.dropout = dropout 65 | 66 | def forward(self, x, adj): 67 | """x: shape (|V|, |D|); adj: shape(|V|, |V|)""" 68 | # x = F.relu(self.gc1(x, adj)) 69 | # x = F.dropout(x, self.dropout, training=self.training) 70 | x = self.gc2(x, adj) 71 | return x 72 | # return F.log_softmax(x, dim=1) 73 | 74 | class GraphAttentionLayer(nn.Module): 75 | """ 76 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 77 | """ 78 | 79 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 80 | super(GraphAttentionLayer, self).__init__() 81 | self.dropout = dropout 82 | self.in_features = in_features 83 | self.out_features = out_features 84 | self.alpha = alpha 85 | self.concat = concat 86 | 87 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 88 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 89 | self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) 90 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 91 | 92 | self.leakyrelu = nn.LeakyReLU(self.alpha) 93 | 94 | def forward(self, input, adj): 95 | h = torch.mm(input, self.W) 96 | N = h.size()[0] 97 | 98 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 99 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 100 | 101 | zero_vec = -9e15*torch.ones_like(e) 102 | attention = torch.where(adj > 0, e, zero_vec) 103 | attention = F.softmax(attention, dim=1) 104 | attention = F.dropout(attention, self.dropout, training=self.training) 105 | h_prime = torch.matmul(attention, h) 106 | 107 | if self.concat: 108 | return F.elu(h_prime) 109 | else: 110 | return h_prime 111 | 112 | def __repr__(self): 113 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 114 | 115 | class SelfAttentionLayer(nn.Module): 116 | def __init__(self, dim, da, alpha=0.2, dropout=0.5): 117 | super(SelfAttentionLayer, self).__init__() 118 | self.dim = dim 119 | self.da = da 120 | self.alpha = alpha 121 | self.dropout = dropout 122 | # self.a = nn.Parameter(torch.zeros(size=(2*self.dim, 1))) 123 | # nn.init.xavier_uniform_(self.a.data, gain=1.414) 124 | self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da))) 125 | self.b = nn.Parameter(torch.zeros(size=(self.da, 1))) 126 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 127 | nn.init.xavier_uniform_(self.b.data, gain=1.414) 128 | # self.leakyrelu = nn.LeakyReLU(self.alpha) 129 | 130 | def forward(self, h): 131 | N = h.shape[0] 132 | assert self.dim == h.shape[1] 133 | # a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.dim) 134 | # e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 135 | # attention = F.softmax(e, dim=1) 136 | e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b).squeeze(dim=1) 137 | attention = F.softmax(e) 138 | # attention = F.dropout(attention, self.dropout, training=self.training) 139 | return torch.matmul(attention, h) 140 | 141 | class SelfAttentionLayer_batch(nn.Module): 142 | def __init__(self, dim, da, alpha=0.2, dropout=0.5): 143 | super(SelfAttentionLayer_batch, self).__init__() 144 | self.dim = dim 145 | self.da = da 146 | self.alpha = alpha 147 | self.dropout = dropout 148 | # self.a = nn.Parameter(torch.zeros(size=(2*self.dim, 1))) 149 | # nn.init.xavier_uniform_(self.a.data, gain=1.414) 150 | self.a = nn.Parameter(torch.zeros(size=(self.dim, self.da))) 151 | self.b = nn.Parameter(torch.zeros(size=(self.da, 1))) 152 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 153 | nn.init.xavier_uniform_(self.b.data, gain=1.414) 154 | # self.leakyrelu = nn.LeakyReLU(self.alpha) 155 | 156 | def forward(self, h, mask): 157 | N = h.shape[0] 158 | assert self.dim == h.shape[2] 159 | # a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.dim) 160 | # e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 161 | # attention = F.softmax(e, dim=1) 162 | mask=1e-30*mask.float() 163 | 164 | e = torch.matmul(torch.tanh(torch.matmul(h, self.a)), self.b) 165 | #print(e.size()) 166 | #print(mask.size()) 167 | attention = F.softmax(e+mask.unsqueeze(-1),dim=1) 168 | # attention = F.dropout(attention, self.dropout, training=self.training) 169 | return torch.matmul(torch.transpose(attention,1,2), h).squeeze(1),attention 170 | 171 | class SelfAttentionLayer2(nn.Module): 172 | def __init__(self, dim, da): 173 | super(SelfAttentionLayer2, self).__init__() 174 | self.dim = dim 175 | self.Wq = nn.Parameter(torch.zeros(self.dim, self.dim)) 176 | self.Wk = nn.Parameter(torch.zeros(self.dim, self.dim)) 177 | nn.init.xavier_uniform_(self.Wq.data, gain=1.414) 178 | nn.init.xavier_uniform_(self.Wk.data, gain=1.414) 179 | # self.leakyrelu = nn.LeakyReLU(self.alpha) 180 | 181 | def forward(self, h): 182 | N = h.shape[0] 183 | assert self.dim == h.shape[1] 184 | q = torch.matmul(h, self.Wq) 185 | k = torch.matmul(h, self.Wk) 186 | e = torch.matmul(q, k.t()) / math.sqrt(self.dim) 187 | attention = F.softmax(e, dim=1) 188 | attention = attention.mean(dim=0) 189 | x = torch.matmul(attention, h) 190 | return x 191 | 192 | class BiAttention(nn.Module): 193 | def __init__(self, input_size, dropout): 194 | super().__init__() 195 | self.dropout = nn.Dropout(p=dropout) 196 | self.input_linear = nn.Linear(input_size, 1, bias=False) 197 | self.memory_linear = nn.Linear(input_size, 1, bias=False) 198 | 199 | self.dot_scale = nn.Parameter(torch.Tensor(input_size).uniform_(1.0 / (input_size ** 0.5))) 200 | 201 | def forward(self, input, memory, mask=None): 202 | bsz, input_len, memory_len = input.size(0), input.size(1), memory.size(1) 203 | 204 | input = self.dropout(input) 205 | memory = self.dropout(memory) 206 | 207 | input_dot = self.input_linear(input) 208 | memory_dot = self.memory_linear(memory).view(bsz, 1, memory_len) 209 | cross_dot = torch.bmm(input * self.dot_scale, memory.permute(0, 2, 1).contiguous()) 210 | att = input_dot + memory_dot + cross_dot 211 | if mask is not None: 212 | att = att - 1e30 * (1 - mask[:,None]) 213 | 214 | weight_one = F.softmax(att, dim=-1) 215 | output_one = torch.bmm(weight_one, memory) 216 | weight_two = F.softmax(att.max(dim=-1)[0], dim=-1).view(bsz, 1, input_len) 217 | output_two = torch.bmm(weight_two, input) 218 | return torch.cat([input, output_one, input*output_one, output_two*output_one], dim=-1) 219 | 220 | class GAT(nn.Module): 221 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): 222 | """Dense version of GAT.""" 223 | super(GAT, self).__init__() 224 | self.dropout = dropout 225 | 226 | self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)] 227 | for i, attention in enumerate(self.attentions): 228 | self.add_module('attention_{}'.format(i), attention) 229 | 230 | self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) 231 | 232 | def forward(self, x, adj): 233 | x = F.dropout(x, self.dropout, training=self.training) 234 | x = torch.cat([att(x, adj) for att in self.attentions], dim=1) 235 | x = F.dropout(x, self.dropout, training=self.training) 236 | x = F.elu(self.out_att(x, adj)) 237 | return F.log_softmax(x, dim=1) 238 | 239 | class SpecialSpmmFunction(torch.autograd.Function): 240 | """Special function for only sparse region backpropataion layer.""" 241 | @staticmethod 242 | def forward(ctx, indices, values, shape, b): 243 | assert indices.requires_grad == False 244 | a = torch.sparse_coo_tensor(indices, values, shape) 245 | ctx.save_for_backward(a, b) 246 | ctx.N = shape[0] 247 | return torch.matmul(a, b) 248 | 249 | @staticmethod 250 | def backward(ctx, grad_output): 251 | a, b = ctx.saved_tensors 252 | grad_values = grad_b = None 253 | if ctx.needs_input_grad[1]: 254 | grad_a_dense = grad_output.matmul(b.t()) 255 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 256 | grad_values = grad_a_dense.view(-1)[edge_idx] 257 | if ctx.needs_input_grad[3]: 258 | grad_b = a.t().matmul(grad_output) 259 | return None, grad_values, None, grad_b 260 | 261 | class SpecialSpmm(nn.Module): 262 | def forward(self, indices, values, shape, b): 263 | return SpecialSpmmFunction.apply(indices, values, shape, b) 264 | 265 | class SpGraphAttentionLayer(nn.Module): 266 | """ 267 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 268 | """ 269 | 270 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 271 | super(SpGraphAttentionLayer, self).__init__() 272 | self.in_features = in_features 273 | self.out_features = out_features 274 | self.alpha = alpha 275 | self.concat = concat 276 | 277 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 278 | nn.init.xavier_normal_(self.W.data, gain=1.414) 279 | 280 | # self.a = nn.Parameter(torch.zeros(size=(1, 2*out_features))) 281 | self.a = nn.Parameter(torch.zeros(size=(1, out_features))) 282 | nn.init.xavier_normal_(self.a.data, gain=1.414) 283 | 284 | # self.dropout = nn.Dropout(dropout) 285 | self.leakyrelu = nn.LeakyReLU(self.alpha) 286 | self.special_spmm = SpecialSpmm() 287 | 288 | def forward(self, input, adj): 289 | N = input.size()[0] 290 | # edge = adj.nonzero().t() 291 | edge = adj._indices() 292 | 293 | h = torch.mm(input, self.W) 294 | # h: N x out 295 | assert not torch.isnan(h).any() 296 | 297 | # Self-attention on the nodes - Shared attention mechanism 298 | # edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 299 | edge_h = h[edge[1, :], :].t() 300 | # edge: 2*D x E 301 | 302 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 303 | assert not torch.isnan(edge_e).any() 304 | # edge_e: E 305 | 306 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N,1)).cuda()) 307 | # e_rowsum: N x 1 308 | 309 | # edge_e = self.dropout(edge_e) 310 | # edge_e: E 311 | 312 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 313 | assert not torch.isnan(h_prime).any() 314 | # h_prime: N x out 315 | 316 | h_prime = h_prime.div(e_rowsum) 317 | # h_prime: N x out 318 | assert not torch.isnan(h_prime).any() 319 | 320 | if self.concat: 321 | # if this layer is not last layer, 322 | return F.elu(h_prime) 323 | else: 324 | # if this layer is last layer, 325 | return h_prime 326 | 327 | def __repr__(self): 328 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 329 | 330 | class SpGAT(nn.Module): 331 | def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads): 332 | """Sparse version of GAT.""" 333 | super(SpGAT, self).__init__() 334 | self.dropout = dropout 335 | 336 | # self.attentions = [SpGraphAttentionLayer(nfeat, 337 | # nhid, 338 | # dropout=dropout, 339 | # alpha=alpha, 340 | # concat=True) for _ in range(nheads)] 341 | # for i, attention in enumerate(self.attentions): 342 | # self.add_module('attention_{}'.format(i), attention) 343 | 344 | # self.out_att = SpGraphAttentionLayer(nhid * nheads, 345 | # nclass, 346 | # dropout=dropout, 347 | # alpha=alpha, 348 | # concat=False) 349 | self.out_att = SpGraphAttentionLayer(nhid, 350 | nclass, 351 | dropout=dropout, 352 | alpha=alpha, 353 | concat=False) 354 | 355 | def forward(self, x, adj): 356 | # x = F.dropout(x, self.dropout, training=self.training) 357 | # x = torch.cat([att(x, adj) for att in self.attentions], dim=1) 358 | # x = F.dropout(x, self.dropout, training=self.training) 359 | # x = F.elu(self.out_att(x, adj)) 360 | x = self.out_att(x, adj) 361 | return x 362 | # return F.log_softmax(x, dim=1) 363 | 364 | def _add_neighbors(kg, g, seed_set, hop): 365 | tails_of_last_hop = seed_set 366 | for h in range(hop): 367 | next_tails_of_last_hop = [] 368 | for entity in tails_of_last_hop: 369 | if entity not in kg: 370 | continue 371 | for tail_and_relation in kg[entity]: 372 | g.add_edge(entity, tail_and_relation[1]) 373 | if entity != tail_and_relation[1]: 374 | next_tails_of_last_hop.append(tail_and_relation[1]) 375 | tails_of_last_hop = next_tails_of_last_hop 376 | 377 | # http://dbpedia.org/ontology/director 378 | 379 | 380 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from functools import lru_cache 3 | import math 4 | import os 5 | import random 6 | import time 7 | import warnings 8 | import heapq 9 | import numpy as np 10 | 11 | # some of the utility methods are helpful for Torch 12 | import torch 13 | import torch.nn as nn 14 | # default type in padded3d needs to be protected if torch 15 | # isn't installed. 16 | TORCH_LONG = torch.long 17 | __TORCH_AVAILABLE = True 18 | 19 | 20 | """Near infinity, useful as a large penalty for scoring when inf is bad.""" 21 | NEAR_INF = 1e20 22 | NEAR_INF_FP16 = 65504 23 | 24 | def neginf(dtype): 25 | """Returns a representable finite number near -inf for a dtype.""" 26 | if dtype is torch.float16: 27 | return -NEAR_INF_FP16 28 | else: 29 | return -NEAR_INF 30 | 31 | def _create_embeddings(dictionary, embedding_size, padding_idx): 32 | """Create and initialize word embeddings.""" 33 | #e=nn.Embedding.from_pretrained(data, freeze=False, padding_idx=0).double() 34 | e = nn.Embedding(len(dictionary)+4, embedding_size, padding_idx) 35 | e.weight.data.copy_(torch.from_numpy(np.load('word2vec_redial.npy'))) 36 | #nn.init.normal_(e.weight, mean=0, std=embedding_size ** -0.5) 37 | #e.weight=data 38 | #nn.init.constant_(e.weight[padding_idx], 0) 39 | return e 40 | 41 | 42 | def _create_entity_embeddings(entity_num, embedding_size, padding_idx): 43 | """Create and initialize word embeddings.""" 44 | e = nn.Embedding(entity_num, embedding_size) 45 | nn.init.normal_(e.weight, mean=0, std=embedding_size ** -0.5) 46 | nn.init.constant_(e.weight[padding_idx], 0) 47 | return e 48 | 49 | -------------------------------------------------------------------------------- /movieID2selection_label.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jokieleung/NTRD/899c42f666e010902051f0663d76188f0e4f67e3/movieID2selection_label.pkl -------------------------------------------------------------------------------- /stopwords.txt: -------------------------------------------------------------------------------- 1 | 'll 2 | 'tis 3 | 'twas 4 | 've 5 | 10 6 | 39 7 | a 8 | a's 9 | able 10 | ableabout 11 | about 12 | above 13 | abroad 14 | abst 15 | accordance 16 | according 17 | accordingly 18 | across 19 | act 20 | actually 21 | ad 22 | added 23 | adj 24 | adopted 25 | ae 26 | af 27 | affected 28 | affecting 29 | affects 30 | after 31 | afterwards 32 | ag 33 | again 34 | against 35 | ago 36 | ah 37 | ahead 38 | ai 39 | ain't 40 | aint 41 | al 42 | all 43 | allow 44 | allows 45 | almost 46 | alone 47 | along 48 | alongside 49 | already 50 | also 51 | although 52 | always 53 | am 54 | amid 55 | amidst 56 | among 57 | amongst 58 | amoungst 59 | amount 60 | an 61 | and 62 | announce 63 | another 64 | any 65 | anybody 66 | anyhow 67 | anymore 68 | anyone 69 | anything 70 | anyway 71 | anyways 72 | anywhere 73 | ao 74 | apart 75 | apparently 76 | appear 77 | appreciate 78 | appropriate 79 | approximately 80 | aq 81 | ar 82 | are 83 | area 84 | areas 85 | aren 86 | aren't 87 | arent 88 | arise 89 | around 90 | arpa 91 | as 92 | aside 93 | ask 94 | asked 95 | asking 96 | asks 97 | associated 98 | at 99 | au 100 | auth 101 | available 102 | aw 103 | away 104 | awfully 105 | az 106 | b 107 | ba 108 | back 109 | backed 110 | backing 111 | backs 112 | backward 113 | backwards 114 | bb 115 | bd 116 | be 117 | became 118 | because 119 | become 120 | becomes 121 | becoming 122 | been 123 | before 124 | beforehand 125 | began 126 | begin 127 | beginning 128 | beginnings 129 | begins 130 | behind 131 | being 132 | beings 133 | believe 134 | below 135 | beside 136 | besides 137 | best 138 | better 139 | between 140 | beyond 141 | bf 142 | bg 143 | bh 144 | bi 145 | big 146 | bill 147 | billion 148 | biol 149 | bj 150 | bm 151 | bn 152 | bo 153 | both 154 | bottom 155 | br 156 | brief 157 | briefly 158 | bs 159 | bt 160 | but 161 | buy 162 | bv 163 | bw 164 | by 165 | bz 166 | c 167 | c'mon 168 | c's 169 | ca 170 | call 171 | came 172 | can 173 | can't 174 | cannot 175 | cant 176 | caption 177 | case 178 | cases 179 | cause 180 | causes 181 | cc 182 | cd 183 | certain 184 | certainly 185 | cf 186 | cg 187 | ch 188 | changes 189 | ci 190 | ck 191 | cl 192 | clear 193 | clearly 194 | click 195 | cm 196 | cmon 197 | cn 198 | co 199 | co. 200 | com 201 | come 202 | comes 203 | computer 204 | con 205 | concerning 206 | consequently 207 | consider 208 | considering 209 | contain 210 | containing 211 | contains 212 | copy 213 | corresponding 214 | could 215 | could've 216 | couldn 217 | couldn't 218 | couldnt 219 | course 220 | cr 221 | cry 222 | cs 223 | cu 224 | currently 225 | cv 226 | cx 227 | cy 228 | cz 229 | d 230 | dare 231 | daren't 232 | darent 233 | date 234 | de 235 | dear 236 | definitely 237 | describe 238 | described 239 | despite 240 | detail 241 | did 242 | didn 243 | didn't 244 | didnt 245 | differ 246 | different 247 | differently 248 | directly 249 | dj 250 | dk 251 | dm 252 | do 253 | does 254 | doesn 255 | doesn't 256 | doesnt 257 | doing 258 | don 259 | don't 260 | done 261 | dont 262 | doubtful 263 | down 264 | downed 265 | downing 266 | downs 267 | downwards 268 | due 269 | during 270 | dz 271 | e 272 | each 273 | early 274 | ec 275 | ed 276 | edu 277 | ee 278 | effect 279 | eg 280 | eh 281 | eight 282 | eighty 283 | either 284 | eleven 285 | else 286 | elsewhere 287 | empty 288 | end 289 | ended 290 | ending 291 | ends 292 | enough 293 | entirely 294 | er 295 | es 296 | especially 297 | et 298 | et-al 299 | etc 300 | even 301 | evenly 302 | ever 303 | evermore 304 | every 305 | everybody 306 | everyone 307 | everything 308 | everywhere 309 | ex 310 | exactly 311 | example 312 | except 313 | f 314 | face 315 | faces 316 | fact 317 | facts 318 | fairly 319 | far 320 | farther 321 | felt 322 | few 323 | fewer 324 | ff 325 | fi 326 | fifteen 327 | fifth 328 | fifty 329 | fify 330 | fill 331 | find 332 | finds 333 | fire 334 | first 335 | five 336 | fix 337 | fj 338 | fk 339 | fm 340 | fo 341 | followed 342 | following 343 | follows 344 | for 345 | forever 346 | former 347 | formerly 348 | forth 349 | forty 350 | forward 351 | found 352 | four 353 | fr 354 | free 355 | from 356 | front 357 | full 358 | fully 359 | further 360 | furthered 361 | furthering 362 | furthermore 363 | furthers 364 | fx 365 | g 366 | ga 367 | gave 368 | gb 369 | gd 370 | ge 371 | general 372 | generally 373 | get 374 | gets 375 | getting 376 | gf 377 | gg 378 | gh 379 | gi 380 | give 381 | given 382 | gives 383 | giving 384 | gl 385 | gm 386 | gmt 387 | gn 388 | go 389 | goes 390 | going 391 | gone 392 | good 393 | goods 394 | got 395 | gotten 396 | gov 397 | gp 398 | gq 399 | gr 400 | great 401 | greater 402 | greatest 403 | greetings 404 | group 405 | grouped 406 | grouping 407 | groups 408 | gs 409 | gt 410 | gu 411 | gw 412 | gy 413 | h 414 | had 415 | hadn't 416 | hadnt 417 | half 418 | happens 419 | hardly 420 | has 421 | hasn 422 | hasn't 423 | hasnt 424 | have 425 | haven 426 | haven't 427 | havent 428 | having 429 | he 430 | he'd 431 | he'll 432 | he's 433 | hed 434 | hell 435 | hello 436 | help 437 | hence 438 | her 439 | here 440 | here's 441 | hereafter 442 | hereby 443 | herein 444 | heres 445 | hereupon 446 | hers 447 | herself 448 | herse” 449 | hes 450 | hi 451 | hid 452 | high 453 | higher 454 | highest 455 | him 456 | himself 457 | himse” 458 | his 459 | hither 460 | hk 461 | hm 462 | hn 463 | home 464 | homepage 465 | hopefully 466 | how 467 | how'd 468 | how'll 469 | how's 470 | howbeit 471 | however 472 | hr 473 | ht 474 | htm 475 | html 476 | http 477 | hu 478 | hundred 479 | i 480 | i'd 481 | i'll 482 | i'm 483 | i've 484 | i.e. 485 | id 486 | ie 487 | if 488 | ignored 489 | ii 490 | il 491 | ill 492 | im 493 | immediate 494 | immediately 495 | importance 496 | important 497 | in 498 | inasmuch 499 | inc 500 | inc. 501 | indeed 502 | index 503 | indicate 504 | indicated 505 | indicates 506 | information 507 | inner 508 | inside 509 | insofar 510 | instead 511 | int 512 | interest 513 | interested 514 | interesting 515 | interests 516 | into 517 | invention 518 | inward 519 | io 520 | iq 521 | ir 522 | is 523 | isn 524 | isn't 525 | isnt 526 | it 527 | it'd 528 | it'll 529 | it's 530 | itd 531 | itll 532 | its 533 | itself 534 | itse” 535 | ive 536 | j 537 | je 538 | jm 539 | jo 540 | join 541 | jp 542 | just 543 | k 544 | ke 545 | keep 546 | keeps 547 | kept 548 | keys 549 | kg 550 | kh 551 | ki 552 | kind 553 | km 554 | kn 555 | knew 556 | know 557 | known 558 | knows 559 | kp 560 | kr 561 | kw 562 | ky 563 | kz 564 | l 565 | la 566 | large 567 | largely 568 | last 569 | lately 570 | later 571 | latest 572 | latter 573 | latterly 574 | lb 575 | lc 576 | least 577 | length 578 | less 579 | lest 580 | let 581 | let's 582 | lets 583 | li 584 | like 585 | liked 586 | likely 587 | likewise 588 | line 589 | little 590 | lk 591 | ll 592 | long 593 | longer 594 | longest 595 | look 596 | looking 597 | looks 598 | low 599 | lower 600 | lr 601 | ls 602 | lt 603 | ltd 604 | lu 605 | lv 606 | ly 607 | m 608 | ma 609 | made 610 | mainly 611 | make 612 | makes 613 | making 614 | man 615 | many 616 | may 617 | maybe 618 | mayn't 619 | maynt 620 | mc 621 | md 622 | me 623 | mean 624 | means 625 | meantime 626 | meanwhile 627 | member 628 | members 629 | men 630 | merely 631 | mg 632 | mh 633 | microsoft 634 | might 635 | might've 636 | mightn't 637 | mightnt 638 | mil 639 | mill 640 | million 641 | mine 642 | minus 643 | miss 644 | mk 645 | ml 646 | mm 647 | mn 648 | mo 649 | more 650 | moreover 651 | most 652 | mostly 653 | move 654 | mp 655 | mq 656 | mr 657 | mrs 658 | ms 659 | msie 660 | mt 661 | mu 662 | much 663 | mug 664 | must 665 | must've 666 | mustn't 667 | mustnt 668 | mv 669 | mw 670 | mx 671 | my 672 | myself 673 | myse” 674 | mz 675 | n 676 | na 677 | name 678 | namely 679 | nay 680 | nc 681 | nd 682 | ne 683 | near 684 | nearly 685 | necessarily 686 | necessary 687 | need 688 | needed 689 | needing 690 | needn't 691 | neednt 692 | needs 693 | neither 694 | net 695 | netscape 696 | never 697 | neverf 698 | neverless 699 | nevertheless 700 | new 701 | newer 702 | newest 703 | next 704 | nf 705 | ng 706 | ni 707 | nine 708 | ninety 709 | nl 710 | no 711 | no-one 712 | nobody 713 | non 714 | none 715 | nonetheless 716 | noone 717 | nor 718 | normally 719 | nos 720 | not 721 | noted 722 | nothing 723 | notwithstanding 724 | novel 725 | now 726 | nowhere 727 | np 728 | nr 729 | nu 730 | null 731 | number 732 | numbers 733 | nz 734 | o 735 | obtain 736 | obtained 737 | obviously 738 | of 739 | off 740 | often 741 | oh 742 | ok 743 | okay 744 | old 745 | older 746 | oldest 747 | om 748 | omitted 749 | on 750 | once 751 | one 752 | one's 753 | ones 754 | only 755 | onto 756 | open 757 | opened 758 | opening 759 | opens 760 | opposite 761 | or 762 | ord 763 | order 764 | ordered 765 | ordering 766 | orders 767 | org 768 | other 769 | others 770 | otherwise 771 | ought 772 | oughtn't 773 | oughtnt 774 | our 775 | ours 776 | ourselves 777 | out 778 | outside 779 | over 780 | overall 781 | owing 782 | own 783 | p 784 | pa 785 | page 786 | pages 787 | part 788 | parted 789 | particular 790 | particularly 791 | parting 792 | parts 793 | past 794 | pe 795 | per 796 | perhaps 797 | pf 798 | pg 799 | ph 800 | pk 801 | pl 802 | place 803 | placed 804 | places 805 | please 806 | plus 807 | pm 808 | pmid 809 | pn 810 | point 811 | pointed 812 | pointing 813 | points 814 | poorly 815 | possible 816 | possibly 817 | potentially 818 | pp 819 | pr 820 | predominantly 821 | present 822 | presented 823 | presenting 824 | presents 825 | presumably 826 | previously 827 | primarily 828 | probably 829 | problem 830 | problems 831 | promptly 832 | proud 833 | provided 834 | provides 835 | pt 836 | put 837 | puts 838 | pw 839 | py 840 | q 841 | qa 842 | que 843 | quickly 844 | quite 845 | qv 846 | r 847 | ran 848 | rather 849 | rd 850 | re 851 | readily 852 | really 853 | reasonably 854 | recent 855 | recently 856 | ref 857 | refs 858 | regarding 859 | regardless 860 | regards 861 | related 862 | relatively 863 | research 864 | reserved 865 | respectively 866 | resulted 867 | resulting 868 | results 869 | right 870 | ring 871 | ro 872 | room 873 | rooms 874 | round 875 | ru 876 | run 877 | rw 878 | s 879 | sa 880 | said 881 | same 882 | saw 883 | say 884 | saying 885 | says 886 | sb 887 | sc 888 | sd 889 | se 890 | sec 891 | second 892 | secondly 893 | seconds 894 | section 895 | see 896 | seeing 897 | seem 898 | seemed 899 | seeming 900 | seems 901 | seen 902 | sees 903 | self 904 | selves 905 | sensible 906 | sent 907 | serious 908 | seriously 909 | seven 910 | seventy 911 | several 912 | sg 913 | sh 914 | shall 915 | shan't 916 | shant 917 | she 918 | she'd 919 | she'll 920 | she's 921 | shed 922 | shell 923 | shes 924 | should 925 | should've 926 | shouldn 927 | shouldn't 928 | shouldnt 929 | show 930 | showed 931 | showing 932 | shown 933 | showns 934 | shows 935 | si 936 | side 937 | sides 938 | significant 939 | significantly 940 | similar 941 | similarly 942 | since 943 | sincere 944 | site 945 | six 946 | sixty 947 | sj 948 | sk 949 | sl 950 | slightly 951 | sm 952 | small 953 | smaller 954 | smallest 955 | sn 956 | so 957 | some 958 | somebody 959 | someday 960 | somehow 961 | someone 962 | somethan 963 | something 964 | sometime 965 | sometimes 966 | somewhat 967 | somewhere 968 | soon 969 | sorry 970 | specifically 971 | specified 972 | specify 973 | specifying 974 | sr 975 | st 976 | state 977 | states 978 | still 979 | stop 980 | strongly 981 | su 982 | sub 983 | substantially 984 | successfully 985 | such 986 | sufficiently 987 | suggest 988 | sup 989 | sure 990 | sv 991 | sy 992 | system 993 | sz 994 | t 995 | t's 996 | take 997 | taken 998 | taking 999 | tc 1000 | td 1001 | tell 1002 | ten 1003 | tends 1004 | test 1005 | text 1006 | tf 1007 | tg 1008 | th 1009 | than 1010 | thank 1011 | thanks 1012 | thanx 1013 | that 1014 | that'll 1015 | that's 1016 | that've 1017 | thatll 1018 | thats 1019 | thatve 1020 | the 1021 | their 1022 | theirs 1023 | them 1024 | themselves 1025 | then 1026 | thence 1027 | there 1028 | there'd 1029 | there'll 1030 | there're 1031 | there's 1032 | there've 1033 | thereafter 1034 | thereby 1035 | thered 1036 | therefore 1037 | therein 1038 | therell 1039 | thereof 1040 | therere 1041 | theres 1042 | thereto 1043 | thereupon 1044 | thereve 1045 | these 1046 | they 1047 | they'd 1048 | they'll 1049 | they're 1050 | they've 1051 | theyd 1052 | theyll 1053 | theyre 1054 | theyve 1055 | thick 1056 | thin 1057 | thing 1058 | things 1059 | think 1060 | thinks 1061 | third 1062 | thirty 1063 | this 1064 | thorough 1065 | thoroughly 1066 | those 1067 | thou 1068 | though 1069 | thoughh 1070 | thought 1071 | thoughts 1072 | thousand 1073 | three 1074 | throug 1075 | through 1076 | throughout 1077 | thru 1078 | thus 1079 | til 1080 | till 1081 | tip 1082 | tis 1083 | tj 1084 | tk 1085 | tm 1086 | tn 1087 | to 1088 | today 1089 | together 1090 | too 1091 | took 1092 | top 1093 | toward 1094 | towards 1095 | tp 1096 | tr 1097 | tried 1098 | tries 1099 | trillion 1100 | truly 1101 | try 1102 | trying 1103 | ts 1104 | tt 1105 | turn 1106 | turned 1107 | turning 1108 | turns 1109 | tv 1110 | tw 1111 | twas 1112 | twelve 1113 | twenty 1114 | twice 1115 | two 1116 | tz 1117 | u 1118 | ua 1119 | ug 1120 | uk 1121 | um 1122 | un 1123 | under 1124 | underneath 1125 | undoing 1126 | unfortunately 1127 | unless 1128 | unlike 1129 | unlikely 1130 | until 1131 | unto 1132 | up 1133 | upon 1134 | ups 1135 | upwards 1136 | us 1137 | use 1138 | used 1139 | useful 1140 | usefully 1141 | usefulness 1142 | uses 1143 | using 1144 | usually 1145 | uucp 1146 | uy 1147 | uz 1148 | v 1149 | va 1150 | value 1151 | various 1152 | vc 1153 | ve 1154 | versus 1155 | very 1156 | vg 1157 | vi 1158 | via 1159 | viz 1160 | vn 1161 | vol 1162 | vols 1163 | vs 1164 | vu 1165 | w 1166 | want 1167 | wanted 1168 | wanting 1169 | wants 1170 | was 1171 | wasn 1172 | wasn't 1173 | wasnt 1174 | way 1175 | ways 1176 | we 1177 | we'd 1178 | we'll 1179 | we're 1180 | we've 1181 | web 1182 | webpage 1183 | website 1184 | wed 1185 | welcome 1186 | well 1187 | wells 1188 | went 1189 | were 1190 | weren 1191 | weren't 1192 | werent 1193 | weve 1194 | wf 1195 | what 1196 | what'd 1197 | what'll 1198 | what's 1199 | what've 1200 | whatever 1201 | whatll 1202 | whats 1203 | whatve 1204 | when 1205 | when'd 1206 | when'll 1207 | when's 1208 | whence 1209 | whenever 1210 | where 1211 | where'd 1212 | where'll 1213 | where's 1214 | whereafter 1215 | whereas 1216 | whereby 1217 | wherein 1218 | wheres 1219 | whereupon 1220 | wherever 1221 | whether 1222 | which 1223 | whichever 1224 | while 1225 | whilst 1226 | whim 1227 | whither 1228 | who 1229 | who'd 1230 | who'll 1231 | who's 1232 | whod 1233 | whoever 1234 | whole 1235 | wholl 1236 | whom 1237 | whomever 1238 | whos 1239 | whose 1240 | why 1241 | why'd 1242 | why'll 1243 | why's 1244 | widely 1245 | width 1246 | will 1247 | willing 1248 | wish 1249 | with 1250 | within 1251 | without 1252 | won 1253 | won't 1254 | wonder 1255 | wont 1256 | words 1257 | work 1258 | worked 1259 | working 1260 | works 1261 | world 1262 | would 1263 | would've 1264 | wouldn 1265 | wouldn't 1266 | wouldnt 1267 | ws 1268 | www 1269 | x 1270 | y 1271 | ye 1272 | year 1273 | years 1274 | yes 1275 | yet 1276 | you 1277 | you'd 1278 | you'll 1279 | you're 1280 | you've 1281 | youd 1282 | youll 1283 | young 1284 | younger 1285 | youngest 1286 | your 1287 | youre 1288 | yours 1289 | yourself 1290 | yourselves 1291 | youve 1292 | yt 1293 | yu 1294 | z 1295 | za 1296 | zero 1297 | zm 1298 | zr -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mask_softmax(x, lengths):#, dim=1) 4 | mask = torch.zeros_like(x).to(device=x.device, non_blocking=True) 5 | t_lengths = lengths[:,:,None].expand_as(mask) 6 | arange_id = torch.arange(mask.size(1)).to(device=x.device, non_blocking=True) 7 | arange_id = arange_id[None,:,None].expand_as(mask) 8 | 9 | mask[arange_id output_metrics_gen["dist4"]: 197 | pass 198 | else: 199 | best_val_gen = output_metrics_gen["dist4"] 200 | self.model.save_model(model_name= self.opt['save_exp_name'] + '_best_dist4.pkl') 201 | print("Best Dist4 generator model saved once------------------------------------------------") 202 | print("best dist4 is :", best_val_gen) 203 | 204 | if best_val_rec > output_metrics_gen["recall@50"] + output_metrics_gen["recall@1"]: 205 | pass 206 | else: 207 | best_val_rec = output_metrics_gen["recall@50"] + output_metrics_gen["recall@1"] 208 | self.model.save_model(model_name= self.opt['save_exp_name'] + '_best_Rec.pkl') 209 | print("Best Recall generator model saved once------------------------------------------------") 210 | print("best res_movie_R@1 is :", output_metrics_gen["recall@1"]) 211 | print("best res_movie_R@10 is :", output_metrics_gen["recall@10"]) 212 | print("best res_movie_R@50 is :", output_metrics_gen["recall@50"]) 213 | print('cur selection_loss is %f'%(sum([l[1] for l in losses])/len(losses))) 214 | print('cur Epoch is : ', i) 215 | 216 | # if i % 5 ==0: # save each 5 epoch 217 | # model_name = self.opt['save_exp_name'] + '_' + str(i) + '.pkl' 218 | # self.model.save_model(model_name=model_name) 219 | # print("generator model saved once------------------------------------------------") 220 | # print('cur selection_loss is %f'%(sum([l[1] for l in losses])/len(losses))) 221 | 222 | _=self.val(is_test=True) 223 | 224 | def val(self,is_test=False): 225 | self.metrics_gen={"ppl":0,"dist1":0,"dist2":0,"dist3":0,"dist4":0,"bleu1":0,"bleu2":0,"bleu3":0,"bleu4":0,"count":0,"true_recall_movie_count":0, "res_movie_recall":0.0,"recall@1":0,"recall@10":0,"recall@50":0} 226 | self.metrics_rec={"recall@1":0,"recall@10":0,"recall@50":0,"loss":0,"gate":0,"count":0,'gate_count':0} 227 | # self.model.eval() 228 | val_dataset = dataset('data/test_data.jsonl', self.opt) 229 | val_set=CRSdataset(val_dataset.data_process(True),self.opt['n_entity'],self.opt['n_concept']) 230 | val_dataset_loader = torch.utils.data.DataLoader(dataset=val_set, 231 | batch_size=self.batch_size, 232 | shuffle=False) 233 | 234 | train_set=CRSdataset(self.train_dataset.data_process(True),self.opt['n_entity'],self.opt['n_concept']) 235 | train_dataset_loader = torch.utils.data.DataLoader(dataset=train_set, 236 | batch_size=self.batch_size, 237 | shuffle=False) 238 | inference_sum=[] 239 | golden_sum=[] 240 | gold_movie_ids=[] 241 | train_gold_movie_ids=[] 242 | train_golden_sum=[] 243 | context_sum=[] 244 | losses=[] 245 | recs=[] 246 | 247 | match_movie_item = [] 248 | for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec,movies_gth,movie_nums in tqdm(val_dataset_loader): 249 | gold_res, cur_gold_movie_ids = self.template_vector2sentence(response.cpu(), movies_gth.cpu()) 250 | # golden_sum.extend(gold_res) 251 | gold_movie_ids.extend(cur_gold_movie_ids) 252 | 253 | for context, c_lengths, response, r_length, mask_response, mask_r_length, entity, entity_vector, movie, concept_mask, dbpedia_mask, concept_vec, db_vec, rec,movies_gth,movie_nums in tqdm(train_dataset_loader): 254 | 255 | train_gold_res, cur_train_gold_movie_ids = self.template_vector2sentence(response.cpu(), movies_gth.cpu()) 256 | # train_golden_sum.extend(train_gold_res) 257 | train_gold_movie_ids.extend(cur_train_gold_movie_ids) 258 | 259 | # train_golden_sum.extend(self.template_vector2sentence(response.cpu(), movies_gth.cpu())) 260 | 261 | 262 | for val_movie in set(gold_movie_ids): 263 | if val_movie not in set(train_gold_movie_ids): 264 | match_movie_item.append(val_movie) 265 | print('-'*50) 266 | print(len(set(match_movie_item))) 267 | print('match movie(in test not in train):') 268 | print(set(match_movie_item)) 269 | print('-'*50) 270 | 271 | 272 | 273 | 274 | 275 | def all_response_movie_recall_cal(self,decode_preds, matching_scores,labels): 276 | 277 | # matching_scores is non-mask version [bsz, seq_len, matching_vocab] 278 | # decode_preds [bsz, seq_len] 279 | # labels [bsz, movie_length_with_padding] 280 | # print('decode_preds shape', decode_preds.shape) 281 | # print('matching_scores shape', matching_scores.shape) 282 | # print('labels shape', labels.shape) 283 | decode_preds = decode_preds[:, 1:] # removing the start index 284 | 285 | labels = labels * (labels!=-1) # removing the padding token 286 | 287 | batch_size, seq_len = decode_preds.shape[0], decode_preds.shape[1] 288 | for cur_b in range(batch_size): 289 | for cur_seq_len in range(seq_len): 290 | if decode_preds[cur_b][cur_seq_len] ==6: # word id is 6 291 | _, pred_idx = torch.topk(matching_scores[cur_b][cur_seq_len], k=100, dim=-1) 292 | targets = labels[cur_b] 293 | for target in targets: 294 | self.metrics_gen["recall@1"] += int(target in pred_idx[:1].tolist()) 295 | self.metrics_gen["recall@10"] += int(target in pred_idx[:10].tolist()) 296 | self.metrics_gen["recall@50"] += int(target in pred_idx[:50].tolist()) 297 | 298 | def metrics_cal_gen(self,rec_loss,preds,responses,recs, beam=1): 299 | def bleu_cal(sen1, tar1): 300 | bleu1 = sentence_bleu([tar1], sen1, weights=(1, 0, 0, 0)) 301 | bleu2 = sentence_bleu([tar1], sen1, weights=(0, 1, 0, 0)) 302 | bleu3 = sentence_bleu([tar1], sen1, weights=(0, 0, 1, 0)) 303 | bleu4 = sentence_bleu([tar1], sen1, weights=(0, 0, 0, 1)) 304 | return bleu1, bleu2, bleu3, bleu4 305 | 306 | def response_movie_recall_cal(sen1, tar1): 307 | for word in sen1: 308 | if '@' in word: # if is movie 309 | if word in tar1: # if in gth 310 | return int(1) 311 | else: 312 | return int(0) 313 | return int(0) 314 | 315 | 316 | def distinct_metrics(outs): 317 | # outputs is a list which contains several sentences, each sentence contains several words 318 | unigram_count = 0 319 | bigram_count = 0 320 | trigram_count=0 321 | quagram_count=0 322 | unigram_set = set() 323 | bigram_set = set() 324 | trigram_set=set() 325 | quagram_set=set() 326 | for sen in outs: 327 | for word in sen: 328 | unigram_count += 1 329 | unigram_set.add(word) 330 | for start in range(len(sen) - 1): 331 | bg = str(sen[start]) + ' ' + str(sen[start + 1]) 332 | bigram_count += 1 333 | bigram_set.add(bg) 334 | for start in range(len(sen)-2): 335 | trg=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) 336 | trigram_count+=1 337 | trigram_set.add(trg) 338 | for start in range(len(sen)-3): 339 | quag=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) + ' ' + str(sen[start + 3]) 340 | quagram_count+=1 341 | quagram_set.add(quag) 342 | dis1 = len(unigram_set) / len(outs)#unigram_count 343 | dis2 = len(bigram_set) / len(outs)#bigram_count 344 | dis3 = len(trigram_set)/len(outs)#trigram_count 345 | dis4 = len(quagram_set)/len(outs)#quagram_count 346 | return dis1, dis2, dis3, dis4 347 | 348 | predict_s=preds 349 | golden_s=responses 350 | #print(rec_loss[0]) 351 | self.metrics_gen["ppl"]+=sum([exp(ppl) for ppl in rec_loss])/len(rec_loss) 352 | generated=[] 353 | total_movie_gth_response_cnt = 0 354 | have_movie_res_cnt = 0 355 | loop = 0 356 | total_item_response_cnt=0 357 | total_hypo_word_count=0 358 | # for out, tar, rec in zip(predict_s, golden_s, recs): 359 | for out in predict_s: 360 | tar = golden_s[loop // beam] 361 | loop = loop+1 362 | bleu1, bleu2, bleu3, bleu4=bleu_cal(out, tar) 363 | generated.append(out) 364 | self.metrics_gen['bleu1']+=bleu1 365 | self.metrics_gen['bleu2']+=bleu2 366 | self.metrics_gen['bleu3']+=bleu3 367 | self.metrics_gen['bleu4']+=bleu4 368 | self.metrics_gen['count']+=1 369 | self.metrics_gen['true_recall_movie_count']+=response_movie_recall_cal(out, tar) 370 | for word in out: 371 | total_hypo_word_count +=1 372 | if '@' in word: 373 | total_item_response_cnt+=1 374 | 375 | total_target_word_count = 0 376 | for tar in golden_s: 377 | for word in tar: 378 | total_target_word_count +=1 379 | if '@' in word: 380 | total_movie_gth_response_cnt+=1 381 | for word in tar: 382 | if '@' in word: 383 | have_movie_res_cnt+=1 384 | break 385 | 386 | dis1, dis2, dis3, dis4=distinct_metrics(generated) 387 | self.metrics_gen['dist1']=dis1 388 | self.metrics_gen['dist2']=dis2 389 | self.metrics_gen['dist3']=dis3 390 | self.metrics_gen['dist4']=dis4 391 | 392 | self.metrics_gen['res_movie_recall'] = self.metrics_gen['true_recall_movie_count'] / have_movie_res_cnt 393 | self.metrics_gen["recall@1"] = self.metrics_gen["recall@1"] / have_movie_res_cnt 394 | self.metrics_gen["recall@10"] = self.metrics_gen["recall@10"] / have_movie_res_cnt 395 | self.metrics_gen["recall@50"] = self.metrics_gen["recall@50"] / have_movie_res_cnt 396 | print('----------'*10) 397 | print('total_movie_gth_response_cnt: ', total_movie_gth_response_cnt) 398 | print('total_gth_response_cnt: ', len(golden_s)) 399 | print('total_hypo_response_cnt: ', len(predict_s)) 400 | print('hypo item ratio: ', total_item_response_cnt / len(predict_s)) 401 | print('target item ratio: ', total_movie_gth_response_cnt / len(golden_s)) 402 | print('have_movie_res_cnt: ', have_movie_res_cnt) 403 | print('----------'*10) 404 | 405 | def vector2sentence(self,batch_sen): 406 | sentences=[] 407 | for sen in batch_sen.numpy().tolist(): 408 | sentence=[] 409 | for word in sen: 410 | if word>3: 411 | sentence.append(self.index2word[word]) 412 | # if word==6: #if MOVIE token 413 | # sentence.append(self.selection_label2movieID[selection_label]) 414 | elif word==3: 415 | sentence.append('_UNK_') 416 | sentences.append(sentence) 417 | return sentences 418 | 419 | def template_vector2sentence(self,batch_sen, batch_selection_pred): 420 | sentences=[] 421 | movie_ids=[] 422 | all_movie_labels = [] 423 | if batch_selection_pred is not None: 424 | batch_selection_pred = batch_selection_pred * (batch_selection_pred!=-1) 425 | batch_selection_pred = torch.masked_select(batch_selection_pred, (batch_selection_pred!=0)) 426 | for movie in batch_selection_pred.numpy().tolist(): 427 | all_movie_labels.append(movie) 428 | 429 | # print('all_movie_labels:', all_movie_labels) 430 | curr_movie_token = 0 431 | for sen in batch_sen.numpy().tolist(): 432 | sentence=[] 433 | for word in sen: 434 | if word>3: 435 | if word==6: #if MOVIE token 436 | # print('all_movie_labels[curr_movie_token]',all_movie_labels[curr_movie_token]) 437 | # print('selection_label2movieID',self.selection_label2movieID[all_movie_labels[curr_movie_token]]) 438 | 439 | # WAY1: original method 440 | str_movie_id = '@' + str(self.selection_label2movieID[all_movie_labels[curr_movie_token]]) 441 | int_movie_id = self.selection_label2movieID[all_movie_labels[curr_movie_token]] 442 | sentence.append(str_movie_id) 443 | movie_ids.append(int_movie_id) 444 | 445 | 446 | 447 | # WAY2: print out the movie name, but should comment when calculating the gen metrics 448 | # if self.id2entity[self.selection_label2movieID[all_movie_labels[curr_movie_token]]] is not None: 449 | # sentence.append(self.id2entity[self.selection_label2movieID[all_movie_labels[curr_movie_token]]].split('/')[-1]) 450 | # else: 451 | # sentence.append('@' + str(self.selection_label2movieID[all_movie_labels[curr_movie_token]])) 452 | 453 | 454 | curr_movie_token +=1 455 | else: 456 | sentence.append(self.index2word[word]) 457 | 458 | elif word==3: 459 | sentence.append('_UNK_') 460 | sentences.append(sentence) 461 | 462 | # print('[DEBUG]sentence : ') 463 | # print(u' '.join(sentence).encode('utf-8').strip()) 464 | 465 | assert curr_movie_token == len(all_movie_labels) 466 | return sentences, movie_ids 467 | 468 | @classmethod 469 | def optim_opts(self): 470 | """ 471 | Fetch optimizer selection. 472 | 473 | By default, collects everything in torch.optim, as well as importing: 474 | - qhm / qhmadam if installed from github.com/facebookresearch/qhoptim 475 | 476 | Override this (and probably call super()) to add your own optimizers. 477 | """ 478 | # first pull torch.optim in 479 | optims = {k.lower(): v for k, v in optim.__dict__.items() 480 | if not k.startswith('__') and k[0].isupper()} 481 | try: 482 | import apex.optimizers.fused_adam as fused_adam 483 | optims['fused_adam'] = fused_adam.FusedAdam 484 | except ImportError: 485 | pass 486 | 487 | try: 488 | # https://openreview.net/pdf?id=S1fUpoR5FQ 489 | from qhoptim.pyt import QHM, QHAdam 490 | optims['qhm'] = QHM 491 | optims['qhadam'] = QHAdam 492 | except ImportError: 493 | # no QHM installed 494 | pass 495 | 496 | return optims 497 | 498 | def init_optim(self, params, optim_states=None, saved_optim_type=None): 499 | """ 500 | Initialize optimizer with model parameters. 501 | 502 | :param params: 503 | parameters from the model 504 | 505 | :param optim_states: 506 | optional argument providing states of optimizer to load 507 | 508 | :param saved_optim_type: 509 | type of optimizer being loaded, if changed will skip loading 510 | optimizer states 511 | """ 512 | 513 | opt = self.opt 514 | 515 | # set up optimizer args 516 | lr = opt['learningrate'] 517 | kwargs = {'lr': lr} 518 | kwargs['amsgrad'] = True 519 | kwargs['betas'] = (0.9, 0.999) 520 | 521 | optim_class = self.optim_opts()[opt['optimizer']] 522 | self.optimizer = optim_class(params, **kwargs) 523 | 524 | def backward(self, loss): 525 | """ 526 | Perform a backward pass. It is recommended you use this instead of 527 | loss.backward(), for integration with distributed training and FP16 528 | training. 529 | """ 530 | loss.backward() 531 | 532 | def update_params(self): 533 | """ 534 | Perform step of optimization, clipping gradients and adjusting LR 535 | schedule if needed. Gradient accumulation is also performed if agent 536 | is called with --update-freq. 537 | 538 | It is recommended (but not forced) that you call this in train_step. 539 | """ 540 | update_freq = 1 541 | if update_freq > 1: 542 | # we're doing gradient accumulation, so we don't only want to step 543 | # every N updates instead 544 | self._number_grad_accum = (self._number_grad_accum + 1) % update_freq 545 | if self._number_grad_accum != 0: 546 | return 547 | 548 | if self.opt['gradient_clip'] > 0: 549 | torch.nn.utils.clip_grad_norm_( 550 | self.model.parameters(), self.opt['gradient_clip'] 551 | ) 552 | 553 | self.optimizer.step() 554 | 555 | def zero_grad(self): 556 | """ 557 | Zero out optimizer. 558 | 559 | It is recommended you call this in train_step. It automatically handles 560 | gradient accumulation if agent is called with --update-freq. 561 | """ 562 | self.optimizer.zero_grad() 563 | 564 | if __name__ == '__main__': 565 | args=setup_args().parse_args() 566 | # import os 567 | # os.environ['CUDA_VISIBLE_DEVICES']=args.gpu 568 | print('CUDA_VISIBLE_DEVICES:', os.environ['CUDA_VISIBLE_DEVICES']) 569 | print(vars(args)) 570 | if args.is_finetune==False: 571 | loop=TrainLoop_fusion_rec(vars(args),is_finetune=False) 572 | # loop.model.load_model('saved_model/net_parameter1_bu.pkl') 573 | loop.train() 574 | else: 575 | loop=TrainLoop_fusion_gen(vars(args),is_finetune=True) 576 | #Tips: should at least load one of the model By Jokie 577 | 578 | #if validation 579 | #WAY1: 580 | # loop.model.load_model('saved_model/matching_linear_model/generation_model_best.pkl') 581 | 582 | #WAY2: 583 | # loop.model.load_model('saved_model/sattn_dialog_model_best.pkl') 584 | # loop.model.load_model('saved_model/generation_model_best.pkl') 585 | # loop.model.load_model('saved_model/generation_model.pkl') 586 | # loop.model.load_model('saved_model/self_attn_generation_model_22.pkl') 587 | 588 | #WAY3: insert 589 | # loop.model.load_model() 590 | loop.model.load_model(args.load_model_pth) 591 | 592 | loop.train() 593 | --------------------------------------------------------------------------------