├── document_level ├── 1 ├── requirements.txt ├── train_sh.sh ├── generate_trie_dict.py ├── modeling.py ├── retriever_utils.py ├── utils.py ├── train_query_encoder.yaml └── train_query_encoder.py ├── passage-level ├── 1 ├── requirements.txt ├── train_sh.sh ├── generate_trie_dict.py ├── modeling.py ├── utils.py ├── retriever_utils.py ├── train_query_encoder.yaml └── train_query_encoder.py ├── LICENSE └── README.md /document_level/1: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /passage-level/1: -------------------------------------------------------------------------------- 1 | 1 2 | -------------------------------------------------------------------------------- /passage-level/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | transformers==4.15.0 3 | datasets==2.0.0 4 | pytrec_eval 5 | deepspeed==0.6.0 6 | ir_datasets==0.5.0 7 | pyserini==0.15.0 8 | tqdm 9 | numpy -------------------------------------------------------------------------------- /document_level/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | transformers==4.15.0 3 | datasets==2.0.0 4 | pytrec_eval 5 | deepspeed==0.6.0 6 | ir_datasets==0.5.0 7 | pyserini==0.15.0 8 | tqdm 9 | numpy -------------------------------------------------------------------------------- /document_level/train_sh.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | 5 | 6 | python3 train_query_encoder.py \ 7 | --do_train True \ 8 | --do_gen True \ 9 | --do_test True \ 10 | --load_small True \ 11 | --num_train_epochs 1 \ 12 | --per_gpu_train_batch_size 4 \ 13 | --per_gpu_eval_batch_size 4 \ 14 | --per_gpu_test_batch_size 2 \ 15 | --overwrite_output_dir True \ -------------------------------------------------------------------------------- /passage-level/train_sh.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | 4 | 5 | 6 | python3 train_query_encoder.py \ 7 | --do_train True \ 8 | --do_gen True \ 9 | --do_test True \ 10 | --load_small True \ 11 | --num_train_epochs 1 \ 12 | --per_gpu_train_batch_size 4 \ 13 | --per_gpu_eval_batch_size 4 \ 14 | --per_gpu_test_batch_size 2 \ 15 | --overwrite_output_dir True \ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 liyongqi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /document_level/generate_trie_dict.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tqdm import tqdm 3 | import re 4 | import json 5 | from transformers import T5Tokenizer 6 | from utils import Trie 7 | import pickle 8 | 9 | tokenizer = T5Tokenizer.from_pretrained("t5-base", do_lower_case=True, 10 | cache_dir="/home/v-yongqili/project/GCoQA/data/huggingface_cache/") 11 | page_title_dict = {} 12 | with open("/home/v-yongqili/project/GCoQA/data/full_wiki_segments.json", 'r') as f: 13 | data = f.readlines() 14 | for line in tqdm(data): 15 | line = json.loads(line) 16 | if line['title'].strip() not in page_title_dict: 17 | page_title_dict[line['title'].strip()] = 1 18 | 19 | print("page_title_dict len %s", len(page_title_dict)) 20 | 21 | title_sequence = [] 22 | for page_title in tqdm(page_title_dict): 23 | input_ids = tokenizer.encode( 24 | page_title, 25 | add_special_tokens=True, 26 | max_length=64, 27 | truncation=True) 28 | title_sequence.append([0] + input_ids) 29 | 30 | decoder_trie = Trie(title_sequence) 31 | with open("/home/v-yongqili/project/GCoQA/data/trie_dict_t5-base_section_level.pkl", 'wb') as f: 32 | pickle.dump(decoder_trie.trie_dict, f) -------------------------------------------------------------------------------- /passage-level/generate_trie_dict.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from tqdm import tqdm 3 | import re 4 | import json 5 | from transformers import T5Tokenizer 6 | from utils import Trie 7 | import pickle 8 | 9 | tokenizer = T5Tokenizer.from_pretrained("t5-base", do_lower_case=True, 10 | cache_dir="/home/v-yongqili/project/GCoQA/data/huggingface_cache/") 11 | page_title_dict = {} 12 | with open("/home/v-yongqili/project/GCoQA/data/full_wiki_segments.json", 'r') as f: 13 | data = f.readlines() 14 | for line in tqdm(data): 15 | line = json.loads(line) 16 | if line['title'].strip() not in page_title_dict: 17 | page_title_dict[line['title'].strip()] = 1 18 | 19 | print("page_title_dict len %s", len(page_title_dict)) 20 | 21 | title_sequence = [] 22 | for page_title in tqdm(page_title_dict): 23 | input_ids = tokenizer.encode( 24 | page_title, 25 | add_special_tokens=True, 26 | max_length=64, 27 | truncation=True) 28 | title_sequence.append([0] + input_ids) 29 | 30 | decoder_trie = Trie(title_sequence) 31 | with open("/home/v-yongqili/project/GCoQA/data/trie_dict_t5-base_section_level.pkl", 'wb') as f: 32 | pickle.dump(decoder_trie.trie_dict, f) -------------------------------------------------------------------------------- /document_level/modeling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import collections 4 | import torch 5 | 6 | 7 | from torch import nn 8 | 9 | 10 | 11 | 12 | 13 | from transformers import T5Tokenizer, T5ForConditionalGeneration 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | 18 | class Generative_Retrieval(nn.Module): 19 | r""" 20 | 21 | """ 22 | def __init__(self, args): 23 | super(Generative_Retrieval, self).__init__() 24 | self.generator = T5ForConditionalGeneration.from_pretrained(args.model_type, cache_dir=args.cache_dir) 25 | 26 | def forward(self, args=None, query_input_ids=None, query_attention_mask=None, 27 | target_input_ids=None, target_attention_mask=None, 28 | prefix_allowed_tokens_fn=None, mode="train"): 29 | if mode=="train": 30 | loss = self.generator(input_ids=query_input_ids, attention_mask=query_attention_mask, labels=target_input_ids).loss 31 | return loss 32 | if mode=="dev": 33 | outputs = self.generator.generate(query_input_ids, 34 | attention_mask= query_attention_mask, 35 | num_beams=5, 36 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 37 | min_length=0, 38 | max_length=64, 39 | num_return_sequences =1) 40 | return outputs 41 | if mode=="test": 42 | outputs = self.generator.generate(query_input_ids, 43 | attention_mask= query_attention_mask, 44 | num_beams=args.beam_size, 45 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 46 | min_length=0, 47 | max_length=64, 48 | num_return_sequences = args.top_k) 49 | return outputs 50 | def dist_gather_tensor(t): 51 | if t is None: 52 | return None 53 | t = t.contiguous() 54 | 55 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 56 | torch.distributed.all_gather(all_tensors, t) 57 | 58 | all_tensors[torch.distributed.get_rank()] = t 59 | all_tensors = torch.cat(all_tensors, dim=0) 60 | 61 | return all_tensors 62 | 63 | -------------------------------------------------------------------------------- /passage-level/modeling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import collections 4 | import torch 5 | 6 | 7 | from torch import nn 8 | 9 | 10 | 11 | 12 | 13 | from transformers import T5Tokenizer, T5ForConditionalGeneration 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | 18 | class Generative_Retrieval(nn.Module): 19 | r""" 20 | 21 | """ 22 | def __init__(self, args): 23 | super(Generative_Retrieval, self).__init__() 24 | self.generator = T5ForConditionalGeneration.from_pretrained(args.model_type, cache_dir=args.cache_dir) 25 | 26 | def forward(self, args=None, query_input_ids=None, query_attention_mask=None, 27 | target_input_ids=None, target_attention_mask=None, 28 | prefix_allowed_tokens_fn=None, mode="train"): 29 | if mode=="train": 30 | loss = self.generator(input_ids=query_input_ids, attention_mask=query_attention_mask, labels=target_input_ids).loss 31 | return loss 32 | if mode=="dev": 33 | outputs = self.generator.generate(query_input_ids, 34 | attention_mask= query_attention_mask, 35 | num_beams=5, 36 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 37 | min_length=0, 38 | max_length=64, 39 | num_return_sequences =1) 40 | return outputs 41 | if mode=="test": 42 | outputs = self.generator.generate(query_input_ids, 43 | attention_mask= query_attention_mask, 44 | num_beams=args.beam_size, 45 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 46 | min_length=0, 47 | max_length=64, 48 | num_return_sequences = args.top_k) 49 | return outputs 50 | def dist_gather_tensor(t): 51 | if t is None: 52 | return None 53 | t = t.contiguous() 54 | 55 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 56 | torch.distributed.all_gather(all_tensors, t) 57 | 58 | all_tensors[torch.distributed.get_rank()] = t 59 | all_tensors = torch.cat(all_tensors, dim=0) 60 | 61 | return all_tensors 62 | 63 | -------------------------------------------------------------------------------- /document_level/retriever_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import json 4 | import logging 5 | import math 6 | import collections 7 | import linecache 8 | import numpy as np 9 | from io import open 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset 12 | import pickle 13 | import csv 14 | from datasets import load_dataset 15 | 16 | 17 | import random 18 | from random import choice 19 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 20 | # from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | 26 | class FinetuningDataset(Dataset): 27 | def __init__(self, filename, tokenizer, 28 | load_small, query_max_seq_length,target_max_seq_length, prepend_answers): 29 | 30 | self._filename = filename 31 | self._tokenizer = tokenizer 32 | self._load_small = load_small 33 | 34 | self._query_max_seq_length = query_max_seq_length 35 | 36 | self._target_max_seq_length = target_max_seq_length 37 | 38 | self.data = [] 39 | with open(filename, 'r') as load_f: 40 | data = json.load(load_f) 41 | for entry in data: 42 | if entry['Answer'] != "UNANSWERABLE": 43 | self.data.append(entry) 44 | # if entry['Answer'] == "UNANSWERABLE": ###############for answer generation experiemnt 45 | # entry['Page'] = 'The Young and the Restless' 46 | # entry['Passage'] = { 47 | # "id": 9185426, 48 | # "title": "The Young and the Restless [SEP] Casting and story development", 49 | # "text": "Co-creators William J. Bell and Lee Phillip Bell centered The Young and the Restless around two core families, the wealthy Brooks and the poor Fosters. Bell borrowed this technique of soap opera building from his mentor, Irna Phillips. While casting for the series, Bell and executive producer John Conboy auditioned 540 actors for the 13 main characters. They assembled the youngest group of actors ever cast on a soap opera at the time, hiring mostly unknown actors whom they considered glamorous model types. Chemistry between actors also factored into the criteria for casting. The stories focused on the younger characters, with an emphasis in fantasy. The fantasy element was reflected in the love story between Jill Abbott and the millionaire Phillip Chancellor II; the Leslie Brooks, Brad Elliot, and Lorie Brooks love triangle; and Snapper Fosters romance with Chris Brooks. Sexuality also played a major role in the stories. Formerly, soap operas did not delve into the sexual side of their romances. Bell changed that, first during his time as head writer of Days of Our Lives and again on The Young and the Restless. William Gray Espys Snapper Foster is considered the first to discover sex on a soap opera. During the story, the character is engaged to Chris Brooks (Trish Stewart) and having a sexual relationship with Sally McGuire (Lee Crawford). Other plots reflected sexual themes as well. For the first time in the genre, the dialogue and the story situations included explicit sexual themes such as premarital intercourse, impotence, incest, and rape. The first two rape storylines that would be told on the serial were controversial at the time as they reflected a more introspective and analytic storytelling style, the first time rape storylines would be addressed in this manner in the genre. The first, in 1973–74, revolved around the rape of Chris Brooks and the aftermath, in which she entertained (and, eventually, rejected) the idea that she was perhaps at fault for her attack. The second, in 1976, involved Chriss sister Peggy (Pamela Peters Solow) and was meant to serve as a cut-and-dried story in which no viewer could justify this attack, committed out of the blue by an authority figure." 50 | # } 51 | # self.data.append(entry) 52 | self._total_data = 0 53 | if self._load_small: 54 | self._total_data = 100 55 | else: 56 | self._total_data = len(self.data) 57 | 58 | self.prepend_answers = prepend_answers 59 | def __len__(self): 60 | return self._total_data 61 | 62 | def __getitem__(self, idx): 63 | 64 | entry = self.data[idx] 65 | 66 | if self.prepend_answers: 67 | entry['Question'] = " [SEP] ".join(entry['Context'])+ " [SEP] " + entry['Question'] 68 | else: 69 | s = [] 70 | for i in range(len(entry['Context'])): 71 | if i%2 == 0: 72 | s.append(entry['Context'][i]) 73 | entry['Question'] = " [SEP] ".join(s)+ " [SEP] " + entry['Question'] 74 | 75 | query_feature = text_to_feature(entry['Question'], self._tokenizer, 76 | max_length=self._query_max_seq_length) 77 | 78 | target_text = entry["Page"].strip() 79 | 80 | 81 | target_feature = text_to_feature(target_text, self._tokenizer, 82 | max_length=self._target_max_seq_length) 83 | return_feature_dict = { 'query_input_ids': np.asarray(query_feature['input_ids']), 84 | 'query_attention_mask': np.asarray(query_feature['attention_mask']), 85 | 'query_text': entry['Question'], 86 | 'answer_text': entry['Answer'], 87 | 'target_input_ids': np.asarray(target_feature['input_ids']), 88 | 'target_attention_mask': np.asarray(target_feature['attention_mask']), 89 | 'target_text': target_text 90 | } 91 | 92 | return return_feature_dict 93 | 94 | 95 | def text_to_feature(text, tokenizer, 96 | max_length=256, 97 | pad_on_left=False, 98 | pad_token=0, 99 | pad_token_segment_id=0, 100 | mask_padding_with_zero=True): 101 | 102 | input_ids = tokenizer.encode( 103 | text, 104 | add_special_tokens=True, 105 | max_length=max_length, 106 | truncation=True 107 | ) 108 | 109 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 110 | # tokens are attended to. 111 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 112 | 113 | # Zero-pad up to the sequence length. 114 | padding_length = max_length - len(input_ids) 115 | if pad_on_left: 116 | input_ids = ([pad_token] * padding_length) + input_ids 117 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 118 | else: 119 | input_ids = input_ids + ([pad_token] * padding_length) 120 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 121 | 122 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 123 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) 124 | 125 | inputs = {} 126 | inputs["input_ids"] = input_ids 127 | inputs["attention_mask"] = attention_mask 128 | 129 | 130 | return inputs 131 | 132 | 133 | 134 | 135 | 136 | def normalize_question(question: str) -> str: 137 | return question 138 | 139 | def normalize_passage(ctx_text: str): 140 | return ctx_text -------------------------------------------------------------------------------- /passage-level/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Load QuAC dataset. """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import json 21 | import logging 22 | import math 23 | import collections 24 | import linecache 25 | import numpy as np 26 | from io import open 27 | from tqdm import tqdm 28 | import torch 29 | 30 | 31 | 32 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 33 | # from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | from typing import Dict, List 39 | 40 | try: 41 | import marisa_trie 42 | except ModuleNotFoundError: 43 | pass 44 | 45 | 46 | class Trie(object): 47 | def __init__(self, sequences: List[List[int]] = []): 48 | self.trie_dict = {} 49 | self.len = 0 50 | if sequences: 51 | for sequence in sequences: 52 | Trie._add_to_trie(sequence, self.trie_dict) 53 | self.len += 1 54 | 55 | self.append_trie = None 56 | self.bos_token_id = None 57 | 58 | def append(self, trie, bos_token_id): 59 | self.append_trie = trie 60 | self.bos_token_id = bos_token_id 61 | 62 | def add(self, sequence: List[int]): 63 | Trie._add_to_trie(sequence, self.trie_dict) 64 | self.len += 1 65 | 66 | def get(self, prefix_sequence: List[int]): 67 | 68 | return Trie._get_from_trie( 69 | prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id 70 | ) 71 | 72 | @staticmethod 73 | def load_from_dict(trie_dict): 74 | trie = Trie() 75 | trie.trie_dict = trie_dict 76 | trie.len = sum(1 for _ in trie) 77 | return trie 78 | 79 | @staticmethod 80 | def _add_to_trie(sequence: List[int], trie_dict: Dict): 81 | if sequence: 82 | if sequence[0] not in trie_dict: 83 | trie_dict[sequence[0]] = {} 84 | Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]]) 85 | 86 | @staticmethod 87 | def _get_from_trie( 88 | prefix_sequence: List[int], 89 | trie_dict: Dict, 90 | append_trie=None, 91 | bos_token_id: int = None, 92 | ): 93 | if len(prefix_sequence) == 0: 94 | output = list(trie_dict.keys()) 95 | if append_trie and bos_token_id in output: 96 | output.remove(bos_token_id) 97 | output += list(append_trie.trie_dict.keys()) 98 | return output 99 | elif prefix_sequence[0] in trie_dict: 100 | return Trie._get_from_trie( 101 | prefix_sequence[1:], 102 | trie_dict[prefix_sequence[0]], 103 | append_trie, 104 | bos_token_id, 105 | ) 106 | else: 107 | if append_trie: 108 | return append_trie.get(prefix_sequence) 109 | else: 110 | return [] 111 | 112 | def __iter__(self): 113 | def _traverse(prefix_sequence, trie_dict): 114 | if trie_dict: 115 | for next_token in trie_dict: 116 | yield from _traverse( 117 | prefix_sequence + [next_token], trie_dict[next_token] 118 | ) 119 | else: 120 | yield prefix_sequence 121 | 122 | return _traverse([], self.trie_dict) 123 | 124 | def __len__(self): 125 | return self.len 126 | 127 | def __getitem__(self, value): 128 | return self.get(value) 129 | 130 | 131 | class MarisaTrie(object): 132 | def __init__( 133 | self, 134 | sequences: List[List[int]] = [], 135 | cache_fist_branch=True, 136 | max_token_id=256001, 137 | ): 138 | 139 | self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + ( 140 | [chr(i) for i in range(65000, max_token_id + 10000)] 141 | if max_token_id >= 55000 142 | else [] 143 | ) 144 | self.char2int = {self.int2char[i]: i for i in range(max_token_id)} 145 | 146 | self.cache_fist_branch = cache_fist_branch 147 | if self.cache_fist_branch: 148 | self.zero_iter = list({sequence[0] for sequence in sequences}) 149 | assert len(self.zero_iter) == 1 150 | self.first_iter = list({sequence[1] for sequence in sequences}) 151 | 152 | self.trie = marisa_trie.Trie( 153 | "".join([self.int2char[i] for i in sequence]) for sequence in sequences 154 | ) 155 | 156 | def get(self, prefix_sequence: List[int]): 157 | if self.cache_fist_branch and len(prefix_sequence) == 0: 158 | return self.zero_iter 159 | elif ( 160 | self.cache_fist_branch 161 | and len(prefix_sequence) == 1 162 | and self.zero_iter == prefix_sequence 163 | ): 164 | return self.first_iter 165 | else: 166 | key = "".join([self.int2char[i] for i in prefix_sequence]) 167 | return list( 168 | { 169 | self.char2int[e[len(key)]] 170 | for e in self.trie.keys(key) 171 | if len(e) > len(key) 172 | } 173 | ) 174 | 175 | def __iter__(self): 176 | for sequence in self.trie.iterkeys(): 177 | yield [self.char2int[e] for e in sequence] 178 | 179 | def __len__(self): 180 | return len(self.trie) 181 | 182 | def __getitem__(self, value): 183 | return self.get(value) 184 | 185 | 186 | class DummyTrieMention(object): 187 | def __init__(self, return_values): 188 | self._return_values = return_values 189 | 190 | def get(self, indices=None): 191 | return self._return_values 192 | 193 | 194 | class DummyTrieEntity(object): 195 | def __init__(self, return_values, codes): 196 | self._return_values = list( 197 | set(return_values).difference( 198 | set( 199 | codes[e] 200 | for e in ( 201 | "start_mention_token", 202 | "end_mention_token", 203 | "start_entity_token", 204 | ) 205 | ) 206 | ) 207 | ) 208 | self._codes = codes 209 | 210 | def get(self, indices, depth=0): 211 | if len(indices) == 0 and depth == 0: 212 | return self._codes["end_mention_token"] 213 | elif len(indices) == 0 and depth == 1: 214 | return self._codes["start_entity_token"] 215 | elif len(indices) == 0: 216 | return self._return_values 217 | elif len(indices) == 1 and indices[0] == self._codes["end_entity_token"]: 218 | return self._codes["EOS"] 219 | else: 220 | return self.get(indices[1:], depth=depth + 1) 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /document_level/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Load QuAC dataset. """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import json 21 | import logging 22 | import math 23 | import collections 24 | import linecache 25 | import numpy as np 26 | from io import open 27 | from tqdm import tqdm 28 | import torch 29 | 30 | 31 | 32 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 33 | # from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | from typing import Dict, List 39 | 40 | try: 41 | import marisa_trie 42 | except ModuleNotFoundError: 43 | pass 44 | 45 | 46 | class Trie(object): 47 | def __init__(self, sequences: List[List[int]] = []): 48 | self.trie_dict = {} 49 | self.len = 0 50 | if sequences: 51 | for sequence in sequences: 52 | Trie._add_to_trie(sequence, self.trie_dict) 53 | self.len += 1 54 | 55 | self.append_trie = None 56 | self.bos_token_id = None 57 | 58 | def append(self, trie, bos_token_id): 59 | self.append_trie = trie 60 | self.bos_token_id = bos_token_id 61 | 62 | def add(self, sequence: List[int]): 63 | Trie._add_to_trie(sequence, self.trie_dict) 64 | self.len += 1 65 | 66 | def get(self, prefix_sequence: List[int]): 67 | 68 | return Trie._get_from_trie( 69 | prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id 70 | ) 71 | 72 | @staticmethod 73 | def load_from_dict(trie_dict): 74 | trie = Trie() 75 | trie.trie_dict = trie_dict 76 | trie.len = sum(1 for _ in trie) 77 | return trie 78 | 79 | @staticmethod 80 | def _add_to_trie(sequence: List[int], trie_dict: Dict): 81 | if sequence: 82 | if sequence[0] not in trie_dict: 83 | trie_dict[sequence[0]] = {} 84 | Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]]) 85 | 86 | @staticmethod 87 | def _get_from_trie( 88 | prefix_sequence: List[int], 89 | trie_dict: Dict, 90 | append_trie=None, 91 | bos_token_id: int = None, 92 | ): 93 | if len(prefix_sequence) == 0: 94 | output = list(trie_dict.keys()) 95 | if append_trie and bos_token_id in output: 96 | output.remove(bos_token_id) 97 | output += list(append_trie.trie_dict.keys()) 98 | return output 99 | elif prefix_sequence[0] in trie_dict: 100 | return Trie._get_from_trie( 101 | prefix_sequence[1:], 102 | trie_dict[prefix_sequence[0]], 103 | append_trie, 104 | bos_token_id, 105 | ) 106 | else: 107 | if append_trie: 108 | return append_trie.get(prefix_sequence) 109 | else: 110 | return [] 111 | 112 | def __iter__(self): 113 | def _traverse(prefix_sequence, trie_dict): 114 | if trie_dict: 115 | for next_token in trie_dict: 116 | yield from _traverse( 117 | prefix_sequence + [next_token], trie_dict[next_token] 118 | ) 119 | else: 120 | yield prefix_sequence 121 | 122 | return _traverse([], self.trie_dict) 123 | 124 | def __len__(self): 125 | return self.len 126 | 127 | def __getitem__(self, value): 128 | return self.get(value) 129 | 130 | 131 | class MarisaTrie(object): 132 | def __init__( 133 | self, 134 | sequences: List[List[int]] = [], 135 | cache_fist_branch=True, 136 | max_token_id=256001, 137 | ): 138 | 139 | self.int2char = [chr(i) for i in range(min(max_token_id, 55000))] + ( 140 | [chr(i) for i in range(65000, max_token_id + 10000)] 141 | if max_token_id >= 55000 142 | else [] 143 | ) 144 | self.char2int = {self.int2char[i]: i for i in range(max_token_id)} 145 | 146 | self.cache_fist_branch = cache_fist_branch 147 | if self.cache_fist_branch: 148 | self.zero_iter = list({sequence[0] for sequence in sequences}) 149 | assert len(self.zero_iter) == 1 150 | self.first_iter = list({sequence[1] for sequence in sequences}) 151 | 152 | self.trie = marisa_trie.Trie( 153 | "".join([self.int2char[i] for i in sequence]) for sequence in sequences 154 | ) 155 | 156 | def get(self, prefix_sequence: List[int]): 157 | if self.cache_fist_branch and len(prefix_sequence) == 0: 158 | return self.zero_iter 159 | elif ( 160 | self.cache_fist_branch 161 | and len(prefix_sequence) == 1 162 | and self.zero_iter == prefix_sequence 163 | ): 164 | return self.first_iter 165 | else: 166 | key = "".join([self.int2char[i] for i in prefix_sequence]) 167 | return list( 168 | { 169 | self.char2int[e[len(key)]] 170 | for e in self.trie.keys(key) 171 | if len(e) > len(key) 172 | } 173 | ) 174 | 175 | def __iter__(self): 176 | for sequence in self.trie.iterkeys(): 177 | yield [self.char2int[e] for e in sequence] 178 | 179 | def __len__(self): 180 | return len(self.trie) 181 | 182 | def __getitem__(self, value): 183 | return self.get(value) 184 | 185 | 186 | class DummyTrieMention(object): 187 | def __init__(self, return_values): 188 | self._return_values = return_values 189 | 190 | def get(self, indices=None): 191 | return self._return_values 192 | 193 | 194 | class DummyTrieEntity(object): 195 | def __init__(self, return_values, codes): 196 | self._return_values = list( 197 | set(return_values).difference( 198 | set( 199 | codes[e] 200 | for e in ( 201 | "start_mention_token", 202 | "end_mention_token", 203 | "start_entity_token", 204 | ) 205 | ) 206 | ) 207 | ) 208 | self._codes = codes 209 | 210 | def get(self, indices, depth=0): 211 | if len(indices) == 0 and depth == 0: 212 | return self._codes["end_mention_token"] 213 | elif len(indices) == 0 and depth == 1: 214 | return self._codes["start_entity_token"] 215 | elif len(indices) == 0: 216 | return self._return_values 217 | elif len(indices) == 1 and indices[0] == self._codes["end_entity_token"]: 218 | return self._codes["EOS"] 219 | else: 220 | return self.get(indices[1:], depth=depth + 1) 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | -------------------------------------------------------------------------------- /passage-level/retriever_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import json 4 | import logging 5 | import math 6 | import collections 7 | import linecache 8 | import numpy as np 9 | from io import open 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset 12 | import pickle 13 | import csv 14 | from datasets import load_dataset 15 | 16 | 17 | import random 18 | from random import choice 19 | 20 | # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) 21 | # from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | 27 | class FinetuningDataset(Dataset): 28 | def __init__(self, filename, tokenizer, 29 | load_small, query_max_seq_length,target_max_seq_length, prepend_answers): 30 | 31 | self._filename = filename 32 | self._tokenizer = tokenizer 33 | self._load_small = load_small 34 | 35 | self._query_max_seq_length = query_max_seq_length 36 | 37 | self._target_max_seq_length = target_max_seq_length 38 | 39 | self.data = [] 40 | with open(filename, 'r') as load_f: 41 | data = json.load(load_f) 42 | for entry in data: 43 | if entry['Answer'] != "UNANSWERABLE": 44 | self.data.append(entry) 45 | # if entry['Answer'] == "UNANSWERABLE": ###############for answer generation experiemnt 46 | # entry['Page'] = 'The Young and the Restless' 47 | # entry['Passage'] = { 48 | # "id": 9185426, 49 | # "title": "The Young and the Restless [SEP] Casting and story development", 50 | # "text": "Co-creators William J. Bell and Lee Phillip Bell centered The Young and the Restless around two core families, the wealthy Brooks and the poor Fosters. Bell borrowed this technique of soap opera building from his mentor, Irna Phillips. While casting for the series, Bell and executive producer John Conboy auditioned 540 actors for the 13 main characters. They assembled the youngest group of actors ever cast on a soap opera at the time, hiring mostly unknown actors whom they considered glamorous model types. Chemistry between actors also factored into the criteria for casting. The stories focused on the younger characters, with an emphasis in fantasy. The fantasy element was reflected in the love story between Jill Abbott and the millionaire Phillip Chancellor II; the Leslie Brooks, Brad Elliot, and Lorie Brooks love triangle; and Snapper Fosters romance with Chris Brooks. Sexuality also played a major role in the stories. Formerly, soap operas did not delve into the sexual side of their romances. Bell changed that, first during his time as head writer of Days of Our Lives and again on The Young and the Restless. William Gray Espys Snapper Foster is considered the first to discover sex on a soap opera. During the story, the character is engaged to Chris Brooks (Trish Stewart) and having a sexual relationship with Sally McGuire (Lee Crawford). Other plots reflected sexual themes as well. For the first time in the genre, the dialogue and the story situations included explicit sexual themes such as premarital intercourse, impotence, incest, and rape. The first two rape storylines that would be told on the serial were controversial at the time as they reflected a more introspective and analytic storytelling style, the first time rape storylines would be addressed in this manner in the genre. The first, in 1973–74, revolved around the rape of Chris Brooks and the aftermath, in which she entertained (and, eventually, rejected) the idea that she was perhaps at fault for her attack. The second, in 1976, involved Chriss sister Peggy (Pamela Peters Solow) and was meant to serve as a cut-and-dried story in which no viewer could justify this attack, committed out of the blue by an authority figure." 51 | # } 52 | # self.data.append(entry) 53 | # if 'train' in filename: 54 | # random.shuffle(self.data) 55 | # self.data = self.data[:int(len(self.data)*0.0)] 56 | 57 | self._total_data = 0 58 | if self._load_small: 59 | self._total_data = 100 60 | else: 61 | self._total_data = len(self.data) 62 | 63 | self.prepend_answers = prepend_answers 64 | def __len__(self): 65 | return self._total_data 66 | 67 | def __getitem__(self, idx): 68 | 69 | entry = self.data[idx] 70 | 71 | if self.prepend_answers: 72 | entry['Question'] = " [SEP] ".join(entry['Context'])+ " [SEP] " + entry['Question'] 73 | else: 74 | s = [] 75 | for i in range(len(entry['Context'])): 76 | if i%2 == 0: 77 | s.append(entry['Context'][i]) 78 | entry['Question'] = " [SEP] ".join(s)+ " [SEP] " + entry['Question'] 79 | 80 | query_feature = text_to_feature(entry['Question'], self._tokenizer, 81 | max_length=self._query_max_seq_length) 82 | 83 | target_text = entry["Passage"]['title'].strip() 84 | 85 | 86 | target_feature = text_to_feature(target_text, self._tokenizer, 87 | max_length=self._target_max_seq_length) 88 | return_feature_dict = { 'query_input_ids': np.asarray(query_feature['input_ids']), 89 | 'query_attention_mask': np.asarray(query_feature['attention_mask']), 90 | 'query_text': entry['Question'], 91 | 'target_input_ids': np.asarray(target_feature['input_ids']), 92 | 'target_attention_mask': np.asarray(target_feature['attention_mask']), 93 | 'target_text': target_text, 94 | 'answer_text': entry["Answer"], 95 | } 96 | 97 | return return_feature_dict 98 | 99 | 100 | def text_to_feature(text, tokenizer, 101 | max_length=256, 102 | pad_on_left=False, 103 | pad_token=0, 104 | pad_token_segment_id=0, 105 | mask_padding_with_zero=True): 106 | 107 | input_ids = tokenizer.encode( 108 | text, 109 | add_special_tokens=True, 110 | max_length=max_length, 111 | truncation=True 112 | ) 113 | 114 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 115 | # tokens are attended to. 116 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 117 | 118 | # Zero-pad up to the sequence length. 119 | padding_length = max_length - len(input_ids) 120 | if pad_on_left: 121 | input_ids = ([pad_token] * padding_length) + input_ids 122 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 123 | else: 124 | input_ids = input_ids + ([pad_token] * padding_length) 125 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 126 | 127 | assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) 128 | assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) 129 | 130 | inputs = {} 131 | inputs["input_ids"] = input_ids 132 | inputs["attention_mask"] = attention_mask 133 | 134 | 135 | return inputs 136 | 137 | 138 | 139 | 140 | 141 | def normalize_question(question: str) -> str: 142 | return question 143 | 144 | def normalize_passage(ctx_text: str): 145 | return ctx_text -------------------------------------------------------------------------------- /document_level/train_query_encoder.yaml: -------------------------------------------------------------------------------- 1 | description: GCoQA 2 | 3 | 4 | 5 | target: 6 | service: amlk8s 7 | name: itplabrr1cl1 8 | # name: itpeusp40cl 9 | environment: 10 | image: yongqili/gdpr-dgl:v5 11 | setup: 12 | - pip install datasets 13 | - pip install sentencepiece 14 | 15 | # target: 16 | # service: sing 17 | # name: msrresrchvc 18 | # environment: 19 | # image: wangliang/pytorch:1.7.1-transformers4.15-fix 20 | # username: resrchvc4cr 21 | # registry: resrchvc4cr.azurecr.io 22 | # setup: 23 | # - echo "export PATH=$PATH:$HOME/.local/bin" >> ~/.bashrc && source ~/.bashrc 24 | # - pip install -r requirements.txt 25 | # - pip install faiss 26 | # - echo "setup done" 27 | 28 | data: 29 | local_dir: /home/v-yongqili/project/GCoQA/data 30 | remote_dir: data/GCoQA/data 31 | 32 | code: 33 | local_dir: ./ 34 | 35 | 36 | jobs: 37 | 38 | 39 | 40 | # - name: GCoQA2_topiocqa_document_testforreader 41 | # sku: G8 42 | # priority: High 43 | # command: 44 | # - export MKL_SERVICE_FORCE_INTEL=1 45 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 46 | # --do_train False 47 | # --load_small False 48 | # --fp16 False 49 | # --num_train_epochs 40 50 | # --per_gpu_train_batch_size 8 51 | # --per_gpu_eval_batch_size 4 52 | # --per_gpu_test_batch_size 2 53 | # --overwrite_output_dir True 54 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 55 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 56 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 57 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 58 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 59 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 60 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 61 | # --learning_rate 1e-5 62 | # --prepend_answers True 63 | # --model_type t5-large 64 | # --top_k 5 65 | # --beam_size 5 66 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 67 | 68 | # - name: GCoQA2_qrecc2_document_testforreader 69 | # sku: G8 70 | # priority: High 71 | # command: 72 | # - export MKL_SERVICE_FORCE_INTEL=1 73 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 74 | # --do_train False 75 | # --load_small False 76 | # --fp16 False 77 | # --num_train_epochs 40 78 | # --per_gpu_train_batch_size 8 79 | # --per_gpu_eval_batch_size 4 80 | # --per_gpu_test_batch_size 2 81 | # --overwrite_output_dir True 82 | # --train_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_train.json 83 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_dev.json 84 | # --test_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_test.json 85 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 86 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 87 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 88 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 89 | # --learning_rate 1e-5 90 | # --prepend_answers True 91 | # --model_type t5-large 92 | # --top_k 5 93 | # --beam_size 5 94 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17825-6f2b0032-8234-4b7f-95e4-379d551e2d37/release_test/checkpoint-6920model.pt 95 | 96 | # - name: GCoQA2_orquac2_document_testforreader 97 | # sku: G8 98 | # priority: High 99 | # command: 100 | # - export MKL_SERVICE_FORCE_INTEL=1 101 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 102 | # --do_train False 103 | # --load_small False 104 | # --fp16 False 105 | # --num_train_epochs 40 106 | # --per_gpu_train_batch_size 4 107 | # --per_gpu_eval_batch_size 4 108 | # --per_gpu_test_batch_size 2 109 | # --overwrite_output_dir True 110 | # --train_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_train.json 111 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_dev.json 112 | # --test_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_test.json 113 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 114 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 115 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 116 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 117 | # --learning_rate 1e-5 118 | # --prepend_answers False 119 | # --model_type t5-large 120 | # --top_k 5 121 | # --beam_size 5 122 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17721-eeaaebf9-3863-45ea-aabf-3c4b195225b9/release_test/checkpoint-20923model.pt 123 | 124 | 125 | 126 | 127 | - name: GCoQA2_topiocqa_document_testforbeam_size5 128 | sku: G2 129 | priority: High 130 | command: 131 | - export MKL_SERVICE_FORCE_INTEL=1 132 | - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 133 | --do_train False 134 | --load_small False 135 | --fp16 False 136 | --num_train_epochs 40 137 | --per_gpu_train_batch_size 8 138 | --per_gpu_eval_batch_size 4 139 | --per_gpu_test_batch_size 2 140 | --overwrite_output_dir True 141 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 142 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 143 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 144 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 145 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 146 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 147 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 148 | --learning_rate 1e-5 149 | --prepend_answers True 150 | --model_type t5-large 151 | --top_k 5 152 | --beam_size 5 153 | --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 154 | - name: GCoQA2_topiocqa_document_testforbeam_size10 155 | sku: G2 156 | priority: High 157 | command: 158 | - export MKL_SERVICE_FORCE_INTEL=1 159 | - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 160 | --do_train False 161 | --load_small False 162 | --fp16 False 163 | --num_train_epochs 40 164 | --per_gpu_train_batch_size 8 165 | --per_gpu_eval_batch_size 4 166 | --per_gpu_test_batch_size 2 167 | --overwrite_output_dir True 168 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 169 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 170 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 171 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 172 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 173 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 174 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 175 | --learning_rate 1e-5 176 | --prepend_answers True 177 | --model_type t5-large 178 | --top_k 10 179 | --beam_size 10 180 | --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 181 | 182 | - name: GCoQA2_topiocqa_document_testforbeam_size20 183 | sku: G2 184 | priority: High 185 | command: 186 | - export MKL_SERVICE_FORCE_INTEL=1 187 | - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 188 | --do_train False 189 | --load_small False 190 | --fp16 False 191 | --num_train_epochs 40 192 | --per_gpu_train_batch_size 8 193 | --per_gpu_eval_batch_size 4 194 | --per_gpu_test_batch_size 2 195 | --overwrite_output_dir True 196 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 197 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 198 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 199 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 200 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 201 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 202 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 203 | --learning_rate 1e-5 204 | --prepend_answers True 205 | --model_type t5-large 206 | --top_k 20 207 | --beam_size 20 208 | --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 209 | 210 | 211 | - name: GCoQA2_topiocqa_document_testforbeam_size50 212 | sku: G2 213 | priority: High 214 | command: 215 | - export MKL_SERVICE_FORCE_INTEL=1 216 | - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 217 | --do_train False 218 | --load_small False 219 | --fp16 False 220 | --num_train_epochs 40 221 | --per_gpu_train_batch_size 8 222 | --per_gpu_eval_batch_size 4 223 | --per_gpu_test_batch_size 2 224 | --overwrite_output_dir True 225 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 226 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 227 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 228 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 229 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 230 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 231 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 232 | --learning_rate 1e-5 233 | --prepend_answers True 234 | --model_type t5-large 235 | --top_k 50 236 | --beam_size 50 237 | --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 238 | 239 | # - name: GCoQA2_topiocqa_document_testforbeam_size100 240 | # sku: G8 241 | # priority: High 242 | # command: 243 | # - export MKL_SERVICE_FORCE_INTEL=1 244 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 245 | # --do_train False 246 | # --load_small False 247 | # --fp16 False 248 | # --num_train_epochs 40 249 | # --per_gpu_train_batch_size 8 250 | # --per_gpu_eval_batch_size 4 251 | # --per_gpu_test_batch_size 2 252 | # --overwrite_output_dir True 253 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 254 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 255 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 256 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 257 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 258 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 259 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 260 | # --learning_rate 1e-5 261 | # --prepend_answers True 262 | # --model_type t5-large 263 | # --top_k 100 264 | # --beam_size 100 265 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7338574471.17930-95ac2da2-97cc-42c5-ba87-c05436cd84f6/release_test/checkpoint-22725model.pt 266 | 267 | 268 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | We have published several works on generative retrieval as follows. 3 | ``` 4 | Multiview Identifiers Enhanced Generative Retrieval. ACL 2023. (MINDER) 5 | Generative Retrieval for Conversational Question Answering. IPM 2023. (GCoQA) 6 | Learning to Rank in Generative Retrieval. AAAI 2024. (LTRGR) 7 | Generative Cross-Modal Retrieval: Memorizing Images in Multimodal Language Models for Retrieval and Beyond. ACL 2024 (GRACE). 8 | Distillation Enhanced Generative Retrieval. ACL 2024 findings (DGR). 9 | ``` 10 | All code, data, and checkpoints of the above works are open-released: 11 | 1. MINDER, LTRGR, and DGR, are a series of works on text retrieval. LTRGR and DGR are continuously training based on the MINDER model, so we release MINDER, LTRGR, and DGR together in the same repository https://github.com/liyongqi67/MINDER. 12 | 2. GCoQA is the work on conversational retrieval and is released at https://github.com/liyongqi67/GCoQA. 13 | 3. GRACE is the work on cross-modal retrieval and is released at https://github.com/liyongqi67/GRACE. 14 | 15 | You could also refer to our preprint works on generative retrieval. 16 | ``` 17 | A Survey of Generative Search and Recommendation in the Era of Large Language Models. 18 | Revolutionizing Text-to-Image Retrieval as Autoregressive Token-to-Voken Generation. 19 | ``` 20 | 21 | 22 | # GCoQA 23 | This is the official implementation for the paper "Generative Retrieval for Conversational Question Answering". 24 | The paper is released in [link](https://liyongqi67.github.io/papers/Generative%20Retrieval%20for%20Conversational%20Question%20Answering.pdf). 25 | If you find our paper or code helpful,please consider citing as follows: 26 | ```bibtex 27 | @article{LI2023103475, 28 | title = {Generative retrieval for conversational question answering}, 29 | journal = {Information Processing & Management}, 30 | volume = {60}, 31 | number = {5}, 32 | pages = {103475}, 33 | year = {2023}, 34 | issn = {0306-4573}, 35 | doi = {https://doi.org/10.1016/j.ipm.2023.103475}, 36 | url = {https://www.sciencedirect.com/science/article/pii/S0306457323002121}, 37 | author = {Yongqi Li and Nan Yang and Liang Wang and Furu Wei and Wenjie Li}, 38 | } 39 | ``` 40 | 41 | ## Dataset 42 | We conducted experiments on three conversational open-domain QA datasets: OR-QuAC, QRECC, and TOPIOCQA. To facilitate future research in this area, we unified the three datasets into a benchmark with the same corpus, as DPR did. 43 | ### 1. Corpus. 44 | 1.1 Passage-level corpus: full_wiki_segments.json. 45 | Format: 46 | ``` 47 | { 48 | 'id': 0, 49 | 'title': 'Eliza Fletcher [SEP] Introduction', 50 | 'text': 'Eliza Fletcher, née Dawson (15 January 1770 – 5 February 1858) was an English autobiographer and early travel writer.' 51 | } 52 | ``` 53 | "Eliza Fletcher" is the page title, and "Introduction" is the section title, in Wikipedia. 54 | 1.2 Document-level corpus: full_wiki_document.json 55 | Format: 56 | ``` 57 | { 58 | 'id': 0, 59 | 'title': 'Eliza Fletcher', 60 | 'text': '......' 61 | } 62 | ``` 63 | "Eliza Fletcher" is the page title in Wikipedia. 64 | ### 2. QA pairs. 65 | TOPIOCQA dataset: topiocqa_train.json, topiocqa_dev.json, topiocqa_test.json. 66 | QRECC dataset: qrecc_train.json, qrecc_dev.json, qrecc_test.json. 67 | OR-QUAC dataset: orquac_train.json, orquac_dev.json, orquac_test.json. 68 | Format: 69 | ``` 70 | { 71 | "Conversation_no": 3209, 72 | "Turn_no": 2, 73 | "Context": [ 74 | "who is finn m. w. caspersen?", 75 | "American financier and philanthropist." 76 | ], 77 | "Question": "where did he study?", 78 | "Gold_question": "", 79 | "Answer": "Peddie School, Brown University, and Harvard Law School.", 80 | "Page": "Finn M. W. Caspersen", 81 | "Section": "Early life and education", 82 | "Passage": { 83 | "id": "8114812", 84 | "text": "He later reflected that being Protestant was important. There was a kind of anti-Catholicism in the family. The family moved to homes in Andover, New Jersey, and Venice, Florida. Caspersen frequently visited Norway as a child, vacationing there during summers after 1947. Caspersen attended private schools until the ninth grade. He attended the Peddie School, a private preparatory school in Hightstown, New Jersey, and was graduated in 1959. Caspersen received a Bachelor of Arts (B.A.) degree from Brown University in 1963 and a law degree (LL.B.) from Harvard Law School in 1966.", 85 | "title": "Finn M. W. Caspersen [SEP] Early life and education" 86 | } 87 | } 88 | ``` 89 | ### 3. Trie. 90 | To implement the constrained generation in the LLM, we process all the corpus and store it in the trie structure. 91 | You could use the scripts passage-level/generate_trie_dict.py and document-level/generate_trie_dict.py to obtain the trie for passages and documents, respectively. 92 | You could also download our processed trie files. 93 | ``` 94 | trie_dict_t5-base_section_level.pkl is for the passage_level. 95 | trie_dict_t5-base_page_level.pkl is for the document_level. 96 | ``` 97 | ### 4. Download. 98 | You could download the above files via this [link](https://drive.google.com/drive/folders/18Sa7QPO0r6j-OSVdoiobzAqcADIn4cfM?usp=sharing). 99 | ## Environment 100 | You could download the Docker image from the Docker Hub to access the exact environment for this project. 101 | environment: 102 | image: yongqili/gdpr-dgl:v5 103 | setup: 104 | - pip install datasets 105 | - pip install sentencepiece 106 | ## Model training 107 | ### Passage_level 108 | The script for training on the TOPIOCQA dataset is 109 | ```bash 110 | - python3 -m torch.distributed.launch --nproc_per_node 8 passage-level/train_query_encoder.py 111 | --do_train True 112 | --load_small False 113 | --fp16 False 114 | --num_train_epochs 40 115 | --per_gpu_train_batch_size 8 116 | --per_gpu_eval_batch_size 4 117 | --per_gpu_test_batch_size 2 118 | --overwrite_output_dir True 119 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 120 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 121 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 122 | --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 123 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 124 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 125 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 126 | --learning_rate 1e-5 127 | --prepend_answers True 128 | --model_type t5-large 129 | --top_k 5 130 | --beam_size 5 131 | ``` 132 | The script for training on the QRECC dataset is 133 | ```bash 134 | - python3 -m torch.distributed.launch --nproc_per_node 8 passage-level/train_query_encoder.py 135 | --do_train True 136 | --load_small False 137 | --fp16 False 138 | --num_train_epochs 40 139 | --per_gpu_train_batch_size 8 140 | --per_gpu_eval_batch_size 4 141 | --per_gpu_test_batch_size 2 142 | --overwrite_output_dir True 143 | --train_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_train.json 144 | --dev_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_dev.json 145 | --test_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_test.json 146 | --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 147 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 148 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 149 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 150 | --learning_rate 1e-5 151 | --prepend_answers True 152 | --model_type t5-large 153 | --top_k 5 154 | --beam_size 5 155 | ``` 156 | The script for training on the OR-QUAC dataset is 157 | ```bash 158 | - python3 -m torch.distributed.launch --nproc_per_node 8 passage-level/train_query_encoder.py 159 | --do_train True 160 | --load_small False 161 | --fp16 False 162 | --num_train_epochs 40 163 | --per_gpu_train_batch_size 8 164 | --per_gpu_eval_batch_size 4 165 | --per_gpu_test_batch_size 2 166 | --overwrite_output_dir True 167 | --train_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_train.json 168 | --dev_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_dev.json 169 | --test_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_test.json 170 | --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 171 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 172 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 173 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 174 | --learning_rate 1e-5 175 | --prepend_answers False 176 | --model_type t5-large 177 | --top_k 5 178 | --beam_size 5 179 | ``` 180 | ### Document_level 181 | The script for training on the TOPIOCQA dataset is 182 | ```bash 183 | - python3 -m torch.distributed.launch --nproc_per_node 8 document-level/train_query_encoder.py 184 | --do_train True 185 | --load_small False 186 | --fp16 False 187 | --num_train_epochs 40 188 | --per_gpu_train_batch_size 8 189 | --per_gpu_eval_batch_size 4 190 | --per_gpu_test_batch_size 2 191 | --overwrite_output_dir True 192 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 193 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 194 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 195 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 196 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 197 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 198 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 199 | --learning_rate 1e-5 200 | --prepend_answers True 201 | --model_type t5-large 202 | --top_k 5 203 | --beam_size 5 204 | ``` 205 | The script for training on the QRECC dataset is 206 | ```bash 207 | - python3 -m torch.distributed.launch --nproc_per_node 8 document-level/train_query_encoder.py 208 | --do_train True 209 | --load_small False 210 | --fp16 False 211 | --num_train_epochs 40 212 | --per_gpu_train_batch_size 8 213 | --per_gpu_eval_batch_size 4 214 | --per_gpu_test_batch_size 2 215 | --overwrite_output_dir True 216 | --train_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_train.json 217 | --dev_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_dev.json 218 | --test_file $$AMLT_DATA_DIR/QA_pairs/qrecc/qrecc_test.json 219 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 220 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 221 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 222 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 223 | --learning_rate 1e-5 224 | --prepend_answers True 225 | --model_type t5-large 226 | --top_k 5 227 | --beam_size 5 228 | ``` 229 | The script for training on the OR-QUAC dataset is 230 | ```bash 231 | - python3 -m torch.distributed.launch --nproc_per_node 8 document-level/train_query_encoder.py 232 | --do_train True 233 | --load_small False 234 | --fp16 False 235 | --num_train_epochs 40 236 | --per_gpu_train_batch_size 8 237 | --per_gpu_eval_batch_size 4 238 | --per_gpu_test_batch_size 2 239 | --overwrite_output_dir True 240 | --train_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_train.json 241 | --dev_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_dev.json 242 | --test_file $$AMLT_DATA_DIR/QA_pairs/orquac/orquac_test.json 243 | --corpus_path $$AMLT_DATA_DIR/full_wiki_document.json 244 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 245 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_page_level.pkl 246 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 247 | --learning_rate 1e-5 248 | --prepend_answers False 249 | --model_type t5-large 250 | --top_k 5 251 | --beam_size 5 252 | ``` 253 | 254 | We trained the models on 8*32GB NVIDIA V100 GPUs. 255 | We release our trained model checkpoints on the three datasets in this [link](https://drive.google.com/drive/folders/19ea3tuIFJkUYiwZ8eGMOTaJ0xnmydSDS?usp=sharing). 256 | 257 | 258 | ## Contact 259 | If there is any problem, please email liyongqi0@gmail.com. Please do not hesitate to email me directly as I do not frequently check GitHub issues. 260 | -------------------------------------------------------------------------------- /passage-level/train_query_encoder.yaml: -------------------------------------------------------------------------------- 1 | description: GCoQA 2 | 3 | 4 | 5 | target: 6 | service: amlk8s 7 | name: itplabrr1cl1 8 | # name: itpeusp40cl 9 | environment: 10 | image: yongqili/gdpr-dgl:v5 11 | setup: 12 | - pip install datasets 13 | - pip install sentencepiece 14 | 15 | # target: 16 | # service: sing 17 | # name: msrresrchvc 18 | # environment: 19 | # image: wangliang/pytorch:1.7.1-transformers4.15-fix 20 | # username: resrchvc4cr 21 | # registry: resrchvc4cr.azurecr.io 22 | # setup: 23 | # - echo "export PATH=$PATH:$HOME/.local/bin" >> ~/.bashrc && source ~/.bashrc 24 | # - pip install -r requirements.txt 25 | # - pip install faiss 26 | # - echo "setup done" 27 | 28 | data: 29 | local_dir: /home/v-yongqili/project/GCoQA/data 30 | remote_dir: data/GCoQA/data 31 | 32 | code: 33 | local_dir: ./ 34 | 35 | 36 | jobs: 37 | 38 | 39 | - name: GCoQA2_topiocqa_t5-3b 40 | sku: G8 41 | priority: High 42 | command: 43 | - export MKL_SERVICE_FORCE_INTEL=1 44 | - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 45 | --do_train True 46 | --load_small False 47 | --fp16 True 48 | --num_train_epochs 40 49 | --per_gpu_train_batch_size 1 50 | --per_gpu_eval_batch_size 1 51 | --per_gpu_test_batch_size 1 52 | --gradient_accumulation_steps 8 53 | --overwrite_output_dir True 54 | --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 55 | --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 56 | --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 57 | --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 58 | --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 59 | --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 60 | --output_dir $$AMLT_OUTPUT_DIR/release_test/ 61 | --learning_rate 1e-5 62 | --prepend_answers True 63 | --model_type t5-3b 64 | --top_k 5 65 | --beam_size 5 66 | # - name: GCoQA2_topiocqa_t5-base 67 | # sku: G8 68 | # priority: High 69 | # command: 70 | # - export MKL_SERVICE_FORCE_INTEL=1 71 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 72 | # --do_train True 73 | # --load_small False 74 | # --fp16 False 75 | # --num_train_epochs 40 76 | # --per_gpu_train_batch_size 8 77 | # --per_gpu_eval_batch_size 4 78 | # --per_gpu_test_batch_size 2 79 | # --overwrite_output_dir True 80 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 81 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 82 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 83 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 84 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 85 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 86 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 87 | # --learning_rate 1e-5 88 | # --prepend_answers True 89 | # --model_type t5-base 90 | # --top_k 5 91 | # --beam_size 5 92 | 93 | 94 | 95 | # - name: GCoQA2_topiocqa_testforreader 96 | # sku: G8 97 | # priority: High 98 | # command: 99 | # - export MKL_SERVICE_FORCE_INTEL=1 100 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 101 | # --do_train False 102 | # --load_small False 103 | # --fp16 False 104 | # --num_train_epochs 40 105 | # --per_gpu_train_batch_size 8 106 | # --per_gpu_eval_batch_size 4 107 | # --per_gpu_test_batch_size 2 108 | # --overwrite_output_dir True 109 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 110 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 111 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 112 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 113 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 114 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 115 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 116 | # --learning_rate 1e-5 117 | # --prepend_answers True 118 | # --model_type t5-large 119 | # --top_k 5 120 | # --beam_size 5 121 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.74957-aa7d4e0a-7d19-437f-9765-f4e43f9e3d70/release_test/checkpoint-34685model.pt 122 | 123 | 124 | # - name: GCoQA2_qrecc2_testforreader 125 | # sku: G8 126 | # priority: High 127 | # command: 128 | # - export MKL_SERVICE_FORCE_INTEL=1 129 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 130 | # --do_train False 131 | # --load_small False 132 | # --fp16 False 133 | # --num_train_epochs 40 134 | # --per_gpu_train_batch_size 8 135 | # --per_gpu_eval_batch_size 4 136 | # --per_gpu_test_batch_size 2 137 | # --overwrite_output_dir True 138 | # --train_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_train.json 139 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_dev.json 140 | # --test_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_test.json 141 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 142 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 143 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 144 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 145 | # --learning_rate 1e-5 146 | # --prepend_answers True 147 | # --model_type t5-large 148 | # --top_k 5 149 | # --beam_size 5 150 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.75059-a95f219b-f991-41b4-a5e5-3b99028c1caf/release_test/checkpoint-12581model.pt 151 | 152 | # - name: GCoQA2_orquac2_testforreader 153 | # sku: G8 154 | # priority: High 155 | # command: 156 | # - export MKL_SERVICE_FORCE_INTEL=1 157 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 158 | # --do_train False 159 | # --load_small False 160 | # --fp16 False 161 | # --num_train_epochs 40 162 | # --per_gpu_train_batch_size 8 163 | # --per_gpu_eval_batch_size 4 164 | # --per_gpu_test_batch_size 2 165 | # --overwrite_output_dir True 166 | # --train_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_train.json 167 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_dev.json 168 | # --test_file $$AMLT_DATA_DIR/QA_pairs/orquac2/orquac_test.json 169 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 170 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 171 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 172 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 173 | # --learning_rate 1e-5 174 | # --prepend_answers False 175 | # --model_type t5-large 176 | # --top_k 5 177 | # --beam_size 5 178 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339793469.04332-d5ffc0fa-e260-4b3e-ab76-64b3ed5204e8/release_test/checkpoint-20923model.pt 179 | 180 | # - name: GCoQA2_qrecc2_t5_large_original 181 | # sku: G8 182 | # priority: High 183 | # command: 184 | # - export MKL_SERVICE_FORCE_INTEL=1 185 | # - python3 -m torch.distributed.launch --nproc_per_node 8 train_query_encoder.py 186 | # --do_train True 187 | # --load_small False 188 | # --fp16 False 189 | # --num_train_epochs 40 190 | # --per_gpu_train_batch_size 8 191 | # --per_gpu_eval_batch_size 4 192 | # --per_gpu_test_batch_size 2 193 | # --overwrite_output_dir True 194 | # --train_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_train.json 195 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_dev.json 196 | # --test_file $$AMLT_DATA_DIR/QA_pairs/qrecc2/qrecc_test.json 197 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 198 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 199 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 200 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 201 | # --learning_rate 1e-5 202 | # --prepend_answers True 203 | # --model_type t5-large 204 | # --top_k 5 205 | # --beam_size 5 206 | 207 | 208 | # - name: GCoQA2_topiocqa_testforbeam_size5 209 | # sku: G2 210 | # priority: High 211 | # command: 212 | # - export MKL_SERVICE_FORCE_INTEL=1 213 | # - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 214 | # --do_train False 215 | # --load_small False 216 | # --fp16 False 217 | # --num_train_epochs 40 218 | # --per_gpu_train_batch_size 8 219 | # --per_gpu_eval_batch_size 4 220 | # --per_gpu_test_batch_size 2 221 | # --overwrite_output_dir True 222 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 223 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 224 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 225 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 226 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 227 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 228 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 229 | # --learning_rate 1e-5 230 | # --prepend_answers True 231 | # --model_type t5-large 232 | # --top_k 5 233 | # --beam_size 5 234 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.74957-aa7d4e0a-7d19-437f-9765-f4e43f9e3d70/release_test/checkpoint-34685model.pt 235 | 236 | # - name: GCoQA2_topiocqa_testforbeam_size10 237 | # sku: G2 238 | # priority: High 239 | # command: 240 | # - export MKL_SERVICE_FORCE_INTEL=1 241 | # - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 242 | # --do_train False 243 | # --load_small False 244 | # --fp16 False 245 | # --num_train_epochs 40 246 | # --per_gpu_train_batch_size 8 247 | # --per_gpu_eval_batch_size 4 248 | # --per_gpu_test_batch_size 2 249 | # --overwrite_output_dir True 250 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 251 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 252 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 253 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 254 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 255 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 256 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 257 | # --learning_rate 1e-5 258 | # --prepend_answers True 259 | # --model_type t5-large 260 | # --top_k 10 261 | # --beam_size 10 262 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.74957-aa7d4e0a-7d19-437f-9765-f4e43f9e3d70/release_test/checkpoint-34685model.pt 263 | 264 | # - name: GCoQA2_topiocqa_testforbeam_size20 265 | # sku: G2 266 | # priority: High 267 | # command: 268 | # - export MKL_SERVICE_FORCE_INTEL=1 269 | # - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 270 | # --do_train False 271 | # --load_small False 272 | # --fp16 False 273 | # --num_train_epochs 40 274 | # --per_gpu_train_batch_size 8 275 | # --per_gpu_eval_batch_size 4 276 | # --per_gpu_test_batch_size 2 277 | # --overwrite_output_dir True 278 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 279 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 280 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 281 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 282 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 283 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 284 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 285 | # --learning_rate 1e-5 286 | # --prepend_answers True 287 | # --model_type t5-large 288 | # --top_k 20 289 | # --beam_size 20 290 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.74957-aa7d4e0a-7d19-437f-9765-f4e43f9e3d70/release_test/checkpoint-34685model.pt 291 | 292 | 293 | # - name: GCoQA2_topiocqa_testforbeam_size50 294 | # sku: G2 295 | # priority: High 296 | # command: 297 | # - export MKL_SERVICE_FORCE_INTEL=1 298 | # - python3 -m torch.distributed.launch --nproc_per_node 2 train_query_encoder.py 299 | # --do_train False 300 | # --load_small False 301 | # --fp16 False 302 | # --num_train_epochs 40 303 | # --per_gpu_train_batch_size 8 304 | # --per_gpu_eval_batch_size 4 305 | # --per_gpu_test_batch_size 2 306 | # --overwrite_output_dir True 307 | # --train_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_train.json 308 | # --dev_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_dev.json 309 | # --test_file $$AMLT_DATA_DIR/QA_pairs/topiocqa/topiocqa_test.json 310 | # --corpus_path $$AMLT_DATA_DIR/full_wiki_segments.json 311 | # --cache_dir $$AMLT_DATA_DIR/huggingface_cache/ 312 | # --trie_dict $$AMLT_DATA_DIR/trie_dict_t5-base_section_level.pkl 313 | # --output_dir $$AMLT_OUTPUT_DIR/release_test/ 314 | # --learning_rate 1e-5 315 | # --prepend_answers True 316 | # --model_type t5-large 317 | # --top_k 50 318 | # --beam_size 50 319 | # --test_ckpt_path //amltb6dbd4c6ed2130b077b2c15ea456aea9/projects/GCoQA/amlt-results/7339678094.74957-aa7d4e0a-7d19-437f-9765-f4e43f9e3d70/release_test/checkpoint-34685model.pt 320 | 321 | 322 | 323 | -------------------------------------------------------------------------------- /passage-level/train_query_encoder.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # In[1]: 4 | 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import os 9 | 10 | 11 | import argparse 12 | import logging 13 | import os 14 | import random 15 | import glob 16 | import timeit 17 | import json 18 | 19 | 20 | from tqdm import tqdm, trange 21 | from copy import copy 22 | import re 23 | import torch 24 | import copy 25 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 26 | TensorDataset) 27 | from torch.utils.data.distributed import DistributedSampler 28 | 29 | try: 30 | from torch.utils.tensorboard import SummaryWriter 31 | except: 32 | from tensorboardX import SummaryWriter 33 | import transformers 34 | from transformers import T5Tokenizer 35 | from transformers import AdamW, get_linear_schedule_with_warmup 36 | from retriever_utils import FinetuningDataset 37 | from modeling import Generative_Retrieval 38 | 39 | import pickle 40 | from torch.cuda.amp import autocast as autocast 41 | import numpy as np 42 | from datasets import load_dataset 43 | from multiprocessing import Pool 44 | from utils import Trie 45 | 46 | from contextlib import contextmanager 47 | # In[2]: 48 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | 53 | transformers.logging.set_verbosity_error() 54 | 55 | 56 | # In[3]: 57 | 58 | 59 | def set_seed(args): 60 | random.seed(args.seed) 61 | np.random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | torch.cuda.manual_seed_all(args.seed) 64 | 65 | 66 | 67 | 68 | def str2bool(v): 69 | if isinstance(v, bool): 70 | return v 71 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 72 | return True 73 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 74 | return False 75 | else: 76 | raise argparse.ArgumentTypeError('Boolean value expected.') 77 | 78 | #######################################################yongqi 79 | def flat(l): 80 | for k in l: 81 | if not isinstance(k, (list, tuple)): 82 | yield k 83 | else: 84 | yield from flat(k) 85 | 86 | def prefix_allowed_tokens_fn(batch_id, sent): 87 | return decoder_trie.get(sent.tolist()) 88 | 89 | def dist_gather_tensor(t): 90 | if t is None: 91 | return None 92 | t = t.contiguous() 93 | 94 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 95 | torch.distributed.all_gather(all_tensors, t) 96 | 97 | all_tensors[torch.distributed.get_rank()] = t 98 | all_tensors = torch.cat(all_tensors, dim=0) 99 | 100 | return all_tensors 101 | 102 | 103 | 104 | def train(args, model, tokenizer): 105 | DatasetClass = FinetuningDataset 106 | train_dataset = DatasetClass(args.train_file, tokenizer, 107 | args.load_small, 108 | query_max_seq_length=args.query_max_seq_length,target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 109 | 110 | """ Train the model """ 111 | if args.local_rank in [-1, 0]: 112 | tb_writer = SummaryWriter(os.path.join(args.output_dir, 'logs')) 113 | 114 | 115 | args.train_batch_size = args.per_gpu_train_batch_size 116 | train_sampler = RandomSampler( 117 | train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 118 | 119 | train_dataloader = DataLoader( 120 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=args.num_workers) 121 | 122 | 123 | if args.max_steps > 0: 124 | t_total = args.max_steps 125 | args.num_train_epochs = args.max_steps // ( 126 | len(train_dataloader) // args.gradient_accumulation_steps) + 1 127 | else: 128 | t_total = len( 129 | train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 130 | 131 | 132 | 133 | if args.fp16: 134 | scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) 135 | 136 | 137 | # Distributed training (should be after apex fp16 initialization) 138 | if args.local_rank != -1: 139 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 140 | output_device=args.local_rank, 141 | find_unused_parameters=True) 142 | 143 | 144 | # Prepare optimizer and schedule (linear warmup and decay) 145 | no_decay = ['bias', 'LayerNorm.weight'] 146 | optimizer_grouped_parameters = [ 147 | {'params': [p for n, p in model.named_parameters() if not any( 148 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 149 | {'params': [p for n, p in model.named_parameters() if any( 150 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 151 | ] 152 | 153 | 154 | optimizer = AdamW(optimizer_grouped_parameters, 155 | lr=args.learning_rate, eps=args.adam_epsilon) 156 | 157 | if args.warmup_steps == 0: 158 | args.warmup_steps = int(t_total * args.warmup_portion) 159 | 160 | scheduler = get_linear_schedule_with_warmup( 161 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 162 | 163 | # Train! 164 | logger.info("***** Running training *****") 165 | logger.info(" Num examples = %d", len(train_dataset)) 166 | logger.info(" Num Epochs = %d", args.num_train_epochs) 167 | logger.info(" Instantaneous batch size per GPU = %d", 168 | args.per_gpu_train_batch_size) 169 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 170 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 171 | logger.info(" Gradient Accumulation steps = %d", 172 | args.gradient_accumulation_steps) 173 | logger.info(" Total optimization steps = %d", t_total) 174 | 175 | global_step = 1 176 | tr_loss, logging_loss = 0.0, 0.0 177 | model.zero_grad() 178 | train_iterator = trange(int(args.num_train_epochs), 179 | desc="Epoch", disable=args.local_rank not in [-1, 0]) 180 | # Added here for reproductibility (even between python 2 and 3) 181 | 182 | global_step_list = [] 183 | for epoch in train_iterator: 184 | 185 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", 186 | disable=args.local_rank not in [-1, 0]) 187 | if args.local_rank != -1: 188 | train_sampler.set_epoch(epoch) 189 | 190 | #######################################################yongqi 191 | for step, batch in enumerate(epoch_iterator): 192 | 193 | 194 | 195 | 196 | #######################################################yongqi 197 | model.train() 198 | query_input_ids = batch['query_input_ids'] 199 | query_attention_mask = batch['query_attention_mask'] 200 | 201 | target_input_ids = batch['target_input_ids'] 202 | target_attention_mask = batch['target_attention_mask'] 203 | 204 | target_input_ids[target_attention_mask == 0] = -100 205 | 206 | 207 | inputs = {'args': args, 208 | 'query_input_ids': query_input_ids.to(args.device), 209 | 'query_attention_mask': query_attention_mask.to(args.device), 210 | 'target_input_ids': target_input_ids.to(args.device), 211 | 'target_attention_mask': target_attention_mask.to(args.device), 212 | 'mode': "train"} 213 | if args.fp16: 214 | with torch.cuda.amp.autocast(enabled=args.fp16): 215 | loss = model(**inputs) 216 | else: 217 | loss = model(**inputs) 218 | 219 | 220 | if args.gradient_accumulation_steps > 1: 221 | loss = loss / args.gradient_accumulation_steps 222 | 223 | if args.fp16: 224 | scaler.scale(loss).backward() 225 | else: 226 | loss.backward() 227 | tr_loss += loss.item() 228 | if (step + 1) % args.gradient_accumulation_steps == 0: 229 | if args.fp16: 230 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 231 | scaler.step(optimizer) 232 | # Updates the scale for next iteration. 233 | scaler.update() 234 | 235 | scheduler.step() 236 | model.zero_grad() 237 | global_step += 1 238 | else: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 240 | optimizer.step() 241 | 242 | scheduler.step() # Update learning rate schedule 243 | model.zero_grad() 244 | global_step += 1 245 | 246 | # print('loss', loss.item()) 247 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 248 | # Log metrics 249 | # Only evaluate when single GPU otherwise metrics may not average well 250 | tb_writer.add_scalar( 251 | 'lr', scheduler.get_lr()[0], global_step) 252 | tb_writer.add_scalar( 253 | 'loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 254 | logging_loss = tr_loss 255 | 256 | 257 | if args.save_steps == -1: 258 | global_step_list.append(global_step) 259 | if args.local_rank in [-1, 0]: 260 | 261 | # Take care of distributed/parallel training 262 | model_to_save = model.module if hasattr( 263 | model, 'module') else model 264 | # Save model checkpoint 265 | output_dir = os.path.join( 266 | args.output_dir, 'checkpoint-{}'.format(global_step)) 267 | torch.save(model_to_save.state_dict(), output_dir+"model.pt") 268 | logger.info("Saving model checkpoint to %s", output_dir) 269 | 270 | 271 | return global_step, tr_loss / global_step, global_step_list 272 | 273 | # In[5]: 274 | def evaluate_dev(args, model, tokenizer): 275 | args.eval_batch_size = args.per_gpu_eval_batch_size 276 | # eval dataset load here to avoid load every time 277 | DatasetClass = FinetuningDataset 278 | eva_dataset = DatasetClass(args.dev_file, tokenizer, 279 | args.load_small, 280 | query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 281 | eval_sampler = RandomSampler( 282 | eva_dataset) if args.local_rank == -1 else DistributedSampler(eva_dataset) 283 | 284 | eval_dataloader = DataLoader( 285 | eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 286 | 287 | 288 | # Distributed training (should be after apex fp16 initialization) 289 | if args.local_rank != -1: 290 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 291 | output_device=args.local_rank, 292 | find_unused_parameters=True) 293 | 294 | # Eval! 295 | logger.info("***** Running evaluation dev *****") 296 | logger.info(" Num examples = %d", len(eva_dataset)) 297 | logger.info(" Batch size = %d", args.eval_batch_size) 298 | 299 | correct_num = 0.0 300 | total_num = 0.0 301 | 302 | 303 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): 304 | 305 | model.eval() 306 | query_input_ids = batch['query_input_ids'] 307 | query_attention_mask = batch['query_attention_mask'] 308 | 309 | target_input_ids = batch['target_input_ids'] 310 | target_attention_mask = batch['target_attention_mask'] 311 | 312 | query_text = batch['query_text'] 313 | target_text = batch['target_text'] 314 | 315 | with torch.no_grad(): 316 | inputs = {'args': args, 317 | 'query_input_ids': query_input_ids.to(args.device), 318 | 'query_attention_mask': query_attention_mask.to(args.device), 319 | 'prefix_allowed_tokens_fn': prefix_allowed_tokens_fn, 320 | 'mode': "dev"} 321 | outputs = model(**inputs) 322 | 323 | predicted_target_text = [tokenizer.decode(g, skip_special_tokens=True) for g in outputs] 324 | 325 | for i in range(len(query_text)): 326 | total_num = total_num + 1 327 | if target_text[i] == predicted_target_text[i]: 328 | correct_num = correct_num + 1 329 | correct_num_gather = torch.from_numpy(np.array([correct_num])).to(args.device) 330 | total_num_gather = torch.from_numpy(np.array([total_num])).to(args.device) 331 | 332 | if args.local_rank != -1: 333 | correct_num_gather = torch.sum(dist_gather_tensor(correct_num_gather)) 334 | total_num_gather = torch.sum(dist_gather_tensor(total_num_gather)) 335 | correct_num = correct_num_gather.item() 336 | total_num = total_num_gather.item() 337 | acc = correct_num/total_num 338 | logger.info(" correct_num = %d", correct_num) 339 | logger.info(" total_num = %d", total_num) 340 | logger.info(" acc = %s", str(acc)) 341 | return acc 342 | 343 | def evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader): 344 | 345 | 346 | # Eval! 347 | logger.info("***** Running evaluation dev *****") 348 | logger.info(" Num examples = %d", len(eva_dataset)) 349 | logger.info(" Batch size = %d", args.eval_batch_size) 350 | 351 | correct_num = [0.0]*args.top_k 352 | total_num = 0.0 353 | output_dict = [] 354 | 355 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): 356 | 357 | model.eval() 358 | query_input_ids = batch['query_input_ids'] 359 | query_attention_mask = batch['query_attention_mask'] 360 | 361 | target_input_ids = batch['target_input_ids'] 362 | target_attention_mask = batch['target_attention_mask'] 363 | 364 | query_text = batch['query_text'] 365 | target_text = batch['target_text'] 366 | 367 | 368 | with torch.no_grad(): 369 | inputs = {'args': args, 370 | 'query_input_ids': query_input_ids.to(args.device), 371 | 'query_attention_mask': query_attention_mask.to(args.device), 372 | 'prefix_allowed_tokens_fn': prefix_allowed_tokens_fn, 373 | 'mode': "test"} 374 | outputs = model(**inputs) 375 | 376 | 377 | 378 | predicted_target_text = [tokenizer.decode(g, skip_special_tokens=True) for g in outputs] 379 | 380 | for i in range(len(query_text)): 381 | total_num = total_num + 1 382 | entry={} 383 | entry['question'] = query_text[i] 384 | entry['answers'] = [batch['answer_text'][i]] 385 | entry['ctxs'] = [] 386 | 387 | for j in range(args.top_k): 388 | if target_text[i] == predicted_target_text[i*args.top_k+j]: 389 | correct_num[j] = correct_num[j] + 1 390 | if predicted_target_text[i*args.top_k+j] in title2idx: 391 | idx = title2idx[predicted_target_text[i*args.top_k+j]] 392 | entry['ctxs'].append(passage_corpus[idx]) 393 | else: 394 | entry['ctxs'].append({"id": 0, 'title':predicted_target_text[i*args.top_k+j], 'text': ""}) 395 | output_dict.append(entry) 396 | for k in [1,3,5,10,20,50,100]: 397 | new_correct_num = correct_num[:k] 398 | correct_num_k = sum(new_correct_num) 399 | recall_k = correct_num_k/total_num 400 | 401 | 402 | mrr = 0.0 403 | for j in range(len(new_correct_num)): 404 | mrr += float(new_correct_num[j])/(j+1) 405 | mrr = mrr/total_num 406 | logger.info("correct_num = %s", correct_num_k) 407 | logger.info("total_num = %s", total_num) 408 | logger.info("recall @ " + str(k) + " = %s", str(recall_k)) 409 | logger.info("mrr @ " + str(k) + " = %s", str(mrr)) 410 | return output_dict 411 | 412 | def dist_gather_tensor(t): 413 | if t is None: 414 | return None 415 | t = t.contiguous() 416 | 417 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 418 | torch.distributed.all_gather(all_tensors, t) 419 | 420 | all_tensors[torch.distributed.get_rank()] = t 421 | all_tensors = torch.cat(all_tensors, dim=0) 422 | 423 | return all_tensors 424 | 425 | parser = argparse.ArgumentParser() 426 | 427 | # data 428 | parser.add_argument("--train_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_train.json", 429 | type=str, required=False, 430 | help="training file ") 431 | parser.add_argument("--dev_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_dev.json", 432 | type=str, required=False, 433 | help="dev_file ") 434 | parser.add_argument("--test_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_test.json", 435 | type=str, required=False, 436 | help="test_file ") 437 | parser.add_argument("--corpus_path", default="/home/v-yongqili/project/GCoQA/data/full_wiki_segments.json", 438 | type=str, required=False, 439 | help="dev_file ") 440 | parser.add_argument("--trie_dict", default="/home/v-yongqili/project/GCoQA/data/trie_dict_t5-base_section_level.pkl", 441 | type=str, required=False, 442 | help="dev_file ") 443 | parser.add_argument("--cache_dir", default="/home/v-yongqili/project/GCoQA/data/huggingface_cache/", type=str, 444 | help="Where do you want to store the pre-trained models downloaded from s3") 445 | parser.add_argument("--output_dir", default='./release_test1', type=str, required=False, 446 | help="The output directory where the model checkpoints and predictions will be written.") 447 | 448 | parser.add_argument("--pretrained_ckpt_path", default=None, type=str, required=False, 449 | help="pretrained_passage_encoder_paramaters") 450 | parser.add_argument("--test_ckpt_path", default=None, type=str, required=False, 451 | help="trained_dul_encoder_paramaters") 452 | 453 | parser.add_argument("--load_small", default=False, type=str2bool, required=False, 454 | help="whether to load just a small portion of data during development") 455 | parser.add_argument("--num_workers", default=4, type=int, required=False, 456 | help="number of workers for dataloader") 457 | 458 | # training 459 | parser.add_argument("--do_train", default=True, type=str2bool, 460 | help="Whether to run training.") 461 | parser.add_argument("--do_test", default=True, type=str2bool, 462 | help="Whether to run eval on the test set.") 463 | # parameters 464 | parser.add_argument("--prepend_answers", default=False, type=str2bool, 465 | help="Whether to prepend answers.") 466 | 467 | 468 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 469 | help="Batch size per GPU/CPU for training.") 470 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 471 | help="Batch size per GPU/CPU for evaluation.") 472 | parser.add_argument("--per_gpu_test_batch_size", default=2, type=int, 473 | help="Batch size per GPU/CPU for evaluation.") 474 | 475 | parser.add_argument("--learning_rate", default=1e-4, type=float, 476 | help="The initial learning rate for Adam.") 477 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 478 | help="Number of updates steps to accumulate before performing a backward/update pass.") 479 | parser.add_argument("--weight_decay", default=0.0, type=float, 480 | help="Weight decay if we apply some.") 481 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 482 | help="Epsilon for Adam optimizer.") 483 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 484 | help="Max gradient norm.") 485 | parser.add_argument("--num_train_epochs", default=40, type=float, 486 | help="Total number of training epochs to perform.") 487 | parser.add_argument("--max_steps", default=-1, type=int, 488 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 489 | parser.add_argument("--warmup_steps", default=0, type=int, 490 | help="Linear warmup over warmup_steps.") 491 | parser.add_argument("--warmup_portion", default=0.1, type=float, 492 | help="Linear warmup over warmup_steps (=t_total * warmup_portion). override warmup_steps ") 493 | parser.add_argument('--do_lower_case', type=str2bool, default=True, 494 | help="tokenizer do_lower_case") 495 | 496 | parser.add_argument('--top_k', type=int, default=5, 497 | help="the number of retrieved passages") 498 | parser.add_argument('--beam_size', type=int, default=5, 499 | help="the number of retrieved passages") 500 | 501 | parser.add_argument('--query_max_seq_length', type=int, default=384, 502 | help="passage_max_seq_length") 503 | parser.add_argument('--target_max_seq_length', type=int, default=64, 504 | help="passage_max_seq_length") 505 | 506 | parser.add_argument('--logging_steps', type=int, default=10, 507 | help="Log every X updates steps.") 508 | parser.add_argument('--save_steps', type=int, default=-1, 509 | help="Save checkpoint every X updates steps.") 510 | 511 | parser.add_argument("--no_cuda", default=False, type=str2bool, 512 | help="Whether not to use CUDA when available") 513 | parser.add_argument('--overwrite_output_dir', default=True, type=str2bool, 514 | help="Overwrite the content of the output directory") 515 | parser.add_argument('--seed', type=int, default=42, 516 | help="random seed for initialization") 517 | 518 | parser.add_argument("--local_rank", type=int, default=-1, 519 | help="local_rank for distributed training on gpus") 520 | parser.add_argument('--fp16', default=False, type=str2bool, 521 | help="Whether to use 16-bit (mixed) precision") 522 | 523 | 524 | parser.add_argument("--model_type", default="t5-base", 525 | type=str, required=False, 526 | help="the type of pretrining model ") 527 | 528 | args, unknown = parser.parse_known_args() 529 | 530 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 531 | raise ValueError( 532 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 533 | 534 | if args.local_rank == -1 or args.no_cuda: 535 | device = torch.device( 536 | "cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu") 537 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 538 | torch.cuda.set_device(args.local_rank) 539 | device = torch.device("cuda", args.local_rank) 540 | torch.distributed.init_process_group(backend='nccl') 541 | 542 | args.device = device 543 | 544 | # Setup logging 545 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 546 | datefmt='%m/%d/%Y %H:%M:%S', 547 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 548 | logger.warning("Process rank: %s, device: %s, distributed training: %s, 16-bits training: %s", 549 | args.local_rank, device, bool(args.local_rank != -1), args.fp16) 550 | 551 | 552 | 553 | # Set seed 554 | set_seed(args) 555 | 556 | # Load pretrained model and tokenizer 557 | if args.local_rank not in [-1, 0]: 558 | # Make sure only the first process in distributed training will download model & vocab 559 | torch.distributed.barrier() 560 | 561 | 562 | tokenizer = T5Tokenizer.from_pretrained(args.model_type, 563 | do_lower_case=args.do_lower_case, 564 | cache_dir=args.cache_dir) 565 | 566 | 567 | with open(args.trie_dict, 'rb') as f: 568 | decoder_trie = Trie.load_from_dict(pickle.load(f)) 569 | logger.info("decoder_trie len %s", decoder_trie.len) 570 | 571 | model = Generative_Retrieval(args) 572 | 573 | if args.pretrained_ckpt_path is not None: 574 | model.load_state_dict(torch.load(args.pretrained_ckpt_path, map_location=torch.device('cpu'))) 575 | logger.info("load checkpoint from %s", args.pretrained_ckpt_path) 576 | 577 | 578 | 579 | 580 | passage_corpus = load_dataset('json', data_files=args.corpus_path, split="train") 581 | logger.info("passage_corpus info %s", passage_corpus) 582 | 583 | 584 | title2idx = {} 585 | with open(args.corpus_path, 'r') as f: 586 | num = 0 587 | for line in tqdm(f): 588 | line = json.loads(line) 589 | 590 | title2idx[line['title']] = num 591 | num+=1 592 | print('len title2idx', len(title2idx)) 593 | 594 | 595 | 596 | if args.local_rank == 0: 597 | # Make sure only the first process in distributed training will download model & vocab 598 | torch.distributed.barrier() 599 | 600 | model.to(args.device) 601 | 602 | 603 | 604 | 605 | 606 | logger.info("Training/evaluation parameters %s", args) 607 | if args.do_train: 608 | global_step, tr_loss, global_step_list = train( 609 | args, model, tokenizer) 610 | logger.info(" global_step = %s, average loss = %s", 611 | global_step, tr_loss) 612 | 613 | if args.local_rank != -1: 614 | torch.distributed.barrier() 615 | 616 | # do eval 617 | if args.local_rank in [-1, 0]: 618 | tb_writer = SummaryWriter(os.path.join(args.output_dir, 'logs')) 619 | 620 | max_acc = -0.1 621 | best_global_step = 0 622 | 623 | for global_step in global_step_list: 624 | 625 | model = Generative_Retrieval(args) 626 | args.test_ckpt_path = os.path.join( 627 | args.output_dir, 'checkpoint-{}'.format(global_step))+"model.pt" 628 | 629 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 630 | 631 | model.to(args.device) 632 | 633 | 634 | logger.info("test the checkpoint from %s", args.test_ckpt_path) 635 | acc = evaluate_dev(args, model, tokenizer) 636 | if args.local_rank in [-1, 0]: 637 | tb_writer.add_scalar('acc_dev', acc, global_step) 638 | if acc > max_acc: 639 | max_acc = acc 640 | best_global_step = global_step 641 | logger.info("max_acc = %s", str(max_acc)) 642 | logger.info("best_global_step = %s", str(best_global_step)) 643 | 644 | 645 | if args.do_test and args.local_rank in [-1, 0]: 646 | model = Generative_Retrieval(args) 647 | if args.do_train: 648 | args.test_ckpt_path = os.path.join( 649 | args.output_dir, 'checkpoint-{}'.format(best_global_step))+"model.pt" 650 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 651 | model.to(args.device) 652 | else: 653 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 654 | model.to(args.device) 655 | logger.info("test the checkpoint from %s", args.test_ckpt_path) 656 | 657 | 658 | args.eval_batch_size = args.per_gpu_test_batch_size 659 | # eval dataset load here to avoid load every time 660 | DatasetClass = FinetuningDataset 661 | eva_dataset = DatasetClass(args.test_file, tokenizer, 662 | args.load_small, 663 | query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 664 | eval_sampler = SequentialSampler(eva_dataset) 665 | 666 | eval_dataloader = DataLoader( 667 | eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 668 | 669 | 670 | # Evaluate on test set 671 | output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 672 | output_dict_file = os.path.join( 673 | args.output_dir, 'output_dict_file-test.json') 674 | if not os.path.exists(args.output_dir): 675 | os.makedirs(args.output_dir) 676 | with open(output_dict_file, 'w') as json_file: 677 | json.dump(output_dict, json_file) 678 | 679 | 680 | # eva_dataset = DatasetClass(args.dev_file, tokenizer, 681 | # args.load_small, 682 | # query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 683 | # eval_sampler = SequentialSampler(eva_dataset) 684 | 685 | # eval_dataloader = DataLoader( 686 | # eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 687 | 688 | 689 | # # Evaluate on dev set 690 | # output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 691 | # output_dict_file = os.path.join( 692 | # args.output_dir, 'output_dict_file-dev.json') 693 | # if not os.path.exists(args.output_dir): 694 | # os.makedirs(args.output_dir) 695 | # with open(output_dict_file, 'w') as json_file: 696 | # json.dump(output_dict, json_file) 697 | 698 | 699 | # eva_dataset = DatasetClass(args.train_file, tokenizer, 700 | # args.load_small, 701 | # query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 702 | # eval_sampler = SequentialSampler(eva_dataset) 703 | 704 | # eval_dataloader = DataLoader( 705 | # eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 706 | 707 | 708 | # # Evaluate on dev set 709 | # output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 710 | # output_dict_file = os.path.join( 711 | # args.output_dir, 'output_dict_file-train.json') 712 | # if not os.path.exists(args.output_dir): 713 | # os.makedirs(args.output_dir) 714 | # with open(output_dict_file, 'w') as json_file: 715 | # json.dump(output_dict, json_file) -------------------------------------------------------------------------------- /document_level/train_query_encoder.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # In[1]: 4 | 5 | 6 | from __future__ import absolute_import, division, print_function 7 | 8 | import os 9 | 10 | 11 | import argparse 12 | import logging 13 | import os 14 | import random 15 | import glob 16 | import timeit 17 | import json 18 | 19 | 20 | from tqdm import tqdm, trange 21 | from copy import copy 22 | import re 23 | import torch 24 | import copy 25 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 26 | TensorDataset) 27 | from torch.utils.data.distributed import DistributedSampler 28 | 29 | try: 30 | from torch.utils.tensorboard import SummaryWriter 31 | except: 32 | from tensorboardX import SummaryWriter 33 | import transformers 34 | from transformers import T5Tokenizer 35 | from transformers import AdamW, get_linear_schedule_with_warmup 36 | from retriever_utils import FinetuningDataset 37 | from modeling import Generative_Retrieval 38 | 39 | import pickle 40 | from torch.cuda.amp import autocast as autocast 41 | import numpy as np 42 | from datasets import load_dataset 43 | from multiprocessing import Pool 44 | from utils import Trie 45 | 46 | from contextlib import contextmanager 47 | # In[2]: 48 | np_str_obj_array_pattern = re.compile(r'[SaUO]') 49 | 50 | logger = logging.getLogger(__name__) 51 | 52 | 53 | transformers.logging.set_verbosity_error() 54 | 55 | 56 | # In[3]: 57 | 58 | 59 | def set_seed(args): 60 | random.seed(args.seed) 61 | np.random.seed(args.seed) 62 | torch.manual_seed(args.seed) 63 | torch.cuda.manual_seed_all(args.seed) 64 | 65 | 66 | 67 | 68 | def str2bool(v): 69 | if isinstance(v, bool): 70 | return v 71 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 72 | return True 73 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 74 | return False 75 | else: 76 | raise argparse.ArgumentTypeError('Boolean value expected.') 77 | 78 | #######################################################yongqi 79 | def flat(l): 80 | for k in l: 81 | if not isinstance(k, (list, tuple)): 82 | yield k 83 | else: 84 | yield from flat(k) 85 | 86 | def prefix_allowed_tokens_fn(batch_id, sent): 87 | return decoder_trie.get(sent.tolist()) 88 | 89 | def dist_gather_tensor(t): 90 | if t is None: 91 | return None 92 | t = t.contiguous() 93 | 94 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 95 | torch.distributed.all_gather(all_tensors, t) 96 | 97 | all_tensors[torch.distributed.get_rank()] = t 98 | all_tensors = torch.cat(all_tensors, dim=0) 99 | 100 | return all_tensors 101 | 102 | 103 | 104 | def train(args, model, tokenizer): 105 | DatasetClass = FinetuningDataset 106 | train_dataset = DatasetClass(args.train_file, tokenizer, 107 | args.load_small, 108 | query_max_seq_length=args.query_max_seq_length,target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 109 | 110 | """ Train the model """ 111 | if args.local_rank in [-1, 0]: 112 | tb_writer = SummaryWriter(os.path.join(args.output_dir, 'logs')) 113 | 114 | 115 | args.train_batch_size = args.per_gpu_train_batch_size 116 | train_sampler = RandomSampler( 117 | train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 118 | 119 | train_dataloader = DataLoader( 120 | train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=args.num_workers) 121 | 122 | 123 | if args.max_steps > 0: 124 | t_total = args.max_steps 125 | args.num_train_epochs = args.max_steps // ( 126 | len(train_dataloader) // args.gradient_accumulation_steps) + 1 127 | else: 128 | t_total = len( 129 | train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 130 | 131 | 132 | 133 | if args.fp16: 134 | scaler = torch.cuda.amp.GradScaler(enabled=args.fp16) 135 | 136 | 137 | # Distributed training (should be after apex fp16 initialization) 138 | if args.local_rank != -1: 139 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 140 | output_device=args.local_rank, 141 | find_unused_parameters=True) 142 | 143 | 144 | # Prepare optimizer and schedule (linear warmup and decay) 145 | no_decay = ['bias', 'LayerNorm.weight'] 146 | optimizer_grouped_parameters = [ 147 | {'params': [p for n, p in model.named_parameters() if not any( 148 | nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 149 | {'params': [p for n, p in model.named_parameters() if any( 150 | nd in n for nd in no_decay)], 'weight_decay': 0.0} 151 | ] 152 | 153 | 154 | optimizer = AdamW(optimizer_grouped_parameters, 155 | lr=args.learning_rate, eps=args.adam_epsilon) 156 | 157 | if args.warmup_steps == 0: 158 | args.warmup_steps = int(t_total * args.warmup_portion) 159 | 160 | scheduler = get_linear_schedule_with_warmup( 161 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 162 | 163 | # Train! 164 | logger.info("***** Running training *****") 165 | logger.info(" Num examples = %d", len(train_dataset)) 166 | logger.info(" Num Epochs = %d", args.num_train_epochs) 167 | logger.info(" Instantaneous batch size per GPU = %d", 168 | args.per_gpu_train_batch_size) 169 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 170 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 171 | logger.info(" Gradient Accumulation steps = %d", 172 | args.gradient_accumulation_steps) 173 | logger.info(" Total optimization steps = %d", t_total) 174 | 175 | global_step = 1 176 | tr_loss, logging_loss = 0.0, 0.0 177 | model.zero_grad() 178 | train_iterator = trange(int(args.num_train_epochs), 179 | desc="Epoch", disable=args.local_rank not in [-1, 0]) 180 | # Added here for reproductibility (even between python 2 and 3) 181 | 182 | global_step_list = [] 183 | for epoch in train_iterator: 184 | 185 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", 186 | disable=args.local_rank not in [-1, 0]) 187 | if args.local_rank != -1: 188 | train_sampler.set_epoch(epoch) 189 | 190 | #######################################################yongqi 191 | for step, batch in enumerate(epoch_iterator): 192 | 193 | 194 | 195 | 196 | #######################################################yongqi 197 | model.train() 198 | query_input_ids = batch['query_input_ids'] 199 | query_attention_mask = batch['query_attention_mask'] 200 | 201 | target_input_ids = batch['target_input_ids'] 202 | target_attention_mask = batch['target_attention_mask'] 203 | 204 | target_input_ids[target_attention_mask == 0] = -100 205 | 206 | 207 | inputs = {'args': args, 208 | 'query_input_ids': query_input_ids.to(args.device), 209 | 'query_attention_mask': query_attention_mask.to(args.device), 210 | 'target_input_ids': target_input_ids.to(args.device), 211 | 'target_attention_mask': target_attention_mask.to(args.device), 212 | 'mode': "train"} 213 | if args.fp16: 214 | with torch.cuda.amp.autocast(enabled=args.fp16): 215 | loss = model(**inputs) 216 | else: 217 | loss = model(**inputs) 218 | 219 | 220 | if args.gradient_accumulation_steps > 1: 221 | loss = loss / args.gradient_accumulation_steps 222 | 223 | if args.fp16: 224 | scaler.scale(loss).backward() 225 | else: 226 | loss.backward() 227 | tr_loss += loss.item() 228 | if (step + 1) % args.gradient_accumulation_steps == 0: 229 | if args.fp16: 230 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 231 | scaler.step(optimizer) 232 | # Updates the scale for next iteration. 233 | scaler.update() 234 | 235 | scheduler.step() 236 | model.zero_grad() 237 | global_step += 1 238 | else: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 240 | optimizer.step() 241 | 242 | scheduler.step() # Update learning rate schedule 243 | model.zero_grad() 244 | global_step += 1 245 | 246 | # print('loss', loss.item()) 247 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 248 | # Log metrics 249 | # Only evaluate when single GPU otherwise metrics may not average well 250 | tb_writer.add_scalar( 251 | 'lr', scheduler.get_lr()[0], global_step) 252 | tb_writer.add_scalar( 253 | 'loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 254 | logging_loss = tr_loss 255 | 256 | 257 | if args.save_steps == -1: 258 | global_step_list.append(global_step) 259 | if args.local_rank in [-1, 0]: 260 | 261 | # Take care of distributed/parallel training 262 | model_to_save = model.module if hasattr( 263 | model, 'module') else model 264 | # Save model checkpoint 265 | output_dir = os.path.join( 266 | args.output_dir, 'checkpoint-{}'.format(global_step)) 267 | torch.save(model_to_save.state_dict(), output_dir+"model.pt") 268 | logger.info("Saving model checkpoint to %s", output_dir) 269 | 270 | 271 | return global_step, tr_loss / global_step, global_step_list 272 | 273 | # In[5]: 274 | def evaluate_dev(args, model, tokenizer): 275 | args.eval_batch_size = args.per_gpu_eval_batch_size 276 | # eval dataset load here to avoid load every time 277 | DatasetClass = FinetuningDataset 278 | eva_dataset = DatasetClass(args.dev_file, tokenizer, 279 | args.load_small, 280 | query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 281 | eval_sampler = RandomSampler( 282 | eva_dataset) if args.local_rank == -1 else DistributedSampler(eva_dataset) 283 | 284 | eval_dataloader = DataLoader( 285 | eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 286 | 287 | 288 | # Distributed training (should be after apex fp16 initialization) 289 | if args.local_rank != -1: 290 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 291 | output_device=args.local_rank, 292 | find_unused_parameters=True) 293 | 294 | # Eval! 295 | logger.info("***** Running evaluation dev *****") 296 | logger.info(" Num examples = %d", len(eva_dataset)) 297 | logger.info(" Batch size = %d", args.eval_batch_size) 298 | 299 | correct_num = 0.0 300 | total_num = 0.0 301 | 302 | 303 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): 304 | 305 | model.eval() 306 | query_input_ids = batch['query_input_ids'] 307 | query_attention_mask = batch['query_attention_mask'] 308 | 309 | target_input_ids = batch['target_input_ids'] 310 | target_attention_mask = batch['target_attention_mask'] 311 | 312 | query_text = batch['query_text'] 313 | target_text = batch['target_text'] 314 | 315 | with torch.no_grad(): 316 | inputs = {'args': args, 317 | 'query_input_ids': query_input_ids.to(args.device), 318 | 'query_attention_mask': query_attention_mask.to(args.device), 319 | 'prefix_allowed_tokens_fn': prefix_allowed_tokens_fn, 320 | 'mode': "dev"} 321 | outputs = model(**inputs) 322 | 323 | predicted_target_text = [tokenizer.decode(g, skip_special_tokens=True) for g in outputs] 324 | 325 | for i in range(len(query_text)): 326 | total_num = total_num + 1 327 | if target_text[i] == predicted_target_text[i]: 328 | correct_num = correct_num + 1 329 | correct_num_gather = torch.from_numpy(np.array([correct_num])).to(args.device) 330 | total_num_gather = torch.from_numpy(np.array([total_num])).to(args.device) 331 | 332 | if args.local_rank != -1: 333 | correct_num_gather = torch.sum(dist_gather_tensor(correct_num_gather)) 334 | total_num_gather = torch.sum(dist_gather_tensor(total_num_gather)) 335 | correct_num = correct_num_gather.item() 336 | total_num = total_num_gather.item() 337 | acc = correct_num/total_num 338 | logger.info(" correct_num = %d", correct_num) 339 | logger.info(" total_num = %d", total_num) 340 | logger.info(" acc = %s", str(acc)) 341 | return acc 342 | 343 | def evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader): 344 | 345 | 346 | # Eval! 347 | logger.info("***** Running evaluation dev *****") 348 | logger.info(" Num examples = %d", len(eva_dataset)) 349 | logger.info(" Batch size = %d", args.eval_batch_size) 350 | 351 | correct_num = [0.0]*args.top_k 352 | total_num = 0.0 353 | output_dict = [] 354 | 355 | for batch in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]): 356 | 357 | model.eval() 358 | query_input_ids = batch['query_input_ids'] 359 | query_attention_mask = batch['query_attention_mask'] 360 | 361 | target_input_ids = batch['target_input_ids'] 362 | target_attention_mask = batch['target_attention_mask'] 363 | 364 | query_text = batch['query_text'] 365 | target_text = batch['target_text'] 366 | 367 | 368 | with torch.no_grad(): 369 | inputs = {'args': args, 370 | 'query_input_ids': query_input_ids.to(args.device), 371 | 'query_attention_mask': query_attention_mask.to(args.device), 372 | 'prefix_allowed_tokens_fn': prefix_allowed_tokens_fn, 373 | 'mode': "test"} 374 | outputs = model(**inputs) 375 | 376 | 377 | 378 | predicted_target_text = [tokenizer.decode(g, skip_special_tokens=True) for g in outputs] 379 | 380 | for i in range(len(query_text)): 381 | total_num = total_num + 1 382 | entry={} 383 | entry['question'] = query_text[i] 384 | entry['answers'] = [batch['answer_text'][i]] 385 | entry['ctxs'] = [] 386 | 387 | for j in range(args.top_k): 388 | if target_text[i] == predicted_target_text[i*args.top_k+j]: 389 | correct_num[j] = correct_num[j] + 1 390 | if predicted_target_text[i*args.top_k+j] in title2idx: 391 | idx = title2idx[predicted_target_text[i*args.top_k+j]] 392 | entry['ctxs'].append(passage_corpus[idx]) 393 | else: 394 | entry['ctxs'].append({"id": 0, 'title':predicted_target_text[i*args.top_k+j], 'text': ""}) 395 | output_dict.append(entry) 396 | 397 | for k in [1,3,5,10,20,50]: 398 | new_correct_num = correct_num[:k] 399 | correct_num_k = sum(new_correct_num) 400 | recall_k = correct_num_k/total_num 401 | 402 | 403 | mrr = 0.0 404 | for j in range(len(new_correct_num)): 405 | mrr += float(new_correct_num[j])/(j+1) 406 | mrr = mrr/total_num 407 | logger.info("correct_num = %s", correct_num_k) 408 | logger.info("total_num = %s", total_num) 409 | logger.info("recall @ " + str(k) + " = %s", str(recall_k)) 410 | logger.info("mrr @ " + str(k) + " = %s", str(mrr)) 411 | return output_dict 412 | 413 | def dist_gather_tensor(t): 414 | if t is None: 415 | return None 416 | t = t.contiguous() 417 | 418 | all_tensors = [torch.empty_like(t) for _ in range(torch.distributed.get_world_size())] 419 | torch.distributed.all_gather(all_tensors, t) 420 | 421 | all_tensors[torch.distributed.get_rank()] = t 422 | all_tensors = torch.cat(all_tensors, dim=0) 423 | 424 | return all_tensors 425 | 426 | parser = argparse.ArgumentParser() 427 | 428 | # data 429 | parser.add_argument("--train_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_train.json", 430 | type=str, required=False, 431 | help="training file ") 432 | parser.add_argument("--dev_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_dev.json", 433 | type=str, required=False, 434 | help="dev_file ") 435 | parser.add_argument("--test_file", default="/home/v-yongqili/project/GCoQA/data/QA_pairs/topiocqa/topiocqa_test.json", 436 | type=str, required=False, 437 | help="test_file ") 438 | parser.add_argument("--corpus_path", default="/home/v-yongqili/project/GCoQA/data/full_wiki_document.json", 439 | type=str, required=False, 440 | help="dev_file ") 441 | parser.add_argument("--trie_dict", default="/home/v-yongqili/project/GCoQA/data/trie_dict_t5-base_page_level.pkl", 442 | type=str, required=False, 443 | help="dev_file ") 444 | parser.add_argument("--cache_dir", default="/home/v-yongqili/project/GCoQA/data/huggingface_cache/", type=str, 445 | help="Where do you want to store the pre-trained models downloaded from s3") 446 | parser.add_argument("--output_dir", default='./release_test1', type=str, required=False, 447 | help="The output directory where the model checkpoints and predictions will be written.") 448 | 449 | parser.add_argument("--pretrained_ckpt_path", default=None, type=str, required=False, 450 | help="pretrained_passage_encoder_paramaters") 451 | parser.add_argument("--test_ckpt_path", default=None, type=str, required=False, 452 | help="trained_dul_encoder_paramaters") 453 | 454 | parser.add_argument("--load_small", default=False, type=str2bool, required=False, 455 | help="whether to load just a small portion of data during development") 456 | parser.add_argument("--num_workers", default=4, type=int, required=False, 457 | help="number of workers for dataloader") 458 | 459 | # training 460 | parser.add_argument("--do_train", default=True, type=str2bool, 461 | help="Whether to run training.") 462 | parser.add_argument("--do_test", default=True, type=str2bool, 463 | help="Whether to run eval on the test set.") 464 | # parameters 465 | parser.add_argument("--prepend_answers", default=False, type=str2bool, 466 | help="Whether to prepend answers.") 467 | 468 | 469 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 470 | help="Batch size per GPU/CPU for training.") 471 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 472 | help="Batch size per GPU/CPU for evaluation.") 473 | parser.add_argument("--per_gpu_test_batch_size", default=2, type=int, 474 | help="Batch size per GPU/CPU for evaluation.") 475 | 476 | parser.add_argument("--learning_rate", default=1e-4, type=float, 477 | help="The initial learning rate for Adam.") 478 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 479 | help="Number of updates steps to accumulate before performing a backward/update pass.") 480 | parser.add_argument("--weight_decay", default=0.0, type=float, 481 | help="Weight decay if we apply some.") 482 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 483 | help="Epsilon for Adam optimizer.") 484 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 485 | help="Max gradient norm.") 486 | parser.add_argument("--num_train_epochs", default=40, type=float, 487 | help="Total number of training epochs to perform.") 488 | parser.add_argument("--max_steps", default=-1, type=int, 489 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 490 | parser.add_argument("--warmup_steps", default=0, type=int, 491 | help="Linear warmup over warmup_steps.") 492 | parser.add_argument("--warmup_portion", default=0.1, type=float, 493 | help="Linear warmup over warmup_steps (=t_total * warmup_portion). override warmup_steps ") 494 | parser.add_argument('--do_lower_case', type=str2bool, default=True, 495 | help="tokenizer do_lower_case") 496 | 497 | parser.add_argument('--top_k', type=int, default=5, 498 | help="the number of retrieved passages") 499 | parser.add_argument('--beam_size', type=int, default=5, 500 | help="the number of retrieved passages") 501 | 502 | parser.add_argument('--query_max_seq_length', type=int, default=384, 503 | help="passage_max_seq_length") 504 | parser.add_argument('--target_max_seq_length', type=int, default=64, 505 | help="passage_max_seq_length") 506 | 507 | parser.add_argument('--logging_steps', type=int, default=10, 508 | help="Log every X updates steps.") 509 | parser.add_argument('--save_steps', type=int, default=-1, 510 | help="Save checkpoint every X updates steps.") 511 | 512 | parser.add_argument("--no_cuda", default=False, type=str2bool, 513 | help="Whether not to use CUDA when available") 514 | parser.add_argument('--overwrite_output_dir', default=True, type=str2bool, 515 | help="Overwrite the content of the output directory") 516 | parser.add_argument('--seed', type=int, default=42, 517 | help="random seed for initialization") 518 | 519 | parser.add_argument("--local_rank", type=int, default=-1, 520 | help="local_rank for distributed training on gpus") 521 | parser.add_argument('--fp16', default=False, type=str2bool, 522 | help="Whether to use 16-bit (mixed) precision") 523 | 524 | 525 | parser.add_argument("--model_type", default="t5-base", 526 | type=str, required=False, 527 | help="the type of pretrining model ") 528 | 529 | args, unknown = parser.parse_known_args() 530 | 531 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 532 | raise ValueError( 533 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 534 | 535 | if args.local_rank == -1 or args.no_cuda: 536 | device = torch.device( 537 | "cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu") 538 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 539 | torch.cuda.set_device(args.local_rank) 540 | device = torch.device("cuda", args.local_rank) 541 | torch.distributed.init_process_group(backend='nccl') 542 | 543 | args.device = device 544 | 545 | # Setup logging 546 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 547 | datefmt='%m/%d/%Y %H:%M:%S', 548 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 549 | logger.warning("Process rank: %s, device: %s, distributed training: %s, 16-bits training: %s", 550 | args.local_rank, device, bool(args.local_rank != -1), args.fp16) 551 | 552 | 553 | 554 | # Set seed 555 | set_seed(args) 556 | 557 | # Load pretrained model and tokenizer 558 | if args.local_rank not in [-1, 0]: 559 | # Make sure only the first process in distributed training will download model & vocab 560 | torch.distributed.barrier() 561 | 562 | 563 | tokenizer = T5Tokenizer.from_pretrained(args.model_type, 564 | do_lower_case=args.do_lower_case, 565 | cache_dir=args.cache_dir) 566 | 567 | 568 | with open(args.trie_dict, 'rb') as f: 569 | decoder_trie = Trie.load_from_dict(pickle.load(f)) 570 | logger.info("decoder_trie len %s", decoder_trie.len) 571 | 572 | model = Generative_Retrieval(args) 573 | 574 | if args.pretrained_ckpt_path is not None: 575 | model.load_state_dict(torch.load(args.pretrained_ckpt_path, map_location=torch.device('cpu'))) 576 | logger.info("load checkpoint from %s", args.pretrained_ckpt_path) 577 | 578 | 579 | 580 | 581 | 582 | 583 | passage_corpus = load_dataset('json', data_files=args.corpus_path, split="train") 584 | logger.info("passage_corpus info %s", passage_corpus) 585 | 586 | 587 | title2idx = {} 588 | with open(args.corpus_path, 'r') as f: 589 | num = 0 590 | for line in tqdm(f): 591 | line = json.loads(line) 592 | 593 | title2idx[line['title']] = num 594 | num+=1 595 | print('len title2idx', len(title2idx)) 596 | 597 | 598 | if args.local_rank == 0: 599 | # Make sure only the first process in distributed training will download model & vocab 600 | torch.distributed.barrier() 601 | 602 | model.to(args.device) 603 | 604 | 605 | 606 | 607 | 608 | logger.info("Training/evaluation parameters %s", args) 609 | if args.do_train: 610 | global_step, tr_loss, global_step_list = train( 611 | args, model, tokenizer) 612 | logger.info(" global_step = %s, average loss = %s", 613 | global_step, tr_loss) 614 | 615 | if args.local_rank != -1: 616 | torch.distributed.barrier() 617 | 618 | # do eval 619 | if args.local_rank in [-1, 0]: 620 | tb_writer = SummaryWriter(os.path.join(args.output_dir, 'logs')) 621 | 622 | max_acc = -0.1 623 | best_global_step = 0 624 | 625 | for global_step in global_step_list: 626 | 627 | model = Generative_Retrieval(args) 628 | args.test_ckpt_path = os.path.join( 629 | args.output_dir, 'checkpoint-{}'.format(global_step))+"model.pt" 630 | 631 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 632 | 633 | model.to(args.device) 634 | 635 | 636 | logger.info("test the checkpoint from %s", args.test_ckpt_path) 637 | acc = evaluate_dev(args, model, tokenizer) 638 | if args.local_rank in [-1, 0]: 639 | tb_writer.add_scalar('acc_dev', acc, global_step) 640 | if acc > max_acc: 641 | max_acc = acc 642 | best_global_step = global_step 643 | logger.info("max_acc = %s", str(max_acc)) 644 | logger.info("best_global_step = %s", str(best_global_step)) 645 | 646 | 647 | if args.do_test and args.local_rank in [-1, 0]: 648 | model = Generative_Retrieval(args) 649 | if args.do_train: 650 | args.test_ckpt_path = os.path.join( 651 | args.output_dir, 'checkpoint-{}'.format(best_global_step))+"model.pt" 652 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 653 | model.to(args.device) 654 | else: 655 | model.load_state_dict(torch.load(args.test_ckpt_path, map_location=torch.device('cpu'))) 656 | model.to(args.device) 657 | logger.info("test the checkpoint from %s", args.test_ckpt_path) 658 | 659 | 660 | args.eval_batch_size = args.per_gpu_test_batch_size 661 | # eval dataset load here to avoid load every time 662 | DatasetClass = FinetuningDataset 663 | eva_dataset = DatasetClass(args.test_file, tokenizer, 664 | args.load_small, 665 | query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 666 | eval_sampler = SequentialSampler(eva_dataset) 667 | 668 | eval_dataloader = DataLoader( 669 | eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 670 | 671 | 672 | # Evaluate on test set 673 | output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 674 | output_dict_file = os.path.join( 675 | args.output_dir, 'output_dict_file-test.json') 676 | if not os.path.exists(args.output_dir): 677 | os.makedirs(args.output_dir) 678 | with open(output_dict_file, 'w') as json_file: 679 | json.dump(output_dict, json_file) 680 | 681 | 682 | # eva_dataset = DatasetClass(args.dev_file, tokenizer, 683 | # args.load_small, 684 | # query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 685 | # eval_sampler = SequentialSampler(eva_dataset) 686 | 687 | # eval_dataloader = DataLoader( 688 | # eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 689 | 690 | 691 | # # Evaluate on dev set 692 | # output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 693 | # output_dict_file = os.path.join( 694 | # args.output_dir, 'output_dict_file-dev.json') 695 | # if not os.path.exists(args.output_dir): 696 | # os.makedirs(args.output_dir) 697 | # with open(output_dict_file, 'w') as json_file: 698 | # json.dump(output_dict, json_file) 699 | 700 | 701 | # eva_dataset = DatasetClass(args.train_file, tokenizer, 702 | # args.load_small, 703 | # query_max_seq_length=args.query_max_seq_length, target_max_seq_length=args.target_max_seq_length, prepend_answers=args.prepend_answers) 704 | # eval_sampler = SequentialSampler(eva_dataset) 705 | 706 | # eval_dataloader = DataLoader( 707 | # eva_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, num_workers=args.num_workers) 708 | 709 | 710 | # # Evaluate on dev set 711 | # output_dict = evaluate_test(args, model, tokenizer, eva_dataset, eval_dataloader) 712 | # output_dict_file = os.path.join( 713 | # args.output_dir, 'output_dict_file-train.json') 714 | # if not os.path.exists(args.output_dir): 715 | # os.makedirs(args.output_dir) 716 | # with open(output_dict_file, 'w') as json_file: 717 | # json.dump(output_dict, json_file) --------------------------------------------------------------------------------