├── .gitignore ├── LICENSE ├── README.md ├── build-semantic-graphs ├── README.md ├── answer_tag.py ├── build_graph.py ├── build_semantic_graph.py ├── build_tree.py ├── merge.py ├── merge_graph.py ├── preprocess │ ├── README.md │ ├── get_coref_and_dep_data.py │ └── preprocess_raw_data.py ├── prune_and_merge_tree.py ├── rearrange_tree.py └── tag.py ├── evaluate_metrics.py ├── model.jpg ├── scripts ├── preprocess_data.sh ├── train_classifier.sh ├── train_generator.sh └── translate.sh └── src ├── onqg ├── dataset │ ├── Constants.py │ ├── Dataset.py │ ├── Vocab.py │ ├── __init__.py │ └── data_processor.py ├── models │ ├── Decoders.py │ ├── Encoders.py │ ├── Models.py │ └── modules │ │ ├── Attention.py │ │ ├── DecAssist.py │ │ ├── Layers.py │ │ ├── MaxOut.py │ │ └── SubLayers.py └── utils │ ├── mask.py │ ├── model_builder.py │ ├── sinusoid.py │ ├── train │ ├── Loss.py │ ├── Optim.py │ ├── Train.py │ └── __init__.py │ └── translate │ ├── Beam.py │ ├── Translator.py │ └── __init__.py ├── pargs.py ├── preprocess.py ├── train.py ├── translate.py └── xargs.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Sigrid 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Graphs for Generating Deep Questions 2 | 3 | This repository contains code and models for the paper: [Semantic Graphs for Generating Deep Questions (ACL 2020)](https://www.aclweb.org/anthology/2020.acl-main.135/). Below is the framework of our proposed model (on the right) together with an input example (on the left). 4 | 5 | ![Model Framework](model.jpg) 6 | 7 | ## Requirements 8 | 9 | #### Environment 10 | 11 | ``` 12 | allennlp 0.9.0 13 | overrides 3.1.0 14 | allennlp-models 1.0.0 15 | 16 | pytorch 1.4.0 17 | nltk 3.4.4 18 | numpy 1.18.1 19 | tqdm 4.32.2 20 | ``` 21 | 22 | #### Data Preprocessing 23 | 24 | We release [all the datasets below](https://drive.google.com/drive/folders/1uPQaK-cWcbkZapmC3qROkmddC_st5uhv?usp=sharing) which are processed based on [HotpotQA](https://hotpotqa.github.io/). 25 | 26 | 1. get tokenized data files of `documents`, `questions`, `answers` 27 | 28 | * get results in folder [`text-data`](https://drive.google.com/drive/folders/1nhBfk2EvOHGDRq6vPCf8Pk8wZFL0dqbf?usp=sharing) 29 | 30 | 2. prepare the json files ready as illustrated in [`build-semantic-graphs`](https://github.com/YuxiXie/SG-Deep-Question-Generation/tree/master/build-semantic-graphs) 31 | 32 | * get results in folder [`json-data`](https://drive.google.com/drive/folders/10idPzICLR_OhEZHfGnvgZcqAB1x509mE?usp=sharing) 33 | 34 | 3. run [`scripts/preprocess_data.sh`](https://github.com/YuxiXie/SG-Deep-Question-Generation/blob/master/scripts/preprocess_data.sh) to get the preprocessed data ready for training 35 | 36 | * get results in folder [`preprocessed-data`](https://drive.google.com/drive/folders/1ZvMRDtb5EeEaylEC-pKSLID0COArJ6Nf?usp=sharing) 37 | 38 | * utilize `glove.840B.300d.txt` from [GloVe](https://nlp.stanford.edu/projects/glove/) to initialize the word-embeddings 39 | 40 | #### Models 41 | 42 | We release both classifier and generator models in this work. The models are constructed based on a ***sequence-to-sequence*** architecture. Typically, we use ***GRU*** and ***GNN*** in the encoder and ***GRU*** in the decoder, you can choose other methods (*e.g.* ***Transformer***) which have also been implemented in our repository. 43 | 44 | * [classifier](https://drive.google.com/file/d/1X_fdQgQ1yv15e7QCOXkhbWpYLnoT80mH/view?usp=sharing): accuracy - 84.06773% 45 | 46 | * [generator](https://drive.google.com/file/d/1Fck0qVYNnLLz3f815CinRfWFrO2ceIfI/view?usp=sharing): BLeU-4 - 15.28304 47 | 48 | ## Training 49 | 50 | * run [`scripts/train_classifier.sh`](https://github.com/YuxiXie/SG-Deep-Question-Generation/blob/master/scripts/train_classifier.sh) to train on the ***Content Selection*** task 51 | 52 | * run [`scripts/train_generator.sh`](https://github.com/YuxiXie/SG-Deep-Question-Generation/blob/master/scripts/train_generator.sh) to train on the ***Question Generation*** task, the default one is to finetune based on the pretrained classifier 53 | 54 | ## Translating 55 | 56 | * run [`scripts/translate.sh`](https://github.com/YuxiXie/SG-Deep-Question-Generation/blob/master/scripts/translate.sh) to get the prediction on the validation dataset 57 | 58 | ## Evaluating 59 | 60 | We take use of the [Evaluation codes for MS COCO caption generation](https://github.com/salaniz/pycocoevalcap) for evaluation on automatic metrics. 61 | 62 | - To install pycocoevalcap and the pycocotools dependency, run: 63 | 64 | ``` 65 | pip install git+https://github.com/salaniz/pycocoevalcap 66 | ``` 67 | 68 | - To evaluate the results in the translated file, _e.g._ `prediction.txt`, run: 69 | 70 | ``` 71 | python evaluate_metrics.py prediction.txt 72 | ``` 73 | 74 | ## Citation 75 | ``` 76 | @inproceedings{pan-etal-2020-DQG, 77 | title = {Semantic Graphs for Generating Deep Questions}, 78 | author = {Pan, Liangming and Xie, Yuxi and Feng, Yansong and Chua, Tat-Seng and Kan, Min-Yen}, 79 | booktitle = {Proceedings of Annual Meeting of the Association for Computational Linguistics (ACL)}, 80 | year = {2020} 81 | } 82 | ``` 83 | -------------------------------------------------------------------------------- /build-semantic-graphs/README.md: -------------------------------------------------------------------------------- 1 | # Method of Building Semantic Graphs (DP-based) 2 | 3 | These files define the rules to build semantic graphs for source text depending on the results of **Dependency Parsing**. 4 | The 3 crucial parts of this method is: 5 | 6 | * Extract _[entity, relation, entity]_ pattern from each sentence, like **Semantic Role Labeling** 7 | 8 | * Merge and Prune nodes in the graph to adrress the problems of _(1)_ sparsity brought by fine-grained nodes _(2)_ graph noise caused by redundant and meaningless punctuation, conjunction, etc. 9 | 10 | * Connect _SIMILAR_ nodes to help to connect sub-graphs of all sentences and get a unified graph for each evidence 11 | 12 | --- 13 | 14 | * Requirement 15 | 16 | ``` 17 | allennlp==0.9.0 18 | overrides==3.1.0 19 | ``` 20 | 21 | * Predictors for dependency parsing and coreference resolution 22 | 23 | - The links to the predictors in our code may not be up-to-date, you may need to check the availability before running the code. 24 | 25 | --- 26 | 27 | To run the codes, execute the commands below: 28 | 29 | * Get the raw json files from [HotpotQA](https://hotpotqa.github.io/) [`training set` & `dev set(distractor)`] preprocessed 30 | 31 | ```bash 32 | python preprocess/preprocess_raw_data.py train.json valid.json data 33 | ``` 34 | 35 | * Get the results of dependency parsing and coreference resolution 36 | 37 | - To initialize the predictors, you need to download the models of dependency parsing and coreference resolution, _e.g._, the latest models released from [AllenNLP](https://demo.allennlp.org/). 38 | 39 | ```bash 40 | python preprocess/get_coref_and_dep_data.py data.train.json data.valid.json dp.json crf_rsltn.json 41 | ``` 42 | 43 | - Since it will take long time to get these files finished, we provide the final data --- [dp.json](https://drive.google.com/file/d/1hdwS5nC86Jrss7HLQt1eds-RjSZjNJBC/view?usp=sharing) and [crf_rsltn.json](https://drive.google.com/file/d/1U9dNzAmNx1TyQ2BjVBYJ-oJ17Lws0dYE/view?usp=sharing). 44 | 45 | * Merge data file (train or valid) with the result files from **Coreference Resolution** and **Dependency Parsing** 46 | 47 | ```bash 48 | python merge.py data.json dp.json crf_rsltn.json merged_data.json 49 | ``` 50 | 51 | * Build Semantic Graphs with _Question Tags_ (i.e., whether a node contains span(s) in the question) as the groundtruth of **Context Selection** and also _Answer Tags_ (i.e., whether a node contains span(s) in the answer) 52 | 53 | - Here you also need to provide the corresponding tokenized `questions.txt` and `answers.txt` files (cf., [`text-data`](https://drive.google.com/drive/folders/1nhBfk2EvOHGDRq6vPCf8Pk8wZFL0dqbf?usp=sharing)) 54 | 55 | - This script will also generate the corresponding tokenized `source.txt`, so you need to provide the directory to dump the data as well. 56 | 57 | ```bash 58 | python build_semantic_graph.py merged_data.json questions.txt answers.txt source.txt graph_with_tags.json 59 | ``` 60 | -------------------------------------------------------------------------------- /build-semantic-graphs/answer_tag.py: -------------------------------------------------------------------------------- 1 | import json 2 | import codecs 3 | import sys 4 | import re 5 | import nltk 6 | from tqdm import tqdm 7 | from nltk.corpus import stopwords 8 | from nltk.stem.porter import PorterStemmer 9 | 10 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 11 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 12 | 13 | pattern = re.compile('[\W]*') 14 | 15 | 16 | def ans_tag(answer, corpus): 17 | src = [] 18 | porter_stemmer, stopwords_eng = PorterStemmer(), stopwords.words('english') 19 | for index, sample in tqdm(enumerate(zip(answer, corpus))): 20 | ans, cps = sample[0], sample[1] 21 | ans, nodes, edges = ans.strip().split(), cps['nodes'], cps['edges'] 22 | node_words = [node['word'].split(' ') for node in nodes] 23 | node_stem = [[porter_stemmer.stem(w) for w in node] for node in node_words] 24 | ans = [w for w in ans if len(w) > 0 and w not in stopwords_eng and not pattern.fullmatch(w)] 25 | ans_stem = [porter_stemmer.stem(w) for w in ans] 26 | src.append(cps['text']) 27 | ans_indexes = [[] for _ in ans] 28 | node_contain = [0 for _ in nodes] 29 | for idx, word in enumerate(ans): 30 | for id_node, words in enumerate(node_words): 31 | if ans_stem[idx] in node_stem[id_node] or word in words or any([w.count(word) > 0 for w in words]): 32 | ans_indexes[idx].append(id_node) 33 | node_contain[id_node] += 1 34 | for indexes in ans_indexes: 35 | if len(indexes) > 0: 36 | indexes.sort(key=lambda idx: (node_contain[idx], -len(nodes[idx]['index']))) 37 | nodes[indexes[-1]]['ans'] = 1 38 | for idx in indexes[:-1]: 39 | if edges[idx][indexes[-1]] == 'CHILD': 40 | nodes[idx]['ans'] = 1 41 | for idx, node in enumerate(nodes): 42 | if 'ans' not in node: 43 | nodes[idx]['ans'] = 0 44 | #for idx, node in enumerate(nodes): 45 | #words = node['word'].strip().split() 46 | #flag = any([w in words for w in ans]) 47 | #if flag: 48 | #nodes[idx]['type'] = 'A' 49 | corpus[index]['nodes'] = nodes 50 | return src, corpus 51 | 52 | 53 | if __name__ == '__main__': 54 | answer = sys.argv[1] 55 | corpusf = sys.argv[2] 56 | source = sys.argv[3] 57 | 58 | with open(answer, 'r', encoding='utf-8') as f: 59 | ans = f.read().strip().split('\n') 60 | 61 | corpus = json_load(corpusf) 62 | 63 | src, corpus = ans_tag(answer, corpus) 64 | 65 | with open(source, 'w', encoding='utf-8') as f: 66 | f.write('\n'.join(src) + '\n') 67 | 68 | json_dump(corpus, corpusf) 69 | 70 | -------------------------------------------------------------------------------- /build-semantic-graphs/build_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this code, we build a graph for each sentence. 3 | 4 | There are two crucial parts at this step: 5 | 1st Merge nodes which are almost the same with each other as one node if meet requirements 6 | 2nd Redirect to make it sure that: 7 | (1) entity --> predicate <-- entity in each [entity, predicate, entity] triple 8 | (2) those parallel words connect to their real parents 9 | """ 10 | 11 | import sys 12 | from tqdm import tqdm 13 | 14 | import json 15 | import codecs 16 | 17 | 18 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 19 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 20 | 21 | 22 | subj_and_obj = ['nsubj', 'nsubjpass', 'csubj', 'csubjpass'] + ['dobj', 'pobj', 'iobj'] 23 | others_dep = ['poss', 'npadvmod', 'appos', 'nn'] 24 | conj = ['conj', 'cc', 'preconj', 'parataxis'] 25 | verb_pos = ['VBZ', 'VBN', 'VBD', 'VBP', 'VB', 'VBG', 'IN', 'TO', 'PP'] 26 | noun_pos = ['NN', 'NNP', 'NNS', 'NNPS'] 27 | subj = ['nsubj', 'nsubjpass', 'csubj', 'csubjpass'] 28 | 29 | 30 | def count_nodes(nodes, tree): 31 | ## initialize node 32 | str_words = ' '.join(tree['word']) 33 | node = {'type': tree['type'], 'dep': tree['dep'], 'pos':tree['pos'], 'word': str_words, 'index':tree['index']} 34 | ## merge almost the same nodes as one node if meet requirements 35 | for idx, exist in enumerate(nodes): 36 | # requirement one: has common words and has the same type 37 | if all([w in exist['word'].split(' ') for w in node['word'].split(' ')]): 38 | if all([w in node['word'].split(' ') for w in exist['word'].split(' ')]) and node['type'] == exist['type']: 39 | # requirement two: noun-like nodes 40 | if node['pos'] in noun_pos or node['dep'] in subj_and_obj + others_dep: 41 | # requirement three: has upper case and not quite short 42 | if any([w.isupper() for w in node['word']]) and len(node['word'].split(' ')) > 1: 43 | # requirement four: enough high level of overlapping 44 | if len(node['word'].split(' ')) / len(exist['word'].split(' ')) > 0.9: 45 | tree['node_num'] = idx 46 | break 47 | ## added as new node if not meet above requirements 48 | if 'node_num' not in tree: 49 | tree['node_num'] = len(nodes) 50 | nodes.append(node) 51 | ## collect child nodes 52 | if 'noun' in tree: 53 | for child in tree['noun']: 54 | count_nodes(nodes, child) 55 | if 'verb' in tree: 56 | for child in tree['verb']: 57 | count_nodes(nodes, child) 58 | if 'attribute' in tree: 59 | for child in tree['attribute']: 60 | count_nodes(nodes, child) 61 | 62 | 63 | def draw_graph(graph, tree): 64 | index = tree['node_num'] 65 | ## copy existed edges in tree 66 | children = tree['noun'] if 'noun' in tree else [] 67 | children += tree['verb'] if 'verb' in tree else [] 68 | children += tree['attribute'] if 'attribute' in tree else [] 69 | for child in children: 70 | if child['dep'] != 'punct' or 'noun' in child or 'verb' in child or 'attribute' in child: 71 | idx = child['node_num'] 72 | graph[idx][index] = child['dep'] 73 | draw_graph(graph, child) 74 | ## redirect to make it sure that: 75 | # 1. entity --> predicate <-- entity in each [entity, predicate, entity] triple 76 | # 2. those parallel words connect to their real parents 77 | if 'noun' in tree and 'verb' in tree: 78 | ## for 'V' type node 79 | if tree['type'] == 'V' or tree['pos'] in verb_pos: 80 | for verb in tree['verb']: 81 | v_idx = verb['node_num'] 82 | is_replace = False # whether to redirect 83 | for noun in tree['noun']: 84 | if noun['dep'] in ['nsubj', 'nsubjpass', 'csubj', 'csubjpass']: 85 | is_replace = True 86 | n_idx = noun['node_num'] 87 | graph[n_idx][v_idx] = noun['dep'] 88 | if is_replace and graph[v_idx][index] in conj: 89 | graph[v_idx][index] = '' 90 | ## for 'A'/'M' type node 91 | else: 92 | verbs = [] 93 | for v in tree['verb']: 94 | # collect verbs to do redirecting with 95 | if v['word'] not in [vrb['word'] for vrb in verbs] or 'noun' in v or 'attribute' in v or 'verb' in v: 96 | verbs.append(v) 97 | tree['verb'] = verbs 98 | for verb in tree['verb']: 99 | v_idx = verb['node_num'] 100 | for noun in tree['noun']: 101 | n_idx = noun['node_num'] 102 | # for parallel words 103 | if graph[n_idx][index] in conj: 104 | if 'verb' not in noun: 105 | graph[v_idx][n_idx] = verb['dep'] 106 | graph[n_idx][index] = '' 107 | else: 108 | nv_idx = noun['verb'][0]['node_num'] 109 | for nn in [tmp for tmp in tree['noun'] if tmp['dep'] in subj]: 110 | graph[nn['node_num']][nv_idx] = nn['dep'] 111 | graph[n_idx][index] = '' 112 | # for [entity, predicate, entity] triple where predicate is corpula 113 | if noun['dep'] in subj_and_obj and noun['dep'] != 'pobj': 114 | graph[n_idx][v_idx] = noun['dep'] 115 | graph[n_idx][index] = '' 116 | 117 | 118 | def get_graph(tree): 119 | ## collect nodes 120 | nodes = [] 121 | count_nodes(nodes, tree) 122 | ## draw edges 123 | edges = [['' for _ in nodes] for _ in nodes] 124 | draw_graph(edges, tree) 125 | 126 | return {'nodes': nodes, 'edges': edges} 127 | 128 | 129 | if __name__ == '__main__': 130 | ##=== load file ===## 131 | tree = json_load(sys.argv[1]) 132 | ##=== build graph ===## 133 | graph = [] 134 | for sample in tqdm(tree, desc=' - (Building Graphs) - '): 135 | evidence = [{'sequence':sent['sequence'], 'graph':get_graph(sent['tree'])} for sent in sample] 136 | graph.append(evidence) 137 | ##=== dump file ===## 138 | json_dump(graph, sys.argv[2]) -------------------------------------------------------------------------------- /build-semantic-graphs/build_semantic_graph.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | 4 | from build_tree import build_tree 5 | from prune_and_merge_tree import prune 6 | from rearrange_tree import rearrange 7 | from build_graph import get_graph 8 | from merge_graph import merge 9 | from tag import text_load, main 10 | from answer_tag import ans_tag 11 | 12 | import json 13 | import codecs 14 | 15 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 16 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 17 | 18 | 19 | if __name__ == '__main__': 20 | data = json_load(sys.argv[1]) 21 | 22 | graphs = [] 23 | for idx, sample in tqdm(enumerate(data), desc=' - (Building Graphs) - '): 24 | corpus = sample['evidence'] 25 | evidence = [] 26 | for sent in corpus: 27 | sent = build_tree(sent) 28 | sent = {'sequence':sent['words'], 'tree':prune(sent['tree'], sent['words'])} 29 | sent = {'sequence': sent['sequence'], 'tree': rearrange(sent['tree'], sent['sequence'])} 30 | evidence.append({'sequence':sent['sequence'], 'graph':get_graph(sent['tree'])}) 31 | graph = merge(evidence) 32 | graphs.append(graph) 33 | 34 | questions = text_load(sys.argv[2]) 35 | qu_graphs = main(graphs, questions) 36 | 37 | answers = text_load(sys.argv[3]) 38 | src, final_graphs = ans_tag(answers, qu_graphs) 39 | 40 | with open(sys.argv[4], 'w', encoding='utf-8') as f: 41 | f.write('\n'.join(src) + '\n') 42 | 43 | json_dump(final_graphs, sys.argv[5]) 44 | -------------------------------------------------------------------------------- /build-semantic-graphs/build_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this code, we build a tree for each sentence based on dependency parsing. 3 | 4 | There are two crucial parts at this step: 5 | 1st classify each node into two types: 6 | 'V' for verbs and 'A' for arguments 7 | 2nd classify the child nodes of each node into three groups: 8 | 'verb' for predicates, 'noun' for subjects/objects and 'attribute' for others 9 | """ 10 | 11 | import sys 12 | from tqdm import tqdm 13 | 14 | import json 15 | import codecs 16 | 17 | 18 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 19 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 20 | 21 | verb_pos = ['VBZ', 'VBN', 'VBD', 'VBP', 'VB', 'VBG'] 22 | prep_pos = ['PP', 'IN', 'TO'] 23 | subj_and_obj = ['nsubj', 'nsubjpass', 'csubj', 'csubjpass'] + ['dobj', 'pobj', 'iobj'] 24 | conj = ['conj', 'parataxis'] 25 | modifier_pos = ['JJ', 'FW', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'] 26 | modifiers = ['amod', 'nn', 'mwe', 'advmod', 'quantmod', 'npadvmod', 'advcl', 'poss', 27 | 'possessive', 'neg', 'auxpass', 'aux', 'det', 'dep', 'predet', 'num'] 28 | 29 | 30 | def merge_node(raw, sequence): 31 | node = {k: v for k, v in raw.items()} 32 | attribute = raw['attribute'] 33 | 34 | attr1, attr2 = [], [] # attr1: ok to merge 35 | indexes = [idx for idx in node['index']] 36 | for a in attribute: 37 | if 'attribute' in a or 'noun' in a or 'verb' in a: 38 | attr2.append(a) 39 | elif (a['dep'] in modifiers or a['pos'] in modifier_pos) and a['pos'] not in prep_pos: 40 | attr1.append(a) 41 | indexes += [idx for idx in a['index']] 42 | else: 43 | attr2.append(a) 44 | 45 | if len(attr1) > 0: 46 | indexes.sort(key=lambda x:x) 47 | flags = [index not in indexes[:idx] for idx, index in enumerate(indexes)] 48 | if len(indexes) == indexes[-1] - indexes[0] + 1 and all(flags): # need to be consecutive modifiers 49 | node['word'] = [sequence[i] for i in indexes] 50 | node['index'] = indexes 51 | if len(attr2) > 0: 52 | node['attribute'] = [a for a in attr2] 53 | else: 54 | del node['attribute'] 55 | 56 | return node 57 | 58 | 59 | def build_detailed_tree(sequence, all_dep, root, word_type): 60 | 61 | def is_noun(node): 62 | return node['dep'] in subj_and_obj or (all_dep[root]['dep'] in subj_and_obj and node['dep'] == 'conj') 63 | 64 | def is_verb(node): 65 | return (node['dep'] == 'cop' and word_type == 'A') or (word_type == 'V' and node['dep'] == 'conj') 66 | ##=== initialize tree-node ===## 67 | element = all_dep[root] 68 | word_type = 'V' if element['pos'] in verb_pos else 'A' 69 | node = {'word': [sequence[root]], 'index': [root], 'type': word_type, 'dep': element['dep'], 'pos': element['pos']} 70 | ##=== classify child node sets ===## 71 | children = [(i, elem) for i, elem in enumerate(all_dep) if elem['head'] == root] 72 | nouns = [child for child in children if is_noun(child[1])] 73 | if len(nouns) > 0: 74 | node['noun'] = [build_detailed_tree(sequence, all_dep, child[0], 'A') for child in nouns] 75 | verbs = [child for child in children if is_verb(child[1])] 76 | if len(verbs) > 0: 77 | node['verb'] = [build_detailed_tree(sequence, all_dep, child[0], 'V') for child in verbs] 78 | attributes = [child for child in children if child not in nouns + verbs] 79 | if len(attributes) > 0: 80 | node['attribute'] = [build_detailed_tree(sequence, all_dep, child[0], 'A') for child in attributes] 81 | ##=== do node-merging ===## 82 | if 'attribute' in node: 83 | node = merge_node(node, sequence) 84 | 85 | return node 86 | 87 | 88 | def build_tree(sent): 89 | dep, sequence, title = sent['dependency_parse'], sent['coreference'], sent['title'] 90 | root = [i for i in range(len(dep)) if dep[i]['head'] == -1] 91 | heads_dep = [w['dep'] for w in dep if w['head'] == root[0]] 92 | 93 | word_type = 'V' if dep[root[0]]['pos'] in verb_pos or 'cop' not in heads_dep else 'A' 94 | tree = build_detailed_tree(sequence, dep, root[0], word_type) 95 | 96 | return {'words': sequence, 'tree': tree, 'title': title} 97 | 98 | 99 | if __name__ == '__main__': 100 | ##=== load raw file ===## 101 | data = json_load(sys.argv[1]) 102 | ##=== build trees ===## 103 | tree = [] 104 | for sample in tqdm(data, desc=' - (Building Trees) - '): 105 | answer, corpus = sample['answer'], sample['evidence'] 106 | evidence = [build_tree(sent) for sent in corpus] 107 | tree.append(evidence) 108 | ##=== dump file ===## 109 | json_dump(tree, sys.argv[2]) 110 | -------------------------------------------------------------------------------- /build-semantic-graphs/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this code, we collect the results of coreference-resolution 3 | and semantic-role-labeling for each sentence. 4 | 5 | PS: we will enlarge the evidence of each sample by add sentence(s) 6 | to it in case that there are no sentences covering some import part 7 | of the question in the existed evidence. 8 | """ 9 | 10 | import sys 11 | from tqdm import tqdm 12 | 13 | import json 14 | import codecs 15 | 16 | import nltk 17 | from nltk.corpus import stopwords 18 | 19 | 20 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 21 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 22 | 23 | 24 | verb_pos = ['VBZ', 'VBN', 'VBD', 'VBP', 'VB', 'VBG', 'IN', 'TO', 'PP'] 25 | noun_pos = ['NN', 'NNP', 'NNS', 'NNPS'] 26 | other_pos = ['JJ', 'FW', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'] 27 | 28 | 29 | def add_corpus(question, evidence, crf, dep): 30 | stopwords_eng = stopwords.words('english') 31 | 32 | question = question[:-1].strip().split() 33 | pos_tags = nltk.pos_tag(question) 34 | 35 | evidences = [word.lower() for evd in evidence for word in evd['coreference']] 36 | coref_list = [' '.join(evd['coreference']) for evd in evidence] 37 | titles = [evd['title'] for evd in evidence] 38 | 39 | for word, pos in zip(question, pos_tags): 40 | ## requirements: noun-type word; not short; upper case; not stop-word 41 | if pos[1] in noun_pos and len(word) >= 3 and not word.islower() and word.lower() not in stopwords_eng: 42 | if word.lower() not in evidences and all([w.count(word.lower()) == 0 for w in evidences]): 43 | for title in titles: 44 | flags = [word in sent for sent in crf[title]] 45 | if any(flags): 46 | ## found a sentence to add into the evidence 47 | index = flags.index(True) 48 | addlist = crf[title][index] 49 | if ' '.join(addlist) not in coref_list: 50 | coref_list.append(' '.join(addlist)) 51 | evidences += [w.lower() for w in addlist] 52 | evidence.append({'text': coref_list[-1], 'dependency_parse': dep[title][index], 53 | 'coreference': addlist, 'title': title}) 54 | break 55 | 56 | return evidence 57 | 58 | 59 | def merge(raw, dep, crf): 60 | corpus = [] 61 | for sample in tqdm(raw, desc=' - (MERGING) - '): 62 | question, answer = sample['question'], sample['answer'] 63 | evidence_index, evidence_sent = [evd['index'] for evd in sample['evidence']], [evd['text'] for evd in sample['evidence']] 64 | 65 | evidence, cnt = [], 0 66 | for index, sent in zip(evidence_index, evidence_sent): 67 | for i in range(index[1][0], index[1][1]): 68 | dependency = dep[index[0]][i] 69 | dependency = [] if dependency is None else dependency 70 | coref = crf[index[0]][i] 71 | evidence.append({'text':sent, 'dependency_parse':dependency, 'coreference':coref, 'title':index[0]}) 72 | cnt += 1 73 | 74 | sample = {'question':question, 'answer':answer, 'evidence':evidence} 75 | sample['evidence'] = add_corpus(question, evidence, crf, dep) ## enlarge the evidence if needed 76 | 77 | corpus.append(sample) 78 | 79 | return corpus 80 | 81 | 82 | if __name__ == '__main__': 83 | ##=== load files ===### 84 | raw = json_load(sys.argv[1]) 85 | dep = json_load(sys.argv[2]) 86 | crf = json_load(sys.argv[3]) 87 | ##=== merge files ===### 88 | data = merge(raw, dep, crf) 89 | ##=== dump file ===### 90 | json_dump(data, sys.argv[4]) -------------------------------------------------------------------------------- /build-semantic-graphs/merge_graph.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this code, we merge the graphs of the sentences in the same evidence into a unified graph. 3 | 4 | To connect between subgraphs, we introduce 'SIMILAR' edges 5 | """ 6 | 7 | import sys 8 | from tqdm import tqdm 9 | 10 | import nltk 11 | from nltk.stem.porter import PorterStemmer 12 | from nltk.corpus import stopwords 13 | 14 | import json 15 | import codecs 16 | import re 17 | 18 | 19 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 20 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 21 | 22 | 23 | pattern = re.compile('[\W]*') 24 | verb_pos = ['VBZ', 'VBN', 'VBD', 'VBP', 'VB', 'VBG', 'IN', 'TO', 'PP'] 25 | noun_pos = ['NN', 'NNP', 'NNS', 'NNPS'] 26 | modifier_pos = ['JJ', 'FW', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'] 27 | subj_and_obj = ['nsubj', 'nsubjpass', 'csubj', 'csubjpass'] + ['dobj', 'pobj', 'iobj'] 28 | modifiers = ['amod', 'nn', 'mwe'] 29 | pronouns = ['it', 'its', 'him', 'he', 'his', 'she', 'her', 'hers', 'they', 'them', 'their', 'theirs'] 30 | 31 | 32 | def draw_edge(node_raw_id, final_nodes, raw_nodes, final_edges, raw_edges, 33 | accumulate_node, accumulate_word, reindex_list): 34 | stopwords_eng = stopwords.words('english') 35 | 36 | NDE, raw_node_num = raw_nodes[node_raw_id], len(raw_nodes) 37 | NDE_words = NDE['word'].strip().split(' ') 38 | 39 | final_nodes[accumulate_node + node_raw_id]['index'] = [ii + accumulate_word for i in NDE['index'] for ii in reindex_list[i]] 40 | 41 | ## copy existed edges 42 | for i, edge in enumerate(raw_edges[node_raw_id]): 43 | if i == node_raw_id: 44 | final_edges[accumulate_node + node_raw_id][accumulate_node + node_raw_id] = 'SELF' 45 | elif edge: 46 | final_edges[accumulate_node + node_raw_id][accumulate_node + i] = edge 47 | ## connect 'SIMILAR' nodes 48 | for i, node in enumerate(final_nodes[accumulate_node + raw_node_num: ]): 49 | i_idx = i + accumulate_node + raw_node_num 50 | words = node['word'].strip().split(' ') 51 | # get 'important' word list 52 | word_i = [w for w in words if len(w) > 0 and w not in stopwords_eng and not pattern.fullmatch(w) and w.lower() not in pronouns] 53 | word_j = [w for w in NDE_words if len(w) > 0 and w not in stopwords_eng and not pattern.fullmatch(w) and w.lower() not in pronouns] 54 | # get common 'important' words 55 | common1 = [w for w in word_i if w in word_j or w.lower() in [ww.lower() for ww in word_j]] 56 | common2 = [w for w in word_j if w in word_i or w.lower() in [ww.lower() for ww in word_i]] 57 | common = common1 if len(common1) < len(common2) else common2 58 | # whether have noun-or-modifier-like words in common 59 | mono_pos = nltk.pos_tag(common) 60 | flag_pos = any([mono[1] in noun_pos + modifier_pos for mono in mono_pos]) 61 | # whether have upper-case words in common 62 | flag_up = any([not w.islower() for w in common]) 63 | # whether is of the same kind of pos&dep tag 64 | pos_qualify = NDE['pos'] in verb_pos + noun_pos and node['pos'] in verb_pos + noun_pos 65 | dep_qualify = NDE['dep'] in subj_and_obj + modifiers and node['dep'] in subj_and_obj + modifiers 66 | 67 | if pos_qualify or dep_qualify: 68 | if (flag_up or flag_pos) and len(word_i) * len(word_j) > 0: 69 | prb1, prb2 = len(common1) / len(word_i), len(common2) / len(word_j) 70 | if max(prb1, prb2) > 1/2 and min(prb1, prb2) > 1/3: # requirement of the lavel of overlapping 71 | final_edges[accumulate_node + node_raw_id][i_idx] = final_edges[i_idx][accumulate_node + node_raw_id] = 'SIMILAR' 72 | 73 | 74 | def merge(corpus): 75 | 76 | def reindex(sequence): 77 | '''do reindexing 78 | because we have coreference resolution, which means 79 | there may be more than one words in the so-called 'one' word in fact 80 | ''' 81 | cnt, new_seq = 0, [] 82 | dicts = [[] for _ in sequence] 83 | for i, w in enumerate(sequence): 84 | wrd_cnt = max(len(w.strip().split(' ')), 1) 85 | dicts[i] = [i for i in range(cnt, cnt + wrd_cnt)] 86 | cnt += wrd_cnt 87 | new_seq += w.strip().split(' ') 88 | length = len(new_seq) 89 | return dicts, length, new_seq 90 | 91 | ## initialize node and edge list 92 | sequences, subgraphs = [sent['sequence'] for sent in corpus], [sent['graph'] for sent in corpus] 93 | 94 | final_nodes = [node for subgraph in subgraphs for node in subgraph['nodes']] 95 | nodes_num = len(final_nodes) 96 | final_edges = [['' for _ in range(nodes_num)] for _ in range(nodes_num)] 97 | 98 | ## merge subgraphs into final graph 99 | word_cnt, node_cnt, new_sequences = 0, 0, [] 100 | for sequence, subgraph in zip(sequences, subgraphs): 101 | indexes, length, new_seq = reindex(sequence) 102 | nodes, edges = subgraph['nodes'], subgraph['edges'] 103 | for idx, node in enumerate(nodes): 104 | draw_edge(node_raw_id=idx, final_nodes=final_nodes, raw_nodes=nodes, 105 | final_edges=final_edges, raw_edges=edges, accumulate_node=node_cnt, 106 | accumulate_word=word_cnt, reindex_list=indexes) 107 | 108 | node_cnt += len(nodes) 109 | word_cnt += length 110 | new_sequences += new_seq 111 | 112 | sequences = ' '.join(new_sequences) 113 | 114 | return {'nodes':final_nodes, 'edges':final_edges, 'text':sequences} 115 | 116 | 117 | if __name__ == '__main__': 118 | ##=== load file ===## 119 | raw = json_load(sys.argv[1]) 120 | ##=== merge graphs ===## 121 | graphs = [] 122 | for sample in tqdm(raw, desc=' - (Building Graphs) - '): 123 | graph = merge(sample) 124 | graphs.append(graph) 125 | ##=== dump file ===## 126 | json_dump(graphs, sys.argv[2]) -------------------------------------------------------------------------------- /build-semantic-graphs/preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Preprocessing the Raw HotpotQA Json Files 2 | 3 | Also see how to run the codes in the [previous directory](https://github.com/YuxiXie/SG-Deep-Question-Generation/tree/master/build-semantic-graphs). 4 | 5 | * get `data.json` files for training and validation 6 | 7 | - Download the raw json files from [HotpotQA](https://hotpotqa.github.io/) [`training set` & `dev set(distractor)`] then run: 8 | 9 | ```bash 10 | python preprocess_raw_data.py train.json valid.json data 11 | ``` 12 | 13 | * Get the results of dependency parsing and coreference resolution 14 | 15 | - You need to first download the model files (the links for the old models we used may change), or you could use the latest models released from [AllenNLP](https://demo.allennlp.org/). 16 | 17 | ```bash 18 | python get_coref_and_dep_data.py data.train.json data.valid.json dp.json crf_rsltn.json 19 | ``` 20 | 21 | - Since it will take long time to get these files finished, we provide the final data --- [dp.json](https://drive.google.com/file/d/1KnZXqchvHqMZnTh_7tuE57cd934aMBIF/view?usp=sharing) and [crf_rsltn.json](https://drive.google.com/file/d/1I8xTvhkEXpiq4D25Dr7XRUIoe779Ytve/view?usp=sharing). 22 | -------------------------------------------------------------------------------- /build-semantic-graphs/preprocess/get_coref_and_dep_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import codecs 3 | import json 4 | from tqdm import tqdm 5 | from allennlp.predictors.predictor import Predictor 6 | from nltk import word_tokenize 7 | 8 | 9 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 10 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 11 | 12 | 13 | def get_context(data): 14 | context = {} 15 | for sample in tqdm(data): 16 | for k, v in sample['context'].items(): 17 | context[k] = v 18 | return context 19 | 20 | 21 | def get_dependency(sent, dependency_parser): 22 | if len(sent.strip()) == 0: 23 | return None 24 | try: 25 | sent = dependency_parser.predict(sentence=sent) 26 | except: 27 | import ipdb; ipdb.set_trace() 28 | words, pos, heads, dependencies = sent['words'], sent['pos'], sent['predicted_heads'], sent['predicted_dependencies'] 29 | result = [{'word':w, 'pos':p, 'head':h - 1, 'dep':d} for w, p, h, d in zip(words, pos, heads, dependencies)] 30 | return result 31 | 32 | 33 | def dependency_parse(raw, filename): 34 | dependency_parser = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/biaffine-dependency-parser-ptb-2018.08.23.tar.gz") 35 | context = { 36 | key: [ 37 | get_dependency(sent, dependency_parser) for sent in value 38 | ] for key, value in tqdm(raw.items(), desc=' - (Dependency Parsing: 1st) - ') 39 | } 40 | json_dump(context, filename) 41 | 42 | 43 | def get_coreference(doc, coref_reslt, pronouns, title): 44 | 45 | def get_crf(span, words): 46 | phrase = [] 47 | for i in range(span[0], span[1] + 1): 48 | phrase += [words[i]] 49 | return (' '.join(phrase), span[0], span[1] - span[0] + 1) 50 | 51 | def get_best(crf): 52 | crf.sort(key=lambda x: x[2], reverse=True) 53 | if crf[0][2] == 1: 54 | crf.sort(key=lambda x: len(x[0]), reverse=True) 55 | for w in crf: 56 | if w[0].lower() not in pronouns and w[0].lower() != '\t': 57 | return w[0] 58 | return None 59 | 60 | doc = coref_reslt.predict(document=doc) 61 | words = [w.strip(' ') for w in doc['document']] 62 | clusters = doc['clusters'] 63 | 64 | for group in clusters: 65 | crf = [get_crf(span, words) for span in group] 66 | entity = get_best(crf) 67 | if entity in ['\t', None]: 68 | try: 69 | entity = coref_reslt.predict(document=title) 70 | entity = ' '.join(entity['document']) 71 | except: 72 | entity = ' '.join(word_tokenize(title)) 73 | if entity not in ['\t', None]: 74 | for phrase in crf: 75 | if phrase[0].lower() in pronouns: 76 | index = phrase[1] 77 | words[index] = entity 78 | 79 | doc, sent = [], [] 80 | for word in words: 81 | if word.strip(' ') == '\t': 82 | doc.append(sent) 83 | sent = [] 84 | else: 85 | if word.count('\t'): 86 | print(word) 87 | word = word.strip('\t') 88 | sent.append(word) 89 | doc.append(sent) 90 | return doc 91 | 92 | 93 | def coreference_resolution(raw, filename): 94 | pronouns = ['it', 'its', 'he', 'him', 'his', 'she', 'her', 'they', 'their', 'them'] 95 | raw = {k: '\t'.join(v) for k,v in raw.items()} 96 | coref_reslt = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/coref-model-2018.02.05.tar.gz") 97 | context = { 98 | key: get_coreference(value, coref_reslt, pronouns, key) for key, value in tqdm(raw.items(), desc=' - (crf for evidence) ') 99 | } 100 | json_dump(context, filename) 101 | 102 | 103 | def get_ner(doc, ner_tagger): 104 | try: 105 | doc = ner_tagger.predict(sentence=doc) 106 | except: 107 | return [[doc, 'O']] 108 | words, tags = doc['words'], doc['tags'] 109 | return [[w, t] for w, t in zip(words, tags)] 110 | 111 | 112 | def ner_tag(raw, filename): 113 | ner_tagger = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/ner-model-2018.12.18.tar.gz") 114 | raw = [[d[0], d[0]] for d in raw] #raw = [[d[0], '\t'.join(d[1])] for d in raw] 115 | context = {sample[0]: get_ner(sample[1], ner_tagger) for sample in tqdm(raw, desc=' - (ner for evidence) ')} 116 | json_dump(context, filename) 117 | 118 | 119 | def sr_labeling(sent, sr_labeler): 120 | if len(sent.strip()) == 0: 121 | return None 122 | try: 123 | sent = sr_labeler.predict(sentence=sent) 124 | except: 125 | import ipdb; ipdb.set_trace() 126 | length, words, verbs = len(sent['words']), sent['words'], sent['verbs'] 127 | tags = [verb['tags'] for verb in verbs] 128 | return {'words':words, 'tags':tags} 129 | 130 | 131 | def semantic_role_labeling(raw, filename): 132 | sr_labeler = Predictor.from_path("https://s3-us-west-2.amazonaws.com/allennlp/models/srl-model-2018.05.25.tar.gz") 133 | context = {sample[0]: [sr_labeling(sent, sr_labeler) for sent in sample[1]] for sample in tqdm(raw, desc=' - (Semantic Role Labeling: 1st) - ')} 134 | json_dump(context, filename) 135 | 136 | 137 | if __name__ == '__main__': 138 | train_data_file, valid_data_file = sys.argv[1], sys.argv[2] 139 | data = json_load(train_data_file) + json_load(valid_data_file) 140 | context = get_context(data) 141 | print('number of context:', len(context)) 142 | 143 | dependency_parse(context, sys.argv[3]) 144 | coreference_resolution(context, sys.argv[4]) 145 | 146 | # ner_tag(context, sys.argv[3]) 147 | # semantic_role_labeling(context, sys.argv[3]) 148 | -------------------------------------------------------------------------------- /build-semantic-graphs/preprocess/preprocess_raw_data.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | 4 | import json 5 | import codecs 6 | 7 | 8 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 9 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 10 | 11 | 12 | def add_evidence(key, dictionary, evidence, evidence_index, context): 13 | if key[0] in dictionary: 14 | try: 15 | evidence.append([dictionary[key[0]][key[1]]]) 16 | evidence_index.append([key[0], (key[1], key[1] + 1)]) 17 | except: 18 | print("ERROR 1") 19 | evidence.append(dictionary[key[0]]) 20 | evidence_index.append([key[0], (0, len(dictionary[key[0]]))]) 21 | context[key[0]] = dictionary[key[0]] 22 | else: 23 | print("ERROR 2") 24 | flags = [k.count(key[0]) for k in dictionary] 25 | if any(flags): 26 | index = flags.index(True) 27 | key[0] = list(dictionary.keys()) 28 | key[0] = key[0][index] 29 | try: 30 | evidence.append([dictionary[key[0]][key[1]]]) 31 | evidence_index.append([key[0], (key[1], key[1] + 1)]) 32 | except: 33 | print("ERROR 4") 34 | evidence.append(dictionary[key[0]]) 35 | evidence_index.append([key[0], (0, len(dictionary[key[0]]))]) 36 | context[key[0]] = dictionary[key[0]] 37 | else: 38 | print("ERROR 3") 39 | return False 40 | return True 41 | 42 | 43 | def extract(data): 44 | paragraphs = [sample['context'] for sample in data] 45 | paragraphs = {c[0]:c[1] for sample in paragraphs for c in sample} 46 | 47 | corpus = [] 48 | for sample in tqdm(data, desc=' - (Extract information) - '): 49 | context = {c[0]:c[1] for c in sample['context']} 50 | supporting_facts = sample['supporting_facts'] 51 | answer = sample['answer'] 52 | question = sample['question'] 53 | 54 | evidence, evidence_index = [], [] 55 | ctxt = {} 56 | 57 | for sf in supporting_facts: 58 | if not add_evidence(sf, context, evidence, evidence_index, ctxt): 59 | add_evidence(sf, paragraphs, evidence, evidence_index, ctxt) 60 | 61 | if len(evidence) > 0: 62 | evidence = [{'text':evd, 'index':idx} for evd, idx in zip(evidence, evidence_index)] 63 | sample = {'question':question, 'answer': answer, 'evidence':evidence, 'context':ctxt} 64 | corpus.append(sample) 65 | 66 | return corpus 67 | 68 | 69 | def overlap(corpus): 70 | train, valid = [], [] 71 | questions, sources = [], [] 72 | for sample in tqdm(corpus['train'], desc=' - (Deal with overlapping) - '): 73 | if sample['question'] not in questions: 74 | questions.append(sample['question']) 75 | train.append(sample) 76 | try: 77 | sources.append('\t'.join(['\t'.join(sent['text']) for sent in sample['evidence']])) 78 | except: 79 | import ipdb; ipdb.set_trace() 80 | for sample in tqdm(corpus['valid'], desc=' - (Deal with overlapping) - '): 81 | tmp = '\t'.join(['\t'.join(sent['text']) for sent in sample['evidence']]) 82 | if tmp not in sources: 83 | valid.append(sample) 84 | print(len(train), len(valid)) 85 | return {'train':train, 'valid':valid} 86 | 87 | 88 | def process(train, valid): 89 | corpus = {'train':extract(train), 'valid':extract(valid)} 90 | corpus = overlap(corpus) 91 | return corpus 92 | 93 | 94 | if __name__ == '__main__': 95 | ##=== load raw HotpotQA dataset ===## 96 | train = json_load(sys.argv[1]) 97 | valid = json_load(sys.argv[2]) 98 | ##=== run processing ===## 99 | corpus = process(train, valid) 100 | ##=== directory for saving train & valid data ===## 101 | json_dump(corpus['train'], sys.argv[3] + '.train.json') 102 | json_dump(corpus['valid'], sys.argv[3] + '.valid.json') 103 | -------------------------------------------------------------------------------- /build-semantic-graphs/prune_and_merge_tree.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this code, we do node pruning and merging in each tree. 3 | 4 | 1st prune: we prune nodes which represent unimportant information 5 | e.g. punctuation, conjunction 6 | 7 | 2nd merge: merge modifier-child nodes into their parent nodes 8 | """ 9 | 10 | import sys 11 | from tqdm import tqdm 12 | 13 | import json 14 | import codecs 15 | 16 | import re 17 | from nltk.corpus import stopwords 18 | 19 | 20 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 21 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 22 | 23 | 24 | prep_pos = ['PP', 'IN', 'TO'] 25 | modefier_pos = ['JJ', 'FW', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'] 26 | modifiers = ['amod', 'nn', 'mwe', 'num', 'quantmod', 'dep', 'number', 'auxpass', 'partmod', 'poss', 27 | 'possessive', 'neg', 'advmod', 'npadvmod', 'advcl', 'aux', 'det', 'predet', 'appos'] 28 | prune_list = ['punct', 'cc', 'preconj'] 29 | 30 | 31 | def merge(node, sequence): 32 | indexes = [idx for idx in node['index']] 33 | 34 | if 'attribute' in node and node['dep'] not in ['det', 'punct']: 35 | attr1, attr2 = [], [] 36 | for a in node['attribute']: 37 | if ('attribute' in a) or ('noun' in a) or ('verb' in a): 38 | attr2.append(a) 39 | elif a['dep'] in modifiers or a['pos'] in modefier_pos: 40 | attr1.append(a) 41 | indexes += [idx for idx in a['index']] 42 | else: 43 | attr2.append(a) 44 | 45 | if len(attr1) > 0: 46 | indexes.sort(key=lambda x:x, reverse=False) 47 | node['word'], node['index'] = [sequence[i] for i in indexes], indexes 48 | if len(attr2) > 0: 49 | node['attribute'] = [a for a in attr2] 50 | else: 51 | del node['attribute'] 52 | 53 | return node 54 | 55 | 56 | def prune(node, sequence): 57 | ## collect child nodes 58 | nouns = node['noun'] if 'noun' in node else [] 59 | verbs = node['verb'] if 'verb' in node else [] 60 | attributes = node['attribute'] if 'attribute' in node else [] 61 | ## prune and update child node sets 62 | Ns, Vs, As = [], [], [] 63 | for child in nouns + verbs + attributes: 64 | if child['pos'] not in prep_pos and child['dep'] in prune_list: 65 | Ns += child['noun'] if 'noun' in child else [] 66 | Vs += child['verb'] if 'verb' in child else [] 67 | As += child['attribute'] if 'attribute' in child else [] 68 | else: 69 | Ns += [child] if child in nouns else [] 70 | Vs += [child] if child in verbs else [] 71 | As += [child] if child in attributes else [] 72 | ## do pruning and merging on child nodes 73 | Ns = [prune(n, sequence) for n in Ns] 74 | Vs = [prune(v, sequence) for v in Vs] 75 | As = [prune(a, sequence) for a in As] 76 | ## do merging 77 | slf = {k:v for k,v in node.items() if k not in ['noun', 'verb', 'attribute']} 78 | if As: 79 | slf['attribute'] = As 80 | slf = merge(slf, sequence) 81 | ## get final node 82 | wrap = {'dep':slf['dep'], 'word':slf['word'], 'index':slf['index'], 'pos':slf['pos'], 83 | 'type':slf['type'], 'noun':Ns, 'verb':Vs} 84 | if 'attribute' in slf: 85 | wrap['attribute'] = slf['attribute'] 86 | if not Ns: 87 | del wrap['noun'] 88 | if not Vs: 89 | del wrap['verb'] 90 | return wrap 91 | 92 | 93 | if __name__ == '__main__': 94 | ##=== load raw file ===## 95 | raw = json_load(sys.argv[1]) 96 | ##=== prune and merge trees ===## 97 | tree = [] 98 | for sample in tqdm(raw, desc=' - (Pruning Trees) - '): 99 | evidence = [{'sequence':sent['words'], 'graph':prune(sent['tree'], sent['words'])} for sent in sample] 100 | tree.append(evidence) 101 | ##=== dump file ===## 102 | json_dump(tree, sys.argv[2]) 103 | -------------------------------------------------------------------------------- /build-semantic-graphs/rearrange_tree.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | 4 | import json 5 | import codecs 6 | 7 | 8 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 9 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 10 | 11 | subj_and_obj = ['nsubj', 'nsubjpass', 'csubj', 'csubjpass'] + ['dobj', 'pobj', 'iobj'] 12 | months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December'] 13 | 14 | 15 | def merge_month(node, sequence): 16 | indexes = [idx for idx in node['index']] 17 | 18 | if 'attribute' in node: 19 | for a in node['attribute']: 20 | index, _ = merge_month(a, sequence) 21 | indexes += index 22 | 23 | indexes.sort(key=lambda x:x) 24 | words = [sequence[i] for i in indexes] 25 | 26 | return indexes, words 27 | 28 | 29 | def rearrange(node, sequence): 30 | ## collect child node sets and do tree-rearranging on child nodes 31 | noun, verb = None, None 32 | slf = {k:v for k,v in node.items() if k not in ['noun', 'verb']} 33 | noun = [rearrange(n, sequence) for n in node['noun']] if 'noun' in node else [] 34 | verb = [rearrange(v, sequence) for v in node['verb']] if 'verb' in node else [] 35 | if 'attribute' in node: 36 | slf['attribute'] = [rearrange(a, sequence) for a in node['attribute']] 37 | ## redirect grandchild nodes to the current node 38 | ## rule: redirect the parallel words to their real parents 39 | if noun and node['type'] != 'V': # those nodes of type 'V' will be rearranged later 40 | for id_n in range(len(noun)): 41 | if noun[id_n]['dep'] in subj_and_obj and 'verb' not in noun[id_n] and 'noun' in noun[id_n]: 42 | new_nouns, rearrg_nouns = [], [] 43 | dep_list = [n['dep'] == 'conj' for n in noun[id_n]['noun']] # whether a parallel word 44 | for i, grandchild in enumerate(noun[id_n]['noun']): 45 | if dep_list[i]: 46 | grandchild['dep'] = noun[id_n]['dep'] 47 | rearrg_nouns.append(grandchild) 48 | else: 49 | new_nouns.append(grandchild) 50 | if len(new_nouns) > 0: 51 | noun[id_n]['noun'] = new_nouns 52 | else: 53 | del noun[id_n]['noun'] 54 | noun += rearrg_nouns 55 | ## merge preposition and its only child node (pobject) as one node [node type = 'M' (modifier)] 56 | if noun and (not verb) and ('attribute' not in node): 57 | if len(noun) == 1 and node['dep'] == 'prep' and noun[0]['dep'] in subj_and_obj: 58 | if ('noun' not in noun[0]) and ('verb' not in noun[0]): 59 | indexes = node['index'] + noun[0]['index'] 60 | indexes.sort(key=lambda x: x) 61 | wrap = {'dep':noun[0]['dep'], 'word':[sequence[i] for i in indexes], 'index':indexes, 62 | 'pos':noun[0]['pos'], 'type': 'M'} 63 | if 'attribute' in noun[0]: 64 | wrap['attribute'] = noun[0]['attribute'] 65 | return wrap 66 | ## if has more than one nodes, do redirecting 67 | elif 'verb' not in noun[0] and 'noun' in noun[0]: 68 | new_nouns, gg_nouns = [], [] 69 | dep_list = [n['dep'] == 'conj' for n in noun[0]['noun']] 70 | for i, grandchild in enumerate(noun[0]['noun']): 71 | if dep_list[i]: 72 | grandchild['dep'] = noun[0]['dep'] 73 | gg_nouns.append(grandchild) 74 | else: 75 | new_nouns.append(grandchild) 76 | if len(new_nouns) > 0: 77 | noun[0]['noun'] = new_nouns 78 | else: 79 | del noun[0]['noun'] 80 | noun += gg_nouns 81 | ## for node which represents time/date (i.e., contain month word), 82 | # merge it with all its attribute child nodes 83 | if 'attribute' in slf and any([w in months for w in slf['word']]): 84 | slf['index'], slf['word'] = merge_month(slf, sequence) 85 | del slf['attribute'] 86 | ## get final node 87 | wrap = {'dep':slf['dep'], 'word':slf['word'], 'index':slf['index'], 'pos':slf['pos'], 88 | 'type':slf['type'], 'noun':noun, 'verb':verb} 89 | if 'attribute' in slf: 90 | wrap['attribute'] = slf['attribute'] 91 | if not noun: 92 | del wrap['noun'] 93 | if not verb: 94 | del wrap['verb'] 95 | return wrap 96 | 97 | 98 | if __name__ == '__main__': 99 | ##=== load file ===## 100 | raw = json_load(sys.argv[1]) 101 | ##=== rearrange trees ===## 102 | graph = [] 103 | for sample in tqdm(raw, desc=' - (Merging Trees) - '): 104 | evidence = [{'sequence': sent['sequence'], 'tree': rearrange(sent['graph'], sent['sequence'])} for sent in sample] 105 | graph.append(evidence) 106 | ##=== dump file ===## 107 | json_dump(graph, sys.argv[2]) 108 | -------------------------------------------------------------------------------- /build-semantic-graphs/tag.py: -------------------------------------------------------------------------------- 1 | """Question Tagging 2 | Generate groudtruth for Context Selection 3 | """ 4 | 5 | import sys 6 | from tqdm import tqdm 7 | 8 | import json 9 | import codecs 10 | 11 | import nltk 12 | from nltk.stem.porter import PorterStemmer 13 | from nltk.corpus import stopwords 14 | 15 | import re 16 | 17 | json_load = lambda x: json.load(codecs.open(x, 'r', encoding='utf-8')) 18 | json_dump = lambda d, p: json.dump(d, codecs.open(p, 'w', 'utf-8'), indent=2, ensure_ascii=False) 19 | 20 | 21 | pattern = re.compile('[\W]*') 22 | verb_pos = ['VBZ', 'VBN', 'VBD', 'VBP', 'VB', 'VBG', 'IN', 'TO', 'PP'] 23 | noun_pos = ['NN', 'NNP', 'NNS', 'NNPS'] 24 | modifier_pos = ['JJ', 'FW', 'JJR', 'JJS', 'RB', 'RBR', 'RBS'] 25 | 26 | 27 | def text_load(filename): 28 | with open(filename, 'r', encoding='utf-8') as f: 29 | data = f.read().strip().split('\n') 30 | return data 31 | 32 | 33 | def tag(nodes, edges, question): 34 | porter_stemmer, stopwords_eng = PorterStemmer(), stopwords.words('english') 35 | ## get words and word-stems 36 | words, question = [node['word'].split(' ') for node in nodes], [w for w in question.split(' ') if len(w) > 0] 37 | node_stem, question_stem = [[porter_stemmer.stem(w) for w in sent] for sent in words], [porter_stemmer.stem(w) for w in question] 38 | ## search for the node list covering the question by each word 39 | question_index, node_contain = [[] for _ in question], [0 for _ in nodes] 40 | for index, word in enumerate(question): 41 | if len(word) > 0 and word not in stopwords_eng and not pattern.fullmatch(word): 42 | for idx, node_words in enumerate(words): 43 | if question_stem[index] in node_stem[idx] or word in node_words or any([w.count(word) > 0 for w in node_words]): 44 | question_index[index].append(idx) 45 | node_contain[idx] += 1 46 | ## tag the node which covers more in the question 47 | for index in question_index: 48 | if len(index) > 0: 49 | index.sort(key=lambda idx: (node_contain[idx], -len(nodes[idx]['index']))) 50 | nodes[index[-1]]['tag'] = 1 51 | ## tag the node which has 'SIMILAR' edge 52 | # (we assume this kind of nodes are important for asking questions) 53 | for index, node in enumerate(nodes): 54 | if 'tag' not in node: 55 | nodes[index]['tag'] = 1 if 'SIMILAR' in edges[index] else 0 56 | 57 | return nodes 58 | 59 | 60 | def main(raw, questions): 61 | for idx, sample in tqdm(enumerate(zip(raw, questions)), desc=' (TAGGING) '): 62 | sample, question = sample[0], sample[1] 63 | nodes, edges = sample['nodes'], sample['edges'] 64 | raw[idx]['nodes'] = tag(nodes, edges, question) 65 | return raw 66 | 67 | 68 | if __name__ == '__main__': 69 | ##=== load files ===## 70 | data = json_load(sys.argv[1]) 71 | questions = text_load(sys.argv[2]) 72 | ##=== tagging ===## 73 | new = main(data, questions) 74 | ##=== dump file ===## 75 | json_dump(new, sys.argv[3]) -------------------------------------------------------------------------------- /evaluate_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from pprint import pprint 5 | 6 | from pycocoevalcap.bleu.bleu import Bleu 7 | from pycocoevalcap.meteor.meteor import Meteor 8 | from pycocoevalcap.rouge.rouge import Rouge 9 | 10 | 11 | def text_load(filename): 12 | with open(filename, 'r', encoding='utf-8') as f: 13 | data = f.read().strip().strip('===========================').strip() 14 | data = data.split('\n===========================\n') 15 | data = [sample.strip().split('\n') for sample in data] 16 | gold = [sample[1].strip().split('\t')[1].lower() for sample in data] 17 | pred = [sample[2].strip().split('\t')[1].lower() for sample in data] 18 | 19 | return gold, pred 20 | 21 | 22 | if __name__ == "__main__": 23 | ground_turth, predictions = text_load(sys.argv[1]) 24 | 25 | scorers = { 26 | "Bleu": Bleu(4), 27 | "Meteor": Meteor(), 28 | "Rouge": Rouge() 29 | } 30 | 31 | gts = {} 32 | res = {} 33 | if len(predictions) == len(ground_turth): 34 | for ind, value in enumerate(predictions): 35 | # print(value) 36 | res[ind] = [value] 37 | 38 | for ind, value in enumerate(ground_turth): 39 | gts[ind] = [value] 40 | else: 41 | Min_Len = min(len(predictions), len(ground_turth)) 42 | for ind in range(Min_Len): 43 | res[ind] = [predictions[ind]] 44 | gts[ind] = [ground_turth[ind]] 45 | 46 | # param gts: Dictionary of reference sentences (id, sentence) 47 | # param res: Dictionary of hypothesis sentences (id, sentence) 48 | 49 | print('samples: {} / {}'.format(len(res.keys()), len(gts.keys()))) 50 | 51 | scores = {} 52 | for name, scorer in scorers.items(): 53 | score, all_scores = scorer.compute_score(gts, res) 54 | if isinstance(score, list): 55 | for i, sc in enumerate(score, 1): 56 | scores[name + str(i)] = sc 57 | else: 58 | scores[name] = score 59 | 60 | pprint(scores) 61 | -------------------------------------------------------------------------------- /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WING-NUS/SG-Deep-Question-Generation/f9ac6f7922d0b7cbfc6e974df950562569cbeaf4/model.jpg -------------------------------------------------------------------------------- /scripts/preprocess_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | DATAHOME=${HOME}/datasets 6 | EXEHOME=${HOME}/src 7 | 8 | cd ${EXEHOME} 9 | 10 | python preprocess.py \ 11 | -train_src ${DATAHOME}/text-data/train.src.txt -train_tgt ${DATAHOME}/text-data/train.tgt.txt \ 12 | -valid_src ${DATAHOME}/text-data/valid.src.txt -valid_tgt ${DATAHOME}/text-data/valid.tgt.txt \ 13 | -train_ans ${DATAHOME}/text-data/train.ans.txt -valid_ans ${DATAHOME}/text-data/valid.ans.txt \ 14 | -train_graph ${DATAHOME}/json-data/train.tag.json -valid_graph ${DATAHOME}/json-data/valid.tag.json \ 15 | -node_feature \ 16 | -copy \ 17 | -answer \ 18 | -save_sequence_data ${DATAHOME}/preprocessed-data/preprcessed_sequence_data.pt \ 19 | -save_graph_data ${DATAHOME}/preprocessed-data/preprcessed_graph_data.pt \ 20 | -train_dataset ${DATAHOME}/Datasets/train_dataset.pt \ 21 | -valid_dataset ${DATAHOME}/Datasets/valid_dataset.pt \ 22 | -src_seq_length 200 -tgt_seq_length 50 \ 23 | -src_vocab_size 50000 -tgt_vocab_size 50000 \ 24 | -src_words_min_frequency 3 -tgt_words_min_frequency 2 \ 25 | -vocab_trunc_mode frequency \ 26 | -pre_trained_vocab ${GLOVEHOME}/glove.840B.300d.txt -word_vec_size 300 \ 27 | -batch_size 32 28 | -------------------------------------------------------------------------------- /scripts/train_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | DATAHOME=${HOME}/datasets 6 | EXEHOME=${HOME}/src 7 | MODELHOME=${HOME}/models 8 | LOGHOME=${HOME}/logs 9 | 10 | mkdir -p ${MODELHOME} 11 | mkdir -p ${LOGHOME} 12 | 13 | cd ${EXEHOME} 14 | 15 | python train.py \ 16 | -sequence_data ${DATAHOME}/preprcessed-data/preprcessed_sequence_data.pt \ 17 | -graph_data ${DATAHOME}/preprcessed-data/preprcessed_graph_data.pt \ 18 | -train_dataset ${DATAHOME}/Datasets/train_dataset.pt \ 19 | -valid_dataset ${DATAHOME}/Datasets/valid_dataset.pt \ 20 | -epoch 100 \ 21 | -batch_size 32 -eval_batch_size 16 \ 22 | -pre_trained_vocab \ 23 | -training_mode classify \ 24 | -max_token_src_len 200 -max_token_tgt_len 50 \ 25 | -sparse 0 \ 26 | -copy \ 27 | -coverage -coverage_weight 0.4 \ 28 | -node_feature \ 29 | -d_word_vec 300 \ 30 | -d_seq_enc_model 512 -d_graph_enc_model 256 -n_graph_enc_layer 3 \ 31 | -d_k 64 -brnn -enc_rnn gru \ 32 | -d_dec_model 512 -n_dec_layer 1 -dec_rnn gru \ 33 | -maxout_pool_size 2 -n_warmup_steps 10000 \ 34 | -dropout 0.5 -attn_dropout 0.1 \ 35 | -gpus 0 \ 36 | -save_mode best -save_model ${MODELHOME}/classifier \ 37 | -log_home ${LOGHOME} \ 38 | -logfile_train ${LOGHOME}/train_classifier \ 39 | -logfile_dev ${LOGHOME}/valid_classifier \ 40 | -translate_ppl 15 \ 41 | -curriculum 0 -extra_shuffle -optim adam \ 42 | -learning_rate 0.00025 -learning_rate_decay 0.75 \ 43 | -valid_steps 500 \ 44 | -decay_steps 500 -start_decay_steps 5000 -decay_bad_cnt 5 \ 45 | -max_grad_norm 5 -max_weight_value 20 46 | -------------------------------------------------------------------------------- /scripts/train_generator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | DATAHOME=${HOME}/datasets 6 | EXEHOME=${HOME}/src 7 | MODELHOME=${HOME}/models 8 | LOGHOME=${HOME}/logs 9 | 10 | mkdir -p ${MODELHOME} 11 | mkdir -p ${LOGHOME} 12 | 13 | cd ${EXEHOME} 14 | 15 | python train.py \ 16 | -sequence_data ${DATAHOME}/preprcessed-data/preprcessed_sequence_data.pt \ 17 | -graph_data ${DATAHOME}/preprcessed-data/preprcessed_graph_data.pt \ 18 | -train_dataset ${DATAHOME}/Datasets/train_dataset.pt \ 19 | -valid_dataset ${DATAHOME}/Datasets/valid_dataset.pt \ 20 | -checkpoint ${MODELHOME}/classifier_84.06773_accuracy.chkpt \ 21 | -epoch 100 \ 22 | -batch_size 32 -eval_batch_size 16 \ 23 | -pre_trained_vocab \ 24 | -training_mode generate \ 25 | -max_token_src_len 200 -max_token_tgt_len 50 \ 26 | -sparse 0 \ 27 | -copy \ 28 | -coverage -coverage_weight 0.4 \ 29 | -node_feature \ 30 | -d_word_vec 300 \ 31 | -d_seq_enc_model 512 -d_graph_enc_model 256 -n_graph_enc_layer 3 \ 32 | -d_k 64 -brnn -enc_rnn gru \ 33 | -d_dec_model 512 -n_dec_layer 1 -dec_rnn gru \ 34 | -maxout_pool_size 2 -n_warmup_steps 10000 \ 35 | -dropout 0.5 -attn_dropout 0.1 \ 36 | -gpus 0 \ 37 | -save_mode best -save_model ${MODELHOME}/generator \ 38 | -log_home ${LOGHOME} \ 39 | -logfile_train ${LOGHOME}/train_generator \ 40 | -logfile_dev ${LOGHOME}/valid_generator \ 41 | -translate_ppl 15 \ 42 | -curriculum 0 -extra_shuffle -optim adam \ 43 | -learning_rate 0.00025 -learning_rate_decay 0.75 \ 44 | -valid_steps 500 \ 45 | -decay_steps 500 -start_decay_steps 5000 -decay_bad_cnt 5 \ 46 | -max_grad_norm 5 -max_weight_value 32 47 | -------------------------------------------------------------------------------- /scripts/translate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | 5 | DATAHOME=${HOME}/datasets 6 | EXEHOME=${HOME}/src 7 | MODELHOME=${HOME}/models 8 | LOGHOME=${HOME}/predictions 9 | 10 | mkdir -p ${LOGHOME} 11 | 12 | cd ${EXEHOME} 13 | 14 | python translate.py \ 15 | -model ${MODELHOME}/generator_15.12441_bleu4.chkpt \ 16 | -sequence_data ${DATAHOME}/preprocessed-data/preprcessed_sequence_data.pt \ 17 | -graph_data ${DATAHOME}/preprocessed-data/preprcessed_graph_data.pt \ 18 | -valid_data ${DATAHOME}/Datasets/valid_dataset.pt \ 19 | -output ${LOGHOME}/prediction.txt \ 20 | -beam_size 5 \ 21 | -batch_size 16 \ 22 | -gpus 0 23 | -------------------------------------------------------------------------------- /src/onqg/dataset/Constants.py: -------------------------------------------------------------------------------- 1 | 2 | PAD = 0 3 | UNK = 1 4 | BOS = 2 5 | EOS = 3 6 | SEP = 4 7 | 8 | PAD_WORD = '[PAD]' 9 | UNK_WORD = '[UNK]' 10 | BOS_WORD = '[BOS]' 11 | EOS_WORD = '[EOS]' 12 | SEP_WORD = '[SEP]' -------------------------------------------------------------------------------- /src/onqg/dataset/Dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | from torch import cuda 9 | 10 | import onqg.dataset.Constants as Constants 11 | 12 | 13 | class Dataset(object): 14 | 15 | def __init__(self, seq_datasets, graph_datasets, batchSize, 16 | copy=False, answer=False, ans_feature=False, 17 | feature=False, node_feature=False, opt_cuda=False): 18 | self.src, self.tgt = seq_datasets['src'], seq_datasets['tgt'] 19 | self.has_tgt = True if self.tgt else False 20 | 21 | self.graph_index = graph_datasets['index'] 22 | # self.graph_root = graph_datasets['root'] 23 | self.edge_in = graph_datasets['edge']['in'] 24 | self.edge_out = graph_datasets['edge']['out'] 25 | 26 | self.answer = answer 27 | self.ans = seq_datasets['ans'] if answer else None 28 | self.ans_feature_num = len(seq_datasets['ans_feature']) if ans_feature else 0 29 | self.ans_features = seq_datasets['ans_feature'] if self.ans_feature_num else None 30 | 31 | self.feature_num = len(seq_datasets['feature']) if feature else 0 32 | self.features = seq_datasets['feature'] if self.feature_num else None 33 | 34 | self.node_feature_num = len(graph_datasets['feature']) if node_feature else 0 35 | self.node_features = graph_datasets['feature'] if self.node_feature_num else None 36 | 37 | self.copy = copy 38 | self.copy_switch = seq_datasets['copy']['switch'] if copy else None 39 | self.copy_tgt = seq_datasets['copy']['tgt'] if copy else None 40 | 41 | self._update_data() 42 | 43 | if opt_cuda: 44 | cuda.set_device(opt_cuda[0]) 45 | self.device = torch.device("cuda" if cuda else "cpu") 46 | 47 | self.batchSize = batchSize 48 | self.numBatches = math.ceil(len(self.src) / batchSize) 49 | 50 | def _update_data(self): 51 | """sort all data by lengths of source text""" 52 | self.idxs = list(range(len(self.src))) 53 | lengths = [s.size(0) for s in self.src] 54 | RAW = [lengths, self.src, self.idxs] 55 | 56 | DATA = list(zip(*RAW)) 57 | DATA.sort(key=lambda x:x[0]) 58 | 59 | self.src = [d[1] for d in DATA] 60 | self.idxs = [d[2] for d in DATA] 61 | 62 | if self.tgt: 63 | self.tgt = [self.tgt[idx] for idx in self.idxs] 64 | if self.copy: 65 | self.copy_switch = [self.copy_switch[idx] for idx in self.idxs] 66 | self.copy_tgt = [self.copy_tgt[idx] for idx in self.idxs] 67 | if self.feature_num: 68 | self.features = [[feature[idx] for idx in self.idxs] for feature in self.features] 69 | if self.answer: 70 | self.ans = [self.ans[idx] for idx in self.idxs] 71 | if self.ans_feature_num: 72 | self.ans_features = [[feature[idx] for idx in self.idxs] for feature in self.ans_features] 73 | 74 | self.edge_in_dict = self._get_edge_dict(self.edge_in) 75 | self.edge_out_dict = self._get_edge_dict(self.edge_out) 76 | 77 | def _get_edge_dict(self, edges): 78 | edges_dict = [] 79 | for sample in edges: 80 | edge_dict = [] 81 | for edge_list in sample: 82 | edge_dict.append([(idx, edge.item()) for idx, edge in enumerate(edge_list) if edge.item() != Constants.PAD]) 83 | edges_dict.append(edge_dict) 84 | return edges_dict 85 | 86 | def _batchify(self, data, align_right=False, include_lengths=False, src_len=None): 87 | """get data in a batch while applying padding, return length if needed""" 88 | if src_len: 89 | lengths = src_len 90 | else: 91 | lengths = [x.size(0) for x in data] 92 | max_length = max(lengths) 93 | 94 | out = data[0].new(len(data), max_length).fill_(Constants.PAD) 95 | for i in range(len(data)): 96 | data_length = data[i].size(0) 97 | offset = max_length - data_length if align_right else 0 98 | out[i].narrow(0, offset, data_length).copy_(data[i]) 99 | 100 | if include_lengths: 101 | return out, lengths 102 | else: 103 | return out 104 | 105 | def _graph_length_info(self, indexes, edges): 106 | indexes = [index for sample in indexes for index in sample] 107 | index_length = [len(index) for index in indexes] 108 | 109 | node_length = [len(edge) for edge in edges] 110 | nodes = [torch.Tensor([i for i in range(length)]) for length in node_length] 111 | tmpNodesBatch = self._batchify(nodes, src_len=node_length) 112 | return index_length, node_length, tmpNodesBatch 113 | 114 | def _pad_edges(self, edges, node_length, align_right=False): 115 | max_length = max(node_length) 116 | edge_lengths = [] 117 | 118 | outs = [] 119 | for edge, data_length in zip(edges, node_length): 120 | out = edge[0].new_full((max_length, max_length), Constants.PAD) 121 | for i, e in enumerate(edge): 122 | offset = max_length - data_length if align_right else 0 123 | out[i].narrow(0, offset, data_length).copy_(e) 124 | outs.append(out.view(-1)) 125 | edge_lengths.append(data_length * data_length) 126 | outs = torch.stack(outs, dim=0) # batch_size x (max_length * max_length) 127 | 128 | return outs, edge_lengths 129 | 130 | def __getitem__(self, index): 131 | """get the exact batch using index, and transform data into Tensor form""" 132 | assert index < self.numBatches, "%d > %d" % (index, self.numBatches) 133 | 134 | srcBatch, lengths = self._batchify(self.src[index * self.batchSize: (index + 1) * self.batchSize], 135 | align_right=False, include_lengths=True) 136 | tgtBatch = None 137 | if self.tgt: 138 | tgtBatch = self._batchify(self.tgt[index * self.batchSize: (index + 1) * self.batchSize]) 139 | 140 | idxBatch = self.idxs[index * self.batchSize: (index + 1) * self.batchSize] 141 | 142 | graphIndexBatch = [self.graph_index[i] for i in idxBatch] 143 | # graphRootBatch = [self.graph_root[i] for i in idxBatch] 144 | edgeInBatch, edgeOutBatch = [self.edge_in[i] for i in idxBatch], [self.edge_out[i] for i in idxBatch] 145 | edgeInDict, edgeOutDict = [self.edge_in_dict[i] for i in idxBatch], [self.edge_out_dict[i] for i in idxBatch] 146 | node_index_length, node_lengths, tmpNodesBatch = self._graph_length_info(graphIndexBatch, edgeInBatch) 147 | edgeInBatch, edge_lengths = self._pad_edges(edgeInBatch, node_lengths) 148 | edgeOutBatch, _ = self._pad_edges(edgeOutBatch, node_lengths) 149 | 150 | nodeFeatBatches = None 151 | if self.node_feature_num: 152 | nodeFeatBatches = [ 153 | self._batchify([feat[i] for i in idxBatch], src_len=node_lengths) for feat in self.node_features 154 | ] 155 | 156 | featBatches = None 157 | if self.feature_num: 158 | featBatches = [ 159 | self._batchify(feat[index * self.batchSize: (index + 1) * self.batchSize], 160 | src_len=lengths) for feat in self.features 161 | ] 162 | 163 | copySwitchBatch, copyTgtBatch = None, None 164 | if self.copy: 165 | copySwitchBatch = self._batchify(self.copy_switch[index * self.batchSize: (index + 1) * self.batchSize]) 166 | copyTgtBatch = self._batchify(self.copy_tgt[index * self.batchSize: (index + 1) * self.batchSize]) 167 | 168 | ansBatch, ansFeatBatches = None, None 169 | if self.answer: 170 | ansBatch, ansLengths = self._batchify(self.ans[index * self.batchSize: (index + 1) * self.batchSize], 171 | align_right=False, include_lengths=True) 172 | if self.ans_feature_num: 173 | ansFeatBatches = [ 174 | self._batchify(feat[index * self.batchSize: (index + 1) * self.batchSize], src_len=ansLengths) 175 | for feat in self.ans_features 176 | ] 177 | 178 | def wrap(b): 179 | if b is None: 180 | return b 181 | b = torch.stack([x for x in b], dim=0).contiguous() 182 | b = b.to(self.device) 183 | return b 184 | 185 | # wrap lengths in a Variable to properly split it in DataParallel 186 | lengths = torch.LongTensor(lengths).view(1, -1).to(self.device) 187 | edge_lengths = torch.LongTensor(edge_lengths).view(1, -1).to(self.device) 188 | indices = range(len(srcBatch)) 189 | 190 | rst = {} 191 | rst['indice'] = indices 192 | rst['src'] = (wrap(srcBatch), lengths) 193 | rst['raw-index'] = idxBatch 194 | if self.has_tgt: 195 | rst['tgt'] = wrap(tgtBatch) 196 | if self.copy: 197 | rst['copy'] = (wrap(copySwitchBatch), wrap(copyTgtBatch)) 198 | if self.answer: 199 | ansLengths = torch.LongTensor(ansLengths).view(1, -1).to(self.device) 200 | rst['ans'] = (wrap(ansBatch), ansLengths) 201 | if self.ans_feature_num: 202 | rst['ans_feat'] = (tuple(wrap(x) for x in ansFeatBatches), ansLengths) 203 | if self.feature_num: 204 | rst['feat'] = (tuple(wrap(x) for x in featBatches), lengths) 205 | 206 | rst['edges'] = ((wrap(edgeInBatch), wrap(edgeOutBatch)), edge_lengths) 207 | rst['edges_dict'] = (edgeInDict, edgeOutDict) 208 | rst['tmp_nodes'] = (wrap(tmpNodesBatch), node_lengths) 209 | rst['graph_index'] = (graphIndexBatch, node_index_length) 210 | # rst['graph_root'] = graphRootBatch 211 | if self.node_feature_num: 212 | rst['node_feat'] = (tuple(wrap(x) for x in nodeFeatBatches), node_lengths) 213 | 214 | return rst 215 | 216 | def __len__(self): 217 | return self.numBatches 218 | 219 | def shuffle(self): 220 | """shuffle the order of data in every batch""" 221 | 222 | def shuffle_group(start, end, NEW): 223 | """shuffle the order of samples with index from start to end""" 224 | RAW = [self.src[start:end], self.tgt[start:end], self.idxs[start:end]] 225 | DATA = list(zip(*RAW)) 226 | index = torch.randperm(len(DATA)) 227 | 228 | src, tgt, idx = zip(*[DATA[i] for i in index]) 229 | NEW['SRCs'] += list(src) 230 | NEW['TGTs'] += list(tgt) 231 | NEW['IDXs'] += list(idx) 232 | 233 | if self.answer: 234 | ans = [self.ans[start:end][i] for i in index] 235 | NEW['ANSs'] += ans 236 | if self.ans_feature_num: 237 | ansft = [[feature[start:end][i] for i in index] for feature in self.ans_features] 238 | for i in range(self.ans_feature_num): 239 | NEW['ANSFTs'][i] += ansft[i] 240 | 241 | if self.feature_num: 242 | ft = [[feature[start:end][i] for i in index] for feature in self.features] 243 | for i in range(self.feature_num): 244 | NEW['FTs'][i] += ft[i] 245 | 246 | if self.copy: 247 | cpswt = [self.copy_switch[start:end][i] for i in index] 248 | cptgt = [self.copy_tgt[start:end][i] for i in index] 249 | NEW['COPYSWTs'] += cpswt 250 | NEW['COPYTGTs'] += cptgt 251 | 252 | return NEW 253 | 254 | assert self.tgt != None, "shuffle is only aimed for training data (with target given)" 255 | 256 | NEW = {'SRCs':[], 'TGTs':[], 'IDXs':[]} 257 | if self.copy: 258 | NEW['COPYSWTs'], NEW['COPYTGTs'] = [], [] 259 | if self.feature_num: 260 | NEW['FTs'] = [[] for i in range(self.feature_num)] 261 | if self.answer: 262 | NEW['ANSs'] = [] 263 | if self.ans_feature_num: 264 | NEW['ANSFTs'] = [[] for i in range(self.ans_feature_num)] 265 | 266 | shuffle_all = random.random() 267 | if shuffle_all > 0.75: # fix this magic number later 268 | start, end = 0, self.batchSize * self.numBatches 269 | NEW = shuffle_group(start, end, NEW) 270 | else: 271 | for batch_idx in range(self.numBatches): 272 | start = batch_idx * self.batchSize 273 | end = start + self.batchSize 274 | 275 | NEW = shuffle_group(start, end, NEW) 276 | 277 | self.src, self.tgt, self.idxs = NEW['SRCs'], NEW['TGTs'], NEW['IDXs'] 278 | if self.copy: 279 | self.copy_switch, self.copy_tgt = NEW['COPYSWTs'], NEW['COPYTGTs'] 280 | if self.answer: 281 | self.ans = NEW['ANSs'] 282 | if self.ans_feature_num: 283 | self.ans_features = NEW['ANSFTs'] 284 | if self.feature_num: 285 | self.features = NEW['FTs'] 286 | -------------------------------------------------------------------------------- /src/onqg/dataset/Vocab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | import onqg.dataset.Constants as Constants 5 | 6 | # from pytorch_pretrained_bert import BertTokenizer, GPT2Tokenizer 7 | 8 | 9 | def build_vocab(pretrained): 10 | if pretrained.count('bert'): 11 | tokenizer = BertTokenizer.from_pretrained(pretrained) 12 | labelToIdx = functools.partial(tokenizer.convert_tokens_to_ids) 13 | idxToLabel = functools.partial(tokenizer.convert_ids_to_tokens) 14 | model_file = 'bert' 15 | else: 16 | raise ValueError('Unsupported vocabulary type: ' + pretrained) 17 | 18 | return {'vocab':tokenizer, 'functions':[labelToIdx, idxToLabel], 'type':model_file} 19 | 20 | 21 | class Vocab(object): 22 | """ 23 | Class for vocabulary 24 | contain BERTTokenizer 25 | Default is defined by given functions 26 | """ 27 | def __init__(self, special_words, lower=False, opts=None): 28 | self.type = opts['type'] if opts is not None else 'default' 29 | self.lower = lower 30 | self.special_words = special_words 31 | 32 | # Special entries will not be pruned. 33 | self.special = [] 34 | 35 | if self.type == 'default': 36 | self.idxToLabel, self.labelToIdx, self.frequencies = {}, {}, {} 37 | if len(special_words) > 0: 38 | self.addSpecials(special_words) 39 | else: 40 | self.tokenizer = opts['vocab'] 41 | self.label_to_idx, self.idx_to_label = opts['functions'][0], opts['functions'][1] 42 | self.special = self.label_to_idx(special_words) 43 | 44 | @classmethod 45 | def from_opt(cls, corpus=None, opt=None, pretrained=None): 46 | special_words = [Constants.PAD_WORD, Constants.UNK_WORD] 47 | if opt['tgt']: 48 | special_words += [Constants.BOS_WORD, Constants.EOS_WORD] 49 | 50 | if pretrained is not None: 51 | vocab = cls(special_words, lower=True, opts=build_vocab(pretrained)) 52 | else: 53 | assert corpus is not None and opt is not None 54 | vocab = cls(special_words, lower=opt['lower']) 55 | for sent in corpus: 56 | for word in sent: 57 | vocab.add(word, lower=opt['lower']) 58 | original_size = vocab.size 59 | vocab = vocab.prune(opt['size'], opt['frequency'], opt['mode']) 60 | print("Truncate vocabulary size from " + str(original_size) + " to " + str(vocab.size)) 61 | 62 | return vocab 63 | 64 | @property 65 | def size(self): 66 | if self.type == 'default': 67 | return len(self.idxToLabel) 68 | else: 69 | return len(self.tokenizer.vocab) 70 | 71 | def lookup(self, key, default=Constants.UNK): 72 | try: 73 | return self.labelToIdx[key] if self.type == 'default' else self.label_to_idx([key])[0] 74 | except KeyError: 75 | return default 76 | 77 | def getLabel(self, idx, default=Constants.UNK_WORD): 78 | try: 79 | return self.idxToLabel[idx] if self.type == 'default' else self.idx_to_label([idx])[0] 80 | except KeyError: 81 | return default 82 | 83 | # Mark this `label` and `idx` as special (i.e. will not be pruned). 84 | def addSpecial(self, label, idx=None): 85 | idx = self.add(label, idx, lower=False) 86 | self.special += [idx] 87 | 88 | # Mark all labels in `labels` as specials (i.e. will not be pruned). 89 | def addSpecials(self, labels): 90 | for label in labels: 91 | self.addSpecial(label) 92 | 93 | # Add `label` in the dictionary. Use `idx` as its index if given. 94 | def add(self, label, idx=None, lower=False): 95 | assert self.type == 'default', "BERT has already been pretrained" 96 | 97 | lower = self.lower if lower and label not in self.special_words else False 98 | label = label.lower() if lower else label 99 | 100 | if idx is not None: 101 | self.idxToLabel[idx] = label 102 | self.labelToIdx[label] = idx 103 | else: 104 | if label in self.labelToIdx: 105 | idx = self.labelToIdx[label] 106 | else: 107 | idx = len(self.idxToLabel) 108 | self.idxToLabel[idx] = label 109 | self.labelToIdx[label] = idx 110 | 111 | if idx not in self.frequencies: 112 | self.frequencies[idx] = 1 113 | else: 114 | self.frequencies[idx] += 1 115 | 116 | return idx 117 | 118 | # Return a new dictionary with the `size` most frequent entries. 119 | def prune(self, size, frequency, mode='size'): 120 | assert self.type == 'default', self.type.upper() + " has already been pretrained" 121 | 122 | freq = torch.Tensor([self.frequencies[i] for i in range(len(self.frequencies))]) 123 | freq, idx = torch.sort(freq, 0, True) 124 | 125 | newVocab = Vocab([self.idxToLabel[i] for i in self.special], lower=self.lower) 126 | 127 | if mode == 'size': 128 | if size >= self.size: 129 | return self 130 | # Only keep the `size` most frequent entries. 131 | for i in idx[:size]: 132 | newVocab.add(self.idxToLabel[i.item()]) 133 | return newVocab 134 | elif mode == 'frequency': 135 | if frequency <= 1: 136 | return self 137 | for cnt, i in enumerate(idx): 138 | if freq[cnt] < frequency: 139 | return newVocab 140 | newVocab.add(self.idxToLabel[i.item()]) 141 | newVocab.frequencies[i.item()] = self.frequencies[i.item()] 142 | else: 143 | print("mode error in Vocab.prune! ") 144 | assert False 145 | 146 | return newVocab 147 | 148 | # Convert `labels` to indices. Use `unkWord` if not found. 149 | # Optionally insert `bosWord` at the beginning and `eosWord` at the . 150 | def convertToIdx(self, labels, unkWord=Constants.UNK_WORD): 151 | unk = self.lookup(unkWord) 152 | indexes = [self.lookup(label, default=unk) for label in labels] 153 | return torch.LongTensor(indexes) 154 | 155 | # Convert `idx` to labels. If index `stop` is reached, convert it and return. 156 | def convertToLabels(self, idx, stopList=[Constants.PAD, Constants.EOS]): 157 | labels = [self.getLabel(i) for i in idx if i not in stopList] 158 | return labels 159 | -------------------------------------------------------------------------------- /src/onqg/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from onqg.dataset.Dataset import Dataset 2 | from onqg.dataset.Vocab import Vocab -------------------------------------------------------------------------------- /src/onqg/dataset/data_processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import onqg.dataset.Constants as Constants 3 | from onqg.utils.mask import get_edge_mask 4 | 5 | 6 | def preprocess_batch(batch, n_edge_type, sparse=True, feature=False, dec_feature=0, 7 | answer=False, ans_feature=False, copy=False, node_feature=False, 8 | device=None): 9 | """Get a batch by indexing to the Dataset object, then preprocess it to get inputs for the model 10 | Input: batch 11 | raw-index: idxBatch 12 | 13 | src: (wrap(srcBatch), lengths) 14 | tgt: wrap(tgtBatch) 15 | copy: (wrap(copySwitchBatch), wrap(copyTgtBatch)) 16 | feat: (tuple(wrap(x) for x in featBatches), lengths) 17 | ans: (wrap(ansBatch), ansLengths) 18 | ans_feat: (tuple(wrap(x) for x in ansFeatBatches), ansLengths) 19 | 20 | edges: (wrap(edgeInBatch), wrap(edgeOutBatch), edge_lengths) 21 | edges_dict: (edgeInDict, edgeOutDict) 22 | tmp_nodes: (wrap(tmpNodesBatch), node_lengths) 23 | graph_index: (graphIndexBatch, node_index_length) 24 | graph_root: graphRootBatch 25 | node_feat: (tuple(wrap(x) for x in nodeFeatBatches), node_lengths) 26 | Output: 27 | (1) inputs dict: 28 | seq-encoder: src_seq, lengths, feat_seqs 29 | graph-encoder: edges, mask, type, feat_seqs (, adjacent_matrix) 30 | encoder-transform: index, lengths, root, node_lengths 31 | decoder: tgt_seq, src_seq, feat_seqs 32 | decoder-transform: index 33 | (2) (max_node_num, max_node_size) 34 | (3) (generation, classification) 35 | (4) (copy_gold, copy_switch) 36 | """ 37 | inputs = {'seq-encoder':{}, 'graph-encoder':{}, 'encoder-transform':{}, 38 | 'decoder':{}, 'decoder-transform':{}} 39 | ###===== RNN encoder =====### 40 | src_seq, tgt_seq = batch['src'], batch['tgt'] 41 | src_seq, lengths = src_seq[0], src_seq[1] 42 | inputs['seq-encoder']['src_seq'], inputs['seq-encoder']['lengths'] = src_seq, lengths 43 | ###===== encoder transform =====### 44 | edges, nodes = batch['edges'][0], batch['tmp_nodes'] 45 | nodes, node_lengths = nodes[0], nodes[1] 46 | graph_index = batch['graph_index'] 47 | # graph_root = batch['graph_root'] 48 | graph_index, index_lengths = graph_index[0], graph_index[1] 49 | inputs['encoder-transform']['index'], inputs['encoder-transform']['lengths'] = graph_index, index_lengths 50 | # inputs['encoder-transform']['root'] = graph_root 51 | inputs['encoder-transform']['node_lengths'] = node_lengths 52 | ###===== graph encoder =====### 53 | in_edge_mask, out_edge_mask = get_edge_mask(edges[0]), get_edge_mask(edges[1]) 54 | if sparse: 55 | in_edge_mask = in_edge_mask.view(in_edge_mask.size(0), -1, max(node_lengths)) 56 | out_edge_mask = out_edge_mask.view(in_edge_mask.size(0), -1, max(node_lengths)) 57 | inputs['graph-encoder']['adjacent_matrix'] = batch['edges_dict'] 58 | edge_type_list = torch.LongTensor([i for i in range(n_edge_type)]).to(device=device) 59 | inputs['graph-encoder']['edges'] = edge_type_list if sparse else (in_edge_mask, out_edge_mask) 60 | inputs['graph-encoder']['mask'] = (in_edge_mask, out_edge_mask) 61 | inputs['graph-encoder']['type'] = batch['node_feat'][0][0] 62 | ###===== classifier =====### 63 | classification = batch['node_feat'][0][-1] 64 | ###===== decoder transform =====### 65 | inputs['decoder-transform']['index'] = graph_index 66 | ###===== decoder =====### 67 | generation = tgt_seq[:, 1:] # exclude [BOS] token 68 | inputs['decoder']['tgt_seq'] = tgt_seq[:, :-1] 69 | inputs['decoder']['src_seq'] = src_seq # nodes 70 | inputs['decoder']['ans_seq'] = batch['ans'][0] 71 | ###===== auxiliary functions =====### 72 | src_feats, tgt_feats = None, None 73 | if feature: 74 | n_all_feature = len(batch['feat'][0]) 75 | # split all features into src and tgt parts, src_feats are those embedded in the encoder 76 | src_feats = batch['feat'][0][:n_all_feature - dec_feature] 77 | if dec_feature: 78 | # dec_feature: the number of features embedded in the decoder 79 | tgt_feats = batch['feat'][0][n_all_feature - dec_feature:] 80 | inputs['seq-encoder']['feat_seqs'], inputs['decoder']['feat_seqs'] = src_feats, tgt_feats 81 | 82 | copy_gold, copy_switch = None, None 83 | if copy: 84 | copy_gold, copy_switch = batch['copy'][1], batch['copy'][0] 85 | copy_gold, copy_switch = copy_gold[:, 1:], copy_switch[:, 1:] 86 | 87 | node_feats = batch['node_feat'][0][1:-1] if node_feature else None 88 | inputs['graph-encoder']['feat_seqs'] = node_feats 89 | 90 | max_node_size = max(index_lengths) 91 | max_node_num = max(node_lengths) 92 | 93 | return inputs, (max_node_num, max_node_size), (generation, classification), (copy_gold, copy_switch) 94 | -------------------------------------------------------------------------------- /src/onqg/models/Decoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import onqg.dataset.Constants as Constants 6 | 7 | from onqg.models.modules.Attention import ConcatAttention 8 | from onqg.models.modules.MaxOut import MaxOut 9 | from onqg.models.modules.DecAssist import StackedRNN, DecInit 10 | 11 | 12 | class RNNDecoder(nn.Module): 13 | """ 14 | Input: (1) inputs['tgt_seq'] 15 | (2) inputs['src_seq'] 16 | (3) inputs['src_indexes'] 17 | (4) inputs['enc_output'] 18 | (5) inputs['hidden'] 19 | (6) inputs['feat_seqs'] 20 | Output: (1) rst['pred'] 21 | (2) rst['attn'] 22 | (3) rst['context'] 23 | (4) rst['copy_pred']; rst['copy_gate'] 24 | (5) rst['coverage_pred'] 25 | 26 | """ 27 | def __init__(self, n_vocab, ans_n_vocab, d_word_vec, d_model, n_layer, n_rnn_enc_layer, 28 | rnn, d_k, feat_vocab, d_feat_vec, d_rnn_enc_model, d_enc_model, 29 | n_enc_layer, input_feed, copy, answer, coverage, layer_attn, 30 | maxout_pool_size, dropout, device=None): 31 | self.name = 'rnn' 32 | 33 | super(RNNDecoder, self).__init__() 34 | 35 | self.n_layer = n_layer 36 | self.layer_attn = layer_attn 37 | self.coverage = coverage 38 | self.copy = copy 39 | self.maxout_pool_size = maxout_pool_size 40 | input_size = d_word_vec 41 | 42 | self.input_feed = input_feed 43 | if input_feed: 44 | input_size += d_rnn_enc_model + d_enc_model 45 | 46 | self.ans_emb = nn.Embedding(ans_n_vocab, d_word_vec, padding_idx=Constants.PAD) 47 | 48 | self.answer = answer 49 | tmp_in = d_word_vec if answer else d_rnn_enc_model 50 | self.decInit = DecInit(d_enc=tmp_in, d_dec=d_model, n_enc_layer=n_rnn_enc_layer) 51 | 52 | self.feature = False if not feat_vocab else True 53 | if self.feature: 54 | self.feat_embs = nn.ModuleList([ 55 | nn.Embedding(n_f_vocab, d_feat_vec, padding_idx=Constants.PAD) for n_f_vocab in feat_vocab 56 | ]) 57 | feat_size = len(feat_vocab) * d_feat_vec if self.feature else 0 58 | 59 | self.d_enc_model = d_rnn_enc_model + d_enc_model 60 | 61 | self.word_emb = nn.Embedding(n_vocab, d_word_vec, padding_idx=Constants.PAD) 62 | self.rnn = StackedRNN(n_layer, input_size, d_model, dropout, rnn=rnn) 63 | self.attn = ConcatAttention(self.d_enc_model + feat_size, d_model, d_k, coverage) 64 | 65 | self.readout = nn.Linear((d_word_vec + d_model + self.d_enc_model), d_model) 66 | self.maxout = MaxOut(maxout_pool_size) 67 | 68 | if copy: 69 | self.copy_switch = nn.Linear(self.d_enc_model + d_model, 1) 70 | 71 | self.hidden_size = d_model 72 | self.dropout = nn.Dropout(dropout) 73 | self.device = device 74 | 75 | @classmethod 76 | def from_opt(cls, opt): 77 | return cls(opt['n_vocab'], opt['ans_n_vocab'], opt['d_word_vec'], opt['d_model'], opt['n_layer'], opt['n_rnn_enc_layer'], 78 | opt['rnn'], opt['d_k'], opt['feat_vocab'], opt['d_feat_vec'], opt['d_rnn_enc_model'], 79 | opt['d_enc_model'], opt['n_enc_layer'], opt['input_feed'], opt['copy'], opt['answer'], 80 | opt['coverage'], opt['layer_attn'], opt['maxout_pool_size'], opt['dropout'], 81 | opt['device']) 82 | 83 | def attn_init(self, context): 84 | if isinstance(context, list): 85 | context = context[-1] 86 | if isinstance(context, tuple): 87 | context = torch.cat(context, dim=-1) 88 | batch_size = context.size(0) 89 | hidden_sizes = (batch_size, self.d_enc_model) 90 | return Variable(context.data.new(*hidden_sizes).zero_(), requires_grad=False) 91 | 92 | def forward(self, inputs, max_length=300): 93 | tgt_seq, src_seq, ans_seq = inputs['tgt_seq'], inputs['src_seq'], inputs['ans_seq'] 94 | enc_output, hidden = inputs['enc_output'], inputs['hidden'] 95 | feat_seqs = inputs['feat_seqs'] 96 | 97 | src_pad_mask = Variable(src_seq.data.eq(Constants.PAD).float(), requires_grad=False, volatile=False) 98 | if self.layer_attn: 99 | n_enc_layer = len(enc_output) 100 | src_pad_mask = src_pad_mask.repeat(1, n_enc_layer) 101 | enc_output = torch.cat(enc_output, dim=1) 102 | 103 | feat_inputs = None 104 | if self.feature: 105 | feat_inputs = [feat_emb(feat_seq) for feat_seq, feat_emb in zip(feat_seqs, self.feat_embs)] 106 | feat_inputs = torch.cat(feat_inputs, dim=2) 107 | if self.layer_attn: 108 | feat_inputs = feat_inputs.repeat(1, n_enc_layer, 1) 109 | 110 | dec_outputs, coverage_output, copy_output, copy_gate_output = [], [], [], [] 111 | cur_context = self.attn_init(enc_output) 112 | 113 | if self.answer: 114 | ans_words = torch.sum(self.ans_emb(ans_seq), dim=1) 115 | hidden = self.decInit(ans_words).unsqueeze(0) 116 | else: 117 | hidden = self.decInit(hidden).unsqueeze(0) 118 | # ans_words = torch.sum(self.ans_emb(ans_seq), dim=1) 119 | # hidden = self.decInit(hidden).unsqueeze(0) 120 | tmp_context, tmp_coverage = None, None 121 | 122 | dec_input = self.word_emb(tgt_seq) 123 | 124 | self.attn.apply_mask(src_pad_mask) 125 | 126 | attention_scores = None 127 | tag = False 128 | 129 | dec_input = dec_input.transpose(0, 1) 130 | for seq_idx, dec_input_emb in enumerate(dec_input.split(1)): 131 | dec_input_emb = dec_input_emb.squeeze(0) 132 | raw_dec_input_emb = dec_input_emb 133 | if self.input_feed: 134 | dec_input_emb = torch.cat((dec_input_emb, cur_context), dim=1) 135 | dec_output, hidden = self.rnn(dec_input_emb, hidden) 136 | 137 | if self.coverage: 138 | if tmp_coverage is None: 139 | tmp_coverage = Variable(torch.zeros((enc_output.size(0), enc_output.size(1)))) 140 | if self.device: 141 | tmp_coverage = tmp_coverage.to(self.device) 142 | cur_context, attn, tmp_context, next_coverage = self.attn(dec_output, enc_output, precompute=tmp_context, 143 | coverage=tmp_coverage, feat_inputs=feat_inputs, 144 | feature=self.feature) 145 | avg_tmp_coverage = tmp_coverage / max(1, seq_idx) 146 | coverage_loss = torch.sum(torch.min(attn, avg_tmp_coverage), dim=1) 147 | tmp_coverage = next_coverage 148 | coverage_output.append(coverage_loss) 149 | else: 150 | cur_context, attn, tmp_context = self.attn(dec_output, enc_output, precompute=tmp_context, 151 | feat_inputs=feat_inputs, feature=self.feature) 152 | 153 | attention_scores = attn if not tag else attn + attention_scores 154 | tag = True 155 | 156 | if self.copy: 157 | copy_prob = self.copy_switch(torch.cat((dec_output, cur_context), dim=1)) 158 | copy_prob = torch.sigmoid(copy_prob) 159 | 160 | if self.layer_attn: 161 | attn = attn.view(attn.size(0), n_enc_layer, -1) 162 | attn = attn.sum(1) 163 | 164 | copy_output.append(attn) 165 | copy_gate_output.append(copy_prob) 166 | 167 | readout = self.readout(torch.cat((raw_dec_input_emb, dec_output, cur_context), dim=1)) 168 | maxout = self.maxout(readout) 169 | output = self.dropout(maxout) 170 | 171 | dec_outputs.append(output) 172 | 173 | dec_output = torch.stack(dec_outputs).transpose(0, 1) 174 | 175 | sum_attention_scores = torch.sum(attention_scores, dim=1, keepdim=True) + 1e-8 176 | attention_scores = attention_scores / sum_attention_scores 177 | 178 | rst = {} 179 | rst['pred'], rst['attn'], rst['context'] = dec_output, attn, cur_context 180 | rst['attention_scores'] = (attention_scores, inputs['scores']) 181 | if self.copy: 182 | copy_output = torch.stack(copy_output).transpose(0, 1) 183 | copy_gate_output = torch.stack(copy_gate_output).transpose(0, 1) 184 | rst['copy_pred'], rst['copy_gate'] = copy_output, copy_gate_output 185 | if self.coverage: 186 | coverage_output = torch.stack(coverage_output).transpose(0, 1) 187 | rst['coverage_pred'] = coverage_output 188 | return rst 189 | 190 | 191 | class DecoderTransformer(nn.Module): 192 | ''' 193 | seq_output - [batch_size, seq_length, dim_seq_enc] 194 | graph_output - [batch_size, node_num, dim_graph_enc] 195 | indexes_list - [batch_size, node_num, index_num] (list) 196 | ''' 197 | def __init__(self, layer_attn, device=None): 198 | super(DecoderTransformer, self).__init__() 199 | self.layer_attn = layer_attn 200 | self.device = device 201 | 202 | def forward(self, inputs): 203 | seq_output, hidden = inputs['seq_output'], inputs['hidden'] 204 | graph_output, indexes_list = inputs['graph_output'], inputs['index'] 205 | 206 | batch_size, seq_length = seq_output.size(0), seq_output.size(1) 207 | dim_graph_enc = graph_output.size(-1) if not self.layer_attn else graph_output[-1].size(-1) 208 | if 'scores' in inputs: 209 | scores = inputs['scores'] 210 | distribution = torch.full((batch_size, seq_length), 1e-8).to(self.device) 211 | 212 | if self.layer_attn: 213 | graph_hidden_states = [torch.full((batch_size, seq_length, dim_graph_enc), 1e-8).to(self.device) for _ in range(len(graph_output))] 214 | else: 215 | graph_hidden_states = torch.full((batch_size, seq_length, dim_graph_enc), 1e-8).to(self.device) 216 | 217 | graph_node_sizes = torch.full((batch_size, seq_length), 0).to(self.device) 218 | 219 | for sample_idx, indexes in enumerate(indexes_list): 220 | ##=== for each sample ===## 221 | for node_idx, index in enumerate(indexes): 222 | ##=== for each node ===## 223 | for i in index: 224 | ##=== for each word ===## 225 | if self.layer_attn: 226 | for idx in range(len(graph_hidden_states)): 227 | graph_hidden_states[idx][sample_idx].narrow(0, i, 1).add_(graph_output[idx][sample_idx][node_idx]) 228 | else: 229 | graph_hidden_states[sample_idx].narrow(0, i, 1).add_(graph_output[sample_idx][node_idx]) 230 | graph_node_sizes[sample_idx][i] += 1 231 | if 'scores' in inputs: 232 | distribution[sample_idx][i] = scores[sample_idx][node_idx] 233 | 234 | for i in range(batch_size): 235 | for j in range(seq_length): 236 | if graph_node_sizes[i][j].item() < 1: 237 | graph_node_sizes[i][j] = 1 238 | 239 | if self.layer_attn: 240 | graph_hidden_states = [x / graph_node_sizes.unsqueeze(2).repeat(1, 1, dim_graph_enc) for x in graph_hidden_states] 241 | else: 242 | graph_hidden_states = graph_hidden_states / graph_node_sizes.unsqueeze(2).repeat(1, 1, dim_graph_enc) 243 | if 'scores' in inputs: 244 | distribution = distribution / graph_node_sizes 245 | 246 | if isinstance(hidden, tuple) or isinstance(hidden, list) or hidden.dim() == 3: 247 | hidden = [h for h in hidden] 248 | hidden = torch.cat(hidden, dim=1) 249 | hidden = hidden.contiguous().view(hidden.size(0), -1) 250 | 251 | distribution = distribution if 'scores' in inputs else None 252 | if self.layer_attn: 253 | enc_output = [torch.cat((graph_output, seq_output), dim=-1) for graph_output in graph_hidden_states] 254 | else: 255 | enc_output = torch.cat((graph_hidden_states, seq_output), dim=-1) 256 | 257 | return enc_output, distribution, hidden -------------------------------------------------------------------------------- /src/onqg/models/Encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 4 | from torch.nn.utils.rnn import pack_padded_sequence as pack 5 | 6 | import onqg.dataset.Constants as Constants 7 | 8 | from onqg.models.modules.Attention import GatedSelfAttention, ConcatAttention 9 | from onqg.models.modules.Layers import GraphEncoderLayer, SparseGraphEncoderLayer 10 | 11 | # from pytorch_pretrained_bert import BertModel 12 | 13 | 14 | class RNNEncoder(nn.Module): 15 | """ 16 | Input: (1) inputs['src_seq'] 17 | (2) inputs['lengths'] 18 | (3) inputs['feat_seqs'] 19 | Output: (1) enc_output 20 | (2) hidden 21 | """ 22 | def __init__(self, n_vocab, d_word_vec, d_model, n_layer, 23 | brnn, rnn, feat_vocab, d_feat_vec, slf_attn, 24 | dropout): 25 | self.name = 'rnn' 26 | 27 | self.n_layer = n_layer 28 | self.num_directions = 2 if brnn else 1 29 | assert d_model % self.num_directions == 0, "d_model = hidden_size x direction_num" 30 | self.hidden_size = d_model // self.num_directions 31 | self.d_enc_model = d_model 32 | 33 | super(RNNEncoder, self).__init__() 34 | 35 | self.word_emb = nn.Embedding(n_vocab, d_word_vec, padding_idx=Constants.PAD) 36 | input_size = d_word_vec 37 | 38 | self.feature = False if not feat_vocab else True 39 | if self.feature: 40 | self.feat_embs = nn.ModuleList([ 41 | nn.Embedding(n_f_vocab, d_feat_vec, padding_idx=Constants.PAD) for n_f_vocab in feat_vocab 42 | ]) 43 | input_size += len(feat_vocab) * d_feat_vec 44 | 45 | self.slf_attn = slf_attn 46 | if slf_attn: 47 | self.gated_slf_attn = GatedSelfAttention(d_model) 48 | 49 | if rnn == 'lstm': 50 | self.rnn = nn.LSTM(input_size, self.hidden_size, num_layers=n_layer, 51 | dropout=dropout, bidirectional=brnn, batch_first=True) 52 | elif rnn == 'gru': 53 | self.rnn = nn.GRU(input_size, self.hidden_size, num_layers=n_layer, 54 | dropout=dropout, bidirectional=brnn, batch_first=True) 55 | else: 56 | raise ValueError("Only support 'LSTM' and 'GRU' for RNN-based Encoder ") 57 | 58 | @classmethod 59 | def from_opt(cls, opt): 60 | return cls(opt['n_vocab'], opt['d_word_vec'], opt['d_model'], opt['n_layer'], 61 | opt['brnn'], opt['rnn'], opt['feat_vocab'], opt['d_feat_vec'], 62 | opt['slf_attn'], opt['dropout']) 63 | 64 | def forward(self, inputs): 65 | src_seq, lengths, feat_seqs = inputs['src_seq'], inputs['lengths'], inputs['feat_seqs'] 66 | lengths = torch.LongTensor(lengths.data.view(-1).tolist()) 67 | 68 | enc_input = self.word_emb(src_seq) 69 | if self.feature: 70 | feat_outputs = [feat_emb(feat_seq) for feat_seq, feat_emb in zip(feat_seqs, self.feat_embs)] 71 | feat_outputs = torch.cat(feat_outputs, dim=2) 72 | enc_input = torch.cat((enc_input, feat_outputs), dim=-1) 73 | 74 | enc_input = pack(enc_input, lengths, batch_first=True, enforce_sorted=False) 75 | enc_output, hidden = self.rnn(enc_input, None) 76 | enc_output = unpack(enc_output, batch_first=True)[0] 77 | 78 | if self.slf_attn: 79 | mask = (src_seq == Constants.PAD).byte() 80 | enc_output = self.gated_slf_attn(enc_output, mask) 81 | 82 | # try: 83 | # mask = (src_seq == Constants.PAD).byte() 84 | # mask = mask.unsqueeze(2).repeat(1, 1, 300).float() 85 | # hidden = torch.sum(enc_input * mask, dim=1) 86 | # enc_output = enc_input 87 | # denominator = lengths.unsqueeze(1).repeat(1, 300).float().to(hidden.device) 88 | # hidden = hidden / denominator 89 | # except: 90 | # print(enc_input.size(), mask.size(), length.size()) 91 | 92 | return enc_output, hidden 93 | 94 | 95 | class GraphEncoder(nn.Module): 96 | """Combine GGNN (Gated Graph Neural Network) and GAT (Graph Attention Network) 97 | Input: (1) nodes - [batch_size, node_num, d_model] 98 | (2) edges - ([batch_size, node_num * node_num], [batch_size, node_num * node_num]) 1st-inlink, 2nd-outlink 99 | (3) mask - ([batch_size, node_num, node_num], [batch_size, node_num, node_num]) 1st-inlink, 2nd-outlink 100 | (4) node_feats - list of [batch_size, node_num] 101 | """ 102 | def __init__(self, n_edge_type, d_model, n_layer, alpha, d_feat_vec, 103 | feat_vocab, layer_attn, dropout, attn_dropout): 104 | self.name = 'graph' 105 | super(GraphEncoder, self).__init__() 106 | self.layer_attn = layer_attn 107 | 108 | self.hidden_size = d_model 109 | self.d_model = d_model 110 | ###=== node features ===### 111 | self.feature = True if feat_vocab else False 112 | if self.feature: 113 | self.feat_embs = nn.ModuleList([ 114 | nn.Embedding(n_f_vocab, d_feat_vec, padding_idx=Constants.PAD) for n_f_vocab in feat_vocab 115 | ]) 116 | #self.hidden_size += d_feat_vec * len(feat_vocab) 117 | self.feature_transform = nn.Linear(self.hidden_size + d_feat_vec * len(feat_vocab), self.hidden_size) 118 | ###=== edge embedding ===### 119 | # self.edge_in_emb = nn.Embedding(n_edge_type, self.hidden_size * d_model, padding_idx=Constants.PAD) 120 | # self.edge_out_emb = nn.Embedding(n_edge_type, self.hidden_size * d_model, padding_idx=Constants.PAD) 121 | # self.edge_bias = edge_bias 122 | # if edge_bias: 123 | # self.edge_in_emb_bias = nn.Embedding(n_edge_type, d_model, padding_idx=Constants.PAD) 124 | # self.edge_out_emb_bias = nn.Embedding(n_edge_type, d_model, padding_idx=Constants.PAD) 125 | ###=== graph encode layers===### 126 | self.layer_stack = nn.ModuleList([ 127 | GraphEncoderLayer(self.hidden_size, d_model, alpha, feature=self.feature, 128 | dropout=dropout, attn_dropout=attn_dropout) for _ in range(n_layer) 129 | ]) 130 | ###=== gated output ===### 131 | self.gate = nn.Linear(2 * d_model, d_model, bias=False) 132 | 133 | @classmethod 134 | def from_opt(cls, opt): 135 | return cls(opt['n_edge_type'], opt['d_model'], opt['n_layer'], opt['alpha'], 136 | opt['d_feat_vec'], opt['feat_vocab'], opt['layer_attn'], 137 | opt['dropout'], opt['attn_dropout']) 138 | 139 | def gated_output(self, outputs, inputs): 140 | concatenation = torch.cat((outputs, inputs), dim=2) 141 | g_t = torch.sigmoid(self.gate(concatenation)) 142 | 143 | output = g_t * outputs + (1 - g_t) * inputs 144 | return output 145 | 146 | def forward(self, inputs): 147 | nodes, mask = inputs['nodes'], inputs['mask'] 148 | node_feats, node_type = inputs['feat_seqs'], inputs['type'] 149 | nodes = self.activate(nodes) 150 | node_output = nodes # batch_size x node_num x d_model 151 | ###=== get embeddings ===### 152 | feat_hidden = None 153 | if self.feature: 154 | feat_hidden = [feat_emb(node_feat) for node_feat, feat_emb in zip(node_feats, self.feat_embs)] 155 | feat_hidden = torch.cat(feat_hidden, dim=2) # batch_size x node_num x (hidden_size - d_model) 156 | node_output = self.feature_transform(torch.cat((node_output, feat_hidden), dim=-1)) 157 | # batch_size x (node_num * node_num) x hidden_size x d_model 158 | # edge_in_hidden = self.edge_in_emb(edges[0]).view(nodes.size(0), -1, self.hidden_size, nodes.size(2)) 159 | # edge_out_hidden = self.edge_out_emb(edges[1]).view(nodes.size(0), -1, self.hidden_size, nodes.size(2)) 160 | # edge_hidden = (edge_in_hidden, edge_out_hidden) 161 | # if self.edge_bias: 162 | # # batch_size x (node_num * node_num) x d_model 163 | # edge_in_hidden_bias, edge_out_hidden_bias = self.edge_in_emb_bias(edges[0]), self.edge_out_emb_bias(edges[1]) 164 | # edge_hidden_bias = (edge_in_hidden_bias, edge_out_hidden_bias) if self.edge_bias else None 165 | ##=== forward ===### 166 | node_outputs = [] 167 | for enc_layer in self.layer_stack: 168 | # node_output = enc_layer(node_output, edge_hidden, mask, feat_hidden=feat_hidden, 169 | # edge_hidden_bias=edge_hidden_bias) 170 | node_output = enc_layer(node_output, mask, node_type, feat_hidden=feat_hidden) 171 | node_outputs.append(node_output) 172 | 173 | node_output = self.gated_output(node_output, nodes) 174 | node_outputs[-1] = node_output 175 | 176 | hidden = [layer_output.transpose(0, 1)[0] for layer_output in node_outputs] 177 | 178 | if self.layer_attn: 179 | node_output = node_outputs 180 | 181 | return node_output, hidden 182 | 183 | 184 | class SparseGraphEncoder(nn.Module): 185 | """Sparse version of Graph Encoder""" 186 | """Combine GGNN (Gated Graph Neural Network) and GAT (Graph Attention Network) 187 | Input: (1) nodes - [batch_size, node_num, d_model] 188 | (2) edges - [edge_type_num] 189 | (3) mask - ([batch_size, node_num, node_num], [batch_size, node_num, node_num]) 1st-inlink, 2nd-outlink 190 | (4) node_feats - list of [batch_size, node_num] 191 | (5) adjacent_matrix - 2 * [batch_size, real_node_num, real_neighbor_num] 1st-inlink, 2nd-outlink 192 | """ 193 | def __init__(self, n_edge_type, d_model, d_rnn_enc_model, n_layer, alpha, d_feat_vec, 194 | feat_vocab, edge_bias, layer_attn, dropout, attn_dropout): 195 | self.name = 'graph' 196 | super(SparseGraphEncoder, self).__init__() 197 | self.layer_attn = layer_attn 198 | 199 | self.hidden_size = d_model 200 | self.d_model = d_model 201 | ###=== node features ===### 202 | self.feature = True if feat_vocab else False 203 | if self.feature: 204 | self.feat_embs = nn.ModuleList([ 205 | nn.Embedding(n_f_vocab, d_feat_vec, padding_idx=Constants.PAD) for n_f_vocab in feat_vocab 206 | ]) 207 | self.hidden_size += d_feat_vec * len(feat_vocab) 208 | ###=== edge embedding ===### 209 | self.edge_in_emb = nn.Embedding(n_edge_type, self.hidden_size * d_model, padding_idx=Constants.PAD) 210 | self.edge_out_emb = nn.Embedding(n_edge_type, self.hidden_size * d_model, padding_idx=Constants.PAD) 211 | self.edge_bias = edge_bias 212 | if edge_bias: 213 | self.edge_in_emb_bias = nn.Embedding(n_edge_type, d_model, padding_idx=Constants.PAD) 214 | self.edge_out_emb_bias = nn.Embedding(n_edge_type, d_model, padding_idx=Constants.PAD) 215 | ###=== graph encode layers===### 216 | self.layer_stack = nn.ModuleList([ 217 | SparseGraphEncoderLayer(self.hidden_size, d_model, alpha, edge_bias=edge_bias, feature=self.feature, 218 | dropout=dropout, attn_dropout=attn_dropout) for _ in range(n_layer) 219 | ]) 220 | ###=== gated output ===### 221 | self.gate = nn.Linear(d_model * 2, d_model, bias=False) 222 | 223 | @classmethod 224 | def from_opt(cls, opt): 225 | return cls(opt['n_edge_type'], opt['d_model'], opt['d_rnn_enc_model'], opt['n_layer'], opt['alpha'], 226 | opt['d_feat_vec'], opt['feat_vocab'], opt['edge_bias'], opt['layer_attn'], 227 | opt['dropout'], opt['attn_dropout']) 228 | 229 | def gated_output(self, outputs, inputs): 230 | concatenation = torch.cat((outputs, inputs), dim=2) 231 | g_t = torch.sigmoid(self.gate(concatenation)) 232 | 233 | output = g_t * outputs + (1 - g_t) * inputs 234 | return output 235 | 236 | def forward(self, inputs): 237 | nodes, edges, mask = inputs['nodes'], inputs['edges'], inputs['mask'] 238 | node_feats, adjacent_matrix = inputs['feat_seqs'], inputs['adjacent_matrix'] 239 | nodes = self.activate(nodes) 240 | node_output = nodes # batch_size x node_num x d_model 241 | ###=== get embeddings ===### 242 | feat_hidden = None 243 | if self.feature: 244 | feat_hidden = [feat_emb(node_feat) for node_feat, feat_emb in zip(node_feats, self.feat_embs)] 245 | feat_hidden = torch.cat(feat_hidden, dim=2) # batch_size x node_num x (hidden_size - d_model) 246 | # batch_size x (node_num * node_num) x hidden_size x d_model 247 | edge_in_hidden = self.edge_in_emb(edges).view(-1, self.hidden_size, nodes.size(2)) 248 | edge_out_hidden = self.edge_out_emb(edges).view(-1, self.hidden_size, nodes.size(2)) 249 | edge_hidden = (edge_in_hidden, edge_out_hidden) 250 | if self.edge_bias: 251 | # batch_size x (node_num * node_num) x d_model 252 | edge_in_hidden_bias, edge_out_hidden_bias = self.edge_in_emb_bias(edges), self.edge_out_emb_bias(edges) 253 | edge_hidden_bias = (edge_in_hidden_bias, edge_out_hidden_bias) if self.edge_bias else None 254 | ###=== forward ===### 255 | node_outputs = [] 256 | for enc_layer in self.layer_stack: 257 | node_output = enc_layer(node_output, edge_hidden, mask, adjacent_matrix, 258 | feat_hidden=feat_hidden, edge_hidden_bias=edge_hidden_bias) 259 | node_outputs.append(node_output) 260 | 261 | node_output = self.gated_output(node_output, nodes) 262 | node_outputs[-1] = node_output 263 | 264 | hidden = [layer_output.transpose(0, 1)[0] for layer_output in node_outputs] 265 | 266 | if self.layer_attn: 267 | node_output = node_outputs 268 | 269 | return node_output, hidden 270 | 271 | 272 | class EncoderTransformer(nn.Module): 273 | """Transform RNN-Encoder's output to Graph-Encoder's input 274 | Input: seq_output - [batch_size, seq_length, rnn_enc_dim] (tensor) 275 | root_list - [batch_size, node_num] (list) 276 | indexes_list - [batch_size, node_num, index_num] (list) 277 | node_sizes - [batch_size * node_num, 1] (list) 278 | """ 279 | def __init__(self, d_model, d_k=64, device=None): 280 | super(EncoderTransformer, self).__init__() 281 | self.device = device 282 | self.d_k = d_k 283 | 284 | self.attn = ConcatAttention(d_model, d_model, d_k) 285 | 286 | def forward(self, inputs, max_length): 287 | 288 | def pad(vectors, data_length, max_length=None): 289 | hidden_size = (max_length, vectors.size(1)) 290 | out = torch.zeros(hidden_size, device=self.device) 291 | out.narrow(0, 0, data_length).copy_(vectors) # bag_size x rnn_enc_dim 292 | return out 293 | 294 | seq_output, hidden, indexes_list = inputs['seq_output'], inputs['hidden'], inputs['index'] 295 | 296 | if isinstance(hidden, tuple) or isinstance(hidden, list) or hidden.dim() == 3: 297 | hidden = [h for h in hidden] 298 | hidden = torch.cat(hidden, dim=1) 299 | hidden = hidden.contiguous().view(hidden.size(0), -1) 300 | 301 | # root_list = inputs['root'] 302 | node_sizes, node_lengths = inputs['lengths'], inputs['node_lengths'] 303 | max_length = max(node_sizes) 304 | ##===== prepare vectors (do padding) =====## 305 | roots, bags, cnt = [], [], 0 306 | # for sample_idx, sample in enumerate(zip(root_list, indexes_list)): 307 | for sample_idx, indexes in enumerate(indexes_list): 308 | # root, indexes = sample[0], sample[1] 309 | # for root_idx, indexes_idx in zip(root, indexes): 310 | for indexes_idx in indexes: 311 | # roots.append(seq_output[sample_idx][root_idx]) 312 | roots.append(hidden[sample_idx]) 313 | bag = pad(torch.stack([seq_output[sample_idx][idx] for idx in indexes_idx], dim=0), 314 | node_sizes[cnt], max_length) 315 | bags.append(bag) 316 | cnt += 1 317 | roots = torch.stack(roots, dim=0) # all_node_num x rnn_enc_dim 318 | bags = torch.stack(bags, dim=0) # all_node_num x bag_size x rnn_enc_dim 319 | ##===== cross attention =====## 320 | context, *_ = self.attn(roots, bags) # all_node_num x rnn_enc_dim 321 | ##===== get node vectors =====## 322 | max_length = max(node_lengths) 323 | nodes = [] 324 | for node_length in node_lengths: 325 | nodes.append(pad(context[:node_length], node_length, max_length)) 326 | context = context[node_length:] 327 | 328 | nodes = torch.stack(nodes, dim=0) # batch_size x node_num x d_model 329 | 330 | return nodes, hidden 331 | 332 | 333 | class TransfEncoder(nn.Module): 334 | """ 335 | Input: (1) inputs['src_seq'] 336 | (2) inputs['src_pos'] 337 | (3) inputs['feat_seqs'] 338 | Output: (1) enc_output 339 | (2) hidden 340 | """ 341 | def __init__(self, n_vocab, pretrained=None, model_name='default', layer_attn=False): 342 | self.name = 'transf' 343 | self.model_type = model_name 344 | 345 | super(TransfEncoder, self).__init__() 346 | 347 | self.layer_attn = layer_attn 348 | self.pretrained = pretrained 349 | 350 | if model_name == 'bert': 351 | self.d_enc_model = 768 352 | self.d_head = 8 353 | self.n_enc_layer = 12 354 | 355 | @classmethod 356 | def from_opt(cls, opt): 357 | if opt['pretrained'].count('bert'): 358 | pretrained = BertModel.from_pretrained(opt['pretrained']) 359 | return cls(opt['n_vocab'], pretrained=pretrained, layer_attn=opt['layer_attn'], model_name='bert') 360 | else: 361 | raise ValueError("Other pretrained models haven't been supported yet") 362 | 363 | def forward(self, inputs, return_attns=False): 364 | src_seq = inputs['src_seq'] 365 | 366 | if self.model_type == 'bert': 367 | enc_outputs, *_ = self.pretrained(src_seq, output_all_encoded_layers=True) 368 | enc_output = enc_outputs[-1] 369 | 370 | hidden = [layer_output.transpose(0, 1)[0] for layer_output in enc_outputs] 371 | if self.layer_attn: 372 | enc_output = enc_outputs 373 | 374 | return enc_output, hidden -------------------------------------------------------------------------------- /src/onqg/models/Models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import onqg.dataset.Constants as Constants 6 | from onqg.models.modules.Attention import ConcatAttention 7 | 8 | 9 | class UnifiedModel(nn.Module): 10 | ''' Unify Sequence-Encoder and Graph-Encoder 11 | 12 | Input: seq-encoder: src_seq, lengths, feat_seqs 13 | graph-encoder: edges 14 | encoder-transform: index, lengths, root 15 | decoder: tgt_seq, src_seq, feat_seqs 16 | answer-encoder: src_seq, lengths, feat_seqs 17 | 18 | Output: results output from the Decoder (type: dict) 19 | ''' 20 | def __init__(self, model_type, seq_encoder, graph_encoder, encoder_transformer, 21 | decoder, decoder_transformer): 22 | super(UnifiedModel, self).__init__() 23 | 24 | self.model_type = model_type 25 | 26 | self.seq_encoder = seq_encoder 27 | 28 | self.encoder_transformer = encoder_transformer 29 | self.graph_encoder = graph_encoder 30 | 31 | self.decoder_transformer = decoder_transformer 32 | self.decoder = decoder 33 | 34 | def forward(self, inputs, max_length=None): 35 | #========== forward ==========# 36 | ## RNN encode ## 37 | seq_output, hidden = self.seq_encoder(inputs['seq-encoder']) 38 | ## encoder transform ## 39 | inputs['encoder-transform']['seq_output'] = seq_output 40 | inputs['encoder-transform']['hidden'] = hidden 41 | node_input, hidden = self.encoder_transformer(inputs['encoder-transform'], max_length) 42 | ## graph encode ## 43 | inputs['graph-encoder']['nodes'] = node_input 44 | node_output, _ = self.graph_encoder(inputs['graph-encoder']) 45 | 46 | outputs = {} 47 | 48 | #========== classify =========# 49 | if self.model_type != 'generate': 50 | scores = self.classifier(node_output) if not self.decoder.layer_attn else self.classifier(node_output[-1]) 51 | inputs['decoder-transform']['scores'] = scores 52 | outputs['classification'] = scores 53 | #========== generate =========# 54 | inputs['decoder-transform']['graph_output'] = node_output 55 | inputs['decoder-transform']['seq_output'] = seq_output 56 | inputs['decoder-transform']['hidden'] = hidden 57 | inputs['decoder']['enc_output'], inputs['decoder']['scores'], hidden = self.decoder_transformer(inputs['decoder-transform']) 58 | inputs['decoder']['hidden'] = hidden 59 | dec_output = self.decoder(inputs['decoder']) 60 | outputs['generation'] = dec_output 61 | #========== generate =========# 62 | if self.model_type != 'classify': 63 | outputs['generation']['pred'] = self.generator(dec_output['pred']) 64 | 65 | return outputs 66 | -------------------------------------------------------------------------------- /src/onqg/models/modules/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | ''' Scaled Dot-Product Attention ''' 9 | 10 | def __init__(self, temperature, attn_dropout=0.1): 11 | super().__init__() 12 | self.temperature = temperature 13 | self.dropout = nn.Dropout(attn_dropout) 14 | self.softmax = nn.Softmax(dim=2) 15 | 16 | def forward(self, q, k, v, mask=None): 17 | attn = torch.bmm(q, k.transpose(1, 2)) 18 | attn = attn / self.temperature 19 | 20 | attn = self.softmax(attn) 21 | 22 | if mask is not None: 23 | attn = attn.masked_fill(mask, 0) 24 | sumattn = torch.sum(attn, dim=2, keepdim=True) + 1e-8 25 | attn = attn / sumattn 26 | 27 | attn = self.dropout(attn) 28 | output = torch.bmm(attn, v) 29 | 30 | return output, attn 31 | 32 | 33 | class ConcatAttention(nn.Module): 34 | def __init__(self, attend_dim, query_dim, att_dim, is_coverage=False): 35 | super(ConcatAttention, self).__init__() 36 | 37 | self.attend_dim = attend_dim 38 | self.query_dim = query_dim 39 | self.att_dim = att_dim 40 | 41 | self.linear_pre = nn.Linear(attend_dim, att_dim, bias=True) 42 | self.linear_q = nn.Linear(query_dim, att_dim, bias=False) 43 | self.linear_v = nn.Linear(att_dim, 1, bias=False) 44 | 45 | self.sftmax = nn.Softmax(dim=1) 46 | self.tanh = nn.Tanh() 47 | 48 | self.mask = None 49 | 50 | self.is_coverage = is_coverage 51 | if is_coverage: 52 | self.linear_cov = nn.Linear(1, att_dim, bias=False) 53 | 54 | def apply_mask(self, mask): 55 | self.mask = mask 56 | 57 | def forward(self, input, context, precompute=None, coverage=None, feat_inputs=None, feature=False): 58 | """ 59 | input: batch x dim 60 | context: batch x sourceL x dim 61 | """ 62 | enc_output = torch.cat((context, feat_inputs), dim=2) if feature else context 63 | if precompute is None: 64 | precompute = self.linear_pre(enc_output) # batch x sourceL x att_dim 65 | targetT = self.linear_q(input).unsqueeze(1) # batch x 1 x att_dim 66 | 67 | tmp_sum = precompute + targetT.repeat(1, precompute.size(1), 1) # batch x sourceL x att_dim 68 | 69 | if self.is_coverage: 70 | weighted_coverage = self.linear_cov(coverage.unsqueeze(2)) # batch x sourceL x att_dim 71 | tmp_sum += weighted_coverage 72 | 73 | tmp_activated = self.tanh(tmp_sum) # batch x sourceL x att_dim 74 | energy = self.linear_v(tmp_activated).view(tmp_sum.size(0), tmp_sum.size(1)) # batch x sourceL 75 | if self.mask is not None: 76 | energy = energy * (1 - self.mask) + self.mask * (-1000000) 77 | 78 | score = self.sftmax(energy) # batch x sourceL 79 | 80 | weightedContext = torch.bmm(score.unsqueeze(1), context).squeeze(1) # batch x dim 81 | 82 | if self.is_coverage: 83 | coverage = coverage + score # batch x sourceL 84 | return weightedContext, score, precompute, coverage 85 | 86 | return weightedContext, score, precompute 87 | 88 | 89 | class GatedSelfAttention(nn.Module): 90 | def __init__(self, dim, attn_dim=64, dropout=0.1): 91 | super(GatedSelfAttention, self).__init__() 92 | 93 | self.m_translate = nn.Linear(dim, attn_dim) 94 | self.q_translate = nn.Linear(dim, attn_dim) 95 | 96 | self.update = nn.Linear(2 * dim, dim, bias=False) 97 | 98 | self.gate = nn.Linear(2 * dim, dim, bias=False) 99 | 100 | if dropout > 0: 101 | self.dropout = nn.Dropout(dropout) 102 | self.has_dropout = True if dropout > 0 else False 103 | 104 | def forward(self, query, mask): 105 | raw = query 106 | 107 | memory = self.m_translate(query) # b_sz x src_len x 64 108 | query = self.q_translate(query) 109 | 110 | energy = torch.bmm(query, memory.transpose(1, 2)) # b_sz x src_len x src_len 111 | energy = energy.masked_fill(mask, value=-1e12) 112 | 113 | score = torch.softmax(energy, dim=2) 114 | if self.has_dropout: 115 | score = self.dropout(score) 116 | context = torch.bmm(score, raw) 117 | 118 | inputs = torch.cat((raw, context), dim=2) 119 | 120 | f_t = torch.tanh(self.update(inputs)) 121 | g_t = torch.sigmoid(self.gate(inputs)) 122 | 123 | output = g_t * f_t + (1 - g_t) * raw 124 | 125 | return output, score 126 | 127 | 128 | class GraphAttention(nn.Module): 129 | def __init__(self, d_q, d_v, alpha, dropout=0.1): 130 | super(GraphAttention, self).__init__() 131 | self.dropout = nn.Dropout(dropout) 132 | 133 | self.attention = nn.Linear(d_q + d_v, 1) 134 | self.leaky_relu = nn.LeakyReLU(alpha) 135 | 136 | def forward(self, query, value, mask): 137 | """ 138 | query - [batch_size, node_num * node_num, d_hidden] 139 | value - [batch_size, node_num * node_num, d_model] 140 | mask - [batch_size, node_num, node_num] 141 | """ 142 | node_num = int(query.size(1) **0.5) 143 | query = query.view(-1, node_num, query.size(2)) 144 | value = value.view(-1, node_num, value.size(2)) # (batch_size * node_num) x node_num x d_model 145 | 146 | pre_attention = torch.cat([query, value], dim=2) 147 | energy = self.leaky_relu(self.attention(pre_attention).squeeze(2)) # (batch_size * node_num) x node_num 148 | # energy = torch.ones((query.size(0) * node_num, node_num)).to(value.device) 149 | 150 | mask = mask.view(-1, node_num) 151 | zero_vec = -9e15 * torch.ones_like(energy) 152 | try: 153 | attention = torch.where(mask > 0, energy, zero_vec) 154 | except: 155 | print(mask.size(), zero_vec.size(), energy.size(), query.size(), value.size()) 156 | 157 | scores = torch.softmax(attention, dim=1) # (batch_size * node_num) x node_num 158 | scores = self.dropout(scores) 159 | 160 | value = torch.bmm(scores.unsqueeze(1), value).squeeze(1) # (batch_size * node_num) x d_model 161 | value = value.view(-1, node_num, value.size(-1)) 162 | 163 | return value 164 | -------------------------------------------------------------------------------- /src/onqg/models/modules/DecAssist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DecInit(nn.Module): 6 | def __init__(self, d_enc, d_dec, n_enc_layer): 7 | self.d_enc_model = d_enc 8 | self.n_enc_layer = n_enc_layer 9 | self.d_dec_model = d_dec 10 | 11 | super(DecInit, self).__init__() 12 | 13 | self.initer = nn.Linear(self.d_enc_model * self.n_enc_layer, self.d_dec_model) 14 | self.tanh = nn.Tanh() 15 | 16 | def forward(self, hidden): 17 | if isinstance(hidden, tuple) or isinstance(hidden, list) or hidden.dim() == 3: 18 | hidden = [h for h in hidden] 19 | hidden = torch.cat(hidden, dim=1) 20 | hidden = hidden.contiguous().view(hidden.size(0), -1) 21 | return self.tanh(self.initer(hidden)) 22 | 23 | 24 | class StackedRNN(nn.Module): 25 | def __init__(self, num_layers, input_size, rnn_size, dropout, rnn='lstm'): 26 | self.dropout = dropout 27 | self.num_layers = num_layers 28 | 29 | super(StackedRNN, self).__init__() 30 | 31 | self.layers = nn.ModuleList() 32 | self.name = rnn 33 | 34 | for _ in range(num_layers): 35 | if rnn == 'lstm': 36 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 37 | elif rnn == 'gru': 38 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 39 | else: 40 | raise ValueError("Supported StackedRNN: LSTM, GRU") 41 | input_size = rnn_size 42 | 43 | def forward(self, inputs, hidden): 44 | if self.name == 'lstm': 45 | h_0, c_0 = hidden 46 | elif self.name == 'gru': 47 | h_0 = hidden 48 | h_1, c_1 = [], [] 49 | 50 | for i, layer in enumerate(self.layers): 51 | if self.name == 'lstm': 52 | h_1_i, c_1_i = layer(inputs, (h_0[i], c_0[i])) 53 | elif self.name == 'gru': 54 | h_1_i = layer(inputs, h_0[i]) 55 | inputs = h_1_i 56 | if i + 1 != self.num_layers: 57 | inputs = self.dropout(inputs) 58 | h_1.append(h_1_i) 59 | if self.name == 'lstm': 60 | c_1.append(c_1_i) 61 | 62 | h_1 = torch.stack(h_1) 63 | if self.name == 'lstm': 64 | c_1 = torch.stack(c_1) 65 | h_1 = (h_1, c_1) 66 | 67 | return inputs, h_1 68 | 69 | -------------------------------------------------------------------------------- /src/onqg/models/modules/Layers.py: -------------------------------------------------------------------------------- 1 | ''' Define the Layers ''' 2 | import torch 3 | import torch.nn as nn 4 | from onqg.models.modules.SubLayers import MultiHeadAttention, PositionwiseFeedForward, Propagator 5 | from onqg.models.modules.Attention import GatedSelfAttention, GraphAttention 6 | import onqg.dataset.Constants as Constants 7 | 8 | 9 | class EncoderLayer(nn.Module): 10 | ''' Compose with two layers ''' 11 | 12 | def __init__(self, d_model, slf_attn, d_inner, n_head, d_k, d_v, 13 | dropout=0.1, attn_dropout=0.1): 14 | super(EncoderLayer, self).__init__() 15 | self.slf_attn = slf_attn 16 | if slf_attn == 'multi-head': 17 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout, attn_dropout=attn_dropout) 18 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 19 | else: 20 | self.gated_slf_attn = GatedSelfAttention(d_model, d_k, dropout=attn_dropout) 21 | 22 | def forward(self, enc_input, src_seq, non_pad_mask=None, slf_attn_mask=None, layer_id=-1): 23 | if self.slf_attn == 'gated': 24 | mask = (src_seq == Constants.PAD).unsqueeze(2) if slf_attn_mask is None else slf_attn_mask 25 | enc_output, enc_slf_attn = self.gated_slf_attn(enc_input, mask) 26 | else: 27 | enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) 28 | enc_output *= non_pad_mask 29 | 30 | enc_output = self.pos_ffn(enc_output) 31 | enc_output *= non_pad_mask 32 | 33 | return enc_output, enc_slf_attn 34 | 35 | 36 | class GraphEncoderLayer(nn.Module): 37 | '''GGNN & GAT Layer''' 38 | def __init__(self, d_hidden, d_model, alpha, feature=False, dropout=0.1, attn_dropout=0.1): 39 | super(GraphEncoderLayer, self).__init__() 40 | self.d_hidden = d_hidden 41 | self.d_model = d_model 42 | self.feature = feature 43 | 44 | self.edge_num = 3 # TODO: fix this magic number 45 | bias_list = [False, False, False] 46 | self.edge_in_list = nn.ModuleList([nn.Linear(d_hidden, d_model, bias=bias_list[i]) for i in range(self.edge_num)]) 47 | self.edge_out_list = nn.ModuleList([nn.Linear(d_hidden, d_model, bias=bias_list[i]) for i in range(self.edge_num)]) 48 | # self.edge_in_emb = nn.Linear(d_hidden, d_model) 49 | # self.edge_out_emb = nn.Linear(d_hidden, d_model) 50 | 51 | self.graph_in_attention = GraphAttention(d_hidden, d_model, alpha, dropout=attn_dropout) 52 | self.graph_out_attention = GraphAttention(d_hidden, d_model, alpha, dropout=attn_dropout) 53 | self.output_gate = Propagator(d_model, dropout=dropout) 54 | 55 | def forward(self, nodes, mask, node_type, feat_hidden=None): 56 | ###=== concatenation ===### 57 | node_hidden = nodes # batch_size x node_num x d_model 58 | # if self.feature: 59 | # node_hidden = torch.cat((node_hidden, feat_hidden), dim=-1) # batch_size x node_num x d_hidden 60 | ###=== transform using edge matrix ===### 61 | in_masks = [(node_type == tag).float().unsqueeze(2).repeat(1, 1, self.d_model).to(nodes.device) 62 | for tag in range(2, 2 + self.edge_num)] 63 | node_in_hidden = torch.sum(torch.stack([in_emb(node_hidden) * in_masks[idx] 64 | for idx, in_emb in enumerate(self.edge_in_list)], dim=0), dim=0) 65 | out_masks = [(node_type == tag).float().unsqueeze(2).repeat(1, 1, self.d_model).to(nodes.device) 66 | for tag in range(2, 2 + self.edge_num)] 67 | node_out_hidden = torch.sum(torch.stack([out_emb(node_hidden) * out_masks[idx] 68 | for idx, out_emb in enumerate(self.edge_out_list)], dim=0), dim=0) # batch_size x node_num x d_model 69 | # node_in_hidden = self.edge_in_emb(node_hidden) 70 | # node_out_hidden = self.edge_out_emb(node_hidden) 71 | ###=== graph attention ===### 72 | node_hidden = node_hidden.unsqueeze(2).repeat(1, 1, nodes.size(1), 1).view(nodes.size(0), -1, self.d_hidden) 73 | node_in_hidden = self.graph_in_attention(node_hidden, node_in_hidden.repeat(1, nodes.size(1), 1), mask[0]) 74 | node_out_hidden = self.graph_out_attention(node_hidden, node_out_hidden.repeat(1, nodes.size(1), 1), mask[1]) 75 | ###=== gated recurrent unit ===### 76 | node_output = self.output_gate(nodes, node_in_hidden, node_out_hidden) 77 | 78 | return node_output 79 | 80 | 81 | class SparseGraphEncoderLayer(nn.Module): 82 | '''Sparse GGNN & GAT Layer''' 83 | def __init__(self, d_hidden, d_model, alpha, edge_bias=False, feature=False, dropout=0.1, attn_dropout=0.1): 84 | super(SparseGraphEncoderLayer, self).__init__() 85 | self.edge_bias = edge_bias 86 | self.d_hidden = d_hidden 87 | self.d_model = d_model 88 | self.feature = feature 89 | 90 | self.graph_in_attention = GraphAttention(d_hidden, d_model, alpha, dropout=attn_dropout) 91 | self.graph_out_attention = GraphAttention(d_hidden, d_model, alpha, dropout=attn_dropout) 92 | self.output_gate = Propagator(d_model, dropout=dropout) 93 | 94 | def forward(self, nodes, edges, mask, adjacent_matrixes, feat_hidden=None, 95 | edge_hidden_bias=None): 96 | ###=== concatenation ===### 97 | node_hidden = nodes 98 | if self.feature: 99 | node_hidden = torch.cat((node_hidden, feat_hidden), dim=-1) # batch_size x node_num x d_hidden 100 | ###=== forward ===### 101 | node_in_hidden = torch.zeros((nodes.size(0), nodes.size(1), nodes.size(1), self.d_model)).to(nodes.device) 102 | node_out_hidden = torch.zeros((nodes.size(0), nodes.size(1), nodes.size(1), self.d_model)).to(nodes.device) 103 | in_matrixes, out_matrixes = adjacent_matrixes[0], adjacent_matrixes[1] 104 | 105 | for sample_id, data in enumerate(zip(node_hidden, in_matrixes, out_matrixes)): 106 | sample, in_matrix, out_matrix = data[0], data[1], data[2] 107 | # in/out_matrix - [real_node_num, real_neighbor_num] 108 | for idx, indexes in enumerate(zip(in_matrix, out_matrix)): 109 | in_index, out_index = indexes[0], indexes[1] 110 | # in/out_index - [real_neighbor_num] 111 | for wrap in in_index: 112 | vector = torch.matmul(sample[wrap[0]], edges[0][wrap[1]]) # [d_model] 113 | node_in_hidden[sample_id][idx].narrow(0, wrap[0], 1).copy_(vector) 114 | for wrap in out_index: 115 | vector = torch.matmul(sample[wrap[0]], edges[1][wrap[1]]) # [d_model] 116 | node_out_hidden[sample_id][idx].narrow(0, wrap[0], 1).copy_(vector) 117 | 118 | ###=== graph-self-attention ===### 119 | node_in_hidden = node_in_hidden.view(nodes.size(0), -1, self.d_model) 120 | node_out_hidden = node_out_hidden.view(nodes.size(0), -1, self.d_model) 121 | node_hidden = node_hidden.unsqueeze(2).repeat(1, 1, nodes.size(1), 1).view(nodes.size(0), -1, self.d_hidden) 122 | node_in_hidden = self.graph_in_attention(node_hidden, node_in_hidden, mask[0]) 123 | node_out_hidden = self.graph_out_attention(node_hidden, node_out_hidden, mask[1]) 124 | ###=== gated recurrent unit ===### 125 | node_output = self.output_gate(nodes, node_in_hidden, node_out_hidden) 126 | 127 | return node_output 128 | -------------------------------------------------------------------------------- /src/onqg/models/modules/MaxOut.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class MaxOut(nn.Module): 5 | def __init__(self, pool_size): 6 | super(MaxOut, self).__init__() 7 | self.pool_size = pool_size 8 | 9 | def forward(self, ipt): 10 | """ 11 | input: 12 | reduce_size: 13 | """ 14 | input_size = list(ipt.size()) 15 | assert input_size[-1] % self.pool_size == 0 16 | output_size = [d for d in input_size] 17 | output_size[-1] = output_size[-1] // self.pool_size 18 | output_size.append(self.pool_size) 19 | last_dim = len(output_size) - 1 20 | ipt = ipt.view(*output_size) 21 | ipt, _ = ipt.max(last_dim, keepdim=True) 22 | output = ipt.squeeze(last_dim) 23 | 24 | return output -------------------------------------------------------------------------------- /src/onqg/models/modules/SubLayers.py: -------------------------------------------------------------------------------- 1 | ''' Define the sublayers in encoder/decoder layer ''' 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from onqg.models.modules.Attention import ScaledDotProductAttention 7 | from onqg.models.modules.MaxOut import MaxOut 8 | 9 | 10 | class MultiHeadAttention(nn.Module): 11 | ''' Multi-Head Attention module ''' 12 | def __init__(self, n_head, d_model, d_k, d_v, addition_input=0, dropout=0.1, attn_dropout=0.1): 13 | super().__init__() 14 | 15 | self.n_head = n_head 16 | self.d_k = d_k 17 | self.d_v = d_v 18 | 19 | self.w_qs = nn.Linear(d_model, n_head * d_k) 20 | self.w_ks = nn.Linear(d_model + addition_input, n_head * d_k) 21 | self.w_vs = nn.Linear(d_model + addition_input, n_head * d_v) 22 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 23 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + addition_input + d_k))) 24 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + addition_input + d_v))) 25 | 26 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=attn_dropout) 27 | self.layer_norm = nn.LayerNorm(d_model) 28 | 29 | self.fc = nn.Linear(n_head * d_v, d_model) 30 | nn.init.xavier_normal_(self.fc.weight) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | 34 | def forward(self, q, k, v, mask=None): 35 | 36 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 37 | 38 | sz_b, len_q, _ = q.size() 39 | sz_b, len_k, _ = k.size() 40 | sz_b, len_v, _ = v.size() 41 | 42 | residual = q 43 | 44 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 45 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 46 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 47 | 48 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 49 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 50 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 51 | 52 | if mask is None: 53 | mask = None 54 | else: 55 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 56 | output, attn = self.attention(q, k, v, mask=mask) 57 | 58 | output = output.view(n_head, sz_b, len_q, d_v) 59 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 60 | 61 | output = self.dropout(self.fc(output)) 62 | output = self.layer_norm(output + residual) 63 | 64 | return output, attn 65 | 66 | class PositionwiseFeedForward(nn.Module): 67 | ''' A two-feed-forward-layer module ''' 68 | 69 | def __init__(self, d_in, d_hid, dropout=0.1): 70 | super().__init__() 71 | 72 | self.onelayer = d_hid == d_in 73 | if self.onelayer: # just to reduce the number of parameters 74 | self.w = nn.Linear(d_in, d_in, bias=False) 75 | self.tanh = nn.Tanh() 76 | else: 77 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 78 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 79 | 80 | self.layer_norm = nn.LayerNorm(d_in) 81 | 82 | self.dropout = nn.Dropout(dropout) 83 | 84 | def forward(self, x): 85 | residual = x 86 | 87 | if self.onelayer: 88 | output = self.w(x) 89 | output = self.tanh(output) 90 | else: 91 | output = x.transpose(1, 2) 92 | output = self.w_2(F.relu(self.w_1(output))) 93 | output = output.transpose(1, 2) # batch_size x seq_length x d_word_vec 94 | 95 | output = self.dropout(output) 96 | output = self.layer_norm(output + residual) 97 | 98 | return output 99 | 100 | 101 | class Propagator(nn.Module): 102 | def __init__(self, state_dim, dropout=0.1): 103 | super(Propagator, self).__init__() 104 | 105 | self.reset_gate = nn.Sequential( 106 | nn.Linear(state_dim * 3, state_dim), 107 | nn.Sigmoid(), 108 | nn.Dropout(dropout) 109 | ) 110 | self.update_gate = nn.Sequential( 111 | nn.Linear(state_dim * 3, state_dim), 112 | nn.Sigmoid(), 113 | nn.Dropout(dropout) 114 | ) 115 | self.transform = nn.Sequential( 116 | nn.Linear(state_dim * 3, state_dim), 117 | nn.Tanh() 118 | ) 119 | 120 | def forward(self, cur_state, in_vec, out_vec): 121 | """ 122 | cur_state - [batch_size, node_num, d_model] 123 | in_vec - [batch_size, node_num, d_model] 124 | out_vec - [batch_size, node_num, d_model] 125 | """ 126 | a = torch.cat([in_vec, out_vec, cur_state], dim=2) 127 | r = self.reset_gate(a) 128 | z = self.update_gate(a) 129 | 130 | joined_input = torch.cat([in_vec, out_vec, r * cur_state], dim=2) 131 | h_hat = self.transform(joined_input) 132 | 133 | output = (1 - z) * cur_state + z * h_hat # batch_size x node_num x d_model 134 | return output -------------------------------------------------------------------------------- /src/onqg/utils/mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | import onqg.dataset.Constants as Constants 5 | 6 | 7 | def get_non_pad_mask(seq): 8 | assert seq.dim() == 2 9 | return seq.ne(Constants.PAD).type(torch.float).unsqueeze(-1) 10 | 11 | 12 | def get_attn_key_pad_mask(seq_k, seq_q): 13 | ''' For masking out the padding part of key sequence. ''' 14 | # Expand to fit the shape of key query attention matrix. 15 | len_q = seq_q.size(1) 16 | padding_mask = seq_k.eq(Constants.PAD) 17 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 18 | 19 | return padding_mask 20 | 21 | 22 | def get_subsequent_mask(seq): 23 | ''' For masking out the subsequent info. ''' 24 | sz_b, len_s = seq.size() 25 | subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), 26 | diagonal=1) 27 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 28 | 29 | return subsequent_mask 30 | 31 | 32 | def get_slf_attn_mask(attn_mask, lengths, device=None): 33 | ''' For masking out according to the given attention matrix ''' 34 | max_length = torch.max(lengths, 0)[0].item() 35 | mask = torch.ones((lengths.size(0), max_length, max_length), device=device, dtype=torch.uint8) 36 | 37 | for idx, sample in enumerate(attn_mask): 38 | seq_len = int(len(sample) **0.5) 39 | sample = sample.view(seq_len, seq_len) 40 | pad_sample = sample if max_length == seq_len else torch.cat((sample, torch.ones((max_length - seq_len, seq_len), 41 | dtype=torch.uint8)), dim=0) 42 | mask[idx].narrow(1, 0, seq_len).copy_(pad_sample) 43 | mask = mask.view(-1, max_length, max_length) 44 | 45 | return mask 46 | 47 | 48 | def get_slf_window_mask(seq, window_size=3, separate=-1): 49 | ''' For masking out the words in distance: 50 | only allow a word to attend to those near to it 51 | 'near' means: within window_size words 52 | ''' 53 | assert window_size >= 0, "Window size cannot be smaller than zero! " 54 | 55 | sz_b, len_s = seq.size() 56 | 57 | slf_window_mask = torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8) 58 | 59 | if separate >= 0: 60 | tmp_seq = [[w.item() for w in sent] for sent in seq] 61 | indexes = [sent.index(separate) for sent in tmp_seq] 62 | else: 63 | for idx in range(len_s): 64 | for i in range(idx - window_size, idx + window_size + 1): 65 | if i >= 0 and i < len_s: 66 | slf_window_mask[idx][i] = 0 67 | 68 | slf_window_mask = slf_window_mask.unsqueeze(0).repeat(sz_b, 1, 1) # b x ls x ls 69 | 70 | if separate >= 0: 71 | for b_idx in range(sz_b): 72 | sep = indexes[b_idx] 73 | for idx in range(len_s): 74 | sep_final = tmp_seq[b_idx].index(separate, sep + 1) 75 | if idx == 0: 76 | for i in range(0, sep_final + 1): 77 | slf_window_mask[b_idx][idx][i] = 0 78 | elif idx == sep: 79 | for i in range(0, sep + 1): 80 | slf_window_mask[b_idx][idx][i] = 0 81 | elif idx == sep_final: 82 | slf_window_mask[b_idx][idx][0] = 0 83 | for i in range(sep + 1, sep_final + 1): 84 | slf_window_mask[b_idx][idx][i] = 0 85 | else: 86 | slf_window_mask[b_idx][idx][0] = 0 87 | for i in range(idx - window_size, idx + window_size + 1): 88 | if i >= 0 and i < len_s: 89 | if (idx <= sep and i <= sep) or (idx > sep and i > sep): 90 | slf_window_mask[b_idx][idx][i] = 0 91 | if idx <= sep: 92 | slf_window_mask[b_idx][idx][sep] = 0 93 | else: 94 | slf_window_mask[b_idx][idx][sep_final] = 0 95 | 96 | return slf_window_mask 97 | 98 | 99 | def get_edge_mask(edges): 100 | ''' Get mask matrix for edges 101 | edges - [batch_size, node_num * node_num] 102 | return - [batch_size, node_num, node_num] 103 | ''' 104 | len_edges = edges.size(1) 105 | node_num = int(len_edges **0.5) 106 | 107 | mask = edges.eq(Constants.PAD) 108 | mask = mask.view(-1, node_num, node_num) 109 | 110 | return mask -------------------------------------------------------------------------------- /src/onqg/utils/model_builder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | 4 | from onqg.models.Models import UnifiedModel 5 | from onqg.models.Encoders import RNNEncoder, GraphEncoder, EncoderTransformer, SparseGraphEncoder, TransfEncoder 6 | from onqg.models.Decoders import RNNDecoder, DecoderTransformer 7 | 8 | 9 | def build_encoder(opt, answer=False, graph=False): 10 | if graph: 11 | options = {'n_edge_type':opt.edge_vocab_size, 'd_model':opt.d_graph_enc_model, 12 | 'd_rnn_enc_model':opt.d_seq_enc_model, 'n_layer':opt.n_graph_enc_layer, 13 | 'alpha':opt.alpha, 'd_feat_vec':opt.d_feat_vec, 'feat_vocab':opt.node_feat_vocab, 14 | 'layer_attn':opt.layer_attn, 'dropout':opt.dropout, 'attn_dropout':opt.attn_dropout} 15 | model = SparseGraphEncoder.from_opt(options) if opt.sparse else GraphEncoder.from_opt(options) 16 | else: 17 | if opt.pretrained and not answer: 18 | options = {'pretrained':opt.pretrained, 'n_vocab':opt.src_vocab_size, 'layer_attn':opt.layer_attn} 19 | 20 | model = TransfEncoder.from_opt(options) 21 | for para in model.parameters(): 22 | para.requires_grad = False 23 | 24 | return model 25 | 26 | feat_vocab = opt.feat_vocab 27 | if feat_vocab: 28 | n_all_feat = len(feat_vocab) 29 | feat_vocab = feat_vocab[:n_all_feat - opt.dec_feature] 30 | 31 | options = {'n_vocab':opt.src_vocab_size, 'd_word_vec':opt.d_word_vec, 'd_model':opt.d_seq_enc_model, 32 | 'n_layer':opt.n_seq_enc_layer, 'brnn':opt.brnn, 'rnn':opt.enc_rnn, 'slf_attn':opt.slf_attn, 33 | 'feat_vocab':feat_vocab, 'd_feat_vec':opt.d_feat_vec, 'dropout':opt.dropout} 34 | 35 | model = RNNEncoder.from_opt(options) 36 | 37 | return model 38 | 39 | 40 | def build_decoder(opt, device): 41 | if opt.dec_feature: 42 | n_all_feat = len(opt.feat_vocab) 43 | feat_vocab = opt.feat_vocab[n_all_feat - opt.dec_feature:] 44 | else: 45 | feat_vocab = None 46 | 47 | options = {'n_vocab':opt.tgt_vocab_size, 'ans_n_vocab':opt.src_vocab_size, 'd_word_vec':opt.d_word_vec, 'd_model':opt.d_dec_model, 48 | 'n_layer':opt.n_dec_layer, 'n_rnn_enc_layer':opt.n_seq_enc_layer, 'rnn':opt.dec_rnn, 'd_k':opt.d_k, 49 | 'feat_vocab':feat_vocab, 'd_feat_vec':opt.d_feat_vec, 'd_enc_model':opt.d_graph_enc_model, 50 | 'd_rnn_enc_model':opt.d_seq_enc_model, 'n_enc_layer':opt.n_graph_enc_layer, 'input_feed':opt.input_feed, 51 | 'copy':opt.copy, 'coverage':opt.coverage, 'layer_attn':opt.layer_attn, 'answer':opt.answer, 52 | 'maxout_pool_size':opt.maxout_pool_size, 'dropout':opt.dropout, 'device':device} 53 | model = RNNDecoder.from_opt(options) 54 | 55 | return model 56 | 57 | 58 | def initialize(model, opt): 59 | parameters_cnt = 0 60 | for name, para in model.named_parameters(): 61 | size = list(para.size()) 62 | local_cnt = 1 63 | for d in size: 64 | local_cnt *= d 65 | 66 | if not opt.pretrained or not name.count('seq_encoder'): 67 | if para.dim() == 1: 68 | para.data.normal_(0, math.sqrt(6 / (1 + para.size(0)))) 69 | else: 70 | nn.init.xavier_normal(para, math.sqrt(3)) 71 | 72 | parameters_cnt += local_cnt 73 | 74 | if opt.pre_trained_vocab: 75 | assert opt.d_word_vec == 300, "Dimension of word vectors must equal to that of pretrained word-embedding" 76 | if not opt.pretrained: 77 | model.seq_encoder.word_emb.weight.data.copy_(opt.pre_trained_src_emb) 78 | model.decoder.word_emb.weight.data.copy_(opt.pre_trained_tgt_emb) 79 | if opt.answer: 80 | model.decoder.ans_emb.weight.data.copy_(opt.pre_trained_ans_emb) 81 | 82 | if opt.proj_share_weight: 83 | weight = model.decoder.maxout(model.decoder.word_emb.weight.data) 84 | model.generator.weight.data.copy_(weight) 85 | 86 | return model, parameters_cnt 87 | 88 | 89 | def build_model(opt, device, separate=-1, checkpoint=None): 90 | ## build model ## 91 | seq_encoder = build_encoder(opt) 92 | encoder_transformer = EncoderTransformer(opt.d_seq_enc_model, d_k=opt.d_k, device=device) 93 | graph_encoder = build_encoder(opt, graph=True) 94 | if opt.d_seq_enc_model != opt.d_graph_enc_model: 95 | graph_encoder.activate = nn.Sequential( 96 | nn.Linear(opt.d_seq_enc_model, opt.d_graph_enc_model, bias=False), 97 | nn.Tanh() 98 | ) 99 | else: 100 | graph_encoder.activate = nn.Tanh() 101 | decoder_transformer = DecoderTransformer(opt.layer_attn, device=device) 102 | decoder = build_decoder(opt, device) 103 | 104 | model = UnifiedModel(opt.training_mode, seq_encoder, graph_encoder, encoder_transformer, 105 | decoder, decoder_transformer) 106 | 107 | model.generator = nn.Linear(opt.d_dec_model // opt.maxout_pool_size, opt.tgt_vocab_size, bias=False) 108 | model.classifier = nn.Sequential( 109 | nn.Linear(opt.d_graph_enc_model, 1, bias=False), 110 | nn.Sigmoid() 111 | ) 112 | 113 | model, parameters_cnt = initialize(model, opt) 114 | 115 | if checkpoint is not None: 116 | model.load_state_dict(checkpoint['model']) 117 | del checkpoint 118 | 119 | ## move to gpus ## 120 | model = model.to(device) 121 | if len(opt.gpus) > 1: 122 | model = nn.DataParallel(model, device_ids=opt.gpus) 123 | 124 | return model, parameters_cnt 125 | -------------------------------------------------------------------------------- /src/onqg/utils/sinusoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 6 | ''' Sinusoid position encoding table ''' 7 | 8 | def cal_angle(position, hid_idx): 9 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 10 | 11 | def get_posi_angle_vec(position): 12 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 13 | 14 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 15 | 16 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 17 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 18 | 19 | if padding_idx is not None: 20 | # zero vector for padding dimension 21 | sinusoid_table[padding_idx] = 0. 22 | 23 | return torch.FloatTensor(sinusoid_table) -------------------------------------------------------------------------------- /src/onqg/utils/train/Loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.functional as F 4 | import torch.nn.functional as funct 5 | 6 | import onqg.dataset.Constants as Constants 7 | 8 | 9 | class Loss(object): 10 | def __init__(self, name, criterion): 11 | self.name = name 12 | self.criterion = criterion 13 | if not issubclass(type(self.criterion), nn.modules.loss._Loss): 14 | raise ValueError("Criterion has to be a subclass of torch.nn._Loss") 15 | # accumulated loss 16 | self.acc_loss = 0 17 | # normalization term 18 | self.norm_term = 0 19 | 20 | def reset(self): 21 | self.acc_loss = 0 22 | self.norm_term = 0 23 | 24 | def get_loss(self): 25 | raise NotImplementedError 26 | 27 | def cuda(self): 28 | self.criterion.cuda() 29 | 30 | def backward(self): 31 | if type(self.acc_loss) is int: 32 | raise ValueError("No loss to back propagate. ") 33 | self.acc_loss.backward() 34 | 35 | class NLLLoss(Loss): 36 | 37 | _NAME = "NLLLoss" 38 | 39 | def __init__(self, opt, weight=None, mask=None, size_average=True, coverage_weight=0.1): 40 | 41 | self.mask = mask 42 | self.size_average = size_average 43 | if mask is not None: 44 | if weight is None: 45 | raise ValueError("Must provide weight with a mask. ") 46 | weight[mask] = 0 47 | 48 | super(NLLLoss, self).__init__(self._NAME, nn.NLLLoss(weight=weight, size_average=size_average)) 49 | 50 | try: 51 | self.opt = opt 52 | if opt.copy: 53 | self.copy_loss = nn.NLLLoss(size_average=False) 54 | self.coverage_weight = coverage_weight 55 | except: 56 | self.coverage_weight = coverage_weight 57 | 58 | self.KL = nn.KLDivLoss() 59 | 60 | def get_loss(self): 61 | if isinstance(self.acc_loss, int): 62 | return 0 63 | # total loss for all batches 64 | loss = self.acc_loss.data.item() 65 | if self.size_average: 66 | # average loss per batch 67 | loss /= self.norm_term 68 | return loss 69 | 70 | def cal_loss(self, inputs): 71 | pred = inputs['pred'] 72 | gold = inputs['gold'] 73 | if self.opt.copy: 74 | copy_pred = inputs['copy_pred'] 75 | copy_gold = inputs['copy_gold'] 76 | copy_gate = inputs['copy_gate'] 77 | copy_switch = inputs['copy_switch'] 78 | if self.opt.coverage: 79 | coverage_pred = inputs['coverage_pred'] 80 | 81 | batch_size = gold.size(0) 82 | gold = gold.contiguous() 83 | norm = nn.Softmax(dim=1) 84 | 85 | pred = pred.contiguous().view(-1, pred.size(2)) 86 | pred = norm(pred) 87 | pred_prob_t = pred.contiguous().view(batch_size, -1, pred.size(1)) + 1e-8 # seq_len x batch_size x vocab_size 88 | 89 | if self.opt.copy: 90 | copy_pred_prob = copy_pred * copy_gate.expand_as(copy_pred) + 1e-8 91 | pred_prob = pred_prob_t * (1 - copy_gate).expand_as(pred_prob_t) + 1e-8 92 | 93 | copy_pred_prob_log = torch.log(copy_pred_prob) 94 | pred_prob_log = torch.log(pred_prob) 95 | copy_pred_prob_log = copy_pred_prob_log * (copy_switch.unsqueeze(2).expand_as(copy_pred_prob_log)) 96 | pred_prob_log = pred_prob_log * ((1 - copy_switch).unsqueeze(2).expand_as(pred_prob_log)) 97 | 98 | pred_prob_log = pred_prob_log.view(-1, pred_prob_log.size(2)) 99 | copy_pred_prob_log = copy_pred_prob_log.view(-1, copy_pred_prob_log.size(2)) 100 | 101 | pred_loss = self.criterion(pred_prob_log, gold.view(-1)) 102 | copy_loss = self.copy_loss(copy_pred_prob_log, copy_gold.contiguous().view(-1)) 103 | 104 | total_loss = pred_loss + copy_loss 105 | else: 106 | pred_prob_t_log = torch.log(pred_prob_t) 107 | pred_prob_t_log = pred_prob_t_log.view(-1, pred_prob_t_log.size(2)) 108 | pred_loss = self.criterion(pred_prob_t_log, gold.view(-1)) 109 | 110 | total_loss = pred_loss 111 | 112 | raw_loss = total_loss 113 | coverage_loss = None 114 | 115 | if self.opt.coverage: 116 | coverage_pred = [cv for cv in coverage_pred] 117 | 118 | coverage_loss = torch.sum(torch.stack(coverage_pred, 1), 1) 119 | coverage_loss = torch.sum(coverage_loss, 0) 120 | total_loss = total_loss + coverage_loss * self.coverage_weight 121 | 122 | return total_loss, coverage_loss, raw_loss 123 | 124 | def cal_loss_ner(self, pred, gold): 125 | device = gold.device 126 | golds = [] 127 | for batch in gold: 128 | tmp_sent = torch.stack([w for w in batch if w.item() != Constants.PAD]) 129 | golds.append(tmp_sent) 130 | golds = torch.cat(golds, dim=0).to(device) 131 | gold = golds.contiguous() 132 | 133 | pred = pred.contiguous().view(-1, pred.size(1)) 134 | pred_prob_t_log = torch.log(pred + 1e-8) 135 | 136 | pred_loss = self.criterion(pred_prob_t_log, gold.view(-1)) 137 | 138 | return pred_loss, gold 139 | -------------------------------------------------------------------------------- /src/onqg/utils/train/Optim.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import functools 3 | 4 | import torch 5 | import torch.optim as optim 6 | from torch.nn.utils import clip_grad_norm_ 7 | 8 | 9 | def build_torch_optimizer(model, opt): 10 | """Builds the PyTorch optimizer. 11 | Input: 12 | model: The model to optimize. 13 | opt: The dictionary of options. 14 | Output: 15 | A ``torch.optim.Optimizer`` instance. 16 | """ 17 | params = list(filter(lambda p: p.requires_grad, model.parameters())) # [p for p in model.parameters() if p.requires_grad] 18 | betas = [0.9, 0.999] # adam_beta1 & adam_beta2 19 | if opt.optim == 'sgd': 20 | optimizer = optim.SGD(params, lr=opt.learning_rate) 21 | elif opt.optim == 'adagrad': 22 | optimizer = optim.Adagrad(params, lr=opt.learning_rate, 23 | initial_accumulator_value=opt.adagrad_accumulator_init) 24 | elif opt.optim == 'adadelta': 25 | optimizer = optim.Adadelta(params, lr=opt.learning_rate) 26 | elif opt.optim == 'adam': 27 | optimizer = optim.Adam(params, lr=opt.learning_rate, betas=betas, eps=1e-9) 28 | elif opt.optim == 'fusedadam': 29 | import apex 30 | optimizer = apex.optimizers.FusedAdam(params, lr=opt.learning_rate, betas=betas) 31 | else: 32 | raise ValueError('Invalid optimizer type: ' + opt.optim) 33 | 34 | return {'optim':optimizer, 'para':params} 35 | 36 | 37 | def make_learning_rate_decay_fn(opt): 38 | """Returns the learning decay function from options.""" 39 | if opt.decay_method == 'noam': 40 | return functools.partial(noam_decay, warmup_steps=opt.n_warmup_steps, model_size=opt.d_model) 41 | elif opt.decay_method == 'noamwd': 42 | return functools.partial(noamwd_decay, warmup_steps=opt.n_warmup_steps, model_size=opt.d_model, 43 | rate=opt.learning_rate_decay, decay_steps=opt.decay_steps, start_step=opt.start_decay_steps) 44 | elif opt.decay_method == 'rsqrt': 45 | return functools.partial(rsqrt_decay, warmup_steps=opt.n_warmup_steps) 46 | elif opt.start_decay_steps is not None: 47 | return functools.partial(exponential_decay, rate=opt.learning_rate_decay, decay_steps=opt.decay_steps, start_step=opt.start_decay_steps) 48 | 49 | 50 | def noam_decay(step, warmup_steps, model_size): 51 | """Learning rate schedule described in https://arxiv.org/pdf/1706.03762.pdf. """ 52 | return (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps**(-1.5))) 53 | 54 | def noamwd_decay(step, warmup_steps, model_size, rate, decay_steps, start_step=0): 55 | """Learning rate schedule optimized for huge batches""" 56 | return (model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps**(-1.5)) * 57 | rate ** (max(step - start_step + decay_steps, 0) // decay_steps)) 58 | 59 | def exponential_decay(step, rate, decay_steps, start_step=0): 60 | """A standard exponential decay, scaling the learning rate by :obj:`rate` every :obj:`decay_steps` steps. """ 61 | return rate ** (max(step - start_step + decay_steps, 0) // decay_steps) 62 | 63 | def rsqrt_decay(step, warmup_steps): 64 | """Decay based on the reciprocal of the step square root.""" 65 | return 1.0 / sqrt(max(step, warmup_steps)) 66 | 67 | 68 | class Optimizer(object): 69 | """ 70 | Controller class for optimization. Mostly a thin wrapper for `optim`, 71 | but also useful for implementing rate scheduling beyond what is currently available. 72 | Also implements necessary methods for training RNNs such as grad manipulations. 73 | """ 74 | def __init__(self, optimizer_dict, learning_rate_decay_method, learning_rate, learning_rate_decay=0.5, 75 | lr_decay_step=1, start_decay_steps=5000, learning_rate_decay_fn=None, max_grad_norm=None, 76 | max_weight_value=None, decay_bad_cnt=None): 77 | self._optimizer = optimizer_dict['optim'] 78 | self._params = optimizer_dict['para'] 79 | self._learning_rate_decay_method = learning_rate_decay_method 80 | self._learning_rate = learning_rate 81 | self._learning_rate_decay = learning_rate_decay 82 | self._learning_rate_decay_fn = learning_rate_decay_fn 83 | self._max_grad_norm = max_grad_norm or 0 84 | self._max_weight_value = max_weight_value 85 | self._training_step = 0 86 | self._decay_step = lr_decay_step 87 | self._bad_cnt = 0 88 | self._decay_bad_cnt = decay_bad_cnt 89 | self._start_decay_steps = start_decay_steps 90 | 91 | @classmethod 92 | def from_opt(cls, model, opt, checkpoint=None): 93 | """Builds the optimizer from options. 94 | Input: 95 | cls: The ``Optimizer`` class to instantiate. 96 | model: The model to optimize. 97 | opt: The dict of user options. 98 | checkpoint: An optional checkpoint to load states from. 99 | Output: 100 | An ``Optimizer`` instance. 101 | """ 102 | optim_opt = opt 103 | # optim_state_dict = None 104 | 105 | optimizer = cls(build_torch_optimizer(model, optim_opt), 106 | optim_opt.decay_method, 107 | optim_opt.learning_rate, 108 | learning_rate_decay=optim_opt.learning_rate_decay, 109 | lr_decay_step=optim_opt.decay_steps, 110 | start_decay_steps=optim_opt.start_decay_steps, 111 | learning_rate_decay_fn=make_learning_rate_decay_fn(optim_opt), 112 | max_grad_norm=optim_opt.max_grad_norm, 113 | max_weight_value=optim_opt.max_weight_value, 114 | decay_bad_cnt=optim_opt.decay_bad_cnt) 115 | 116 | return optimizer 117 | 118 | @property 119 | def training_step(self): 120 | """The current training step.""" 121 | return self._training_step 122 | 123 | def learning_rate(self, better): 124 | """Returns the current learning rate.""" 125 | 126 | if better: 127 | self._bad_cnt = 0 128 | else: 129 | self._bad_cnt += 1 130 | 131 | if self._training_step % self._decay_step == 0 and self._training_step > self._start_decay_steps: 132 | 133 | if self._bad_cnt >= self._decay_bad_cnt and self._learning_rate >= 1e-5: 134 | 135 | if self._learning_rate_decay_method: 136 | scale = self._learning_rate_decay_fn(self._decay_step) 137 | self._decay_step += 1 138 | self._learning_rate *= scale 139 | else: 140 | self._learning_rate *= self._learning_rate_decay 141 | 142 | self._bad_cnt = 0 143 | 144 | return self._learning_rate 145 | 146 | def state_dict(self): 147 | return { 148 | 'training_step': self._training_step, 149 | 'decay_step': self._decay_step, 150 | 'optimizer': self._optimizer.state_dict() 151 | } 152 | 153 | def load_state_dict(self, state_dict): 154 | self._training_step = state_dict['training_step'] 155 | # State can be partially restored. 156 | if 'decay_step' in state_dict: 157 | self._decay_step = state_dict['decay_step'] 158 | if 'optimizer' in state_dict: 159 | self._optimizer.load_state_dict(state_dict['optimizer']) 160 | 161 | def zero_grad(self): 162 | """Zero the gradients of optimized parameters.""" 163 | self._optimizer.zero_grad() 164 | 165 | def backward(self, loss): 166 | """Wrapper for backward pass. Some optimizer requires ownership of the backward pass.""" 167 | loss.backward() 168 | 169 | def step(self): 170 | """Update the model parameters based on current gradients. """ 171 | learning_rate = self._learning_rate 172 | for group in self._optimizer.param_groups: 173 | group['lr'] = learning_rate 174 | if self._max_grad_norm > 0: 175 | clip_grad_norm_(group['params'], self._max_grad_norm) 176 | self._optimizer.step() 177 | if self._max_weight_value: 178 | for p in self._params: 179 | p.data.clamp_(0 - self._max_weight_value, self._max_weight_value) 180 | self._training_step += 1 181 | 182 | def update_learning_rate(self, better): 183 | lr0 = self._learning_rate 184 | lr = self.learning_rate(better) 185 | 186 | if lr != lr0: 187 | print("Update the learning rate to " + str(lr)) 188 | -------------------------------------------------------------------------------- /src/onqg/utils/train/__init__.py: -------------------------------------------------------------------------------- 1 | from onqg.utils.train.Train import SupervisedTrainer -------------------------------------------------------------------------------- /src/onqg/utils/translate/Beam.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | import torch 5 | import onqg.dataset.Constants as Constants 6 | 7 | 8 | class Beam(): 9 | def __init__(self, size, vocab_size, copy=False, device=None): 10 | self.vocab_size = vocab_size 11 | self.size = size 12 | self.copy = copy 13 | self.device = device 14 | 15 | self._done = False 16 | # Scores for each translation on the beam 17 | self.scores = torch.zeros((size, ), dtype=torch.float, device=device) 18 | self.all_scores = [] 19 | self.all_length = [] 20 | # Backpointers at each time step 21 | self.prev_ks = [] 22 | # Outputs at each time step 23 | self.next_ys = [torch.full((size, ), Constants.PAD, dtype=torch.long, device=device)] 24 | self.next_ys[0][0] = Constants.BOS 25 | self.next_ys_cp = [torch.full((size, ), Constants.PAD, dtype=torch.long, device=device)] 26 | self.next_ys_cp[0][0] = Constants.BOS 27 | # Attentions (matrix) for each time step 28 | self.attn = [] 29 | # Whether copy for each time step 30 | self.is_copy = [] 31 | 32 | def get_current_state(self): 33 | "Get the outputs for the current timestep." 34 | return self.get_tentative_hypothesis() 35 | #return self.next_ys[-1] 36 | 37 | def get_current_origin(self): 38 | "Get the backpointers for the current timestep." 39 | return self.prev_ks[-1] 40 | 41 | @property 42 | def done(self): 43 | return self._done 44 | 45 | def advance(self, pred_prob, copy_pred_prob=None, attn=None): 46 | "Update beam status and check if finished or not." 47 | num_words = pred_prob.size(1) 48 | raw_num_words = num_words 49 | if self.copy: 50 | assert copy_pred_prob is not None 51 | num_src_words = copy_pred_prob.size(1) 52 | num_words += num_src_words 53 | pred_prob = torch.cat((pred_prob, copy_pred_prob), dim=1) 54 | 55 | # Accumulate length for those who hasn't finished yet 56 | if len(self.prev_ks) > 0: 57 | finish_index = self.next_ys[-1].eq(Constants.EOS) # get the EOS indexes 58 | if any(finish_index): 59 | pred_prob.masked_fill_(finish_index.unsqueeze(1).expand_as(pred_prob), -float('inf')) 60 | for idx in range(self.size): 61 | if self.next_ys[-1][idx] == Constants.EOS: 62 | pred_prob[idx][Constants.EOS] = 0 63 | # set up the current step length 64 | cur_length = self.all_length[-1] 65 | for idx in range(self.size): 66 | cur_length[idx] += 0 if self.next_ys[-1][idx] == Constants.EOS else 1 67 | 68 | # Sum the previous scores 69 | if len(self.prev_ks) > 0: 70 | prev_score = self.all_scores[-1] 71 | now_acc_score = pred_prob + prev_score.unsqueeze(1).expand_as(pred_prob) 72 | beam_lk = now_acc_score / cur_length.unsqueeze(1).expand_as(now_acc_score) 73 | else: 74 | self.all_length.append(torch.FloatTensor(self.size).fill_(1).to(self.device)) 75 | beam_lk = pred_prob[0] 76 | 77 | flat_beam_lk = beam_lk.view(-1) 78 | 79 | best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort 80 | # best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 2nd sort 81 | # self.all_scores.append(self.scores) 82 | self.scores = best_scores 83 | 84 | # bestScoresId is flattened as a (beam x word) array, 85 | # so we need to calculate which word and beam each score came from 86 | prev_k = best_scores_id // num_words 87 | predict = best_scores_id - prev_k * num_words 88 | if self.copy: 89 | is_copy = predict.ge(torch.LongTensor(self.size).fill_(raw_num_words).to(self.device)).long() 90 | else: 91 | is_copy = 0 92 | final_predict = predict * (1 - is_copy) + is_copy * Constants.UNK 93 | 94 | if len(self.prev_ks) > 0: 95 | self.all_length.append(cur_length.index_select(0, prev_k)) # 96 | self.all_scores.append(now_acc_score.view(-1).index_select(0, best_scores_id)) 97 | else: 98 | self.all_scores.append(self.scores) 99 | 100 | self.prev_ks.append(prev_k) 101 | self.next_ys.append(final_predict) 102 | self.next_ys_cp.append(predict) 103 | self.is_copy.append(is_copy) 104 | if attn: 105 | self.attn.append(attn.index_select(0, prev_k)) 106 | 107 | # End condition is when top-of-beam is EOS. 108 | if all(self.next_ys[-1].eq(Constants.EOS)): 109 | self._done = True 110 | 111 | return self._done 112 | 113 | def sort_scores(self): 114 | "Sort the scores." 115 | return torch.sort(self.scores, 0, True) 116 | 117 | def get_the_best_score_and_idx(self): 118 | "Get the score of the best in the beam." 119 | scores, ids = self.sort_scores() 120 | return scores[1], ids[1] 121 | 122 | def get_hypothesis(self, k): 123 | """ Walk back to construct the full hypothesis. """ 124 | hyp, copy_hyp = [], [] 125 | if len(self.attn) > 0: 126 | attn = [] 127 | if self.copy: 128 | is_copy = [] 129 | for j in range(len(self.prev_ks) - 1, -1, -1): 130 | hyp.append(self.next_ys[j+1][k].item()) 131 | if len(self.attn) > 0: 132 | attn.append(self.attn[j][k]) 133 | if self.copy: 134 | is_copy.append(self.is_copy[j][k]) 135 | copy_hyp.append(self.next_ys_cp[j + 1][k]) 136 | k = self.prev_ks[j][k] 137 | 138 | rst = {'hyp':hyp[::-1], 'cp_hyp':copy_hyp[::-1]} 139 | if len(self.attn) > 0: 140 | rst['attn'] = torch.stack(attn[::-1]) 141 | if self.copy: 142 | rst['is_cp'] = is_copy[::-1] 143 | 144 | return rst 145 | 146 | def get_tentative_hypothesis(self): 147 | "Get the decoded sequence for the current timestep." 148 | if len(self.next_ys) == 1: 149 | dec_seq = self.next_ys[0].unsqueeze(1) 150 | else: 151 | _, keys = self.sort_scores() 152 | hyps = [self.get_hypothesis(k)['hyp'] for k in keys] 153 | hyps = [[Constants.BOS] + h for h in hyps] 154 | dec_seq = torch.LongTensor(hyps) 155 | 156 | return dec_seq 157 | 158 | 159 | -------------------------------------------------------------------------------- /src/onqg/utils/translate/Translator.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tqdm import tqdm 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch import cuda 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | from onqg.utils.translate.Beam import Beam 11 | import onqg.dataset.Constants as Constants 12 | from onqg.dataset.data_processor import preprocess_batch 13 | 14 | from nltk.translate import bleu_score 15 | 16 | 17 | def add(tgt_list): 18 | tgt = [] 19 | for b in tgt_list: 20 | tgt += b 21 | return tgt 22 | 23 | 24 | def get_tokens(indexes, data): 25 | src, tgt = data['src'], data['tgt'] 26 | srcs = [src[i] for i in indexes] 27 | golds = [[[w for w in tgt[i] if w not in [Constants.BOS_WORD, Constants.EOS_WORD]]] for i in indexes] 28 | return srcs, golds 29 | 30 | 31 | class Translator(object): 32 | def __init__(self, opt, vocab, tokens, src_vocab): 33 | self.opt = opt 34 | self.max_token_seq_len = opt.max_token_tgt_len 35 | self.tokens = tokens 36 | if opt.gpus: 37 | cuda.set_device(opt.gpus[0]) 38 | self.device = torch.device('cuda' if opt.gpus else 'cpu') 39 | self.vocab = vocab 40 | 41 | def translate_batch(self, model, inputs, max_length): 42 | 43 | def get_inst_idx_to_tensor_position_map(inst_idx_list): 44 | ''' Indicate the position of an instance in a tensor. ''' 45 | return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} 46 | 47 | def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq): 48 | dec_partial_seq1 = [b.get_current_state() for b in inst_dec_beams if not b.done] 49 | dec_partial_seq2 = torch.stack(dec_partial_seq1).to(self.device) 50 | dec_partial_seq = dec_partial_seq2.view(-1, len_dec_seq) 51 | return dec_partial_seq 52 | 53 | def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm): 54 | dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device) 55 | dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1) 56 | return dec_partial_pos 57 | 58 | def collect_active_inst_idx_list(inst_beams, pred_prob, copy_pred_prob, inst_idx_to_position_map, n_bm): 59 | active_inst_idx_list = [] 60 | pred_prob = pred_prob.unsqueeze(0).view(len(inst_idx_to_position_map), n_bm, -1) 61 | copy_pred_prob = None if not self.opt.copy else copy_pred_prob.unsqueeze(0).view(len(inst_idx_to_position_map), n_bm, -1) 62 | for inst_idx, inst_position in inst_idx_to_position_map.items(): 63 | copy_prob = None if not self.opt.copy else copy_pred_prob[inst_position] 64 | is_inst_complete = inst_beams[inst_idx].advance(pred_prob[inst_position], copy_prob) 65 | if not is_inst_complete: 66 | active_inst_idx_list += [inst_idx] 67 | 68 | return active_inst_idx_list 69 | 70 | def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm, layer=False): 71 | ''' Collect tensor parts associated to active instances. ''' 72 | tmp_beamed_tensor = beamed_tensor[0] if layer else beamed_tensor 73 | _, *d_hs = tmp_beamed_tensor.size() 74 | 75 | n_curr_active_inst = len(curr_active_inst_idx) 76 | new_shape = (n_curr_active_inst * n_bm, *d_hs) 77 | 78 | beamed_tensor = beamed_tensor if layer else [beamed_tensor] 79 | 80 | beamed_tensor = [layer_b_tensor.view(n_prev_active_inst, -1) for layer_b_tensor in beamed_tensor] 81 | beamed_tensor = [layer_b_tensor.index_select(0, curr_active_inst_idx) for layer_b_tensor in beamed_tensor] 82 | beamed_tensor = [layer_b_tensor.view(*new_shape) for layer_b_tensor in beamed_tensor] 83 | 84 | beamed_tensor = beamed_tensor if layer else beamed_tensor[0] 85 | 86 | return beamed_tensor 87 | 88 | with torch.no_grad(): 89 | ### ========== Prepare data ========== ### 90 | if len(self.opt.gpus) > 1: 91 | model = model.module 92 | ### ========== Encode ========== ### 93 | seq_output, hidden = model.seq_encoder(inputs['seq-encoder']) 94 | inputs['encoder-transform']['seq_output'] = seq_output 95 | inputs['encoder-transform']['hidden'] = hidden 96 | node_input, hidden = model.encoder_transformer(inputs['encoder-transform'], max_length) 97 | inputs['graph-encoder']['nodes'] = node_input 98 | node_output, _ = model.graph_encoder(inputs['graph-encoder']) 99 | ##===== Decode =====## 100 | inputs['decoder-transform']['graph_output'] = node_output 101 | inputs['decoder-transform']['seq_output'] = seq_output 102 | inputs['decoder-transform']['hidden'] = hidden 103 | inputs['decoder']['enc_output'], inputs['decoder']['scores'], hidden = model.decoder_transformer(inputs['decoder-transform']) 104 | inputs['decoder']['hidden'] = hidden 105 | ### ========== Repeat for beam search ========== ### 106 | n_bm = self.opt.beam_size 107 | enc_output = inputs['decoder']['enc_output'] 108 | if self.opt.layer_attn: 109 | n_inst, len_s, d_h = enc_output[0].size() 110 | inputs['decoder']['enc_output'] = [src_layer.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 111 | for src_layer in enc_output] 112 | else: 113 | n_inst, len_s, d_h = enc_output.size() 114 | inputs['decoder']['enc_output'] = enc_output.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) 115 | inputs['decoder']['src_seq'] = inputs['decoder']['src_seq'].repeat(1, n_bm).view(n_inst * n_bm, len_s) 116 | inputs['decoder']['ans_seq'] = inputs['decoder']['ans_seq'].repeat(1, n_bm).view(n_inst * n_bm, -1) 117 | inputs['decoder']['hidden'] = hidden.repeat(1, n_bm).view(n_inst * n_bm, -1) # [h.repeat(1, n_bm).view(n_inst * n_bm, -1) for h in hidden] 118 | inputs['decoder']['feat_seqs'] = [feat_seq.repeat(1, n_bm).view(n_inst * n_bm, len_s) 119 | for feat_seq in inputs['decoder']['feat_seqs']] if self.opt.feature else None 120 | ### ========== Prepare beams ========== ### 121 | inst_dec_beams = [Beam(n_bm, self.vocab.size, self.opt.copy, device=self.device) for _ in range(n_inst)] 122 | ### ========== Bookkeeping for active or not ========== ### 123 | active_inst_idx_list = list(range(n_inst)) 124 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 125 | ### ========== Decode ========== ### 126 | norm = nn.Softmax(dim=1) 127 | for len_dec_seq in range(1, self.max_token_seq_len + 1): 128 | n_active_inst = len(inst_idx_to_position_map) 129 | ### ===== decoder forward ===== ### 130 | inputs['decoder']['tgt_seq'] = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq) # (n_bm x batch_size) x len_dec_seq 131 | 132 | rst = model.decoder(inputs['decoder']) 133 | rst['pred'] = model.generator(rst['pred']) 134 | pred = rst['pred'][:, -1, :] 135 | pred = norm(pred) 136 | if self.opt.copy: 137 | copy_pred, copy_gate = rst['copy_pred'][:, -1, :], rst['copy_gate'][:, -1, :] 138 | ### ===== log softmax ===== ### 139 | pred = norm(pred) + 1e-8 140 | pred_prob = pred 141 | copy_pred_log = None 142 | if self.opt.copy: 143 | copy_gate = copy_gate.ge(0.5).type(torch.cuda.FloatTensor) 144 | pred_prob_log = torch.log(pred_prob * ((1 - copy_gate).expand_as(pred_prob)) + 1e-25) 145 | copy_pred_log = torch.log(copy_pred * (copy_gate.expand_as(copy_pred)) + 1e-25) 146 | else: 147 | pred_prob_log = torch.log(pred_prob) 148 | ### ====== active list update ====== ### 149 | active_inst_idx_list = collect_active_inst_idx_list(inst_dec_beams, pred_prob_log, copy_pred_log, 150 | inst_idx_to_position_map, n_bm) 151 | if not active_inst_idx_list: 152 | break # all instances have finished their path to [EOS] 153 | ### ====== variables update ====== ### 154 | # Sentences which are still active are collected, 155 | # so the decoder will not run on completed sentences. 156 | n_prev_active_inst = len(inst_idx_to_position_map) 157 | active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] 158 | active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) 159 | 160 | inputs['decoder']['enc_output'] = collect_active_part(inputs['decoder']['enc_output'], active_inst_idx, n_prev_active_inst, 161 | n_bm, layer=self.opt.layer_attn) 162 | inputs['decoder']['src_seq'] = collect_active_part(inputs['decoder']['src_seq'], active_inst_idx, n_prev_active_inst, n_bm) 163 | inputs['decoder']['ans_seq'] = collect_active_part(inputs['decoder']['ans_seq'], active_inst_idx, n_prev_active_inst, n_bm) 164 | inputs['decoder']['hidden'] = collect_active_part(inputs['decoder']['hidden'], active_inst_idx, n_prev_active_inst, n_bm) 165 | inputs['decoder']['feat_seqs'] = [collect_active_part(feat_seq, active_inst_idx, n_prev_active_inst, n_bm) 166 | for feat_seq in inputs['decoder']['feat_seqs']] if self.opt.feature else None 167 | 168 | inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) 169 | 170 | ### ========== Get hypothesis ========== ### 171 | all_hyp, all_scores = [], [] 172 | all_copy_hyp, all_is_copy = [], [] 173 | 174 | for inst_idx in range(len(inst_dec_beams)): 175 | scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores() 176 | all_scores.append(scores[: self.opt.n_best]) 177 | 178 | rsts = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:self.opt.n_best]] 179 | hyp = [rst['hyp'] for rst in rsts] 180 | all_hyp.append(hyp) 181 | if self.opt.copy: 182 | copy_hyp = [rst['cp_hyp'] for rst in rsts] 183 | is_copy = [rst['is_cp'] for rst in rsts] 184 | all_copy_hyp.append(copy_hyp) 185 | all_is_copy.append(is_copy) 186 | 187 | if self.opt.copy: 188 | return all_hyp, all_scores, all_is_copy, all_copy_hyp 189 | else: 190 | return all_hyp, all_scores 191 | 192 | def eval_batch(self, model, inputs, max_length, gold, copy_gold=None, copy_switch=None, batchIdx=None): 193 | 194 | def get_preds(seq, is_copy_seq=None, copy_seq=None, src_words=None, attn=None): 195 | pred = [idx for idx in seq if idx not in [Constants.PAD, Constants.EOS]] 196 | for i, _ in enumerate(pred): 197 | if self.opt.copy and is_copy_seq[i].item(): 198 | pred[i] = src_words[copy_seq[i].item() - self.vocab.size] 199 | else: 200 | pred[i] = self.vocab.idxToLabel[pred[i]] 201 | return pred 202 | 203 | def src_tokens(src_words): 204 | if self.opt.pretrained.count('bert'): 205 | tmp_word, tmp_idx = '', 0 206 | for i, w in enumerate(src_words): 207 | if not w.startswith('##'): 208 | if tmp_word: 209 | src_words[tmp_idx] = tmp_word 210 | for j in range(tmp_idx + 1, i): 211 | src_words[j] = '' 212 | tmp_word, tmp_idx = w, i 213 | else: 214 | tmp_word += w.lstrip('##') 215 | src_words[tmp_idx] = tmp_word 216 | for j in range(tmp_idx + 1, i): 217 | src_words[j] = '' 218 | raw_words = [word for word in src_words if word != ''] 219 | return src_words, raw_words 220 | 221 | golds, preds, paras = [], [], [] 222 | if self.opt.copy: 223 | all_hyp, _, all_is_copy, all_copy_hyp = self.translate_batch(model, inputs, max_length) 224 | else: 225 | all_hyp, _ = self.translate_batch(model, inputs, max_length) 226 | 227 | src_sents, golds = get_tokens(batchIdx, self.tokens) 228 | for i, seqs in tqdm(enumerate(all_hyp), mininterval=2, desc=' - (Translating) ', leave=False): 229 | seq = seqs[0] 230 | src_words, raw_words = src_tokens(src_sents[i]) 231 | if self.opt.copy: 232 | is_copy_seq, copy_seq = all_is_copy[i][0], all_copy_hyp[i][0] 233 | preds.append(get_preds(seq, is_copy_seq=is_copy_seq, copy_seq=copy_seq, src_words=src_words)) 234 | else: 235 | preds.append(get_preds(seq)) 236 | paras.append(raw_words) 237 | 238 | return {'gold':golds, 'pred':preds, 'para':paras} 239 | 240 | def eval_all(self, model, validData, output_sent=False): 241 | all_golds, all_preds = [], [] 242 | if output_sent: 243 | all_paras = [] 244 | 245 | for idx in tqdm(range(len(validData)), mininterval=2, desc=' - (Translating) ', leave=False): 246 | ### ========== Prepare data ========== ### 247 | batch = validData[idx] 248 | inputs, max_lengths, gold, copy = preprocess_batch(batch, self.opt.edge_vocab_size, sparse=self.opt.sparse, feature=self.opt.feature, 249 | dec_feature=self.opt.dec_feature, copy=self.opt.copy, node_feature=self.opt.node_feature, 250 | device=self.device) 251 | copy_gold, copy_switch = copy[0], copy[1] 252 | ### ========== Translate ========== ### 253 | rst = self.eval_batch(model, inputs, max_lengths, gold, 254 | copy_gold=copy_gold, copy_switch=copy_switch, 255 | batchIdx=batch['raw-index']) 256 | 257 | all_golds += rst['gold'] 258 | all_preds += rst['pred'] 259 | if output_sent: 260 | all_paras += rst['para'] 261 | 262 | 263 | bleu = bleu_score.corpus_bleu(all_golds, all_preds) 264 | 265 | if output_sent: 266 | return bleu, (all_golds, all_preds, all_paras) 267 | return bleu -------------------------------------------------------------------------------- /src/onqg/utils/translate/__init__.py: -------------------------------------------------------------------------------- 1 | from onqg.utils.translate.Translator import Translator -------------------------------------------------------------------------------- /src/pargs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def add_options(parser): 5 | ##### ========== Data Files ========== ##### 6 | parser.add_argument('-train_src', help="Path to the training source data") 7 | parser.add_argument('-train_tgt', help="Path to the training target data") 8 | parser.add_argument('-valid_src', help="Path to the validation source data") 9 | parser.add_argument('-valid_tgt', help="Path to the validation target data") 10 | parser.add_argument('-train_dataset', help="Path to the training dataset object") 11 | parser.add_argument('-valid_dataset', help="Path to the validation dataset object") 12 | 13 | parser.add_argument('-train_graph', help="Path to the training source graph data") 14 | parser.add_argument('-valid_graph', help="Path to the validation source graph data") 15 | 16 | parser.add_argument('-train_ans', default='', help="Path to the training answer") 17 | parser.add_argument('-valid_ans', default='', help="Path to the validation answer") 18 | 19 | parser.add_argument('-feature', default=False, action='store_true') 20 | parser.add_argument('-node_feature', default=False, action='store_true') 21 | parser.add_argument('-train_feats', default=[], nargs='+', type=str, help="Train files of source features") 22 | parser.add_argument('-valid_feats', default=[], nargs='+', type=str, help="Valid files of source features") 23 | 24 | parser.add_argument('-answer', default=False, action='store_true') 25 | parser.add_argument('-ans_feature', default=False, action='store_true') 26 | parser.add_argument('-train_ans_feats', default=[], nargs='+', type=str, help="Train files of answer features") 27 | parser.add_argument('-valid_ans_feats', default=[], nargs='+', type=str, help="Valid files of answer features") 28 | 29 | ##### ========== Data Preprocess Options ========== ##### 30 | parser.add_argument('-copy', default=False, action='store_true') 31 | 32 | parser.add_argument('-src_seq_length', type=int, default=300) 33 | parser.add_argument('-tgt_seq_length', type=int, default=100) 34 | 35 | parser.add_argument('-src_vocab_size', type=int, default=50000) 36 | parser.add_argument('-tgt_vocab_size', type=int, default=50000) 37 | parser.add_argument('-src_words_min_frequency', type=int, default=1) 38 | parser.add_argument('-tgt_words_min_frequency', type=int, default=1) 39 | parser.add_argument('-vocab_trunc_mode', default='size', 40 | help="How to truncate vocabulary size") 41 | 42 | parser.add_argument('-feat_vocab_size', type=int, default=1000) 43 | parser.add_argument('-feat_words_min_frequency', type=int, default=1) 44 | 45 | parser.add_argument('-share_vocab', action='store_true', default=False, 46 | help="Share source and target vocabulary") 47 | 48 | parser.add_argument('-pretrained', type=str, default='', help="choices: bert-base-uncased, gpt2, etc.") 49 | parser.add_argument('-pre_trained_vocab', default='', 50 | help="Path to the pre-trained vocab file") 51 | parser.add_argument('-word_vec_size', type=int, default=300) 52 | parser.add_argument('-batch_size', type=int, default=32) 53 | 54 | ##### ========== Final Results Directory ========== ##### 55 | parser.add_argument('-save_sequence_data', help="Output file for the prepared data") 56 | parser.add_argument('-save_graph_data', help="Output file for the prepared data") 57 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xargs 3 | import argparse 4 | 5 | import math 6 | import time 7 | import logging 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch import cuda 13 | 14 | import onqg.dataset.Constants as Constants 15 | from onqg.dataset.Dataset import Dataset 16 | 17 | from onqg.utils.model_builder import build_model 18 | from onqg.utils.train.Loss import NLLLoss 19 | from onqg.utils.train.Optim import Optimizer 20 | from onqg.utils.train import SupervisedTrainer 21 | from onqg.utils.translate import Translator 22 | 23 | 24 | def main(opt, logger): 25 | logger.info('My PID is {0}'.format(os.getpid())) 26 | logger.info('PyTorch version: {0}'.format(str(torch.__version__))) 27 | logger.info(opt) 28 | 29 | if torch.cuda.is_available() and not opt.gpus: 30 | logger.info("WARNING: You have a CUDA device, so you should probably run with -gpus 0") 31 | if opt.seed > 0: 32 | torch.manual_seed(opt.seed) 33 | if opt.gpus: 34 | if opt.cuda_seed > 0: 35 | torch.cuda.manual_seed(opt.cuda_seed) 36 | cuda.set_device(opt.gpus[0]) 37 | logger.info('My seed is {0}'.format(torch.initial_seed())) 38 | logger.info('My cuda seed is {0}'.format(torch.cuda.initial_seed())) 39 | 40 | ###### ==================== Loading Options ==================== ###### 41 | if opt.checkpoint: 42 | checkpoint = torch.load(opt.checkpoint) 43 | 44 | ###### ==================== Loading Dataset ==================== ###### 45 | opt.sparse = True if opt.sparse else False 46 | # logger.info('Loading sequential data ......') 47 | # sequences = torch.load(opt.sequence_data) 48 | # seq_vocabularies = sequences['dict'] 49 | # logger.info('Loading structural data ......') 50 | # graphs = torch.load(opt.graph_data) 51 | # graph_vocabularies = graphs['dict'] 52 | 53 | ### ===== load pre-trained vocabulary ===== ### 54 | logger.info('Loading sequential data ......') 55 | sequences = torch.load(opt.sequence_data) 56 | seq_vocabularies = sequences['dict'] 57 | logger.info('Loading pre-trained vocabulary ......') 58 | if opt.pre_trained_vocab: 59 | if not opt.pretrained: 60 | opt.pre_trained_src_emb = seq_vocabularies['pre-trained']['src'] 61 | opt.pre_trained_tgt_emb = seq_vocabularies['pre-trained']['tgt'] 62 | if opt.answer: 63 | opt.pre_trained_ans_emb = seq_vocabularies['pre-trained']['src'] 64 | 65 | ### ===== wrap datasets ===== ### 66 | logger.info('Loading Dataset objects ......') 67 | trainData = torch.load(opt.train_dataset) 68 | validData = torch.load(opt.valid_dataset) 69 | trainData.batchSize = validData.batchSize = opt.batch_size 70 | trainData.numBatches = math.ceil(len(trainData.src) / trainData.batchSize) 71 | validData.numBatches = math.ceil(len(validData.src) / validData.batchSize) 72 | 73 | logger.info('Preparing vocabularies ......') 74 | opt.src_vocab_size = seq_vocabularies['src'].size 75 | opt.tgt_vocab_size = seq_vocabularies['tgt'].size 76 | opt.feat_vocab = [fv.size for fv in seq_vocabularies['feature']] if opt.feature else None 77 | 78 | logger.info('Loading structural data ......') 79 | graphs = torch.load(opt.graph_data) 80 | graph_vocabularies = graphs['dict'] 81 | del graphs 82 | 83 | opt.edge_vocab_size = graph_vocabularies['edge']['in'].size 84 | opt.node_feat_vocab = [fv.size for fv in graph_vocabularies['feature'][1:-1]] if opt.node_feature else None 85 | 86 | logger.info(' * vocabulary size. source = %d; target = %d' % (opt.src_vocab_size, opt.tgt_vocab_size)) 87 | logger.info(' * number of training batches. %d' % len(trainData)) 88 | logger.info(' * maximum batch size. %d' % opt.batch_size) 89 | 90 | ##### =================== Prepare Model =================== ##### 91 | device = torch.device('cuda' if opt.gpus else 'cpu') 92 | trainData.device = validData.device = device 93 | checkpoint = checkpoint if opt.checkpoint else None 94 | 95 | model, parameters_cnt = build_model(opt, device, checkpoint=checkpoint) 96 | del checkpoint 97 | 98 | logger.info(' * Number of parameters to learn = %d' % parameters_cnt) 99 | 100 | ##### ==================== Prepare Optimizer ==================== ##### 101 | optimizer = Optimizer.from_opt(model, opt) 102 | 103 | ##### ==================== Prepare Loss ==================== ##### 104 | weight = torch.ones(opt.tgt_vocab_size) 105 | weight[Constants.PAD] = 0 106 | loss = NLLLoss(opt, weight, size_average=False) 107 | if opt.gpus: 108 | loss.cuda() 109 | 110 | ##### ==================== Prepare Translator ==================== ##### 111 | translator = Translator(opt, seq_vocabularies['tgt'], sequences['valid']['tokens'], seq_vocabularies['src']) 112 | 113 | ##### ==================== Training ==================== ##### 114 | trainer = SupervisedTrainer(model, loss, optimizer, translator, logger, 115 | opt, trainData, validData, seq_vocabularies['src'], 116 | graph_vocabularies['feature']) 117 | del model 118 | del trainData 119 | del validData 120 | del seq_vocabularies['src'] 121 | del graph_vocabularies['feature'] 122 | trainer.train(device) 123 | 124 | 125 | if __name__ == '__main__': 126 | ##### ==================== parse the options ==================== ##### 127 | parser = argparse.ArgumentParser(description='train.py') 128 | xargs.add_data_options(parser) 129 | xargs.add_model_options(parser) 130 | xargs.add_train_options(parser) 131 | opt = parser.parse_args() 132 | 133 | ##### ==================== prepare the logger ==================== ##### 134 | logging.basicConfig(format='%(asctime)s [%(levelname)s:%(name)s]: %(message)s', level=logging.INFO) 135 | log_file_name = time.strftime("%Y%m%d-%H%M%S") + '.log.txt' 136 | if opt.log_home: 137 | log_file_name = os.path.join(opt.log_home, log_file_name) 138 | file_handler = logging.FileHandler(log_file_name, encoding='utf-8') 139 | file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)-5.5s:%(name)s] %(message)s')) 140 | logging.root.addHandler(file_handler) 141 | logger = logging.getLogger(__name__) 142 | 143 | main(opt, logger) 144 | -------------------------------------------------------------------------------- /src/translate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import cuda 4 | import torch.nn as nn 5 | import argparse 6 | from tqdm import tqdm 7 | 8 | from onqg.utils.translate import Translator 9 | from onqg.dataset import Dataset 10 | from onqg.utils.model_builder import build_model 11 | 12 | 13 | def dump(data, filename): 14 | golds, preds, paras = data[0], data[1], data[2] 15 | with open(filename, 'w', encoding='utf-8') as f: 16 | for g, p, pa in zip(golds, preds, paras): 17 | pa = [w for w in pa if w not in ['[PAD]', '[CLS]']] 18 | f.write('\t' + ' '.join(pa) + '\n') 19 | f.write('\t' + ' '.join(g[0]) + '\n') 20 | f.write('\t' + ' '.join(p) + '\n') 21 | f.write('===========================\n') 22 | 23 | 24 | def main(opt): 25 | device = torch.device('cuda' if opt.cuda else 'cpu') 26 | 27 | checkpoint = torch.load(opt.model) 28 | model_opt = checkpoint['settings'] 29 | model_opt.gpus = opt.gpus 30 | model_opt.beam_size, model_opt.batch_size = opt.beam_size, opt.batch_size 31 | 32 | ### Prepare Data ### 33 | sequences = torch.load(opt.sequence_data) 34 | seq_vocabularies = sequences['dict'] 35 | 36 | validData = torch.load(opt.valid_data) 37 | validData.batchSize = opt.batch_size 38 | validData.numBatches = math.ceil(len(validData.src) / validData.batchSize) 39 | 40 | ### Prepare Model ### 41 | validData.device = validData.device = device 42 | model, _ = build_model(model_opt, device) 43 | model.load_state_dict(checkpoint['model']) 44 | model.eval() 45 | 46 | translator = Translator(model_opt, seq_vocabularies['tgt'], sequences['valid']['tokens'], seq_vocabularies['src']) 47 | 48 | bleu, outputs = translator.eval_all(model, validData, output_sent=True) 49 | 50 | print('\nbleu-4', bleu, '\n') 51 | 52 | dump(outputs, opt.output) 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='translate.py') 57 | 58 | parser.add_argument('-model', required=True, help='Path to model .pt file') 59 | parser.add_argument('-sequence_data', required=True, help='Path to data file') 60 | parser.add_argument('-graph_data', required=True, help='Path to data file') 61 | parser.add_argument('-valid_data', required=True, help='Path to data file') 62 | parser.add_argument('-output', required=True, help='Path to output the predictions') 63 | parser.add_argument('-beam_size', type=int, default=5) 64 | parser.add_argument('-batch_size', type=int, default=32) 65 | parser.add_argument('-gpus', default=[], nargs='+', type=int) 66 | 67 | opt = parser.parse_args() 68 | opt.cuda = True if opt.gpus else False 69 | if opt.cuda: 70 | cuda.set_device(opt.gpus[0]) 71 | 72 | main(opt) -------------------------------------------------------------------------------- /src/xargs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def add_data_options(parser): 5 | ## Data options 6 | parser.add_argument('-sequence_data', required=True) 7 | parser.add_argument('-graph_data', required=True) 8 | parser.add_argument('-train_dataset', required=True) 9 | parser.add_argument('-valid_dataset', required=True) 10 | 11 | # Test options 12 | parser.add_argument('-max_token_src_len', type=int, default=300) 13 | parser.add_argument('-max_token_tgt_len', type=int, default=50) 14 | parser.add_argument('-accum_count', type=int, nargs='+', default=[1], 15 | help="Accumulate gradient this many times. " 16 | "Approximately equivalent to updating batch_size * accum_count batches at once. " 17 | "Recommended for Transformer.") 18 | 19 | 20 | def add_model_options(parser): 21 | ## checkpoint 22 | parser.add_argument('-checkpoint', type=str, default='', 23 | help="Path of trained model to do further training") 24 | 25 | ## Model options 26 | parser.add_argument('-epoch', type=int, default=10) 27 | parser.add_argument('-batch_size', type=int, default=64) 28 | 29 | parser.add_argument('-answer', default=False, action='store_true') 30 | parser.add_argument('-node_feature', default=False, action='store_true', 31 | help="whether to incorporate node feature information") 32 | parser.add_argument('-feature', default=False, action='store_true', 33 | help="whether to incorporate word feature information") 34 | parser.add_argument('-dec_feature', type=int, default=0, 35 | help="""Number of features directly sent to the decoder""") 36 | 37 | parser.add_argument('-pretrained', type=str, default='', help="choices: bert-base-uncased, gpt2, etc.") 38 | parser.add_argument('-pre_trained_vocab', action='store_true', default=False, 39 | help="whether to use a pretrained word vector for embedding") 40 | 41 | parser.add_argument('-d_word_vec', type=int, default=512, 42 | help="hidden size of word-vector in embedding layer") 43 | parser.add_argument('-d_feat_vec', type=int, default=32, 44 | help="size of feature embedding vector") 45 | 46 | parser.add_argument('-sparse', type=int, default=1) 47 | 48 | parser.add_argument('-d_seq_enc_model', type=int, default=512, 49 | help="hidden_size * num_directions of vector in RNN encoder") 50 | parser.add_argument('-d_graph_enc_model', type=int, default=512, 51 | help="hidden_size of vector in graph encoder") 52 | parser.add_argument('-d_dec_model', type=int, default=512, 53 | help="hidden size of vector in decoder") 54 | parser.add_argument('-n_seq_enc_layer', type=int, default=1, help="number of rnn encoder layers") 55 | parser.add_argument('-n_graph_enc_layer', type=int, default=3, help="number of graph encoder layers") 56 | parser.add_argument('-n_dec_layer', type=int, default=1, help="number of decoder layer") 57 | 58 | parser.add_argument('-n_head', type=int, default=8, 59 | help="number of heads in multi-head self-attention mechanism") 60 | parser.add_argument('-d_inner', type=int, default=2048, 61 | help="size of inner vector of the 1st layer in feed forward network") 62 | parser.add_argument('-d_k', type=int, default=64, 63 | help="size of attention vector") 64 | parser.add_argument('-d_v', type=int, default=64, 65 | help="size of vectors which will be used to calculate weighted context vector") 66 | 67 | parser.add_argument('-alpha', type=float, default=0.1, help="for LeakyReLu") 68 | 69 | parser.add_argument('-brnn', default=False, action='store_true', 70 | help="whether to use a bidirectional RNN in encoder") 71 | parser.add_argument('-input_feed', type=bool, default=1, 72 | help="whether to incorporate encoder-hidden-state directly into RNN in decoder") 73 | parser.add_argument('-enc_rnn', type=str, choices=['gru', 'lstm'], default='gru') 74 | parser.add_argument('-dec_rnn', type=str, choices=['gru', 'lstm'], default='gru') 75 | 76 | parser.add_argument('-defined_slf_attn_mask', default='', type=str, 77 | help="Path to the self-attention matrix for each sequences") 78 | 79 | parser.add_argument('-copy', action='store_true', default=False, 80 | help="""Use a copy mechanism""") 81 | parser.add_argument('-coverage', action='store_true', default=False, 82 | help="""Use a coverage mechanism""") 83 | parser.add_argument('-coverage_weight', type=float, default=1.0, 84 | help="""Weight of the loss of coverage mechanism in final total loss""") 85 | 86 | parser.add_argument('-slf_attn', action='store_true', default=False, 87 | help="source self-attention encoding") 88 | 89 | parser.add_argument('-maxout_pool_size', type=int, default=2, 90 | help='Pooling size for MaxOut layer.') 91 | 92 | parser.add_argument('-layer_attn', default=False, action='store_true', 93 | help="""Wether to add a universal cross attention 94 | on all outputs of the encoder layers to generate 95 | a aggregate context of all layers""") 96 | 97 | parser.add_argument('-proj_share_weight', action='store_true', 98 | help="whether to share weight between embedding and final projecting layers") 99 | 100 | 101 | def add_train_options(parser): 102 | # log 103 | parser.add_argument('-log_home', required=True, help="""log home""") 104 | 105 | parser.add_argument('-save_model', default=None) 106 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') 107 | 108 | parser.add_argument('-valid_steps', type=int, default=500, 109 | help="Number of interval steps between two adjacent times of evaluation") 110 | 111 | parser.add_argument('-logfile_train', default='', 112 | help="Path to save loss and evaluation reports on training data") 113 | parser.add_argument('-logfile_dev', default='', 114 | help="Path to save loss and evaluation reports on validation data") 115 | 116 | # BLeU-4 117 | parser.add_argument('-translate_ppl', type=float, default=40, 118 | help="Start to calculate BLeU4 on validation data when its PPL reach this number.") 119 | parser.add_argument('-translate_steps', type=int, default=2500, 120 | help="Number of interval steps between two adjacent times of translation") 121 | 122 | # training trick 123 | parser.add_argument('-training_mode', required=True, choices=['unify', 'generate', 'classify']) 124 | parser.add_argument('-n_warmup_steps', type=int, default=4000) 125 | parser.add_argument('-dropout', type=float, default=0.1) 126 | parser.add_argument('-attn_dropout', type=float, default=0.1) 127 | 128 | parser.add_argument('-curriculum', type=int, default=1, 129 | help="""For this many epochs, order the minibatches based 130 | on source sequence length. Sometimes setting this to 1 will 131 | increase convergence speed.""") 132 | parser.add_argument('-extra_shuffle', action="store_true", 133 | help="""By default only shuffle mini-batch order; when true, 134 | shuffle and re-assign mini-batches""") 135 | 136 | # learning rate 137 | parser.add_argument('-optim', default='sgd', 138 | choices=['sgd', 'adagrad', 'adadelta', 'adam', 'sparseadam', 'fusedadam'], 139 | help="Optimization method.") 140 | parser.add_argument('-learning_rate', type=float, default=1.0, 141 | help="Starting learning rate. " 142 | "Recommended settings: sgd = 1, adagrad = 0.1, adadelta = 1, adam = 0.001") 143 | parser.add_argument('-decay_method', type=str, default='', choices=['noam', 'noamwd', 'rsqrt', ''], 144 | help="Use a custom decay rate.") 145 | parser.add_argument('-learning_rate_decay', type=float, default=0.5, 146 | help="If update_learning_rate, decay learning rate by " 147 | "this much if steps have gone past " 148 | "start_decay_steps") 149 | parser.add_argument('-decay_steps', type=int, default=500, 150 | help="Decay every decay_steps") 151 | parser.add_argument('-start_decay_steps', type=int, default=1000, 152 | help="Start decaying every decay_steps after start_decay_steps") 153 | parser.add_argument('-max_grad_norm', type=float, default=5, 154 | help="If the norm of the gradient vector exceeds this, " 155 | "renormalize it to have the norm equal to max_grad_norm") 156 | parser.add_argument('-max_weight_value', type=float, default=15, 157 | help="If the norm of the gradient vector exceeds this, " 158 | "renormalize it to have the norm equal to max_grad_norm") 159 | parser.add_argument('-decay_bad_cnt', type=int, default=3) 160 | 161 | # GPU 162 | parser.add_argument('-gpus', default=[], nargs='+', type=int, 163 | help="Use CUDA on the listed devices.") 164 | 165 | parser.add_argument('-log_interval', type=int, default=100, 166 | help="logger.info stats at this interval.") 167 | 168 | parser.add_argument('-seed', type=int, default=-1, 169 | help="""Random seed used for the experiments 170 | reproducibility.""") 171 | parser.add_argument('-cuda_seed', type=int, default=-1, 172 | help="""Random CUDA seed used for the experiments 173 | reproducibility.""") 174 | 175 | # translate 176 | parser.add_argument('-eval_batch_size', type=int, default=16) 177 | parser.add_argument('-beam_size', type=int, default=5, help='Beam size') 178 | parser.add_argument('-n_best', type=int, default=1, 179 | help="""If verbose is set, will output the n_best 180 | decoded sentences""") --------------------------------------------------------------------------------