├── README.md ├── dataset.py ├── rouge ├── __init__.py ├── rouge.py └── rouge_score.py ├── run.py ├── seq2seq_model.py ├── utils.py └── vocab.py /README.md: -------------------------------------------------------------------------------- 1 | # MATINF - Multitask Chinese NLP Dataset 2 | The dataset and PyTorch Implementation for ACL 2020 paper ["MATINF: A Jointly Labeled Large-Scale Dataset for Classification, Question Answering and Summarization"](https://arxiv.org/abs/2004.12302). 3 | 4 | ## Citation 5 | If you use the dataset or code in your research, please kindly cite our work: 6 | 7 | ```bibtex 8 | @inproceedings{xu-etal-2020-matinf, 9 | title = "{MATINF}: A Jointly Labeled Large-Scale Dataset for Classification, Question Answering and Summarization", 10 | author = "Xu, Canwen and 11 | Pei, Jiaxin and 12 | Wu, Hongtao and 13 | Liu, Yiyu and 14 | Li, Chenliang", 15 | booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", 16 | month = jul, 17 | year = "2020", 18 | address = "Online", 19 | publisher = "Association for Computational Linguistics", 20 | url = "https://www.aclweb.org/anthology/2020.acl-main.330", 21 | pages = "3586--3596", 22 | } 23 | ``` 24 | 25 | ## Dataset 26 | You can get MATINF dataset by signing [the agreement on Google Form](https://forms.gle/nkH4LVE4iNQeDzsc9) to request the access. You will get the download link and the zip password after filling the form. 27 | **ALL USE MUST BE NON-COMMERCIAL!!** 28 | 29 | ## Code 30 | Please manually change the `stage` variable in `main()` to toggle from different training phases. 31 | 32 | Then run: 33 | ```bash 34 | python run.py 35 | ``` 36 | Code credit: [Hongtao Wu](mailto:wuhongtao@whu.edu.cn?cc=xucanwen@whu.edu.cn) 37 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.utils.data as data 5 | 6 | CLASS3_NAME_TO_INDEX = { 7 | '0-1岁': 0, 8 | '1-2岁': 1, 9 | '2-3岁': 2 10 | } 11 | 12 | CLASS18_NAME_TO_INDEX = { 13 | '动作发育': 0, 14 | '幼儿园': 1, 15 | '产褥期保健': 2, 16 | '婴幼常见病': 3, 17 | '家庭教育': 4, 18 | '未准父母': 5, 19 | '婴幼保健': 6, 20 | '婴幼期喂养': 7, 21 | '疫苗接种': 8, 22 | '腹泻': 9, 23 | '宝宝上火': 10, 24 | '婴幼心理': 11, 25 | '皮肤护理': 12, 26 | '流产和不孕': 13, 27 | '婴幼早教': 14, 28 | '儿童过敏': 15, 29 | '孕期保健': 16, 30 | '婴幼营养': 17 31 | } 32 | 33 | 34 | class Dataset(data.Dataset): 35 | def __init__(self): 36 | self.data = [] 37 | 38 | def __getitem__(self, index): 39 | return self.data[index] 40 | 41 | def __len__(self): 42 | return len(self.data) 43 | 44 | def add_data(self, question, description, answer, category): 45 | q = [] 46 | d = [] 47 | a = [] 48 | 49 | for w in question: 50 | q.append(w) 51 | for w in description: 52 | d.append(w) 53 | for w in answer: 54 | a.append(w) 55 | if category in CLASS3_NAME_TO_INDEX: 56 | c = CLASS3_NAME_TO_INDEX[category] 57 | else: 58 | c = CLASS18_NAME_TO_INDEX[category] 59 | 60 | self.data.append({ 61 | 'question': q, # list of tokens 62 | 'description': d, # list of tokens 63 | 'answer': a, # list of tokens 64 | 'category': c # int 65 | }) 66 | -------------------------------------------------------------------------------- /rouge/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from rouge.rouge import FilesRouge, Rouge 3 | 4 | __version__ = "0.3.2" 5 | __all__ = ["FilesRouge", "Rouge"] 6 | -------------------------------------------------------------------------------- /rouge/rouge.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | import six 4 | import rouge.rouge_score as rouge_score 5 | import io 6 | import os 7 | 8 | 9 | class FilesRouge: 10 | def __init__(self, hyp_path, ref_path, metrics=None, stats=None, 11 | batch_lines=None): 12 | assert(os.path.isfile(hyp_path)) 13 | assert(os.path.isfile(ref_path)) 14 | 15 | self.rouge = Rouge(metrics=metrics, stats=stats) 16 | 17 | def line_count(path): 18 | count = 0 19 | with open(path, "rb") as f: 20 | for line in f: 21 | count += 1 22 | return count 23 | 24 | hyp_lc = line_count(hyp_path) 25 | ref_lc = line_count(ref_path) 26 | assert(hyp_lc == ref_lc) 27 | 28 | assert(batch_lines is None or type(batch_lines) == int) 29 | 30 | self.hyp_path = hyp_path 31 | self.ref_path = ref_path 32 | self.batch_lines = batch_lines 33 | 34 | def get_scores(self, avg=False, ignore_empty=False): 35 | """Calculate ROUGE scores between each pair of 36 | lines (hyp_file[i], ref_file[i]). 37 | Args: 38 | * hyp_path: hypothesis file path 39 | * ref_path: references file path 40 | * avg (False): whether to get an average scores or a list 41 | """ 42 | hyp_path, ref_path = self.hyp_path, self.ref_path 43 | 44 | with io.open(hyp_path, encoding="utf-8", mode="r") as hyp_file: 45 | hyps = [line[:-1] for line in hyp_file] 46 | with io.open(ref_path, encoding="utf-8", mode="r") as ref_file: 47 | refs = [line[:-1] for line in ref_file] 48 | 49 | return self.rouge.get_scores(hyps, refs, avg=avg, 50 | ignore_empty=ignore_empty) 51 | 52 | 53 | class Rouge: 54 | DEFAULT_METRICS = ["rouge-1", "rouge-2", "rouge-l"] 55 | AVAILABLE_METRICS = { 56 | "rouge-1": lambda hyp, ref: rouge_score.rouge_n(hyp, ref, 1), 57 | "rouge-2": lambda hyp, ref: rouge_score.rouge_n(hyp, ref, 2), 58 | "rouge-l": lambda hyp, ref: 59 | rouge_score.rouge_l_summary_level(hyp, ref), 60 | } 61 | DEFAULT_STATS = ["f", "p", "r"] 62 | AVAILABLE_STATS = ["f", "p", "r"] 63 | 64 | def __init__(self, metrics=None, stats=None): 65 | if metrics is not None: 66 | self.metrics = [m.lower() for m in metrics] 67 | 68 | for m in self.metrics: 69 | if m not in Rouge.AVAILABLE_METRICS: 70 | raise ValueError("Unknown metric '%s'" % m) 71 | else: 72 | self.metrics = Rouge.DEFAULT_METRICS 73 | 74 | if stats is not None: 75 | self.stats = [s.lower() for s in stats] 76 | 77 | for s in self.stats: 78 | if s not in Rouge.AVAILABLE_STATS: 79 | raise ValueError("Unknown stat '%s'" % s) 80 | else: 81 | self.stats = Rouge.DEFAULT_STATS 82 | 83 | def get_scores(self, hyps, refs, avg=False, ignore_empty=False): 84 | if isinstance(hyps, six.string_types): 85 | hyps, refs = [hyps], [refs] 86 | 87 | if ignore_empty: 88 | # Filter out hyps of 0 length 89 | hyps_and_refs = zip(hyps, refs) 90 | hyps_and_refs = [_ for _ in hyps_and_refs if len(_[0]) > 0] 91 | hyps, refs = zip(*hyps_and_refs) 92 | 93 | assert(type(hyps) == type(refs)) 94 | assert(len(hyps) == len(refs)) 95 | 96 | if not avg: 97 | return self._get_scores(hyps, refs) 98 | return self._get_avg_scores(hyps, refs) 99 | 100 | def _get_scores(self, hyps, refs): 101 | scores = [] 102 | for hyp, ref in zip(hyps, refs): 103 | sen_score = {} 104 | hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0] 105 | ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0] 106 | 107 | for m in self.metrics: 108 | fn = Rouge.AVAILABLE_METRICS[m] 109 | sc = fn(hyp, ref) 110 | sen_score[m] = {s: sc[s] for s in self.stats} 111 | scores.append(sen_score) 112 | return scores 113 | 114 | def _get_avg_scores(self, hyps, refs): 115 | scores = {m: {s: 0 for s in self.stats} for m in self.metrics} 116 | 117 | count = 0 118 | for (hyp, ref) in zip(hyps, refs): 119 | hyp = [" ".join(_.split()) for _ in hyp.split(".") if len(_) > 0] 120 | ref = [" ".join(_.split()) for _ in ref.split(".") if len(_) > 0] 121 | 122 | for m in self.metrics: 123 | fn = Rouge.AVAILABLE_METRICS[m] 124 | sc = fn(hyp, ref) 125 | scores[m] = {s: scores[m][s] + sc[s] for s in self.stats} 126 | count += 1 127 | scores = {m: {s: scores[m][s] / count for s in scores[m]} 128 | for m in scores} 129 | return scores 130 | -------------------------------------------------------------------------------- /rouge/rouge_score.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2017 Google Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ROUGE Metric Implementation 16 | 17 | This is a very slightly version of: 18 | https://github.com/pltrdy/seq2seq/blob/master/seq2seq/metrics/rouge.py 19 | 20 | --- 21 | 22 | ROUGe metric implementation. 23 | 24 | This is a modified and slightly extended verison of 25 | https://github.com/miso-belica/sumy/blob/dev/sumy/evaluation/rouge.py. 26 | """ 27 | from __future__ import absolute_import 28 | from __future__ import division, print_function, unicode_literals 29 | import itertools 30 | 31 | 32 | def _get_ngrams(n, text): 33 | """Calcualtes n-grams. 34 | 35 | Args: 36 | n: which n-grams to calculate 37 | text: An array of tokens 38 | 39 | Returns: 40 | A set of n-grams 41 | """ 42 | ngram_set = set() 43 | text_length = len(text) 44 | max_index_ngram_start = text_length - n 45 | for i in range(max_index_ngram_start + 1): 46 | ngram_set.add(tuple(text[i:i + n])) 47 | return ngram_set 48 | 49 | 50 | def _split_into_words(sentences): 51 | """Splits multiple sentences into words and flattens the result""" 52 | return list(itertools.chain(*[_.split(" ") for _ in sentences])) 53 | 54 | 55 | def _get_word_ngrams(n, sentences): 56 | """Calculates word n-grams for multiple sentences. 57 | """ 58 | assert len(sentences) > 0 59 | assert n > 0 60 | 61 | words = _split_into_words(sentences) 62 | return _get_ngrams(n, words) 63 | 64 | 65 | def _len_lcs(x, y): 66 | """ 67 | Returns the length of the Longest Common Subsequence between sequences x 68 | and y. 69 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 70 | 71 | Args: 72 | x: sequence of words 73 | y: sequence of words 74 | 75 | Returns 76 | integer: Length of LCS between x and y 77 | """ 78 | table = _lcs(x, y) 79 | n, m = len(x), len(y) 80 | return table[n, m] 81 | 82 | 83 | def _lcs(x, y): 84 | """ 85 | Computes the length of the longest common subsequence (lcs) between two 86 | strings. The implementation below uses a DP programming algorithm and runs 87 | in O(nm) time where n = len(x) and m = len(y). 88 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 89 | 90 | Args: 91 | x: collection of words 92 | y: collection of words 93 | 94 | Returns: 95 | Table of dictionary of coord and len lcs 96 | """ 97 | n, m = len(x), len(y) 98 | table = dict() 99 | for i in range(n + 1): 100 | for j in range(m + 1): 101 | if i == 0 or j == 0: 102 | table[i, j] = 0 103 | elif x[i - 1] == y[j - 1]: 104 | table[i, j] = table[i - 1, j - 1] + 1 105 | else: 106 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 107 | return table 108 | 109 | 110 | def _recon_lcs(x, y): 111 | """ 112 | Returns the Longest Subsequence between x and y. 113 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 114 | 115 | Args: 116 | x: sequence of words 117 | y: sequence of words 118 | 119 | Returns: 120 | sequence: LCS of x and y 121 | """ 122 | i, j = len(x), len(y) 123 | table = _lcs(x, y) 124 | 125 | def _recon(i, j): 126 | """private recon calculation""" 127 | if i == 0 or j == 0: 128 | return [] 129 | elif x[i - 1] == y[j - 1]: 130 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 131 | elif table[i - 1, j] > table[i, j - 1]: 132 | return _recon(i - 1, j) 133 | else: 134 | return _recon(i, j - 1) 135 | 136 | recon_tuple = tuple(map(lambda x: x[0], _recon(i, j))) 137 | return recon_tuple 138 | 139 | 140 | def multi_rouge_n(sequences, scores_ids, n=2): 141 | """ 142 | Efficient way to compute highly repetitive scoring 143 | i.e. sequences are involved multiple time 144 | 145 | Args: 146 | sequences(list[str]): list of sequences (either hyp or ref) 147 | scores_ids(list[tuple(int)]): list of pairs (hyp_id, ref_id) 148 | ie. scores[i] = rouge_n(scores_ids[i][0], 149 | scores_ids[i][1]) 150 | 151 | Returns: 152 | scores: list of length `len(scores_ids)` containing rouge `n` 153 | scores as a dict with 'f', 'r', 'p' 154 | Raises: 155 | KeyError: if there's a value of i in scores_ids that is not in 156 | [0, len(sequences)[ 157 | """ 158 | ngrams = [_get_word_ngrams(n, sequence) for sequence in sequences] 159 | counts = [len(ngram) for ngram in ngrams] 160 | 161 | scores = [] 162 | for hyp_id, ref_id in scores_ids: 163 | evaluated_ngrams = ngrams[hyp_id] 164 | evaluated_count = counts[hyp_id] 165 | 166 | reference_ngrams = ngrams[ref_id] 167 | reference_count = counts[ref_id] 168 | 169 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 170 | overlapping_count = len(overlapping_ngrams) 171 | 172 | scores += [f_r_p_rouge_n(evaluated_count, 173 | reference_count, overlapping_count)] 174 | return scores 175 | 176 | 177 | def rouge_n(evaluated_sentences, reference_sentences, n=2): 178 | """ 179 | Computes ROUGE-N of two text collections of sentences. 180 | Sourece: http://research.microsoft.com/en-us/um/people/cyl/download/ 181 | papers/rouge-working-note-v1.3.1.pdf 182 | 183 | Args: 184 | evaluated_sentences: The sentences that have been picked by the 185 | summarizer 186 | reference_sentences: The sentences from the referene set 187 | n: Size of ngram. Defaults to 2. 188 | 189 | Returns: 190 | A tuple (f1, precision, recall) for ROUGE-N 191 | 192 | Raises: 193 | ValueError: raises exception if a param has len <= 0 194 | """ 195 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 196 | raise ValueError("Collections must contain at least 1 sentence.") 197 | 198 | evaluated_ngrams = _get_word_ngrams(n, evaluated_sentences) 199 | reference_ngrams = _get_word_ngrams(n, reference_sentences) 200 | reference_count = len(reference_ngrams) 201 | evaluated_count = len(evaluated_ngrams) 202 | 203 | # Gets the overlapping ngrams between evaluated and reference 204 | overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams) 205 | overlapping_count = len(overlapping_ngrams) 206 | 207 | return f_r_p_rouge_n(evaluated_count, reference_count, overlapping_count) 208 | 209 | 210 | def f_r_p_rouge_n(evaluated_count, reference_count, overlapping_count): 211 | # Handle edge case. This isn't mathematically correct, but it's good enough 212 | if evaluated_count == 0: 213 | precision = 0.0 214 | else: 215 | precision = overlapping_count / evaluated_count 216 | 217 | if reference_count == 0: 218 | recall = 0.0 219 | else: 220 | recall = overlapping_count / reference_count 221 | 222 | f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8)) 223 | 224 | return {"f": f1_score, "p": precision, "r": recall} 225 | 226 | 227 | def _union_lcs(evaluated_sentences, reference_sentence, prev_union=None): 228 | """ 229 | Returns LCS_u(r_i, C) which is the LCS score of the union longest common 230 | subsequence between reference sentence ri and candidate summary C. 231 | For example: 232 | if r_i= w1 w2 w3 w4 w5, and C contains two sentences: c1 = w1 w2 w6 w7 w8 233 | and c2 = w1 w3 w8 w9 w5, then the longest common subsequence of r_i and c1 234 | is "w1 w2" and the longest common subsequence of r_i and c2 is "w1 w3 w5". 235 | The union longest common subsequence of r_i, c1, and c2 is "w1 w2 w3 w5" 236 | and LCS_u(r_i, C) = 4/5. 237 | 238 | Args: 239 | evaluated_sentences: The sentences that have been picked by the 240 | summarizer 241 | reference_sentence: One of the sentences in the reference summaries 242 | 243 | Returns: 244 | float: LCS_u(r_i, C) 245 | 246 | ValueError: 247 | Raises exception if a param has len <= 0 248 | """ 249 | if prev_union is None: 250 | prev_union = set() 251 | 252 | if len(evaluated_sentences) <= 0: 253 | raise ValueError("Collections must contain at least 1 sentence.") 254 | 255 | lcs_union = prev_union 256 | prev_count = len(prev_union) 257 | reference_words = _split_into_words([reference_sentence]) 258 | 259 | combined_lcs_length = 0 260 | for eval_s in evaluated_sentences: 261 | evaluated_words = _split_into_words([eval_s]) 262 | lcs = set(_recon_lcs(reference_words, evaluated_words)) 263 | combined_lcs_length += len(lcs) 264 | lcs_union = lcs_union.union(lcs) 265 | 266 | new_lcs_count = len(lcs_union) - prev_count 267 | return new_lcs_count, lcs_union 268 | 269 | 270 | def rouge_l_summary_level(evaluated_sentences, reference_sentences): 271 | """ 272 | Computes ROUGE-L (summary level) of two text collections of sentences. 273 | http://research.microsoft.com/en-us/um/people/cyl/download/papers/ 274 | rouge-working-note-v1.3.1.pdf 275 | 276 | Calculated according to: 277 | R_lcs = SUM(1, u)[LCS(r_i,C)]/m 278 | P_lcs = SUM(1, u)[LCS(r_i,C)]/n 279 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 280 | 281 | where: 282 | SUM(i,u) = SUM from i through u 283 | u = number of sentences in reference summary 284 | C = Candidate summary made up of v sentences 285 | m = number of words in reference summary 286 | n = number of words in candidate summary 287 | 288 | Args: 289 | evaluated_sentences: The sentences that have been picked by the 290 | summarizer 291 | reference_sentence: One of the sentences in the reference summaries 292 | 293 | Returns: 294 | A float: F_lcs 295 | 296 | Raises: 297 | ValueError: raises exception if a param has len <= 0 298 | """ 299 | if len(evaluated_sentences) <= 0 or len(reference_sentences) <= 0: 300 | raise ValueError("Collections must contain at least 1 sentence.") 301 | 302 | # total number of words in reference sentences 303 | m = len(set(_split_into_words(reference_sentences))) 304 | 305 | # total number of words in evaluated sentences 306 | n = len(set(_split_into_words(evaluated_sentences))) 307 | 308 | # print("m,n %d %d" % (m, n)) 309 | union_lcs_sum_across_all_references = 0 310 | union = set() 311 | for ref_s in reference_sentences: 312 | lcs_count, union = _union_lcs(evaluated_sentences, 313 | ref_s, 314 | prev_union=union) 315 | union_lcs_sum_across_all_references += lcs_count 316 | 317 | llcs = union_lcs_sum_across_all_references 318 | r_lcs = llcs / m 319 | p_lcs = llcs / n 320 | beta = p_lcs / (r_lcs + 1e-12) 321 | num = (1 + (beta**2)) * r_lcs * p_lcs 322 | denom = r_lcs + ((beta**2) * p_lcs) 323 | f_lcs = num / (denom + 1e-12) 324 | return {"f": f_lcs, "p": p_lcs, "r": r_lcs} 325 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import time 5 | import sys 6 | import os 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.data as data 12 | 13 | from typing import Dict, List 14 | from rouge.rouge import Rouge 15 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 16 | import pickle 17 | import re 18 | from collections import namedtuple 19 | from itertools import cycle 20 | from tqdm import tqdm 21 | 22 | from vocab import Vocab 23 | from seq2seq_model import Seq2seq 24 | 25 | 26 | MODEs = ('summ', 'qa', 'cls3', 'cls18') 27 | CLASS3_NAME_TO_INDEX = { 28 | '0-1岁': 0, 29 | '1-2岁': 1, 30 | '2-3岁': 2 31 | } 32 | 33 | CLASS18_NAME_TO_INDEX = { 34 | '动作发育': 0, 35 | '幼儿园': 1, 36 | '产褥期保健': 2, 37 | '婴幼常见病': 3, 38 | '家庭教育': 4, 39 | '未准父母': 5, 40 | '婴幼保健': 6, 41 | '婴幼期喂养': 7, 42 | '疫苗接种': 8, 43 | '腹泻': 9, 44 | '宝宝上火': 10, 45 | '婴幼心理': 11, 46 | '皮肤护理': 12, 47 | '流产和不孕': 13, 48 | '婴幼早教': 14, 49 | '儿童过敏': 15, 50 | '孕期保健': 16, 51 | '婴幼营养': 17 52 | } 53 | 54 | seed = 2019 55 | 56 | BATCH_SIZE_SUMM_QA = 80 57 | BATCH_SIZE_CLS3 = 16 58 | BATCH_SIZE_CLS18 = 64 59 | 60 | SUMM_WEIGHT = 1.0 61 | QA_WEIGHT = 1.0 62 | CLS_WEIGHT = 1.0 63 | 64 | log_every = 100 65 | valid_niter = 500 66 | max_patience = 10 67 | max_epoch = 10 68 | 69 | VALID_BATCH = 64 70 | VALID_NUM = -1 # '-1' if use the whole dev set to validate. 71 | TEST_DATA_FILE = os.sep.join(['DATA', 'test.csv']) 72 | model_save_path = 'checkpoints/' 73 | model_load_path = '' 74 | USE_CUDA = True 75 | GPU_PARALLEL = False 76 | is_training = True 77 | OUTPUT_FILE = os.sep.join(['output', 'test_output.txt']) 78 | base_learning_rate = 0.001 79 | 80 | DATASET_TRAIN_CLS3 = './data/train_cls3.pkl' 81 | DATASET_TRAIN_CLS18 = './data/train_cls18.pkl' 82 | DATASET_DEV_CLS3 = './data/dev_cls3.pkl' 83 | DATASET_DEV_CLS18 = './data/dev_cls18.pkl' 84 | DATASET_TEST_CLS3 = './data/test_cls3.pkl' 85 | DATASET_TEST_CLS18 = './data/test_cls18.pkl' 86 | 87 | vocab_file = './data/vocab.json' 88 | embeddings_file = './data/embeddings.pkl' 89 | 90 | Hypothesis = namedtuple('Hypothesis', ['value', 'score']) 91 | 92 | device = torch.device("cuda:0" if USE_CUDA else "cpu") 93 | 94 | 95 | def train(args): 96 | print('Loading dataset...') 97 | with open(DATASET_TRAIN_CLS3, 'rb') as f: 98 | dataset_tr_cls3 = pickle.load(f) 99 | with open(DATASET_TRAIN_CLS18, 'rb') as f: 100 | dataset_tr_cls18 = pickle.load(f) 101 | dataset_tr_summ_qa = data.ConcatDataset([dataset_tr_cls3, dataset_tr_cls18]) 102 | 103 | dataloader_tr_summ_qa = cycle(data.DataLoader(dataset=dataset_tr_summ_qa, 104 | batch_size=BATCH_SIZE_SUMM_QA, 105 | shuffle=True, 106 | collate_fn=lambda x: x)) 107 | dataloader_tr_cls3 = cycle(data.DataLoader(dataset=dataset_tr_cls3, 108 | batch_size=BATCH_SIZE_CLS3, 109 | shuffle=True, 110 | collate_fn=lambda x: x)) 111 | dataloader_tr_cls18 = cycle(data.DataLoader(dataset=dataset_tr_cls18, 112 | batch_size=BATCH_SIZE_CLS18, 113 | shuffle=True, 114 | collate_fn=lambda x: x)) 115 | 116 | print('Loading vocab...') 117 | vocab = Vocab.load(vocab_file) 118 | 119 | print('Loading embeddings...') 120 | with open(embeddings_file, 'rb') as f: 121 | embeddings = pickle.load(f) 122 | print('-----OK-----') 123 | 124 | if not os.path.exists(model_save_path): 125 | print('create dir: {}'.format(model_save_path)) 126 | os.mkdir(model_save_path) 127 | 128 | if model_load_path: 129 | print('Loading model...') 130 | model = Seq2seq.load(model_load_path) 131 | else: 132 | model = Seq2seq(hidden_size=200, vocab=vocab, embddings=embeddings) 133 | 134 | if USE_CUDA: 135 | print('use device: %s' % device, file=sys.stderr) 136 | model = model.to(device) 137 | if GPU_PARALLEL: # there may exists something wrong... please set it to 'False'. 138 | print("Let's use", torch.cuda.device_count(), "GPUs!") 139 | model = torch.nn.DataParallel(model, device_ids=[0, 1]) 140 | 141 | model.train() 142 | 143 | optimizer = torch.optim.Adam(model.parameters(), lr=base_learning_rate) 144 | CELoss = nn.CrossEntropyLoss() 145 | 146 | epoch = train_iter = 0 147 | report_iter_num = 0 148 | report_loss_summ = report_loss_qa = report_loss_cls3 = 0 149 | report_loss_cls18 = 0 150 | cum_examples_summ_qa = cum_examples_cls3 = cum_examples_cls18 = 0 151 | best_results = [0, 0, 0, 0] # best results for [summ, qa, cls3, cls18] 152 | 153 | # patience = 0 154 | iter_num_summ_qa = math.ceil(len(dataset_tr_summ_qa) / BATCH_SIZE_SUMM_QA) 155 | iter_num_cls3 = math.ceil(len(dataset_tr_cls3) / BATCH_SIZE_CLS3) 156 | iter_num_cls18 = math.ceil(len(dataset_tr_cls18) / BATCH_SIZE_CLS18) 157 | iter_num_of_one_epoch = min(iter_num_summ_qa, iter_num_cls3, iter_num_cls18) 158 | 159 | begin_time = time.time() 160 | 161 | while True: 162 | epoch += 1 163 | 164 | for i in range(iter_num_of_one_epoch): 165 | train_iter += 1 166 | report_iter_num += 1 167 | 168 | # -------------------------------------------------------------------- 169 | # summ: D -> Q 170 | optimizer.zero_grad() 171 | 172 | mini_batch = next(iter(dataloader_tr_summ_qa)) 173 | question = [data['question'] for data in mini_batch] 174 | description = [data['description'] for data in mini_batch] 175 | 176 | for i in range(len(question)): 177 | question[i].insert(0, '') 178 | question[i].insert(len(question[i]), '') 179 | 180 | try: 181 | example_losses_summ = -model(description, question, mode='summ') 182 | batch_loss_summ = example_losses_summ.sum() # total batch loss. 183 | loss_summ = batch_loss_summ / len(mini_batch) * SUMM_WEIGHT # final(avg.) batch loss 184 | loss_summ.backward() 185 | 186 | report_loss_summ += batch_loss_summ.item() 187 | cum_examples_summ_qa += len(mini_batch) 188 | except RuntimeError as e: 189 | if 'out of memory' in str(e): 190 | print('| WARNING: ran out of memory') 191 | if hasattr(torch.cuda, 'empty_cache'): 192 | torch.cuda.empty_cache() 193 | else: 194 | raise e 195 | 196 | # clip gradient 197 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 198 | 199 | optimizer.step() 200 | 201 | # -------------------------------------------------------------------- 202 | # QA: Q -> A 203 | optimizer.zero_grad() 204 | 205 | question = [data['question'] for data in mini_batch] 206 | answer = [data['answer'] for data in mini_batch] 207 | 208 | for i in range(len(answer)): 209 | answer[i].insert(0, '') 210 | answer[i].insert(len(answer[i]), '') 211 | 212 | try: 213 | example_losses_qa = -model(question, answer, mode='qa') 214 | batch_loss_qa = example_losses_qa.sum() # total batch loss. 215 | loss_qa = batch_loss_qa / len(mini_batch) * QA_WEIGHT # final(ave) batch loss 216 | loss_qa.backward() 217 | 218 | report_loss_qa += batch_loss_qa.item() 219 | except RuntimeError as e: 220 | if 'out of memory' in str(e): 221 | print('| WARNING: ran out of memory') 222 | if hasattr(torch.cuda, 'empty_cache'): 223 | torch.cuda.empty_cache() 224 | else: 225 | raise e 226 | 227 | # clip gradient 228 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 229 | 230 | optimizer.step() 231 | 232 | # -------------------------------------------------------------------- 233 | # cls3: D, Q -> C 234 | optimizer.zero_grad() 235 | 236 | mini_batch = next(iter(dataloader_tr_cls3)) 237 | question = [data['question'] for data in mini_batch] 238 | description = [data['description'] for data in mini_batch] 239 | category = torch.tensor([data['category'] for data in mini_batch]).to(device) 240 | 241 | y_pred = model(source=description, source2=question, target=None, mode='cls3') 242 | loss_cls3 = CELoss(y_pred, category) * CLS_WEIGHT 243 | 244 | try: 245 | loss_cls3.backward() 246 | except RuntimeError as e: 247 | if 'out of memory' in str(e): 248 | print('| WARNING: ran out of memory') 249 | if hasattr(torch.cuda, 'empty_cache'): 250 | torch.cuda.empty_cache() 251 | else: 252 | raise e 253 | 254 | # clip gradient 255 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 256 | 257 | optimizer.step() 258 | 259 | report_loss_cls3 += loss_cls3.item() 260 | cum_examples_cls3 += len(mini_batch) 261 | 262 | # -------------------------------------------------------------------- 263 | # cls18: D,Q -> C 264 | optimizer.zero_grad() 265 | 266 | mini_batch = next(iter(dataloader_tr_cls18)) 267 | question = [data['question'] for data in mini_batch] 268 | description = [data['description'] for data in mini_batch] 269 | category = torch.tensor([data['category'] for data in mini_batch]).to(device) 270 | 271 | y_pred = model(source=description, source2=question, target=None, mode='cls18') 272 | loss_cls18 = CELoss(y_pred, category) * CLS_WEIGHT 273 | 274 | try: 275 | loss_cls18.backward() 276 | except RuntimeError as e: 277 | if 'out of memory' in str(e): 278 | print('| WARNING: ran out of memory') 279 | if hasattr(torch.cuda, 'empty_cache'): 280 | torch.cuda.empty_cache() 281 | else: 282 | raise e 283 | 284 | # clip gradient 285 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 286 | 287 | optimizer.step() 288 | 289 | report_loss_cls18 += loss_cls18.item() 290 | cum_examples_cls18 += len(mini_batch) 291 | 292 | if train_iter % log_every == 0: 293 | print('-' * 50) 294 | print('epoch:', epoch) 295 | print('iters:', train_iter) 296 | kwargs_summ = { 297 | 'report_loss': report_loss_summ, 298 | 'report_iter_num': report_iter_num, 299 | 'cum_examples': cum_examples_summ_qa, 300 | 'num_of_train_set': len(dataset_tr_summ_qa), 301 | 'begin_time': begin_time 302 | } 303 | report(mode='summ', **kwargs_summ) 304 | kwargs_qa = { 305 | 'report_loss': report_loss_qa, 306 | 'report_iter_num': report_iter_num, 307 | 'cum_examples': cum_examples_summ_qa, 308 | 'num_of_train_set': len(dataset_tr_summ_qa), 309 | 'begin_time': begin_time 310 | } 311 | report(mode='qa', **kwargs_qa) 312 | kwargs_cls3 = { 313 | 'report_loss': report_loss_cls3, 314 | 'report_iter_num': report_iter_num, 315 | 'cum_examples': cum_examples_cls3, 316 | 'num_of_train_set': len(dataset_tr_cls3), 317 | 'begin_time': begin_time 318 | } 319 | report(mode='cls3', **kwargs_cls3) 320 | kwargs_cls18 = { 321 | 'report_loss': report_loss_cls18, 322 | 'report_iter_num': report_iter_num, 323 | 'cum_examples': cum_examples_cls18, 324 | 'num_of_train_set': len(dataset_tr_cls18), 325 | 'begin_time': begin_time 326 | } 327 | report(mode='cls18', **kwargs_cls18) 328 | print('-' * 50) 329 | report_loss_summ = report_loss_qa = report_loss_cls3 = report_loss_cls18 = 0 330 | report_iter_num = 0 331 | 332 | if train_iter % valid_niter == 0: 333 | print('begin validation ...', file=sys.stderr) 334 | 335 | which_better = [] 336 | save_model = False 337 | results = valid(model) 338 | for i in range(len(results)): 339 | if i < 2 and isinstance(results[i], dict) and results[i]['rouge-l']['f'] > best_results[i]: 340 | save_model = True 341 | best_results[i] = results[i]['rouge-l']['f'] 342 | which_better.append(i) 343 | elif i >= 2 and results[i] > best_results[i]: 344 | save_model = True 345 | best_results[i] = results[i] 346 | which_better.append(i) 347 | 348 | if save_model: 349 | print('Task {} get better scores!'.format(which_better)) 350 | Seq2seq.save(model, model_save_path + 'model_iter_{}.pt'.format(train_iter)) 351 | 352 | cum_examples_summ_qa = cum_examples_cls3 = cum_examples_cls18 = 0 353 | # END one epoch. 354 | 355 | if epoch == max_epoch: 356 | print('reached maximum number of epochs!', file=sys.stderr) 357 | # exit(0) 358 | break 359 | 360 | 361 | # dev 362 | def valid(model, mode='all'): 363 | model.eval() 364 | with open(DATASET_DEV_CLS3, 'rb') as f: 365 | dataset_cls3 = pickle.load(f) 366 | with open(DATASET_DEV_CLS18, 'rb') as f: 367 | dataset_cls18 = pickle.load(f) 368 | dataset_summ_qa = data.ConcatDataset([dataset_cls3, dataset_cls18]) 369 | 370 | cls3_loader = torch.utils.data.DataLoader(dataset=dataset_cls3, 371 | batch_size=VALID_BATCH, 372 | shuffle=False, 373 | collate_fn=lambda x: x) 374 | cls3_iterator = iter(cls3_loader) 375 | cls18_loader = torch.utils.data.DataLoader(dataset=dataset_cls18, 376 | batch_size=VALID_BATCH, 377 | shuffle=False, 378 | collate_fn=lambda x: x) 379 | cls18_iterator = iter(cls18_loader) 380 | 381 | rouge_summ = rouge_qa = None 382 | acc_cls3 = acc_cls18 = 0 383 | # -------------------------------------------------------------------- 384 | if mode in ['all', 'summ','qa']: 385 | data_val_sum_qa = [] 386 | if VALID_NUM > 0: 387 | for i in range(VALID_NUM): 388 | data_val_sum_qa.append(dataset_summ_qa[i]) 389 | else: 390 | for i in range(len(dataset_summ_qa)): 391 | data_val_sum_qa.append(dataset_summ_qa[i]) 392 | 393 | if mode in ['all','summ']: 394 | refs = [' '.join(data['question']) for data in data_val_sum_qa] 395 | x = [data['description'] for data in data_val_sum_qa] 396 | hyps = beam_search('summ', model, x) 397 | hyps = [' '.join(list(sent)) for sent in hyps] 398 | rouge = Rouge() 399 | try: 400 | rouge_summ = rouge.get_scores(hyps, refs, avg=True, ignore_empty=True) 401 | print_rouge(rouge_summ) 402 | except RuntimeError: 403 | print('Failed to compute Rouge!') 404 | 405 | if mode in ['all', 'qa']: 406 | refs = [' '.join(data['answer']) for data in data_val_sum_qa] 407 | x = [data['question'] for data in data_val_sum_qa] 408 | hyps = beam_search('qa', model, x) 409 | hyps = [' '.join(list(sent)) for sent in hyps] 410 | rouge = Rouge() 411 | try: 412 | rouge_qa = rouge.get_scores(hyps, refs, avg=True, ignore_empty=True) 413 | print_rouge(rouge_qa) 414 | except RuntimeError: 415 | print('Failed to compute Rouge!') 416 | 417 | # cls3 & cls18 418 | def iter_through_cls_dev(iterator, mode): 419 | val_correct = 0 420 | val_num = 0 421 | for i in range(math.ceil(VALID_NUM / VALID_BATCH)): 422 | mini_batch = next(iterator) 423 | question = [data['question'] for data in mini_batch] 424 | description = [data['description'] for data in mini_batch] 425 | y_gt = torch.tensor([data['category'] for data in mini_batch]).to(device) 426 | y_pred = model(source=description, source2=question, target=None, mode=mode) 427 | y_pred_labels = torch.argmax(y_pred, dim=1) 428 | val_correct += (y_gt == y_pred_labels).sum().item() 429 | val_num += len(mini_batch) 430 | 431 | return val_correct / val_num 432 | 433 | if mode in ['all', 'cls3']: 434 | acc_cls3 = iter_through_cls_dev(cls3_iterator, 'cls3') 435 | print('Acc_cls3:', acc_cls3) 436 | if mode in ['all', 'cls18']: 437 | acc_cls18 = iter_through_cls_dev(cls18_iterator, 'cls18') 438 | print('Acc_cls18:', acc_cls18) 439 | 440 | if is_training: 441 | model.train() 442 | 443 | return rouge_summ, rouge_qa, acc_cls3, acc_cls18 444 | 445 | 446 | def report(mode: str, **kwargs): 447 | if mode not in MODEs: 448 | print('Failed to report! Invalid mode {}.'.format(mode)) 449 | return 450 | 451 | print('mode %s: avg. loss %.2f, progress %.2f, ' 452 | 'time elapsed %.2f sec' % (mode, 453 | kwargs['report_loss'] / kwargs[ 454 | 'report_iter_num'], 455 | float(kwargs['cum_examples']) / kwargs['num_of_train_set'] * 100, 456 | time.time() - kwargs['begin_time'])) 457 | 458 | 459 | def print_rouge(rouge: Rouge): 460 | # print('p: ', [str(rouge['rouge-1']['p']), str(rouge['rouge-2']['p']), str(rouge['rouge-l']['p'])]) 461 | # print('r: ', [str(rouge['rouge-1']['r']), str(rouge['rouge-2']['r']), str(rouge['rouge-l']['r'])]) 462 | print('f: ', [str(rouge['rouge-1']['f']), str(rouge['rouge-2']['f']), str(rouge['rouge-l']['f'])]) 463 | 464 | 465 | # Test 466 | def evaluate_summ_qa(model, dataset, mode, batch_size=64): 467 | assert mode in ('summ', 'qa'), 'Invalid mode!' 468 | 469 | model.eval() 470 | 471 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 472 | batch_size=batch_size, 473 | shuffle=False, 474 | collate_fn=lambda x: x) 475 | 476 | rouge1_f_sum = rouge2_f_sum = rougeL_f_sum = bleu_sum = 0 477 | examples_rouge = examples_bleu = 0 478 | 479 | rouge = Rouge() 480 | count = 0 481 | if mode == 'summ': 482 | for mini_batch in tqdm(data_loader): 483 | count += 1 484 | refs = [' '.join(data['question']) for data in mini_batch] 485 | x = [data['description'] for data in mini_batch] 486 | hyps_raw = beam_search('summ', model, x) 487 | hyps = [' '.join(list(sent)) for sent in hyps_raw] 488 | try: 489 | rouge_score = rouge.get_scores(hyps, refs, avg=True, ignore_empty=True) 490 | rouge1_f_sum += rouge_score['rouge-1']['f'] * len(mini_batch) 491 | rouge2_f_sum += rouge_score['rouge-2']['f'] * len(mini_batch) 492 | rougeL_f_sum += rouge_score['rouge-l']['f'] * len(mini_batch) 493 | examples_rouge += len(mini_batch) 494 | except ValueError as e: 495 | print(str(e) + ' | continuing...') 496 | continue 497 | 498 | elif mode == 'qa': 499 | for mini_batch in tqdm(data_loader): 500 | count += 1 501 | refs = [' '.join(data['answer']) for data in mini_batch] 502 | x = [data['question'] for data in mini_batch] 503 | hyps_raw = beam_search('qa', model, x) 504 | hyps = [' '.join(list(sent)) for sent in hyps_raw] 505 | try: 506 | rouge_score = rouge.get_scores(hyps, refs, avg=True, ignore_empty=True) 507 | rouge1_f_sum += rouge_score['rouge-1']['f'] * len(mini_batch) 508 | rouge2_f_sum += rouge_score['rouge-2']['f'] * len(mini_batch) 509 | rougeL_f_sum += rouge_score['rouge-l']['f'] * len(mini_batch) 510 | examples_rouge += len(mini_batch) 511 | except ValueError as e: 512 | print(str(e) + ' | continuing...') 513 | continue 514 | 515 | # calculate BLEU score 516 | refs = [data['answer'] for data in mini_batch] 517 | hyps = [list(sent) for sent in hyps_raw] 518 | smoothie = SmoothingFunction().method4 519 | for i in range(len(hyps)): 520 | try: 521 | bleu = sentence_bleu([refs[i]], hyps[i], smoothing_function=smoothie) 522 | bleu_sum += bleu 523 | examples_bleu += 1 524 | except ZeroDivisionError as e: 525 | print(str(e) + ' | continuing...') 526 | continue 527 | 528 | rouge_1_f = rouge1_f_sum / examples_rouge 529 | rouge_2_f = rouge2_f_sum / examples_rouge 530 | rouge_L_f = rougeL_f_sum / examples_rouge 531 | if mode == 'qa': 532 | bleu_score = bleu_sum / examples_bleu 533 | 534 | # with open('output/test_{}.txt'.format(mode), 'w', encoding='utf-8') as f: 535 | # f.write('rouge-1 f: ' + str(rouge_1_f) + '\n') 536 | # f.write('rouge-2 f: ' + str(rouge_2_f) + '\n') 537 | # f.write('rouge-L f: ' + str(rouge_L_f) + '\n') 538 | # f.write('\n') 539 | # 540 | # for i in range((len(candidates)): 541 | # f.write('input: ' + inputs[i] + '\n') 542 | # f.write('hyp: ' + ''.join(candidates[i]) + '\n') 543 | # f.write('ref: ' + targets[i] + '\n\n') 544 | 545 | if is_training: 546 | model.train() 547 | print('rouge-1 f: ' + str(rouge_1_f)) 548 | print('rouge-2 f: ' + str(rouge_2_f)) 549 | print('rouge-L f: ' + str(rouge_L_f)) 550 | if mode == 'qa': 551 | print('bleu: ', bleu_score) 552 | 553 | 554 | def evaluate_cls(model, dataset, mode, batch_size=16): 555 | assert mode in ('cls3', 'cls18'), 'Invalid mode!' 556 | model.eval() 557 | 558 | data_loader = torch.utils.data.DataLoader(dataset=dataset, 559 | batch_size=batch_size, 560 | shuffle=False, 561 | collate_fn=lambda x: x) 562 | 563 | val_correct = 0 564 | val_num = 0 565 | 566 | for mini_batch in tqdm(data_loader): 567 | question = [data['question'] for data in mini_batch] 568 | description = [data['description'] for data in mini_batch] 569 | y_gt = torch.tensor([data['category'] for data in mini_batch]).to(device) # (batch,1) 570 | y_pred = model(source=description, source2=question, target=None, mode=mode) # (batch,3) 571 | y_pred_labels = torch.argmax(y_pred, dim=1) 572 | val_correct += (y_gt == y_pred_labels).sum() 573 | val_num += len(mini_batch) 574 | 575 | accuracy = val_correct.item() / val_num 576 | # with open('output/test_{}.txt'.format(mode), 'w', encoding='utf-8') as f: 577 | # f.write('accuracy: ' + str(accuracy)) 578 | 579 | if is_training: 580 | model.train() 581 | 582 | print('mode:' + mode + ' | acc: ' + str(accuracy)) 583 | 584 | 585 | def test(mode, model_path, args): 586 | """ Performs decoding on a test set, and save the best-scoring decoding results. 587 | 588 | """ 589 | assert mode in MODEs, 'Invalid mode!' 590 | print('mode:', mode) 591 | print("load test data...") 592 | if mode == 'cls3': 593 | with open(DATASET_TEST_CLS3, 'rb') as f: 594 | dataset_test = pickle.load(f) 595 | elif mode == 'cls18': 596 | with open(DATASET_TEST_CLS18, 'rb') as f: 597 | dataset_test = pickle.load(f) 598 | else: 599 | with open(DATASET_TEST_CLS3, 'rb') as f: 600 | dataset_cls3 = pickle.load(f) 601 | with open(DATASET_TEST_CLS3, 'rb') as f: 602 | dataset_cls18 = pickle.load(f) 603 | dataset_test = data.ConcatDataset([dataset_cls3, dataset_cls18]) 604 | 605 | print("load model from {}".format(model_path)) 606 | model = Seq2seq.load(model_path) 607 | 608 | if USE_CUDA: 609 | print('use device: %s' % device, file=sys.stderr) 610 | model = model.to(device) 611 | if GPU_PARALLEL: 612 | print("Let's use", torch.cuda.device_count(), "GPUs!") 613 | model = torch.nn.DataParallel(model, device_ids=[0, 1]) 614 | 615 | if mode in ('summ', 'qa'): 616 | evaluate_summ_qa(model, dataset_test, mode, batch_size=128) 617 | else: 618 | evaluate_cls(model, dataset_test, mode, batch_size=512) 619 | 620 | 621 | def beam_search(mode: str, model: Seq2seq, test_data_src: List[List[str]], beam_size: int = 5, 622 | max_decoding_time_step: int = 100): 623 | """ Run beam search to construct hypotheses for a list of src-language sentences. 624 | @param model (NMT): NMT Model 625 | @param test_data_src (List[List[str]]): List of sentences (words) in source language, from test set. 626 | @param beam_size (int): beam_size (# of hypotheses to hold for a translation at every step) 627 | @param max_decoding_time_step (int): maximum sentence length that Beam search can produce 628 | @returns hypotheses (List[str]): List of Hypothesis for every source sentence. 629 | """ 630 | model.eval() 631 | 632 | hypotheses = [] 633 | with torch.no_grad(): 634 | for src_sent in test_data_src: 635 | example_hyps = model.beam_search(mode, src_sent, beam_size=beam_size, 636 | max_decoding_time_step=max_decoding_time_step) 637 | hypotheses.append(example_hyps) 638 | 639 | if is_training: 640 | model.train() 641 | 642 | # with open('output/check_{}.txt'.format(mode), 'w', encoding='UTF-8') as f: 643 | # for i in range(50): 644 | # f.write('Source: ' + ''.join(test_data_src[i])) 645 | # f.write('\n') 646 | # f.write('Output: ' + hypotheses[i]) 647 | # f.write('\n-------------------------\n') 648 | hypotheses = [re.sub(r'|', '', sent) for sent in hypotheses] 649 | return hypotheses 650 | 651 | 652 | def single_or_finetune(**kwargs): 653 | parameters = kwargs 654 | assert parameters['mode'] in MODEs 655 | print('Loading dataset...') 656 | if parameters['mode'] == 'cls3': 657 | with open(DATASET_TRAIN_CLS3, 'rb') as f: 658 | dataset = pickle.load(f) 659 | dataloader = torch.utils.data.DataLoader(dataset=dataset, 660 | batch_size=parameters['batch_size'], 661 | shuffle=True, 662 | collate_fn=lambda x: x) 663 | elif parameters['mode'] == 'cls18': 664 | with open(DATASET_TRAIN_CLS18, 'rb') as f: 665 | dataset = pickle.load(f) 666 | dataloader = torch.utils.data.DataLoader(dataset=dataset, 667 | batch_size=parameters['batch_size'], 668 | shuffle=False, 669 | collate_fn=lambda x: x) 670 | else: 671 | with open(DATASET_TRAIN_CLS3, 'rb') as f: 672 | dataset_cls3 = pickle.load(f) 673 | with open(DATASET_TRAIN_CLS18, 'rb') as f: 674 | dataset_cls18 = pickle.load(f) 675 | dataset = data.ConcatDataset([dataset_cls3, dataset_cls18]) 676 | dataloader = torch.utils.data.DataLoader(dataset=dataset, 677 | batch_size=parameters['batch_size'], 678 | shuffle=True, 679 | collate_fn=lambda x: x) 680 | 681 | print('Loading vocab...') 682 | vocab = Vocab.load(vocab_file) 683 | 684 | print('Loading embeddings...') 685 | with open(embeddings_file, 'rb') as f: 686 | embeddings = pickle.load(f) 687 | print('-----OK-----') 688 | 689 | if not os.path.exists(parameters['model_save_path']): 690 | print('create dir: {}'.format(parameters['model_save_path'])) 691 | os.mkdir(parameters['model_save_path']) 692 | 693 | if parameters['task'] == 'finetune' and parameters['model_load_path']: 694 | print('Loading model from {}...'.format(parameters['model_load_path'])) 695 | model = Seq2seq.load(parameters['model_load_path']) 696 | elif parameters['task'] == 'single': 697 | model = Seq2seq(hidden_size=200, vocab=vocab, embddings=embeddings, 698 | enc_num_layers=1, dec_num_layers=1) 699 | else: 700 | raise RuntimeError('Parameters error!') 701 | 702 | if USE_CUDA: 703 | print('use device: %s' % device, file=sys.stderr) 704 | model = model.to(device) 705 | if GPU_PARALLEL: 706 | print("Let's use", torch.cuda.device_count(), "GPUs!") 707 | model = model.to('cuda:0') 708 | model = torch.nn.DataParallel(model, device_ids=[0, 1]) 709 | 710 | model.train() 711 | 712 | optimizer = torch.optim.Adam(model.parameters(), lr=parameters['lr']) 713 | CELoss = nn.CrossEntropyLoss() 714 | 715 | epoch = train_iter = 0 716 | report_iter_num = 0 717 | report_loss = cum_examples = 0 718 | best_results = 0 719 | 720 | print('Performing {} task, mode: {}.'.format(parameters['task'], parameters['mode'])) 721 | begin_time = time.time() 722 | while True: 723 | epoch += 1 724 | for mini_batch in dataloader: 725 | train_iter += 1 726 | report_iter_num += 1 727 | 728 | # -------------------------------------------------------------------- 729 | # summ 730 | if parameters['mode'] == 'summ': 731 | 732 | optimizer.zero_grad() 733 | 734 | question = [data['question'] for data in mini_batch] 735 | description = [data['description'] for data in mini_batch] 736 | 737 | for i in range(len(question)): 738 | question[i].insert(0, '') 739 | question[i].insert(len(question[i]), '') 740 | 741 | try: 742 | example_losses_summ = -model(description, question, mode='summ') 743 | batch_loss_summ = example_losses_summ.sum() 744 | loss_summ = batch_loss_summ / len(mini_batch) * SUMM_WEIGHT 745 | loss_summ.backward() 746 | 747 | report_loss += batch_loss_summ.item() 748 | cum_examples += len(mini_batch) 749 | except RuntimeError as e: 750 | if 'out of memory' in str(e): 751 | print('| WARNING: ran out of memory') 752 | if hasattr(torch.cuda, 'empty_cache'): 753 | torch.cuda.empty_cache() 754 | else: 755 | raise e 756 | 757 | # clip gradient 758 | torch.nn.utils.clip_grad_norm_(model.parameters(), 25) 759 | 760 | optimizer.step() 761 | 762 | # -------------------------------------------------------------------- 763 | # QA 764 | if parameters['mode'] == 'qa': 765 | optimizer.zero_grad() 766 | 767 | question = [data['question'] for data in mini_batch] 768 | answer = [data['answer'] for data in mini_batch] 769 | 770 | for i in range(len(answer)): 771 | answer[i].insert(0, '') 772 | answer[i].insert(len(answer[i]), '') 773 | 774 | try: 775 | example_losses_qa = -model(question, answer, mode='qa') 776 | batch_loss_qa = example_losses_qa.sum() 777 | loss_qa = batch_loss_qa / len(mini_batch) * QA_WEIGHT 778 | loss_qa.backward() 779 | 780 | report_loss += batch_loss_qa.item() 781 | cum_examples += len(mini_batch) 782 | except RuntimeError as e: 783 | if 'out of memory' in str(e): 784 | print('| WARNING: ran out of memory') 785 | if hasattr(torch.cuda, 'empty_cache'): 786 | torch.cuda.empty_cache() 787 | else: 788 | raise e 789 | 790 | # clip gradient 791 | torch.nn.utils.clip_grad_norm_(model.parameters(), 25) 792 | 793 | optimizer.step() 794 | 795 | # -------------------------------------------------------------------- 796 | # cls3 797 | if parameters['mode'] == 'cls3': 798 | optimizer.zero_grad() 799 | 800 | question = [data['question'] for data in mini_batch] 801 | description = [data['description'] for data in mini_batch] 802 | category = torch.tensor([data['category'] for data in mini_batch]).to(device) 803 | 804 | y_pred = model(source=description, source2=question, target=None, mode='cls3') 805 | loss_cls3 = CELoss(y_pred, category) * CLS_WEIGHT 806 | 807 | try: 808 | loss_cls3.backward() 809 | except RuntimeError as e: 810 | if 'out of memory' in str(e): 811 | print('| WARNING: ran out of memory') 812 | if hasattr(torch.cuda, 'empty_cache'): 813 | torch.cuda.empty_cache() 814 | else: 815 | raise e 816 | 817 | # clip gradient 818 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 819 | 820 | optimizer.step() 821 | 822 | report_loss += loss_cls3.item() 823 | cum_examples += len(mini_batch) 824 | 825 | # -------------------------------------------------------------------- 826 | # cls18 827 | if parameters['mode'] == 'cls18': 828 | optimizer.zero_grad() 829 | 830 | question = [data['question'] for data in mini_batch] 831 | description = [data['description'] for data in mini_batch] 832 | category = torch.tensor([data['category'] for data in mini_batch]).to(device) 833 | 834 | y_pred = model(source=description, source2=question, target=None, mode='cls18') 835 | loss_cls18 = CELoss(y_pred, category) * CLS_WEIGHT 836 | 837 | try: 838 | loss_cls18.backward() 839 | except RuntimeError as e: 840 | if 'out of memory' in str(e): 841 | print('| WARNING: ran out of memory') 842 | if hasattr(torch.cuda, 'empty_cache'): 843 | torch.cuda.empty_cache() 844 | else: 845 | raise e 846 | 847 | # clip gradient 848 | torch.nn.utils.clip_grad_norm_(model.parameters(), 20) 849 | 850 | optimizer.step() 851 | 852 | report_loss += loss_cls18.item() 853 | cum_examples += len(mini_batch) 854 | 855 | if train_iter % log_every == 0: 856 | print('-' * 50) 857 | print('epoch:', epoch) 858 | print('iters:', train_iter) 859 | kwargs_summ = { 860 | 'report_loss': report_loss, 861 | 'report_iter_num': report_iter_num, 862 | 'cum_examples': cum_examples, 863 | 'num_of_train_set': len(dataset), 864 | 'begin_time': begin_time 865 | } 866 | report(mode='summ', **kwargs_summ) 867 | 868 | print('-' * 50) 869 | report_loss = report_iter_num = 0 870 | 871 | if train_iter % valid_niter == 0: 872 | print('begin validation ...', file=sys.stderr) 873 | save_model = False 874 | results = valid(model, mode=parameters['mode']) 875 | i = MODEs.index(parameters['mode']) 876 | 877 | if i < 2 and isinstance(results[i], dict) and results[i]['rouge-l']['f'] > best_results: 878 | save_model = True 879 | best_results = results[i]['rouge-l']['f'] 880 | elif i >= 2 and results[i] > best_results: 881 | save_model = True 882 | best_results = results[i] 883 | 884 | if save_model: 885 | print('get better score!') 886 | Seq2seq.save(model, parameters['model_save_path'] + 'model_{}_{}.pt'.format(parameters['mode'], 887 | train_iter)) 888 | 889 | cum_examples = 0 890 | # END one epoch. 891 | 892 | if epoch == max_epoch: 893 | print('reached maximum number of epochs!', file=sys.stderr) 894 | exit(0) 895 | 896 | 897 | def main(): 898 | # set the random number generators 899 | torch.manual_seed(seed) 900 | if USE_CUDA: 901 | torch.cuda.manual_seed(seed) 902 | torch.backends.cudnn.deterministic = True 903 | torch.backends.cudnn.benchmark = False 904 | 905 | stage = 'train' # 'train'||'test'||'single'||'finetune' 906 | global is_training 907 | 908 | if stage == 'train': 909 | is_training = True 910 | train() 911 | elif stage == 'test': 912 | is_training = False 913 | test(mode='qa', model_path='./checkpoints_seed_2019/model_13000.pt') 914 | elif stage in ['finetune', 'single']: 915 | is_training = True 916 | parameters = { 917 | 'lr': 1e-3, 918 | 'task': stage, 919 | 'mode': 'cls18', 920 | 'model_load_path': './checkpoints_seed_2019/model_13000.pt', 921 | 'model_save_path': './{}_seed_{}/'.format(stage, seed), 922 | 'batch_size': 128, 923 | } 924 | single_or_finetune(**parameters) 925 | else: 926 | raise RuntimeError('invalid run mode') 927 | exit(0) 928 | 929 | 930 | if __name__ == '__main__': 931 | main() 932 | -------------------------------------------------------------------------------- /seq2seq_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | import torch.nn.functional as F 8 | import sys 9 | from typing import List, Tuple 10 | from collections import namedtuple 11 | 12 | Hypothesis = namedtuple('Hypothesis', ['value', 'score']) 13 | MODE = ('summ', 'qa', 'cls3', 'cls18') 14 | 15 | 16 | class Seq2seq(nn.Module): 17 | def __init__(self, hidden_size, vocab, embddings, enc_num_layers=1, dec_num_layers=1): 18 | super(Seq2seq, self).__init__() 19 | self.hidden_size = hidden_size 20 | self.vocab = vocab 21 | self.embeddings = embddings 22 | self.embeddings.weight.requires_grad = True 23 | self.embed_size = self.embeddings.weight.shape[1] 24 | self.enc_num_layers = enc_num_layers 25 | self.dec_num_layers = dec_num_layers 26 | 27 | # 1 represents 'Summarization' task. 28 | self.encoder1 = nn.LSTM(self.embed_size, self.hidden_size, self.enc_num_layers) 29 | self.decoder1 = nn.LSTMCell(self.embed_size, self.hidden_size) 30 | self.att_projection1 = nn.Linear(self.hidden_size, self.hidden_size) 31 | self.h_projection1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 32 | self.c_projection1 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 33 | self.combined_output_projection1 = nn.Linear(2 * self.hidden_size, self.hidden_size, bias=False) 34 | self.dropout1 = nn.Dropout(0.2) 35 | self.target_vocab_projection1 = nn.Linear(self.hidden_size, self.vocab.size()) 36 | 37 | # 2 represents 'QA' task. 38 | self.encoder2 = self.decoder1 39 | self.decoder2 = nn.LSTMCell(self.embed_size, self.hidden_size) 40 | self.h_projection2 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 41 | self.c_projection2 = nn.Linear(self.hidden_size, self.hidden_size, bias=False) 42 | self.att_projection2 = nn.Linear(self.hidden_size, self.hidden_size) 43 | self.combined_output_projection2 = nn.Linear(2 * self.hidden_size, self.hidden_size, bias=False) 44 | self.dropout2 = nn.Dropout(0.2) 45 | self.target_vocab_projection2 = nn.Linear(self.hidden_size, self.vocab.size()) 46 | 47 | self.cls_dropout = nn.Dropout(0.2) 48 | self.fc_share = nn.Linear(self.hidden_size * 2, self.hidden_size) 49 | # where 3 represents 3-classification; 18 represents 18-classification. 50 | self.fc3 = nn.Linear(self.hidden_size, 3) 51 | self.fc18 = nn.Linear(self.hidden_size, 18) 52 | 53 | def forward(self, source, target, mode, source2=None): 54 | """ 55 | :param 56 | source: (list[list[str]]) 57 | target: ([list[list[str]]]) 58 | :return 59 | scores (b,): Array of log-likelihoods of target sentences. (for summ & QA tasks) 60 | OR y_pred (b, 3|18) (for classification tasks) 61 | """ 62 | assert mode in MODE, 'unrecognized mode!' 63 | if mode == 'summ': 64 | source_lengths = [len(s) for s in source] 65 | source_padded = self.vocab.to_input_tensor(source, device=self.device) 66 | target_padded = self.vocab.to_input_tensor(target, device=self.device) 67 | 68 | enc_hiddens, dec_init_state, _ = self.encode_summ(source_padded, source_lengths) 69 | enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths) 70 | combined_outputs = self.decode_summ(dec_init_state, target_padded, enc_hiddens, enc_masks) 71 | 72 | P = F.log_softmax(self.target_vocab_projection1(combined_outputs), 73 | dim=-1) 74 | target_masks = (target_padded != self.vocab.word2id['']).float() 75 | 76 | # Compute log probability of generating true target words 77 | target_gold_words_log_prob = torch.gather(P, index=(target_padded[1:]).unsqueeze(-1), dim=-1).squeeze( 78 | -1) * target_masks[1:] 79 | scores = target_gold_words_log_prob.sum(dim=0) 80 | return scores 81 | elif mode == 'qa': 82 | source_lengths = [len(s) for s in source] 83 | source_padded = self.vocab.to_input_tensor(source, device=self.device) 84 | target_padded = self.vocab.to_input_tensor(target, device=self.device) 85 | 86 | enc_hiddens, dec_init_state, _ = self.encode_qa(source_padded, source_lengths) 87 | enc_masks = self.generate_sent_masks(enc_hiddens, source_lengths) 88 | combined_outputs = self.decode_qa(dec_init_state, target_padded, enc_hiddens, enc_masks) 89 | 90 | P = F.log_softmax(self.target_vocab_projection2(combined_outputs), 91 | dim=-1) # (max_len_sents, batch_size, vocab_size) 92 | 93 | # Zero out, probabilities for which we have nothing in the target text 94 | target_masks = (target_padded != self.vocab.word2id['']).float() 95 | 96 | # Compute log probability of generating true target words 97 | target_gold_words_log_prob = torch.gather(P, index=(target_padded[1:]).unsqueeze(-1), dim=-1).squeeze( 98 | -1) * target_masks[1:] 99 | scores = target_gold_words_log_prob.sum(dim=0) 100 | return scores 101 | elif mode.startswith('cls'): 102 | source_lengths1 = [len(s) for s in source] 103 | source_padded1 = self.vocab.to_input_tensor(source, device=self.device) 104 | 105 | source_lengths2 = [len(s) for s in source2] 106 | source_padded2 = self.vocab.to_input_tensor(source2, device=self.device) 107 | 108 | _, _, last_hidden1 = self.encode_summ(source_padded1, source_lengths1) 109 | _, _, last_hidden2 = self.encode_qa(source_padded2, source_lengths2) 110 | x1 = torch.squeeze(last_hidden1, dim=0) 111 | x = torch.cat((x1, last_hidden2), dim=1) 112 | x = self.fc_share(x) 113 | x = self.cls_dropout(x) 114 | if mode == 'cls3': 115 | y_pred = self.fc3(x) 116 | else: 117 | y_pred = self.fc18(x) 118 | 119 | return y_pred 120 | 121 | def encode_summ(self, source_padded, source_lengths): 122 | """Apply encoder to source_padded to obtain the hidden states 123 | """ 124 | X = self.embeddings(source_padded) 125 | self.encoder1.flatten_parameters() 126 | enc_hiddens, (last_hidden, last_cell) = self.encoder1( 127 | pack_padded_sequence(X, source_lengths, enforce_sorted=False)) 128 | enc_hiddens = pad_packed_sequence(enc_hiddens, batch_first=True)[0] 129 | 130 | init_decoder_hidden = self.h_projection1(torch.squeeze(last_hidden, dim=0)) 131 | init_decoder_cell = self.c_projection1(torch.squeeze(last_cell, dim=0)) 132 | dec_init_state = (init_decoder_hidden, init_decoder_cell) 133 | 134 | return enc_hiddens, dec_init_state, last_hidden 135 | 136 | def encode_qa(self, source_padded, source_lengths): 137 | """Apply encoder to source_padded to obtain the final state of encoder. 138 | """ 139 | X = self.embeddings(source_padded) 140 | max_length = X.size()[0] 141 | batch_size = X.size()[1] 142 | embed_size = X.size()[2] 143 | enc_hiddens = torch.zeros(X.size()).to(self.device) 144 | last_hidden = None 145 | last_cell = None 146 | h_t = torch.zeros(batch_size, embed_size).to(self.device) 147 | c_t = torch.zeros(batch_size, embed_size).to(self.device) 148 | i = 0 149 | for x_t in torch.split(X, 1, dim=0): 150 | x_t = torch.squeeze(x_t, dim=0) 151 | (h_t, c_t) = self.encoder2(x_t, (h_t, c_t)) 152 | enc_hiddens[i] = h_t 153 | i = i + 1 154 | if i == max_length: 155 | last_hidden = h_t 156 | last_cell = c_t 157 | init_decoder_hidden = self.h_projection2(last_hidden) 158 | init_decoder_cell = self.c_projection2(last_cell) 159 | dec_init_state = (init_decoder_hidden, init_decoder_cell) 160 | 161 | return enc_hiddens.permute(1, 0, 2), dec_init_state, last_hidden 162 | 163 | def decode_summ(self, dec_init_state: (torch.Tensor, torch.Tensor), target_padded: torch.Tensor, 164 | enc_hiddens: torch.Tensor, enc_masks: torch.Tensor): 165 | # Chop of the token for max length sentences. 166 | batch_size = target_padded.size()[1] 167 | target_padded = target_padded[:-1] 168 | 169 | o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device) 170 | dec_state = dec_init_state 171 | 172 | combined_outputs = [] 173 | 174 | enc_hiddens_proj = self.att_projection1(enc_hiddens) 175 | Y = self.embeddings(target_padded) 176 | 177 | for Y_t in torch.split(Y, 1, dim=0): 178 | Y_t = torch.squeeze(Y_t, dim=0) 179 | Ybar_t = torch.add(Y_t, o_prev) 180 | dec_state, o_t, e_t = self.step_summ(Ybar_t, 181 | dec_state, 182 | enc_hiddens, 183 | enc_hiddens_proj, 184 | enc_masks) 185 | combined_outputs.append(o_t) 186 | o_prev = o_t 187 | 188 | combined_outputs = torch.stack(combined_outputs, dim=0) 189 | return combined_outputs 190 | 191 | def decode_qa(self, dec_init_state: (torch.Tensor, torch.Tensor), target_padded: torch.Tensor, 192 | enc_hiddens: torch.Tensor, enc_masks: torch.Tensor): 193 | 194 | # Chop of the token for max length sentences. 195 | batch_size = target_padded.size()[1] 196 | target_padded = target_padded[:-1] 197 | 198 | o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device) 199 | dec_state = dec_init_state 200 | 201 | # Initialize a list we will use to collect the combined output o_t on each step 202 | combined_outputs = [] 203 | 204 | enc_hiddens_proj = self.att_projection2(enc_hiddens) 205 | Y = self.embeddings(target_padded) 206 | 207 | for Y_t in torch.split(Y, 1, dim=0): 208 | Y_t = torch.squeeze(Y_t, dim=0) 209 | Ybar_t = torch.add(Y_t, o_prev) 210 | dec_state, o_t, e_t = self.step_qa(Ybar_t, 211 | dec_state, 212 | enc_hiddens, 213 | enc_hiddens_proj, 214 | enc_masks) 215 | combined_outputs.append(o_t) 216 | o_prev = o_t 217 | 218 | combined_outputs = torch.stack(combined_outputs, dim=0) 219 | return combined_outputs 220 | 221 | def step_summ(self, Ybar_t: torch.Tensor, 222 | dec_state: Tuple[torch.Tensor, torch.Tensor], 223 | enc_hiddens: torch.Tensor, 224 | enc_hiddens_proj: torch.Tensor, 225 | enc_masks: torch.Tensor): 226 | """One forward step of the decoder. 227 | :param Y_t: (batch_size, embed_size) The first tokens of each of the mini-batch of sents. 228 | :param dec_state: ... 229 | :returns dec_state: the current state of decoder. 230 | :returns output: the current hidden state of decoder. 231 | """ 232 | dec_state = self.decoder1(Ybar_t, dec_state) 233 | dec_hidden, dec_cell = dec_state 234 | e_t = torch.squeeze( 235 | torch.bmm(enc_hiddens_proj, torch.unsqueeze(dec_hidden, 2)), dim=2) 236 | 237 | # Set e_t to -inf where enc_masks has 1 238 | if enc_masks is not None: 239 | e_t.data.masked_fill_(enc_masks.bool(), -float('inf')) 240 | 241 | alpha_t = F.softmax(e_t, dim=1) 242 | a_t = torch.squeeze( 243 | torch.bmm(torch.unsqueeze(alpha_t, 1), enc_hiddens), 244 | 1) 245 | U_t = torch.cat((a_t, dec_hidden), dim=1) 246 | V_t = self.combined_output_projection1(U_t) 247 | O_t = self.dropout1(torch.tanh(V_t)) 248 | 249 | combined_output = O_t 250 | return dec_state, combined_output, e_t 251 | 252 | def step_qa(self, Ybar_t: torch.Tensor, 253 | dec_state: Tuple[torch.Tensor, torch.Tensor], 254 | enc_hiddens: torch.Tensor, 255 | enc_hiddens_proj: torch.Tensor, 256 | enc_masks: torch.Tensor): 257 | """One forward step of the decoder. 258 | :param Y_t: (batch_size, embed_size) The first tokens of each of the mini-batch of sents. 259 | :param dec_state: ... 260 | :returns dec_state: the current state of decoder. 261 | :returns output: the current hidden state of decoder. 262 | """ 263 | 264 | dec_state = self.decoder2(Ybar_t, dec_state) 265 | dec_hidden, dec_cell = dec_state 266 | e_t = torch.squeeze( 267 | torch.bmm(enc_hiddens_proj, torch.unsqueeze(dec_hidden, 2)), 268 | dim=2) 269 | 270 | # Set e_t to -inf where enc_masks has 1 271 | if enc_masks is not None: 272 | e_t.data.masked_fill_(enc_masks.bool(), -float('inf')) 273 | 274 | alpha_t = F.softmax(e_t, dim=1) 275 | a_t = torch.squeeze( 276 | torch.bmm(torch.unsqueeze(alpha_t, 1), enc_hiddens), 277 | 1) 278 | U_t = torch.cat((a_t, dec_hidden), dim=1) 279 | V_t = self.combined_output_projection2(U_t) 280 | O_t = self.dropout2(torch.tanh(V_t)) 281 | 282 | combined_output = O_t 283 | return dec_state, combined_output, e_t 284 | 285 | @property 286 | def device(self) -> torch.device: 287 | return self.embeddings.weight.device 288 | 289 | def predict(self, src_sent: List[List[str]]): 290 | """Predict the output sentence according to the src_sent. 291 | """ 292 | 293 | src_sent_tensor = self.vocab.to_input_tensor(src_sent, self.device) 294 | dec_init_state = self.encode_summ(src_sent_tensor, [len(s) for s in src_sent]) 295 | 296 | dec_state = dec_init_state 297 | batch_size = dec_state[0].size()[0] 298 | 299 | hypotheses = [''] * batch_size 300 | 301 | flags = [False] * batch_size 302 | y_t = [['']] * batch_size 303 | y_t = self.vocab.to_input_tensor(y_t, device=self.device) 304 | y_t = self.embeddings(y_t) 305 | 306 | stop = False 307 | MAX_SENT = 50 308 | count = 0 309 | while not stop: 310 | count += 1 311 | stop = True 312 | y_t = torch.squeeze(y_t, dim=0) 313 | dec_state, output = self.step_summ(y_t, dec_state) 314 | top1_idxs = torch.argmax(F.log_softmax(self.target_vocab_projection1(output), dim=-1), -1) 315 | top1_idxs = top1_idxs.tolist() # Convert tensor to list with length of batch_size. 316 | current_words = [self.vocab.id2word[id] for id in top1_idxs] 317 | for i in range(len(current_words)): 318 | if current_words[i] == '': 319 | flags[i] = True 320 | if not flags[i]: 321 | hypotheses[i] = hypotheses[i] + current_words[i] 322 | for f in flags: 323 | if not f: 324 | stop = False 325 | if count >= MAX_SENT: 326 | break 327 | y_t = [[hyp[-1]] for hyp in hypotheses] 328 | y_t = self.vocab.to_input_tensor(y_t, device=self.device) 329 | y_t = self.embeddings(y_t) 330 | 331 | return hypotheses 332 | 333 | def beam_search(self, mode: str, src_sent: List[str], beam_size: int = 5, max_decoding_time_step: int = 70): 334 | """ Given a single source sentence, perform beam search, yielding translations in the target language. 335 | @param src_sent (List[str]): a single source sentence (words) 336 | @param beam_size (int): beam size 337 | @param max_decoding_time_step (int): maximum number of time steps to unroll the decoding RNN 338 | @returns hypotheses (List[Hypothesis]): a list of hypothesis, each hypothesis has two fields: 339 | value: List[str]: the decoded target sentence, represented as a list of words 340 | score: float: the log-likelihood of the target sentence 341 | """ 342 | assert mode in ['summ', 'qa'] 343 | src_sents_var = self.vocab.to_input_tensor([src_sent], self.device) 344 | if mode == 'summ': 345 | src_encodings, dec_init_vec, _ = self.encode_summ(src_sents_var, [len(src_sent)]) 346 | src_encodings_att_linear = self.att_projection1(src_encodings) 347 | else: 348 | src_encodings, dec_init_vec, _ = self.encode_qa(src_sents_var, [len(src_sent)]) 349 | src_encodings_att_linear = self.att_projection2(src_encodings) 350 | 351 | h_tm1 = dec_init_vec 352 | att_tm1 = torch.zeros(1, self.hidden_size, device=self.device) 353 | 354 | hypotheses = [['']] 355 | hyp_scores = torch.zeros(len(hypotheses), dtype=torch.float, device=self.device) 356 | completed_hypotheses = [] 357 | 358 | t = 0 359 | while len(completed_hypotheses) < beam_size and t < max_decoding_time_step: 360 | t += 1 361 | hyp_num = len(hypotheses) 362 | 363 | exp_src_encodings = src_encodings.expand(hyp_num, 364 | src_encodings.size(1), 365 | src_encodings.size(2)) 366 | 367 | exp_src_encodings_att_linear = src_encodings_att_linear.expand(hyp_num, 368 | src_encodings_att_linear.size(1), 369 | src_encodings_att_linear.size(2)) 370 | 371 | y_tm1 = torch.tensor([self.vocab.word2id[hyp[-1]] for hyp in hypotheses], dtype=torch.long, 372 | device=self.device) 373 | y_t_embed = self.embeddings(y_tm1) 374 | x = torch.add(y_t_embed, att_tm1) 375 | 376 | if mode == 'summ': 377 | (h_t, cell_t), att_t, _ = self.step_summ(x, h_tm1, 378 | exp_src_encodings, exp_src_encodings_att_linear, 379 | enc_masks=None) 380 | log_p_t = F.log_softmax(self.target_vocab_projection1(att_t), dim=-1) 381 | else: 382 | (h_t, cell_t), att_t, _ = self.step_qa(x, h_tm1, 383 | exp_src_encodings, exp_src_encodings_att_linear, 384 | enc_masks=None) 385 | log_p_t = F.log_softmax(self.target_vocab_projection2(att_t), dim=-1) 386 | 387 | live_hyp_num = beam_size - len(completed_hypotheses) 388 | contiuating_hyp_scores = (hyp_scores.unsqueeze(1).expand_as(log_p_t) + log_p_t).view(-1) 389 | top_cand_hyp_scores, top_cand_hyp_pos = torch.topk(contiuating_hyp_scores, k=live_hyp_num) 390 | 391 | prev_hyp_ids = top_cand_hyp_pos / self.vocab.size() 392 | hyp_word_ids = top_cand_hyp_pos % self.vocab.size() 393 | 394 | new_hypotheses = [] 395 | live_hyp_ids = [] 396 | new_hyp_scores = [] 397 | 398 | for prev_hyp_id, hyp_word_id, cand_new_hyp_score in zip(prev_hyp_ids, hyp_word_ids, top_cand_hyp_scores): 399 | prev_hyp_id = prev_hyp_id.item() 400 | hyp_word_id = hyp_word_id.item() 401 | cand_new_hyp_score = cand_new_hyp_score.item() 402 | 403 | hyp_word = self.vocab.id2word[hyp_word_id] 404 | new_hyp_sent = hypotheses[prev_hyp_id] + [hyp_word] 405 | if hyp_word == '': 406 | completed_hypotheses.append(Hypothesis(value=new_hyp_sent[1:-1], 407 | score=cand_new_hyp_score)) 408 | else: 409 | new_hypotheses.append(new_hyp_sent) 410 | live_hyp_ids.append(prev_hyp_id) 411 | new_hyp_scores.append(cand_new_hyp_score) 412 | 413 | if len(completed_hypotheses) == beam_size: 414 | break 415 | 416 | live_hyp_ids = torch.tensor(live_hyp_ids, dtype=torch.long, device=self.device) 417 | h_tm1 = (h_t[live_hyp_ids], cell_t[live_hyp_ids]) 418 | att_tm1 = att_t[live_hyp_ids] 419 | 420 | hypotheses = new_hypotheses 421 | hyp_scores = torch.tensor(new_hyp_scores, dtype=torch.float, device=self.device) 422 | 423 | if len(completed_hypotheses) == 0: 424 | completed_hypotheses.append(Hypothesis(value=hypotheses[0][1:], 425 | score=hyp_scores[0].item())) 426 | 427 | completed_hypotheses.sort(key=lambda hyp: hyp.score, reverse=True) 428 | 429 | hypothesis = Hypothesis(value=None, score=-float('inf')) 430 | for hypo in completed_hypotheses: 431 | if hypo.score > hypothesis.score: 432 | hypothesis = hypo 433 | 434 | return ''.join(hypothesis.value) 435 | 436 | def generate_sent_masks(self, enc_hiddens: torch.Tensor, source_lengths: List[int]) -> torch.Tensor: 437 | """ Generate sentence masks for encoder hidden states. 438 | 439 | @param enc_hiddens (Tensor): encodings of shape (b, src_len, 2*h), where b = batch size, 440 | src_len = max source length, h = hidden size. 441 | @param source_lengths (List[int]): List of actual lengths for each of the sentences in the batch. 442 | 443 | @returns enc_masks (Tensor): Tensor of sentence masks of shape (b, src_len), 444 | where src_len = max source length, h = hidden size. 445 | """ 446 | enc_masks = torch.zeros(enc_hiddens.size(0), enc_hiddens.size(1), dtype=torch.float, device=self.device) 447 | for e_id, src_len in enumerate(source_lengths): 448 | enc_masks[e_id, src_len:] = 1 449 | return enc_masks.to(self.device) 450 | 451 | @staticmethod 452 | def load(model_path: str): 453 | """ Load the model from a file. 454 | """ 455 | model = torch.load(model_path, map_location=torch.device('cuda:0')) 456 | 457 | return model 458 | 459 | @staticmethod 460 | def save(model, path: str): 461 | """ Save the model to a file. 462 | """ 463 | print('save the model to [%s]' % path, file=sys.stderr) 464 | torch.save(model, path) 465 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import collections 4 | from typing import List 5 | 6 | from rouge import Rouge 7 | import torch 8 | import torch.nn 9 | import math 10 | import numpy as np 11 | import pandas as pd 12 | import pickle 13 | from dataset import Dataset 14 | 15 | CLASS3_NAME_TO_INDEX = { 16 | '0-1岁': 0, 17 | '1-2岁': 1, 18 | '2-3岁': 2 19 | } 20 | 21 | CLASS18_NAME_TO_INDEX = { 22 | '动作发育': 0, 23 | '幼儿园': 1, 24 | '产褥期保健': 2, 25 | '婴幼常见病': 3, 26 | '家庭教育': 4, 27 | '未准父母': 5, 28 | '婴幼保健': 6, 29 | '婴幼期喂养': 7, 30 | '疫苗接种': 8, 31 | '腹泻': 9, 32 | '宝宝上火': 10, 33 | '婴幼心理': 11, 34 | '皮肤护理': 12, 35 | '流产和不孕': 13, 36 | '婴幼早教': 14, 37 | '儿童过敏': 15, 38 | '孕期保健': 16, 39 | '婴幼营养': 17 40 | } 41 | 42 | 43 | def pad_sents(sents, pad_token): 44 | """pad list of sentences according to the longest sent. 45 | """ 46 | sents_padded = [] 47 | max_len = max([len(sent) for sent in sents]) 48 | for s in sents: 49 | if len(s) < max_len: 50 | s_len = len(s) 51 | sents_padded.append(s + (max_len - s_len) * [pad_token]) 52 | else: 53 | sents_padded.append(s) 54 | return sents_padded 55 | 56 | 57 | def build_embeddings(file_path, vocab): 58 | with open(file_path, encoding='UTF-8') as f: 59 | line = f.readline().strip().split(' ') 60 | size, dim = vocab.size(), int(line[1]) 61 | weight_matrix = torch.randn((size, dim), dtype=torch.float) 62 | 63 | for line in f: 64 | line = line.rstrip().split(' ') 65 | if line[0] in vocab.word2id.keys(): 66 | weight = list(map(float, line[-dim:])) 67 | weight = torch.tensor(weight, dtype=torch.float) 68 | weight_matrix[vocab.word2id[line[0]]] = torch.unsqueeze(weight, dim=0) 69 | 70 | return torch.nn.Embedding.from_pretrained(weight_matrix) 71 | 72 | 73 | def read_data(file_path): 74 | """Read dataset file. 75 | """ 76 | dataset_cls3 = Dataset() 77 | dataset_cls18 = Dataset() 78 | max_len = 256 79 | num_summ_qa = num_cls3 = num_cls18 = 0 80 | 81 | data_table = pd.read_csv(file_path, sep=',', encoding='UTF-8') 82 | 83 | for i in range(0, len(data_table)): 84 | question = str(data_table.iat[i, 1]).strip() 85 | description = str(data_table.iat[i, 2]).strip() 86 | answer = str(data_table.iat[i, 3]).strip() 87 | category = str(data_table.iat[i, 4]).strip() 88 | 89 | if len(description) > max_len or len(answer) > max_len: 90 | print('Too long: ', str(data_table.iat[i, 0])) 91 | continue 92 | 93 | num_summ_qa += 1 94 | if category in CLASS3_NAME_TO_INDEX: 95 | num_cls3 += 1 96 | dataset_cls3.add_data(question, description, answer, category) 97 | elif category in CLASS18_NAME_TO_INDEX: 98 | num_cls18 += 1 99 | dataset_cls18.add_data(question, description, answer, category) 100 | else: 101 | print('Unexpected category! id:{}'.format(data_table.iat[i, 0])) 102 | continue 103 | 104 | print('samples num for sum and qa:', num_summ_qa) 105 | print('samples num for cls3:', num_cls3) 106 | print('samples num for cls18:', num_cls18) 107 | 108 | return dataset_cls3, dataset_cls18 109 | 110 | 111 | def cal_rouge(hyps:List[str],refs:List[str],avg:bool=False,ignore_empty:bool=False): 112 | """ 113 | :param hyps: List of hyps, each hyp is a 'str' consists of a sequence of tokens separated by spaces. 114 | :param refs: List of refs, each ref is a 'str' consists of a sequence of tokens separated by spaces. 115 | :param avg: If scoring multiple sentences, 'avg' should be 'True'. 116 | :param ignore_empty: Filter out hyps of 0 length. 117 | :return: 118 | scores: a single dict with average values (avg=True) or a list of n dicts (avg=False) 119 | a dict: 120 | {"rouge-1": {"f": _, "p": _, "r": _}, "rouge-2" : { .. }, "rouge-l": { ... }} 121 | """ 122 | rouge = Rouge() 123 | scores = rouge.get_scores(hyps,refs,avg,ignore_empty) 124 | return scores 125 | 126 | 127 | def batch_iter(data, batch_size, shuffle=False): 128 | """ Yield batches of source and target sentences reverse sorted by length (largest to smallest). 129 | @param data : 130 | (list of (src_sents, tgt_sents)) list of tuples containing source and target sentences. 131 | OR 132 | (list of src_sents) list of source sentences. 133 | @param batch_size (int): batch size 134 | @param shuffle (boolean): whether to randomly shuffle the dataset 135 | @return 136 | src_sents,tgt_sents: both list[list[str]] with length of batch_size. 137 | OR 138 | examples: (list[list[str]]) with length of batch_size. 139 | """ 140 | batch_num = math.ceil(len(data) / batch_size) 141 | index_array = list(range(len(data))) 142 | 143 | if shuffle: 144 | np.random.shuffle(index_array) 145 | 146 | if isinstance(data[0], tuple): 147 | for i in range(batch_num): 148 | indices = index_array[i * batch_size: (i + 1) * batch_size] 149 | examples = [data[idx] for idx in indices] 150 | 151 | examples = sorted(examples, key=lambda e: len(e[0]), reverse=True) 152 | src_sents = [e[0] for e in examples] 153 | tgt_sents = [e[1] for e in examples] 154 | 155 | yield src_sents, tgt_sents 156 | elif isinstance(data[0], list): 157 | for i in range(batch_num): 158 | indices = index_array[i * batch_size: (i + 1) * batch_size] 159 | examples = [data[idx] for idx in indices] 160 | examples = sorted(examples, key=lambda e: len(e), reverse=True) 161 | yield examples 162 | 163 | 164 | if __name__ == '__main__': 165 | print('read and split dataset...') 166 | for mode in ['train', 'dev', 'test']: 167 | dataset_cls3, dataset_cls18 = read_data('./data/{}.csv'.format(mode)) 168 | with open('./data/{}_{}.pkl'.format(mode, 'cls3'), 'wb') as f: 169 | pickle.dump(dataset_cls3, f) 170 | with open('./data/{}_{}.pkl'.format(mode, 'cls18'), 'wb') as f: 171 | pickle.dump(dataset_cls18, f) 172 | exit(0) 173 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import json 5 | import torch 6 | from utils import build_embeddings,pad_sents 7 | import pandas as pd 8 | import pickle 9 | 10 | 11 | class Vocab: 12 | def __init__(self, word2id=None): 13 | if word2id: 14 | self.word2id = word2id 15 | else: 16 | self.word2id = dict() 17 | self.word2id[''] = 0 18 | self.word2id[''] = 1 19 | self.word2id[''] = 2 20 | self.word2id[''] = 3 21 | self.special_num = 4 # the number of the above special tokens. 22 | 23 | self.id2word = {v: k for k, v in self.word2id.items()} 24 | 25 | def add(self, word): 26 | if word not in self.word2id.keys(): 27 | self.word2id[word] = len(self.word2id) 28 | self.id2word[self.word2id[word]] = word 29 | 30 | def size(self): 31 | return len(self.word2id) 32 | 33 | def save(self, file_path): 34 | with open(file_path, 'w') as f: 35 | json.dump(self.word2id, f) 36 | 37 | def word2indix(self, word:str): 38 | if word in self.word2id.keys(): 39 | return self.word2id[word] 40 | else: 41 | return self.word2id[''] 42 | 43 | def word2indices(self,sents): 44 | """Convert list of words or list of sentence of words into list or list of list indices. 45 | """ 46 | if type(sents[0]) == list: 47 | return [[self.word2indix(w) for w in s] for s in sents] 48 | else: 49 | return [self.word2indix(w) for w in sents] 50 | 51 | def indices2words(self, word_ids): 52 | """ Convert list of indices into words. 53 | """ 54 | return [self.id2word[w_id] for w_id in word_ids] 55 | 56 | def to_input_tensor(self, sents, device)->torch.Tensor: 57 | """Convert list of sentence into tensor with necessary padding for shorter sentence. 58 | """ 59 | sents_padded = pad_sents(sents,'') 60 | word_ids = self.word2indices(sents_padded) 61 | sents_torch = torch.tensor(word_ids, dtype=torch.long,device=device) 62 | return sents_torch.t() 63 | 64 | @staticmethod 65 | def build(file:str): 66 | vocab = Vocab() 67 | 68 | data = pd.read_csv(file, sep=',', encoding='UTF-8') 69 | for i in range(0, len(data)): 70 | sent = str(data.iat[i,1]) + str(data.iat[i,2]) + str(data.iat[i,3]) 71 | for word in sent: 72 | vocab.add(word) 73 | print('vocab size:', vocab.size()) 74 | return vocab 75 | 76 | @staticmethod 77 | def load(file_path): 78 | with open(file_path) as f: 79 | word2id = json.load(f) 80 | return Vocab(word2id) 81 | 82 | 83 | if __name__ == '__main__': 84 | print('Building vocab...') 85 | vocab = Vocab.build('./data/train.csv') 86 | vocab.save('./data/vocab.json') 87 | 88 | # print('Loading vocab...') 89 | # vocab = Vocab.load('./data/vocab.json') 90 | print('Building embeddings...') 91 | embeddings = build_embeddings('./data/ChineseEmbedding.txt', vocab) 92 | with open('./data/embeddings.pkl','wb') as f: 93 | pickle.dump(embeddings,f) --------------------------------------------------------------------------------