├── path ├── run_eval.sh ├── scripts ├── gen_wow_json.sh ├── gen_dialdoc_json.sh ├── gen_data_cls.sh ├── train.sh └── eval.sh ├── config.py ├── models ├── __init__.py ├── model_builder.py ├── loss.py ├── perturbation.py ├── hf_models.py └── reader.py ├── data_utils ├── __init__.py ├── reader_dataset.py ├── data_class.py ├── utils.py ├── doc2dial_reader.py └── data_collator.py ├── run.sh ├── setup.sh ├── utils ├── utils.py ├── checkpoint.py ├── model_utils.py ├── dist_utils.py ├── sampler.py └── options.py ├── download_hf_model.py ├── eval.py ├── gen_data.py ├── README.md ├── environment.yml ├── prepro ├── create_dialdoc_json.py └── create_wow_json.py └── train_reader.py /path: -------------------------------------------------------------------------------- 1 | parent_dir=./dialdoc 2 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | 2 | dataname=$1 3 | ckpt=$2 4 | output_path=$3 5 | 6 | bash scripts/eval.sh $dataname $ckpt $output_path 7 | -------------------------------------------------------------------------------- /scripts/gen_wow_json.sh: -------------------------------------------------------------------------------- 1 | . path 2 | 3 | python ./prepro/create_wow_json.py \ 4 | --cache_dir "${parent_dir}/data/" 5 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | AGENT='' 2 | USER='' 3 | PARENT_TITLE='' 4 | TEXT='' 5 | 6 | TOKENS=[AGENT, USER, PARENT_TITLE, TEXT] -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import Reader, compute_loss 2 | from .hf_models import HFBertEncoder 3 | from .perturbation import SmartPerturbation 4 | from .loss import LOSS 5 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .doc2dial_reader import Doc2DialReader 2 | from .data_class import ReaderSample, ReaderPassage, SpanPrediction 3 | from .data_collator import DataCollator 4 | from .reader_dataset import ReaderDataset -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | 3 | data_name=$1 4 | 5 | python download_hf_model.py --data_name=$data_name 6 | 7 | bash scripts/gen_${data_name}_json.sh 8 | bash scripts/gen_data_cls.sh ${data_name} 9 | bash scripts/train.sh ${data_name} 10 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/bash -xe 2 | 3 | . path 4 | 5 | cache=$parent_dir/cache 6 | exp=$parent_dir/exp 7 | data=$parent_dir/data 8 | pretrained_models=$parent_dir/pretrained_models 9 | 10 | mkdir -p $cache 11 | mkdir -p $exp 12 | mkdir -p $data 13 | mkdir -p $pretrained_models 14 | -------------------------------------------------------------------------------- /scripts/gen_dialdoc_json.sh: -------------------------------------------------------------------------------- 1 | . path 2 | 3 | python ./prepro/create_dialdoc_json.py \ 4 | --dtype train \ 5 | --filepath "${parent_dir}/raw_data/doc2dial_dial_{}.json" \ 6 | --outfile "${parent_dir}/data/{}.json" 7 | 8 | python ./prepro/create_dialdoc_json.py \ 9 | --dtype validation \ 10 | --filepath "${parent_dir}/raw_data/doc2dial_dial_{}.json" \ 11 | --outfile "${parent_dir}/data/{}.json" 12 | -------------------------------------------------------------------------------- /models/model_builder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | def get_bert_reader_components(args, inference_only: bool = False, **kwargs): 4 | 5 | tokenizer = AutoTokenizer(args.pretrained_model_cfg) 6 | encoder = HFBertEncoder.init_encoder(args, len(tokenizer)) 7 | 8 | reader = Reader(args, encoder) 9 | 10 | optimizer = ( 11 | get_optimizer( 12 | reader, 13 | learning_rate=args.learning_rate, 14 | adam_eps=args.adam_eps, 15 | weight_decay=args.weight_decay, 16 | ) 17 | if not inference_only 18 | else None 19 | ) 20 | 21 | return tokenizer, reader, optimizer 22 | -------------------------------------------------------------------------------- /scripts/gen_data_cls.sh: -------------------------------------------------------------------------------- 1 | . path 2 | 3 | hf_model_name=bert-base-uncased 4 | data_name=$1 5 | 6 | 7 | IFS=- read hf_model_type the_rest <<< "$hf_model_name" 8 | 9 | pretrained_model_dir=${parent_dir}/pretrained_models/${hf_model_name} 10 | base_dir=${parent_dir} 11 | if [ $data_name == "dialdoc" ]; then 12 | max_seq_len=512 13 | else 14 | max_seq_len=384 15 | fi 16 | 17 | cache_dir=${base_dir}/cache/cls_${hf_model_type} 18 | [ ! -d $pretrained_model_dir ] && echo "$pretrained_model_dir does not exist!" && exit 1 19 | mkdir -p $cache_dir 20 | 21 | python gen_data.py \ 22 | --pretrained_model_dir $pretrained_model_dir \ 23 | --input_dir ${base_dir}/data \ 24 | --data_name $data_name \ 25 | --output_dir $cache_dir \ 26 | --max_seq_len $max_seq_len \ 27 | --max_history_len 128 \ 28 | --max_num_spans_per_passage 50 \ 29 | --use_cls_span_start 30 | -------------------------------------------------------------------------------- /data_utils/reader_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import pickle 5 | 6 | import torch 7 | 8 | from utils import dist_utils 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | class ReaderDataset(torch.utils.data.Dataset): 14 | def __init__(self, data_dir): 15 | paths = glob.glob(os.path.join(data_dir, '*')) 16 | 17 | if dist_utils.is_local_master(): 18 | logger.info(f"Data dir: {data_dir}") 19 | logger.info(f"Data paths: {paths}") 20 | 21 | assert paths, "No Data files found." 22 | data = [] 23 | for path in paths: 24 | with open(path, "rb") as f: 25 | data.extend(pickle.load(f)) 26 | 27 | if dist_utils.is_local_master(): 28 | logger.info(f"Total data size: {len(data)}") 29 | 30 | for d in data: 31 | d.to_tensor() 32 | self.data = data 33 | 34 | def __getitem__(self, idx): 35 | return self.data[idx] 36 | 37 | def __len__(self): 38 | return len(self.data) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from .dist_utils import is_local_master 5 | 6 | logger = logging.getLogger() 7 | 8 | def print_section_bar(string, width=80): 9 | if is_local_master(): 10 | logger.info('='*width) 11 | half_width = (width - len(string)) // 2 12 | logger.info(' '*half_width + string) 13 | logger.info('='*width) 14 | 15 | 16 | def convert_to_at_k(lst): 17 | # we evaluate @k so if the k-1 span matches, then k, k+1, k+2 18 | # will also be regarded as match. 19 | at_k = [-1] * len(lst) 20 | for i in range(len(lst)): 21 | if lst[i] == -1: 22 | break 23 | # sum until itself 24 | at_k[i] = min(sum(lst[:i+1]), 1) 25 | return at_k 26 | 27 | 28 | def softlink(target, link_name): 29 | temp_link = link_name + '.new' 30 | try: 31 | os.remove(temp_link) 32 | except OSError: 33 | pass 34 | os.symlink(target, temp_link) 35 | os.rename(temp_link, link_name) 36 | 37 | 38 | def print_args(args): 39 | print_section_bar('*'*16 + 'CONFIGURATION' + '*'*16) 40 | for key, val in sorted(vars(args).items()): 41 | logger.info(f'{key:<30} --> {val}') 42 | -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import logging 5 | import collections 6 | 7 | from .dist_utils import is_local_master 8 | 9 | logger = logging.getLogger() 10 | 11 | CheckpointState = collections.namedtuple( 12 | "CheckpointState", 13 | [ 14 | "model_dict", 15 | "optimizer_dict", 16 | "scheduler_dict", 17 | "amp_dict", 18 | "offset", 19 | "epoch", 20 | "global_step", 21 | "encoder_params", 22 | ], 23 | ) 24 | 25 | 26 | def get_saved_checkpoints(args, file_prefix): 27 | cp_paths = [] 28 | if args.output_dir: 29 | cp_paths = glob.glob( 30 | os.path.join(args.output_dir, file_prefix + "*")) 31 | 32 | if len(cp_paths) > 0: 33 | cp_paths = sorted(cp_paths, key=os.path.getctime) 34 | return cp_paths 35 | 36 | 37 | def load_states_from_checkpoint(model_file): 38 | if is_local_master(): 39 | logger.info(f"Reading saved model from {model_file}") 40 | state_dict = torch.load(model_file, map_location="cpu") 41 | if is_local_master(): 42 | logger.info(f"model_state_dict keys {state_dict.keys()}") 43 | return CheckpointState(**state_dict) 44 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | . path 2 | 3 | export WORLD_SIZE=2 4 | 5 | output_dir=${parent_dir}/exp 6 | mkdir -p $output_dir 7 | 8 | data_name=$1 9 | if [ $data_name == "dialdoc" ]; then 10 | max_seq_len=512 11 | passages_per_question=8 12 | max_answer_length=5 13 | hist_loss_weight=1.0 14 | else 15 | max_seq_len=384 16 | passages_per_question=10 17 | max_answer_length=1 # one sentence only 18 | hist_loss_weight=0.5 19 | fi 20 | 21 | python -m torch.distributed.launch --nproc_per_node $WORLD_SIZE train_reader.py \ 22 | --pretrained_model_cfg ${parent_dir}/pretrained_models/bert-base-uncased \ 23 | --seed 42 \ 24 | --learning_rate 3e-5 \ 25 | --eval_step 1000 \ 26 | --do_lower_case \ 27 | --eval_top_docs 20 \ 28 | --warmup_steps 1000 \ 29 | --max_seq_len ${max_seq_len} \ 30 | --batch_size 2 \ 31 | --passages_per_question ${passages_per_question} \ 32 | --num_train_epochs 20 \ 33 | --dev_batch_size 4 \ 34 | --max_answer_length ${max_answer_length} \ 35 | --passages_per_question_predict 20 \ 36 | --train_file ${parent_dir}/cache/cls_bert/train \ 37 | --dev_file ${parent_dir}/cache/cls_bert/dev \ 38 | --output_dir $output_dir \ 39 | --gradient_accumulation_steps 1 \ 40 | --ignore_token_type \ 41 | --decision_function 1 \ 42 | --hist_loss_weight ${hist_loss_weight} \ 43 | --fp16 \ 44 | --fp16_opt_level O2 \ 45 | --data_name ${data_name} \ 46 | --adv_loss_type js \ 47 | --adv_loss_weight 5 \ 48 | --use_z_attn 49 | 50 | -------------------------------------------------------------------------------- /download_hf_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from dataclasses import dataclass, field 4 | from transformers import AutoModel, AutoTokenizer, HfArgumentParser 5 | 6 | 7 | @dataclass 8 | class Arguments: 9 | """ 10 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 11 | """ 12 | 13 | model_name: str = field( 14 | default="bert-base-uncased", 15 | metadata={"help": ""} 16 | ) 17 | data_name: str = field( 18 | default="dialdoc", 19 | metadata={"help": ""} 20 | ) 21 | 22 | parser = HfArgumentParser((Arguments)) 23 | args = parser.parse_args_into_dataclasses()[0] 24 | args.output_parent_dir = f"./{args.data_name}/pretrained_models" 25 | output_dir = os.path.join(args.output_parent_dir, args.model_name) 26 | if os.path.exists(output_dir): 27 | print(f'{output_dir} exists!') 28 | exit() 29 | 30 | print('='*100) 31 | print(' '*10, f'Download {args.model_name} model and tokenizer') 32 | print('='*100) 33 | m = AutoModel.from_pretrained(args.model_name) 34 | t = AutoTokenizer.from_pretrained(args.model_name) 35 | 36 | print('='*100) 37 | print(' '*10, f'Save {args.model_name} model and tokenizer') 38 | print('='*100) 39 | m.save_pretrained(output_dir) 40 | t.save_pretrained(output_dir) 41 | 42 | config_file = os.path.join(output_dir, "config.json") 43 | with open(config_file) as fin: 44 | cfg = json.loads(fin.read()) 45 | cfg["return_dict"] = False 46 | with open(config_file, "w") as fout: 47 | fout.write(json.dumps(cfg, indent=2)) 48 | 49 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | . path 2 | 3 | export WORLD_SIZE=1 4 | 5 | output_dir=${parent_dir}/exp 6 | mkdir -p $output_dir 7 | 8 | data_name=$1 9 | checkpoint_path=$2 10 | prediction_results_path=$3 11 | 12 | if [ $data_name == "dialdoc" ]; then 13 | max_seq_len=512 14 | passages_per_question=8 15 | max_answer_length=5 16 | hist_loss_weight=1.0 17 | else 18 | max_seq_len=384 19 | passages_per_question=10 20 | max_answer_length=1 # one sentence only 21 | hist_loss_weight=0.5 22 | fi 23 | 24 | python -m torch.distributed.launch --nproc_per_node $WORLD_SIZE train_reader.py \ 25 | --pretrained_model_cfg ${parent_dir}/pretrained_models/bert-base-uncased \ 26 | --checkpoint_file $checkpoint_path \ 27 | --prediction_results_file $prediction_results_path \ 28 | --seed 42 \ 29 | --learning_rate 3e-5 \ 30 | --eval_step 1000 \ 31 | --do_lower_case \ 32 | --eval_top_docs 20 \ 33 | --warmup_steps 1000 \ 34 | --max_seq_len ${max_seq_len} \ 35 | --batch_size 2 \ 36 | --passages_per_question ${passages_per_question} \ 37 | --num_train_epochs 20 \ 38 | --dev_batch_size 4 \ 39 | --max_answer_length ${max_answer_length} \ 40 | --passages_per_question_predict 20 \ 41 | --dev_file ${parent_dir}/cache/cls_bert/dev \ 42 | --output_dir $output_dir \ 43 | --gradient_accumulation_steps 1 \ 44 | --ignore_token_type \ 45 | --decision_function 1 \ 46 | --hist_loss_weight ${hist_loss_weight} \ 47 | --fp16 \ 48 | --fp16_opt_level O2 \ 49 | --data_name ${data_name} \ 50 | --adv_loss_type js \ 51 | --adv_loss_weight 5 \ 52 | --use_z_attn 53 | 54 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from collections import defaultdict, Counter 5 | import numpy as np 6 | import re 7 | import string 8 | import argparse 9 | 10 | 11 | def normalize_answer(s): 12 | """Lower text and remove punctuation, articles and extra whitespace.""" 13 | 14 | def remove_articles(text): 15 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 16 | return re.sub(regex, " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def get_tokens(s): 32 | if not s: 33 | return [] 34 | return normalize_answer(s).split() 35 | 36 | 37 | def compute_exact(a_gold, a_pred): 38 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 39 | 40 | 41 | def compute_f1(a_gold, a_pred): 42 | gold_toks = get_tokens(a_gold) 43 | pred_toks = get_tokens(a_pred) 44 | common = Counter(gold_toks) & Counter(pred_toks) 45 | num_same = sum(common.values()) 46 | if len(gold_toks) == 0 or len(pred_toks) == 0: 47 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 48 | return int(gold_toks == pred_toks) 49 | if num_same == 0: 50 | return 0 51 | precision = 1.0 * num_same / len(pred_toks) 52 | recall = 1.0 * num_same / len(gold_toks) 53 | f1 = (2 * precision * recall) / (precision + recall) 54 | return f1 -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def stable_kl(logit, target, epsilon=1e-6, reduce=True): 6 | logit = logit.view(-1, logit.size(-1)).float() 7 | target = target.view(-1, target.size(-1)).float() 8 | bs = logit.size(0) 9 | p = F.log_softmax(logit, 1).exp() 10 | y = F.log_softmax(target, 1).exp() 11 | rp = -(1.0/(p + epsilon) -1 + epsilon).detach().log() 12 | ry = -(1.0/(y + epsilon) -1 + epsilon).detach().log() 13 | if reduce: 14 | return (p* (rp - ry) * 2).sum() / bs 15 | else: 16 | return (p* (rp - ry) * 2).sum() 17 | 18 | 19 | def sym_kl(input, target, reduction='batchmean'): 20 | input = input.float() 21 | target = target.float() 22 | left = F.kl_div( 23 | F.log_softmax(input, dim=-1, dtype=torch.float32), 24 | F.softmax(target.detach(), dim=-1, dtype=torch.float32), 25 | reduction=reduction, 26 | ) 27 | 28 | right = F.kl_div( 29 | F.log_softmax(target, dim=-1, dtype=torch.float32), 30 | F.softmax(input.detach(), dim=-1, dtype=torch.float32), 31 | reduction=reduction, 32 | ) 33 | loss = left + right 34 | return loss 35 | 36 | 37 | def ns_sym_kl(input, target, reduction='batchmean'): 38 | input = input.float() 39 | target = target.float() 40 | loss = stable_kl(input, target.detach()) + \ 41 | stable_kl(target, input.detach()) 42 | return loss 43 | 44 | 45 | def js(input, target, reduction='batchmean'): 46 | input = input.float() 47 | target = target.float() 48 | m = F.softmax(target.detach(), dim=-1, dtype=torch.float32) + \ 49 | F.softmax(input.detach(), dim=-1, dtype=torch.float32) 50 | m = 0.5 * m 51 | loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), m, reduction=reduction) + \ 52 | F.kl_div(F.log_softmax(target, dim=-1, dtype=torch.float32), m, reduction=reduction) 53 | return loss 54 | 55 | 56 | def hl(input, target, reduction='batchmean'): 57 | input = input.float() 58 | target = target.float() 59 | si = F.softmax(target.detach(), dim=-1, dtype=torch.float32).sqrt_() 60 | st = F.softmax(input.detach(), dim=-1, dtype=torch.float32).sqrt_() 61 | loss = F.mse_loss(si, st) 62 | return loss 63 | 64 | 65 | LOSS = { 66 | 'sym_kl': sym_kl, 67 | 'ns_sym_kl': ns_sym_kl, 68 | 'js': js, 69 | 'hl': hl, 70 | } -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from data_utils import Doc2DialReader 3 | from transformers import AutoTokenizer 4 | from config import TOKENS 5 | 6 | def main(args): 7 | 8 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_dir) 9 | tokenizer.add_special_tokens( 10 | { 11 | 'additional_special_tokens': TOKENS, 12 | } 13 | ) 14 | 15 | reader = Doc2DialReader( 16 | args, 17 | args.input_dir, 18 | args.output_dir, 19 | tokenizer, 20 | args.max_seq_len, 21 | args.max_history_len, 22 | args.max_num_spans_per_passage, 23 | args.num_sample_per_file, 24 | ) 25 | 26 | reader.convert_json_to_finetune_pkl('train') 27 | reader.convert_json_to_finetune_pkl('dev') 28 | 29 | if args.data_name == 'wow': 30 | reader.convert_json_to_finetune_pkl('test') 31 | reader.convert_json_to_finetune_pkl('dev_unseen') 32 | reader.convert_json_to_finetune_pkl('test_unseen') 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument( 38 | '--data_name', 39 | required=True, 40 | type=str, 41 | choices=['dialdoc', 'wow'], 42 | help='', 43 | ) 44 | parser.add_argument( 45 | '--pretrained_model_dir', 46 | required=True, 47 | type=str, 48 | help='', 49 | ) 50 | parser.add_argument( 51 | '--input_dir', 52 | required=True, 53 | type=str, 54 | help='', 55 | ) 56 | parser.add_argument( 57 | '--output_dir', 58 | required=True, 59 | type=str, 60 | help='', 61 | ) 62 | parser.add_argument( 63 | '--max_seq_len', 64 | required=True, 65 | type=int, 66 | help='', 67 | ) 68 | parser.add_argument( 69 | '--max_history_len', 70 | required=True, 71 | type=int, 72 | help='', 73 | ) 74 | parser.add_argument( 75 | '--max_num_spans_per_passage', 76 | required=True, 77 | type=int, 78 | help='', 79 | ) 80 | parser.add_argument( 81 | '--num_sample_per_file', 82 | default=1000, 83 | type=int, 84 | help='', 85 | ) 86 | parser.add_argument( 87 | '--use_cls_span_start', 88 | action='store_true', 89 | help='', 90 | ) 91 | 92 | args = parser.parse_args() 93 | 94 | main(args) 95 | -------------------------------------------------------------------------------- /data_utils/data_class.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import collections 3 | from typing import List 4 | 5 | class ReaderPassage: 6 | """ 7 | Container to collect and cache all Q&A passages related attributes before generating the reader input 8 | """ 9 | 10 | def __init__( 11 | self, 12 | id=None, 13 | text: List[str] = None, 14 | type: List[int] = None, 15 | title: str = None, 16 | position: int = None, 17 | has_answer: bool = None, 18 | answers_spans: List = None, 19 | history_has_answers: List[bool]=None, 20 | history_answers_spans: List[list]=None, 21 | ): 22 | self.id = id 23 | self.span_texts = text 24 | self.span_types = type 25 | self.title = title 26 | self.position = position 27 | self.has_answer = has_answer 28 | 29 | # index pair indicating the start/end span id of each answer 30 | self.answers_spans = answers_spans 31 | 32 | # passage token ids 33 | self.sequence_ids = None 34 | self.sequence_type_ids = None 35 | self.question_boundaries = None 36 | 37 | # indices of cls tokens in sequence_ids 38 | self.clss = None 39 | self.ends = None 40 | # mask of clss, where the first two and padded cls tokens are 0 41 | self.mask_cls = None 42 | 43 | self.history_has_answers = history_has_answers 44 | self.history_answers_spans = history_answers_spans 45 | 46 | self.dialog_act_id = None 47 | self.history_dialog_act_ids = None 48 | 49 | def to_tensor(self): 50 | self.sequence_ids = torch.from_numpy(self.sequence_ids) 51 | self.sequence_type_ids = torch.from_numpy(self.sequence_type_ids) 52 | self.clss = torch.from_numpy(self.clss) 53 | self.ends = torch.from_numpy(self.ends) 54 | self.mask_cls = torch.from_numpy(self.mask_cls) 55 | self.question_boundaries = torch.from_numpy(self.question_boundaries) 56 | 57 | 58 | class ReaderSample: 59 | """ 60 | Container to collect all Q&A passages data per singe question 61 | """ 62 | 63 | def __init__( 64 | self, 65 | question: str, 66 | answers: List, 67 | id=None, 68 | positive_passages: List[ReaderPassage] = [], 69 | negative_passages: List[ReaderPassage] = [], 70 | passages: List[ReaderPassage] = [], 71 | ): 72 | self.id = id 73 | self.question = question 74 | self.answers = answers 75 | self.positive_passages = positive_passages 76 | self.negative_passages = negative_passages 77 | self.passages = passages 78 | 79 | @property 80 | def all_passages(self): 81 | return self.passages + self.negative_passages + self.positive_passages 82 | 83 | def to_tensor(self): 84 | for p in self.all_passages: 85 | p.to_tensor() 86 | 87 | 88 | SpanPrediction = collections.namedtuple( 89 | "SpanPrediction", 90 | [ 91 | "prediction_text", 92 | "span_score", 93 | "relevance_score", 94 | "passage_index", 95 | "passage_text", 96 | ], 97 | ) 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DIALKI 2 | 3 | 4 | This repo provides the training and inference code for the DIALKI model in our EMNLP 2021 paper: [DIALKI: Knowledge Identification in Conversational Systems through Dialogue-Document Contextualization](https://arxiv.org/abs/2109.04673) 5 | 6 | 7 | ## Setup 8 | The code has been tested on CUDA 11.0+. 9 | 1. Run `conda env create -f environment.yml` and `conda activate dialki` 10 | 2. To train on doc2dial dataset, first create a folder `./dialdoc` and put original data files from [here](https://github.com/doc2dial/sharedtask-dialdoc2021/tree/master/data/doc2dial/v1.0.1) into a subfolder `./dialdoc/raw_data`. 11 | 3. If you want to train on wow instead, skip step 2. Create a folder `./wow` and change the path variable in `./path` file to `./wow`. 12 | 4. Run `bash setup.sh`. 13 | 14 | ## Data Preparation and Training 15 | The default parameters were used to run on 2 NVIDIA Quadro Q6000 GPUs. Each training process took about 18 hours for 20 epochs (default). 16 | 1. Simply run `bash run.sh dialdoc` or `bash run.sh wow` depending on which dataset you want to run. 17 | 18 | #### Some important parameters to change if not enough memory for training: 19 | Setting `--adv_loss_weight=0.0` in `scripts/train.sh` disables the posterior regularization, which helps save memory during training, but at the cost of model performance. `--passages_per_question` can also be set smaller to save memory. Setting `--decision_function=0` disables the knowledge contextualization component. 20 | 21 | ## Inference and Evaluation 22 | After you finish training, run `bash run_eval.sh [dataname] [checkpoint_path] [inference_output_path]` to run inference. `dataname` can be either `dialdoc` or `wow`. The checkpoint_path can be either the best model from your training or [our provided model](https://drive.google.com/drive/folders/1iuEtWgb16r3JNaB8NKRQ8VUQjW3pHvvi?usp=sharing) for each dataset. `inference_output_path` is where you want the inference results to be saved. The console will print out the evaluation results during inference. 23 | 24 | Currently, the inference will run for dev set by default. If you want to change to test sets (note that you need to contact dialdoc authors to get their test set), go to `script/eval.sh` and change the `--dev_file` path. 25 | 26 | 27 | ## Cite 28 | ``` 29 | @inproceedings{wu-etal-2021-dialki, 30 | title = "{DIALKI}: Knowledge Identification in Conversational Systems through Dialogue-Document Contextualization", 31 | author = "Wu, Zeqiu and 32 | Lu, Bo-Ru and 33 | Hajishirzi, Hannaneh and 34 | Ostendorf, Mari", 35 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 36 | month = nov, 37 | year = "2021", 38 | address = "Online and Punta Cana, Dominican Republic", 39 | publisher = "Association for Computational Linguistics", 40 | url = "https://aclanthology.org/2021.emnlp-main.140", 41 | pages = "1852--1863", 42 | abstract = "Identifying relevant knowledge to be used in conversational systems that are grounded in long documents is critical to effective response generation. We introduce a knowledge identification model that leverages the document structure to provide dialogue-contextualized passage encodings and better locate knowledge relevant to the conversation. An auxiliary loss captures the history of dialogue-document connections. We demonstrate the effectiveness of our model on two document-grounded conversational datasets and provide analyses showing generalization to unseen documents and long dialogue contexts.", 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /data_utils/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from transformers import BertTokenizer, BertTokenizerFast 3 | 4 | from .data_class import SpanPrediction 5 | 6 | def is_word_head(tokenizer, token): 7 | # TODO including all models that are not using BPE tokenization 8 | if isinstance(tokenizer, (BertTokenizerFast, BertTokenizer)): 9 | return not token.startswith('##') 10 | return token.startswith('Ġ') 11 | 12 | 13 | def get_word_idxs(tokenizer, tokens, party_tokens, dont_mask_words): 14 | 15 | word_idxs = [] 16 | curr_word_idx = 0 17 | prev_t = None 18 | # for whole word masking 19 | for t in tokens: 20 | if t in dont_mask_words: 21 | word_idxs.append(-1) 22 | # Handling : is for BPE tokenizer. Remember to add : if you use BPE 23 | # tokenizer, such as RoBertaTokenizer 24 | elif t == ':' and prev_t in party_tokens: 25 | word_idxs.append(-1) 26 | elif is_word_head(tokenizer, t): 27 | curr_word_idx += 1 28 | word_idxs.append(curr_word_idx) 29 | else: 30 | word_idxs.append(curr_word_idx) 31 | prev_t = t 32 | 33 | assert len(tokens) == len(word_idxs) 34 | 35 | return word_idxs 36 | 37 | 38 | def start_end_finder(start_logits, end_logits, max_answer_length, span_type, mask_cls): 39 | scores = [] 40 | for (i, s) in enumerate(start_logits): 41 | for (j, e) in enumerate(end_logits[i : i + max_answer_length]): 42 | if mask_cls[i] != 0 and mask_cls[i + j] != 0: 43 | if span_type and span_type[i+j-1] != span_type[i-1]: 44 | break 45 | scores.append(((i, i + j), s + e)) 46 | 47 | scores = sorted(scores, key=lambda x: x[1], reverse=True) 48 | 49 | chosen_span_intervals = [] 50 | for (start_index, end_index), score in scores: 51 | assert start_index <= end_index 52 | length = end_index - start_index + 1 53 | assert length <= max_answer_length 54 | 55 | if any( 56 | [ 57 | start_index <= prev_start_index <= prev_end_index <= end_index 58 | or prev_start_index <= start_index <= end_index <= prev_end_index 59 | for (prev_start_index, prev_end_index) in chosen_span_intervals 60 | ] 61 | ): 62 | continue 63 | 64 | chosen_span_intervals.append((start_index, end_index)) 65 | yield start_index, end_index, score 66 | 67 | yield -1, -1, -1 68 | 69 | 70 | def get_best_spans( 71 | start_logits: List, 72 | end_logits: List, 73 | max_answer_length: int, 74 | passage_idx: int, 75 | span_text: str, 76 | span_type: str, 77 | mask_cls: List, 78 | relevance_score: float, 79 | top_spans: int = 1, 80 | ) -> List[SpanPrediction]: 81 | """ 82 | Finds the best answer span for the extractive Q&A model 83 | """ 84 | 85 | best_spans = [] 86 | for start_index, end_index, score in start_end_finder(start_logits, end_logits, max_answer_length, span_type, mask_cls): 87 | if start_index == -1 and end_index == -1: 88 | break 89 | 90 | predicted_answer = ' '.join(span_text[start_index-1:end_index]) # offset the question and title segment 91 | 92 | best_spans.append( 93 | SpanPrediction( 94 | predicted_answer, score, relevance_score, passage_idx, ' '.join(span_text) 95 | ) 96 | ) 97 | 98 | if len(best_spans) == top_spans: 99 | break 100 | return best_spans 101 | -------------------------------------------------------------------------------- /models/perturbation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft. All rights reserved. 2 | import torch 3 | import logging 4 | from .loss import stable_kl 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | def generate_noise(embed, mask, epsilon=1e-5): 9 | noise = embed.data.new(embed.size()).normal_(0, 1) * epsilon 10 | noise.detach() 11 | noise.requires_grad_() 12 | return noise 13 | 14 | 15 | class SmartPerturbation(): 16 | def __init__( 17 | self, 18 | epsilon=1e-6, 19 | step_size=1e-3, 20 | noise_var=1e-5, 21 | norm_p='inf', 22 | k=1, 23 | norm_level=0 24 | ): 25 | super(SmartPerturbation, self).__init__() 26 | self.epsilon = epsilon 27 | # eta 28 | self.step_size = step_size 29 | self.k = k 30 | # sigma 31 | self.noise_var = noise_var 32 | self.norm_p = norm_p 33 | self.norm_level = norm_level > 0 34 | 35 | 36 | def _norm_grad(self, grad, eff_grad=None, sentence_level=False): 37 | if self.norm_p == 'l2': 38 | if sentence_level: 39 | direction = grad / (torch.norm(grad, dim=(-2, -1), keepdim=True) + self.epsilon) 40 | else: 41 | direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + self.epsilon) 42 | elif self.norm_p == 'l1': 43 | direction = grad.sign() 44 | else: 45 | if sentence_level: 46 | direction = grad / (grad.abs().max((-2, -1), keepdim=True)[0] + self.epsilon) 47 | else: 48 | direction = grad / (grad.abs().max(-1, keepdim=True)[0] + self.epsilon) 49 | eff_direction = eff_grad / (grad.abs().max(-1, keepdim=True)[0] + self.epsilon) 50 | return direction, eff_direction 51 | 52 | def forward( 53 | self, 54 | model, 55 | logits, 56 | batch, 57 | global_step, 58 | calc_logits_keys, 59 | ): 60 | # init delta 61 | embed = model(batch, global_step, fwd_type='get_embs') 62 | noise = generate_noise(embed, batch['attention_mask'], epsilon=self.noise_var) 63 | for step in range(0, self.k): 64 | adv_logits, _ = model(batch, global_step, fwd_type='inputs_embeds', inputs_embeds=embed+noise, end_task_only=True) 65 | adv_loss = 0 66 | for k in calc_logits_keys: 67 | adv_loss += stable_kl(adv_logits[k], logits[k].detach(), reduce=False) 68 | 69 | delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True, retain_graph=False) 70 | norm = delta_grad.norm() 71 | if (torch.isnan(norm) or torch.isinf(norm)): 72 | return {}, -1, -1 73 | eff_delta_grad = delta_grad * self.step_size 74 | delta_grad = noise + delta_grad * self.step_size 75 | noise, eff_noise = self._norm_grad(delta_grad, eff_grad=eff_delta_grad, sentence_level=self.norm_level) 76 | noise = noise.detach() 77 | noise.requires_grad_() 78 | 79 | # save memory 80 | del adv_logits 81 | 82 | adv_logits, _ = model(batch, global_step, fwd_type='inputs_embeds', inputs_embeds=embed+noise, end_task_only=True) 83 | 84 | for k in list(adv_logits.keys()): 85 | if k in calc_logits_keys: 86 | adv_logits[f'adv_{k}'] = adv_logits.pop(k) 87 | else: 88 | adv_logits.pop(k) 89 | 90 | return adv_logits, embed.detach().abs().mean(), eff_noise.detach().abs().mean() 91 | #adv_lc = self.loss_map[task_id] 92 | #adv_loss = adv_lc(logits, adv_logits, ignore_index=-1) 93 | #return adv_loss, embed.detach().abs().mean(), eff_noise.detach().abs().mean() 94 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import List 4 | 5 | import torch 6 | from torch import nn 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from transformers.optimization import AdamW 9 | 10 | logger = logging.getLogger() 11 | 12 | def setup_for_distributed_mode( 13 | model: nn.Module, 14 | optimizer: torch.optim.Optimizer, 15 | device: object, 16 | n_gpu: int = 1, 17 | local_rank: int = -1, 18 | fp16: bool = False, 19 | fp16_opt_level: str = "O1", 20 | ) -> (nn.Module, torch.optim.Optimizer): 21 | model.to(device) 22 | if fp16: 23 | try: 24 | import apex 25 | from apex import amp 26 | 27 | apex.amp.register_half_function(torch, "einsum") 28 | except ImportError: 29 | raise ImportError( 30 | "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." 31 | ) 32 | 33 | logger.info(f"Use apex opt level = {fp16_opt_level}") 34 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 35 | 36 | if n_gpu > 1: 37 | model = torch.nn.DataParallel(model) 38 | 39 | if local_rank != -1: 40 | model = torch.nn.parallel.DistributedDataParallel( 41 | model, 42 | device_ids=[local_rank], 43 | output_device=local_rank, 44 | find_unused_parameters=True, 45 | ) 46 | return model, optimizer 47 | 48 | 49 | def move_to_device(sample, device): 50 | if len(sample) == 0: 51 | return {} 52 | 53 | def _move_to_device(maybe_tensor, device): 54 | if torch.is_tensor(maybe_tensor): 55 | return maybe_tensor.to(device) 56 | elif isinstance(maybe_tensor, dict): 57 | return { 58 | key: _move_to_device(value, device) 59 | for key, value in maybe_tensor.items() 60 | } 61 | elif isinstance(maybe_tensor, list): 62 | return [_move_to_device(x, device) for x in maybe_tensor] 63 | elif isinstance(maybe_tensor, tuple): 64 | return [_move_to_device(x, device) for x in maybe_tensor] 65 | else: 66 | return maybe_tensor 67 | 68 | return _move_to_device(sample, device) 69 | 70 | 71 | def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): 72 | """Create a schedule with a learning rate that decreases linearly after 73 | linearly increasing during a warmup period. 74 | """ 75 | 76 | def lr_lambda(current_step): 77 | if current_step < warmup_steps: 78 | return float(current_step) / float(max(1, warmup_steps)) 79 | return max( 80 | 0.0, 81 | float(training_steps - current_step) 82 | / float(max(1, training_steps - warmup_steps)), 83 | ) 84 | 85 | return LambdaLR(optimizer, lr_lambda, last_epoch) 86 | 87 | 88 | def init_weights(modules: List): 89 | for module in modules: 90 | if isinstance(module, (nn.Linear, nn.Embedding)): 91 | module.weight.data.normal_(mean=0.0, std=0.02) 92 | elif isinstance(module, nn.LayerNorm): 93 | module.bias.data.zero_() 94 | module.weight.data.fill_(1.0) 95 | if isinstance(module, nn.Linear) and module.bias is not None: 96 | module.bias.data.zero_() 97 | 98 | 99 | def get_model_obj(model: nn.Module): 100 | return model.module if hasattr(model, "module") else model 101 | 102 | 103 | def get_optimizer(model, learning_rate=0e-5, adam_eps=1e-8, weight_decay=0.0): 104 | no_decay = ["bias", "LayerNorm.weight"] 105 | 106 | optimizer_grouped_parameters = [ 107 | { 108 | "params": [ 109 | p 110 | for n, p in model.named_parameters() 111 | if not any(nd in n for nd in no_decay) 112 | ], 113 | "weight_decay": weight_decay, 114 | }, 115 | { 116 | "params": [ 117 | p 118 | for n, p in model.named_parameters() 119 | if any(nd in n for nd in no_decay) 120 | ], 121 | "weight_decay": 0.0, 122 | }, 123 | ] 124 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) 125 | return optimizer 126 | 127 | 128 | def batched_index_select(tensor, dim, index): 129 | assert (index >= 0).all() 130 | views = [tensor.shape[0]] + \ 131 | [1 if i != dim else -1 for i in range(1, len(tensor.size()))] 132 | expanse = list(tensor.size()) 133 | expanse[0] = -1 134 | expanse[dim] = -1 135 | index = index.view(views).expand(expanse) 136 | return torch.gather(tensor, dim, index) 137 | 138 | 139 | def init_weights(modules): 140 | for module in modules: 141 | if isinstance(module, (nn.Linear, nn.Embedding)): 142 | module.weight.data.normal_(mean=0.0, std=0.02) 143 | elif isinstance(module, nn.LayerNorm): 144 | module.bias.data.zero_() 145 | module.weight.data.fill_(1.0) 146 | if isinstance(module, nn.Linear) and module.bias is not None: 147 | module.bias.data.zero_() 148 | -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from utils import model_utils 6 | 7 | 8 | def get_rank(): 9 | if not dist.is_available(): 10 | return -1 11 | if not dist.is_initialized(): 12 | return -1 13 | return dist.get_rank() 14 | 15 | 16 | def get_world_size(): 17 | if not dist.is_available(): 18 | return 1 19 | if not dist.is_initialized(): 20 | return 1 21 | return dist.get_world_size() 22 | 23 | 24 | def get_device(): 25 | if not dist.is_available() or not dist.is_initialized(): 26 | return torch.device("cuda", 0) 27 | else: 28 | return torch.device("cuda", get_rank()) 29 | 30 | 31 | def is_local_master(): 32 | return get_rank() in [-1, 0] 33 | 34 | 35 | def get_default_group(): 36 | return dist.group.WORLD 37 | 38 | 39 | def all_reduce(tensor, group=None): 40 | if group is None: 41 | group = get_default_group() 42 | return dist.all_reduce(tensor, group=group) 43 | 44 | 45 | def all_gather_list(data, group=None, max_size=16384): 46 | """Gathers arbitrary data from all nodes into a list. 47 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 48 | data. Note that *data* must be picklable. 49 | Args: 50 | data (Any): data from the local worker to be gathered on other workers 51 | group (optional): group of the collective 52 | """ 53 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 54 | 55 | enc = pickle.dumps(data) 56 | enc_size = len(enc) 57 | 58 | if enc_size + SIZE_STORAGE_BYTES > max_size: 59 | raise ValueError( 60 | f'encoded data exceeds max_size, this can be fixed by increasing ' 61 | f'buffer size: {enc_size}') 62 | 63 | rank = get_rank() 64 | world_size = get_world_size() 65 | buffer_size = max_size * world_size 66 | 67 | if not hasattr(all_gather_list, '_buffer') or \ 68 | all_gather_list._buffer.numel() < buffer_size: 69 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 70 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 71 | 72 | buffer = all_gather_list._buffer 73 | buffer.zero_() 74 | cpu_buffer = all_gather_list._cpu_buffer 75 | 76 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, ( 77 | f'Encoded object size should be less than {256 ** SIZE_STORAGE_BYTES} ' 78 | f'bytes') 79 | 80 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 81 | 82 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 83 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 84 | 85 | start = rank * max_size 86 | size = enc_size + SIZE_STORAGE_BYTES 87 | buffer[start: start + size].copy_(cpu_buffer[:size]) 88 | 89 | all_reduce(buffer, group=group) 90 | 91 | try: 92 | result = [] 93 | for i in range(world_size): 94 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 95 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 96 | if size > 0: 97 | result.append( 98 | pickle.loads( 99 | bytes( 100 | out_buffer[SIZE_STORAGE_BYTES: size+SIZE_STORAGE_BYTES].tolist()))) 101 | return result 102 | except pickle.UnpicklingError: 103 | raise Exception( 104 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 105 | 'workers to enter the function together, so this error usually indicates ' 106 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 107 | 'sync if one of them runs out of memory, or if there are other conditions ' 108 | 'in your training script that can cause one worker to finish an epoch ' 109 | 'while other workers are still iterating over their portions of the data.') 110 | 111 | 112 | def all_gather(data, to_cpu=True): 113 | world_size = get_world_size() 114 | if world_size == 1: 115 | data = torch.tensor(data) 116 | if to_cpu: 117 | data = data.cpu() 118 | return [data] 119 | 120 | device = get_device() 121 | 122 | if not torch.is_tensor(data): 123 | data = torch.Tensor(data) 124 | data = data.to(device) 125 | 126 | rest_size = data.size()[1:] 127 | 128 | local_size = torch.LongTensor([data.size(0)]).to(device) 129 | size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)] 130 | dist.all_gather(size_list, local_size) 131 | size_list = [int(size.item()) for size in size_list] 132 | 133 | # +1 for a weird thing happen when local size == max(size_list). 134 | max_size = max(size_list) + 1 135 | 136 | tensor_list = [] 137 | for _ in size_list: 138 | tensor_list.append(torch.zeros(size=(max_size,)+rest_size).to(device)) 139 | 140 | padding = torch.zeros(size=(max_size-local_size,)+rest_size).to(device) 141 | tensor = torch.cat((data, padding), dim=0) 142 | 143 | dist.all_gather(tensor_list, tensor) 144 | 145 | data_list = [] 146 | for size, tensor in zip(size_list, tensor_list): 147 | data_list.append(tensor[:size]) 148 | 149 | if to_cpu: 150 | data_list = model_utils.move_to_device(data_list, 'cpu') 151 | 152 | return data_list -------------------------------------------------------------------------------- /models/hf_models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | from torch import nn 5 | import transformers as tfs 6 | from transformers.models.bert import modeling_bert 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class HFBertEncoder(modeling_bert.BertModel): 13 | def __init__(self, config, coordinator_config, args): 14 | modeling_bert.BertModel.__init__(self, config) 15 | assert config.hidden_size > 0, "Encoder hidden_size can't be zero" 16 | 17 | if args.use_coordinator: 18 | self.coordinator = modeling_bert.BertEncoder(coordinator_config) 19 | # assumes number of docs does not exceed 30 20 | self.doc_pos_embeddings = nn.Embedding(30, config.hidden_size) 21 | self.coord_head_mask = [ 22 | None for i in range(coordinator_config.num_hidden_layers)] 23 | else: 24 | self.coordinator = None 25 | 26 | if args.projection_dim != 0: 27 | self.encode_proj = nn.Linear( 28 | config.hidden_size, args.projection_dim) 29 | else: 30 | self.encode_proj = None 31 | 32 | self.activation = nn.Tanh() 33 | self.init_weights() 34 | 35 | @property 36 | def output_dim(self): 37 | if self.encode_proj: 38 | return self.encode_proj.out_features 39 | return self.config.hidden_size 40 | 41 | @classmethod 42 | def init_encoder(cls, args, vocab_size): 43 | 44 | cfg = tfs.AutoConfig.from_pretrained(args.pretrained_model_cfg) 45 | coordinator_cfg = modeling_bert.BertConfig.from_pretrained( 46 | 'bert-base-uncased') 47 | coordinator_cfg.num_hidden_layers = args.coordinator_layers 48 | coordinator_cfg.num_attention_heads = args.coordinator_heads 49 | 50 | dropout = args.dropout if hasattr(args, 'dropout') else 0.0 51 | if dropout != 0: 52 | cfg.attention_probs_dropout_prob = dropout 53 | cfg.hidden_dropout_prob = dropout 54 | coordinator_cfg.attention_probs_dropout_prob = dropout 55 | coordinator_cfg.hidden_dropout_prob = dropout 56 | 57 | encoder = cls.from_pretrained( 58 | args.pretrained_model_cfg, 59 | config=cfg, 60 | coordinator_config=coordinator_cfg, 61 | args=args) 62 | 63 | if cfg.vocab_size != vocab_size: 64 | logger.info(f"Resize embedding from {cfg.vocab_size} to {vocab_size}") 65 | encoder.resize_token_embeddings(vocab_size) 66 | 67 | # Hacky way to duplicate position embeddings. 68 | if args.max_seq_len > 512: 69 | my_pos_embeddings = nn.Embedding(args.max_seq_len, cfg.hidden_size) 70 | my_pos_embeddings.weight.data[:512] = encoder.embeddings.position_embeddings.weight.data 71 | n_assigned = 512 72 | while n_assigned < args.max_seq_len: 73 | next_n_assigned = min(n_assigned+512, args.max_seq_len) 74 | my_pos_embeddings.weight.data[n_assigned:next_n_assigned] = encoder.embeddings.position_embeddings.weight.data[:next_n_assigned-n_assigned,:] 75 | n_assigned = next_n_assigned 76 | encoder.embeddings.position_embeddings = my_pos_embeddings 77 | 78 | if args.num_token_types > 2: 79 | my_type_embeddings = nn.Embedding( 80 | args.num_token_types, cfg.hidden_size) 81 | my_type_embeddings.weight.data[:] = encoder.embeddings.token_type_embeddings.weight.data[0][None,:].repeat(args.num_token_types,1) 82 | encoder.embeddings.token_type_embeddings = my_type_embeddings 83 | 84 | encoder.embeddings.register_buffer( 85 | "position_ids", torch.arange(args.max_seq_len).expand(1, -1)) 86 | 87 | return encoder 88 | 89 | def forward( 90 | self, 91 | N, 92 | M, 93 | input_ids, 94 | token_type_ids, 95 | position_ids, 96 | attention_mask, 97 | inputs_embeds=None, 98 | ): 99 | if self.config.output_hidden_states: 100 | sequence_output, pooled_output, hidden_states = super().forward( 101 | input_ids=input_ids, 102 | token_type_ids=token_type_ids, 103 | attention_mask=attention_mask, 104 | inputs_embeds=inputs_embeds, 105 | ) 106 | else: 107 | hidden_states = None 108 | sequence_output, pooled_output = super().forward( 109 | input_ids=input_ids, 110 | token_type_ids=token_type_ids, 111 | attention_mask=attention_mask, 112 | inputs_embeds=inputs_embeds, 113 | ) 114 | # (N * M, 1, hidden_size). 115 | pooled_output = sequence_output[:, 0, :] 116 | pooled_output = self.coordinate( 117 | N, M, pooled_output.unsqueeze(1), position_ids) 118 | return sequence_output, pooled_output, hidden_states 119 | 120 | def coordinate(self, N, M, pooled_output, position_ids): 121 | # (N * M, l, hidden_size). 122 | hidden_size = pooled_output.size(-1) 123 | L = pooled_output.size(1) 124 | if self.coordinator: 125 | # (N * M, l, hidden_size). => (l, N * M, hidden_size). 126 | pooled_output = pooled_output.transpose(0, 1) 127 | # (l, N * M, hidden_size). => (l*N, M, hidden_size). 128 | pooled_output = pooled_output.view(-1, M, hidden_size) 129 | 130 | if self.doc_pos_embeddings: 131 | # (N, M) => (N, M, hidden_size). => (l*N, M, hidden_size). 132 | doc_position_embeddings = self.doc_pos_embeddings( 133 | position_ids).repeat(L, 1, 1) 134 | 135 | # (l*N, M, hidden_size). => (l*N, M, hidden_size). 136 | pooled_output = pooled_output + doc_position_embeddings 137 | 138 | pooled_output = self.coordinator( 139 | pooled_output, head_mask=self.coord_head_mask)[0] 140 | # (l*N, M, hidden_size). => (l, N * M, hidden_size). 141 | pooled_output = pooled_output.view(-1, N * M, hidden_size) 142 | # (l, N * M, hidden_size). => (N * M, l, hidden_size). 143 | pooled_output = pooled_output.transpose(0, 1) 144 | if self.encode_proj: 145 | pooled_output = self.activation(self.encode_proj(pooled_output)) 146 | return pooled_output 147 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dialki 2 | channels: 3 | - pytorch 4 | - comet_ml 5 | - anaconda 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=1_gnu 11 | - aiohttp=3.7.4=py38h497a2fe_0 12 | - async-timeout=3.0.1=py_1000 13 | - backcall=0.2.0=py_0 14 | - blas=1.0=mkl 15 | - boto=2.49.0=py_0 16 | - boto3=1.17.34=pyhd8ed1ab_0 17 | - botocore=1.20.34=pyhd8ed1ab_0 18 | - brotlipy=0.7.0=py38h497a2fe_1001 19 | - bz2file=0.98=py_0 20 | - c-ares=1.17.1=h7f98852_1 21 | - ca-certificates=2020.10.14=0 22 | - cachetools=4.2.1=pyhd8ed1ab_0 23 | - catalogue=2.0.1=py38h578d9bd_2 24 | - certifi=2020.6.20=py38_0 25 | - cffi=1.14.5=py38ha65f79e_0 26 | - chardet=4.0.0=py38h578d9bd_1 27 | - click=7.1.2=pyh9f0ad1d_0 28 | - colorama=0.4.4=pyh9f0ad1d_0 29 | - comet_ml=3.8.1=py38 30 | - configobj=5.0.6=py38_1 31 | - cryptography=3.4.6=py38ha5dfef3_0 32 | - cudatoolkit=11.0.221=h6bb024c_0 33 | - cxxfilt=0.2.2=py38h709712a_2 34 | - cymem=2.0.5=py38h709712a_1 35 | - cython-blis=0.7.4=py38h5c078b8_0 36 | - dataclasses=0.8=pyhc8e2a94_1 37 | - decorator=4.4.2=py_0 38 | - dulwich=0.20.21=py38h497a2fe_0 39 | - everett=1.0.3=pyhd3deb0d_0 40 | - filelock=3.0.12=pyh9f0ad1d_0 41 | - freetype=2.10.4=h5ab3b9f_0 42 | - google-api-core=1.25.1=pyhd8ed1ab_0 43 | - google-auth=1.24.0=pyhd3deb0d_0 44 | - google-cloud-core=1.5.0=pyhd3deb0d_0 45 | - google-cloud-storage=1.19.0=py_0 46 | - google-crc32c=1.1.2=py38h8838a9a_0 47 | - google-resumable-media=1.2.0=pyhd3deb0d_0 48 | - googleapis-common-protos=1.52.0=py38h578d9bd_1 49 | - grpcio=1.36.1=py38hdd6454d_0 50 | - idna=2.10=pyh9f0ad1d_0 51 | - importlib-metadata=2.0.0=py_1 52 | - importlib_metadata=2.0.0=1 53 | - iniconfig=1.1.1=pyh9f0ad1d_0 54 | - intel-openmp=2020.2=254 55 | - ipython=7.18.1=py38h5ca1d4c_0 56 | - ipython_genutils=0.2.0=py38_0 57 | - jedi=0.18.0=py38h06a4308_1 58 | - jinja2=2.11.3=pyh44b312d_0 59 | - jmespath=0.10.0=pyh9f0ad1d_0 60 | - joblib=1.0.1=pyhd8ed1ab_0 61 | - jpeg=9b=h024ee3a_2 62 | - jsonschema=3.2.0=py_2 63 | - lcms2=2.11=h396b838_0 64 | - ld_impl_linux-64=2.33.1=h53a641e_7 65 | - libcrc32c=1.1.1=h9c3ff4c_2 66 | - libffi=3.3=he6710b0_2 67 | - libgcc-ng=9.3.0=h2828fa1_18 68 | - libgomp=9.3.0=h2828fa1_18 69 | - libpng=1.6.37=hbc83047_0 70 | - libprotobuf=3.15.6=h780b84a_0 71 | - libstdcxx-ng=9.3.0=h6de172a_18 72 | - libtiff=4.2.0=h3942068_0 73 | - libuv=1.40.0=h7b6447c_0 74 | - libwebp-base=1.2.0=h27cfd23_0 75 | - lz4-c=1.9.3=h2531618_0 76 | - markupsafe=1.1.1=py38h497a2fe_3 77 | - mkl=2020.2=256 78 | - mkl-service=2.3.0=py38he904b0f_0 79 | - mkl_fft=1.3.0=py38h54f3939_0 80 | - mkl_random=1.1.1=py38h0573a6f_0 81 | - more-itertools=8.7.0=pyhd8ed1ab_0 82 | - multidict=5.1.0=py38h497a2fe_1 83 | - murmurhash=1.0.5=py38h709712a_0 84 | - ncurses=6.2=he6710b0_1 85 | - ninja=1.10.2=py38hff7bd54_0 86 | - nltk=3.5=py_0 87 | - numpy=1.19.2=py38h54aff64_0 88 | - numpy-base=1.19.2=py38hfa32c7d_0 89 | - nvidia-apex=0.1=py38h92f0514_3 90 | - nvidia-ml=7.352.0=py_0 91 | - olefile=0.46=py_0 92 | - openssl=1.1.1k=h7f98852_0 93 | - packaging=20.9=pyh44b312d_0 94 | - parso=0.8.0=py_0 95 | - pathy=0.4.0=pyhd8ed1ab_0 96 | - pexpect=4.8.0=py38_0 97 | - pickleshare=0.7.5=py38_1000 98 | - pillow=8.1.2=py38he98fc37_0 99 | - pip=21.0.1=py38h06a4308_0 100 | - pluggy=0.13.1=py38h578d9bd_4 101 | - preshed=3.0.5=py38h709712a_0 102 | - prompt-toolkit=3.0.8=py_0 103 | - protobuf=3.15.6=py38h709712a_0 104 | - psutil=5.7.2=py38h7b6447c_0 105 | - ptyprocess=0.6.0=py38_0 106 | - py=1.10.0=pyhd3deb0d_0 107 | - pyasn1=0.4.8=py_0 108 | - pyasn1-modules=0.2.7=py_0 109 | - pycparser=2.20=pyh9f0ad1d_2 110 | - pydantic=1.7.3=py38h497a2fe_1 111 | - pygments=2.7.1=py_0 112 | - pyopenssl=20.0.1=pyhd8ed1ab_0 113 | - pyparsing=2.4.7=pyh9f0ad1d_0 114 | - pyrsistent=0.17.3=py38h7b6447c_0 115 | - pysocks=1.7.1=py38h578d9bd_3 116 | - pytest=6.2.2=py38h578d9bd_0 117 | - python=3.8.8=hdb3f193_4 118 | - python-dateutil=2.8.1=py_0 119 | - python-decouple=3.4=pyhd3deb0d_0 120 | - python_abi=3.8=1_cp38 121 | - pytz=2021.1=pyhd8ed1ab_0 122 | - pyyaml=5.4.1=py38h497a2fe_0 123 | - readline=8.1=h27cfd23_0 124 | - requests=2.25.1=pyhd3deb0d_0 125 | - requests-toolbelt=0.9.1=py_0 126 | - rsa=4.7.2=pyh44b312d_0 127 | - s3transfer=0.3.6=pyhd8ed1ab_0 128 | - sacremoses=0.0.43=pyh9f0ad1d_0 129 | - setuptools=52.0.0=py38h06a4308_0 130 | - shellingham=1.4.0=pyh44b312d_0 131 | - six=1.15.0=py38h06a4308_0 132 | - smart_open=2.2.1=pyh9f0ad1d_0 133 | - spacy=3.0.5=py38hc10631b_0 134 | - spacy-legacy=3.0.1=pyhd8ed1ab_0 135 | - sqlite=3.35.2=hdfb4753_0 136 | - srsly=2.4.0=py38h709712a_2 137 | - thinc=8.0.2=py38hc10631b_1 138 | - tk=8.6.10=hbc83047_0 139 | - tokenizers=0.10.1=py38hb63a372_0 140 | - toml=0.10.2=pyhd8ed1ab_0 141 | - torchaudio=0.7.2=py38 142 | - torchvision=0.8.2=py38_cu110 143 | - traitlets=5.0.5=py_0 144 | - transformers=4.4.2=pyhd8ed1ab_0 145 | - typer=0.3.2=pyhd8ed1ab_0 146 | - typing-extensions=3.7.4.3=0 147 | - typing_extensions=3.7.4.3=py_0 148 | - wasabi=0.8.2=pyh44b312d_0 149 | - wcwidth=0.2.5=py_0 150 | - websocket-client=0.57.0=py38_2 151 | - wheel=0.36.2=pyhd3eb1b0_0 152 | - wrapt=1.12.1=py38h7b6447c_1 153 | - wurlitzer=2.0.1=py38_0 154 | - xz=5.2.5=h7b6447c_0 155 | - yaml=0.2.5=h516909a_0 156 | - yarl=1.6.3=py38h497a2fe_1 157 | - zipp=3.3.1=py_0 158 | - zlib=1.2.11=h7b6447c_3 159 | - zstd=1.4.5=h9ceee32_0 160 | - pip: 161 | - absl-py==1.0.0 162 | - alabaster==0.7.12 163 | - antlr4-python3-runtime==4.8 164 | - attrs==20.2.0 165 | - babel==2.9.1 166 | - coloredlogs==15.0.1 167 | - colorlog==6.6.0 168 | - datasets==1.16.1 169 | - dill==0.3.4 170 | - docformatter==1.4 171 | - docutils==0.15.2 172 | - emoji==1.6.1 173 | - fairscale==0.4.3 174 | - flake8==4.0.1 175 | - flake8-bugbear==21.11.29 176 | - fsspec==2021.11.1 177 | - gitdb==4.0.9 178 | - gitdb2==4.0.2 179 | - gitpython==3.1.24 180 | - google-auth-oauthlib==0.4.6 181 | - huggingface-hub==0.1.2 182 | - humanfriendly==10.0 183 | - hydra-core==1.1.1 184 | - imagesize==1.3.0 185 | - importlib-resources==5.4.0 186 | - iopath==0.1.9 187 | - jsonlines==2.0.0 188 | - markdown==3.3.4 189 | - markdown-it-py==0.5.8 190 | - mccabe==0.6.1 191 | - multiprocess==0.70.12.2 192 | - myst-parser==0.12.10 193 | - oauthlib==3.1.1 194 | - omegaconf==2.1.1 195 | - pandas==1.3.4 196 | - parlai==1.5.1 197 | - portalocker==2.3.2 198 | - py-gfm==1.0.2 199 | - py-rouge==1.1 200 | - pyarrow==6.0.1 201 | - pycodestyle==2.8.0 202 | - pyflakes==2.4.0 203 | - pytest-datadir==1.3.1 204 | - pytest-regressions==2.2.0 205 | - pyzmq==22.3.0 206 | - regex==2021.11.10 207 | - requests-mock==1.9.3 208 | - requests-oauthlib==1.3.0 209 | - scikit-learn==1.0.1 210 | - scipy==1.7.3 211 | - sh==1.14.2 212 | - smmap==5.0.0 213 | - snowballstemmer==2.2.0 214 | - sphinx==2.2.2 215 | - sphinx-autodoc-typehints==1.10.3 216 | - sphinx-rtd-theme==1.0.0 217 | - sphinxcontrib-applehelp==1.0.2 218 | - sphinxcontrib-devhelp==1.0.2 219 | - sphinxcontrib-htmlhelp==2.0.0 220 | - sphinxcontrib-jsmath==1.0.1 221 | - sphinxcontrib-qthelp==1.0.3 222 | - sphinxcontrib-serializinghtml==1.1.5 223 | - subword-nmt==0.3.7 224 | - tensorboard==2.7.0 225 | - tensorboard-data-server==0.6.1 226 | - tensorboard-plugin-wit==1.8.0 227 | - tensorboardx==2.4.1 228 | - threadpoolctl==3.0.0 229 | - torch==1.10.0 230 | - torchtext==0.11.0 231 | - tornado==6.1 232 | - tqdm==4.62.3 233 | - unidecode==1.3.2 234 | - untokenize==0.1.1 235 | - urllib3==1.26.7 236 | - websocket-server==0.6.2 237 | - werkzeug==2.0.2 238 | - xxhash==2.0.2 239 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Iterator 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data import Sampler, Dataset 7 | 8 | 9 | def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None): 10 | """ 11 | Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of 12 | similar lengths. To do this, the indices are: 13 | 14 | - randomly permuted 15 | - grouped in mega-batches of size :obj:`mega_batch_mult * batch_size` 16 | - sorted by length in each mega-batch 17 | 18 | The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of 19 | maximum length placed first, so that an OOM happens sooner rather than later. 20 | """ 21 | # Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller. 22 | if mega_batch_mult is None: 23 | mega_batch_mult = min(len(lengths) // (batch_size * 4), 50) 24 | # Just in case, for tiny datasets 25 | if mega_batch_mult == 0: 26 | mega_batch_mult = 1 27 | 28 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 29 | indices = torch.randperm(len(lengths), generator=generator) 30 | megabatch_size = mega_batch_mult * batch_size 31 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 32 | megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches] 33 | 34 | # The rest is to get the biggest batch first. 35 | # Since each megabatch is sorted by descending length, the longest element is the first 36 | megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches] 37 | max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item() 38 | # Switch to put the longest element in first position 39 | megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0] 40 | 41 | return sum(megabatches, []) 42 | 43 | 44 | class DistributedSampler(Sampler[int]): 45 | r"""Sampler that restricts data loading to a subset of the dataset. 46 | 47 | It is especially useful in conjunction with 48 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 49 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 50 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 51 | original dataset that is exclusive to it. 52 | 53 | .. note:: 54 | Dataset is assumed to be of constant size. 55 | 56 | Args: 57 | dataset: Dataset used for sampling. 58 | num_replicas (int, optional): Number of processes participating in 59 | distributed training. By default, :attr:`world_size` is retrieved from the 60 | current distributed group. 61 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 62 | By default, :attr:`rank` is retrieved from the current distributed 63 | group. 64 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 65 | indices. 66 | seed (int, optional): random seed used to shuffle the sampler if 67 | :attr:`shuffle=True`. This number should be identical across all 68 | processes in the distributed group. Default: ``0``. 69 | drop_last (bool, optional): if ``True``, then the sampler will drop the 70 | tail of the data to make it evenly divisible across the number of 71 | replicas. If ``False``, the sampler will add extra indices to make 72 | the data evenly divisible across the replicas. Default: ``False``. 73 | 74 | .. warning:: 75 | In distributed mode, calling the :meth:`set_epoch` method at 76 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 77 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 78 | the same ordering will be always used. 79 | 80 | Example:: 81 | 82 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 83 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 84 | ... sampler=sampler) 85 | >>> for epoch in range(start_epoch, n_epochs): 86 | ... if is_distributed: 87 | ... sampler.set_epoch(epoch) 88 | ... train(loader) 89 | """ 90 | 91 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, 92 | rank: Optional[int] = None, shuffle: bool = True, 93 | seed: int = 0, drop_last: bool = False) -> None: 94 | if num_replicas is None: 95 | if not dist.is_available(): 96 | raise RuntimeError("Requires distributed package to be available") 97 | num_replicas = dist.get_world_size() 98 | if rank is None: 99 | if not dist.is_available(): 100 | raise RuntimeError("Requires distributed package to be available") 101 | rank = dist.get_rank() 102 | if rank >= num_replicas or rank < 0: 103 | raise ValueError( 104 | "Invalid rank {}, rank should be in the interval" 105 | " [0, {}]".format(rank, num_replicas - 1)) 106 | self.dataset = dataset 107 | self.num_replicas = num_replicas 108 | self.rank = rank 109 | self.epoch = 0 110 | self.drop_last = drop_last 111 | if self.drop_last: 112 | self.num_samples = len(self.dataset) // self.num_replicas # type: ignore 113 | else: 114 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore 115 | self.total_size = self.num_samples * self.num_replicas 116 | self.shuffle = shuffle 117 | self.seed = seed 118 | self.offset = 0 119 | 120 | def __iter__(self) -> Iterator[int]: 121 | if self.shuffle: 122 | # deterministically shuffle based on epoch and seed 123 | g = torch.Generator() 124 | g.manual_seed(self.seed + self.epoch) 125 | indices = torch.randperm(len(self.dataset), generator=g) # type: ignore 126 | else: 127 | indices = range(len(self.dataset)) # type: ignore 128 | 129 | while (self.offset + self.num_replicas) <= len(indices): 130 | yield int(indices[self.offset + self.rank]) 131 | self.offset += self.num_replicas 132 | 133 | if not self.drop_last: 134 | # find the number of samples remaining 135 | num_rem = len(indices) % self.num_replicas 136 | if num_rem: 137 | if self.rank < num_rem: 138 | yield int(indices[self.offset + self.rank]) 139 | else: 140 | # wraparound, but mod in the case of self.rank >= len(indices) 141 | yield int(indices[self.rank % len(indices)]) 142 | 143 | @property 144 | def current_offset(self): 145 | return self.offset 146 | 147 | def set_offset(self, offset): 148 | self.offset = offset 149 | 150 | def __len__(self) -> int: 151 | return self.num_samples 152 | 153 | def set_epoch(self, epoch: int) -> None: 154 | r""" 155 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 156 | use a different random ordering for each epoch. Otherwise, the next iteration of this 157 | sampler will yield the same ordering. 158 | 159 | Args: 160 | epoch (int): Epoch number. 161 | """ 162 | self.epoch = epoch 163 | 164 | 165 | class SequentialDistributedSampler(Sampler): 166 | """ 167 | Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. 168 | Even though we only use this sampler for eval and predict (no training), which means that the model params won't 169 | have to be synced (i.e. will not hang for synchronization even if varied number of forward passes), we still add 170 | extra samples to the sampler to make it evenly divisible (like in `DistributedSampler`) to make it easy to `gather` 171 | or `reduce` resulting tensors at the end of the loop. 172 | """ 173 | 174 | def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None): 175 | if num_replicas is None: 176 | if not dist.is_available(): 177 | raise RuntimeError("Requires distributed package to be available") 178 | num_replicas = dist.get_world_size() 179 | if rank is None: 180 | if not dist.is_available(): 181 | raise RuntimeError("Requires distributed package to be available") 182 | rank = dist.get_rank() 183 | self.dataset = dataset 184 | self.num_replicas = num_replicas 185 | self.rank = rank 186 | num_samples = len(self.dataset) 187 | # Add extra samples to make num_samples a multiple of batch_size if passed 188 | if batch_size is not None: 189 | self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size 190 | else: 191 | self.num_samples = int(math.ceil(num_samples / num_replicas)) 192 | self.total_size = self.num_samples * self.num_replicas 193 | self.batch_size = batch_size 194 | 195 | def __iter__(self): 196 | indices = list(range(len(self.dataset))) 197 | 198 | # add extra samples to make it evenly divisible 199 | indices += indices[: (self.total_size - len(indices))] 200 | assert ( 201 | len(indices) == self.total_size 202 | ), f"Indices length {len(indices)} and total size {self.total_size} mismatched" 203 | 204 | # subsample 205 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 206 | assert ( 207 | len(indices) == self.num_samples 208 | ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" 209 | 210 | return iter(indices) 211 | 212 | def __len__(self): 213 | return self.num_samples -------------------------------------------------------------------------------- /prepro/create_dialdoc_json.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import os 4 | import re 5 | import sys 6 | import random 7 | from collections import defaultdict, Counter 8 | import argparse 9 | import nltk 10 | from transformers import BertTokenizer 11 | import string 12 | 13 | DEFAULT_TYPE_ID = 0 14 | USER_TYPE_ID = 1 15 | AGENT_TYPE_ID = 2 16 | 17 | DA_ID_MAP = {"respond_solution": 0, "query_condition": 1, "respond_solution_positive": 2, "respond_solution_negative": 3, 18 | "query_solution": 4, "response_negative": 5, "response_positive": 6} 19 | 20 | def _get_answers_rc(references, spans, doc_text, spid2passagelocation): 21 | """Obtain the grounding annotation for a given dialogue turn""" 22 | if not references: 23 | return [] 24 | start, end = -1, -1 25 | ls_sp = [] 26 | secid2count = defaultdict(int) 27 | for ele in references: 28 | sp_id = ele["sp_id"] 29 | secid2count[spid2passagelocation[sp_id][0]] += 1 30 | 31 | secid = sorted(secid2count.items(), key=lambda item: -item[1])[0][0] 32 | 33 | start_spid = None 34 | end_spid = None 35 | for ele in references: 36 | sp_id = ele["sp_id"] 37 | 38 | start_sp, end_sp = spans[sp_id]["start_sp"], spans[sp_id]["end_sp"] 39 | if start == -1 or start > start_sp: 40 | start = start_sp 41 | start_spid = sp_id 42 | if end < end_sp: 43 | end = end_sp 44 | end_spid = sp_id 45 | ls_sp.append(doc_text[start_sp:end_sp]) 46 | answer = doc_text[start:end] 47 | ans_spids = [str(id) for id in range(int(start_spid),int(end_spid)+1)] if (start_spid and end_spid) else [] 48 | return [' '.join(answer.strip().split())], secid, ans_spids 49 | 50 | def _load_doc_data_rc(filepath): 51 | doc_filepath = os.path.join(os.path.dirname(filepath), "doc2dial_doc.json") 52 | with open(doc_filepath, encoding="utf-8") as f: 53 | data = json.load(f)["doc_data"] 54 | return data 55 | 56 | def _get_section_info(doc): 57 | # assumes that doc spans and sections (passages) are put in order in the doc file 58 | id2text = defaultdict(list) 59 | id2spid = defaultdict(list) 60 | id2ptitles = defaultdict(list) 61 | ids_in_order = [] 62 | for sp_id in doc["spans"]: 63 | title = doc["spans"][sp_id]['title'] 64 | text_sp = doc["spans"][sp_id]["text_sp"].strip() 65 | 66 | id2ptitles[title] += [ptitle["text"] for ptitle in doc["spans"][sp_id]["parent_titles"]] 67 | id2text[title] += [text_sp] 68 | id2spid[title] += [sp_id] 69 | if title not in ids_in_order: 70 | ids_in_order += [title] 71 | secid2text = {} 72 | secid2spid = {} 73 | for id in id2text: 74 | texts = [k for k in id2text[id]] 75 | prefix = f'' 76 | for ptitle in id2ptitles[id]: 77 | prefix = f' {ptitle} {prefix}' 78 | texts = [prefix] + texts 79 | secid2text[id] = texts 80 | secid2spid[id] = [-1] + [k for k in id2spid[id]] 81 | return secid2text, secid2spid, ids_in_order 82 | 83 | def _get_spid2type(doc): 84 | spid2type = {} 85 | cur_type = 1 86 | prev_id_sec = None 87 | for i, sp_id in enumerate(doc["spans"]): 88 | id_sec = doc["spans"][sp_id]["id_sec"] 89 | if prev_id_sec and id_sec != prev_id_sec: 90 | cur_type = 1 - cur_type 91 | spid2type[sp_id] = cur_type 92 | prev_id_sec = id_sec 93 | 94 | return spid2type 95 | 96 | def _get_passage_type_ids(args, spids, spid2type): 97 | type_id_list = [] 98 | for spid in spids: 99 | if spid == -1: 100 | type_id_list += [DEFAULT_TYPE_ID] 101 | else: 102 | sp_type = spid2type[spid] 103 | type_id_list += [sp_type] 104 | return type_id_list 105 | 106 | 107 | def _process_sections(args, doc_id, secid2text, secid2spid, ids_in_order, spid2type): 108 | ctxs = [] 109 | spid2passagelocation = {} 110 | secid2position = {} 111 | for i, secid in enumerate(ids_in_order): 112 | assert len(secid2text[secid]) == len(secid2spid[secid]) 113 | secid2position[secid] = i 114 | type_id_list = _get_passage_type_ids(args, secid2spid[secid], spid2type) 115 | ctx = {'id': secid, "title": doc_id, "position": i, "text": secid2text[secid], "type": type_id_list, "has_answer": False} 116 | for j, spid in enumerate(secid2spid[secid]): 117 | if spid == -1: 118 | continue 119 | spid2passagelocation[spid] = (secid, j) 120 | ctxs += [ctx] 121 | return ctxs, spid2passagelocation, secid2position 122 | 123 | 124 | def process_doc(args, doc, doc_id): 125 | secid2text, secid2spid, ids_in_order = _get_section_info(doc) 126 | spid2type = _get_spid2type(doc) 127 | ctxs, spid2passagelocation, secid2position = _process_sections(args, doc_id, secid2text, secid2spid, ids_in_order, spid2type) 128 | return ctxs, spid2passagelocation, secid2position 129 | 130 | 131 | def process_ctxs_dict(args, ctxs, idx, turnid2ans_spids, spid2passagelocation, secid2position, docid, ans_secid): 132 | new_ctxs = [dict(ctx) for ctx in ctxs] 133 | for c in new_ctxs: 134 | c['type'] = list(c['type']) 135 | history_answers_spans = [[[] for i in range(idx+1)] for j in range(len(new_ctxs))] 136 | 137 | for tid in sorted(turnid2ans_spids.keys(), reverse=True): 138 | for spid in turnid2ans_spids[tid]: 139 | secid, type_idx = spid2passagelocation[spid] 140 | history_answers_spans[secid2position[secid]][idx - tid] += [type_idx] 141 | 142 | for i, c in enumerate(new_ctxs): 143 | if c['id'] == ans_secid: 144 | c["has_answer"] = True 145 | new_ctxs[i]['history_answers_spans'] = [[sorted(ans_sp)[0], sorted(ans_sp)[-1]] if len(ans_sp)>0 else None for ans_sp in history_answers_spans[i]] 146 | new_ctxs[i]['history_has_answers'] = [True if len(ans_sp)>0 else False for ans_sp in history_answers_spans[i]] 147 | return new_ctxs 148 | 149 | def get_question_type_ids(args, question): 150 | question_type_id_list = [] 151 | for turn in question: 152 | if turn.startswith(''): 153 | question_type_id_list += [USER_TYPE_ID] 154 | else: 155 | question_type_id_list += [AGENT_TYPE_ID] 156 | return question_type_id_list 157 | 158 | def find_answers_spans(cur_ctxs, ans_spids, spid2passagelocation): 159 | ans_string = '' 160 | for ctx in cur_ctxs: 161 | if ctx['has_answer']: 162 | indices = [] 163 | for spid in ans_spids: 164 | indices += [spid2passagelocation[spid][1]] 165 | indices = sorted(indices) 166 | ctx['answers_spans'] = [(indices[0], indices[-1])] 167 | ans_string = ' '.join(ctx['text'][indices[0]:indices[-1]+1]) 168 | else: 169 | ctx['answers_spans'] = None 170 | return cur_ctxs, ans_string 171 | 172 | 173 | def update_turnid2ans_spids(args, idx, turn, doc, turnid2ans_spids, spid2passagelocation): 174 | _, _, ans_spids = _get_answers_rc(turn["references"], doc["spans"], doc["doc_text"], spid2passagelocation) 175 | turnid2ans_spids[idx] = list(ans_spids) 176 | return turnid2ans_spids 177 | 178 | 179 | def is_valid_turn(idx, dial_turns, turn): 180 | return turn["role"] != "agent" and idx + 1 < len(dial_turns) and dial_turns[idx + 1]["role"] == "agent" 181 | 182 | 183 | def main(args): 184 | dtype = args.dtype 185 | filepath = args.filepath.format(dtype) 186 | if dtype == 'validation': 187 | dtype = 'dev' 188 | outfile = args.outfile.format(dtype) 189 | 190 | 191 | doc_data = _load_doc_data_rc(filepath) 192 | qas = [] 193 | with open(filepath, encoding="utf-8") as f: 194 | dial_data = json.load(f)["dial_data"] 195 | for domain, d_doc_dials in dial_data.items(): 196 | for doc_id, dials in d_doc_dials.items(): 197 | doc = doc_data[domain][doc_id] 198 | ctxs, spid2passagelocation, secid2position = process_doc(args, doc, doc_id) 199 | 200 | for dial in dials: 201 | all_prev_utterances = [] 202 | all_prev_dialacts = [] 203 | turnid2ans_spids = {} 204 | dial_turns = dial["turns"] 205 | for idx, turn in enumerate(dial_turns): 206 | all_prev_utterances.append("<{}>: {}".format(turn["role"], turn["utterance"])) 207 | all_prev_dialacts.append(DA_ID_MAP[turn["da"]]) 208 | 209 | id_ = "{}_{}".format(dial["dial_id"], turn["turn_id"]) 210 | 211 | turnid2ans_spids = update_turnid2ans_spids(args, idx, turn, doc, turnid2ans_spids, spid2passagelocation) 212 | 213 | if not is_valid_turn(idx, dial_turns, turn): 214 | continue 215 | turn_to_predict = dial_turns[idx + 1] 216 | 217 | 218 | question = list(reversed(all_prev_utterances)) 219 | dialog_act = list(reversed(all_prev_dialacts)) 220 | answers, ans_secid, ans_spids = _get_answers_rc(turn_to_predict["references"], doc["spans"], doc["doc_text"], spid2passagelocation) 221 | cur_ctxs = process_ctxs_dict(args, ctxs, idx, turnid2ans_spids, spid2passagelocation, secid2position, doc_id, ans_secid) 222 | cur_ctxs, ans_string = find_answers_spans(cur_ctxs, ans_spids, spid2passagelocation) 223 | question_type_id_list = get_question_type_ids(args, question) 224 | qa = { 225 | "id": id_, 226 | "question": question, 227 | "question_type": question_type_id_list, 228 | "history_dialog_act": dialog_act, 229 | "dialog_act": DA_ID_MAP[turn_to_predict["da"]], 230 | "answers": answers, 231 | "ctxs": cur_ctxs, 232 | "domain": domain, 233 | } 234 | qas += [qa] 235 | with open(outfile, 'w', encoding="utf-8") as fout: 236 | fout.write(json.dumps(qas, indent=4)) 237 | 238 | 239 | if __name__ == "__main__": 240 | parser = argparse.ArgumentParser() 241 | 242 | parser.add_argument( 243 | "--dtype", 244 | required=True, 245 | type=str, 246 | help="train or validation or test", 247 | ) 248 | parser.add_argument( 249 | "--filepath", 250 | type=str, 251 | default='./raw_data/v1.0.1/doc2dial_dial_{}.json', 252 | help="file of the input dial data", 253 | ) 254 | parser.add_argument( 255 | "--outfile", 256 | type=str, 257 | default='', 258 | help="file of the output data", 259 | ) 260 | 261 | args = parser.parse_args() 262 | 263 | 264 | main(args) 265 | 266 | 267 | -------------------------------------------------------------------------------- /data_utils/doc2dial_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import json 5 | import pickle 6 | import logging 7 | import concurrent.futures 8 | import numpy as np 9 | from collections import defaultdict 10 | from tqdm import tqdm 11 | 12 | from typing import List 13 | 14 | import torch 15 | 16 | from .data_class import ReaderSample, ReaderPassage, SpanPrediction 17 | from .utils import get_word_idxs 18 | from config import TOKENS, AGENT, USER 19 | 20 | 21 | logger = logging.getLogger() 22 | 23 | 24 | class Doc2DialReader: 25 | def __init__(self, args, input_dir, output_dir, tokenizer, max_seq_len, 26 | max_history_len, max_num_spans_per_passage, 27 | num_sample_per_file=1000): 28 | 29 | assert max_history_len < max_seq_len, \ 30 | f'max_history_len {max_history_len} must shorter than max_seq_len {max_seq_len}' 31 | 32 | self.args = args 33 | self.input_dir = input_dir 34 | self.output_dir = output_dir 35 | self.tokenizer = tokenizer 36 | self.max_seq_len = max_seq_len 37 | self.max_history_len = max_history_len 38 | self.max_num_spans_per_passage = max_num_spans_per_passage 39 | self.num_sample_per_file = num_sample_per_file 40 | 41 | self.dont_mask_words = { 42 | tokenizer.cls_token, 43 | tokenizer.sep_token, 44 | } 45 | self.dont_mask_words.update(set(TOKENS)) 46 | self.party_tokens = {AGENT, USER} 47 | 48 | 49 | @staticmethod 50 | def load_data(split, input_dir): 51 | input_path = os.path.join(input_dir, f'{split}.json') 52 | with open(input_path, 'r', encoding='utf-8') as f: 53 | samples = json.load(f) 54 | return samples 55 | 56 | 57 | def convert_json_to_finetune_pkl(self, split): 58 | self.convert_json_to_pkl(split, self.preprocess_chunk_for_finetune) 59 | 60 | 61 | def convert_json_to_pkl(self, split, callback): 62 | output_dir = os.path.join(self.output_dir, split) 63 | os.makedirs(output_dir, exist_ok=True) 64 | 65 | # NOTE 66 | # do not modify global variable to prevent from errors 67 | # read-only 68 | global global_samples 69 | global_samples = Doc2DialReader.load_data(split.replace('_span',''), self.input_dir) 70 | 71 | chunks = [] 72 | for i in range(math.ceil(len(global_samples) / self.num_sample_per_file)): 73 | chunks.append((i, i*self.num_sample_per_file, (i+1)*self.num_sample_per_file)) 74 | 75 | with concurrent.futures.ProcessPoolExecutor() as executor: 76 | feature2chunk = {executor.submit(callback, split, chunk): chunk for chunk in chunks} 77 | iterator = tqdm( 78 | concurrent.futures.as_completed(feature2chunk), 79 | total=len(chunks), 80 | desc=f'Preprocess {split:>5s} data from {self.input_dir}:', 81 | ) 82 | 83 | finish = 0 84 | no_positive_passages = 0 85 | for feature in iterator: 86 | chunk_idx, _, _ = feature2chunk[feature] 87 | try: 88 | npp = feature.result() 89 | iterator.set_description(f'Preprocess {split:>5s} data from {self.input_dir}: chunk {chunk_idx:>5d} finished!') 90 | finish += 1 91 | no_positive_passages += npp 92 | except Exception as e: 93 | sys.exit(f'[Error]: {e}. {feature.result()}') 94 | 95 | logger.info(f'# of samples = {len(global_samples)}') 96 | logger.info(f'no positive_passages = {no_positive_passages}') 97 | logger.info(f'lost answer % = {no_positive_passages / len(global_samples) * 100:.2f}') 98 | return len(chunks) 99 | 100 | 101 | def preprocess_chunk_for_finetune(self, split, chunk): 102 | chunk_idx, start, end = chunk 103 | results = [] 104 | no_positive_passages = 0 105 | is_train = True if split == 'train' else False 106 | for sample in global_samples[start:end]: 107 | sample = self.preprocess_sample(sample, is_train) 108 | if sample is None: 109 | no_positive_passages += 1 110 | continue 111 | results.append(sample) 112 | 113 | output_path = os.path.join(self.output_dir, f'{split}/{chunk_idx}.pkl') 114 | with open(output_path, mode='wb') as f: 115 | pickle.dump(results, f) 116 | 117 | return no_positive_passages 118 | 119 | 120 | def preprocess_sample(self, sample, is_train=True): 121 | q = sample["question"] 122 | q_type = sample["question_type"] 123 | 124 | positive_passages, negative_passages = Doc2DialReader.select_passages(sample, is_train) 125 | # create concatenated sequence ids for each passage and adjust answer spans 126 | positive_passages = [ 127 | self.create_passage(s, q, q_type) for s in positive_passages 128 | ] 129 | 130 | negative_passages = [ 131 | self.create_passage(s, q, q_type) for s in negative_passages 132 | ] 133 | 134 | for passage in positive_passages: 135 | num_history_questons = len(passage.question_boundaries) 136 | passage.dialog_act_id = sample['dialog_act'] 137 | passage.history_dialog_act_ids = sample['history_dialog_act'][:num_history_questons] 138 | 139 | # no positive 140 | if is_train and any(not p.has_answer for p in positive_passages): 141 | return None 142 | 143 | if is_train: 144 | return ReaderSample( 145 | q, 146 | sample["answers"], 147 | id=sample["id"], 148 | positive_passages=positive_passages, 149 | negative_passages=negative_passages, 150 | ) 151 | else: 152 | return ReaderSample( 153 | q, 154 | sample["answers"], 155 | id=sample["id"], 156 | passages=negative_passages, 157 | ) 158 | 159 | 160 | def create_passage(self, passage, questions, question_types): 161 | 162 | """ 163 | history question 164 | """ 165 | # 0 is for the first CLS token, so we start from 1 166 | question_boundaries = [1] 167 | question_tokens, question_type_ids, history_question_lens = [], [], [] 168 | for idx, q in enumerate(questions): 169 | tokens = self.tokenizer.tokenize(q) 170 | type_ids = [question_types[idx]] * len(tokens) 171 | 172 | question_tokens.extend(tokens) 173 | question_type_ids.extend(type_ids) 174 | 175 | history_question_lens.append(len(tokens)) 176 | question_boundaries.append(question_boundaries[-1] + len(tokens)) 177 | 178 | # -2 is for CLS 179 | if sum(history_question_lens) >= self.max_history_len-2: 180 | break 181 | 182 | # -2 is for CLS 183 | question_tokens = question_tokens[:self.max_history_len-2] 184 | question_type_ids = question_type_ids[:self.max_history_len-2] 185 | question_boundaries[-1] = min(question_boundaries[-1], self.max_history_len-2) 186 | num_history_questions = len(history_question_lens) 187 | 188 | passage.history_answers_spans = passage.history_answers_spans[:num_history_questions] 189 | passage.history_has_answers = passage.history_has_answers[:num_history_questions] 190 | passage.question_boundaries = np.array(list(zip(question_boundaries[:-1], question_boundaries[1:]))) 191 | 192 | """ 193 | title 194 | """ 195 | title_tokens = self.tokenizer.tokenize(passage.title) 196 | title_type_ids = len(title_tokens) * [0] 197 | 198 | history_and_title_tokens = [self.tokenizer.cls_token] + question_tokens \ 199 | + [self.tokenizer.sep_token] + title_tokens + [self.tokenizer.sep_token] 200 | history_and_title_type_ids = [0] + question_type_ids + [0] + title_type_ids + [0] 201 | 202 | # -1 for the last SEP 203 | assert len(history_and_title_tokens) < self.max_seq_len-1, \ 204 | "No space for passage tokens" 205 | 206 | 207 | """ 208 | passage (spans) 209 | """ 210 | shift = len(history_and_title_tokens) 211 | passage_tokens, passage_type_ids = [], [] 212 | clss = [shift] # clss[0] is for the dummy span of or 213 | ends = [] 214 | for i, span in enumerate(passage.span_texts): 215 | if self.args.use_cls_span_start: 216 | span_tokens = [self.tokenizer.cls_token] + self.tokenizer.tokenize(span) 217 | else: 218 | span_tokens = [self.tokenizer.sep_token] + self.tokenizer.tokenize(span) 219 | span_type_ids = [passage.span_types[i]] * len(span_tokens) 220 | next_cls_pos = clss[-1] + len(span_tokens) 221 | 222 | ends.append(next_cls_pos-1) 223 | clss.append(next_cls_pos) 224 | 225 | passage_tokens.extend(span_tokens) 226 | passage_type_ids.extend(span_type_ids) 227 | 228 | # -1 for the last SEP 229 | final_tokens = (history_and_title_tokens + passage_tokens)[:self.max_seq_len-1] + [self.tokenizer.sep_token] 230 | final_type_ids = (history_and_title_type_ids + passage_type_ids)[:self.max_seq_len-1] + [0] 231 | 232 | 233 | assert len(final_tokens) == len(final_type_ids) 234 | 235 | passage.sequence_ids = np.array(self.tokenizer.convert_tokens_to_ids(final_tokens)) 236 | passage.sequence_type_ids = np.array(final_type_ids) 237 | passage.word_idxs = np.array(get_word_idxs(self.tokenizer, final_tokens, self.party_tokens, self.dont_mask_words)) 238 | 239 | # the last item of clss is a dummy 240 | clss = clss[:-1] 241 | 242 | # (0, shift-1) is for history_and_title 243 | clss = [0] + clss 244 | ends = [shift-1] + ends 245 | 246 | # make sure all ends are less than max length 247 | ends = list(filter(lambda idx: idx < self.max_seq_len-1, ends)) 248 | clss = clss[:len(ends)] 249 | 250 | num_spans = min(len(clss), self.max_num_spans_per_passage) 251 | clss = clss[:num_spans] + [-1] * (self.max_num_spans_per_passage - num_spans) 252 | ends = ends[:num_spans] + [-1] * (self.max_num_spans_per_passage - num_spans) 253 | clss = np.array(clss) 254 | ends = np.array(ends) 255 | mask_cls = 1 - (clss == -1) 256 | mask_cls[0] = 0 # 1st CLS (before history) 257 | if self.args.data_name == 'dialdoc': 258 | mask_cls[1] = 0 # 2nd CLS (before ) 259 | clss[clss == -1] = 0 260 | ends[ends == -1] = 0 261 | 262 | passage.clss = clss 263 | passage.ends = ends 264 | passage.mask_cls = mask_cls 265 | 266 | if passage.has_answer: 267 | # +1 for CLS offset 268 | passage.answers_spans = [(s[0]+1, s[1]+1) for s in passage.answers_spans if s[1] + 1 < num_spans] 269 | passage.has_answer = (len(passage.answers_spans) > 0) 270 | 271 | 272 | """ 273 | history spans 274 | """ 275 | for i, s in enumerate(passage.history_answers_spans): 276 | if not s: 277 | continue 278 | 279 | # +1 for CLS offset 280 | s = [s[0]+1, s[1]+1] 281 | if not (s[1] < num_spans): 282 | s = None 283 | passage.history_answers_spans[i] = s 284 | passage.history_has_answers[i] = (s is not None) 285 | 286 | return passage 287 | 288 | 289 | @staticmethod 290 | def select_passages(sample, is_train): 291 | answers = sample["answers"] 292 | 293 | ctxs = [ReaderPassage(**ctx) for ctx in sample["ctxs"]] 294 | 295 | if is_train: 296 | positive_passages = list(filter(lambda ctx: ctx.has_answer, ctxs)) 297 | negative_passages = list(filter(lambda ctx: not ctx.has_answer, ctxs)) 298 | else: 299 | positive_passages = [] 300 | negative_passages = ctxs 301 | 302 | return positive_passages, negative_passages -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import socket 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from models import loss 11 | from utils import dist_utils 12 | 13 | logger = logging.getLogger() 14 | 15 | 16 | def add_data_params(parser: argparse.ArgumentParser): 17 | parser.add_argument( 18 | '--do_lower_case', 19 | action='store_true', 20 | help=('Whether to lower case the input text. True for uncased models, ' 21 | 'False for cased models.')) 22 | parser.add_argument( 23 | '--max_num_answers', 24 | default=2, 25 | type=int, 26 | help='Max amount of answer spans to marginalize per single passage') 27 | parser.add_argument( 28 | '--passages_per_question', 29 | type=int, 30 | default=2, 31 | help='Total amount of positive and negative passages per question') 32 | parser.add_argument( 33 | '--passages_per_question_predict', 34 | type=int, 35 | default=50, 36 | help=('Total amount of positive and negative passages per question for ' 37 | 'evaluation')) 38 | parser.add_argument( 39 | '--max_answer_length', 40 | default=5, 41 | type=int, 42 | help=('The maximum length of an answer (in spans) that can be ' 43 | 'generated. This is needed because the start and end predictions ' 44 | 'are not conditioned on one another.')) 45 | parser.add_argument( 46 | '--special_attention', 47 | action='store_true', 48 | help=('using special attention to limit the range a question or a ' 49 | 'passage can attend to')) 50 | parser.add_argument( 51 | '--passage_attend_history', 52 | action='store_true', 53 | help=('tokens in a passage can attend to history question. ' 54 | '(information leak?)')) 55 | parser.add_argument( 56 | '--data_name', 57 | required=True, 58 | type=str, 59 | choices=['dialdoc', 'wow'], 60 | help='The name of the dataset.') 61 | 62 | 63 | def add_encoder_params(parser: argparse.ArgumentParser): 64 | """Common parameters to initialize an encoder-based model.""" 65 | 66 | parser.add_argument( 67 | '--pretrained_model_cfg', 68 | default=None, 69 | type=str, 70 | help='Path of the pre-trained model.') 71 | parser.add_argument( 72 | '--checkpoint_file', 73 | default=None, 74 | type=str, 75 | help='Trained checkpoint file to initialize the model.') 76 | parser.add_argument( 77 | '--projection_dim', 78 | default=0, 79 | type=int, 80 | help='Extra linear layer on top of standard bert/roberta encoder.') 81 | parser.add_argument( 82 | '--max_seq_len', 83 | type=int, 84 | default=512, 85 | help='Max length of the encoder input sequence.') 86 | parser.add_argument( 87 | '--dropout', 88 | default=0.1, 89 | type=float, 90 | help='') 91 | parser.add_argument( 92 | '--use_coordinator', 93 | action='store_true', 94 | help=('Whether to use a coordinator to contexualize passages with ' 95 | 'other passage vector')) 96 | parser.add_argument( 97 | '--coordinator_layers', 98 | default=1, 99 | type=int, 100 | help='Number of hidden layers for the passage coordinator') 101 | parser.add_argument( 102 | '--coordinator_heads', 103 | default=3, 104 | type=int, 105 | help='Number of attention heads for the passage coordinator') 106 | parser.add_argument( 107 | '--num_token_types', 108 | default=10, 109 | type=int, 110 | help='Number of possiblen token types') 111 | parser.add_argument( 112 | '--ignore_token_type', 113 | action='store_true', 114 | help='Whether to ignore token types or not') 115 | parser.add_argument( 116 | '--compute_da_loss', 117 | action='store_true', 118 | help='Whether to jointly train dialog act prediction or not') 119 | parser.add_argument( 120 | '--decision_function', 121 | type=int, 122 | default=0, 123 | help='Which decision function to use for calculating loss') 124 | parser.add_argument( 125 | '--hist_loss_weight', 126 | type=float, 127 | default=1.0, 128 | help='weight of history loss') 129 | parser.add_argument( 130 | '--user2agent_loss_weight', 131 | default=0, 132 | type=float, 133 | help=('predict a history agent span based on the previous user ' 134 | 'question if > 0')) 135 | parser.add_argument( 136 | '--span_marker', 137 | action='store_true', 138 | help='mark spans used in history') 139 | parser.add_argument( 140 | '--skip_mark_last_user', 141 | action='store_true', 142 | help=('skip add mark embeddings of the last user turn to span ' 143 | 'embeddings')) 144 | parser.add_argument( 145 | '--marker_after_steps', 146 | default=0, 147 | type=int, 148 | help='not using marker in the begining of the training process') 149 | parser.add_argument( 150 | '--use_z_attn', 151 | action='store_true', 152 | help='') 153 | 154 | 155 | def add_f_div_regularization_params(parser: argparse.ArgumentParser): 156 | parser.add_argument( 157 | '--adv_epsilon', 158 | default=1e-6, 159 | type=float, 160 | help='for adv training') 161 | parser.add_argument( 162 | '--adv_step_size', 163 | default=1e-5, 164 | type=float, 165 | help='for adv training') 166 | parser.add_argument( 167 | '--adv_noise_var', 168 | default=1e-5, 169 | type=float, 170 | help='for adv training') 171 | parser.add_argument( 172 | '--adv_norm_p', 173 | default='inf', 174 | type=str, 175 | help='for adv training') 176 | parser.add_argument( 177 | '--adv_norm_level', 178 | default=0, 179 | type=int, 180 | help='for adv training') 181 | parser.add_argument( 182 | '--adv_k', 183 | default=1, 184 | type=int, 185 | help='for adv training') 186 | parser.add_argument( 187 | '--adv_calc_logits_keys', 188 | nargs='+', 189 | default=['start', 'end', 'relevance'], 190 | help='for adv training') 191 | parser.add_argument( 192 | '--adv_loss_weight', 193 | default=0.0, 194 | type=float, 195 | help='for adv training') 196 | parser.add_argument( 197 | '--adv_loss_type', 198 | default='hl', 199 | choices=loss.LOSS.keys(), 200 | type=str, 201 | help='for adv training') 202 | 203 | 204 | def add_training_params(parser: argparse.ArgumentParser): 205 | """Common parameters for training.""" 206 | parser.add_argument( 207 | '--train_file', 208 | default=None, 209 | type=str, 210 | help='File pattern for the train set.') 211 | parser.add_argument( 212 | '--dev_file', 213 | default=None, 214 | type=str, 215 | help='File pattern for the dev set.') 216 | parser.add_argument( 217 | '--batch_size', 218 | default=2, 219 | type=int, 220 | help='Amount of questions per batch.') 221 | parser.add_argument( 222 | '--dev_batch_size', 223 | type=int, 224 | default=4, 225 | help='amount of questions per batch for dev set validation.') 226 | parser.add_argument( 227 | '--seed', 228 | type=int, 229 | default=0, 230 | help='random seed for initialization and dataset shuffling.') 231 | parser.add_argument( 232 | '--adam_eps', 233 | default=1e-8, 234 | type=float, 235 | help='Epsilon for Adam optimizer.') 236 | parser.add_argument( 237 | '--adam_betas', 238 | default='(0.9, 0.999)', 239 | type=str, 240 | help='Betas for Adam optimizer.') 241 | parser.add_argument( 242 | '--max_grad_norm', 243 | default=1.0, 244 | type=float, 245 | help='Max gradient norm.') 246 | parser.add_argument( 247 | '--log_batch_step', 248 | default=100, 249 | type=int, 250 | help='Number of steps to log during training.') 251 | parser.add_argument( 252 | '--train_rolling_loss_step', 253 | default=100, 254 | type=int, 255 | help='Number of steps of interval to save traning loss.') 256 | parser.add_argument( 257 | '--weight_decay', 258 | default=0.0, 259 | type=float, 260 | help='Weight decay for optimizer.') 261 | parser.add_argument( 262 | '--learning_rate', 263 | default=1e-5, 264 | type=float, 265 | help='Learning rate.') 266 | parser.add_argument( 267 | '--warmup_steps', 268 | default=100, 269 | type=int, 270 | help='Linear warmup over warmup_steps.') 271 | parser.add_argument( 272 | '--gradient_accumulation_steps', 273 | type=int, 274 | default=1, 275 | help='Number of update steps to accumulate before updating parameters.') 276 | parser.add_argument( 277 | '--num_train_epochs', 278 | default=3.0, 279 | type=float, 280 | help='Total number of training epochs to perform.') 281 | parser.add_argument( 282 | '--auto_resume', 283 | action='store_true', 284 | help='Auto resume from latest checkpoint') 285 | parser.add_argument( 286 | '--save_checkpoint_every_minutes', 287 | type=int, 288 | default=15, 289 | help='Save a checkpoint every x minutes') 290 | parser.add_argument( 291 | '--topk_em', 292 | type=int, 293 | default=2, 294 | help='Topk checkpoints according to EM metrics.') 295 | parser.add_argument( 296 | '--topk_f1', 297 | type=int, 298 | default=2, 299 | help='Topk checkpoints according to F1 metrics.') 300 | parser.add_argument( 301 | '--best_metric', 302 | type=str, 303 | choices=['em', 'f1'], 304 | help='Take the best model based on EM or F1 scores.') 305 | parser.add_argument( 306 | '--eval_step', 307 | default=2000, 308 | type=int, 309 | help='Batch steps to run validation and save checkpoint.') 310 | parser.add_argument( 311 | '--eval_top_docs', 312 | nargs='+', 313 | type=int, 314 | help=('Top retrival passages thresholds to analyze prediction results ' 315 | 'for')) 316 | parser.add_argument( 317 | '--checkpoint_filename_prefix', 318 | type=str, 319 | default='dialki', 320 | help='Checkpoint filename prefix.') 321 | parser.add_argument( 322 | '--output_dir', 323 | required=True, 324 | type=str, 325 | help='Output directory for checkpoints.') 326 | parser.add_argument( 327 | '--inference_only', 328 | action='store_true', 329 | help='Inference only.') 330 | parser.add_argument( 331 | '--prediction_results_file', 332 | type=str, 333 | help='Path to a file to write prediction results to') 334 | 335 | 336 | def add_cuda_params(parser: argparse.ArgumentParser): 337 | parser.add_argument( 338 | '--local_rank', 339 | type=int, 340 | default=-1, 341 | help='The parameter for distributed training.') 342 | parser.add_argument( 343 | '--fp16', 344 | action='store_true', 345 | help='Whether to use 16-bit float precision instead of 32-bit.') 346 | parser.add_argument( 347 | '--fp16_opt_level', 348 | type=str, 349 | default='O1', 350 | help=('For fp16: Apex AMP optimization level selected.' 351 | 'See details at https://nvidia.github.io/apex/amp.html.')) 352 | 353 | 354 | def get_encoder_checkpoint_params_names(): 355 | return [ 356 | 'do_lower_case', 357 | 'pretrained_model_cfg', 358 | 'projection_dim', 359 | 'max_seq_len', 360 | ] 361 | 362 | 363 | def get_encoder_params_state(args): 364 | """ 365 | Selects the param values to be saved in a checkpoint, so that a trained 366 | model faile can be used for downstream tasks without the need to specify 367 | these parameter again. 368 | 369 | Return: Dict of params to memorize in a checkpoint. 370 | """ 371 | params_to_save = get_encoder_checkpoint_params_names() 372 | 373 | r = {} 374 | for param in params_to_save: 375 | r[param] = getattr(args, param) 376 | return r 377 | 378 | 379 | def set_encoder_params_from_state(state, args): 380 | if not state: 381 | return 382 | params_to_save = get_encoder_checkpoint_params_names() 383 | 384 | override_params = [ 385 | (param, state[param]) 386 | for param in params_to_save 387 | if param in state and state[param] 388 | ] 389 | for param, value in override_params: 390 | if param == "pretrained_model_cfg": 391 | continue 392 | if hasattr(args, param): 393 | if dist_utils.is_local_master(): 394 | logger.warning( 395 | f'Overriding args parameter value from checkpoint state. ' 396 | f'{param = }, {value = }') 397 | setattr(args, param, value) 398 | return args 399 | 400 | 401 | def set_seed(args): 402 | seed = args.seed 403 | random.seed(seed) 404 | np.random.seed(seed) 405 | torch.manual_seed(seed) 406 | if args.n_gpu > 0: 407 | torch.cuda.manual_seed_all(seed) 408 | 409 | 410 | def next_free_port(port=10123, max_port=65535): 411 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 412 | while port <= max_port: 413 | try: 414 | sock.bind(('', port)) 415 | sock.close() 416 | return port 417 | except OSError: 418 | port += 1 419 | raise IOError('no free ports') 420 | 421 | 422 | def setup_args_gpu(args): 423 | """ 424 | Setup arguments CUDA, GPU & distributed training. 425 | """ 426 | 427 | word_size = os.environ.get('WORLD_SIZE') 428 | word_size = int(word_size) if word_size else 1 429 | args.distributed_world_size = word_size 430 | local_rank = args.local_rank 431 | 432 | if local_rank == -1: 433 | # Single-node multi-gpu (or cpu) mode. 434 | if torch.cuda.is_available(): 435 | device = 'cuda' 436 | else: 437 | device = 'cpu' 438 | device = torch.device(device) 439 | n_gpu = args.n_gpu = torch.cuda.device_count() 440 | else: 441 | # Distributed mode. 442 | torch.cuda.set_device(args.local_rank) 443 | device = torch.device('cuda', args.local_rank) 444 | master_port = next_free_port() 445 | dist_init_method = f'tcp://localhost:{master_port}' 446 | torch.distributed.init_process_group( 447 | backend='nccl', 448 | init_method=dist_init_method, 449 | rank=args.local_rank, 450 | world_size=word_size) 451 | n_gpu = args.n_gpu = 1 452 | args.device = device 453 | 454 | if dist_utils.is_local_master(): 455 | logger.info( 456 | f'Initialized host {socket.gethostname()}' 457 | f'{local_rank = } {device = } {n_gpu = } {word_size = }' 458 | f'16-bits training: {args.fp16}') 459 | -------------------------------------------------------------------------------- /prepro/create_wow_json.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple, Union, Iterator 2 | 3 | import json 4 | import os 5 | from collections import namedtuple 6 | import random 7 | import colorlog 8 | from operator import itemgetter 9 | import argparse 10 | 11 | from collections import defaultdict, Counter 12 | 13 | import numpy as np 14 | from tqdm import tqdm 15 | from parlai.core.dict import DictionaryAgent 16 | from parlai.core.worlds import create_task 17 | 18 | 19 | PARLAI_KNOWLEDGE_SEPARATOR = '__knowledge__' 20 | BERT_KNOWLEDGE_SEPARATOR = '_ _ knowledge _ _' 21 | 22 | _PARLAI_PAD = '__null__' 23 | _PARLAI_GO = '__start__' 24 | _PARLAI_EOS = '__end__' 25 | _PARLAI_UNK = '__unk__' 26 | 27 | 28 | DEFAULT_TYPE_ID = 0 29 | USER_TYPE_ID = 1 30 | AGENT_TYPE_ID = 2 31 | 32 | 33 | class WowDatasetReader: 34 | 35 | def __init__(self, 36 | cache_dir: str = None) -> None: 37 | self._datapath = cache_dir 38 | 39 | 40 | def _load_and_preprocess_all(self, mode: str): 41 | """ 42 | As default, it returns the following action dict: 43 | { 44 | 'id': 'wizard_of_wikipedia' 45 | 'text': chosen_topic\n # if first example in episode 46 | last_apprentice_message\n # if possible 47 | wizard_message # if --label-type is 'chosen_sent' 48 | 'knowledge': title_1 sentence_1\n 49 | . 50 | . 51 | . 52 | title_m sentence_n # all knowledge available to wizard 53 | 'labels': [title_checked sentence_checked] # default 54 | OR 55 | [wizard_response] # if --label-type set to 'response' 56 | 'label_candidates': knowledge + [no_passages_used no_passages_used] 57 | OR 58 | 100 response candidates # if 'validation' or 'test' 59 | 'chosen_topic': chosen_topic as untokenized string 60 | 'checked_sentence': checked sentence if wizard, else None # if --include_checked_sentence 61 | 'title': title of checked sentence # if --include_checked_sentence 62 | --> if not exists, then checked_sentence = title = 'no_passages_used' 63 | 'episode_done': (Boolean) whether episode is done or not 64 | } 65 | """ 66 | 67 | parlai_opt = self._get_parlai_opt([ 68 | '--task', 'wizard_of_wikipedia:generator:topic_split' if 'unseen' in mode else 'wizard_of_wikipedia:generator:random_split', 69 | '--datatype', '{}:stream'.format(mode.split('_')[0]) if 'unseen' in mode else f'{mode}:stream', # 'train' for shuffled data and 'train:stream' for unshuffled data 70 | '--datapath', self._datapath, 71 | # dict_XXX will not be used if we use bert tokenizer 72 | '--dict_lower', 'True', 73 | '--dict_tokenizer', 'bpe', 74 | '--dict_file', f"{self._datapath}/wow.dict", 75 | '--dict_textfields', "text,labels,chosen_topic,checked_sentence,knowledge,title", # For retrieval mode, use "text,labels" 76 | # By following author's code. For retrieval mode, use 250004 77 | # Also, note that this is the size of bpehelper dictionary. 78 | # So, final dictionary can be larger than this one 79 | # And, don't convert special tokens to index with txt2vec method, you must use tok2ind 80 | '--dict_maxtokens', '30000', 81 | '--dict_nulltoken', _PARLAI_PAD, 82 | '--dict_starttoken', _PARLAI_GO, 83 | '--dict_endtoken', _PARLAI_EOS, 84 | '--dict_unktoken', _PARLAI_UNK, 85 | '--include_knowledge_separator', 'True', # include speical __knowledge__ token between title and passage 86 | '--include_checked_sentence', 'True', 87 | '--label_type', 'response', # choices = ['response', 'chosen_sent'] 88 | ]) 89 | # As a default, world use "WizardDialogKnowledgeTeacher" 90 | agent = DictionaryAgent(parlai_opt) 91 | world = create_task(parlai_opt, agent) 92 | num_examples = world.num_examples() 93 | num_episodes = world.num_episodes() 94 | 95 | episodes = [] 96 | for _ in range(num_episodes): 97 | examples = [] 98 | while True: 99 | world.parley() 100 | example = world.acts[0] 101 | examples.append(example) 102 | if world.episode_done(): 103 | episodes.append(examples) 104 | break 105 | 106 | return self._preprocess_episodes(episodes, mode) 107 | 108 | def _get_parlai_opt(self, options: List[str] = [], print_args=False): 109 | from parlai.scripts.build_dict import setup_args 110 | parser = setup_args() 111 | opt = parser.parse_args(options, print_args=print_args) 112 | return opt 113 | 114 | def _get_preprocessed_fname(self, mode): 115 | if self._datapath: 116 | return os.path.join(self._datapath, f'{mode}_episodes.json') 117 | else: 118 | return None 119 | 120 | def _preprocess_episodes(self, episodes, mode): 121 | 122 | colorlog.info("Preprocess wizard of wikipedia dataset") 123 | 124 | new_episodes = [] 125 | for episode_num, episode in enumerate(tqdm(episodes, ncols=70)): 126 | new_examples = [] 127 | for example_num, example in enumerate(episode): 128 | context = example['text'] 129 | if mode == "train": 130 | response = example['labels'][0] 131 | else: 132 | response = example['eval_labels'][0] 133 | chosen_topic = example['chosen_topic'] 134 | 135 | # Set up knowledge 136 | checked_knowledge = example['title'] + ' __knowledge__ ' + example['checked_sentence'] 137 | knowledge_sentences = [k for k in example['knowledge'].rstrip().split('\n')] 138 | assert "no_passages_used __knowledge__ no_passages_used" in knowledge_sentences 139 | 140 | for idx, k in enumerate(knowledge_sentences): 141 | if k == checked_knowledge: 142 | break 143 | else: 144 | # Sometimes, knowledge does not include checked_sentnece 145 | idx = None 146 | colorlog.warning("Knowledge does not include checked sentence.") 147 | if idx is None: 148 | knowledge_sentences += [checked_knowledge] 149 | 150 | new_example = {'context': context, 151 | 'response': response, 152 | 'chosen_topic': chosen_topic, 153 | 'knowledge_sentences': knowledge_sentences, 154 | 'chosen_knowledge': checked_knowledge, 155 | 'episode_num': episode_num, 156 | 'example_num': example_num} 157 | new_examples.append(new_example) 158 | new_episodes.append(new_examples) 159 | 160 | return new_episodes 161 | 162 | 163 | 164 | class WoWPassage: 165 | 166 | def __init__(self, title, position, sentences): 167 | self.id = "" # put a dummy one 168 | self.title = title 169 | self.position = position 170 | self.text = sentences 171 | self.type = [0] * len(sentences) # put a dummy one 172 | self.has_answer = None 173 | self.history_answers_spans = None 174 | self.history_has_answers = None 175 | self.answers_spans = None 176 | 177 | 178 | class WoWDataSample: 179 | 180 | def __init__(self, conv_id, turn_id, chosen_sentence): 181 | self.id = f'conv_{conv_id}_turn_{turn_id}' 182 | 183 | self.question = None 184 | 185 | self.question_type = None # list of 1s and 2s indicating previous turn roles 186 | self.history_dialog_act = None 187 | self.dialog_act = 0 # put a dummy one 188 | self.answers = [chosen_sentence] 189 | 190 | self.ctxs = None 191 | 192 | 193 | def set_question(self, question, question_type): 194 | self.question = question 195 | self.question_type = question_type 196 | assert len(question_type) == len(question) 197 | self. history_dialog_act = [0] * len(question) # put a dummy one 198 | 199 | 200 | def add_passage(self, title, position, sentences, has_answer, chosen_sentence, history_has_answers, history_sentences): 201 | if self.ctxs is None: 202 | self.ctxs = [] 203 | 204 | passage = WoWPassage(title, position, sentences) 205 | passage.has_answer = has_answer 206 | if has_answer: 207 | sentence_id = sentences.index(chosen_sentence) 208 | assert sentence_id >= 0 209 | passage.answers_spans = [[sentence_id, sentence_id]] 210 | 211 | passage.history_has_answers = history_has_answers 212 | history_answers_spans = [] 213 | for i, hist_has_ans in enumerate(history_has_answers): 214 | if not hist_has_ans: 215 | history_answers_spans += [None] 216 | else: 217 | sentence_id = sentences.index(history_sentences[i]) 218 | assert sentence_id >= 0 219 | history_answers_spans += [[sentence_id, sentence_id]] 220 | passage.history_answers_spans = history_answers_spans 221 | 222 | self.ctxs += [vars(passage)] 223 | 224 | 225 | class WoWDialog(): 226 | 227 | def __init__(self, turns): 228 | 229 | self.dialid = turns[0]['episode_num'] 230 | self.chosen_topic = turns[0]['chosen_topic'] 231 | self.utterances = turns 232 | 233 | def read_samples(self): 234 | 235 | samples = [] 236 | 237 | prev_turns = [] 238 | prev_type = [] 239 | prev_sentences = [] 240 | prev_passages = [] 241 | 242 | perc_found_prev_answers = [] 243 | 244 | for i, turn in enumerate(self.utterances): 245 | 246 | assert i == turn['example_num'] 247 | assert self.dialid == turn['episode_num'] 248 | 249 | user_turn = turn['context'] 250 | agent_turn = turn['response'] 251 | 252 | prev_passages += [None] 253 | prev_sentences += [None] 254 | prev_type += [USER_TYPE_ID] 255 | if i == 0: 256 | # the first user turn includes topic at the beginning 257 | topic = user_turn[:len(self.chosen_topic)] 258 | assert topic == self.chosen_topic 259 | # original topic and text separator is '\n' 260 | assert user_turn == topic or user_turn[len(self.chosen_topic)] == '\n' 261 | text = user_turn[len(self.chosen_topic):] 262 | if text: 263 | prev_turns += [f"{topic} {text.strip()}"] 264 | else: 265 | prev_turns += [f"{topic}"] 266 | else: 267 | prev_turns += [f" {user_turn.strip()}"] 268 | 269 | 270 | chosen_passage, chosen_sentence = turn['chosen_knowledge'].split(' __knowledge__ ') 271 | assert turn['chosen_knowledge'] in turn['knowledge_sentences'] 272 | # make data sample 273 | sample = WoWDataSample(self.dialid, i, chosen_sentence) 274 | sample.set_question(list(reversed(prev_turns)), list(reversed(prev_type))) 275 | 276 | 277 | title2sentences = defaultdict(list) 278 | all_k_sents = turn["knowledge_sentences"] 279 | prev_title = None 280 | titles = [] 281 | for j, sent in enumerate(all_k_sents): 282 | try: 283 | title, k_sent = sent.split(' __knowledge__ ') 284 | except: 285 | continue 286 | 287 | title2sentences[title] += [k_sent] 288 | prev_title = title 289 | if title not in titles: 290 | titles += [title] 291 | 292 | 293 | total_has_answer = 0 294 | total_prev_has_answer = [0] * len(prev_turns) 295 | for position, title in enumerate(titles): 296 | 297 | sentences = title2sentences[title] 298 | 299 | has_answer = (chosen_passage == title and chosen_sentence in sentences) 300 | total_has_answer += int(has_answer) 301 | 302 | prev_has_answers = [(prev_p == title and prev_sent in sentences) for prev_p, prev_sent in zip(prev_passages, prev_sentences)] 303 | total_prev_has_answer = [total_prev_has_answer[i]+int(ans) for i, ans in enumerate(prev_has_answers)] 304 | 305 | sample.add_passage(title, 306 | position, 307 | sentences, 308 | has_answer, 309 | chosen_sentence, 310 | list(reversed(prev_has_answers)), 311 | list(reversed(prev_sentences))) 312 | 313 | assert total_has_answer == 1 314 | 315 | n_found_prev_answers = sum([total_prev_has_answer[i] for i, ans in enumerate(total_prev_has_answer) if prev_type[i] == AGENT_TYPE_ID]) 316 | n_prev_agent_turns = len([t for t in prev_type if t == AGENT_TYPE_ID]) 317 | perc_found_prev_answers += [n_found_prev_answers*1.0/(n_prev_agent_turns+1e-9)] 318 | 319 | samples += [sample] 320 | 321 | prev_passages += [chosen_passage] 322 | prev_sentences += [chosen_sentence] 323 | prev_type += [AGENT_TYPE_ID] 324 | prev_turns += [f" {agent_turn.strip()}"] 325 | 326 | return samples, perc_found_prev_answers 327 | 328 | 329 | def write_examples(dialogues, outdir, split): 330 | dialogues = [WoWDialog(d) for d in dialogues] 331 | qas = [] 332 | perc_found_prev_answers_list = [] 333 | for dial in dialogues: 334 | samples, perc_found_prev_answers = dial.read_samples() 335 | perc_found_prev_answers_list += perc_found_prev_answers 336 | qas += [vars(sample) for sample in samples] 337 | 338 | print(f'{split} examples: {len(qas)}') 339 | print(f'percentage of previous answers can be found in current candidate passages: {sum(perc_found_prev_answers_list)/len(perc_found_prev_answers_list)}') 340 | with open(os.path.join(outdir, f'{split}.json'), 'w', encoding="utf-8") as fout: 341 | fout.write(json.dumps(qas, indent=4)) 342 | 343 | 344 | def main(args): 345 | 346 | reader = WowDatasetReader(args.cache_dir) 347 | dialogues = reader._load_and_preprocess_all('train') 348 | write_examples(dialogues, args.cache_dir, 'train') 349 | 350 | dialogues = reader._load_and_preprocess_all('valid') 351 | write_examples(dialogues, args.cache_dir, 'dev') 352 | 353 | dialogues = reader._load_and_preprocess_all('valid_unseen') 354 | write_examples(dialogues, args.cache_dir, 'dev_unseen') 355 | 356 | dialogues = reader._load_and_preprocess_all('test') 357 | write_examples(dialogues, args.cache_dir, 'test') 358 | 359 | dialogues = reader._load_and_preprocess_all('test_unseen') 360 | write_examples(dialogues, args.cache_dir, 'test_unseen') 361 | 362 | 363 | if __name__ == "__main__": 364 | parser = argparse.ArgumentParser() 365 | 366 | parser.add_argument( 367 | "--cache_dir", 368 | type=str, 369 | default='', 370 | help="directory path of the output data", 371 | ) 372 | args = parser.parse_args() 373 | 374 | 375 | main(args) 376 | -------------------------------------------------------------------------------- /data_utils/data_collator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import torch 4 | import numpy as np 5 | 6 | import config 7 | 8 | 9 | def _pad(target, fill_value, pad_len, dim=0): 10 | if pad_len == 0: 11 | return target 12 | size = list(target.size()) 13 | size[dim] = pad_len 14 | pad = torch.full(size, fill_value) 15 | return torch.cat([target, pad], dim=dim) 16 | 17 | 18 | class DataCollator: 19 | def __init__( 20 | self, 21 | data_name, 22 | tokenizer, 23 | max_seq_len, 24 | max_num_answers, 25 | max_num_passages_per_questions, 26 | special_attention, 27 | passage_attend_history, 28 | is_train, 29 | shuffle): 30 | self.tokenizer = tokenizer 31 | self.max_seq_len = max_seq_len 32 | self.max_num_answers = max_num_answers 33 | self.max_num_passages_per_questions = max_num_passages_per_questions 34 | self.special_attention = special_attention 35 | self.passage_attend_history = passage_attend_history 36 | self.is_train = is_train 37 | self.shuffle = shuffle 38 | self.data_name = data_name 39 | 40 | 41 | def __call__(self, samples): 42 | seq_lens, num_spans, num_passages, num_history_turns, num_user_turns = \ 43 | [], [], [], [], [] 44 | for sample in samples: 45 | seq_lens.extend([len(p.sequence_ids) for p in sample.all_passages]) 46 | num_passages.append(len(sample.all_passages)) 47 | num_spans.extend([len(p.clss) for p in sample.all_passages]) 48 | num_history_turns.extend( 49 | [len(p.question_boundaries) for p in sample.all_passages]) 50 | 51 | user_token_id = self.tokenizer.convert_tokens_to_ids(config.USER) 52 | user_idxs = ( 53 | sample.all_passages[0].sequence_ids == user_token_id).nonzero( 54 | as_tuple=True)[0] 55 | num_user_turns.append(user_idxs.nelement()) 56 | 57 | max_seq_len = max(seq_lens) 58 | assert self.max_seq_len >= max_seq_len, \ 59 | (f"max_seq_len ({max_seq_len}) > global max_seq_len ({self.max_seq_len})." 60 | f"Check preprocessing or data dir") 61 | 62 | max_num_spans = max(num_spans) 63 | passages_per_question = min( 64 | max(num_passages), self.max_num_passages_per_questions) 65 | max_num_history_turns = max(num_history_turns) 66 | max_num_user_turns = max(num_user_turns) 67 | 68 | 69 | ret = collections.defaultdict(list) 70 | for sample in samples: 71 | positive_ctxs = sample.positive_passages 72 | if self.is_train: 73 | negative_ctxs = sample.negative_passages 74 | else: 75 | negative_ctxs = sample.passages 76 | 77 | r = self._preprocess_sample( 78 | positive_ctxs, 79 | negative_ctxs, 80 | max_seq_len, 81 | max_num_spans, 82 | max_num_history_turns, 83 | max_num_user_turns, 84 | passages_per_question, 85 | ) 86 | for k, v in r.items(): 87 | ret[k].append(v) 88 | 89 | for k, v in ret.items(): 90 | ret[k] = torch.stack(v) 91 | 92 | ret['samples'] = samples 93 | 94 | return ret 95 | 96 | 97 | def _preprocess_sample( 98 | self, 99 | positives, 100 | negatives, 101 | max_seq_len, 102 | max_num_spans, 103 | max_num_history_turns, 104 | max_num_user_turns, 105 | passages_per_question, 106 | ): 107 | 108 | def _get_answers_tensor(spans): 109 | starts = [span[0] for span in spans] 110 | ends = [span[1] for span in spans] 111 | 112 | starts_tensor = torch.full( 113 | (passages_per_question, self.max_num_answers), 114 | -1, dtype=torch.long) 115 | starts_tensor[0, :len(starts)] = torch.tensor(starts) 116 | 117 | ends_tensor = torch.full( 118 | (passages_per_question, self.max_num_answers), 119 | -1, dtype=torch.long) 120 | ends_tensor[0, :len(ends)] = torch.tensor(ends) 121 | 122 | return starts_tensor, ends_tensor 123 | 124 | # select one positive 125 | if positives: 126 | if self.shuffle: 127 | positive_idx = np.random.choice(len(positives)) 128 | else: 129 | positive_idx = 0 130 | positive = positives[positive_idx] 131 | num_positives = 1 132 | else: 133 | num_positives = 0 134 | 135 | # select negatives 136 | negative_idxs = range(len(negatives)) 137 | if self.shuffle: 138 | negative_idxs = np.random.permutation(negative_idxs) 139 | negative_idxs = negative_idxs[:passages_per_question-num_positives] 140 | negatives = [negatives[i] for i in negative_idxs] 141 | 142 | 143 | if self.is_train: 144 | passages = [positive] + negatives 145 | else: 146 | passages = negatives 147 | 148 | num_history_turns = len(passages[0].question_boundaries) 149 | 150 | """ 151 | get labels 152 | """ 153 | ret = {} 154 | if self.is_train: 155 | ret['answer_starts'], ret['answer_ends'] = _get_answers_tensor( 156 | positive.answers_spans) 157 | history_da_label = torch.full( 158 | (passages_per_question, max_num_history_turns), 159 | -1, dtype=torch.long) 160 | history_da_label[0, :num_history_turns] = torch.tensor( 161 | positive.history_dialog_act_ids) 162 | ret['history_da_label'] = history_da_label 163 | 164 | da_label = torch.full( 165 | (passages_per_question,), -1, dtype=torch.long) 166 | da_label[0] = positive.dialog_act_id 167 | ret['da_label'] = da_label 168 | 169 | if self.data_name == 'dialdoc': 170 | user_token_id = self.tokenizer.convert_tokens_to_ids(config.USER) 171 | agent_token_id = self.tokenizer.convert_tokens_to_ids(config.AGENT) 172 | user_idxs = (passages[0].sequence_ids == user_token_id).nonzero( 173 | as_tuple=True)[0] 174 | agent_idxs = (passages[0].sequence_ids == agent_token_id).nonzero( 175 | as_tuple=True)[0] 176 | 177 | user_idxs = [(u, 'u') for u in user_idxs] 178 | agent_idxs = [(a, 'a') for a in agent_idxs] 179 | party_idxs = sorted(agent_idxs + user_idxs, key=lambda x: x[0]) 180 | assert party_idxs[0][1] == 'u', \ 181 | 'Make sure the first index is from user.' 182 | user_idxs = [] 183 | user_starts = [] 184 | for i, p in enumerate(party_idxs[1:], start=1): 185 | if p[1] == 'u' and party_idxs[i-1][1] == 'a': 186 | user_idxs.append(i) 187 | user_starts.append(p[0]) 188 | 189 | # +1 for the first user turn (current user turn). 190 | num_user_turns = len(user_idxs) + 1 191 | 192 | 193 | ret2 = collections.defaultdict(list) 194 | 195 | # [1] is the first start position of the current user turn. 196 | if self.data_name == 'dialdoc': 197 | ret2['user_starts'] = [ 198 | _pad(torch.tensor([1] + user_starts), 199 | -1, max_num_user_turns - num_user_turns) 200 | for _ in range(len(passages))] 201 | 202 | history_relevance = torch.full( 203 | (max_num_history_turns,), -1, dtype=torch.long) 204 | if self.data_name == 'dialdoc': 205 | user2agent_relevance = torch.full( 206 | (max_num_user_turns,), -1, dtype=torch.long) 207 | 208 | if self.data_name == 'dialdoc': 209 | if self.is_train: 210 | user2agent_relevance[0] = 0 211 | 212 | for i, p in enumerate(passages): 213 | # history 214 | h_answer_starts, h_answer_ends = [], [] 215 | for j, (span, has_answer) in enumerate( 216 | zip(p.history_answers_spans, p.history_has_answers)): 217 | if has_answer: 218 | h_answer_starts.append(span[0]) 219 | h_answer_ends.append(span[1]) 220 | history_relevance[j] = i 221 | else: 222 | h_answer_starts.append(-1) 223 | h_answer_ends.append(-1) 224 | h_answer_starts = torch.tensor(h_answer_starts, dtype=torch.long) 225 | h_answer_ends = torch.tensor(h_answer_ends, dtype=torch.long) 226 | h_answer_starts = _pad( 227 | h_answer_starts, -1, max_num_history_turns - num_history_turns) 228 | h_answer_ends = _pad( 229 | h_answer_ends, -1, max_num_history_turns - num_history_turns) 230 | ret2['history_answer_starts'].append(h_answer_starts) 231 | ret2['history_answer_ends'].append(h_answer_ends) 232 | 233 | if self.data_name == 'dialdoc': 234 | if i == 0 and self.is_train: 235 | user2agent_answer_starts = [positive.answers_spans[0][0]] 236 | user2agent_answer_ends = [positive.answers_spans[0][1]] 237 | else: 238 | user2agent_answer_starts = [-1] 239 | user2agent_answer_ends = [-1] 240 | 241 | for j, u_i in enumerate(user_idxs, start=1): 242 | has_answer = p.history_has_answers[u_i-1] 243 | if has_answer: 244 | span = p.history_answers_spans[u_i-1] 245 | user2agent_answer_starts.append(span[0]) 246 | user2agent_answer_ends.append(span[1]) 247 | user2agent_relevance[j] = i 248 | else: 249 | user2agent_answer_starts.append(-1) 250 | user2agent_answer_ends.append(-1) 251 | user2agent_answer_starts = torch.tensor( 252 | user2agent_answer_starts, dtype=torch.long) 253 | user2agent_answer_ends = torch.tensor( 254 | user2agent_answer_ends, dtype=torch.long) 255 | user2agent_answer_starts = _pad( 256 | user2agent_answer_starts, 257 | -1, max_num_user_turns - num_user_turns) 258 | user2agent_answer_ends = _pad( 259 | user2agent_answer_ends, 260 | -1, max_num_user_turns - num_user_turns) 261 | ret2['user2agent_answer_starts'].append(user2agent_answer_starts) 262 | ret2['user2agent_answer_ends'].append(user2agent_answer_ends) 263 | 264 | 265 | ret['history_relevance'] = history_relevance 266 | if self.data_name == 'dialdoc': 267 | ret['user2agent_relevance'] = user2agent_relevance 268 | 269 | # Gets inputs. 270 | for i, p in enumerate(passages): 271 | seq_len = len(p.sequence_ids) 272 | pad_len = max_seq_len - seq_len 273 | ret2['input_ids'].append( 274 | _pad(p.sequence_ids, self.tokenizer.pad_token_id, pad_len)) 275 | ret2['type_ids'].append(_pad(p.sequence_type_ids, 0, pad_len)) 276 | ret2['passage_positions'].append( 277 | torch.tensor(p.position, dtype=torch.long)) 278 | 279 | pad_len = max_num_spans - len(p.clss) 280 | ret2['clss'].append(_pad(p.clss, 0, pad_len)) 281 | ret2['ends'].append(_pad(p.ends, 0, pad_len)) 282 | ret2['mask_cls'].append(_pad(p.mask_cls, 0, pad_len)) 283 | 284 | pad_len = max_num_history_turns - len(p.question_boundaries) 285 | ret2['question_boundaries'].append(_pad(p.question_boundaries, -1, pad_len)) 286 | 287 | 288 | # Gets the special attention mask. 289 | if self.special_attention: 290 | user_token_id = self.tokenizer.convert_tokens_to_ids(config.USER) 291 | agent_token_id = self.tokenizer.convert_tokens_to_ids(config.AGENT) 292 | user_idxs = (ret2['input_ids'][-1] == user_token_id).nonzero( 293 | as_tuple=True)[0] 294 | agent_idxs = (ret2['input_ids'][-1] == agent_token_id).nonzero( 295 | as_tuple=True)[0] 296 | 297 | # 0 is for CLS 298 | special_idxs = [torch.tensor([0]), user_idxs] 299 | # agent_idx can be empty 300 | if agent_idxs.nelement() != 0: 301 | special_idxs.append(agent_idxs) 302 | text_token_id = self.tokenizer.convert_tokens_to_ids(config.TEXT) 303 | parent_title_token_id = self.tokenizer.convert_tokens_to_ids( 304 | config.PARENT_TITLE) 305 | text_idx = (ret2['input_ids'][-1] == text_token_id).nonzero( 306 | as_tuple=True)[0] 307 | parent_title_idx = ( 308 | ret2['input_ids'][-1] == parent_title_token_id).nonzero( 309 | as_tuple=True)[0] 310 | if parent_title_idx.nelement() == 0: 311 | passage_first_idx = text_idx 312 | else: 313 | passage_first_idx = parent_title_idx 314 | 315 | assert passage_first_idx.nelement() != 0, \ 316 | (f'text_idx = {text_idx.nelement()}, ' 317 | f'parent_title_idx = {parent_title_idx.nelement()}.') 318 | 319 | special_idxs.append(passage_first_idx) 320 | special_idxs = torch.sort(torch.cat(special_idxs))[0] 321 | 322 | attention_mask = [] 323 | for j in range(len(special_idxs)): 324 | m = torch.arange(max_seq_len) 325 | 326 | # passage can attend history, so start from 0 327 | if (j == len(special_idxs) - 1 328 | and self.passage_attend_history): 329 | m = ((m >= 0) & (m < seq_len)).long() 330 | else: 331 | m = ((m >= special_idxs[j]) & (m < seq_len)).long() 332 | 333 | # we want tokens within the range have the same 334 | # attention_mask, so we use repeat. 335 | if j < len(special_idxs)-1: 336 | repeat = special_idxs[j+1] - special_idxs[j] 337 | else: 338 | repeat = seq_len - special_idxs[j] 339 | 340 | m = m.repeat(repeat, 1) 341 | attention_mask.append(m) 342 | 343 | attention_mask.append( 344 | torch.zeros( 345 | (max_seq_len - seq_len, max_seq_len),dtype=torch.long)) 346 | ret2['attention_mask'].append(torch.cat(attention_mask)) 347 | 348 | name2fill_value = { 349 | 'input_ids': self.tokenizer.pad_token_id, 350 | 'type_ids': 0, 351 | 'passage_positions': 29, 352 | 'question_boundaries': -1, 353 | 'attention_mask': 0, # for special attention 354 | 'user_starts': -1, 355 | } 356 | 357 | pad_len = passages_per_question - len(passages) 358 | for k, v in ret2.items(): 359 | v = torch.stack(v) 360 | fill_value = name2fill_value.get(k, -1) 361 | ret2[k] = _pad(v, fill_value, pad_len) 362 | 363 | if not self.special_attention: 364 | assert 'attention_mask' not in ret2 365 | ret2['attention_mask'] = ( 366 | ret2['input_ids'] != self.tokenizer.pad_token_id).long() 367 | 368 | ret.update(ret2) 369 | 370 | return ret -------------------------------------------------------------------------------- /models/reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor as T 6 | 7 | from data_utils import utils as d_utils 8 | from models import loss 9 | from models import perturbation 10 | from utils import model_utils 11 | 12 | logger = logging.getLogger() 13 | 14 | 15 | def _pad_to_len(seq: T, pad_id: int, max_len: int): 16 | s_len = seq.size(0) 17 | if s_len > max_len: 18 | return seq[0: max_len] 19 | return torch.cat([seq, torch.Tensor().new_full( 20 | (max_len - s_len,), pad_id, dtype=torch.long).to(seq.device)], dim=0) 21 | 22 | 23 | class Reader(nn.Module): 24 | 25 | def __init__(self, args, encoder): 26 | super(Reader, self).__init__() 27 | self.args = args 28 | self.encoder = encoder 29 | 30 | hidden_size = encoder.config.hidden_size 31 | self.hidden_size = hidden_size 32 | 33 | if args.compute_da_loss: 34 | # TODO: replace hard-coded da label number 35 | self.da_classifier = nn.Linear(hidden_size, 7) 36 | self.segment_transform_da = nn.Linear(hidden_size, hidden_size) 37 | self.history_da_classifier = nn.Linear(hidden_size, 7) 38 | model_utils.init_weights( 39 | [self.da_classifier, 40 | self.history_da_classifier, 41 | self.segment_transform_da]) 42 | 43 | to_init = [] 44 | if args.hist_loss_weight > 0: 45 | self.segment_transform_start = nn.Linear(hidden_size, hidden_size) 46 | self.segment_transform_end = nn.Linear(hidden_size, hidden_size) 47 | self.segment_transform_relevance = nn.Linear( 48 | hidden_size, hidden_size) 49 | self.history_qa_classifier = nn.Linear(hidden_size, 1) 50 | 51 | to_init.extend( 52 | [ 53 | self.history_qa_classifier, 54 | self.segment_transform_start, 55 | self.segment_transform_end, 56 | self.segment_transform_relevance, 57 | ] 58 | ) 59 | 60 | if args.user2agent_loss_weight > 0: 61 | self.user2agent_transform_start = nn.Linear( 62 | hidden_size, hidden_size) 63 | self.user2agent_transform_end = nn.Linear(hidden_size, hidden_size) 64 | self.user2agent_transform_relevance = nn.Linear( 65 | hidden_size, hidden_size) 66 | self.user2agent_qa_classifier = nn.Linear(hidden_size, 1) 67 | 68 | to_init.extend( 69 | [ 70 | self.user2agent_transform_start, 71 | self.user2agent_transform_end, 72 | self.user2agent_transform_relevance, 73 | self.user2agent_qa_classifier, 74 | ] 75 | ) 76 | 77 | if args.span_marker: 78 | assert args.decision_function == 3 79 | self.marker_embs = nn.Embedding(2, hidden_size) 80 | to_init.append(self.marker_embs) 81 | 82 | if to_init: 83 | model_utils.init_weights(to_init) 84 | 85 | if (args.decision_function == 0 86 | or args.decision_function == 2 87 | or args.decision_function == 3): 88 | self.qa_outputs = nn.Linear(hidden_size, 2) 89 | self.qa_classifier = nn.Linear(hidden_size, 1) 90 | 91 | elif args.decision_function == 1: 92 | self.W_k_u = nn.Linear(hidden_size, hidden_size) 93 | self.W_v_u = nn.Linear(hidden_size, hidden_size) 94 | 95 | self.W_k_a = nn.Linear(hidden_size, hidden_size) 96 | self.W_v_a = nn.Linear(hidden_size, hidden_size) 97 | 98 | model_utils.init_weights([self.W_k_u, self.W_v_u, 99 | self.W_k_a, self.W_v_a]) 100 | self.dropout = nn.Dropout(args.dropout) 101 | self.softmax = nn.Softmax(dim=-1) 102 | 103 | if self.args.use_z_attn: 104 | self.W_u_a = nn.Linear(hidden_size, hidden_size) 105 | self.W_u_u = nn.Linear(hidden_size, hidden_size) 106 | model_utils.init_weights([self.W_u_u, self.W_u_a]) 107 | 108 | self.qa_classifier = nn.Linear(hidden_size, 1) 109 | 110 | self.qa_outputs = nn.Linear(3 * hidden_size, 2) 111 | 112 | else: 113 | raise NotImplementedError( 114 | f'unknown decision function {args.decision_function}') 115 | 116 | model_utils.init_weights([self.qa_outputs, self.qa_classifier]) 117 | 118 | self._setup_adv_training() 119 | 120 | def _setup_adv_training(self): 121 | self.adv_teacher = None 122 | if self.args.adv_loss_weight > 0: 123 | self.adv_teacher = perturbation.SmartPerturbation( 124 | self.args.adv_epsilon, 125 | self.args.adv_step_size, 126 | self.args.adv_noise_var, 127 | self.args.adv_norm_p, 128 | self.args.adv_k, 129 | norm_level=self.args.adv_norm_level) 130 | 131 | 132 | def forward( 133 | self, 134 | batch, 135 | global_step, 136 | fwd_type='input_ids', 137 | inputs_embeds=None, 138 | end_task_only=False): 139 | 140 | # Notations: 141 | # (1) N - number of questions in a batch. 142 | # (2) M - number of passages per questions. 143 | # (3) L - sequence length. 144 | N, M, L = batch['input_ids'].size() 145 | N, M, Ls = batch['clss'].size() 146 | 147 | # TODO: check question_broundaries when not training 148 | question_boundaries = batch['question_boundaries'].view(N * M, -1, 2) 149 | if 'user_starts' in batch: 150 | user_starts = batch['user_starts'].view(N * M, -1, 1) 151 | 152 | input_ids = batch['input_ids'].view(N*M, L) 153 | passage_positions = batch['passage_positions'].view(N, M) 154 | 155 | if batch['attention_mask'].dim() == 4: 156 | attention_mask = batch['attention_mask'].view(N*M, L, L) 157 | else: 158 | attention_mask = batch['attention_mask'].view(N*M, L) 159 | clss = batch['clss'].view(N*M, Ls) 160 | ends = batch['ends'].view(N*M, Ls) 161 | mask_cls = batch['mask_cls'].view(N*M, Ls) 162 | type_ids = batch['type_ids'].view(N*M, L) 163 | 164 | if self.args.ignore_token_type: 165 | encoder_type_ids = None 166 | else: 167 | encoder_type_ids = batch['type_ids'].view(N*M, L) 168 | 169 | if fwd_type == 'get_embs': 170 | # Gets embeddings only. 171 | assert inputs_embeds is None 172 | return self.encoder.embeddings(input_ids, encoder_type_ids) 173 | else: 174 | # Skips input_ids to inputs_embeds. 175 | if fwd_type == 'inputs_embeds': 176 | assert inputs_embeds is not None 177 | _input_ids = None 178 | elif fwd_type == 'input_ids': 179 | _input_ids = input_ids 180 | else: 181 | raise ValueError(f'fwd_type = {fwd_type} is not available') 182 | 183 | sequence_output, _pooled_output, _hidden_states = self.encoder( 184 | N, 185 | M, 186 | _input_ids, 187 | encoder_type_ids, 188 | passage_positions, 189 | attention_mask, 190 | inputs_embeds) 191 | 192 | # TODO: use batched_index_select 193 | span_start_embs = sequence_output[ 194 | torch.arange(sequence_output.size(0)).unsqueeze(1), clss] 195 | 196 | logits = {} 197 | if self.args.decision_function == 0: 198 | logits.update( 199 | self._forward_df0( 200 | N, 201 | M, 202 | sequence_output, 203 | _pooled_output, 204 | span_start_embs, 205 | question_boundaries, 206 | passage_positions, 207 | end_task_only)) 208 | elif self.args.decision_function == 1: 209 | logits.update( 210 | self._forward_df1( 211 | N, 212 | M, 213 | sequence_output, 214 | _pooled_output, 215 | type_ids, 216 | span_start_embs, 217 | question_boundaries, 218 | passage_positions, 219 | end_task_only)) 220 | elif self.args.decision_function == 2: 221 | logits.update( 222 | self._forward_df2( 223 | N, 224 | M, 225 | sequence_output, 226 | _pooled_output, 227 | span_start_embs, 228 | question_boundaries, 229 | passage_positions, 230 | user_starts, 231 | end_task_only)) 232 | elif self.args.decision_function == 3: 233 | if self.args.adv_loss_weight > 0: 234 | raise NotImplementedError 235 | 236 | logits.update( 237 | self._forward_df3( 238 | N, 239 | M, 240 | sequence_output, 241 | _pooled_output, 242 | span_start_embs, 243 | question_boundaries, 244 | passage_positions, 245 | mask_cls, 246 | global_step)) 247 | 248 | logits['start'] = logits['start'].view(N, M, Ls) 249 | logits['end'] = logits['end'].view(N, M, Ls) 250 | logits['relevance'] = logits['relevance'].view(N, M) 251 | 252 | others = {} 253 | if not end_task_only and self.args.adv_loss_weight > 0: 254 | adv_logits, emb_val, eff_perturb = self.adv_forward( 255 | batch, 256 | logits, 257 | global_step, 258 | self.args.adv_calc_logits_keys) 259 | logits.update(adv_logits) 260 | others['emb_val'] = emb_val 261 | others['eff_perturb'] = eff_perturb 262 | 263 | return logits, others 264 | 265 | def adv_forward(self, batch, logits, global_step, calc_logits_keys): 266 | assert self.adv_teacher is not None 267 | adv_logits, emb_val, eff_perturb = self.adv_teacher.forward( 268 | self, 269 | logits, 270 | batch, 271 | global_step, 272 | calc_logits_keys) 273 | return adv_logits, emb_val, eff_perturb 274 | 275 | def _forward_df0( 276 | self, 277 | N, 278 | M, 279 | sequence_output, 280 | _pooled_output, 281 | span_start_embs, 282 | question_boundaries, 283 | passage_positions, 284 | end_task_only): 285 | 286 | logits = {} 287 | 288 | if self.training and not end_task_only: 289 | start = self._get_question_boundary_start(question_boundaries) 290 | question_segments = model_utils.batched_index_select( 291 | sequence_output, 1, start) 292 | logits.update( 293 | self._calc_history_logits( 294 | N, M, sequence_output, span_start_embs, 295 | start, passage_positions)) 296 | logits.update( 297 | self._calc_da_logits(question_segments, sequence_output)) 298 | 299 | logits['start'], logits['end'] = self.qa_outputs( 300 | span_start_embs).split(1, dim=-1) 301 | logits['relevance'] = self.qa_classifier(_pooled_output) 302 | 303 | return logits 304 | 305 | def _forward_df2( 306 | self, 307 | N, 308 | M, 309 | sequence_output, 310 | _pooled_output, 311 | span_start_embs, 312 | question_boundaries, 313 | passage_positions, 314 | user_starts, 315 | end_task_only): 316 | logits = self._forward_df0(N, M, sequence_output, _pooled_output, 317 | span_start_embs, question_boundaries, 318 | passage_positions) 319 | 320 | if self.training and not end_task_only: 321 | # _get_question_boundary_start takes two idxs 322 | start = self._get_question_boundary_start(user_starts) 323 | logits.update( 324 | self._calc_user2agent_logits(N, M, sequence_output, 325 | span_start_embs, start, 326 | passage_positions)) 327 | 328 | return logits 329 | 330 | def _forward_df3( 331 | self, 332 | N, 333 | M, 334 | sequence_output, 335 | _pooled_output, 336 | span_start_embs, 337 | question_boundaries, 338 | passage_positions, 339 | mask_cls, 340 | global_step, 341 | ): 342 | logits = {} 343 | 344 | assert self.args.hist_loss_weight > 0 345 | 346 | start = self._get_question_boundary_start(question_boundaries) 347 | question_segments = model_utils.batched_index_select( 348 | sequence_output, 1, start) 349 | logits.update( 350 | self._calc_history_logits(N, M, sequence_output, span_start_embs, 351 | start, passage_positions)) 352 | logits.update(self._calc_da_logits(question_segments, sequence_output)) 353 | 354 | Ls = mask_cls.size(-1) 355 | mask_cls = mask_cls.view(N, M, Ls) 356 | max_num_history_questions = logits['history_relevance'].size(1) 357 | device = mask_cls.device 358 | 359 | history_start_logits = logits['history_start'].view(N, M, -1, Ls) 360 | history_end_logits = logits['history_end'].view(N, M, -1, Ls) 361 | 362 | # before: history_relevance_logits size = (N*M, max_num_history_questions, 1) 363 | # after: history_relevance_logits size = (N*max_num_history_questions, M, 1) 364 | history_relevance_logits = torch.cat([t.transpose(0, 1) for t in logits['history_relevance'].split(M, dim=0)], dim=0) 365 | 366 | # idxs size = (N*max_num_history_questions, M, 1) 367 | _, idxs = torch.sort(history_relevance_logits, dim=1, descending=True) 368 | top1 = idxs[:, 0].view(N, max_num_history_questions) 369 | 370 | marker_idxs = torch.zeros((N, M, Ls), dtype=torch.long) 371 | 372 | if global_step >= self.args.marker_after_steps: 373 | for n in range(N): 374 | for hq in range(max_num_history_questions): 375 | if self.args.skip_mark_last_user and hq == 0: 376 | continue 377 | passage_idx = top1[n][hq] 378 | p_start_logits = history_start_logits[n, passage_idx].tolist() 379 | p_end_logits = history_end_logits[n, passage_idx].tolist() 380 | start_index, end_index, _ = next(d_utils.start_end_finder(p_start_logits, p_end_logits, self.args.max_answer_length, None, mask_cls[n, passage_idx])) 381 | if start_index != -1 and end_index != -1: 382 | h = torch.arange(Ls) 383 | h = ((h >= start_index) & (h <= end_index)).long() 384 | marker_idxs[n, passage_idx] += h 385 | 386 | marker_idxs.clamp_(0, 1) 387 | marker_idxs = marker_idxs.to(device) 388 | 389 | # mark_embs size = (N, M, Ls, hidden_size) 390 | marker_embs = self.marker_embs(marker_idxs).view(N*M, Ls, -1) 391 | 392 | logits['start'], logits['end'] = self.qa_outputs( 393 | span_start_embs+marker_embs).split(1, dim=-1) 394 | logits['relevance'] = self.qa_classifier(_pooled_output) 395 | 396 | return logits 397 | 398 | def _forward_df1( 399 | self, 400 | N, 401 | M, 402 | sequence_output, 403 | _pooled_output, 404 | type_ids, 405 | span_start_embs, 406 | question_boundaries, 407 | passage_positions, 408 | end_task_only, 409 | ): 410 | 411 | logits = {} 412 | 413 | start = self._get_question_boundary_start(question_boundaries) 414 | question_types = torch.gather(type_ids, 1, start) 415 | # question_segments size = (N*M, max_num_history_questions, 768) 416 | question_segments = model_utils.batched_index_select( 417 | sequence_output, 1, start) 418 | 419 | if self.training and not end_task_only: 420 | logits.update( 421 | self._calc_history_logits(N, M, sequence_output, 422 | span_start_embs, start, 423 | passage_positions)) 424 | logits.update( 425 | self._calc_da_logits(question_segments, sequence_output)) 426 | 427 | 428 | ### start next turn logits calculation 429 | span_embs_user = self._calc_ctx_span_emb_by_role_history( 430 | question_types, question_segments, _pooled_output, 431 | span_start_embs, is_agent=False) 432 | span_embs_agent = self._calc_ctx_span_emb_by_role_history( 433 | question_types, question_segments, _pooled_output, 434 | span_start_embs, is_agent=True) 435 | 436 | s = torch.cat( 437 | (span_start_embs, span_embs_user, span_embs_agent), dim=-1) 438 | 439 | logits['start'], logits['end'] = self.qa_outputs(s).split(1, dim=-1) 440 | 441 | logits['relevance'] = self.qa_classifier(_pooled_output) 442 | 443 | return logits 444 | 445 | 446 | def _calc_ctx_span_emb_by_role_history( 447 | self, 448 | question_types, 449 | question_segments, 450 | _pooled_output, 451 | span_start_embs, 452 | is_agent=True 453 | ): 454 | # Calculates question_segments and start_mask for user or agent. 455 | # role_segments size = (N*M, max_role_questions, 768). 456 | # role_start_index size = (N*M, max_role_questions) 457 | # list of index tensor of role turns for each passage. 458 | qtype = 2 if is_agent else 1 459 | 460 | role_start_index = [ 461 | (question_types[i, :] == qtype).nonzero(as_tuple=True)[0] 462 | for i in range(question_segments.size(0))] 463 | max_role_questions = max([t.size(0) for t in role_start_index]) 464 | role_start_index = torch.stack( 465 | [_pad_to_len(t, -1, max_role_questions) for t in role_start_index], 466 | dim=0) 467 | role_start_mask = role_start_index == -1 468 | role_start_index = role_start_index.masked_fill(role_start_mask, 0) 469 | role_segments = model_utils.batched_index_select( 470 | question_segments, 1, role_start_index) 471 | 472 | # k: (N*M, max_n_spans, hid) 473 | k = span_start_embs 474 | _, max_n_spans, _ = k.size() 475 | v = k 476 | n_questions = max_role_questions 477 | if not is_agent: 478 | n_questions = min(max_role_questions, 2) 479 | mask = ~role_start_mask 480 | 481 | for i in range(n_questions): 482 | 483 | idx = n_questions - i - 1 484 | # u: (N*M, 1, hid) 485 | u = role_segments[:, idx:idx+1, :] 486 | extended_pooled = _pooled_output.repeat(1, max_n_spans, 1) 487 | extended_u = u.repeat(1, max_n_spans, 1) 488 | 489 | if is_agent: 490 | if self.args.use_z_attn: 491 | _v = torch.relu(self.W_k_a(self.dropout(k)) 492 | + self.W_v_a(self.dropout(extended_pooled)) 493 | + self.W_u_a(self.dropout(extended_u))) 494 | else: 495 | _v = torch.relu(self.W_k_a(self.dropout(k)) 496 | + self.W_v_a(self.dropout(extended_pooled))) 497 | else: 498 | if self.args.use_z_attn: 499 | _v = torch.relu(self.W_k_u(self.dropout(k)) 500 | + self.W_v_u(self.dropout(extended_pooled)) 501 | + self.W_u_u(self.dropout(extended_u))) 502 | else: 503 | _v = torch.relu(self.W_k_u(self.dropout(k)) 504 | + self.W_v_u(self.dropout(extended_pooled))) 505 | 506 | # q = 1; g: (N*M, 1, max_n_spans) 507 | if self.args.use_z_attn: 508 | g = torch.sigmoid( 509 | torch.einsum('bqh,bsh->bqs', u, extended_pooled) 510 | + torch.einsum('bqh,bsh->bqs', u, k)) 511 | else: 512 | g = torch.sigmoid( 513 | torch.einsum('bqh,bsh->bqs', u, extended_pooled)) 514 | # set g to be 0 on padded turns 515 | g = g.mul( 516 | mask.type(g.type())[:, idx:idx+1, None].repeat(1, 1, max_n_spans)) 517 | g = g.squeeze(1).unsqueeze(-1).repeat(1, 1, self.hidden_size) 518 | v = v + g.mul(_v) 519 | v = v / torch.abs(torch.norm(v, dim=(2)))[:,:,None] 520 | return v 521 | 522 | 523 | def _calc_da_logits(self, question_segments, sequence_output): 524 | 525 | logits = {} 526 | if self.args.compute_da_loss: 527 | da_segment = torch.relu(self.segment_transform_da(question_segments)) 528 | # (N * M, max_num_history_questions, 7) 529 | logits['history_da'] = self.history_da_classifier(da_segment) 530 | # (N * M, 7) 531 | logits['da'] = self.da_classifier(sequence_output[:, 0, :]) 532 | 533 | return logits 534 | 535 | 536 | def _calc_history_logits( 537 | self, 538 | N, 539 | M, 540 | sequence_output, 541 | span_start_embs, 542 | question_boundary_start, 543 | passage_positions 544 | ): 545 | if self.args.hist_loss_weight == 0: 546 | return {} 547 | # question_segments size = (N*M, max_num_history_questions, 768). 548 | question_segments = model_utils.batched_index_select( 549 | sequence_output, 1, question_boundary_start) 550 | 551 | logits = {} 552 | # einsum notations: 553 | # (1) b: NM. 554 | # (2) q: max_num_history_questions. 555 | # (3) h: 768. 556 | # (4) s: max_seq_len. 557 | start_segments = self.segment_transform_start(question_segments) 558 | logits['history_start'] = torch.einsum( 559 | 'bqh,bsh->bqs', start_segments, span_start_embs) 560 | 561 | end_segments = self.segment_transform_end(question_segments) 562 | logits['history_end'] = torch.einsum( 563 | 'bqh,bsh->bqs', end_segments, span_start_embs) 564 | 565 | # segments size = (N*M, max_num_history_questions, 768) 566 | # segment_transform_relevance size = (768, 768) 567 | relevance_segment = torch.relu( 568 | self.segment_transform_relevance(question_segments)) 569 | relevance_segment = self.encoder.coordinate( 570 | N, M, relevance_segment, passage_positions) 571 | 572 | # history_qa_classifier size = (768, 1) 573 | # history_relevance_logits size = (N*M, max_num_history_questions, 1) 574 | logits['history_relevance'] = self.history_qa_classifier( 575 | relevance_segment) 576 | 577 | return logits 578 | 579 | def _calc_user2agent_logits( 580 | self, 581 | N, 582 | M, 583 | sequence_output, 584 | span_start_embs, 585 | question_boundary_start, 586 | passage_positions, 587 | ): 588 | 589 | if self.args.user2agent_loss_weight == 0: 590 | return {} 591 | # question_segments size = (N*M, max_num_history_questions, 768) 592 | question_segments = model_utils.batched_index_select( 593 | sequence_output, 1, question_boundary_start) 594 | 595 | logits = {} 596 | # einsum notations: 597 | # (1) b: NM. 598 | # (2) q: max_num_history_questions. 599 | # (3) h: 768. 600 | # (4) s: max_seq_len. 601 | start_segments = self.user2agent_transform_start(question_segments) 602 | logits['user2agent_start'] = torch.einsum( 603 | 'bqh,bsh->bqs', start_segments, span_start_embs) 604 | 605 | end_segments = self.user2agent_transform_end(question_segments) 606 | logits['user2agent_end'] = torch.einsum( 607 | 'bqh,bsh->bqs', end_segments, span_start_embs) 608 | 609 | # segments size = (N*M, max_num_history_questions, 768) 610 | # segment_transform_relevance size = (768, 768) 611 | relevance_segment = torch.relu( 612 | self.user2agent_transform_relevance(question_segments)) 613 | relevance_segment = self.encoder.coordinate( 614 | N, M, relevance_segment, passage_positions) 615 | 616 | # history_qa_classifier size = (768, 1) 617 | # history_relevance_logits size = (N*M, max_num_history_questions, 1) 618 | logits['user2agent_relevance'] = self.user2agent_qa_classifier( 619 | relevance_segment) 620 | 621 | return logits 622 | 623 | 624 | def _get_question_boundary_start(self, question_boundaries): 625 | # start size = (N*M, max_num_history_questions) 626 | # sequence_output size = (N*M, max_num_history_questions, 768) 627 | start = question_boundaries[:, :, 0] 628 | start_mask = start == -1 629 | start = start.masked_fill(start_mask, 0) 630 | return start 631 | 632 | 633 | def compute_loss(args, logits, batch, others): 634 | 635 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 636 | 637 | N, M, L = batch['input_ids'].size() 638 | device = logits['start'].device 639 | dtype = logits['start'].dtype 640 | 641 | answer_starts = batch['answer_starts'].view(N * M, -1) 642 | answer_ends = batch['answer_ends'].view(N * M, -1) 643 | 644 | # (N*M) * Ls. 645 | start_logits = logits['start'].view(N * M, -1) 646 | end_logits = logits['end'].view(N * M, -1) 647 | 648 | # Next turn loss 649 | relevance_labels = torch.zeros(N, dtype=torch.long).to(device) 650 | passage_loss = loss_fct(logits['relevance'], relevance_labels) 651 | 652 | num_answers = answer_starts.size(1) 653 | start_loss = loss_fct(start_logits.repeat(num_answers, 1), answer_starts.T.reshape(-1)) 654 | end_loss = loss_fct(end_logits.repeat(num_answers, 1), answer_ends.T.reshape(-1)) 655 | span_loss = start_loss + end_loss 656 | 657 | 658 | # Dialog act (da) loss. 659 | if args.compute_da_loss: 660 | da_logits = logits['da'] 661 | history_da_logits = logits['history_da'] 662 | 663 | da_logits = da_logits.view(N * M, 7) 664 | # da_label: (N, M) 665 | da_loss = loss_fct(da_logits, batch['da_label'].view(-1)) 666 | history_da_logits = history_da_logits.view(-1, 7) 667 | # da_label: (N, M, max_turns) 668 | history_da_loss = loss_fct(history_da_logits, batch['history_da_label'].view(-1)) 669 | 670 | # History loss. 671 | if args.hist_loss_weight > 0: 672 | max_num_spans = logits['history_start'].size(2) 673 | # history_start_logits size = (N*M, max_num_history_questions, Ls) 674 | history_start_loss = loss_fct( 675 | logits['history_start'].view(-1, max_num_spans), 676 | batch['history_answer_starts'].view(-1)) 677 | history_end_loss = loss_fct( 678 | logits['history_end'].view(-1, max_num_spans), 679 | batch['history_answer_ends'].view(-1)) 680 | 681 | # history_relevance_logits size = (N*M, max_num_history_questions, 1) 682 | # history_relevance size = (N, max_num_history_questions) 683 | max_num_history_questions = logits['history_relevance'].size(1) 684 | history_relevance_logits = torch.cat( 685 | [t.transpose(0, 1) 686 | for t in logits['history_relevance'].split(M, dim=0)], dim=0) 687 | history_passage_loss = loss_fct(history_relevance_logits.view( 688 | N*max_num_history_questions, M), batch['history_relevance'].view(-1)) 689 | 690 | history_span_loss = history_start_loss + history_end_loss 691 | 692 | # user2agent. 693 | if args.user2agent_loss_weight > 0: 694 | max_num_spans = logits['user2agent_start'].size(2) 695 | # history_start_logits size = (N*M, max_num_history_questions, Ls) 696 | user2agent_start_loss = loss_fct(logits['user2agent_start'].view( 697 | -1, max_num_spans), batch['user2agent_answer_starts'].view(-1)) 698 | user2agent_end_loss = loss_fct(logits['user2agent_end'].view( 699 | -1, max_num_spans), batch['user2agent_answer_ends'].view(-1)) 700 | 701 | # history_relevance_logits size = (N*M, max_num_history_questions, 1) 702 | # history_relevance size = (N, max_num_history_questions) 703 | max_num_user_questions = logits['user2agent_relevance'].size(1) 704 | user2agent_relevance_logits = torch.cat( 705 | [t.transpose(0, 1) 706 | for t in logits['user2agent_relevance'].split(M, dim=0)], dim=0) 707 | user2agent_passage_loss = loss_fct(user2agent_relevance_logits.view( 708 | N*max_num_user_questions, M), batch['user2agent_relevance'].view(-1)) 709 | 710 | user2agent_span_loss = user2agent_start_loss + user2agent_end_loss 711 | 712 | # Adv loss. 713 | if args.adv_loss_weight > 0: 714 | adv_loss = {} 715 | for k in args.adv_calc_logits_keys: 716 | if others['emb_val'] >= 0: 717 | adv_loss[f'adv_{k}'] = loss.LOSS[ 718 | args.adv_loss_type](logits[k], logits[f'adv_{k}']) 719 | else: 720 | adv_loss[f'adv_{k}'] = torch.tensor(0, dtype=dtype).to(device) 721 | # Total 722 | losses = { 723 | "start": start_loss, 724 | "end": end_loss, 725 | "span": span_loss, 726 | "passage": passage_loss} 727 | total_loss = span_loss + passage_loss 728 | 729 | if args.hist_loss_weight > 0: 730 | losses["history_span"] = history_span_loss 731 | losses["history_passage"] = history_passage_loss 732 | history_loss = history_span_loss + history_passage_loss 733 | total_loss += args.hist_loss_weight * history_loss 734 | 735 | if args.user2agent_loss_weight > 0: 736 | losses["user2agent_span"] = user2agent_span_loss 737 | losses["user2agent_passage"] = user2agent_passage_loss 738 | user2agent_loss = user2agent_span_loss + user2agent_passage_loss 739 | total_loss += args.user2agent_loss_weight * user2agent_loss 740 | 741 | if args.compute_da_loss: 742 | losses["da"] = da_loss 743 | losses["history_da"] = history_da_loss 744 | total_loss += (da_loss + history_da_loss) 745 | 746 | if args.adv_loss_weight > 0: 747 | losses.update(adv_loss) 748 | total_loss += args.adv_loss_weight * sum(adv_loss.values()) 749 | 750 | losses["total"] = total_loss 751 | return losses 752 | -------------------------------------------------------------------------------- /train_reader.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | from typing import List 5 | import time 6 | import heapq 7 | 8 | import argparse 9 | import glob 10 | import logging 11 | import math 12 | import numpy as np 13 | import torch 14 | import transformers as tfs 15 | 16 | import config 17 | from data_utils import data_collator, reader_dataset 18 | from data_utils import utils as du 19 | from data_utils import data_class 20 | import eval 21 | import models 22 | from models import loss 23 | from utils import checkpoint 24 | from utils import dist_utils 25 | from utils import model_utils 26 | from utils import options 27 | from utils import sampler 28 | from utils import utils 29 | 30 | try: 31 | from apex import amp 32 | except: 33 | pass 34 | 35 | logger = logging.getLogger() 36 | logger.setLevel(logging.INFO) 37 | if logger.hasHandlers(): 38 | logger.handlers.clear() 39 | console = logging.StreamHandler() 40 | logger.addHandler(console) 41 | 42 | ReaderQuestionPredictions = collections.namedtuple( 43 | 'ReaderQuestionPredictions', 44 | ['id', 'gold_answers', 'passage_spans', 'passage_answers']) 45 | 46 | 47 | class ReaderTrainer(object): 48 | 49 | def __init__(self, args): 50 | 51 | utils.print_section_bar('Initializing components for training') 52 | 53 | self.topk_spans = 3 54 | self.topk_passages = 3 55 | self.topk_cp_info_filename = 'topk_cp_info.json' 56 | 57 | checkpoint_file = None 58 | if args.checkpoint_file: 59 | checkpoint_file = args.checkpoint_file 60 | else: 61 | # eval mode: get the best model automatically 62 | if args.train_file is None: 63 | checkpoint_file = os.path.join( 64 | args.output_dir, f'best_{args.best_metric}') 65 | checkpoint_file = os.path.join( 66 | args.output_dir, 67 | os.path.basename(os.readlink(checkpoint_file))) 68 | elif args.auto_resume: 69 | try: 70 | utils.print_section_bar( 71 | 'Auto resume from the latest checkpoint') 72 | if dist_utils.is_local_master(): 73 | logger.info(f'Checkpoint files {self.latest_cp_path}') 74 | if os.path.exists(self.latest_cp_path): 75 | checkpoint_file = self.latest_cp_path 76 | except Exception as e: 77 | logger.info(f'[Error] {e}') 78 | pass 79 | 80 | saved_state = None 81 | if checkpoint_file: 82 | assert os.path.exists(checkpoint_file), \ 83 | f'model does not exist! {checkpoint_file}' 84 | if dist_utils.is_local_master(): 85 | utils.print_section_bar('Restore from checkpoint') 86 | logger.info(f'Checkpoint files {checkpoint_file}') 87 | saved_state = checkpoint.load_states_from_checkpoint(checkpoint_file) 88 | options.set_encoder_params_from_state( 89 | saved_state.encoder_params, args) 90 | 91 | tokenizer = tfs.AutoTokenizer.from_pretrained( 92 | args.pretrained_model_cfg) 93 | tokenizer.add_special_tokens( 94 | {'additional_special_tokens': config.TOKENS}) 95 | 96 | encoder = models.HFBertEncoder.init_encoder(args, len(tokenizer)) 97 | reader = models.Reader(args, encoder) 98 | 99 | if args.inference_only: 100 | optimizer = None 101 | else: 102 | optimizer = model_utils.get_optimizer( 103 | reader, 104 | learning_rate=args.learning_rate, 105 | adam_eps=args.adam_eps, 106 | weight_decay=args.weight_decay) 107 | 108 | self.reader, self.optimizer = model_utils.setup_for_distributed_mode( 109 | reader, 110 | optimizer, 111 | args.device, 112 | args.n_gpu, 113 | args.local_rank, 114 | args.fp16, 115 | args.fp16_opt_level) 116 | 117 | self.start_epoch = 0 118 | self.start_offset = 0 119 | self.global_step = 0 120 | self.args = args 121 | 122 | if args.train_file is not None: 123 | self.topk_cp_info = self._load_topk_checkpoints_info() 124 | 125 | if saved_state: 126 | self._load_saved_state(saved_state) 127 | 128 | self.saved_state = saved_state 129 | self.tokenizer = tokenizer 130 | 131 | @property 132 | def latest_softlink(self): 133 | return os.path.join(self.args.output_dir, 'latest') 134 | 135 | @property 136 | def latest_cp_path(self): 137 | try: 138 | return os.path.join( 139 | self.args.output_dir, 140 | os.readlink(self.latest_softlink)) 141 | except: 142 | pass 143 | return '' 144 | 145 | def get_train_dataloader(self, train_dataset, shuffle=True, offset=0): 146 | if torch.distributed.is_initialized(): 147 | train_sampler = sampler.DistributedSampler( 148 | train_dataset, 149 | num_replicas=self.args.distributed_world_size, 150 | rank=self.args.local_rank, 151 | shuffle=shuffle) 152 | train_sampler.set_offset(offset) 153 | else: 154 | assert self.args.local_rank == -1 155 | train_sampler = torch.utils.data.RandomSampler(train_dataset) 156 | 157 | train_data_collator = data_collator.DataCollator( 158 | self.args.data_name, 159 | self.tokenizer, 160 | self.args.max_seq_len, 161 | self.args.max_num_answers, 162 | self.args.passages_per_question, 163 | self.args.special_attention, 164 | self.args.passage_attend_history, 165 | is_train=True, 166 | shuffle=True) 167 | 168 | dataloader = torch.utils.data.DataLoader( 169 | train_dataset, 170 | batch_size=self.args.batch_size, 171 | pin_memory=True, 172 | sampler=train_sampler, 173 | num_workers=0, 174 | collate_fn=train_data_collator, 175 | drop_last=False) 176 | 177 | return dataloader 178 | 179 | def get_eval_data_loader(self, eval_dataset): 180 | if torch.distributed.is_initialized(): 181 | eval_sampler = sampler.SequentialDistributedSampler( 182 | eval_dataset, 183 | num_replicas=self.args.distributed_world_size, 184 | rank=self.args.local_rank) 185 | else: 186 | assert self.args.local_rank == -1 187 | eval_sampler = torch.utils.data.SequentialSampler(eval_dataset) 188 | 189 | eval_data_collator = data_collator.DataCollator( 190 | self.args.data_name, 191 | self.tokenizer, 192 | self.args.max_seq_len, 193 | self.args.max_num_answers, 194 | self.args.passages_per_question_predict, 195 | self.args.special_attention, 196 | self.args.passage_attend_history, 197 | is_train=False, 198 | shuffle=False) 199 | 200 | dataloader = torch.utils.data.DataLoader( 201 | eval_dataset, 202 | batch_size=self.args.dev_batch_size, 203 | pin_memory=True, 204 | sampler=eval_sampler, 205 | num_workers=0, 206 | collate_fn=eval_data_collator, 207 | drop_last=False) 208 | 209 | return dataloader 210 | 211 | def run_train(self): 212 | args = self.args 213 | 214 | train_dataset = reader_dataset.ReaderDataset(args.train_file) 215 | train_dataloader = self.get_train_dataloader( 216 | train_dataset, 217 | shuffle=True, 218 | offset=self.start_offset) 219 | 220 | updates_per_epoch = math.ceil( 221 | len(train_dataloader) / args.gradient_accumulation_steps) 222 | total_updates = updates_per_epoch * args.num_train_epochs 223 | 224 | dataloader_steps = self.start_offset // ( 225 | args.distributed_world_size * args.batch_size) 226 | updated_steps = (dataloader_steps // 227 | args.gradient_accumulation_steps) + ( 228 | self.start_epoch * updates_per_epoch) 229 | remaining_updates = total_updates - updated_steps 230 | 231 | # global_step is added per dataloader step. 232 | calc_global_step = (self.start_epoch * len(train_dataloader) + 233 | dataloader_steps) 234 | 235 | assert self.global_step == calc_global_step, \ 236 | (f'global step = {self.global_step}, ' 237 | f'calc global step = {calc_global_step}') 238 | 239 | self.scheduler = model_utils.get_schedule_linear( 240 | self.optimizer, 241 | warmup_steps=args.warmup_steps, 242 | training_steps=total_updates, 243 | last_epoch=self.global_step-1) 244 | 245 | if self.saved_state: 246 | if self.saved_state.scheduler_dict: 247 | if dist_utils.is_local_master(): 248 | logger.info(f'Loading scheduler state ...') 249 | self.scheduler.load_state_dict(self.saved_state.scheduler_dict) 250 | 251 | utils.print_section_bar('Training') 252 | if dist_utils.is_local_master(): 253 | logger.info(f'Total updates = {total_updates}') 254 | logger.info( 255 | f'Updates per epoch (/gradient accumulation) = ' 256 | f'{updates_per_epoch}') 257 | logger.info( 258 | f'Steps per epoch (dataloader) = {len(train_dataloader)}') 259 | logger.info( 260 | f'Gradient accumulation steps = ' 261 | f'{args.gradient_accumulation_steps}') 262 | logger.info( 263 | f'Start offset of the epoch {self.start_epoch} (dataset) = ' 264 | f'step {self.start_offset}') 265 | logger.info( 266 | f'Updated step of the epoch {self.start_epoch} (dataloader) = ' 267 | f'step {updated_steps}') 268 | logger.info( 269 | f'Total remaining updates = {remaining_updates}') 270 | 271 | # Starts training here. 272 | for epoch in range(self.start_epoch, int(args.num_train_epochs)): 273 | utils.print_section_bar(f'Epoch {epoch}') 274 | 275 | if isinstance(train_dataloader.sampler, sampler.DistributedSampler): 276 | train_dataloader.sampler.set_epoch(epoch) 277 | 278 | self._train_epoch(epoch, train_dataloader) 279 | 280 | if isinstance(train_dataloader.sampler, sampler.DistributedSampler): 281 | train_dataloader.sampler.set_offset(0) 282 | 283 | utils.print_section_bar('Training finished.') 284 | if dist_utils.is_local_master(): 285 | best_em = -self.topk_cp_info['EM'][0][0] 286 | best_f1 = -self.topk_cp_info['F1'][0][0] 287 | best_em_path = self.topk_cp_info['EM'][0][1] 288 | best_f1_path = self.topk_cp_info['F1'][0][1] 289 | logger.info(f'Best EM {best_em * 100:.2f} path = {best_em_path}') 290 | logger.info(f'Best F1 {best_f1 * 100:.2f} path = {best_f1_path}') 291 | 292 | return 293 | 294 | def validate_and_save(self, epoch, offset): 295 | curr_em, curr_f1 = self.validate() 296 | 297 | args = self.args 298 | 299 | if dist_utils.is_local_master(): 300 | cp_path = self._save_checkpoint(epoch, offset, dry_run=True) 301 | 302 | # Uses min heap, so add a negative 303 | if curr_em > -self.topk_cp_info['EM'][0][0]: 304 | logger.info(f'New best EM {curr_em*100:.2f} on dev') 305 | self._save_checkpoint(epoch, offset) 306 | utils.softlink( 307 | cp_path, 308 | os.path.join(args.output_dir, 'best_em')) 309 | 310 | if curr_f1 > -self.topk_cp_info['F1'][0][0]: 311 | logger.info(f'New best F1 {curr_f1*100:.2f} on dev') 312 | self._save_checkpoint(epoch, offset) 313 | utils.softlink( 314 | cp_path, 315 | os.path.join(args.output_dir, 'best_f1')) 316 | 317 | heapq.heappush(self.topk_cp_info['EM'], (-curr_em, cp_path)) 318 | heapq.heappush(self.topk_cp_info['F1'], (-curr_f1, cp_path)) 319 | tmp = [] 320 | for _ in range(min(args.topk_em, len(self.topk_cp_info['EM']))): 321 | heapq.heappush(tmp, heapq.heappop(self.topk_cp_info['EM'])) 322 | self.topk_cp_info['EM'] = tmp 323 | tmp = [] 324 | for _ in range(min(args.topk_f1, len(self.topk_cp_info['F1']))): 325 | heapq.heappush(tmp, heapq.heappop(self.topk_cp_info['F1'])) 326 | self.topk_cp_info['F1']= tmp 327 | 328 | best_em = -self.topk_cp_info['EM'][0][0] 329 | best_f1 = -self.topk_cp_info['F1'][0][0] 330 | best_em_path = self.topk_cp_info['EM'][0][1] 331 | best_f1_path = self.topk_cp_info['F1'][0][1] 332 | logger.info(f'Curr EM {curr_em * 100:.2f}') 333 | logger.info(f'Curr F1 {curr_f1 * 100:.2f}') 334 | logger.info(f'Best EM {best_em * 100:.2f} path = {best_em_path}') 335 | logger.info(f'Best F1 {best_f1 * 100:.2f} path = {best_f1_path}') 336 | 337 | self._save_topk_checkpoints_info() 338 | 339 | all_saved_cps = checkpoint.get_saved_checkpoints( 340 | args, args.checkpoint_filename_prefix) 341 | keep_cps = (set(c[1] for c in self.topk_cp_info['EM']) 342 | | set(c[1] for c in self.topk_cp_info['F1'])) 343 | keep_cps = set( 344 | [os.path.join(args.output_dir, cp) for cp in keep_cps]) 345 | if self.latest_cp_path: 346 | keep_cps.update([self.latest_cp_path]) 347 | for cp in all_saved_cps: 348 | if cp not in keep_cps: 349 | os.remove(cp) 350 | 351 | def validate(self): 352 | if dist_utils.is_local_master(): 353 | logger.info('Validation ...') 354 | 355 | args = self.args 356 | topk_passages = self.topk_passages 357 | eval_dataset = reader_dataset.ReaderDataset(args.dev_file) 358 | eval_dataloader = self.get_eval_data_loader(eval_dataset) 359 | 360 | all_results = [] 361 | validate_batch_times = [] 362 | for step, batch in enumerate(eval_dataloader): 363 | self.reader.eval() 364 | step += 1 365 | 366 | if step % 100 == 0 and dist_utils.is_local_master(): 367 | logger.info( 368 | f'Eval step {step} / {len(eval_dataloader)}; ' 369 | f'eval time per batch = {np.mean(validate_batch_times):.2f}') 370 | 371 | batch = model_utils.move_to_device(batch, args.device) 372 | 373 | start_time = time.time() 374 | if args.local_rank != -1: 375 | # Uses DDP. 376 | with self.reader.no_sync(), torch.no_grad(): 377 | logits, _ = self.reader( 378 | batch, self.global_step, end_task_only=True) 379 | else: 380 | # Uses a single GPU. 381 | with torch.no_grad(): 382 | logits, _ = self.reader( 383 | batch, self.global_step, end_task_only=True) 384 | end_time = time.time() 385 | validate_batch_times.append(end_time - start_time) 386 | 387 | batch_predictions = self._get_best_prediction( 388 | logits['start'], 389 | logits['end'], 390 | logits['relevance'], 391 | batch['samples']) 392 | 393 | all_results.extend(batch_predictions) 394 | 395 | # Deletes output of the current iteration to save memory. 396 | del logits 397 | 398 | all_passage_f1s = [] 399 | all_passage_at_k = [] 400 | all_passage_em_at_k = [] 401 | for pred in all_results: 402 | # we only have a single answer 403 | gold_answer = pred.gold_answers[0] 404 | 405 | passage_at_k = utils.convert_to_at_k(pred.passage_answers) 406 | passage_em_at_k = [] 407 | for topk_passage_idx in range(topk_passages): 408 | 409 | spans = pred.passage_spans[topk_passage_idx] 410 | 411 | ems = [] 412 | for s_i, s in enumerate(spans): 413 | if s is not None: 414 | em = eval.compute_exact(gold_answer, s.prediction_text) 415 | em = {True: 1, False: 0}[em] 416 | ems.append(em) 417 | else: 418 | # an empty span will be counted as an empty string 419 | ems.append(0) 420 | 421 | if topk_passage_idx == 0: 422 | f1 = 0.0 423 | if spans and spans[0] is not None: 424 | f1 = eval.compute_f1( 425 | gold_answer, spans[0].prediction_text) 426 | all_passage_f1s.append(f1) 427 | 428 | em_at_k = utils.convert_to_at_k(ems) 429 | passage_em_at_k.append(em_at_k) 430 | 431 | all_passage_at_k.append(passage_at_k) 432 | all_passage_em_at_k.append(passage_em_at_k) 433 | 434 | # Gathers results from other GPUs 435 | limit = len(eval_dataset) 436 | all_passage_at_k = torch.cat( 437 | dist_utils.all_gather(all_passage_at_k)).int().numpy()[:limit] 438 | all_passage_em_at_k = torch.cat( 439 | dist_utils.all_gather(all_passage_em_at_k)).int().numpy()[:limit] 440 | all_passage_f1s = torch.cat( 441 | dist_utils.all_gather(all_passage_f1s)).float().numpy()[:limit] 442 | 443 | avg_passage_at_k = np.ma.masked_where( 444 | all_passage_at_k == -1, all_passage_at_k).mean(axis=0).tolist() 445 | avg_passage_em_at_k = np.ma.masked_where( 446 | all_passage_em_at_k == -1, all_passage_em_at_k).mean(axis=0).tolist() 447 | avg_passage_f1 = float(np.mean(all_passage_f1s)) 448 | 449 | if dist_utils.is_local_master(): 450 | passage_at_k_str = '' 451 | passage_at_k_dic = {} 452 | for k_i, acc in enumerate(avg_passage_at_k): 453 | n = f'Passage@{k_i+1}' 454 | if acc is None: 455 | passage_at_k_str += f'{n} = 0.0; ' 456 | passage_at_k_dic[n] = 0.0 457 | else: 458 | passage_at_k_str += f'{n} = {acc * 100:.2f}; ' 459 | passage_at_k_dic[n] = acc 460 | 461 | passage_em_at_k_str = '' 462 | passage_em_at_k_dic = {} 463 | for p, em_at_k in enumerate(avg_passage_em_at_k): 464 | n_p = f'Rank {p+1} passage' 465 | passage_em_at_k_str += f'{n_p}: ' 466 | for k_i, acc in enumerate(em_at_k): 467 | n_em = f'EM@{k_i+1}' 468 | passage_em_at_k_str += f'{n_em} = {acc * 100:.2f}; ' 469 | passage_em_at_k_dic[f'{n_p} {n_em}'] = acc 470 | passage_em_at_k_str += '\n' 471 | 472 | logger.info(f'eval_top_docs = {args.eval_top_docs[0]}') 473 | logger.info(f'F1 = {avg_passage_f1*100:.2f}') 474 | logger.info(passage_at_k_str) 475 | logger.info(passage_em_at_k_str) 476 | 477 | # Gathers numerical data from other GPUs. 478 | # Strings should be obtained directly from eval_dataset. 479 | if args.prediction_results_file: 480 | self._save_predictions( 481 | args.prediction_results_file, all_results) 482 | 483 | return avg_passage_em_at_k[0][0], avg_passage_f1 484 | 485 | 486 | def _train_epoch(self, epoch, train_dataloader): 487 | args = self.args 488 | epoch_loss = 0 489 | rolling_train_losses = collections.defaultdict(int) 490 | rolling_train_others = collections.defaultdict(int) 491 | 492 | step_offset = 0 493 | # For restoring from a checkpoint. 494 | if train_dataloader.sampler.current_offset != 0: 495 | step_offset += (train_dataloader.sampler.current_offset // 496 | (args.distributed_world_size 497 | * args.batch_size)) 498 | 499 | train_batch_times = [] 500 | start_time = time.time() 501 | for step, batch in enumerate(train_dataloader, start=step_offset): 502 | self.reader.train() 503 | step += 1 504 | 505 | batch_start_time = time.time() 506 | if step % args.gradient_accumulation_steps != 0 \ 507 | and args.local_rank != -1: 508 | with self.reader.no_sync(): 509 | losses, others = self._training_step(batch) 510 | else: 511 | losses, others = self._training_step(batch) 512 | batch_end_time = time.time() 513 | train_batch_times.append(batch_end_time - batch_start_time) 514 | 515 | self.global_step += 1 516 | 517 | # Saves latest checkpoint every X minutes. 518 | if dist_utils.is_local_master(): 519 | now_time = time.time() 520 | time_diff = now_time - start_time 521 | # Converts seconds to minutes. 522 | if time_diff // \ 523 | (60 * args.save_checkpoint_every_minutes) == 1: 524 | logger.info( 525 | f'Save checkpoint every ' 526 | f'{args.save_checkpoint_every_minutes} minutes.') 527 | dataset_offset = (step 528 | * args.distributed_world_size 529 | * args.batch_size) 530 | cp_path = self._save_checkpoint(epoch, dataset_offset) 531 | if self.latest_cp_path: 532 | os.remove(self.latest_cp_path) 533 | utils.softlink(cp_path, self.latest_softlink) 534 | start_time = now_time 535 | 536 | ''' 537 | record loss 538 | ''' 539 | epoch_loss += losses['total'] 540 | for k, loss in losses.items(): 541 | rolling_train_losses[k] += loss 542 | for k, other in others.items(): 543 | # other could be -1 if adv_loss not applicable 544 | rolling_train_others[k] += max(other, 0) 545 | 546 | ''' 547 | parameters update 548 | ''' 549 | if (step - step_offset) % args.gradient_accumulation_steps == 0: 550 | if args.max_grad_norm > 0: 551 | if args.fp16: 552 | torch.nn.utils.clip_grad_norm_( 553 | amp.master_params(self.optimizer), args.max_grad_norm 554 | ) 555 | else: 556 | torch.nn.utils.clip_grad_norm_( 557 | self.reader.parameters(), args.max_grad_norm 558 | ) 559 | 560 | self.scheduler.step() 561 | self.optimizer.step() 562 | self.reader.zero_grad() 563 | 564 | if self.global_step % args.log_batch_step == 0: 565 | lr = self.optimizer.param_groups[0]['lr'] 566 | if dist_utils.is_local_master(): 567 | avg_batch_time = np.mean(train_batch_times) 568 | logger.info( 569 | f'Epoch: {epoch}: ' 570 | f'Step: {step}/{len(train_dataloader)}; ' 571 | f'Global_step={self.global_step}; ' 572 | f'lr={lr:.3e}; ' 573 | f'train time per batch = {avg_batch_time:.2f}') 574 | 575 | if (step - step_offset) % args.train_rolling_loss_step == 0: 576 | 577 | log_str = (f'Avg. loss and other in the recent ' 578 | f'{args.train_rolling_loss_step} batches: \n') 579 | for k, loss in rolling_train_losses.items(): 580 | loss /= args.train_rolling_loss_step 581 | loss = torch.cat( 582 | dist_utils.all_gather([loss])).mean().numpy() 583 | log_str += (' -' + f'{k:>21} loss: {loss:.4f}\n') 584 | 585 | for k, other in rolling_train_others.items(): 586 | other /= args.train_rolling_loss_step 587 | other = torch.cat( 588 | dist_utils.all_gather([other])).mean().numpy() 589 | log_str += (' -' + f'{k:>20} other: {other:.8f}\n') 590 | 591 | if dist_utils.is_local_master(): 592 | logger.info(f'Train: global step = {self.global_step}; ' 593 | f'step = {step}') 594 | logger.info(log_str) 595 | 596 | rolling_train_losses = collections.defaultdict(int) 597 | rolling_train_others = collections.defaultdict(int) 598 | 599 | if self.global_step % args.eval_step == 0: 600 | if dist_utils.is_local_master(): 601 | logger.info( 602 | f'Validation: Epoch: {epoch} ' 603 | f'Step: {step}/{len(train_dataloader)}') 604 | dataset_offset = (step 605 | * args.distributed_world_size 606 | * args.batch_size) 607 | self.validate_and_save(epoch, dataset_offset) 608 | 609 | epoch_loss = epoch_loss / len(train_dataloader) 610 | 611 | if dist_utils.is_local_master(): 612 | logger.info(f'Avg. total Loss of epoch {epoch} ={epoch_loss:.3f}') 613 | 614 | def _save_topk_checkpoints_info(self): 615 | dic = {} 616 | dic['EM'] = [(-c[0], c[1])for c in self.topk_cp_info['EM']] 617 | dic['F1'] = [(-c[0], c[1])for c in self.topk_cp_info['F1']] 618 | path = os.path.join( 619 | self.args.output_dir, self.topk_cp_info_filename) 620 | with open(path, 'w') as f: 621 | json.dump(dic, f, indent=4) 622 | 623 | def _load_topk_checkpoints_info(self): 624 | path = os.path.join( 625 | self.args.output_dir, self.topk_cp_info_filename) 626 | if os.path.exists(path): 627 | with open(path, 'r') as f: 628 | dic = json.load(f) 629 | dic['EM'] = [(-c[0], c[1])for c in dic['EM']] 630 | dic['F1'] = [(-c[0], c[1])for c in dic['F1']] 631 | else: 632 | dic = {} 633 | dic['EM'] = [(0, '')] 634 | dic['F1'] = [(0, '')] 635 | return dic 636 | 637 | def _save_checkpoint(self, epoch, offset, dry_run=False): 638 | cp_path = os.path.join( 639 | self.args.output_dir, 640 | '.'.join( 641 | [ 642 | self.args.checkpoint_filename_prefix, 643 | str(epoch), 644 | str(offset), 645 | str(self.global_step), 646 | ] 647 | ) 648 | ) 649 | 650 | if dry_run: 651 | return os.path.basename(cp_path) 652 | # file already saved! 653 | if os.path.exists(cp_path): 654 | return os.path.basename(cp_path) 655 | 656 | logger.info(f'Saved checkpoint to {cp_path}') 657 | 658 | model_to_save = model_utils.get_model_obj(self.reader) 659 | meta_params = options.get_encoder_params_state(self.args) 660 | state = checkpoint.CheckpointState( 661 | model_to_save.state_dict(), 662 | self.optimizer.state_dict(), 663 | self.scheduler.state_dict(), 664 | amp.state_dict(), 665 | offset, 666 | epoch, 667 | self.global_step, 668 | meta_params, 669 | ) 670 | 671 | torch.save(state._asdict(), cp_path) 672 | return os.path.basename(cp_path) 673 | 674 | def _load_saved_state(self, saved_state: checkpoint.CheckpointState): 675 | epoch = saved_state.epoch 676 | offset = saved_state.offset 677 | global_step = saved_state.global_step 678 | if offset == 0: # epoch has been completed 679 | epoch += 1 680 | 681 | if dist_utils.is_local_master(): 682 | logger.info( 683 | f'Loading checkpoint @' 684 | f'epoch = {epoch}, ' 685 | f'offset = {offset}, ' 686 | f'global_step = {global_step}, ' 687 | ) 688 | self.start_epoch = epoch 689 | self.start_offset = offset 690 | self.global_step = global_step 691 | 692 | model_to_load = model_utils.get_model_obj(self.reader) 693 | if saved_state.model_dict: 694 | if dist_utils.is_local_master(): 695 | logger.info('Loading model weights from saved state ...') 696 | if self.args.train_file is None: 697 | model_to_load.load_state_dict( 698 | saved_state.model_dict, strict=False) 699 | else: 700 | model_to_load.load_state_dict(saved_state.model_dict) 701 | 702 | 703 | if self.args.train_file is not None: 704 | if saved_state.optimizer_dict: 705 | if dist_utils.is_local_master(): 706 | logger.info('Loading saved optimizer state ...') 707 | self.optimizer.load_state_dict(saved_state.optimizer_dict) 708 | 709 | if self.args.auto_resume: 710 | self.optimizer.state = {} 711 | amp.load_state_dict(saved_state.amp_dict) 712 | 713 | def _get_best_prediction( 714 | self, 715 | start_logits, 716 | end_logits, 717 | relevance_logits, 718 | samples_batch: List[data_class.ReaderSample] 719 | ) -> List[ReaderQuestionPredictions]: 720 | 721 | args = self.args 722 | topk_spans = self.topk_spans 723 | topk_passages = self.topk_passages 724 | passage_thresholds = self.args.eval_top_docs 725 | 726 | max_answer_length = args.max_answer_length 727 | questions_num, passages_per_question = relevance_logits.size() 728 | 729 | _, idxs = torch.sort(relevance_logits, dim=1, descending=True) 730 | 731 | batch_results = [] 732 | max_num_passages = passage_thresholds[0] 733 | for q in range(questions_num): 734 | sample = samples_batch[q] 735 | 736 | non_empty_passages_num = len(sample.passages) 737 | 738 | passage_spans = [] 739 | passage_answers = [] 740 | for p in range(passages_per_question): 741 | 742 | # Needs topk passage but some passages will be passed because of 743 | # empty passages. 744 | if len(passage_spans) == topk_passages: 745 | break 746 | 747 | passage_idx = idxs[q, p].item() 748 | 749 | if not (passage_idx < max_num_passages): 750 | continue 751 | 752 | # Empty passage is selected, so skip. 753 | if passage_idx >= non_empty_passages_num: 754 | continue 755 | 756 | reader_passage = sample.passages[passage_idx] 757 | sequence_ids = reader_passage.sequence_ids 758 | sequence_len = sequence_ids.size(0) 759 | reader_passage.has_answer 760 | # Assumes question & title information is at the beginning of the sequence 761 | 762 | p_start_logits = start_logits[q, passage_idx].tolist() 763 | p_end_logits = end_logits[q, passage_idx].tolist() 764 | best_spans = du.get_best_spans( 765 | p_start_logits, 766 | p_end_logits, 767 | max_answer_length, 768 | passage_idx, 769 | reader_passage.span_texts, 770 | reader_passage.span_types, 771 | reader_passage.mask_cls.tolist(), 772 | relevance_logits[q, passage_idx].item(), 773 | top_spans=10) 774 | 775 | best_spans = best_spans[:topk_spans] 776 | best_spans += [None] * (topk_spans - len(best_spans)) 777 | passage_spans.append(best_spans) 778 | assert len(passage_spans[-1]) == topk_spans 779 | 780 | passage_answers.append( 781 | {True: 1, False: 0}[reader_passage.has_answer]) 782 | 783 | # No passage 784 | # -1 as padding 785 | passage_answers += [-1] * (topk_passages - len(passage_answers)) 786 | passage_spans += [[None]*topk_spans] * (topk_passages - len(passage_spans)) 787 | assert len(passage_answers) == topk_passages 788 | 789 | batch_results.append( 790 | ReaderQuestionPredictions( 791 | sample.id, sample.answers, passage_spans, passage_answers)) 792 | return batch_results 793 | 794 | def _training_step(self, batch) -> torch.Tensor: 795 | args = self.args 796 | batch = model_utils.move_to_device(batch, args.device) 797 | logits, others = self.reader(batch, self.global_step) 798 | 799 | losses = models.compute_loss(args, logits, batch, others) 800 | 801 | losses = {k: loss.mean() for k, loss in losses.items()} 802 | 803 | if args.fp16: 804 | with amp.scale_loss(losses['total'], self.optimizer) as scaled_loss: 805 | scaled_loss.backward() 806 | else: 807 | losses['total'].backward() 808 | 809 | return {k: v.item() for k, v in losses.items()}, others 810 | 811 | def _get_preprocessed_filepaths(self, data_files: List, is_train: bool): 812 | serialized_files = [fn for fn in data_files if fn.endswith('.pkl')] 813 | if serialized_files: 814 | return serialized_files 815 | 816 | assert len(data_files) == 1, \ 817 | 'Only 1 source file pre-processing is supported.' 818 | 819 | # Data may have been serialized and cached before, 820 | # Tries to find ones from same dir. 821 | def _find_cached_files(path: str): 822 | dir_path, base_name = os.path.split(path) 823 | base_name = base_name.replace('.json', '') 824 | out_file_prefix = os.path.join(dir_path, base_name) 825 | out_file_pattern = out_file_prefix + '*.pkl' 826 | return glob.glob(out_file_pattern), out_file_prefix 827 | 828 | serialized_files, _ = _find_cached_files(data_files[0]) 829 | 830 | assert serialized_files, 'run preprocessing code before training' 831 | 832 | if serialized_files: 833 | logger.info('Found preprocessed files. %s', serialized_files) 834 | return serialized_files 835 | 836 | def _save_predictions( 837 | self, 838 | out_file: str, 839 | prediction_results: List[ReaderQuestionPredictions]): 840 | 841 | logger.info(f'Saving prediction results to {out_file}') 842 | 843 | with open(out_file, 'w', encoding='utf-8') as f: 844 | save_results = [] 845 | for r in prediction_results: 846 | 847 | result = {'question': r.id, 'gold_answers': r.gold_answers} 848 | passage_preds = [] 849 | for p_topk, (spans, p_ans) in enumerate( 850 | zip(r.passage_spans, r.passage_answers)): 851 | span_preds = [] 852 | for span_topk, span in enumerate(spans): 853 | span_preds.append( 854 | { 855 | 'span_topk': span_topk, 856 | 'prediction': { 857 | 'text': (span.prediction_text 858 | if span is not None else None), 859 | 'score': (span.span_score 860 | if span is not None else None), 861 | 'relevance_score': ( 862 | span.relevance_score 863 | if span is not None else None), 864 | } 865 | }) 866 | passage_pred = { 867 | 'passage_topk': p_topk, 868 | 'passage_answer': p_ans, 869 | 'passage_idx': (spans[0].passage_index 870 | if spans[0] is not None else None), 871 | 'passage': (spans[0].passage_text 872 | if spans[0] is not None else None), 873 | 'spans': span_preds, 874 | } 875 | passage_preds.append(passage_pred) 876 | result['predictions'] = passage_preds 877 | save_results.append(result) 878 | json.dump(save_results, f, indent=4) 879 | 880 | 881 | def main(): 882 | parser = argparse.ArgumentParser() 883 | 884 | options.add_encoder_params(parser) 885 | options.add_f_div_regularization_params(parser) 886 | options.add_cuda_params(parser) 887 | options.add_training_params(parser) 888 | options.add_data_params(parser) 889 | args = parser.parse_args() 890 | if args.passage_attend_history: 891 | assert args.special_attention, \ 892 | 'passage_attend_history is a kind of special attention.' 893 | 894 | assert os.path.exists(args.pretrained_model_cfg), \ 895 | (f'{args.pretrained_model_cfg} doesn\'t exist. ' 896 | f'Please manually download the HuggingFace model.') 897 | options.setup_args_gpu(args) 898 | # Makes sure random seed is fixed. 899 | # set_seed must be called after setup_args_gpu. 900 | options.set_seed(args) 901 | 902 | if dist_utils.is_local_master(): 903 | utils.print_args(args) 904 | 905 | trainer = ReaderTrainer(args) 906 | 907 | if args.train_file is not None: 908 | trainer.run_train() 909 | elif args.dev_file is not None: 910 | logger.info('No train files are specified. Run validation.') 911 | trainer.validate() 912 | else: 913 | logger.warning( 914 | 'Neither train_file or (checkpoint_file & dev_file) parameters ' 915 | 'are specified. Nothing to do.') 916 | 917 | 918 | if __name__ == '__main__': 919 | main() 920 | --------------------------------------------------------------------------------