├── .gitignore ├── FiD ├── LiT5-Distill.py ├── LiT5-Score.py └── src │ ├── data.py │ ├── model.py │ ├── modeling_t5.py │ └── options.py ├── LICENSE ├── LiT5-Distill.sh ├── LiT5-Score.sh ├── README.md ├── requirements.txt ├── runs └── .gitkeep └── topics ├── msmarco-dl19-bm25.jsonl ├── msmarco-dl19-spladepp.jsonl ├── msmarco-dl20-bm25.jsonl └── msmarco-dl20-spladepp.jsonl /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | -------------------------------------------------------------------------------- /FiD/LiT5-Distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | torch.manual_seed(0) 5 | transformers.set_seed(0) 6 | 7 | import numpy as np 8 | from torch.utils.data import DataLoader, SequentialSampler 9 | 10 | from src.options import Options 11 | import src.data 12 | import src.model 13 | 14 | import copy 15 | import json 16 | from typing import List, Union, Dict, Any, Tuple 17 | 18 | def evaluate(model, dataset, dataloader, tokenizer, opt): 19 | generated_permutations = [] 20 | 21 | with torch.no_grad(): 22 | for i, batch in enumerate(dataloader): 23 | (idx, passage_ids, passage_mask, query) = batch 24 | passage_ids = passage_ids.contiguous().view(passage_ids.size(0), -1) 25 | passage_mask = passage_mask.contiguous().view(passage_mask.size(0), -1) 26 | 27 | outputs = model.generate( 28 | input_ids=passage_ids.cuda(), 29 | attention_mask=passage_mask.cuda(), 30 | max_length=opt.answer_maxlength, 31 | do_sample=False 32 | ) 33 | 34 | for k, o in enumerate(outputs): 35 | output = tokenizer.decode(o, skip_special_tokens=True) 36 | generated_permutations.append(output) 37 | return generated_permutations 38 | 39 | def clean_response(response: str) -> str: 40 | new_response = "" 41 | for c in response: 42 | if not c.isdigit(): 43 | new_response += " " 44 | else: 45 | new_response += c 46 | new_response = new_response.strip() 47 | return new_response 48 | def remove_duplicate(response: List[int]) -> List[int]: 49 | new_response = [] 50 | for c in response: 51 | if c not in new_response: 52 | new_response.append(c) 53 | return new_response 54 | 55 | 56 | if __name__ == "__main__": 57 | options = Options() 58 | options.add_reader_options() 59 | options.add_eval_options() 60 | opt = options.parse() 61 | 62 | tokenizer = transformers.T5Tokenizer.from_pretrained(opt.model_path, return_dict=False, legacy=False, use_fast=True) 63 | 64 | collator_function = src.data.Collator(opt.text_maxlength, tokenizer, batch_size=opt.batch_size, n_passages=opt.n_passages, suffix= " Relevance Ranking: ") 65 | eval_examples = src.data.load_data(opt.eval_data) 66 | 67 | model_class = src.model.FiD 68 | model = model_class.from_pretrained(opt.model_path).cuda().eval() 69 | 70 | if opt.bfloat16: 71 | model = model.bfloat16() 72 | 73 | for query in eval_examples: 74 | query['ctxs'] = query['ctxs'][:opt.n_rerank_passages] 75 | 76 | stride = opt.stride 77 | window_size = opt.n_passages 78 | 79 | print("Start Inference") 80 | for passes in range(opt.n_passes): 81 | for window_start_idx in range(opt.n_rerank_passages - window_size, -1, -stride): 82 | eval_dataset = src.data.Dataset( 83 | eval_examples, 84 | opt.n_passages, 85 | start_pos=window_start_idx, 86 | question_prefix='Search Query:', 87 | passage_prefix='Passage:', 88 | passage_numbering=True 89 | ) 90 | print('Reranking passages:', window_start_idx, 'to', window_start_idx+window_size) 91 | 92 | eval_sampler = SequentialSampler(eval_dataset) 93 | eval_dataloader = DataLoader( 94 | eval_dataset, 95 | sampler=eval_sampler, 96 | batch_size=opt.batch_size, 97 | num_workers=4, 98 | collate_fn=collator_function 99 | ) 100 | 101 | generated_permutations = evaluate(model, eval_dataset, eval_dataloader, tokenizer, opt) 102 | 103 | for i in range(len(eval_examples)): 104 | query_dict = eval_examples[i] 105 | permutation = generated_permutations[i] 106 | 107 | resort_passages = copy.deepcopy(query_dict['ctxs'][window_start_idx:window_start_idx+window_size]) 108 | if len(resort_passages) > 0: 109 | response = clean_response(permutation) 110 | response = [int(x) - 1 for x in response.split()] 111 | response = remove_duplicate(response) 112 | original_rank = [tt for tt in range(len(resort_passages))] 113 | response = [ss for ss in response if ss in original_rank] 114 | response = response + [tt for tt in original_rank if tt not in response] 115 | for j, x in enumerate(response): 116 | query_dict['ctxs'][j + window_start_idx] = resort_passages[x] 117 | 118 | with open(opt.runfile_path + '.' + str(passes) + '.trec', 'w') as f: 119 | for query in eval_examples: 120 | rank = 1 121 | for passage in query['ctxs']: 122 | if 'docid' in passage.keys(): 123 | f.write(" ".join([query['id'], "Q0", str(passage['docid']), str(rank), str(1/rank), "RankFiD\n"])) 124 | rank+=1 125 | 126 | 127 | -------------------------------------------------------------------------------- /FiD/LiT5-Score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | 4 | torch.manual_seed(0) 5 | transformers.set_seed(0) 6 | 7 | import numpy as np 8 | from torch.utils.data import DataLoader, SequentialSampler 9 | 10 | from src.options import Options 11 | import src.data 12 | import src.model 13 | 14 | import copy 15 | import json 16 | from typing import List, Union, Dict, Any, Tuple 17 | 18 | def evaluate(model, dataset, dataloader, tokenizer, opt): 19 | if opt.write_crossattention_scores: 20 | model.overwrite_forward_crossattention() 21 | model.reset_score_storage() 22 | 23 | with torch.no_grad(): 24 | for i, batch in enumerate(dataloader): 25 | (idx, passage_ids, passage_mask, query) = batch 26 | passage_ids = passage_ids.contiguous().view(passage_ids.size(0), -1) 27 | passage_mask = passage_mask.contiguous().view(passage_mask.size(0), -1) 28 | 29 | if opt.write_crossattention_scores: 30 | model.reset_score_storage() 31 | 32 | outputs = model.generate( 33 | input_ids=passage_ids.cuda(), 34 | attention_mask=passage_mask.cuda(), 35 | max_length=opt.answer_maxlength, 36 | do_sample=False 37 | ) 38 | 39 | # need to zero out scores after EOS token. This is needed when batching results in sequences with different lengths. 40 | output_sequence_lengths = [] 41 | for output in outputs: 42 | length = 0 43 | for token in output: 44 | if token == 1: # EOS token 45 | break 46 | length += 1 47 | output_sequence_lengths.append(length) 48 | 49 | if opt.write_crossattention_scores: 50 | query_mask_reader = ( 51 | tokenizer.batch_encode_plus( 52 | query, 53 | max_length=opt.text_maxlength, 54 | padding="longest", 55 | truncation=True, 56 | return_tensors="pt", 57 | add_special_tokens=False, 58 | )["attention_mask"] 59 | .bool() 60 | .cuda() 61 | ) 62 | 63 | crossattention_scores = model.get_crossattention_scores(opt.n_passages, 64 | mask=passage_mask.cuda(), 65 | ids=passage_ids.cuda(), 66 | mask_query=query_mask_reader.cuda(), 67 | output_sequence_lengths=output_sequence_lengths) 68 | 69 | for k, o in enumerate(outputs): 70 | example = dataset.data[idx[k]] 71 | if opt.write_crossattention_scores: 72 | for j in range(min(len(example['ctxs']), opt.n_passages)): 73 | for key in crossattention_scores: 74 | example['ctxs'][j][key] = crossattention_scores[key][k, j].item() 75 | 76 | 77 | 78 | if __name__ == "__main__": 79 | options = Options() 80 | options.add_reader_options() 81 | options.add_eval_options() 82 | opt = options.parse() 83 | 84 | tokenizer = transformers.T5Tokenizer.from_pretrained(opt.model_path, return_dict=False, legacy=False, use_fast=True) 85 | 86 | collator_function = src.data.Collator(opt.text_maxlength, tokenizer, batch_size=opt.batch_size, n_passages=opt.n_passages) 87 | 88 | eval_examples = src.data.load_data( 89 | opt.eval_data, 90 | ) 91 | eval_dataset = src.data.Dataset( 92 | eval_examples, 93 | opt.n_passages, 94 | start_pos=0, 95 | ) 96 | 97 | eval_sampler = SequentialSampler(eval_dataset) 98 | eval_dataloader = DataLoader( 99 | eval_dataset, 100 | sampler=eval_sampler, 101 | batch_size=opt.batch_size, 102 | num_workers=4, 103 | collate_fn=collator_function 104 | ) 105 | 106 | model_class = src.model.FiD 107 | model = model_class.from_pretrained(opt.model_path, from_flax=False).cuda().eval() 108 | 109 | if opt.bfloat16: 110 | model = model.bfloat16() 111 | 112 | print("Start Inference") 113 | 114 | evaluate(model, eval_dataset, eval_dataloader, tokenizer, opt) 115 | 116 | with open(opt.runfile_path, 'w') as f: 117 | for query in eval_dataset.data: 118 | sort_passages = [] 119 | for passage in query['ctxs']: 120 | if 'docid' in passage.keys() and opt.sort_key in passage.keys(): 121 | sort_passages.append(passage) 122 | sort_passages = sorted(sort_passages, key=lambda x: x[opt.sort_key], reverse=True) 123 | 124 | sum_of_scores = 0 # used to normalize scores 125 | for passage in sort_passages: 126 | sum_of_scores += passage[opt.sort_key] 127 | 128 | rank = 1 129 | for passage in sort_passages: 130 | f.write(" ".join([query['id'], "Q0", str(passage['docid']), str(rank), str(passage[opt.sort_key]/sum_of_scores), "ScoreFiD\n"])) 131 | rank+=1 132 | 133 | -------------------------------------------------------------------------------- /FiD/src/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import json 4 | import numpy as np 5 | from .options import Options 6 | 7 | class Dataset(torch.utils.data.Dataset): 8 | def __init__(self, 9 | data, 10 | n_passages=None, 11 | start_pos=0, 12 | question_prefix='question:', 13 | passage_prefix='context:', 14 | passage_numbering=False): 15 | self.data = data 16 | self.n_passages = n_passages 17 | self.start_pos = start_pos 18 | self.question_prefix = question_prefix 19 | self.passage_prefix = passage_prefix 20 | self.passage_numbering = passage_numbering 21 | 22 | def __len__(self): 23 | return len(self.data) 24 | 25 | def __getitem__(self, index): 26 | example = self.data[index] 27 | question = self.question_prefix + " " + example['question'] 28 | 29 | if 'ctxs' in example and self.n_passages is not None: 30 | # add dummy contexts when there are not enough 31 | while len(example['ctxs']) < self.start_pos+self.n_passages: 32 | example['ctxs'].append({'text': ""}) 33 | 34 | contexts = np.array(example['ctxs'][self.start_pos:self.start_pos+self.n_passages]) 35 | 36 | if self.passage_numbering: 37 | f = self.passage_prefix + " [{}] {}" 38 | passages = [] 39 | passage_id = 1 40 | for c in contexts: 41 | passages.append(f.format(passage_id, c['text'])) 42 | passage_id+=1 43 | else: 44 | f = self.passage_prefix + " {}" 45 | passages = np.array([f.format(c['text']) for c in contexts]) 46 | 47 | else: 48 | passages = None 49 | return { 50 | 'index' : index, 51 | 'question' : question, 52 | 'passages' : passages, 53 | } 54 | 55 | def encode_passages(batch_text_passages, tokenizer, max_length, batch_size, n_passages): 56 | passage_ids, passage_masks = [], [] 57 | for k, text_passages in enumerate(batch_text_passages): 58 | p = tokenizer.batch_encode_plus( 59 | text_passages, 60 | max_length=max_length, 61 | padding='max_length', 62 | return_tensors='pt', 63 | truncation=True 64 | ) 65 | passage_ids.append(p['input_ids'][None]) 66 | passage_masks.append(p['attention_mask'][None]) 67 | 68 | passage_ids = torch.cat(passage_ids, dim=0) 69 | passage_masks = torch.cat(passage_masks, dim=0) 70 | return passage_ids, passage_masks.bool() 71 | 72 | class Collator(object): 73 | def __init__(self, text_maxlength, tokenizer, answer_maxlength=32, batch_size=1, n_passages=100, suffix=''): 74 | self.tokenizer = tokenizer 75 | self.text_maxlength = text_maxlength 76 | self.answer_maxlength = answer_maxlength 77 | self.batch_size = batch_size 78 | self.n_passages = n_passages 79 | self.suffix = suffix 80 | 81 | def __call__(self, batch): 82 | index = torch.tensor([ex['index'] for ex in batch]) 83 | 84 | def append_question(example): 85 | if example['passages'] is None: 86 | return [example['question']] 87 | return [example['question'] + " " + t + self.suffix for t in example['passages']] 88 | text_passages = [append_question(example) for example in batch] 89 | query = [example['question'] for example in batch] 90 | passage_ids, passage_masks = encode_passages(text_passages, 91 | self.tokenizer, 92 | self.text_maxlength, 93 | self.batch_size, 94 | self.n_passages) 95 | 96 | return (index, passage_ids, passage_masks, query) 97 | 98 | def load_data(data_path): 99 | if data_path.endswith('.jsonl'): 100 | data = open(data_path, 'r') 101 | elif data_path.endswith('.json'): 102 | with open(data_path, 'r') as fin: 103 | data = json.load(fin) 104 | examples = [] 105 | for k, example in enumerate(data): 106 | if data_path is not None and data_path.endswith('.jsonl'): 107 | example = json.loads(example) 108 | if not 'id' in example: 109 | example['id'] = k 110 | examples.append(example) 111 | 112 | if data_path.endswith('.jsonl'): 113 | data.close() 114 | 115 | return examples -------------------------------------------------------------------------------- /FiD/src/model.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | import transformers 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from torch.nn import CrossEntropyLoss 7 | import numpy as np 8 | import copy 9 | 10 | from src.options import Options 11 | 12 | options = Options() 13 | options.add_reader_options() 14 | options.add_eval_options() 15 | opt = options.parse() 16 | 17 | if opt.write_crossattention_scores: 18 | from src.modeling_t5 import T5ForConditionalGeneration, T5Stack 19 | else: 20 | from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration, T5Stack 21 | 22 | 23 | class FiDStack(T5Stack): 24 | def __init__(self, config, embed_tokens=None): 25 | super().__init__(config, embed_tokens=embed_tokens) 26 | 27 | def forward( 28 | self, 29 | input_ids=None, 30 | attention_mask=None, 31 | encoder_hidden_states=None, 32 | encoder_attention_mask=None, 33 | inputs_embeds=None, 34 | head_mask=None, 35 | cross_attn_head_mask=None, 36 | past_key_values=None, 37 | use_cache=None, 38 | output_attentions=None, 39 | output_hidden_states=None, 40 | return_dict=None, 41 | ): 42 | if not self.is_decoder: 43 | input_ids = input_ids.view(input_ids.size(0) * self.config.n_passages, -1) 44 | attention_mask = attention_mask.view(attention_mask.size(0) * self.config.n_passages, -1) 45 | 46 | output = super().forward( 47 | input_ids=input_ids, 48 | attention_mask=attention_mask, 49 | encoder_hidden_states=encoder_hidden_states, 50 | encoder_attention_mask=encoder_attention_mask, 51 | inputs_embeds=inputs_embeds, 52 | head_mask=head_mask, 53 | cross_attn_head_mask=cross_attn_head_mask, 54 | past_key_values=past_key_values, 55 | use_cache=use_cache, 56 | output_attentions=output_attentions, 57 | output_hidden_states=output_hidden_states, 58 | return_dict=return_dict, 59 | ) 60 | 61 | if not self.is_decoder: 62 | bsz = input_ids.size(0) // self.config.n_passages 63 | if not return_dict: 64 | last_hidden_states = output[0] 65 | last_hidden_state = last_hidden_states.view(bsz, -1, last_hidden_states.size(-1)) 66 | output = tuple( 67 | last_hidden_state, 68 | *output[1:], 69 | ) 70 | else: 71 | last_hidden_state = output.last_hidden_state 72 | output.last_hidden_state = last_hidden_state.view(bsz, -1, last_hidden_state.size(-1)) 73 | 74 | return output 75 | 76 | 77 | class FiD(T5ForConditionalGeneration): 78 | _keys_to_ignore_on_load_missing = [ 79 | r"encoder\.embed_tokens\.weight", 80 | r"decoder\.embed_tokens\.weight", 81 | r"lm_head\.weight", 82 | ] 83 | _keys_to_ignore_on_load_unexpected = [ 84 | r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", 85 | ] 86 | 87 | def __init__(self, config): 88 | super().__init__(config) 89 | self.model_dim = config.d_model 90 | 91 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 92 | 93 | self.config.n_passages = opt.n_passages 94 | self.config.bsz = opt.batch_size 95 | 96 | encoder_config = copy.deepcopy(config) 97 | encoder_config.is_decoder = False 98 | encoder_config.use_cache = False 99 | encoder_config.is_encoder_decoder = False 100 | 101 | self.encoder = FiDStack(encoder_config, self.shared) 102 | 103 | decoder_config = copy.deepcopy(config) 104 | decoder_config.is_decoder = True 105 | decoder_config.is_encoder_decoder = False 106 | decoder_config.use_cache = True 107 | decoder_config.num_layers = config.num_decoder_layers 108 | self.decoder = FiDStack(decoder_config, self.shared) 109 | 110 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 111 | 112 | # Initialize weights and apply final processing 113 | self.post_init() 114 | 115 | # Model parallel 116 | self.model_parallel = False 117 | self.device_map = None 118 | 119 | def set_checkpoint(self, use_checkpoint): 120 | """ 121 | Enable or disable checkpointing in the encoder. 122 | See https://pytorch.org/docs/stable/checkpoint.html 123 | """ 124 | for mod in self.encoder.encoder.block: 125 | mod.use_checkpoint = use_checkpoint 126 | 127 | def reset_score_storage(self): 128 | """ 129 | Reset score storage, only used when cross-attention scores are saved 130 | to train a retriever. 131 | """ 132 | for mod in self.decoder.block: 133 | mod.layer[1].EncDecAttention.normalized_score_storage = None 134 | 135 | @torch.no_grad() 136 | def get_crossattention_scores(self, n_passages, mask, ids, mask_query=None, output_sequence_lengths=[]): 137 | """ 138 | Cross-attention scores are aggregated to obtain a single scalar per 139 | passage. This scalar can be seen as a similarity score between the 140 | question and the input passage. It is obtained by averaging the 141 | cross-attention scores obtained on the first decoded token over heads, 142 | layers, and tokens of the input passage. 143 | 144 | More details in Distilling Knowledge from Reader to Retriever: 145 | https://arxiv.org/abs/2012.04584. 146 | """ 147 | norms = [] 148 | for mod in self.decoder.block: 149 | norms.append(mod.layer[1].EncDecAttention.normalized_score_storage) 150 | norms = torch.stack(norms) 151 | 152 | output = {} 153 | self.aggregate_value(norms, mask, n_passages, ids, mask_query, output, prefix="norms", output_sequence_lengths=output_sequence_lengths) 154 | return output 155 | 156 | def aggregate_value(self, scores, mask, n_passages, ids, mask_query=None, output={}, prefix="", output_sequence_lengths=[]): 157 | n_layers, bsz, n_tokens, total_tokens = scores.size() 158 | 159 | ids = ids.view(bsz, n_passages, -1) 160 | scores = scores.view(n_layers, bsz, n_tokens, n_passages, -1) 161 | mask = mask.view(bsz, n_passages, -1) 162 | scores = scores.masked_fill(~mask[None, :, None], 0.0) 163 | 164 | scores = scores.sum(dim=[0]) 165 | 166 | scores_woquery = None 167 | # Compute scores based on scores without query 168 | if not mask_query is None: 169 | output[f"{prefix}woquery"] = self.get_woquery_score(scores, mask_query, mask, n_layers, output_sequence_lengths=output_sequence_lengths) 170 | 171 | return output 172 | 173 | 174 | def get_woquery_score(self, scores, mask_query, mask, n_layers, output_sequence_lengths): 175 | if scores.size(-1) > mask_query.size(-1): 176 | zero_padding = torch.zeros( 177 | [mask_query.size(0), scores.size(-1) - mask_query.size(-1)], device=mask_query.device, dtype=torch.bool 178 | ) 179 | mask_query = torch.cat([mask_query, zero_padding], dim=-1) 180 | mask_query = mask * (~mask_query[:, None]) 181 | scores_woquery = scores.masked_fill(~mask_query[:, None], 0.0) 182 | 183 | ntokens_woquery = 256 * n_layers 184 | 185 | # zero out scores after EOS token. This is needed when batching results in sequences with different lengths. 186 | for i in range(len(scores_woquery)): 187 | scores_woquery[i, output_sequence_lengths[i]:, :, :] = 0 188 | 189 | scores_woquery = scores_woquery.sum(dim=[1, 3]) 190 | return scores_woquery / ntokens_woquery 191 | 192 | def overwrite_forward_crossattention(self): 193 | """ 194 | Replace cross-attention forward function, only used to save 195 | cross-attention scores. 196 | """ 197 | for mod in self.decoder.block: 198 | xattn = mod.layer[1].EncDecAttention 199 | xattn.forward = types.MethodType(cross_attention_forward, xattn) 200 | 201 | def create_crossattention_storage(self): 202 | for mod in self.decoder.block: 203 | xattn = mod.layer[1].EncDecAttention 204 | xattn.normalized_score_storage = None 205 | 206 | def cross_attention_forward( 207 | self, 208 | hidden_states, 209 | mask=None, 210 | key_value_states=None, 211 | position_bias=None, 212 | past_key_value=None, 213 | layer_head_mask=None, 214 | query_length=None, 215 | use_cache=False, 216 | output_attentions=False, 217 | ): 218 | """ 219 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 220 | """ 221 | # Input is (batch_size, seq_length, dim) 222 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 223 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 224 | 225 | batch_size, seq_length = hidden_states.shape[:2] 226 | real_seq_length = seq_length 227 | 228 | if past_key_value is not None: 229 | assert ( 230 | len(past_key_value) == 2 231 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 232 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length 233 | 234 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] 235 | 236 | def shape(states): 237 | """projection""" 238 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 239 | 240 | def unshape(states): 241 | """reshape""" 242 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 243 | 244 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 245 | """projects hidden states correctly to key/query states""" 246 | if key_value_states is None: 247 | # self-attn 248 | # (batch_size, n_heads, seq_length, dim_per_head) 249 | hidden_states = shape(proj_layer(hidden_states)) 250 | elif past_key_value is None: 251 | # cross-attn 252 | # (batch_size, n_heads, seq_length, dim_per_head) 253 | hidden_states = shape(proj_layer(key_value_states)) 254 | 255 | if past_key_value is not None: 256 | if key_value_states is None: 257 | # self-attn 258 | # (batch_size, n_heads, key_length, dim_per_head) 259 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 260 | else: 261 | # cross-attn 262 | hidden_states = past_key_value 263 | return hidden_states 264 | 265 | # get query states 266 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) 267 | 268 | # get key/value states 269 | key_states = project( 270 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None 271 | ) 272 | value_states = project( 273 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None 274 | ) 275 | 276 | # compute scores 277 | scores = torch.matmul( 278 | query_states, key_states.transpose(3, 2) 279 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 280 | 281 | 282 | if position_bias is None: 283 | if not self.has_relative_attention_bias: 284 | position_bias = torch.zeros( 285 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 286 | ) 287 | if self.gradient_checkpointing and self.training: 288 | position_bias.requires_grad = True 289 | else: 290 | position_bias = self.compute_bias(real_seq_length, key_length) 291 | 292 | # if key and values are already calculated 293 | # we want only the last query position bias 294 | if past_key_value is not None: 295 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :] 296 | 297 | if mask is not None: 298 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) 299 | 300 | scores += position_bias 301 | 302 | attn_weights = nn.functional.softmax(scores.float(), dim=-1) # .type_as(scores) 303 | 304 | if hasattr(self, "normalized_score_storage"): 305 | with torch.no_grad(): 306 | self.normalized_score_storage = ( 307 | (torch.norm(value_states.float(), dim=-1)[:, :, None] * attn_weights).detach().mean(dim=1) 308 | ) 309 | 310 | attn_weights = nn.functional.dropout(attn_weights.type_as(scores), p=self.dropout, training=self.training) 311 | 312 | # Mask heads if we want to 313 | if layer_head_mask is not None: 314 | attn_weights = attn_weights * layer_head_mask 315 | 316 | attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) 317 | attn_output = self.o(attn_output) 318 | 319 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None 320 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 321 | 322 | if output_attentions: 323 | outputs = outputs + (attn_weights,) 324 | return outputs -------------------------------------------------------------------------------- /FiD/src/modeling_t5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch T5 model. """ 16 | 17 | 18 | import copy 19 | import math 20 | import os 21 | import warnings 22 | 23 | import torch 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss 26 | from torch.utils.checkpoint import checkpoint 27 | 28 | from transformers.activations import ACT2FN 29 | from transformers.file_utils import ( 30 | DUMMY_INPUTS, 31 | DUMMY_MASK, 32 | add_start_docstrings, 33 | add_start_docstrings_to_model_forward, 34 | is_torch_fx_proxy, 35 | replace_return_docstrings, 36 | ) 37 | from transformers.modeling_outputs import ( 38 | BaseModelOutput, 39 | BaseModelOutputWithPastAndCrossAttentions, 40 | Seq2SeqLMOutput, 41 | Seq2SeqModelOutput, 42 | ) 43 | from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer 44 | from transformers.utils import logging 45 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 46 | from transformers.models.t5.configuration_t5 import T5Config 47 | 48 | 49 | logger = logging.get_logger(__name__) 50 | 51 | _CONFIG_FOR_DOC = "T5Config" 52 | _TOKENIZER_FOR_DOC = "T5Tokenizer" 53 | _CHECKPOINT_FOR_DOC = "t5-small" 54 | 55 | #################################################### 56 | # This dict contains ids and associated url 57 | # for the pretrained weights provided with the models 58 | #################################################### 59 | T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ 60 | "t5-small", 61 | "t5-base", 62 | "t5-large", 63 | "t5-3b", 64 | "t5-11b", 65 | # See all T5 models at https://huggingface.co/models?filter=t5 66 | ] 67 | 68 | 69 | #################################################### 70 | # This is a conversion method from TF 1.0 to PyTorch 71 | # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 72 | #################################################### 73 | def load_tf_weights_in_t5(model, config, tf_checkpoint_path): 74 | """Load tf checkpoints in a pytorch model.""" 75 | try: 76 | import re 77 | 78 | import numpy as np 79 | import tensorflow as tf 80 | except ImportError: 81 | logger.error( 82 | "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " 83 | "https://www.tensorflow.org/install/ for installation instructions." 84 | ) 85 | raise 86 | tf_path = os.path.abspath(tf_checkpoint_path) 87 | logger.info(f"Converting TensorFlow checkpoint from {tf_path}") 88 | # Load weights from TF model 89 | init_vars = tf.train.list_variables(tf_path) 90 | names = [] 91 | tf_weights = {} 92 | for name, shape in init_vars: 93 | logger.info(f"Loading TF weight {name} with shape {shape}") 94 | array = tf.train.load_variable(tf_path, name) 95 | names.append(name) 96 | tf_weights[name] = array 97 | 98 | for txt_name in names: 99 | name = txt_name.split("/") 100 | # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v 101 | # which are not required for using pretrained model 102 | if any( 103 | n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] 104 | for n in name 105 | ): 106 | logger.info(f"Skipping {'/'.join(name)}") 107 | tf_weights.pop(txt_name, None) 108 | continue 109 | if "_slot_" in name[-1]: 110 | logger.info(f"Skipping {'/'.join(name)}") 111 | tf_weights.pop(txt_name, None) 112 | continue 113 | pointer = model 114 | array = tf_weights[txt_name] 115 | 116 | for m_name in name: 117 | if re.fullmatch(r"[A-Za-z]+_\d+", m_name): 118 | scope_names = re.split(r"_(\d+)", m_name) 119 | else: 120 | scope_names = [m_name] 121 | if scope_names[0] in ["kernel", "scale", "embedding"]: 122 | pointer = getattr(pointer, "weight") 123 | elif scope_names[0] == "self_attention": 124 | pointer = getattr(pointer, "layer") 125 | pointer = pointer[0] 126 | elif scope_names[0] == "enc_dec_attention": 127 | pointer = getattr(pointer, "layer") 128 | pointer = pointer[1] 129 | elif scope_names[0] == "dense_relu_dense": 130 | pointer = getattr(pointer, "layer") 131 | pointer = pointer[2] 132 | elif scope_names[0] == "rms_norm": 133 | if hasattr(pointer, "layer_norm"): 134 | pointer = getattr(pointer, "layer_norm") 135 | elif hasattr(pointer, "final_layer_norm"): 136 | pointer = getattr(pointer, "final_layer_norm") 137 | elif scope_names[0] == "scale": 138 | pointer = getattr(pointer, "weight") 139 | elif scope_names[0] == "output_bias" or scope_names[0] == "beta": 140 | pointer = getattr(pointer, "bias") 141 | elif scope_names[0] == "squad": 142 | pointer = getattr(pointer, "classifier") 143 | elif scope_names[0] == "decoder" and name[1] == "logits": 144 | continue 145 | elif scope_names[0] == "logits": 146 | pointer = getattr(pointer, "lm_head") 147 | elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): 148 | pointer = getattr(pointer, f"wi_{scope_names[1]}") 149 | continue 150 | else: 151 | try: 152 | pointer = getattr(pointer, scope_names[0]) 153 | except AttributeError: 154 | logger.info(f"Skipping {'/'.join(name)}") 155 | continue 156 | if len(scope_names) >= 2: 157 | num = int(scope_names[1]) 158 | pointer = pointer[num] 159 | if scope_names[0] not in ["kernel", "scale", "embedding"]: 160 | pointer = getattr(pointer, "weight") 161 | if scope_names[0] != "embedding": 162 | logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") 163 | array = np.transpose(array) 164 | try: 165 | assert ( 166 | pointer.shape == array.shape 167 | ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" 168 | except AssertionError as e: 169 | e.args += (pointer.shape, array.shape) 170 | raise 171 | logger.info(f"Initialize PyTorch weight {name}") 172 | pointer.data = torch.from_numpy(array.astype(np.float32)) 173 | tf_weights.pop(txt_name, None) 174 | 175 | logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") 176 | return model 177 | 178 | 179 | #################################################### 180 | # PyTorch Models are constructed by sub-classing 181 | # - torch.nn.Module for the layers and 182 | # - PreTrainedModel for the models (it-self a sub-class of nn.Module) 183 | #################################################### 184 | PARALLELIZE_DOCSTRING = r""" 185 | This is an experimental feature and is a subject to change at a moment's notice. 186 | 187 | Uses a device map to distribute attention modules of the model across several devices. If no device map is given, 188 | it will evenly distribute blocks across all devices. 189 | 190 | Args: 191 | device_map (`Dict[int, list]`, optional, defaults to None): 192 | A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always 193 | automatically mapped to the first device (for esoteric reasons). That means that the first device should 194 | have fewer attention modules mapped to it than other devices. For reference, the t5 models have the 195 | following number of attention modules: 196 | 197 | - t5-small: 6 198 | - t5-base: 12 199 | - t5-large: 24 200 | - t5-3b: 24 201 | - t5-11b: 24 202 | 203 | Example: 204 | 205 | ```python 206 | # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: 207 | model = T5ForConditionalGeneration.from_pretrained('t5-3b') 208 | device_map = {0: [0, 1, 2], 209 | 210 | 1: [3, 4, 5, 6, 7, 8, 9], 211 | 2: [10, 11, 12, 13, 14, 15, 16], 212 | 3: [17, 18, 19, 20, 21, 22, 23]} 213 | model.parallelize(device_map) 214 | ``` 215 | """ 216 | DEPARALLELIZE_DOCSTRING = r""" 217 | Moves the model to cpu from a model parallel state. 218 | 219 | Example: 220 | 221 | ```python 222 | # On a 4 GPU machine with t5-3b: 223 | model = T5ForConditionalGeneration.from_pretrained('t5-3b') 224 | device_map = {0: [0, 1, 2], 225 | 226 | 1: [3, 4, 5, 6, 7, 8, 9], 227 | 2: [10, 11, 12, 13, 14, 15, 16], 228 | 3: [17, 18, 19, 20, 21, 22, 23]} 229 | model.parallelize(device_map) # Splits the model across several devices 230 | model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() 231 | ``` 232 | """ 233 | 234 | 235 | class T5LayerNorm(nn.Module): 236 | def __init__(self, hidden_size, eps=1e-6): 237 | """ 238 | Construct a layernorm module in the T5 style No bias and no subtraction of mean. 239 | """ 240 | super().__init__() 241 | self.weight = nn.Parameter(torch.ones(hidden_size)) 242 | self.variance_epsilon = eps 243 | 244 | def forward(self, hidden_states): 245 | # layer norm should always be calculated in float32 246 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 247 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 248 | 249 | # convert into half-precision if necessary 250 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 251 | hidden_states = hidden_states.to(self.weight.dtype) 252 | 253 | return self.weight * hidden_states 254 | 255 | 256 | class T5DenseReluDense(nn.Module): 257 | def __init__(self, config): 258 | super().__init__() 259 | self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) 260 | self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) 261 | self.dropout = nn.Dropout(config.dropout_rate) 262 | 263 | def forward(self, hidden_states): 264 | hidden_states = self.wi(hidden_states) 265 | hidden_states = nn.functional.relu(hidden_states) 266 | hidden_states = self.dropout(hidden_states) 267 | hidden_states = self.wo(hidden_states) 268 | # hidden_states = torch.clamp(hidden_states, -1000, 1000) 269 | return hidden_states 270 | 271 | 272 | class T5DenseGatedGeluDense(nn.Module): 273 | def __init__(self, config): 274 | super().__init__() 275 | self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) 276 | self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) 277 | self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) 278 | self.dropout = nn.Dropout(config.dropout_rate) 279 | self.gelu_act = ACT2FN["gelu_new"] 280 | 281 | def forward(self, hidden_states): 282 | hidden_gelu = self.wi_0(hidden_states) 283 | hidden_gelu = self.gelu_act(hidden_gelu.float()).type_as(hidden_states) 284 | hidden_linear = self.wi_1(hidden_states) 285 | hidden_states = hidden_gelu * hidden_linear 286 | hidden_states = self.dropout(hidden_states) 287 | hidden_states = self.wo(hidden_states) 288 | # hidden_states = torch.clamp(hidden_states, -1000, 1000) 289 | return hidden_states 290 | 291 | 292 | class T5LayerFF(nn.Module): 293 | def __init__(self, config): 294 | super().__init__() 295 | if config.feed_forward_proj == "relu": 296 | self.DenseReluDense = T5DenseReluDense(config) 297 | elif config.feed_forward_proj == "gated-gelu": 298 | self.DenseReluDense = T5DenseGatedGeluDense(config) 299 | else: 300 | raise ValueError( 301 | f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" 302 | ) 303 | 304 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 305 | self.dropout = nn.Dropout(config.dropout_rate) 306 | 307 | def forward(self, hidden_states): 308 | forwarded_states = self.layer_norm(hidden_states) 309 | forwarded_states = self.DenseReluDense(forwarded_states) 310 | hidden_states = hidden_states + self.dropout(forwarded_states) 311 | return hidden_states 312 | 313 | 314 | class T5Attention(nn.Module): 315 | def __init__(self, config: T5Config, has_relative_attention_bias=False): 316 | super().__init__() 317 | self.is_decoder = config.is_decoder 318 | self.has_relative_attention_bias = has_relative_attention_bias 319 | 320 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 321 | self.d_model = config.d_model 322 | self.key_value_proj_dim = config.d_kv 323 | self.n_heads = config.num_heads 324 | self.dropout = config.dropout_rate 325 | self.inner_dim = self.n_heads * self.key_value_proj_dim 326 | 327 | # Mesh TensorFlow initialization to avoid scaling before softmax 328 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 329 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 330 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 331 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 332 | 333 | if self.has_relative_attention_bias: 334 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) 335 | self.pruned_heads = set() 336 | self.gradient_checkpointing = False 337 | 338 | def prune_heads(self, heads): 339 | if len(heads) == 0: 340 | return 341 | heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads) 342 | # Prune linear layers 343 | self.q = prune_linear_layer(self.q, index) 344 | self.k = prune_linear_layer(self.k, index) 345 | self.v = prune_linear_layer(self.v, index) 346 | self.o = prune_linear_layer(self.o, index, dim=1) 347 | # Update hyper params 348 | self.n_heads = self.n_heads - len(heads) 349 | self.inner_dim = self.key_value_proj_dim * self.n_heads 350 | self.pruned_heads = self.pruned_heads.union(heads) 351 | 352 | @staticmethod 353 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 354 | """ 355 | Adapted from Mesh Tensorflow: 356 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 357 | 358 | Translate relative position to a bucket number for relative attention. The relative position is defined as 359 | memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to 360 | position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for 361 | small absolute relative_position and larger buckets for larger absolute relative_positions. All relative 362 | positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. 363 | This should allow for more graceful generalization to longer sequences than the model has been trained on 364 | 365 | Args: 366 | relative_position: an int32 Tensor 367 | bidirectional: a boolean - whether the attention is bidirectional 368 | num_buckets: an integer 369 | max_distance: an integer 370 | 371 | Returns: 372 | a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) 373 | """ 374 | relative_buckets = 0 375 | if bidirectional: 376 | num_buckets //= 2 377 | relative_buckets += (relative_position > 0).to(torch.long) * num_buckets 378 | relative_position = torch.abs(relative_position) 379 | else: 380 | relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) 381 | # now relative_position is in the range [0, inf) 382 | 383 | # half of the buckets are for exact increments in positions 384 | max_exact = num_buckets // 2 385 | is_small = relative_position < max_exact 386 | 387 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 388 | relative_postion_if_large = max_exact + ( 389 | torch.log(relative_position.float() / max_exact) 390 | / math.log(max_distance / max_exact) 391 | * (num_buckets - max_exact) 392 | ).to(torch.long) 393 | relative_postion_if_large = torch.min( 394 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 395 | ) 396 | 397 | relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) 398 | return relative_buckets 399 | 400 | def compute_bias(self, query_length, key_length): 401 | """Compute binned relative position bias""" 402 | context_position = torch.arange( 403 | query_length, dtype=torch.long, device=self.relative_attention_bias.weight.device 404 | )[:, None] 405 | memory_position = torch.arange(key_length, dtype=torch.long, device=self.relative_attention_bias.weight.device)[ 406 | None, : 407 | ] 408 | relative_position = memory_position - context_position # shape (query_length, key_length) 409 | relative_position_bucket = self._relative_position_bucket( 410 | relative_position, # shape (query_length, key_length) 411 | bidirectional=(not self.is_decoder), 412 | num_buckets=self.relative_attention_num_buckets, 413 | ) 414 | values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) 415 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) 416 | return values 417 | 418 | def forward( 419 | self, 420 | hidden_states, 421 | mask=None, 422 | key_value_states=None, 423 | position_bias=None, 424 | past_key_value=None, 425 | layer_head_mask=None, 426 | query_length=None, 427 | use_cache=False, 428 | output_attentions=False, 429 | ): 430 | """ 431 | Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). 432 | """ 433 | # Input is (batch_size, seq_length, dim) 434 | # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) 435 | # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) 436 | batch_size, seq_length = hidden_states.shape[:2] 437 | 438 | real_seq_length = seq_length 439 | 440 | if past_key_value is not None: 441 | assert ( 442 | len(past_key_value) == 2 443 | ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" 444 | real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length 445 | 446 | key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] 447 | 448 | def shape(states): 449 | """projection""" 450 | return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) 451 | 452 | def unshape(states): 453 | """reshape""" 454 | return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) 455 | 456 | def project(hidden_states, proj_layer, key_value_states, past_key_value): 457 | """projects hidden states correctly to key/query states""" 458 | if key_value_states is None: 459 | # self-attn 460 | # (batch_size, n_heads, seq_length, dim_per_head) 461 | hidden_states = shape(proj_layer(hidden_states)) 462 | elif past_key_value is None: 463 | # cross-attn 464 | # (batch_size, n_heads, seq_length, dim_per_head) 465 | hidden_states = shape(proj_layer(key_value_states)) 466 | 467 | if past_key_value is not None: 468 | if key_value_states is None: 469 | # self-attn 470 | # (batch_size, n_heads, key_length, dim_per_head) 471 | hidden_states = torch.cat([past_key_value, hidden_states], dim=2) 472 | else: 473 | # cross-attn 474 | hidden_states = past_key_value 475 | return hidden_states 476 | 477 | # get query states 478 | query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) 479 | 480 | # get key/value states 481 | key_states = project( 482 | hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None 483 | ) 484 | value_states = project( 485 | hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None 486 | ) 487 | 488 | # compute scores 489 | scores = torch.matmul( 490 | query_states, key_states.transpose(3, 2) 491 | ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 492 | 493 | if position_bias is None: 494 | if not self.has_relative_attention_bias: 495 | position_bias = torch.zeros( 496 | (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype 497 | ) 498 | if self.gradient_checkpointing and self.training: 499 | position_bias.requires_grad = True 500 | else: 501 | position_bias = self.compute_bias(real_seq_length, key_length) 502 | 503 | # if key and values are already calculated 504 | # we want only the last query position bias 505 | if past_key_value is not None: 506 | position_bias = position_bias[:, :, -hidden_states.size(1) :, :] 507 | 508 | if mask is not None: 509 | position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) 510 | 511 | scores += position_bias 512 | attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( 513 | scores 514 | ) # (batch_size, n_heads, seq_length, key_length) 515 | attn_weights = nn.functional.dropout( 516 | attn_weights, p=self.dropout, training=self.training 517 | ) # (batch_size, n_heads, seq_length, key_length) 518 | 519 | # Mask heads if we want to 520 | if layer_head_mask is not None: 521 | attn_weights = attn_weights * layer_head_mask 522 | 523 | attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) 524 | attn_output = self.o(attn_output) 525 | 526 | present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None 527 | outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) 528 | 529 | if output_attentions: 530 | outputs = outputs + (attn_weights,) 531 | return outputs 532 | 533 | 534 | class T5LayerSelfAttention(nn.Module): 535 | def __init__(self, config, has_relative_attention_bias=False): 536 | super().__init__() 537 | self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) 538 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 539 | self.dropout = nn.Dropout(config.dropout_rate) 540 | 541 | def forward( 542 | self, 543 | hidden_states, 544 | attention_mask=None, 545 | position_bias=None, 546 | layer_head_mask=None, 547 | past_key_value=None, 548 | use_cache=False, 549 | output_attentions=False, 550 | ): 551 | normed_hidden_states = self.layer_norm(hidden_states) 552 | attention_output = self.SelfAttention( 553 | normed_hidden_states, 554 | mask=attention_mask, 555 | position_bias=position_bias, 556 | layer_head_mask=layer_head_mask, 557 | past_key_value=past_key_value, 558 | use_cache=use_cache, 559 | output_attentions=output_attentions, 560 | ) 561 | hidden_states = hidden_states + self.dropout(attention_output[0]) 562 | outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them 563 | return outputs 564 | 565 | 566 | class T5LayerCrossAttention(nn.Module): 567 | def __init__(self, config): 568 | super().__init__() 569 | self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) 570 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 571 | self.dropout = nn.Dropout(config.dropout_rate) 572 | 573 | def forward( 574 | self, 575 | hidden_states, 576 | key_value_states, 577 | attention_mask=None, 578 | position_bias=None, 579 | layer_head_mask=None, 580 | past_key_value=None, 581 | use_cache=False, 582 | query_length=None, 583 | output_attentions=False, 584 | ): 585 | normed_hidden_states = self.layer_norm(hidden_states) 586 | attention_output = self.EncDecAttention( 587 | normed_hidden_states, 588 | mask=attention_mask, 589 | key_value_states=key_value_states, 590 | position_bias=position_bias, 591 | layer_head_mask=layer_head_mask, 592 | past_key_value=past_key_value, 593 | use_cache=use_cache, 594 | query_length=query_length, 595 | output_attentions=output_attentions, 596 | ) 597 | layer_output = hidden_states + self.dropout(attention_output[0]) 598 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them 599 | return outputs 600 | 601 | 602 | class T5Block(nn.Module): 603 | def __init__(self, config, has_relative_attention_bias=False): 604 | super().__init__() 605 | self.is_decoder = config.is_decoder 606 | self.layer = nn.ModuleList() 607 | self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) 608 | if self.is_decoder: 609 | self.layer.append(T5LayerCrossAttention(config)) 610 | 611 | self.layer.append(T5LayerFF(config)) 612 | 613 | def forward( 614 | self, 615 | hidden_states, 616 | attention_mask=None, 617 | position_bias=None, 618 | encoder_hidden_states=None, 619 | encoder_attention_mask=None, 620 | encoder_decoder_position_bias=None, 621 | layer_head_mask=None, 622 | cross_attn_layer_head_mask=None, 623 | past_key_value=None, 624 | use_cache=False, 625 | output_attentions=False, 626 | return_dict=True, 627 | ): 628 | 629 | if past_key_value is not None: 630 | assert self.is_decoder, "Only decoder can use `past_key_values`" 631 | expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 632 | 633 | if len(past_key_value) != expected_num_past_key_values: 634 | raise ValueError( 635 | f"There should be {expected_num_past_key_values} past states. " 636 | f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" 637 | f"Got {len(past_key_value)} past key / value states" 638 | ) 639 | 640 | self_attn_past_key_value = past_key_value[:2] 641 | cross_attn_past_key_value = past_key_value[2:] 642 | else: 643 | self_attn_past_key_value, cross_attn_past_key_value = None, None 644 | 645 | self_attention_outputs = self.layer[0]( 646 | hidden_states, 647 | attention_mask=attention_mask, 648 | position_bias=position_bias, 649 | layer_head_mask=layer_head_mask, 650 | past_key_value=self_attn_past_key_value, 651 | use_cache=use_cache, 652 | output_attentions=output_attentions, 653 | ) 654 | hidden_states, present_key_value_state = self_attention_outputs[:2] 655 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights 656 | 657 | # clamp inf values to enable fp16 training 658 | # if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 659 | if torch.isinf(hidden_states).any(): 660 | print("a") 661 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 662 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 663 | 664 | do_cross_attention = self.is_decoder and encoder_hidden_states is not None 665 | if do_cross_attention: 666 | # the actual query length is unknown for cross attention 667 | # if using past key value states. Need to inject it here 668 | if present_key_value_state is not None: 669 | query_length = present_key_value_state[0].shape[2] 670 | else: 671 | query_length = None 672 | 673 | cross_attention_outputs = self.layer[1]( 674 | hidden_states, 675 | key_value_states=encoder_hidden_states, 676 | attention_mask=encoder_attention_mask, 677 | position_bias=encoder_decoder_position_bias, 678 | layer_head_mask=cross_attn_layer_head_mask, 679 | past_key_value=cross_attn_past_key_value, 680 | query_length=query_length, 681 | use_cache=use_cache, 682 | output_attentions=output_attentions, 683 | ) 684 | hidden_states = cross_attention_outputs[0] 685 | 686 | # clamp inf values to enable fp16 training 687 | # if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 688 | if torch.isinf(hidden_states).any(): 689 | print("b") 690 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 691 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 692 | 693 | # Combine self attn and cross attn key value states 694 | if present_key_value_state is not None: 695 | present_key_value_state = present_key_value_state + cross_attention_outputs[1] 696 | 697 | # Keep cross-attention outputs and relative position weights 698 | attention_outputs = attention_outputs + cross_attention_outputs[2:] 699 | 700 | # Apply Feed Forward layer 701 | hidden_states = self.layer[-1](hidden_states) 702 | 703 | # clamp inf values to enable fp16 training 704 | # if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): 705 | if torch.isinf(hidden_states).any(): 706 | print(f"c {torch.linalg.norm(hidden_states).item()}") 707 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 708 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 709 | 710 | outputs = (hidden_states,) 711 | 712 | if use_cache: 713 | outputs = outputs + (present_key_value_state,) + attention_outputs 714 | else: 715 | outputs = outputs + attention_outputs 716 | 717 | return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 718 | 719 | 720 | class T5PreTrainedModel(PreTrainedModel): 721 | """ 722 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 723 | models. 724 | """ 725 | 726 | config_class = T5Config 727 | load_tf_weights = load_tf_weights_in_t5 728 | base_model_prefix = "transformer" 729 | is_parallelizable = True 730 | supports_gradient_checkpointing = True 731 | 732 | @property 733 | def dummy_inputs(self): 734 | input_ids = torch.tensor(DUMMY_INPUTS) 735 | input_mask = torch.tensor(DUMMY_MASK) 736 | dummy_inputs = { 737 | "decoder_input_ids": input_ids, 738 | "input_ids": input_ids, 739 | "decoder_attention_mask": input_mask, 740 | } 741 | return dummy_inputs 742 | 743 | def _init_weights(self, module): 744 | """Initialize the weights""" 745 | factor = self.config.initializer_factor # Used for testing weights initialization 746 | if isinstance(module, T5LayerNorm): 747 | module.weight.data.fill_(factor * 1.0) 748 | elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): 749 | # Mesh TensorFlow embeddings initialization 750 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 751 | module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) 752 | elif isinstance(module, T5DenseReluDense): 753 | # Mesh TensorFlow FF initialization 754 | # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 755 | # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 756 | module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) 757 | if hasattr(module.wi, "bias") and module.wi.bias is not None: 758 | module.wi.bias.data.zero_() 759 | module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) 760 | if hasattr(module.wo, "bias") and module.wo.bias is not None: 761 | module.wo.bias.data.zero_() 762 | elif isinstance(module, T5DenseGatedGeluDense): 763 | module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) 764 | if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: 765 | module.wi_0.bias.data.zero_() 766 | module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) 767 | if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: 768 | module.wi_1.bias.data.zero_() 769 | module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) 770 | if hasattr(module.wo, "bias") and module.wo.bias is not None: 771 | module.wo.bias.data.zero_() 772 | elif isinstance(module, T5Attention): 773 | # Mesh TensorFlow attention initialization to avoid scaling before softmax 774 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 775 | d_model = self.config.d_model 776 | key_value_proj_dim = self.config.d_kv 777 | n_heads = self.config.num_heads 778 | module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) 779 | module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) 780 | module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) 781 | module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) 782 | if module.has_relative_attention_bias: 783 | module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) 784 | 785 | def _set_gradient_checkpointing(self, module, value=False): 786 | if isinstance(module, (T5Attention, T5Stack)): 787 | module.gradient_checkpointing = value 788 | 789 | def _shift_right(self, input_ids): 790 | decoder_start_token_id = self.config.decoder_start_token_id 791 | pad_token_id = self.config.pad_token_id 792 | 793 | assert ( 794 | decoder_start_token_id is not None 795 | ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" 796 | 797 | # shift inputs to the right 798 | if is_torch_fx_proxy(input_ids): 799 | # Item assignment is not supported natively for proxies. 800 | shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) 801 | shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) 802 | else: 803 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 804 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 805 | shifted_input_ids[..., 0] = decoder_start_token_id 806 | 807 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 808 | # replace possible -100 values in labels by `pad_token_id` 809 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 810 | 811 | assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" 812 | 813 | return shifted_input_ids 814 | 815 | 816 | class T5Stack(T5PreTrainedModel): 817 | def __init__(self, config, embed_tokens=None): 818 | super().__init__(config) 819 | 820 | self.embed_tokens = embed_tokens 821 | self.is_decoder = config.is_decoder 822 | 823 | self.block = nn.ModuleList( 824 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 825 | ) 826 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 827 | self.dropout = nn.Dropout(config.dropout_rate) 828 | 829 | # Initialize weights and apply final processing 830 | self.post_init() 831 | # Model parallel 832 | self.model_parallel = False 833 | self.device_map = None 834 | self.gradient_checkpointing = False 835 | 836 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 837 | def parallelize(self, device_map=None): 838 | # Check validity of device_map 839 | self.device_map = ( 840 | get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map 841 | ) 842 | assert_device_map(self.device_map, len(self.block)) 843 | self.model_parallel = True 844 | self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) 845 | self.last_device = "cuda:" + str(max(self.device_map.keys())) 846 | # Load onto devices 847 | for k, v in self.device_map.items(): 848 | for layer in v: 849 | cuda_device = "cuda:" + str(k) 850 | self.block[layer] = self.block[layer].to(cuda_device) 851 | 852 | # Set embed_tokens to first layer 853 | self.embed_tokens = self.embed_tokens.to(self.first_device) 854 | # Set final layer norm to last device 855 | self.final_layer_norm = self.final_layer_norm.to(self.last_device) 856 | 857 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 858 | def deparallelize(self): 859 | self.model_parallel = False 860 | self.device_map = None 861 | self.first_device = "cpu" 862 | self.last_device = "cpu" 863 | for i in range(len(self.block)): 864 | self.block[i] = self.block[i].to("cpu") 865 | self.embed_tokens = self.embed_tokens.to("cpu") 866 | self.final_layer_norm = self.final_layer_norm.to("cpu") 867 | torch.cuda.empty_cache() 868 | 869 | def get_input_embeddings(self): 870 | return self.embed_tokens 871 | 872 | def set_input_embeddings(self, new_embeddings): 873 | self.embed_tokens = new_embeddings 874 | 875 | def forward( 876 | self, 877 | input_ids=None, 878 | attention_mask=None, 879 | encoder_hidden_states=None, 880 | encoder_attention_mask=None, 881 | inputs_embeds=None, 882 | head_mask=None, 883 | cross_attn_head_mask=None, 884 | past_key_values=None, 885 | use_cache=None, 886 | output_attentions=None, 887 | output_hidden_states=None, 888 | return_dict=None, 889 | ): 890 | # Model parallel 891 | if self.model_parallel: 892 | torch.cuda.set_device(self.first_device) 893 | self.embed_tokens = self.embed_tokens.to(self.first_device) 894 | use_cache = use_cache if use_cache is not None else self.config.use_cache 895 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 896 | output_hidden_states = ( 897 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 898 | ) 899 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 900 | 901 | if input_ids is not None and inputs_embeds is not None: 902 | err_msg_prefix = "decoder_" if self.is_decoder else "" 903 | raise ValueError( 904 | f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" 905 | ) 906 | elif input_ids is not None: 907 | input_shape = input_ids.size() 908 | input_ids = input_ids.view(-1, input_shape[-1]) 909 | elif inputs_embeds is not None: 910 | input_shape = inputs_embeds.size()[:-1] 911 | else: 912 | err_msg_prefix = "decoder_" if self.is_decoder else "" 913 | raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") 914 | 915 | if inputs_embeds is None: 916 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 917 | inputs_embeds = self.embed_tokens(input_ids) 918 | 919 | batch_size, seq_length = input_shape 920 | 921 | # required mask seq length can be calculated via length of past 922 | mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length 923 | 924 | if use_cache is True: 925 | assert self.is_decoder, f":obj:`use_cache` can only be set to `True` if {self} is used as a decoder" 926 | 927 | if attention_mask is None: 928 | attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) 929 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 930 | encoder_seq_length = encoder_hidden_states.shape[1] 931 | encoder_attention_mask = torch.ones( 932 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 933 | ) 934 | 935 | # initialize past_key_values with `None` if past does not exist 936 | if past_key_values is None: 937 | past_key_values = [None] * len(self.block) 938 | 939 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 940 | # ourselves in which case we just need to make it broadcastable to all heads. 941 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) 942 | 943 | # If a 2D or 3D attention mask is provided for the cross-attention 944 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 945 | if self.is_decoder and encoder_hidden_states is not None: 946 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 947 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 948 | if encoder_attention_mask is None: 949 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) 950 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 951 | else: 952 | encoder_extended_attention_mask = None 953 | 954 | # Prepare head mask if needed 955 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 956 | cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) 957 | present_key_value_states = () if use_cache else None 958 | all_hidden_states = () if output_hidden_states else None 959 | all_attentions = () if output_attentions else None 960 | all_cross_attentions = () if (output_attentions and self.is_decoder) else None 961 | position_bias = None 962 | encoder_decoder_position_bias = None 963 | 964 | hidden_states = self.dropout(inputs_embeds) 965 | 966 | for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): 967 | layer_head_mask = head_mask[i] 968 | cross_attn_layer_head_mask = cross_attn_head_mask[i] 969 | # Model parallel 970 | if self.model_parallel: 971 | torch.cuda.set_device(hidden_states.device) 972 | # Ensure that attention_mask is always on the same device as hidden_states 973 | if attention_mask is not None: 974 | attention_mask = attention_mask.to(hidden_states.device) 975 | if position_bias is not None: 976 | position_bias = position_bias.to(hidden_states.device) 977 | if encoder_hidden_states is not None: 978 | encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) 979 | if encoder_extended_attention_mask is not None: 980 | encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) 981 | if encoder_decoder_position_bias is not None: 982 | encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) 983 | if layer_head_mask is not None: 984 | layer_head_mask = layer_head_mask.to(hidden_states.device) 985 | if cross_attn_layer_head_mask is not None: 986 | cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) 987 | if output_hidden_states: 988 | all_hidden_states = all_hidden_states + (hidden_states,) 989 | 990 | if self.gradient_checkpointing and self.training: 991 | if use_cache: 992 | logger.warn( 993 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 994 | ) 995 | use_cache = False 996 | 997 | def create_custom_forward(module): 998 | def custom_forward(*inputs): 999 | return tuple(module(*inputs, use_cache, output_attentions)) 1000 | 1001 | return custom_forward 1002 | 1003 | layer_outputs = checkpoint( 1004 | create_custom_forward(layer_module), 1005 | hidden_states, 1006 | extended_attention_mask, 1007 | position_bias, 1008 | encoder_hidden_states, 1009 | encoder_extended_attention_mask, 1010 | encoder_decoder_position_bias, 1011 | layer_head_mask, 1012 | cross_attn_layer_head_mask, 1013 | None, # past_key_value is always None with gradient checkpointing 1014 | ) 1015 | else: 1016 | layer_outputs = layer_module( 1017 | hidden_states, 1018 | attention_mask=extended_attention_mask, 1019 | position_bias=position_bias, 1020 | encoder_hidden_states=encoder_hidden_states, 1021 | encoder_attention_mask=encoder_extended_attention_mask, 1022 | encoder_decoder_position_bias=encoder_decoder_position_bias, 1023 | layer_head_mask=layer_head_mask, 1024 | cross_attn_layer_head_mask=cross_attn_layer_head_mask, 1025 | past_key_value=past_key_value, 1026 | use_cache=use_cache, 1027 | output_attentions=output_attentions, 1028 | ) 1029 | 1030 | # layer_outputs is a tuple with: 1031 | # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) 1032 | if use_cache is False: 1033 | layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] 1034 | 1035 | hidden_states, present_key_value_state = layer_outputs[:2] 1036 | 1037 | # We share the position biases between the layers - the first layer store them 1038 | # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), 1039 | # (cross-attention position bias), (cross-attention weights) 1040 | position_bias = layer_outputs[2] 1041 | if self.is_decoder and encoder_hidden_states is not None: 1042 | encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] 1043 | # append next layer key value states 1044 | if use_cache: 1045 | present_key_value_states = present_key_value_states + (present_key_value_state,) 1046 | 1047 | if output_attentions: 1048 | all_attentions = all_attentions + (layer_outputs[3],) 1049 | if self.is_decoder: 1050 | all_cross_attentions = all_cross_attentions + (layer_outputs[5],) 1051 | 1052 | # Model Parallel: If it's the last layer for that device, put things on the next device 1053 | if self.model_parallel: 1054 | for k, v in self.device_map.items(): 1055 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 1056 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 1057 | 1058 | hidden_states = self.final_layer_norm(hidden_states) 1059 | hidden_states = self.dropout(hidden_states) 1060 | 1061 | # Add last layer 1062 | if output_hidden_states: 1063 | all_hidden_states = all_hidden_states + (hidden_states,) 1064 | 1065 | if not return_dict: 1066 | return tuple( 1067 | v 1068 | for v in [ 1069 | hidden_states, 1070 | present_key_value_states, 1071 | all_hidden_states, 1072 | all_attentions, 1073 | all_cross_attentions, 1074 | ] 1075 | if v is not None 1076 | ) 1077 | return BaseModelOutputWithPastAndCrossAttentions( 1078 | last_hidden_state=hidden_states, 1079 | past_key_values=present_key_value_states, 1080 | hidden_states=all_hidden_states, 1081 | attentions=all_attentions, 1082 | cross_attentions=all_cross_attentions, 1083 | ) 1084 | 1085 | 1086 | T5_START_DOCSTRING = r""" 1087 | 1088 | The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, 1089 | Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a text-to-text 1090 | denoising generative setting. 1091 | 1092 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic 1093 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 1094 | pruning heads etc.) 1095 | 1096 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) 1097 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 1098 | general usage and behavior. 1099 | 1100 | Parameters: 1101 | config ([`T5Config`]): Model configuration class with all the parameters of the model. 1102 | Initializing with a config file does not load the weights associated with the model, only the 1103 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model 1104 | weights. 1105 | """ 1106 | 1107 | T5_INPUTS_DOCSTRING = r""" 1108 | Args: 1109 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1110 | Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you 1111 | should be able to pad the inputs on both the right and the left. 1112 | 1113 | Indices can be obtained using [`T5Tokenizer`]. See 1114 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for 1115 | detail. 1116 | 1117 | [What are input IDs?](../glossary#input-ids) 1118 | 1119 | To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). 1120 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 1121 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1122 | 1123 | - 1 for tokens that are **not masked**, 1124 | - 0 for tokens that are **masked**. 1125 | 1126 | [What are attention masks?](../glossary#attention-mask) 1127 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 1128 | Indices of decoder input sequence tokens in the vocabulary. 1129 | 1130 | Indices can be obtained using [`T5Tokenizer`]. See 1131 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for 1132 | details. 1133 | 1134 | [What are decoder input IDs?](../glossary#decoder-input-ids) 1135 | 1136 | T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If 1137 | `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 1138 | `past_key_values`). 1139 | 1140 | To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training). 1141 | decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 1142 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will 1143 | also be used by default. 1144 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 1145 | Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, 1]`: 1146 | 1147 | - 1 indicates the head is **not masked**, 1148 | - 0 indicates the head is **masked**. 1149 | 1150 | decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 1151 | Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, 1]`: 1152 | 1153 | - 1 indicates the head is **not masked**, 1154 | - 0 indicates the head is **masked**. 1155 | 1156 | cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 1157 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in 1158 | `[0, 1]`: 1159 | 1160 | - 1 indicates the head is **not masked**, 1161 | - 0 indicates the head is **masked**. 1162 | 1163 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): 1164 | Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: 1165 | *attentions*) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a 1166 | sequence of hidden states at the output of the last layer of the encoder. Used in the cross-attention of 1167 | the decoder. 1168 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 1169 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 1170 | 1171 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` 1172 | (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` 1173 | instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1174 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1175 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 1176 | This is useful if you want more control over how to convert `input_ids` indices into associated 1177 | vectors than the model's internal embedding lookup matrix. 1178 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 1179 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 1180 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` 1181 | have to be input (see `past_key_values`). This is useful if you want more control over how to convert 1182 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 1183 | 1184 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` 1185 | takes the value of `inputs_embeds`. 1186 | 1187 | use_cache (`bool`, *optional*): 1188 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up 1189 | decoding (see `past_key_values`). 1190 | 1191 | output_attentions (`bool`, *optional*): 1192 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1193 | tensors for more detail. 1194 | output_hidden_states (`bool`, *optional*): 1195 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1196 | more detail. 1197 | return_dict (`bool`, *optional*): 1198 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1199 | """ 1200 | 1201 | T5_ENCODER_INPUTS_DOCSTRING = r""" 1202 | Args: 1203 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1204 | Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you 1205 | should be able to pad the inputs on both the right and the left. 1206 | 1207 | Indices can be obtained using [`T5Tokenizer`]. See 1208 | [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for 1209 | detail. 1210 | 1211 | To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). 1212 | attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 1213 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1214 | 1215 | - 1 for tokens that are **not masked**, 1216 | - 0 for tokens that are **masked**. 1217 | 1218 | [What are attention masks?](../glossary#attention-mask) 1219 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 1220 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 1221 | 1222 | - 1 indicates the head is **not masked**, 1223 | - 0 indicates the head is **masked**. 1224 | 1225 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1226 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 1227 | This is useful if you want more control over how to convert `input_ids` indices into associated 1228 | vectors than the model's internal embedding lookup matrix. 1229 | output_attentions (`bool`, *optional*): 1230 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 1231 | tensors for more detail. 1232 | output_hidden_states (`bool`, *optional*): 1233 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 1234 | more detail. 1235 | return_dict (`bool`, *optional*): 1236 | Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. 1237 | """ 1238 | 1239 | # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 1240 | __HEAD_MASK_WARNING_MSG = """ 1241 | The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, 1242 | `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. 1243 | If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, 1244 | num_heads)`. 1245 | """ 1246 | 1247 | 1248 | @add_start_docstrings( 1249 | "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", 1250 | T5_START_DOCSTRING, 1251 | ) 1252 | class T5Model(T5PreTrainedModel): 1253 | _keys_to_ignore_on_load_missing = [ 1254 | r"encoder\.embed_tokens\.weight", 1255 | r"decoder\.embed_tokens\.weight", 1256 | ] 1257 | _keys_to_ignore_on_load_unexpected = [ 1258 | r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", 1259 | ] 1260 | 1261 | def __init__(self, config: T5Config): 1262 | super().__init__(config) 1263 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 1264 | 1265 | encoder_config = copy.deepcopy(config) 1266 | encoder_config.is_decoder = False 1267 | encoder_config.use_cache = False 1268 | encoder_config.is_encoder_decoder = False 1269 | self.encoder = T5Stack(encoder_config, self.shared) 1270 | 1271 | decoder_config = copy.deepcopy(config) 1272 | decoder_config.is_decoder = True 1273 | decoder_config.is_encoder_decoder = False 1274 | decoder_config.num_layers = config.num_decoder_layers 1275 | self.decoder = T5Stack(decoder_config, self.shared) 1276 | 1277 | # Initialize weights and apply final processing 1278 | self.post_init() 1279 | 1280 | # Model parallel 1281 | self.model_parallel = False 1282 | self.device_map = None 1283 | 1284 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1285 | def parallelize(self, device_map=None): 1286 | self.device_map = ( 1287 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 1288 | if device_map is None 1289 | else device_map 1290 | ) 1291 | assert_device_map(self.device_map, len(self.encoder.block)) 1292 | self.encoder.parallelize(self.device_map) 1293 | self.decoder.parallelize(self.device_map) 1294 | self.model_parallel = True 1295 | 1296 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1297 | def deparallelize(self): 1298 | self.encoder.deparallelize() 1299 | self.decoder.deparallelize() 1300 | self.encoder = self.encoder.to("cpu") 1301 | self.decoder = self.decoder.to("cpu") 1302 | self.model_parallel = False 1303 | self.device_map = None 1304 | torch.cuda.empty_cache() 1305 | 1306 | def get_input_embeddings(self): 1307 | return self.shared 1308 | 1309 | def set_input_embeddings(self, new_embeddings): 1310 | self.shared = new_embeddings 1311 | self.encoder.set_input_embeddings(new_embeddings) 1312 | self.decoder.set_input_embeddings(new_embeddings) 1313 | 1314 | def get_encoder(self): 1315 | return self.encoder 1316 | 1317 | def get_decoder(self): 1318 | return self.decoder 1319 | 1320 | def _prune_heads(self, heads_to_prune): 1321 | """ 1322 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 1323 | class PreTrainedModel 1324 | """ 1325 | for layer, heads in heads_to_prune.items(): 1326 | self.encoder.layer[layer].attention.prune_heads(heads) 1327 | 1328 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) 1329 | @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) 1330 | def forward( 1331 | self, 1332 | input_ids=None, 1333 | attention_mask=None, 1334 | decoder_input_ids=None, 1335 | decoder_attention_mask=None, 1336 | head_mask=None, 1337 | decoder_head_mask=None, 1338 | cross_attn_head_mask=None, 1339 | encoder_outputs=None, 1340 | past_key_values=None, 1341 | inputs_embeds=None, 1342 | decoder_inputs_embeds=None, 1343 | use_cache=None, 1344 | output_attentions=None, 1345 | output_hidden_states=None, 1346 | return_dict=None, 1347 | ): 1348 | r""" 1349 | Returns: 1350 | 1351 | Example: 1352 | 1353 | ```python 1354 | >>> from transformers import T5Tokenizer, T5Model 1355 | 1356 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 1357 | >>> model = T5Model.from_pretrained('t5-small') 1358 | 1359 | >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 1360 | >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 1361 | 1362 | >>> # forward pass 1363 | >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) 1364 | >>> last_hidden_states = outputs.last_hidden_state 1365 | ```""" 1366 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1367 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1368 | 1369 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 1370 | if head_mask is not None and decoder_head_mask is None: 1371 | if self.config.num_layers == self.config.num_decoder_layers: 1372 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 1373 | decoder_head_mask = head_mask 1374 | 1375 | # Encode if needed (training, first prediction pass) 1376 | if encoder_outputs is None: 1377 | encoder_outputs = self.encoder( 1378 | input_ids=input_ids, 1379 | attention_mask=attention_mask, 1380 | inputs_embeds=inputs_embeds, 1381 | head_mask=head_mask, 1382 | output_attentions=output_attentions, 1383 | output_hidden_states=output_hidden_states, 1384 | return_dict=return_dict, 1385 | ) 1386 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1387 | encoder_outputs = BaseModelOutput( 1388 | last_hidden_state=encoder_outputs[0], 1389 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1390 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1391 | ) 1392 | 1393 | hidden_states = encoder_outputs[0] 1394 | 1395 | if self.model_parallel: 1396 | torch.cuda.set_device(self.decoder.first_device) 1397 | # Set device for model parallelism 1398 | if self.model_parallel: 1399 | torch.cuda.set_device(self.decoder.first_device) 1400 | hidden_states = hidden_states.to(self.decoder.first_device) 1401 | if decoder_input_ids is not None: 1402 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 1403 | if attention_mask is not None: 1404 | attention_mask = attention_mask.to(self.decoder.first_device) 1405 | if decoder_attention_mask is not None: 1406 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 1407 | 1408 | # Decode 1409 | decoder_outputs = self.decoder( 1410 | input_ids=decoder_input_ids, 1411 | attention_mask=decoder_attention_mask, 1412 | inputs_embeds=decoder_inputs_embeds, 1413 | past_key_values=past_key_values, 1414 | encoder_hidden_states=hidden_states, 1415 | encoder_attention_mask=attention_mask, 1416 | head_mask=decoder_head_mask, 1417 | cross_attn_head_mask=cross_attn_head_mask, 1418 | use_cache=use_cache, 1419 | output_attentions=output_attentions, 1420 | output_hidden_states=output_hidden_states, 1421 | return_dict=return_dict, 1422 | ) 1423 | 1424 | if not return_dict: 1425 | return decoder_outputs + encoder_outputs 1426 | 1427 | return Seq2SeqModelOutput( 1428 | last_hidden_state=decoder_outputs.last_hidden_state, 1429 | past_key_values=decoder_outputs.past_key_values, 1430 | decoder_hidden_states=decoder_outputs.hidden_states, 1431 | decoder_attentions=decoder_outputs.attentions, 1432 | cross_attentions=decoder_outputs.cross_attentions, 1433 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1434 | encoder_hidden_states=encoder_outputs.hidden_states, 1435 | encoder_attentions=encoder_outputs.attentions, 1436 | ) 1437 | 1438 | 1439 | @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) 1440 | class T5ForConditionalGeneration(T5PreTrainedModel): 1441 | _keys_to_ignore_on_load_missing = [ 1442 | r"encoder\.embed_tokens\.weight", 1443 | r"decoder\.embed_tokens\.weight", 1444 | r"lm_head\.weight", 1445 | ] 1446 | _keys_to_ignore_on_load_unexpected = [ 1447 | r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", 1448 | ] 1449 | 1450 | def __init__(self, config): 1451 | super().__init__(config) 1452 | self.model_dim = config.d_model 1453 | 1454 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 1455 | 1456 | encoder_config = copy.deepcopy(config) 1457 | encoder_config.is_decoder = False 1458 | encoder_config.use_cache = False 1459 | encoder_config.is_encoder_decoder = False 1460 | self.encoder = T5Stack(encoder_config, self.shared) 1461 | 1462 | decoder_config = copy.deepcopy(config) 1463 | decoder_config.is_decoder = True 1464 | decoder_config.is_encoder_decoder = False 1465 | decoder_config.num_layers = config.num_decoder_layers 1466 | self.decoder = T5Stack(decoder_config, self.shared) 1467 | 1468 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 1469 | 1470 | # Initialize weights and apply final processing 1471 | self.post_init() 1472 | 1473 | # Model parallel 1474 | self.model_parallel = False 1475 | self.device_map = None 1476 | 1477 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1478 | def parallelize(self, device_map=None): 1479 | self.device_map = ( 1480 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 1481 | if device_map is None 1482 | else device_map 1483 | ) 1484 | assert_device_map(self.device_map, len(self.encoder.block)) 1485 | self.encoder.parallelize(self.device_map) 1486 | self.decoder.parallelize(self.device_map) 1487 | self.lm_head = self.lm_head.to(self.decoder.first_device) 1488 | self.model_parallel = True 1489 | 1490 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1491 | def deparallelize(self): 1492 | self.encoder.deparallelize() 1493 | self.decoder.deparallelize() 1494 | self.encoder = self.encoder.to("cpu") 1495 | self.decoder = self.decoder.to("cpu") 1496 | self.lm_head = self.lm_head.to("cpu") 1497 | self.model_parallel = False 1498 | self.device_map = None 1499 | torch.cuda.empty_cache() 1500 | 1501 | def get_input_embeddings(self): 1502 | return self.shared 1503 | 1504 | def set_input_embeddings(self, new_embeddings): 1505 | self.shared = new_embeddings 1506 | self.encoder.set_input_embeddings(new_embeddings) 1507 | self.decoder.set_input_embeddings(new_embeddings) 1508 | 1509 | def set_output_embeddings(self, new_embeddings): 1510 | self.lm_head = new_embeddings 1511 | 1512 | def get_output_embeddings(self): 1513 | return self.lm_head 1514 | 1515 | def get_encoder(self): 1516 | return self.encoder 1517 | 1518 | def get_decoder(self): 1519 | return self.decoder 1520 | 1521 | @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) 1522 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1523 | def forward( 1524 | self, 1525 | input_ids=None, 1526 | attention_mask=None, 1527 | decoder_input_ids=None, 1528 | decoder_attention_mask=None, 1529 | head_mask=None, 1530 | decoder_head_mask=None, 1531 | cross_attn_head_mask=None, 1532 | encoder_outputs=None, 1533 | past_key_values=None, 1534 | inputs_embeds=None, 1535 | decoder_inputs_embeds=None, 1536 | labels=None, 1537 | use_cache=None, 1538 | output_attentions=None, 1539 | output_hidden_states=None, 1540 | return_dict=None, 1541 | ): 1542 | r""" 1543 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1544 | Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for 1545 | labels in `[0, ..., config.vocab_size]` 1546 | 1547 | Returns: 1548 | 1549 | Examples: 1550 | 1551 | ```python 1552 | >>> from transformers import T5Tokenizer, T5ForConditionalGeneration 1553 | 1554 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 1555 | >>> model = T5ForConditionalGeneration.from_pretrained('t5-small') 1556 | 1557 | >>> # training 1558 | >>> input_ids = tokenizer('The walks in park', return_tensors='pt').input_ids 1559 | >>> labels = tokenizer(' cute dog the ', return_tensors='pt').input_ids 1560 | >>> outputs = model(input_ids=input_ids, labels=labels) 1561 | >>> loss = outputs.loss 1562 | >>> logits = outputs.logits 1563 | 1564 | >>> # inference 1565 | >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 1566 | >>> outputs = model.generate(input_ids) 1567 | >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) 1568 | >>> # studies have shown that owning a dog is good for you. 1569 | ```""" 1570 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1571 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1572 | 1573 | # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask 1574 | if head_mask is not None and decoder_head_mask is None: 1575 | if self.config.num_layers == self.config.num_decoder_layers: 1576 | warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) 1577 | decoder_head_mask = head_mask 1578 | 1579 | # Encode if needed (training, first prediction pass) 1580 | if encoder_outputs is None: 1581 | # Convert encoder inputs in embeddings if needed 1582 | encoder_outputs = self.encoder( 1583 | input_ids=input_ids, 1584 | attention_mask=attention_mask, 1585 | inputs_embeds=inputs_embeds, 1586 | head_mask=head_mask, 1587 | output_attentions=output_attentions, 1588 | output_hidden_states=output_hidden_states, 1589 | return_dict=return_dict, 1590 | ) 1591 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1592 | encoder_outputs = BaseModelOutput( 1593 | last_hidden_state=encoder_outputs[0], 1594 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1595 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1596 | ) 1597 | 1598 | hidden_states = encoder_outputs[0] 1599 | 1600 | if self.model_parallel: 1601 | torch.cuda.set_device(self.decoder.first_device) 1602 | 1603 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 1604 | # get decoder inputs from shifting lm labels to the right 1605 | decoder_input_ids = self._shift_right(labels) 1606 | 1607 | # Set device for model parallelism 1608 | if self.model_parallel: 1609 | torch.cuda.set_device(self.decoder.first_device) 1610 | hidden_states = hidden_states.to(self.decoder.first_device) 1611 | if decoder_input_ids is not None: 1612 | decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) 1613 | if attention_mask is not None: 1614 | attention_mask = attention_mask.to(self.decoder.first_device) 1615 | if decoder_attention_mask is not None: 1616 | decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) 1617 | 1618 | # Decode 1619 | decoder_outputs = self.decoder( 1620 | input_ids=decoder_input_ids, 1621 | attention_mask=decoder_attention_mask, 1622 | inputs_embeds=decoder_inputs_embeds, 1623 | past_key_values=past_key_values, 1624 | encoder_hidden_states=hidden_states, 1625 | encoder_attention_mask=attention_mask, 1626 | head_mask=decoder_head_mask, 1627 | cross_attn_head_mask=cross_attn_head_mask, 1628 | use_cache=use_cache, 1629 | output_attentions=output_attentions, 1630 | output_hidden_states=output_hidden_states, 1631 | return_dict=return_dict, 1632 | ) 1633 | 1634 | sequence_output = decoder_outputs[0] 1635 | 1636 | # Set device for model parallelism 1637 | if self.model_parallel: 1638 | torch.cuda.set_device(self.encoder.first_device) 1639 | self.lm_head = self.lm_head.to(self.encoder.first_device) 1640 | sequence_output = sequence_output.to(self.lm_head.weight.device) 1641 | 1642 | if self.config.tie_word_embeddings: 1643 | # Rescale output before projecting on vocab 1644 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 1645 | sequence_output = sequence_output * (self.model_dim**-0.5) 1646 | 1647 | lm_logits = self.lm_head(sequence_output) 1648 | 1649 | loss = None 1650 | if labels is not None: 1651 | loss_fct = CrossEntropyLoss(ignore_index=-100) 1652 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 1653 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 1654 | 1655 | if not return_dict: 1656 | output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs 1657 | return ((loss,) + output) if loss is not None else output 1658 | 1659 | return Seq2SeqLMOutput( 1660 | loss=loss, 1661 | logits=lm_logits, 1662 | past_key_values=decoder_outputs.past_key_values, 1663 | decoder_hidden_states=decoder_outputs.hidden_states, 1664 | decoder_attentions=decoder_outputs.attentions, 1665 | cross_attentions=decoder_outputs.cross_attentions, 1666 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1667 | encoder_hidden_states=encoder_outputs.hidden_states, 1668 | encoder_attentions=encoder_outputs.attentions, 1669 | ) 1670 | 1671 | def prepare_inputs_for_generation( 1672 | self, 1673 | input_ids, 1674 | past=None, 1675 | attention_mask=None, 1676 | head_mask=None, 1677 | decoder_head_mask=None, 1678 | cross_attn_head_mask=None, 1679 | use_cache=None, 1680 | encoder_outputs=None, 1681 | **kwargs, 1682 | ): 1683 | 1684 | # cut decoder_input_ids if past is used 1685 | if past is not None: 1686 | input_ids = input_ids[:, -1:] 1687 | 1688 | return { 1689 | "decoder_input_ids": input_ids, 1690 | "past_key_values": past, 1691 | "encoder_outputs": encoder_outputs, 1692 | "attention_mask": attention_mask, 1693 | "head_mask": head_mask, 1694 | "decoder_head_mask": decoder_head_mask, 1695 | "cross_attn_head_mask": cross_attn_head_mask, 1696 | "use_cache": use_cache, 1697 | } 1698 | 1699 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1700 | return self._shift_right(labels) 1701 | 1702 | def _reorder_cache(self, past, beam_idx): 1703 | # if decoder past is not included in output 1704 | # speedy decoding is disabled and no need to reorder 1705 | if past is None: 1706 | logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") 1707 | return past 1708 | 1709 | reordered_decoder_past = () 1710 | for layer_past_states in past: 1711 | # get the correct batch idx from layer past batch dim 1712 | # batch dim of `past` is at 2nd position 1713 | reordered_layer_past_states = () 1714 | for layer_past_state in layer_past_states: 1715 | # need to set correct `past` for each of the four key / value states 1716 | reordered_layer_past_states = reordered_layer_past_states + ( 1717 | layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), 1718 | ) 1719 | 1720 | assert reordered_layer_past_states[0].shape == layer_past_states[0].shape 1721 | assert len(reordered_layer_past_states) == len(layer_past_states) 1722 | 1723 | reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) 1724 | return reordered_decoder_past 1725 | 1726 | 1727 | @add_start_docstrings( 1728 | "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", 1729 | T5_START_DOCSTRING, 1730 | ) 1731 | class T5EncoderModel(T5PreTrainedModel): 1732 | authorized_missing_keys = [ 1733 | r"encoder\.embed_tokens\.weight", 1734 | ] 1735 | 1736 | def __init__(self, config: T5Config): 1737 | super().__init__(config) 1738 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 1739 | 1740 | encoder_config = copy.deepcopy(config) 1741 | encoder_config.use_cache = False 1742 | encoder_config.is_encoder_decoder = False 1743 | self.encoder = T5Stack(encoder_config, self.shared) 1744 | 1745 | # Initialize weights and apply final processing 1746 | self.post_init() 1747 | 1748 | # Model parallel 1749 | self.model_parallel = False 1750 | self.device_map = None 1751 | 1752 | @add_start_docstrings(PARALLELIZE_DOCSTRING) 1753 | def parallelize(self, device_map=None): 1754 | self.device_map = ( 1755 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 1756 | if device_map is None 1757 | else device_map 1758 | ) 1759 | assert_device_map(self.device_map, len(self.encoder.block)) 1760 | self.encoder.parallelize(self.device_map) 1761 | self.model_parallel = True 1762 | 1763 | @add_start_docstrings(DEPARALLELIZE_DOCSTRING) 1764 | def deparallelize(self): 1765 | self.encoder.deparallelize() 1766 | self.encoder = self.encoder.to("cpu") 1767 | self.model_parallel = False 1768 | self.device_map = None 1769 | torch.cuda.empty_cache() 1770 | 1771 | def get_input_embeddings(self): 1772 | return self.shared 1773 | 1774 | def set_input_embeddings(self, new_embeddings): 1775 | self.shared = new_embeddings 1776 | self.encoder.set_input_embeddings(new_embeddings) 1777 | 1778 | def get_encoder(self): 1779 | return self.encoder 1780 | 1781 | def _prune_heads(self, heads_to_prune): 1782 | """ 1783 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 1784 | class PreTrainedModel 1785 | """ 1786 | for layer, heads in heads_to_prune.items(): 1787 | self.encoder.layer[layer].attention.prune_heads(heads) 1788 | 1789 | @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) 1790 | @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) 1791 | def forward( 1792 | self, 1793 | input_ids=None, 1794 | attention_mask=None, 1795 | head_mask=None, 1796 | inputs_embeds=None, 1797 | output_attentions=None, 1798 | output_hidden_states=None, 1799 | return_dict=None, 1800 | ): 1801 | r""" 1802 | Returns: 1803 | 1804 | Example: 1805 | 1806 | ```python 1807 | >>> from transformers import T5Tokenizer, T5EncoderModel 1808 | >>> tokenizer = T5Tokenizer.from_pretrained('t5-small') 1809 | >>> model = T5EncoderModel.from_pretrained('t5-small') 1810 | >>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 1811 | >>> outputs = model(input_ids=input_ids) 1812 | >>> last_hidden_states = outputs.last_hidden_state 1813 | ```""" 1814 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1815 | 1816 | encoder_outputs = self.encoder( 1817 | input_ids=input_ids, 1818 | attention_mask=attention_mask, 1819 | inputs_embeds=inputs_embeds, 1820 | head_mask=head_mask, 1821 | output_attentions=output_attentions, 1822 | output_hidden_states=output_hidden_states, 1823 | return_dict=return_dict, 1824 | ) 1825 | 1826 | return encoder_outputs -------------------------------------------------------------------------------- /FiD/src/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | class Options(): 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | self.initialize_parser() 9 | 10 | def add_eval_options(self): 11 | self.parser.add_argument('--write_crossattention_scores', action='store_true', 12 | help='take relevance cross-attention scores from model') 13 | self.parser.add_argument('--bfloat16', action='store_true', 14 | help='Run model inference in bfloat16') 15 | self.parser.add_argument('--stride', type=int, default=1) 16 | self.parser.add_argument('--n_rerank_passages', type=int, default=1) 17 | self.parser.add_argument('--sort_key', type=str, default='none') 18 | self.parser.add_argument('--n_passes', type=int, default=1) 19 | 20 | def add_reader_options(self): 21 | self.parser.add_argument('--eval_data', type=str, default='none', help='path of eval data') 22 | self.parser.add_argument('--text_maxlength', type=int, default=150, 23 | help='maximum number of tokens in text segments (query+passage)') 24 | self.parser.add_argument('--answer_maxlength', type=int, default=-1, 25 | help='maximum number of tokens to generate') 26 | self.parser.add_argument('--n_passages', type=int, default=1) 27 | 28 | 29 | def initialize_parser(self): 30 | # basic parameters 31 | self.parser.add_argument('--runfile_path', type=str, help='.trec runfiles are saved here') 32 | self.parser.add_argument('--model_path', type=str, default='none', help='path for model') 33 | 34 | # dataset parameters 35 | self.parser.add_argument("--batch_size", default=1, type=int, 36 | help="Batch size per GPU/CPU") 37 | 38 | 39 | def parse(self): 40 | opt = self.parser.parse_args() 41 | return opt 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | Copyright 2019-2021 Pyserini authors 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. -------------------------------------------------------------------------------- /LiT5-Distill.sh: -------------------------------------------------------------------------------- 1 | # Uncomment the first-stage you wish to test 2 | #firststage=spladepp 3 | firststage=bm25 4 | 5 | # Uncomment the model you wish to test 6 | model=castorini/LiT5-Distill-base; batchsize=260; windowsize=20 7 | #model=castorini/LiT5-Distill-base-v2; batchsize=68; windowsize=100 8 | 9 | #model=castorini/LiT5-Distill-large; batchsize=120; windowsize=20 10 | #model=castorini/LiT5-Distill-large-v2; batchsize=22; windowsize=100 11 | 12 | #model=castorini/LiT5-Distill-xl; batchsize=36; windowsize=20 13 | #model=castorini/LiT5-Distill-xl-v2; batchsize=12; windowsize=100 14 | 15 | total_n_rerank_passages=100 16 | stride=10 17 | n_passes=1 18 | 19 | for topics in 'dl19' 'dl20'; do 20 | runfile_path="runs/run.${topics}_${firststage}_${model//\//}" 21 | 22 | python3 FiD/LiT5-Distill.py \ 23 | --model_path $model \ 24 | --eval_data "topics/msmarco-${topics}-${firststage}.jsonl" \ 25 | --batch_size $batchsize \ 26 | --n_passages $windowsize \ 27 | --runfile_path $runfile_path \ 28 | --text_maxlength 150 \ 29 | --answer_maxlength 140 \ 30 | --stride $stride \ 31 | --n_rerank_passages $total_n_rerank_passages \ 32 | --bfloat16 \ 33 | --n_passes $n_passes 34 | 35 | for ((i = 0 ; i < n_passes ; i++ )); do 36 | python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 ${topics}-passage ${runfile_path}.${i}.trec 37 | done 38 | done 39 | -------------------------------------------------------------------------------- /LiT5-Score.sh: -------------------------------------------------------------------------------- 1 | sort_key=normswoquery 2 | 3 | # Uncomment the first-stage you wish to test 4 | #firststage=spladepp 5 | firststage=bm25 6 | 7 | # Uncomment the model you wish to test 8 | model=castorini/LiT5-Score-base; batchsize=30 9 | #model=castorini/LiT5-Score-large; batchsize=10 10 | #model=castorini/LiT5-Score-xl; batchsize=4 11 | 12 | for topics in 'dl19' 'dl20'; do 13 | runfile_path="runs/run.${topics}_${firststage}_${model//\//}.trec" 14 | 15 | python3 FiD/LiT5-Score.py \ 16 | --model_path $model \ 17 | --eval_data "topics/msmarco-${topics}-${firststage}.jsonl" \ 18 | --batch_size $batchsize \ 19 | --n_passages 100 \ 20 | --runfile_path $runfile_path \ 21 | --text_maxlength 150 \ 22 | --answer_maxlength 20 \ 23 | --write_crossattention_scores \ 24 | --sort_key $sort_key \ 25 | --bfloat16 26 | 27 | python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 ${topics}-passage $runfile_path 28 | done 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LiT5 (List-in-T5) Reranking 2 | 3 | ## RankLLM 4 | We have integrated LiT5 into [RankLLM](https://github.com/castorini/rank_llm), which is actively maintained and includes additional improvements. We highly recommend using RankLLM. 5 | 6 | ## 📟 Instructions 7 | 8 | We provide the scripts and data necessary to reproduce reranking results for [LiT5-Distill](LiT5-Distill.sh) and [LiT5-Score](LiT5-Score.sh) on DL19 and DL20 for BM25 and SPLADE++ ED first-stage retrieval. Note you may need to change the batchsize depending on your VRAM. We have observed that results may change slightly when the batchsize is changed. [This is a known issue when running inference in bfloat16](https://github.com/huggingface/transformers/issues/25921). Additionally, you may need to remove the --bfloat16 option from the scripts if your GPU does not support it. 9 | 10 | Note, the v2 LiT5-Distill models support reranking up to 100 passages at once. 11 | 12 | ## Models 13 | 14 | The following is a table of our models hosted on HuggingFace: 15 | 16 | | Model Name | Hugging Face Identifier/Link | 17 | |-----------------------|--------------------------------------------------------------------------------------------| 18 | | LiT5-Distill-base | [castorini/LiT5-Distill-base](https://huggingface.co/castorini/LiT5-Distill-base) | 19 | | LiT5-Distill-large | [castorini/LiT5-Distill-large](https://huggingface.co/castorini/LiT5-Distill-large) | 20 | | LiT5-Distill-xl | [castorini/LiT5-Distill-xl](https://huggingface.co/castorini/LiT5-Distill-xl) | 21 | | LiT5-Distill-base-v2 | [castorini/LiT5-Distill-base-v2](https://huggingface.co/castorini/LiT5-Distill-base-v2) | 22 | | LiT5-Distill-large-v2 | [castorini/LiT5-Distill-large-v2](https://huggingface.co/castorini/LiT5-Distill-large-v2) | 23 | | LiT5-Distill-xl-v2 | [castorini/LiT5-Distill-xl-v2](https://huggingface.co/castorini/LiT5-Distill-xl-v2) | 24 | | LiT5-Score-base | [castorini/LiT5-Score-base](https://huggingface.co/castorini/LiT5-Score-base) | 25 | | LiT5-Score-large | [castorini/LiT5-Score-large](https://huggingface.co/castorini/LiT5-Score-large) | 26 | | LiT5-Score-xl | [castorini/LiT5-Score-xl](https://huggingface.co/castorini/LiT5-Score-xl) | 27 | 28 | 29 | ## Expected Results 30 | 31 | This table shows the expected results for reranking with BM25 first-stage retrieval 32 | 33 | ### DL19 34 | | Model Name | nDCG@10 | 35 | |-----------------------|---------| 36 | | LiT5-Distill-base | 71.7 | 37 | | LiT5-Distill-large | 72.7 | 38 | | LiT5-Distill-xl | 72.3 | 39 | | LiT5-Distill-base-v2 | 71.7 | 40 | | LiT5-Distill-large-v2 | 73.3 | 41 | | LiT5-Distill-xl-v2 | 73.0 | 42 | | LiT5-Score-base | 68.9 | 43 | | LiT5-Score-large | 72.0 | 44 | | LiT5-Score-xl | 70.0 | 45 | 46 | ### DL20 47 | | Model Name | nDCG@10 | 48 | |-----------------------|---------| 49 | | LiT5-Distill-base | 68.0 | 50 | | LiT5-Distill-large | 70.0 | 51 | | LiT5-Distill-xl | 71.8 | 52 | | LiT5-Distill-base-v2 | 66.7 | 53 | | LiT5-Distill-large-v2 | 69.8 | 54 | | LiT5-Distill-xl-v2 | 73.7 | 55 | | LiT5-Score-base | 66.2 | 56 | | LiT5-Score-large | 67.8 | 57 | | LiT5-Score-xl | 65.7 | 58 | 59 | This table shows the expected results for reranking with SPLADE++ ED first-stage retrieval 60 | 61 | ### DL19 62 | | Model Name | nDCG@10 | 63 | |-----------------------|---------| 64 | | LiT5-Distill-base | 74.6 | 65 | | LiT5-Distill-large | 76.8 | 66 | | LiT5-Distill-xl | 76.8 | 67 | | LiT5-Distill-base-v2 | 78.3 | 68 | | LiT5-Distill-large-v2 | 80.0 | 69 | | LiT5-Distill-xl-v2 | 78.5 | 70 | | LiT5-Score-base | 68.4 | 71 | | LiT5-Score-large | 68.7 | 72 | | LiT5-Score-xl | 69.0 | 73 | 74 | ### DL20 75 | | Model Name | nDCG@10 | 76 | |-----------------------|---------| 77 | | LiT5-Distill-base | 74.1 | 78 | | LiT5-Distill-large | 76.5 | 79 | | LiT5-Distill-xl | 76.7 | 80 | | LiT5-Distill-base-v2 | 75.1 | 81 | | LiT5-Distill-large-v2 | 76.6 | 82 | | LiT5-Distill-xl-v2 | 80.4 | 83 | | LiT5-Score-base | 68.5 | 84 | | LiT5-Score-large | 73.1 | 85 | | LiT5-Score-xl | 71.0 | 86 | 87 | ## ✨ References 88 | 89 | If you use LiT5, please cite the following paper: 90 | [[2312.16098] Scaling Down, LiTting Up: Efficient Zero-Shot Listwise Reranking with Seq2seq Encoder-Decoder Models](https://arxiv.org/abs/2312.16098) 91 | 92 | ``` 93 | @ARTICLE{tamber2023scaling, 94 | title = {Scaling Down, LiTting Up: Efficient Zero-Shot Listwise Reranking with Seq2seq Encoder-Decoder Models}, 95 | author = {Manveer Singh Tamber and Ronak Pradeep and Jimmy Lin}, 96 | year = {2023}, 97 | journal = {arXiv preprint arXiv: 2312.16098} 98 | } 99 | ``` 100 | 101 | 🙏 Acknowledgments 102 | 103 | This repository borrows code from the original [FiD repository](https://github.com/facebookresearch/FiD), the [atlas repository](https://github.com/facebookresearch/atlas), and the [RankLLM repository](https://github.com/castorini/rank_llm)! 104 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.3.0 2 | transformers==4.40.2 3 | pyserini 4 | faiss-cpu -------------------------------------------------------------------------------- /runs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/castorini/LiT5/6f6aca2547e59499f1f58f49ce92f88ff2d6cf16/runs/.gitkeep --------------------------------------------------------------------------------