├── .gitignore ├── FiD ├── README.md ├── ac_scheduler.py ├── data.py ├── data_ac.py ├── evaluation.py ├── fidt5.py ├── fidt5_ac.py ├── options.py ├── slurm.py ├── t5blocks.py ├── test.py ├── test_ac.py ├── test_ac_scheduler.py ├── test_retrieval_acc.py ├── test_retriever_baseline.py ├── train.py ├── train_ac.py ├── train_ac_scheduler.py └── util.py ├── README.md ├── requirements.txt └── scripts ├── batch_eval_retrieval_acc.sh ├── batch_eval_scheduler_nq.sh ├── batch_eval_scheduler_trivia.sh ├── download_data.sh ├── download_models.sh ├── eval_ac_nq_single.sh ├── eval_nq_batch.sh ├── eval_nq_single.sh ├── eval_retrieval_acc.sh ├── eval_scheduler_nq.sh ├── eval_scheduler_trivia.sh ├── eval_trivia.sh ├── eval_trivia_batch.sh ├── init.sh ├── train_ac_nq.sh ├── train_ac_scheduler_nq.sh ├── train_ac_scheduler_trivia.sh ├── train_ac_trivia.sh └── train_nq.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /FiD/README.md: -------------------------------------------------------------------------------- 1 | # Fusion-in-Decoder 2 | 3 | ## Dependencies 4 | 5 | - Python 3 6 | - [NumPy](http://www.numpy.org/) 7 | - [PyTorch](http://pytorch.org/) (currently tested on version 1.6.0) 8 | - [Transformers](http://huggingface.co/transformers/) (version 3.0.2, unlikely to work with a different version) 9 | 10 | ### Download data 11 | 12 | ### Train 13 | 14 | [`train.py`](train.py) provides the code for training a model from scratch. An example usage of the script with some options is given below: 15 | 16 | ```shell 17 | python train.py \ 18 | --use_checkpointing \ 19 | --train_data_path $tp \ 20 | --dev_data_path $dp \ 21 | --model_size base \ 22 | --per_gpu_batch_size 4 \ 23 | --n_context 10 \ 24 | --name my_experiment \ 25 | --checkpoint_dir checkpoint \ 26 | --eval_freq 500 27 | ``` 28 | 29 | ### Test 30 | 31 | [`test.py`](test.py) provides the script to evaluate the performance of the model. An example usage of the script is provided below. 32 | 33 | ```shell 34 | python test.py \ 35 | --model_path my_model_path \ 36 | --test_data_path my_test_data.json \ 37 | --model_size base \ 38 | --per_gpu_batch_size 4 \ 39 | --n_context 10 \ 40 | --name my_test \ 41 | --checkpoint_dir checkpoint 42 | ``` 43 | 44 | ### Data format 45 | 46 | The expected data format is a list of entry examples, where each entry example is a dictionary containing 47 | - `id`: example id, optional 48 | - `question`: question text 49 | - `target`: answer used for model training, if not given, the target is randomly sampled from the 'answers' list 50 | - `answers`: list of answer text for evaluation, also used for training if target is not given 51 | - `ctxs`: a list of passages where each item is a dictionary containing 52 | - `title`: article title 53 | - `text`: passage text 54 | 55 | Entry example: 56 | ``` 57 | { 58 | 'id': '0', 59 | 'question': 'What element did Marie Curie name after her native land?', 60 | 'target': 'Polonium', 61 | 'answers': ['Polonium', 'Po (chemical element)', 'Po'], 62 | 'ctxs': [ 63 | { 64 | "title": "Marie Curie", 65 | "text": "them on visits to Poland. She named the first chemical element that she discovered in 1898 \"polonium\", after her native country. Marie Curie died in 1934, aged 66, at a sanatorium in Sancellemoz (Haute-Savoie), France, of aplastic anemia from exposure to radiation in the course of her scientific research and in the course of her radiological work at field hospitals during World War I. Maria Sk\u0142odowska was born in Warsaw, in Congress Poland in the Russian Empire, on 7 November 1867, the fifth and youngest child of well-known teachers Bronis\u0142awa, \"n\u00e9e\" Boguska, and W\u0142adys\u0142aw Sk\u0142odowski. The elder siblings of Maria" 66 | }, 67 | { 68 | "title": "Marie Curie", 69 | "text": "was present in such minute quantities that they would eventually have to process tons of the ore. In July 1898, Curie and her husband published a joint paper announcing the existence of an element which they named \"polonium\", in honour of her native Poland, which would for another twenty years remain partitioned among three empires (Russian, Austrian, and Prussian). On 26 December 1898, the Curies announced the existence of a second element, which they named \"radium\", from the Latin word for \"ray\". In the course of their research, they also coined the word \"radioactivity\". To prove their discoveries beyond any" 70 | } 71 | ] 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /FiD/ac_scheduler.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | 7 | from torch.distributions.utils import probs_to_logits 8 | from torch.distributions import Categorical 9 | 10 | Tower = namedtuple("Tower", "has_answer_logit, layer, rank") 11 | 12 | LARGE_POS = 1000 13 | LARGE_NEG = -1e5 14 | 15 | 16 | class BaseScheduler(nn.Module): 17 | def __init__(self, config): 18 | super().__init__() 19 | self.config = config 20 | self.init_priorities = nn.Parameter(self._get_init_priorities()) 21 | 22 | def _get_init_priorities(self) -> torch.Tensor: 23 | """Initialize the initial priorities for all passages.""" 24 | # # Uniform 25 | # self.init_priorities = nn.Parameter(torch.Tensor(self.config.scheduler_n_context)) 26 | # bound = 1 / math.sqrt(self.config.scheduler_n_context) 27 | # self.init_priorities.data.uniform_(-bound, bound) 28 | # 29 | # # Zeros 30 | # self.init_priorities = nn.Parameter(torch.Tensor(self.config.scheduler_n_context)) 31 | # self.init_priorities.data.fill_(0.) 32 | 33 | # Heuristic 34 | init_probs = torch.tensor([0.5 / (i + 1) for i in range(self.config.scheduler_n_context)]) 35 | return probs_to_logits(init_probs, is_binary=True) # shape: [n_passages] 36 | 37 | def forward(self, has_answer_logits, layers, ranks): 38 | """ 39 | Args: 40 | has_answer_logits (Tensor): float Tensor with shape [bsz, n_passages] 41 | layers (Tensor): int Tensor with shape [bsz, n_passages] 42 | ranks (Tensor): int Tensor with shape [bsz, n_passages] 43 | 44 | Returns: 45 | priorities (Tensor): float Tensor with shape [bsz, n_passages] 46 | """ 47 | raise NotImplementedError 48 | 49 | def act(self, all_has_answer_logits, layer_indices, masks, greedy=False, **kwargs): 50 | """ Take an action given the current status of the skyline. 51 | 52 | Args: 53 | all_has_answer_logits (torch.Tensor): float Tensor with shape [bsz, n_passages, num_layers] 54 | layer_indices (torch.Tensor): int Tensor with shape [bsz, n_passages] 55 | masks (torch.Tensor): float Tensor with shape [bsz, n_passages] 56 | greedy (bool): True if act greedily 57 | **kwargs: 58 | 59 | Returns: 60 | action (torch.Tensor): int Tensor with shape [bsz] that indicates the actions chosen 61 | log_probs (torch.Tensor): float Tensor with shape [bsz] that indicates the log-prob of the actions 62 | """ 63 | bsz, n_passages, num_layers = all_has_answer_logits.shape 64 | device = all_has_answer_logits.device 65 | 66 | init_priors = self.init_priorities[:n_passages].unsqueeze(0).expand(bsz, -1) # [bsz, n_passages] 67 | all_has_answer_logits_with_init = torch.cat((init_priors.unsqueeze(-1), all_has_answer_logits), -1) 68 | # shape: [bsz, n_passages, num_layers + 1] 69 | 70 | layer_tensor = layer_indices + 1 # shape: [bsz, n_passages], range=[0, num_layers] 71 | rank_tensor = torch.arange(n_passages, device=device).unsqueeze(0).expand(bsz, -1) # shape: [bsz, n_passages] 72 | 73 | # Collect the has_answer logits for each tower (including the initial layers), shape: [bsz, n_passages] 74 | has_answer_logits = all_has_answer_logits_with_init.gather(2, layer_tensor.unsqueeze(2)).squeeze(2) 75 | # has_answer_logits = torch.where(layer_indices < 0, init_priors, has_answer_logits) 76 | 77 | priorities = self.forward(has_answer_logits, layer_tensor, rank_tensor) # shape: [bsz, n_passages] 78 | 79 | # Apply the mask to avoid choosing the maximum towers again 80 | priorities = priorities + (1. - masks) * LARGE_NEG 81 | 82 | if greedy: # select the max priority during evaluation 83 | action = priorities.argmax(-1) # shape: [bsz] 84 | log_prob = -torch.ones(bsz, device=device, requires_grad=False) 85 | else: 86 | m = Categorical(logits=priorities) 87 | action = m.sample() # shape: [bsz] 88 | log_prob = m.log_prob(action) # shape: [bsz] 89 | 90 | return action, log_prob 91 | 92 | 93 | class TopScheduler(BaseScheduler): 94 | def forward(self, has_answer_logits, layers, ranks): 95 | priorities = has_answer_logits * 0. - ranks.float() # shape: [bsz, n_passages] 96 | return priorities 97 | 98 | 99 | class DummyScheduler(BaseScheduler): 100 | """Simple scheduler that is solely based on has_answer_prob.""" 101 | 102 | def __init__(self, config): 103 | super().__init__(config) 104 | self.weight = nn.Parameter(torch.tensor(1.0)) 105 | 106 | def forward(self, has_answer_logits, layers, ranks): 107 | priorities = self.weight * has_answer_logits # shape: [bsz, n_passages] 108 | return priorities 109 | 110 | 111 | class SimpleScheduler(BaseScheduler): 112 | """Scheduler that exploits has_answer_prob, rank and layer as input.""" 113 | 114 | def __init__(self, config): 115 | super().__init__(config) 116 | self.weight = nn.Parameter(torch.tensor(1.0)) 117 | 118 | self.layer_embeddings = nn.Embedding(config.num_layers + 1, 1) 119 | self.rank_embeddings = nn.Embedding(config.scheduler_n_context, 1) 120 | 121 | def forward(self, has_answer_logits, layers, ranks): 122 | # Compute the offsets 123 | layer_emb = self.layer_embeddings(layers) # shape: [bsz, n_passages, 1] 124 | rank_emb = self.rank_embeddings(ranks) # shape: [bsz, n_passages, 1] 125 | offsets = torch.squeeze(layer_emb + rank_emb, -1) # shape: [bsz, n_passages] 126 | 127 | priorities = self.weight * has_answer_logits + offsets # shape: [bsz, n_passages] 128 | return priorities 129 | 130 | 131 | class MLPScheduler(BaseScheduler): 132 | """Scheduler that uses MLP to integrate has_answer_prob, rank and layer as input.""" 133 | 134 | def __init__(self, config): 135 | super().__init__(config) 136 | self.weight = nn.Parameter(torch.tensor(1.0)) 137 | 138 | embed_size = config.scheduler_embed_size 139 | self.layer_embeddings = nn.Embedding(config.num_layers + 1, embed_size) 140 | self.rank_embeddings = nn.Embedding(config.scheduler_n_context, embed_size) 141 | 142 | # MLP 143 | hidden_size = config.scheduler_hidden_size 144 | self.dense0 = nn.Linear(embed_size * 2 + 1, hidden_size) 145 | self.act_fn = F.relu 146 | self.dense1 = nn.Linear(hidden_size, 2, bias=False) 147 | 148 | def forward(self, has_answer_logits, layers, ranks): 149 | """ 150 | Args: 151 | has_answer_logits (Tensor): float Tensor with shape [bsz, n_passages] 152 | layers (Tensor): int Tensor with shape [bsz, n_passages] 153 | ranks (Tensor): int Tensor with shape [bsz, n_passages] 154 | 155 | Returns: 156 | priorities (Tensor): float Tensor with shape [bsz, n_passages] 157 | """ 158 | 159 | layer_emb = self.layer_embeddings(layers) # shape: [bsz, n_passages, embed_size] 160 | rank_emb = self.rank_embeddings(ranks) # shape: [bsz, n_passages, embed_size] 161 | 162 | mlp_input = torch.cat( 163 | (has_answer_logits.unsqueeze(-1), layer_emb, rank_emb), -1 164 | ) # shape: [bsz, n_passages, embed_size * 2 + 1] 165 | mlp_output = self.dense1(self.act_fn(self.dense0(mlp_input))) # shape: [bsz, n_passages, 1] 166 | 167 | offset_logit = torch.squeeze(mlp_output, -1) # shape: [bsz, n_passages] 168 | priorities = self.weight * has_answer_logits + offset_logit # shape: [bsz, n_passages] 169 | 170 | return priorities # shape: [bsz, n_passages] 171 | 172 | 173 | class GatedMLPScheduler(BaseScheduler): 174 | """Scheduler that uses MLP to integrate has_answer_prob (gated), rank and layer.""" 175 | 176 | def __init__(self, config): 177 | super().__init__(config) 178 | self.weight = nn.Parameter(torch.tensor(1.0)) 179 | 180 | embed_size = config.scheduler_embed_size 181 | self.layer_embeddings = nn.Embedding(config.num_layers + 1, embed_size) 182 | self.rank_embeddings = nn.Embedding(config.scheduler_n_context, embed_size) 183 | 184 | # MLP 185 | hidden_size = config.scheduler_hidden_size 186 | self.dense0 = nn.Linear(embed_size * 2 + 1, hidden_size) 187 | self.act_fn = F.relu 188 | self.dense1 = nn.Linear(hidden_size, 2, bias=False) 189 | 190 | def forward(self, has_answer_logits, layers, ranks): 191 | """ 192 | Args: 193 | has_answer_logits (Tensor): float Tensor with shape [bsz, n_passages] 194 | layers (Tensor): int Tensor with shape [bsz, n_passages] 195 | ranks (Tensor): int Tensor with shape [bsz, n_passages] 196 | 197 | Returns: 198 | priorities (Tensor): float Tensor with shape [bsz, n_passages] 199 | """ 200 | 201 | layer_emb = self.layer_embeddings(layers) # shape: [bsz, n_passages, embed_size] 202 | rank_emb = self.rank_embeddings(ranks) # shape: [bsz, n_passages, embed_size] 203 | 204 | mlp_input = torch.cat( 205 | (has_answer_logits.unsqueeze(-1), layer_emb, rank_emb), -1 206 | ) # shape: [bsz, n_passages, embed_size * 2 + 1] 207 | mlp_output = self.dense1(self.act_fn(self.dense0(mlp_input))) # shape: [bsz, n_passages, 2] 208 | 209 | offset_logit, gate_logit = torch.unbind(mlp_output, dim=-1) 210 | gate = torch.sigmoid(gate_logit) # shape: [bsz, n_passages] 211 | priorities = self.weight * gate * has_answer_logits + offset_logit # shape: [bsz, n_passages] 212 | 213 | return priorities # shape: [bsz, n_passages] 214 | 215 | 216 | class GatedMLPSchedulerWithPosition(GatedMLPScheduler): 217 | """Scheduler that uses MLP to integrate has_answer_prob (gated), rank and layer.""" 218 | 219 | def __init__(self, config): 220 | super().__init__(config) 221 | self.weight = nn.Parameter(torch.tensor(1.0)) 222 | 223 | embed_size = config.scheduler_embed_size 224 | self.layer_embeddings = nn.Embedding(config.num_layers + 1, embed_size) 225 | self.rank_embeddings = nn.Embedding(config.scheduler_n_context, embed_size) 226 | 227 | # MLP 228 | hidden_size = config.scheduler_hidden_size 229 | self.dense0 = nn.Linear(embed_size * 2 + 4, hidden_size) 230 | self.act_fn = F.relu 231 | self.dense1 = nn.Linear(hidden_size, 2, bias=False) 232 | 233 | def forward(self, has_answer_logits, layers, ranks): 234 | """ 235 | Args: 236 | has_answer_logits (Tensor): float Tensor with shape [bsz, n_passages] 237 | layers (Tensor): int Tensor with shape [bsz, n_passages] 238 | ranks (Tensor): int Tensor with shape [bsz, n_passages] 239 | 240 | Returns: 241 | priorities (Tensor): float Tensor with shape [bsz, n_passages] 242 | """ 243 | 244 | layer_emb = self.layer_embeddings(layers) # shape: [bsz, n_passages, embed_size] 245 | rank_emb = self.rank_embeddings(ranks) # shape: [bsz, n_passages, embed_size] 246 | 247 | # Additional features 248 | layers_feat = (layers.float() / self.config.num_layers).unsqueeze(-1) # shape: [bsz, n_passages, 1] 249 | ranks_feat = (ranks.float() / self.config.scheduler_n_context).unsqueeze(-1) # shape: [bsz, n_passages, 1] 250 | 251 | mlp_input = torch.cat( 252 | (has_answer_logits.unsqueeze(-1), layers_feat, ranks_feat, layers_feat + ranks_feat, 253 | layer_emb, rank_emb), -1 254 | ) # shape: [bsz, n_passages, embed_size * 2 + 1] 255 | mlp_output = self.dense1(self.act_fn(self.dense0(mlp_input))) # shape: [bsz, n_passages, 2] 256 | 257 | offset_logit, gate_logit = torch.unbind(mlp_output, dim=-1) 258 | gate = torch.sigmoid(gate_logit) # shape: [bsz, n_passages] 259 | priorities = self.weight * gate * has_answer_logits + offset_logit # shape: [bsz, n_passages] 260 | 261 | return priorities # shape: [bsz, n_passages] 262 | 263 | 264 | SchedulerMapping = { 265 | "top": TopScheduler, 266 | "dummy": DummyScheduler, 267 | "simple": SimpleScheduler, 268 | "mlp": MLPScheduler, 269 | "gated_mlp": GatedMLPScheduler, 270 | "gated_mlp_pos": GatedMLPSchedulerWithPosition, 271 | } 272 | 273 | 274 | def get_scheduler(config): 275 | """Construct a scheduler from the config (default: None)""" 276 | if hasattr(config, "scheduler_type"): 277 | try: 278 | scheduler = SchedulerMapping[config.scheduler_type](config) 279 | except KeyError: 280 | raise KeyError(f"Invalid scheduler_type: {config.scheduler_type}") 281 | else: 282 | scheduler = None 283 | return scheduler 284 | 285 | 286 | def run_ac_scheduler( 287 | hidden_states, 288 | attention_mask, 289 | has_answer_outputs, 290 | ac_scheduler: BaseScheduler, 291 | budget: int, 292 | num_passages_retained: int, 293 | is_training: bool = True, 294 | ): 295 | """ 296 | 297 | Args: 298 | hidden_states (torch.Tensor): float Tensor with shape [bsz (B), n_passages (N), plen (L), d_model (D)] 299 | attention_mask (torch.Tensor): float Tensor with shape [B, N, L] 300 | has_answer_outputs (torch.Tensor): float Tensor with shape [B, N, num_layers] 301 | ac_scheduler (BaseScheduler): 302 | budget (int): 303 | num_passages_retained (int): 304 | is_training (bool): 305 | 306 | Returns: 307 | 308 | """ 309 | bsz, n_passages, plen, _ = hidden_states.shape 310 | num_layers = has_answer_outputs.shape[2] 311 | if budget > num_layers * n_passages: 312 | raise ValueError(f"budget={budget} should be small than num_layers * n_passages={num_layers * n_passages}") 313 | device = hidden_states.device 314 | 315 | # Run the AC prioritization algorithm 316 | all_actions, all_log_probs = [], [] 317 | skyline = -torch.ones((bsz, n_passages), dtype=torch.long, device=device) # -1 indicates the initial state 318 | tower_masks = torch.ones((bsz, n_passages), dtype=torch.float, device=device) # 1.->active, 0.->inactive 319 | for step in range(budget): 320 | actions, action_log_probs = ac_scheduler.act( 321 | has_answer_outputs, 322 | skyline, 323 | masks=tower_masks, 324 | greedy=not is_training 325 | ) # shape: [bsz], [bsz] 326 | all_actions.append(actions) 327 | all_log_probs.append(action_log_probs) 328 | 329 | # Update the selected towers in the skyline 330 | for i, action in enumerate(actions): 331 | new_layer = skyline[i, action].item() + 1 # increment the layer 332 | if new_layer < num_layers - 1: 333 | skyline[i, action] = new_layer 334 | elif new_layer == num_layers - 1: # reaches the last layer 335 | skyline[i, action] = new_layer 336 | tower_masks[i, action] = 0. # mask the tower to avoid choosing it again. 337 | else: 338 | raise ValueError("Selected the tower that is at maximum height.") 339 | 340 | actions = torch.stack(all_actions, 1) # shape: [bsz, budget] 341 | log_probs = torch.stack(all_log_probs, 1) # shape: [bsz, budget] 342 | 343 | # Find the highest towers (passages), shape: [bsz, num_passages_retained] 344 | retained_passages = skyline.argsort(dim=1, descending=True)[:, :num_passages_retained] 345 | 346 | # Update the skyline: forward the highest towers to their last layer 347 | skyline.scatter_(1, retained_passages, num_layers - 1) # shape: [bsz, n_passages] 348 | 349 | # Acquire the hidden_states and attention_masks for the retained passages 350 | retained_hidden_states, retained_attention_masks = [], [] 351 | for bi in range(bsz): 352 | # TODO (jimmycode): find a more efficient implementation for this indexing operation 353 | cur_retained_hidden_states = torch.cat( 354 | [hidden_states[bi, retained_passages[bi, pj]] for pj in range(num_passages_retained)], 0 355 | ) # shape: [num_passages_retained * p_len, d_model] 356 | retained_hidden_states.append(cur_retained_hidden_states) 357 | 358 | cur_retained_attention_masks = torch.cat( 359 | [attention_mask[bi, retained_passages[bi, pj]] for pj in range(num_passages_retained)], 0 360 | ) # shape: [num_passages_retained * p_len] 361 | retained_attention_masks.append(cur_retained_attention_masks) 362 | hidden_states = torch.stack(retained_hidden_states, 0) # shape: [bsz, num_passages_retained * p_len, d_model] 363 | attention_mask = torch.stack(retained_attention_masks, 0) # shape: [bsz, num_passages_retained * p_len] 364 | 365 | return hidden_states, attention_mask, (actions, log_probs, skyline, retained_passages) 366 | 367 | 368 | def compute_REINFORCE_loss(has_answer_labels, actions, log_probs, step_cost, discount=1.0): 369 | """ Compute the REINFORCE loss: 1) evaluate rewards and returns 2) compute the loss """ 370 | action_labels = torch.gather(has_answer_labels, 1, actions) # shape: [bsz, budget] 371 | immediate_rewards = action_labels - step_cost # shape: [bsz, budget] 372 | 373 | # Calculate return values 374 | # return_values = immediate_rewards.flip(1).cumsum(1).flip(1) # shape: [bsz, budget] 375 | all_imd_rewards = immediate_rewards.flip(1).unbind(1) 376 | all_returns, acc = [], None 377 | for ir in all_imd_rewards: 378 | cur_return = ir if acc is None else ir + acc * discount # shape: [bsz] 379 | all_returns.append(cur_return) 380 | acc = cur_return 381 | return_values = torch.stack(all_returns, 1).flip(1) # shape: [bsz, budget] 382 | 383 | loss = -torch.sum(log_probs * return_values) # add negative to maximize 384 | sum_reward = torch.mean(torch.sum(immediate_rewards, 1)) # the sum of all immediate rewards (for logging) 385 | 386 | return loss, sum_reward 387 | -------------------------------------------------------------------------------- /FiD/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import json 4 | 5 | 6 | class QAExample(): 7 | def __init__(self, id, question, answers, target=None, titles=None, contexts=None): 8 | self.id = id 9 | self.question = question 10 | self.answers = answers 11 | self.target = target 12 | self.titles = titles 13 | self.contexts = contexts 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, data, n_context, tokenizer, max_passage_length=250, no_title=False): 18 | self.data = data 19 | self.n_context = n_context 20 | self.tokenizer = tokenizer 21 | self.max_passage_length = max_passage_length 22 | self.no_title = no_title 23 | self.question_prefix = 'question:' 24 | self.title_prefix = 'title:' 25 | self.context_prefix = 'context:' 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, index): 31 | example = self.data[index] 32 | question = example.question 33 | if example.target is None: 34 | target = random.choice(example.answers) 35 | else: 36 | target = example.target 37 | 38 | titles = example.titles[:self.n_context] 39 | contexts = example.contexts[:self.n_context] 40 | 41 | passages = [] 42 | if len(contexts) == 0: 43 | to_concatenate = [self.question_prefix, question] 44 | text = ' '.join(to_concatenate) 45 | passages.append(text) 46 | for i in range(min(self.n_context, len(contexts))): 47 | c = contexts[i] 48 | t = titles[i] 49 | to_concatenate = [self.question_prefix, question] 50 | if c is not None: 51 | if not self.no_title: 52 | to_concatenate += [self.title_prefix, t] 53 | to_concatenate += [self.context_prefix, c] 54 | text = ' '.join(to_concatenate) 55 | passages.append(text) 56 | 57 | return {'index': index, 'question': question, 'target': target, 'passages': passages} 58 | 59 | def get_example(self, index): 60 | return self.data[index] 61 | 62 | 63 | class Collator(object): 64 | def __init__(self, opt, tokenizer): 65 | self.tokenizer = tokenizer 66 | self.max_passage_length = opt.max_passage_length 67 | self.model_type = opt.model_type 68 | 69 | def __call__(self, batch): 70 | index = torch.tensor([ex['index'] for ex in batch]) 71 | question = [ex['question'] for ex in batch] 72 | if self.model_type == 'bart': 73 | target = [ex['target'] for ex in batch] 74 | else: 75 | target = [ex['target'] + ' ' for ex in batch] 76 | target = self.tokenizer.batch_encode_plus(target, pad_to_max_length=True, return_tensors="pt") 77 | target_ids, target_mask = target["input_ids"], target["attention_mask"] 78 | 79 | batch_text_passages = [ex['passages'] for ex in batch] 80 | batch_encoded_passages = [] 81 | 82 | max_context_length = 0 83 | for k, text_passages in enumerate(batch_text_passages): 84 | encoded_passages = [] 85 | for text_p in text_passages: 86 | encoded_p = self.tokenizer.encode(text_p) 87 | if len(encoded_p) > self.max_passage_length: 88 | encoded_p = encoded_p[:self.max_passage_length] 89 | max_context_length = max(max_context_length, len(encoded_p)) 90 | encoded_passages.append(encoded_p) 91 | batch_encoded_passages.append(encoded_passages) 92 | max_context_length = min(max_context_length, self.max_passage_length) 93 | 94 | batch_passage_ids, batch_passage_masks = [], [] 95 | for k, encoded_passages in enumerate(batch_encoded_passages): 96 | p_ids, p_masks = [], [] 97 | for p in encoded_passages: 98 | plen = len(p) 99 | c = torch.cat((torch.tensor(p), torch.zeros(max_context_length - plen).long()), dim=0) # shape: [L] 100 | p_ids.append(c) 101 | m = torch.cat((torch.ones(plen).bool(), torch.zeros(max_context_length - plen).bool()), 102 | dim=0) # shape: [L] 103 | p_masks.append(m) 104 | p_ids = torch.stack(p_ids, dim=0) # shape: [N, L], N is the number of passages 105 | p_masks = torch.stack(p_masks, dim=0) # shape: [N, L] 106 | batch_passage_ids.append(p_ids) 107 | batch_passage_masks.append(p_masks) 108 | 109 | batch_passage_ids = torch.stack(batch_passage_ids, dim=0) # shape: [B, N, L], B is the batch size 110 | batch_passage_masks = torch.stack(batch_passage_masks, dim=0) # shape: [B, N, L] 111 | 112 | return index, target_ids, target_mask, batch_passage_ids, batch_passage_masks 113 | 114 | 115 | def load_data(data_path, global_rank=-1, world_size=-1, n_context=None): 116 | with open(data_path, "r") as f: 117 | data = json.load(f) 118 | 119 | examples = [] 120 | for k, example in enumerate(data): 121 | if global_rank > -1 and not k % world_size == global_rank: 122 | continue 123 | if 'id' in example: 124 | id = example['id'] 125 | else: 126 | id = k 127 | if 'target' in example: 128 | target = example['target'] 129 | else: 130 | target = None 131 | answers = example['answers'] 132 | question = example['question'] 133 | titles, contexts = [], [] 134 | if 'ctxs' in example: 135 | ctxs = example['ctxs'] 136 | if n_context is not None: 137 | ctxs = ctxs[:n_context] 138 | for i, c in enumerate(ctxs): 139 | titles.append(c['title']) 140 | contexts.append(c['text']) 141 | ex = QAExample(id=id, question=question, answers=answers, target=target, titles=titles, contexts=contexts) 142 | examples.append(ex) 143 | 144 | del data 145 | return examples 146 | -------------------------------------------------------------------------------- /FiD/data_ac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import json 4 | 5 | 6 | class QAExample(): 7 | def __init__(self, id, question, answers, target=None, titles=None, contexts=None): 8 | self.id = id 9 | self.question = question 10 | self.answers = answers 11 | self.target = target 12 | self.titles = titles 13 | self.contexts = contexts 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, data, n_context, tokenizer, max_passage_length=250, no_title=False): 18 | self.data = data 19 | self.n_context = n_context 20 | self.tokenizer = tokenizer 21 | self.max_passage_length = max_passage_length 22 | self.no_title = no_title 23 | self.question_prefix = 'question:' 24 | self.title_prefix = 'title:' 25 | self.context_prefix = 'context:' 26 | 27 | def __len__(self): 28 | return len(self.data) 29 | 30 | def __getitem__(self, index): 31 | example = self.data[index] 32 | question = example.question 33 | if example.target is None: 34 | target = random.choice(example.answers) 35 | else: 36 | target = example.target 37 | 38 | titles = example.titles[:self.n_context] 39 | contexts = example.contexts[:self.n_context] 40 | 41 | passages = [] 42 | if len(contexts) == 0: 43 | to_concatenate = [self.question_prefix, question] 44 | text = ' '.join(to_concatenate) 45 | passages.append(text) 46 | for i in range(min(self.n_context, len(contexts))): 47 | c = contexts[i] 48 | t = titles[i] 49 | to_concatenate = [self.question_prefix, question] 50 | if c is not None: 51 | if not self.no_title: 52 | to_concatenate += [self.title_prefix, t] 53 | to_concatenate += [self.context_prefix, c] 54 | text = ' '.join(to_concatenate) 55 | passages.append(text) 56 | 57 | return {'index': index, 'question': question, 'target': target, 'passages': passages} 58 | 59 | def get_example(self, index): 60 | return self.data[index] 61 | 62 | 63 | class Collator(object): 64 | def __init__(self, opt, tokenizer): 65 | self.tokenizer = tokenizer 66 | self.max_passage_length = opt.max_passage_length 67 | self.model_type = opt.model_type 68 | 69 | def __call__(self, batch): 70 | index = torch.tensor([ex['index'] for ex in batch]) 71 | question = [ex['question'] for ex in batch] 72 | if self.model_type == 'bart': 73 | target = [ex['target'] for ex in batch] 74 | else: 75 | target = [ex['target'] + ' ' for ex in batch] 76 | target = self.tokenizer.batch_encode_plus(target, pad_to_max_length=True, return_tensors="pt") 77 | target_ids, target_mask = target["input_ids"], target["attention_mask"] 78 | 79 | batch_text_passages = [ex['passages'] for ex in batch] 80 | batch_encoded_passages = [] 81 | 82 | # Encode the passages 83 | max_context_length = 0 84 | for k, text_passages in enumerate(batch_text_passages): 85 | encoded_passages = [] 86 | for text_p in text_passages: 87 | encoded_p = self.tokenizer.encode(text_p) 88 | if len(encoded_p) > self.max_passage_length: 89 | encoded_p = encoded_p[:self.max_passage_length] 90 | max_context_length = max(max_context_length, len(encoded_p)) 91 | encoded_passages.append(encoded_p) 92 | batch_encoded_passages.append(encoded_passages) 93 | max_context_length = min(max_context_length, self.max_passage_length) 94 | 95 | # Pad the passages to maximum length 96 | batch_passage_ids, batch_passage_masks = [], [] 97 | for k, encoded_passages in enumerate(batch_encoded_passages): 98 | p_ids, p_masks = [], [] 99 | for p in encoded_passages: 100 | plen = len(p) 101 | c = torch.cat((torch.tensor(p), torch.zeros(max_context_length - plen).long()), dim=0) # shape: [L] 102 | p_ids.append(c) 103 | m = torch.cat((torch.ones(plen).bool(), torch.zeros(max_context_length - plen).bool()), 104 | dim=0) # shape: [L] 105 | p_masks.append(m) 106 | p_ids = torch.stack(p_ids, dim=0) # shape: [N, L], N is the number of passages 107 | p_masks = torch.stack(p_masks, dim=0) # shape: [N, L] 108 | batch_passage_ids.append(p_ids) 109 | batch_passage_masks.append(p_masks) 110 | 111 | batch_passage_ids = torch.stack(batch_passage_ids, dim=0) # shape: [B, N, L], B is the batch size 112 | batch_passage_masks = torch.stack(batch_passage_masks, dim=0) # shape: [B, N, L] 113 | 114 | # Get the has_answer labels for training the adaptive computation mechanisms 115 | batch_answers = [ex['target'] for ex in batch] 116 | batch_has_answer_labels = [] 117 | for text_passages, answer in zip(batch_text_passages, batch_answers): 118 | has_answer_labels = [] 119 | for text_p in text_passages: 120 | has_answer_labels.append(1. if answer in text_p else 0.) 121 | batch_has_answer_labels.append(has_answer_labels) 122 | batch_has_answer_labels = torch.tensor(batch_has_answer_labels) # shape: [B, N] 123 | 124 | return index, target_ids, target_mask, batch_passage_ids, batch_passage_masks, batch_has_answer_labels 125 | 126 | 127 | def load_data(data_path, global_rank=-1, world_size=-1, n_context=None): 128 | with open(data_path, "r") as f: 129 | data = json.load(f) 130 | 131 | examples = [] 132 | for k, example in enumerate(data): 133 | if global_rank > -1 and not k % world_size == global_rank: 134 | continue 135 | if 'id' in example: 136 | id = example['id'] 137 | else: 138 | id = k 139 | if 'target' in example: 140 | target = example['target'] 141 | else: 142 | target = None 143 | answers = example['answers'] 144 | question = example['question'] 145 | titles, contexts = [], [] 146 | if 'ctxs' in example: 147 | ctxs = example['ctxs'] 148 | if n_context is not None: 149 | ctxs = ctxs[:n_context] 150 | for i, c in enumerate(ctxs): 151 | titles.append(c['title']) 152 | contexts.append(c['text']) 153 | ex = QAExample(id=id, question=question, answers=answers, target=target, titles=titles, contexts=contexts) 154 | examples.append(ex) 155 | 156 | del data 157 | return examples 158 | -------------------------------------------------------------------------------- /FiD/evaluation.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import unicodedata 4 | import regex 5 | import copy 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def normalize_answer(s): 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | 15 | def white_space_fix(text): 16 | return ' '.join(text.split()) 17 | 18 | def remove_punc(text): 19 | exclude = set(string.punctuation) 20 | return ''.join(ch for ch in text if ch not in exclude) 21 | 22 | def lower(text): 23 | return text.lower() 24 | 25 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 26 | 27 | 28 | def exact_match_score(prediction, ground_truth): 29 | return normalize_answer(prediction) == normalize_answer(ground_truth) 30 | 31 | 32 | def ems(prediction, ground_truths): 33 | return max([exact_match_score(prediction, gt) for gt in ground_truths]) 34 | 35 | 36 | class Tokens(object): 37 | """A class to represent a list of tokenized text.""" 38 | TEXT = 0 39 | TEXT_WS = 1 40 | SPAN = 2 41 | POS = 3 42 | LEMMA = 4 43 | NER = 5 44 | 45 | def __init__(self, data, annotators, opts=None): 46 | self.data = data 47 | self.annotators = annotators 48 | self.opts = opts or {} 49 | 50 | def __len__(self): 51 | """The number of tokens.""" 52 | return len(self.data) 53 | 54 | def slice(self, i=None, j=None): 55 | """Return a view of the list of tokens from [i, j).""" 56 | new_tokens = copy.copy(self) 57 | new_tokens.data = self.data[i: j] 58 | return new_tokens 59 | 60 | def untokenize(self): 61 | """Returns the original text (with whitespace reinserted).""" 62 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 63 | 64 | def words(self, uncased=False): 65 | """Returns a list of the text of each token 66 | Args: 67 | uncased: lower cases text 68 | """ 69 | if uncased: 70 | return [t[self.TEXT].lower() for t in self.data] 71 | else: 72 | return [t[self.TEXT] for t in self.data] 73 | 74 | def offsets(self): 75 | """Returns a list of [start, end) character offsets of each token.""" 76 | return [t[self.SPAN] for t in self.data] 77 | 78 | def pos(self): 79 | """Returns a list of part-of-speech tags of each token. 80 | Returns None if this annotation was not included. 81 | """ 82 | if 'pos' not in self.annotators: 83 | return None 84 | return [t[self.POS] for t in self.data] 85 | 86 | def lemmas(self): 87 | """Returns a list of the lemmatized text of each token. 88 | Returns None if this annotation was not included. 89 | """ 90 | if 'lemma' not in self.annotators: 91 | return None 92 | return [t[self.LEMMA] for t in self.data] 93 | 94 | def entities(self): 95 | """Returns a list of named-entity-recognition tags of each token. 96 | Returns None if this annotation was not included. 97 | """ 98 | if 'ner' not in self.annotators: 99 | return None 100 | return [t[self.NER] for t in self.data] 101 | 102 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 103 | """Returns a list of all ngrams from length 1 to n. 104 | Args: 105 | n: upper limit of ngram length 106 | uncased: lower cases text 107 | filter_fn: user function that takes in an ngram list and returns 108 | True or False to keep or not keep the ngram 109 | as_string: return the ngram as a string vs list 110 | """ 111 | 112 | def _skip(gram): 113 | if not filter_fn: 114 | return False 115 | return filter_fn(gram) 116 | 117 | words = self.words(uncased) 118 | ngrams = [(s, e + 1) 119 | for s in range(len(words)) 120 | for e in range(s, min(s + n, len(words))) 121 | if not _skip(words[s:e + 1])] 122 | 123 | # Concatenate into strings 124 | if as_strings: 125 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 126 | 127 | return ngrams 128 | 129 | def entity_groups(self): 130 | """Group consecutive entity tokens with the same NER tag.""" 131 | entities = self.entities() 132 | if not entities: 133 | return None 134 | non_ent = self.opts.get('non_ent', 'O') 135 | groups = [] 136 | idx = 0 137 | while idx < len(entities): 138 | ner_tag = entities[idx] 139 | # Check for entity tag 140 | if ner_tag != non_ent: 141 | # Chomp the sequence 142 | start = idx 143 | while (idx < len(entities) and entities[idx] == ner_tag): 144 | idx += 1 145 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 146 | else: 147 | idx += 1 148 | return groups 149 | 150 | 151 | class Tokenizer(object): 152 | """Base tokenizer class. 153 | Tokenizers implement tokenize, which should return a Tokens class. 154 | """ 155 | 156 | def tokenize(self, text): 157 | raise NotImplementedError 158 | 159 | def shutdown(self): 160 | pass 161 | 162 | def __del__(self): 163 | self.shutdown() 164 | 165 | 166 | class SimpleTokenizer(Tokenizer): 167 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 168 | NON_WS = r'[^\p{Z}\p{C}]' 169 | 170 | def __init__(self, **kwargs): 171 | """ 172 | Args: 173 | annotators: None or empty set (only tokenizes). 174 | """ 175 | self._regexp = regex.compile( 176 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 177 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 178 | ) 179 | if len(kwargs.get('annotators', {})) > 0: 180 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 181 | (type(self).__name__, kwargs.get('annotators'))) 182 | self.annotators = set() 183 | 184 | def tokenize(self, text): 185 | data = [] 186 | matches = [m for m in self._regexp.finditer(text)] 187 | for i in range(len(matches)): 188 | # Get text 189 | token = matches[i].group() 190 | 191 | # Get whitespace 192 | span = matches[i].span() 193 | start_ws = span[0] 194 | if i + 1 < len(matches): 195 | end_ws = matches[i + 1].span()[0] 196 | else: 197 | end_ws = span[1] 198 | 199 | # Format data 200 | data.append(( 201 | token, 202 | text[start_ws: end_ws], 203 | span, 204 | )) 205 | return Tokens(data, self.annotators) 206 | 207 | 208 | def _normalize(text): 209 | return unicodedata.normalize("NFD", text) 210 | 211 | 212 | def has_answer(answers, text, tokenizer) -> bool: 213 | """Check if a document contains an answer string. 214 | If `match_type` is string, token matching is done between the text and answer. 215 | If `match_type` is regex, we search the whole text with the regex. 216 | """ 217 | text = _normalize(text) 218 | 219 | # Answer is a list of possible strings 220 | text = tokenizer.tokenize(text).words(uncased=True) 221 | 222 | for single_answer in answers: 223 | single_answer = _normalize(single_answer) 224 | single_answer = tokenizer.tokenize(single_answer) 225 | single_answer = single_answer.words(uncased=True) 226 | 227 | for i in range(0, len(text) - len(single_answer) + 1): 228 | if single_answer == text[i: i + len(single_answer)]: 229 | return True 230 | 231 | return False 232 | -------------------------------------------------------------------------------- /FiD/fidt5.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch T5 model. """ 16 | 17 | import copy 18 | import logging 19 | import warnings 20 | 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | from torch.nn import CrossEntropyLoss 25 | 26 | from transformers.configuration_t5 import T5Config 27 | from transformers.file_utils import DUMMY_INPUTS, DUMMY_MASK 28 | from transformers.modeling_t5 import T5LayerNorm, T5DenseReluDense, T5LayerFF, T5Attention, T5LayerSelfAttention, \ 29 | T5LayerCrossAttention, T5Model, T5ForConditionalGeneration, load_tf_weights_in_t5 30 | from transformers.modeling_utils import PreTrainedModel 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class T5Block(nn.Module): 36 | def __init__(self, config, has_relative_attention_bias=False): 37 | super().__init__() 38 | self.is_decoder = config.is_decoder 39 | self.layer = nn.ModuleList() 40 | self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) 41 | if self.is_decoder: 42 | self.layer.append(T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)) 43 | 44 | self.layer.append(T5LayerFF(config)) 45 | 46 | def forward( 47 | self, 48 | hidden_states, 49 | attention_mask=None, 50 | position_bias=None, 51 | encoder_hidden_states=None, 52 | encoder_attention_mask=None, 53 | encoder_decoder_position_bias=None, 54 | head_mask=None, 55 | past_key_value_state=None, 56 | use_cache=False, 57 | output_attentions=False, 58 | ): 59 | 60 | if past_key_value_state is not None: 61 | assert self.is_decoder, "Only decoder can use `past_key_value_states`" 62 | expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4 63 | 64 | error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format( 65 | expected_num_past_key_value_states, 66 | "2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "", 67 | len(past_key_value_state), 68 | ) 69 | assert len(past_key_value_state) == expected_num_past_key_value_states, error_message 70 | 71 | self_attn_past_key_value_state = past_key_value_state[:2] 72 | cross_attn_past_key_value_state = past_key_value_state[2:] 73 | else: 74 | self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None 75 | 76 | self_attention_outputs = self.layer[0]( 77 | hidden_states, 78 | attention_mask=attention_mask, 79 | position_bias=position_bias, 80 | head_mask=head_mask, 81 | past_key_value_state=self_attn_past_key_value_state, 82 | use_cache=use_cache, 83 | output_attentions=output_attentions, 84 | ) 85 | hidden_states, present_key_value_state = self_attention_outputs[:2] 86 | attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights 87 | 88 | if self.is_decoder and encoder_hidden_states is not None: 89 | # the actual query length is unknown for cross attention 90 | # if using past key value states. Need to inject it here 91 | if present_key_value_state is not None: 92 | query_length = present_key_value_state[0].shape[2] 93 | else: 94 | query_length = None 95 | 96 | cross_attention_outputs = self.layer[1]( 97 | hidden_states, 98 | kv=encoder_hidden_states, 99 | attention_mask=encoder_attention_mask, 100 | position_bias=encoder_decoder_position_bias, 101 | head_mask=head_mask, 102 | past_key_value_state=cross_attn_past_key_value_state, 103 | query_length=query_length, 104 | use_cache=use_cache, 105 | output_attentions=output_attentions, 106 | ) 107 | hidden_states = cross_attention_outputs[0] 108 | # Combine self attn and cross attn key value states 109 | if present_key_value_state is not None: 110 | present_key_value_state = present_key_value_state + cross_attention_outputs[1] 111 | 112 | # Keep cross-attention outputs and relative position weights 113 | attention_outputs = attention_outputs + cross_attention_outputs[2:] 114 | 115 | # Apply Feed Forward layer 116 | hidden_states = self.layer[-1](hidden_states) 117 | outputs = (hidden_states,) 118 | 119 | # Add attentions if we output them 120 | outputs = outputs + (present_key_value_state,) + attention_outputs 121 | return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) 122 | 123 | 124 | class T5PreTrainedModel(PreTrainedModel): 125 | """ An abstract class to handle weights initialization and 126 | a simple interface for downloading and loading pretrained models. 127 | """ 128 | 129 | config_class = T5Config 130 | load_tf_weights = load_tf_weights_in_t5 131 | base_model_prefix = "transformer" 132 | 133 | @property 134 | def dummy_inputs(self): 135 | input_ids = torch.tensor(DUMMY_INPUTS) 136 | input_mask = torch.tensor(DUMMY_MASK) 137 | dummy_inputs = { 138 | "decoder_input_ids": input_ids, 139 | "input_ids": input_ids, 140 | "decoder_attention_mask": input_mask, 141 | } 142 | return dummy_inputs 143 | 144 | def _init_weights(self, module): 145 | """ Initialize the weights """ 146 | factor = self.config.initializer_factor # Used for testing weights initialization 147 | if isinstance(module, T5LayerNorm): 148 | module.weight.data.fill_(factor * 1.0) 149 | elif isinstance(module, (T5Model, T5ForConditionalGeneration, FiDT5)): # Change (FiD): added FiDT5 150 | # Mesh TensorFlow embeddings initialization 151 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 152 | module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) 153 | elif isinstance(module, T5DenseReluDense): 154 | # Mesh TensorFlow FF initialization 155 | # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 156 | # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 157 | module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) 158 | if hasattr(module.wi, "bias") and module.wi.bias is not None: 159 | module.wi.bias.data.zero_() 160 | module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) 161 | if hasattr(module.wo, "bias") and module.wo.bias is not None: 162 | module.wo.bias.data.zero_() 163 | elif isinstance(module, T5Attention): 164 | # Mesh TensorFlow attention initialization to avoid scaling before softmax 165 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 166 | d_model = self.config.d_model 167 | d_kv = self.config.d_kv 168 | n_heads = self.config.num_heads 169 | module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * d_kv) ** -0.5)) 170 | module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) 171 | module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5)) 172 | module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * d_kv) ** -0.5)) 173 | if module.has_relative_attention_bias: 174 | module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) 175 | 176 | def _shift_right(self, input_ids): 177 | decoder_start_token_id = self.config.decoder_start_token_id 178 | pad_token_id = self.config.pad_token_id 179 | 180 | assert ( 181 | decoder_start_token_id is not None 182 | ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information" 183 | 184 | # shift inputs to the right 185 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 186 | shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() 187 | shifted_input_ids[..., 0] = decoder_start_token_id 188 | 189 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 190 | # replace possible -100 values in labels by `pad_token_id` 191 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 192 | 193 | assert torch.all(shifted_input_ids >= 0).item(), "Verify that `labels` has only positive values and -100" 194 | 195 | return shifted_input_ids 196 | 197 | 198 | class T5Stack(T5PreTrainedModel): 199 | def __init__(self, config, embed_tokens=None): 200 | super().__init__(config) 201 | 202 | self.embed_tokens = embed_tokens 203 | self.is_decoder = config.is_decoder 204 | 205 | self.block = nn.ModuleList( 206 | [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] 207 | ) 208 | self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 209 | self.dropout = nn.Dropout(config.dropout_rate) 210 | 211 | self.init_weights() 212 | self.checkpoint = False # Change (FiD): flag for gradient checkpointing (during training) 213 | 214 | def get_input_embeddings(self): 215 | return self.embed_tokens 216 | 217 | def get_output_embeddings(self): 218 | return self.embed_tokens 219 | 220 | def set_input_embeddings(self, new_embeddings): 221 | self.embed_tokens = new_embeddings 222 | 223 | def forward( 224 | self, 225 | input_ids=None, 226 | attention_mask=None, 227 | encoder_hidden_states=None, 228 | encoder_attention_mask=None, 229 | inputs_embeds=None, 230 | head_mask=None, 231 | past_key_value_states=None, 232 | use_cache=None, 233 | output_attentions=None, 234 | output_hidden_states=None, 235 | ): 236 | if not self.is_decoder: # Change (FiD): encoder needs to reshape the inputs 237 | bsz, tc = input_ids.shape 238 | plen = tc // self.n_passages 239 | input_ids = input_ids.view(bsz * self.n_passages, plen) 240 | attention_mask = attention_mask.view(bsz * self.n_passages, plen) 241 | 242 | use_cache = use_cache if use_cache is not None else self.config.use_cache 243 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 244 | output_hidden_states = ( 245 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 246 | ) 247 | 248 | if input_ids is not None and inputs_embeds is not None: 249 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 250 | elif input_ids is not None: 251 | input_shape = input_ids.size() 252 | input_ids = input_ids.view(-1, input_shape[-1]) 253 | elif inputs_embeds is not None: 254 | input_shape = inputs_embeds.size()[:-1] 255 | else: 256 | if self.is_decoder: 257 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 258 | else: 259 | raise ValueError("You have to specify either input_ids or inputs_embeds") 260 | 261 | if inputs_embeds is None: 262 | assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" 263 | inputs_embeds = self.embed_tokens(input_ids) 264 | 265 | batch_size, seq_length = input_shape 266 | 267 | if past_key_value_states is not None: 268 | assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format( 269 | input_shape, (batch_size, 1) 270 | ) 271 | # required mask seq length can be calculated via length of past 272 | # key value states and seq_length = 1 for the last token 273 | mask_seq_length = past_key_value_states[0][0].shape[2] + seq_length 274 | else: 275 | mask_seq_length = seq_length 276 | 277 | if attention_mask is None: 278 | attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) 279 | if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: 280 | encoder_seq_length = encoder_hidden_states.shape[1] 281 | encoder_attention_mask = torch.ones( 282 | batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long 283 | ) 284 | 285 | # initialize past_key_value_states with `None` if past does not exist 286 | if past_key_value_states is None: 287 | past_key_value_states = [None] * len(self.block) 288 | 289 | # ourselves in which case we just need to make it broadcastable to all heads. 290 | extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) 291 | 292 | if self.is_decoder and encoder_attention_mask is not None: 293 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 294 | else: 295 | encoder_extended_attention_mask = None 296 | 297 | # Prepare head mask if needed 298 | head_mask = self.get_head_mask(head_mask, self.config.num_layers) 299 | present_key_value_states = () 300 | all_hidden_states = () 301 | all_attentions = () 302 | position_bias = None 303 | encoder_decoder_position_bias = None 304 | 305 | hidden_states = self.dropout(inputs_embeds) 306 | 307 | for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)): 308 | if output_hidden_states: 309 | all_hidden_states = all_hidden_states + (hidden_states,) 310 | 311 | if not self.is_decoder and self.checkpoint: # Change (FiD): encoder with gradient checkpointing 312 | hidden_states = hidden_states.contiguous() 313 | extended_attention_mask = extended_attention_mask.contiguous() 314 | layer_outputs = torch.utils.checkpoint.checkpoint( 315 | layer_module, 316 | hidden_states, 317 | extended_attention_mask, 318 | position_bias, 319 | # encoder_hidden_states, 320 | # encoder_extended_attention_mask, 321 | # encoder_decoder_position_bias, 322 | # head_mask[i], 323 | # past_key_value_state, 324 | ) 325 | else: 326 | layer_outputs = layer_module( 327 | hidden_states, 328 | attention_mask=extended_attention_mask, 329 | position_bias=position_bias, 330 | encoder_hidden_states=encoder_hidden_states, 331 | encoder_attention_mask=encoder_extended_attention_mask, 332 | encoder_decoder_position_bias=encoder_decoder_position_bias, 333 | head_mask=head_mask[i], 334 | past_key_value_state=past_key_value_state, 335 | use_cache=use_cache, 336 | output_attentions=output_attentions, 337 | ) 338 | # layer_outputs is a tuple with: 339 | # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) 340 | hidden_states, present_key_value_state = layer_outputs[:2] 341 | 342 | if i == 0: 343 | # We share the position biases between the layers - the first layer store them 344 | # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) 345 | position_bias = layer_outputs[3 if output_attentions else 2] 346 | if self.is_decoder and encoder_hidden_states is not None: 347 | encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3] 348 | # append next layer key value states 349 | present_key_value_states = present_key_value_states + (present_key_value_state,) 350 | 351 | if output_attentions: 352 | all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now 353 | 354 | hidden_states = self.final_layer_norm(hidden_states) 355 | hidden_states = self.dropout(hidden_states) 356 | 357 | if not self.is_decoder: # Change (FiD): reshape output 358 | hidden_states = hidden_states.view(bsz, self.n_passages * plen, -1) 359 | 360 | # Add last layer 361 | if output_hidden_states: 362 | all_hidden_states = all_hidden_states + (hidden_states,) 363 | 364 | outputs = (hidden_states,) 365 | if use_cache is True: 366 | assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) 367 | outputs = outputs + (present_key_value_states,) 368 | if output_hidden_states: 369 | outputs = outputs + (all_hidden_states,) 370 | if output_attentions: 371 | outputs = outputs + (all_attentions,) 372 | return outputs # last-layer hidden state, (presents,) (all hidden states), (all attentions) 373 | 374 | 375 | class FiDT5(T5PreTrainedModel): 376 | def __init__(self, config): 377 | super().__init__(config) 378 | self.model_dim = config.d_model 379 | 380 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 381 | 382 | encoder_config = copy.deepcopy(config) 383 | encoder_config.use_cache = False 384 | self.encoder = T5Stack(encoder_config, self.shared) 385 | 386 | decoder_config = copy.deepcopy(config) 387 | decoder_config.is_decoder = True 388 | self.decoder = T5Stack(decoder_config, self.shared) 389 | 390 | self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) 391 | 392 | self.init_weights() 393 | 394 | def get_input_embeddings(self): 395 | return self.shared 396 | 397 | def set_input_embeddings(self, new_embeddings): 398 | self.shared = new_embeddings 399 | self.encoder.set_input_embeddings(new_embeddings) 400 | self.decoder.set_input_embeddings(new_embeddings) 401 | 402 | def get_output_embeddings(self): 403 | return self.lm_head 404 | 405 | def get_encoder(self): 406 | return self.encoder 407 | 408 | def get_decoder(self): 409 | return self.decoder 410 | 411 | def forward( 412 | self, 413 | input_ids=None, 414 | attention_mask=None, 415 | encoder_outputs=None, 416 | decoder_input_ids=None, 417 | decoder_attention_mask=None, 418 | decoder_past_key_value_states=None, 419 | use_cache=None, 420 | labels=None, 421 | inputs_embeds=None, 422 | decoder_inputs_embeds=None, 423 | head_mask=None, 424 | output_attentions=None, 425 | output_hidden_states=None, 426 | **kwargs 427 | ): 428 | 429 | if "lm_labels" in kwargs: 430 | warnings.warn( 431 | "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", 432 | DeprecationWarning, 433 | ) 434 | labels = kwargs.pop("lm_labels") 435 | assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." 436 | 437 | use_cache = use_cache if use_cache is not None else self.config.use_cache 438 | 439 | # Encode if needed (training, first prediction pass) 440 | if encoder_outputs is None: 441 | # Convert encoder inputs in embeddings if needed 442 | encoder_outputs = self.encoder( 443 | input_ids=input_ids, 444 | attention_mask=attention_mask, 445 | inputs_embeds=inputs_embeds, 446 | head_mask=head_mask, 447 | output_attentions=output_attentions, 448 | output_hidden_states=output_hidden_states, 449 | ) 450 | 451 | hidden_states = encoder_outputs[0] 452 | 453 | if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: 454 | # get decoder inputs from shifting lm labels to the right 455 | decoder_input_ids = self._shift_right(labels) 456 | 457 | # If decoding with past key value states, only the last tokens 458 | # should be given as an input 459 | if decoder_past_key_value_states is not None: 460 | assert labels is None, "Decoder should not use cached key value states when training." 461 | if decoder_input_ids is not None: 462 | decoder_input_ids = decoder_input_ids[:, -1:] 463 | if decoder_inputs_embeds is not None: 464 | decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] 465 | 466 | # Decode 467 | decoder_outputs = self.decoder( 468 | input_ids=decoder_input_ids, 469 | attention_mask=decoder_attention_mask, 470 | inputs_embeds=decoder_inputs_embeds, 471 | past_key_value_states=decoder_past_key_value_states, 472 | encoder_hidden_states=hidden_states, 473 | encoder_attention_mask=attention_mask, 474 | head_mask=head_mask, 475 | use_cache=use_cache, 476 | output_attentions=output_attentions, 477 | output_hidden_states=output_hidden_states, 478 | ) 479 | 480 | # insert decoder past at right place 481 | # to speed up decoding 482 | if use_cache is True: 483 | past = ((encoder_outputs, decoder_outputs[1]),) 484 | decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] 485 | 486 | sequence_output = decoder_outputs[0] 487 | # Rescale output before projecting on vocab 488 | # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 489 | sequence_output = sequence_output * (self.model_dim ** -0.5) 490 | lm_logits = self.lm_head(sequence_output) 491 | 492 | decoder_outputs = (lm_logits,) + decoder_outputs[1:] # Add hidden states and attention if they are here 493 | if labels is not None: 494 | loss_fct = CrossEntropyLoss(ignore_index=-100) 495 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) 496 | # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 497 | decoder_outputs = (loss,) + decoder_outputs 498 | 499 | return decoder_outputs + encoder_outputs 500 | 501 | def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): 502 | assert past is not None, "past has to be defined for encoder_outputs" 503 | 504 | encoder_outputs, decoder_past_key_value_states = past 505 | 506 | return { 507 | "decoder_input_ids": input_ids, 508 | "decoder_past_key_value_states": decoder_past_key_value_states, 509 | "encoder_outputs": encoder_outputs, 510 | "attention_mask": attention_mask, 511 | "use_cache": use_cache, 512 | } 513 | 514 | def _reorder_cache(self, past, beam_idx): 515 | # if decoder past is not included in output 516 | # speedy decoding is disabled and no need to reorder 517 | if past[1] is None: 518 | logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") 519 | return past 520 | 521 | decoder_past = past[1] 522 | past = (past[0],) 523 | reordered_decoder_past = () 524 | for layer_past_states in decoder_past: 525 | # get the correct batch idx from layer past batch dim 526 | # batch dim of `past` is at 2nd position 527 | reordered_layer_past_states = () 528 | for layer_past_state in layer_past_states: 529 | # need to set correct `past` for each of the four key / value states 530 | reordered_layer_past_states = reordered_layer_past_states + ( 531 | layer_past_state.index_select(0, beam_idx), 532 | ) 533 | 534 | assert reordered_layer_past_states[0].shape == layer_past_states[0].shape 535 | assert len(reordered_layer_past_states) == len(layer_past_states) 536 | 537 | reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) 538 | return past + (reordered_decoder_past,) 539 | -------------------------------------------------------------------------------- /FiD/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | class Options(): 6 | def __init__(self): 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | self.parser = self.initialize(parser) 9 | 10 | def initialize(self, parser): 11 | # basic parameters 12 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment') 13 | parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint/', help='models are saved here') 14 | parser.add_argument('--model_path', type=str, default='none', help='path for retraining') 15 | parser.add_argument('--train_data_path', type=str, default='none', help='path of train data') 16 | parser.add_argument('--dev_data_path', type=str, default='none', help='path of dev data') 17 | parser.add_argument('--dev_data_size', type=int, default=-1, help='subsample dev data to speedup evaluation') 18 | parser.add_argument('--test_data_path', type=str, default='none', help='path of test data') 19 | parser.add_argument('--model_type', type=str, default='t5') 20 | parser.add_argument('--model_size', type=str, default='base') 21 | parser.add_argument('--write_results', action='store_true', help='save test results') 22 | 23 | # dataset parameters 24 | parser.add_argument("--per_gpu_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 25 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Accumulation step.") 26 | parser.add_argument('--no_title', action='store_true', help='article titles not included in passages') 27 | parser.add_argument('--n_context', type=int, default=1) 28 | parser.add_argument('--total_step', type=int, default=10000) 29 | parser.add_argument('--reload_step', type=int, default=-1, help='reload model at step ') 30 | parser.add_argument('--max_passage_length', type=int, default=250, 31 | help='maximum number of tokens in the passages (question included)') 32 | parser.add_argument('--checkpointing_encoder', action='store_true', help='trades memory for compute') 33 | parser.add_argument('--checkpointing_decoder', action='store_true', help='trades memory for compute') 34 | 35 | # training parameters 36 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 37 | parser.add_argument('--adam_epsilon', type=float, default=1e-8, help='epsilon for Adam optimizer') 38 | parser.add_argument('--warmup_step', type=int, default=0, help='number of warmup steps') 39 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 40 | parser.add_argument('--log_freq', type=int, default=10, 41 | help='log model loss every steps during training') 42 | parser.add_argument('--eval_freq', type=int, default=500, 43 | help='evaluate model every steps during training') 44 | parser.add_argument('--eval_print_freq', type=int, default=1000, 45 | help='print intermediate results of evaluation every steps') 46 | parser.add_argument('--save_freq', type=int, default=1000, 47 | help='save model every steps during training') 48 | 49 | parser.add_argument("--local_rank", type=int, default=-1, 50 | help="For distributed training: local_rank") 51 | parser.add_argument("--master_port", type=int, default=-1, 52 | help="Master port (for multi-node SLURM jobs)") 53 | parser.add_argument('--seed', type=int, default=0, help="random seed for initialization") 54 | parser.add_argument('--global_rank', type=int, default=-1) 55 | parser.add_argument('--world_size', type=int, default=-1) 56 | parser.add_argument('--is_master', action='store_true') 57 | parser.add_argument('--fp16', action='store_true') 58 | parser.add_argument('--fp16_opt_level', type=str, default="O1", 59 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 60 | "See details at https://nvidia.github.io/apex/amp.html") 61 | 62 | # AC (Adaptive Computation) has_answer_heads parameters 63 | parser.add_argument('--has_answer_pool_type', type=str, default="none", 64 | help='pooling type of has_answer_heads') 65 | 66 | # AC (Adaptive Computation) Scheduler model parameters 67 | parser.add_argument('--scheduler_type', type=str, default="none", 68 | help='type of the AC scheduler (default: none)') 69 | parser.add_argument('--scheduler_n_context', type=int, default=1, 70 | help='maximum number of context for the AC scheduler') 71 | parser.add_argument('--scheduler_embed_size', type=int, default=10, 72 | help='embedding size of the AC MLPScheduler') 73 | parser.add_argument('--scheduler_hidden_size', type=int, default=10, 74 | help='hidden size of the AC MLPScheduler') 75 | 76 | # AC (Adaptive Computation) train/inference parameters 77 | parser.add_argument('--freeze_fid_params', action='store_true', help='freeze the FiD parameters') 78 | parser.add_argument('--freeze_has_answer_heads', action='store_true', 79 | help='freeze the has_answer_heads parameters (used when training the AC scheduler)') 80 | parser.add_argument('--use_bce_loss', action='store_true', 81 | help='train the has_answer_heads with Binary Cross-entropy loss') 82 | parser.add_argument('--use_rl_loss', action='store_true', 83 | help='train the scheduler with REINFORCE loss') 84 | parser.add_argument('--budget', type=int, default=None, help='budget number of passage layer') 85 | parser.add_argument('--num_passages_retained', type=int, default=None, 86 | help='number of passages retained after AC') 87 | parser.add_argument('--step_cost', type=float, default=0., 88 | help='cost per step when training the scheduler with REINFORCE') 89 | parser.add_argument('--discount', type=float, default=1., 90 | help='discount factor when training the scheduler with REINFORCE') 91 | 92 | return parser 93 | 94 | def print_options(self, opt): 95 | message = '' 96 | for k, v in sorted(vars(opt).items()): 97 | comment = '' 98 | default = self.parser.get_default(k) 99 | if v != default: 100 | comment = '\t[default: %s]' % str(default) 101 | message += '{:>40}: {:<40}{}\n'.format(str(k), str(v), comment) 102 | 103 | expr_dir = os.path.join(opt.checkpoint_dir, opt.name) 104 | model_dir = os.path.join(expr_dir, 'models') 105 | if not os.path.exists(model_dir): 106 | os.makedirs(os.path.join(expr_dir, 'models')) 107 | file_name = os.path.join(expr_dir, 'opt.txt') 108 | with open(file_name, 'wt') as opt_file: 109 | opt_file.write(message) 110 | opt_file.write('\n') 111 | 112 | def parse(self): 113 | opt, _ = self.parser.parse_known_args() 114 | opt = self.parser.parse_args() 115 | return opt 116 | -------------------------------------------------------------------------------- /FiD/slurm.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | import os 3 | import sys 4 | import torch 5 | import socket 6 | import signal 7 | import subprocess 8 | 9 | 10 | logger = getLogger() 11 | 12 | def sig_handler(signum, frame): 13 | logger.warning("Signal handler called with signal " + str(signum)) 14 | prod_id = int(os.environ['SLURM_PROCID']) 15 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 16 | if prod_id == 0: 17 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 18 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 19 | else: 20 | logger.warning("Not the master process, no need to requeue.") 21 | sys.exit(-1) 22 | 23 | 24 | def term_handler(signum, frame): 25 | logger.warning("Signal handler called with signal " + str(signum)) 26 | logger.warning("Bypassing SIGTERM.") 27 | 28 | 29 | def init_signal_handler(): 30 | """ 31 | Handle signals sent by SLURM for time limit / pre-emption. 32 | """ 33 | signal.signal(signal.SIGUSR1, sig_handler) 34 | signal.signal(signal.SIGTERM, term_handler) 35 | #logger.warning("Signal handler installed.") 36 | 37 | 38 | def init_distributed_mode(params): 39 | """ 40 | Handle single and multi-GPU / multi-node / SLURM jobs. 41 | Initialize the following variables: 42 | - n_nodes 43 | - node_id 44 | - local_rank 45 | - global_rank 46 | - world_size 47 | """ 48 | params.is_slurm_job = 'SLURM_JOB_ID' in os.environ 49 | #print("SLURM job: %s" % str(params.is_slurm_job)) 50 | 51 | # SLURM job 52 | if params.is_slurm_job: 53 | 54 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 55 | 56 | SLURM_VARIABLES = [ 57 | 'SLURM_JOB_ID', 58 | 'SLURM_JOB_NODELIST', 'SLURM_JOB_NUM_NODES', 'SLURM_NTASKS', 'SLURM_TASKS_PER_NODE', 59 | 'SLURM_MEM_PER_NODE', 'SLURM_MEM_PER_CPU', 60 | 'SLURM_NODEID', 'SLURM_PROCID', 'SLURM_LOCALID', 'SLURM_TASK_PID' 61 | ] 62 | 63 | PREFIX = "%i - " % int(os.environ['SLURM_PROCID']) 64 | for name in SLURM_VARIABLES: 65 | value = os.environ.get(name, None) 66 | #print(PREFIX + "%s: %s" % (name, str(value))) 67 | 68 | # # job ID 69 | # params.job_id = os.environ['SLURM_JOB_ID'] 70 | 71 | # number of nodes / node ID 72 | params.n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 73 | params.node_id = int(os.environ['SLURM_NODEID']) 74 | 75 | # local rank on the current node / global rank 76 | params.local_rank = int(os.environ['SLURM_LOCALID']) 77 | params.global_rank = int(os.environ['SLURM_PROCID']) 78 | 79 | # number of processes / GPUs per node 80 | params.world_size = int(os.environ['SLURM_NTASKS']) 81 | params.n_gpu_per_node = params.world_size // params.n_nodes 82 | 83 | # define master address and master port 84 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 85 | params.master_addr = hostnames.split()[0].decode('utf-8') 86 | assert 10001 <= params.master_port <= 20000 or params.world_size == 1 87 | #print(PREFIX + "Master address: %s" % params.master_addr) 88 | #print(PREFIX + "Master port : %i" % params.master_port) 89 | 90 | # set environment variables for 'env://' 91 | os.environ['MASTER_ADDR'] = params.master_addr 92 | os.environ['MASTER_PORT'] = str(params.master_port) 93 | os.environ['WORLD_SIZE'] = str(params.world_size) 94 | os.environ['RANK'] = str(params.global_rank) 95 | 96 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 97 | elif params.local_rank != -1: 98 | 99 | assert params.master_port == -1 100 | 101 | # read environment variables 102 | params.global_rank = int(os.environ['RANK']) 103 | params.world_size = int(os.environ['WORLD_SIZE']) 104 | params.n_gpu_per_node = int(os.environ['NGPU']) 105 | 106 | # number of nodes / node ID 107 | params.n_nodes = params.world_size // params.n_gpu_per_node 108 | params.node_id = params.global_rank // params.n_gpu_per_node 109 | 110 | # local job (single GPU) 111 | else: 112 | assert params.local_rank == -1 113 | assert params.master_port == -1 114 | params.n_nodes = 1 115 | params.node_id = 0 116 | params.local_rank = 0 117 | params.global_rank = 0 118 | params.world_size = 1 119 | params.n_gpu_per_node = 1 120 | 121 | # sanity checks 122 | assert params.n_nodes >= 1 123 | assert 0 <= params.node_id < params.n_nodes 124 | assert 0 <= params.local_rank <= params.global_rank < params.world_size 125 | assert params.world_size == params.n_nodes * params.n_gpu_per_node 126 | 127 | # define whether this is the master process / if we are in distributed mode 128 | params.is_master = params.node_id == 0 and params.local_rank == 0 129 | params.multi_node = params.n_nodes > 1 130 | params.multi_gpu = params.world_size > 1 131 | 132 | # summary 133 | PREFIX = "%i - " % params.global_rank 134 | #print(PREFIX + "Number of nodes: %i" % params.n_nodes) 135 | #print(PREFIX + "Node ID : %i" % params.node_id) 136 | #print(PREFIX + "Local rank : %i" % params.local_rank) 137 | #print(PREFIX + "Global rank : %i" % params.global_rank) 138 | #print(PREFIX + "World size : %i" % params.world_size) 139 | #print(PREFIX + "GPUs per node : %i" % params.n_gpu_per_node) 140 | #print(PREFIX + "Master : %s" % str(params.is_master)) 141 | #print(PREFIX + "Multi-node : %s" % str(params.multi_node)) 142 | #print(PREFIX + "Multi-GPU : %s" % str(params.multi_gpu)) 143 | #print(PREFIX + "Hostname : %s" % socket.gethostname()) 144 | 145 | # set GPU device 146 | torch.cuda.set_device(params.local_rank) 147 | 148 | # initialize multi-GPU 149 | if params.multi_gpu: 150 | 151 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 152 | # 'env://' will read these environment variables: 153 | # MASTER_PORT - required; has to be a free port on machine with rank 0 154 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 155 | # WORLD_SIZE - required; can be set either here, or in a call to init function 156 | # RANK - required; can be set either here, or in a call to init function 157 | 158 | #print("Initializing PyTorch distributed ...") 159 | torch.distributed.init_process_group( 160 | init_method='env://', 161 | backend='nccl', 162 | ) 163 | -------------------------------------------------------------------------------- /FiD/t5blocks.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License 15 | import copy 16 | import logging 17 | import math 18 | import os 19 | import warnings 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | import torch.utils.checkpoint 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss 26 | from transformers.configuration_t5 import T5Config 27 | 28 | 29 | class T5LayerNorm(nn.Module): 30 | def __init__(self, hidden_size, eps=1e-6): 31 | """ Construct a layernorm module in the T5 style 32 | No bias and no substraction of mean. 33 | """ 34 | super().__init__() 35 | self.weight = nn.Parameter(torch.ones(hidden_size)) 36 | self.variance_epsilon = eps 37 | 38 | def forward(self, x): 39 | # layer norm should always be calculated in float32 40 | variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) 41 | x = x / torch.sqrt(variance + self.variance_epsilon) 42 | 43 | if self.weight.dtype == torch.float16: 44 | x = x.to(torch.float16) 45 | return self.weight * x 46 | 47 | 48 | class T5DenseReluDense(nn.Module): 49 | def __init__(self, config): 50 | super().__init__() 51 | self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) 52 | self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) 53 | self.dropout = nn.Dropout(config.dropout_rate) 54 | 55 | def forward(self, hidden_states): 56 | h = self.wi(hidden_states) 57 | h = F.relu(h) 58 | h = self.dropout(h) 59 | h = self.wo(h) 60 | return h 61 | 62 | 63 | class T5LayerFF(nn.Module): 64 | def __init__(self, config): 65 | super().__init__() 66 | self.DenseReluDense = T5DenseReluDense(config) 67 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 68 | self.dropout = nn.Dropout(config.dropout_rate) 69 | 70 | def forward(self, hidden_states): 71 | norm_x = self.layer_norm(hidden_states) 72 | y = self.DenseReluDense(norm_x) 73 | layer_output = hidden_states + self.dropout(y) 74 | return layer_output 75 | 76 | 77 | class T5Attention(nn.Module): 78 | def __init__(self, config, has_relative_attention_bias=False): 79 | super().__init__() 80 | self.is_decoder = config.is_decoder 81 | self.has_relative_attention_bias = has_relative_attention_bias 82 | 83 | self.relative_attention_num_buckets = config.relative_attention_num_buckets 84 | self.d_model = config.d_model 85 | self.d_kv = config.d_kv 86 | self.n_heads = config.num_heads 87 | self.dropout = config.dropout_rate 88 | self.inner_dim = self.n_heads * self.d_kv 89 | 90 | # Mesh TensorFlow initialization to avoid scaling before softmax 91 | self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) 92 | self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) 93 | self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) 94 | self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) 95 | 96 | if self.has_relative_attention_bias: 97 | self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) 98 | self.pruned_heads = set() 99 | 100 | def prune_heads(self, heads): 101 | if len(heads) == 0: 102 | return 103 | heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, self.d_kv, self.pruned_heads) 104 | # Prune linear layers 105 | self.q = prune_linear_layer(self.q, index) 106 | self.k = prune_linear_layer(self.k, index) 107 | self.v = prune_linear_layer(self.v, index) 108 | self.o = prune_linear_layer(self.o, index, dim=1) 109 | # Update hyper params 110 | self.n_heads = self.n_heads - len(heads) 111 | self.inner_dim = self.d_kv * self.n_heads 112 | self.pruned_heads = self.pruned_heads.union(heads) 113 | 114 | @staticmethod 115 | def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): 116 | """ 117 | Adapted from Mesh Tensorflow: 118 | https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 119 | Translate relative position to a bucket number for relative attention. 120 | The relative position is defined as memory_position - query_position, i.e. 121 | the distance in tokens from the attending position to the attended-to 122 | position. If bidirectional=False, then positive relative positions are 123 | invalid. 124 | We use smaller buckets for small absolute relative_position and larger buckets 125 | for larger absolute relative_positions. All relative positions >=max_distance 126 | map to the same bucket. All relative positions <=-max_distance map to the 127 | same bucket. This should allow for more graceful generalization to longer 128 | sequences than the model has been trained on. 129 | Args: 130 | relative_position: an int32 Tensor 131 | bidirectional: a boolean - whether the attention is bidirectional 132 | num_buckets: an integer 133 | max_distance: an integer 134 | Returns: 135 | a Tensor with the same shape as relative_position, containing int32 136 | values in the range [0, num_buckets) 137 | """ 138 | ret = 0 139 | n = -relative_position 140 | if bidirectional: 141 | num_buckets //= 2 142 | ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets 143 | n = torch.abs(n) 144 | else: 145 | n = torch.max(n, torch.zeros_like(n)) 146 | # now n is in the range [0, inf) 147 | 148 | # half of the buckets are for exact increments in positions 149 | max_exact = num_buckets // 2 150 | is_small = n < max_exact 151 | 152 | # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance 153 | val_if_large = max_exact + ( 154 | torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) 155 | ).to(torch.long) 156 | val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) 157 | 158 | ret += torch.where(is_small, n, val_if_large) 159 | return ret 160 | 161 | def compute_bias(self, qlen, klen): 162 | """ Compute binned relative position bias """ 163 | context_position = torch.arange(qlen, dtype=torch.long)[:, None] 164 | memory_position = torch.arange(klen, dtype=torch.long)[None, :] 165 | relative_position = memory_position - context_position # shape (qlen, klen) 166 | rp_bucket = self._relative_position_bucket( 167 | relative_position, # shape (qlen, klen) 168 | bidirectional=not self.is_decoder, 169 | num_buckets=self.relative_attention_num_buckets, 170 | ) 171 | rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device) 172 | values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) 173 | values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, qlen, klen) 174 | return values 175 | 176 | def forward( 177 | self, 178 | input, 179 | mask=None, 180 | kv=None, 181 | position_bias=None, 182 | past_key_value_state=None, 183 | head_mask=None, 184 | query_length=None, 185 | use_cache=False, 186 | output_attentions=False, 187 | ): 188 | """ 189 | Self-attention (if kv is None) or attention over source sentence (provided by kv). 190 | """ 191 | # Input is (bs, qlen, dim) 192 | # Mask is (bs, klen) (non-causal) or (bs, klen, klen) 193 | # past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head) 194 | bs, qlen, dim = input.size() 195 | 196 | if past_key_value_state is not None: 197 | assert self.is_decoder is True, "Encoder cannot cache past key value states" 198 | assert ( 199 | len(past_key_value_state) == 2 200 | ), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format( 201 | len(past_key_value_state) 202 | ) 203 | real_qlen = qlen + past_key_value_state[0].shape[2] if query_length is None else query_length 204 | else: 205 | real_qlen = qlen 206 | 207 | if kv is None: 208 | klen = real_qlen 209 | else: 210 | klen = kv.size(1) 211 | 212 | def shape(x): 213 | """ projection """ 214 | return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2) 215 | 216 | def unshape(x): 217 | """ compute context """ 218 | return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim) 219 | 220 | q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head) 221 | 222 | if kv is None: 223 | k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head) 224 | v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head) 225 | elif past_key_value_state is None: 226 | k = v = kv 227 | k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head) 228 | v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head) 229 | 230 | if past_key_value_state is not None: 231 | if kv is None: 232 | k_, v_ = past_key_value_state 233 | k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head) 234 | v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head) 235 | else: 236 | k, v = past_key_value_state 237 | 238 | if self.is_decoder and use_cache is True: 239 | present_key_value_state = ((k, v),) 240 | else: 241 | present_key_value_state = (None,) 242 | 243 | scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) 244 | 245 | if position_bias is None: 246 | if not self.has_relative_attention_bias: 247 | raise ValueError("No position_bias provided and no weights to compute position_bias") 248 | position_bias = self.compute_bias(real_qlen, klen) 249 | 250 | # if key and values are already calculated 251 | # we want only the last query position bias 252 | if past_key_value_state is not None: 253 | position_bias = position_bias[:, :, -1:, :] 254 | 255 | if mask is not None: 256 | position_bias = position_bias + mask # (bs, n_heads, qlen, klen) 257 | 258 | scores += position_bias 259 | weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) 260 | weights = F.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen) 261 | 262 | # Mask heads if we want to 263 | if head_mask is not None: 264 | weights = weights * head_mask 265 | 266 | context = torch.matmul(weights, v) # (bs, n_heads, qlen, dim_per_head) 267 | context = unshape(context) # (bs, qlen, dim) 268 | 269 | context = self.o(context) 270 | 271 | outputs = (context,) + present_key_value_state 272 | 273 | if output_attentions: 274 | outputs = outputs + (weights,) 275 | if self.has_relative_attention_bias: 276 | outputs = outputs + (position_bias,) 277 | return outputs 278 | 279 | 280 | class T5LayerSelfAttention(nn.Module): 281 | def __init__(self, config, has_relative_attention_bias=False): 282 | super().__init__() 283 | self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) 284 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 285 | self.dropout = nn.Dropout(config.dropout_rate) 286 | 287 | def forward( 288 | self, 289 | hidden_states, 290 | attention_mask=None, 291 | position_bias=None, 292 | head_mask=None, 293 | past_key_value_state=None, 294 | use_cache=False, 295 | output_attentions=False, 296 | ): 297 | norm_x = self.layer_norm(hidden_states) 298 | attention_output = self.SelfAttention( 299 | norm_x, 300 | mask=attention_mask, 301 | position_bias=position_bias, 302 | head_mask=head_mask, 303 | past_key_value_state=past_key_value_state, 304 | use_cache=use_cache, 305 | output_attentions=output_attentions, 306 | ) 307 | y = attention_output[0] 308 | layer_output = hidden_states + self.dropout(y) 309 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them 310 | return outputs 311 | 312 | 313 | class T5LayerCrossAttention(nn.Module): 314 | def __init__(self, config, has_relative_attention_bias=False): 315 | super().__init__() 316 | self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) 317 | self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) 318 | self.dropout = nn.Dropout(config.dropout_rate) 319 | 320 | def forward( 321 | self, 322 | hidden_states, 323 | kv, 324 | attention_mask=None, 325 | position_bias=None, 326 | head_mask=None, 327 | past_key_value_state=None, 328 | use_cache=False, 329 | query_length=None, 330 | output_attentions=False, 331 | ): 332 | norm_x = self.layer_norm(hidden_states) 333 | attention_output = self.EncDecAttention( 334 | norm_x, 335 | mask=attention_mask, 336 | kv=kv, 337 | position_bias=position_bias, 338 | head_mask=head_mask, 339 | past_key_value_state=past_key_value_state, 340 | use_cache=use_cache, 341 | query_length=query_length, 342 | output_attentions=output_attentions, 343 | ) 344 | y = attention_output[0] 345 | layer_output = hidden_states + self.dropout(y) 346 | outputs = (layer_output,) + attention_output[1:] # add attentions if we output them 347 | return outputs 348 | -------------------------------------------------------------------------------- /FiD/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | import logging 6 | import data 7 | import util 8 | from fidt5 import FiDT5 9 | import numpy as np 10 | from pathlib import Path 11 | import torch.distributed as dist 12 | from options import Options 13 | from torch.utils.data import DataLoader, SequentialSampler 14 | import evaluation 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def evaluate(model, dataset, dataloader, tokenizer, opt): 20 | model.eval() 21 | if hasattr(model, "module"): 22 | model = model.module 23 | total = 0 24 | ems = [] 25 | 26 | fw = None 27 | if opt.write_results: 28 | write_path = os.path.join(opt.checkpoint_dir, opt.name, 'test_results') 29 | fw = open(os.path.join(write_path, '%d.txt' % opt.global_rank), 'w') 30 | 31 | with torch.no_grad(): 32 | for i, batch in enumerate(dataloader): 33 | idx, answer_ids, answer_mask, context_ids, context_mask = batch 34 | # answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 35 | model.encoder.n_passages = context_ids.size(1) 36 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 37 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 38 | 39 | outputs = model.generate( 40 | input_ids=context_ids, 41 | attention_mask=context_mask, 42 | max_length=50, 43 | ) 44 | 45 | for k, o in enumerate(outputs): 46 | ans = tokenizer.decode(o, skip_special_tokens=True) 47 | example = dataset.get_example(idx[k]) 48 | question = example.question 49 | gold = example.answers 50 | id = example.id 51 | ems_score = evaluation.ems(ans, gold) 52 | ems.append(ems_score) 53 | 54 | if fw is not None: 55 | fw.write(str(id) + "\t" + ans + '\n') 56 | 57 | total += 1 58 | 59 | if (i + 1) % opt.eval_print_freq == 0: 60 | logger.warning(f"{opt.global_rank}, {i + 1} / {len(dataloader)} -- average = {np.mean(ems):.3f}") 61 | 62 | logger.warning(f"{opt.global_rank}, total {total} -- average = {np.mean(ems):.3f}") 63 | if opt.world_size > 1 and not opt.local_rank == -1: 64 | torch.distributed.barrier() 65 | score, total = util.weighted_average(np.mean(ems), total, opt) 66 | logger.info('total number of example %d' % total) 67 | return score 68 | 69 | 70 | if __name__ == "__main__": 71 | options = Options() 72 | opt = options.parse() 73 | opt.train_batch_size = opt.per_gpu_batch_size 74 | logger.info("Distributed training") 75 | opt.is_master = True 76 | 77 | 78 | model_name = 't5-' + opt.model_size 79 | model_class = FiDT5 80 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name, return_dict=False) 81 | 82 | collator_function = data.Collator(opt, tokenizer) 83 | test_examples = data.load_data(opt.test_data_path) 84 | test_dataset = data.Dataset(test_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 85 | 86 | test_sampler = SequentialSampler(test_dataset) 87 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=opt.per_gpu_batch_size, 88 | shuffle=False, num_workers=4, collate_fn=collator_function) 89 | 90 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 91 | directory_exists = os.path.exists(dir_path) 92 | if opt.world_size > 1 and not opt.local_rank == -1: 93 | torch.distributed.barrier() 94 | os.makedirs(dir_path, exist_ok=True) 95 | if opt.write_results: 96 | os.makedirs(os.path.join(dir_path, 'test_results'), exist_ok=True) 97 | if not directory_exists and opt.is_master: 98 | options.print_options(opt) 99 | if opt.world_size > 1 and not opt.local_rank == -1: 100 | torch.distributed.barrier() 101 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 102 | stdout_handler = logging.StreamHandler(sys.stdout) 103 | handlers = [file_handler, stdout_handler] 104 | logging.basicConfig( 105 | datefmt="%m/%d/%Y %H:%M:%S", 106 | level=logging.INFO if opt.is_master else logging.WARN, 107 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 108 | handlers=handlers, 109 | ) 110 | 111 | model = model_class.from_pretrained(opt.model_path) 112 | 113 | # model = model_class.from_pretrained('t5-large') 114 | # quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) 115 | # quantized_dict = torch.load('pretrained_models/nq_large_dpr_int8/checkpoint/best_dev/pytorch_model.bin') 116 | # quantized_model.load_state_dict(quantized_dict) 117 | # qm = list(quantized_model.modules()) 118 | # qml=list(filter(lambda x: type(x) == torch.nn.quantized.dynamic.modules.linear.Linear, qm)) 119 | # counter = 0 120 | # with torch.no_grad(): 121 | # for mod in model.modules(): 122 | # if type(mod) == torch.nn.Linear: 123 | # mod.weight.copy_(torch.dequantize(qml[counter].weight())) 124 | # counter += 1 125 | 126 | model = model.cuda() 127 | 128 | logger.info("Start eval") 129 | ems = evaluate(model, test_dataset, test_dataloader, tokenizer, opt) 130 | 131 | if opt.write_results and opt.is_master: 132 | print(opt.is_master) 133 | glob_path = Path(opt.checkpoint_dir) / opt.name / 'test_results' 134 | write_path = Path(opt.checkpoint_dir) / opt.name / 'final_output.json' 135 | util.write_output(glob_path, write_path) 136 | 137 | logger.info("EM %.6f" % (ems)) 138 | -------------------------------------------------------------------------------- /FiD/test_ac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | import logging 6 | import data 7 | import util 8 | from fidt5_ac import ACFiDT5 9 | import numpy as np 10 | from pathlib import Path 11 | import torch.distributed as dist 12 | from options import Options 13 | from torch.utils.data import DataLoader, SequentialSampler 14 | import evaluation 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def evaluate(model, dataset, dataloader, tokenizer, opt): 20 | model.eval() 21 | if hasattr(model, "module"): 22 | model = model.module 23 | 24 | # Set AC parameters 25 | model.encoder.budget = opt.budget 26 | model.encoder.num_passages_retained = opt.num_passages_retained 27 | logger.warning(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}") 28 | 29 | total = 0 30 | ems = [] 31 | 32 | fw = None 33 | if opt.write_results: 34 | write_path = os.path.join(opt.checkpoint_dir, opt.name, 'test_results') 35 | fw = open(os.path.join(write_path, '%d.txt' % opt.global_rank), 'w') 36 | 37 | with torch.no_grad(): 38 | for i, batch in enumerate(dataloader): 39 | idx, answer_ids, answer_mask, context_ids, context_mask = batch 40 | # answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 41 | model.encoder.n_passages = context_ids.size(1) 42 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 43 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 44 | 45 | outputs = model.generate( 46 | input_ids=context_ids, 47 | attention_mask=context_mask, 48 | max_length=50, 49 | ) 50 | 51 | for k, o in enumerate(outputs): 52 | ans = tokenizer.decode(o, skip_special_tokens=True) 53 | example = dataset.get_example(idx[k]) 54 | question = example.question 55 | gold = example.answers 56 | id = example.id 57 | ems_score = evaluation.ems(ans, gold) 58 | ems.append(ems_score) 59 | 60 | if fw is not None: 61 | fw.write(str(id) + "\t" + ans + '\n') 62 | 63 | total += 1 64 | 65 | if (i + 1) % opt.eval_print_freq == 0: 66 | logger.warning(f"{opt.global_rank}, {i + 1} / {len(dataloader)} -- average = {np.mean(ems):.3f}") 67 | 68 | logger.warning(f"{opt.global_rank}, total {total} -- average = {np.mean(ems):.3f}") 69 | if opt.world_size > 1 and not opt.local_rank == -1: 70 | torch.distributed.barrier() 71 | score, total = util.weighted_average(np.mean(ems), total, opt) 72 | logger.info('total number of example %d' % total) 73 | return score 74 | 75 | 76 | if __name__ == "__main__": 77 | options = Options() 78 | opt = options.parse() 79 | opt.train_batch_size = opt.per_gpu_batch_size 80 | logger.info("Distributed training") 81 | opt.is_master = True 82 | 83 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 84 | 85 | model_name = 't5-' + opt.model_size 86 | model_class = ACFiDT5 87 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name, return_dict=False) 88 | 89 | collator_function = data.Collator(opt, tokenizer) 90 | test_examples = data.load_data(opt.test_data_path) 91 | test_dataset = data.Dataset(test_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 92 | 93 | test_sampler = SequentialSampler(test_dataset) 94 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=opt.per_gpu_batch_size, 95 | shuffle=False, num_workers=4, collate_fn=collator_function) 96 | 97 | directory_exists = os.path.exists(dir_path) 98 | if opt.world_size > 1 and not opt.local_rank == -1: 99 | torch.distributed.barrier() 100 | os.makedirs(dir_path, exist_ok=True) 101 | if opt.write_results: 102 | os.makedirs(os.path.join(dir_path, 'test_results'), exist_ok=True) 103 | if not directory_exists and opt.is_master: 104 | options.print_options(opt) 105 | if opt.world_size > 1 and not opt.local_rank == -1: 106 | torch.distributed.barrier() 107 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 108 | stdout_handler = logging.StreamHandler(sys.stdout) 109 | handlers = [file_handler, stdout_handler] 110 | logging.basicConfig( 111 | datefmt="%m/%d/%Y %H:%M:%S", 112 | level=logging.INFO if opt.is_master else logging.WARN, 113 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 114 | handlers=handlers, 115 | ) 116 | 117 | model = model_class.from_pretrained(opt.model_path) 118 | 119 | # model = model_class.from_pretrained('t5-large') 120 | # quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) 121 | # quantized_dict = torch.load('pretrained_models/nq_large_dpr_int8/checkpoint/best_dev/pytorch_model.bin') 122 | # quantized_model.load_state_dict(quantized_dict) 123 | # qm = list(quantized_model.modules()) 124 | # qml=list(filter(lambda x: type(x) == torch.nn.quantized.dynamic.modules.linear.Linear, qm)) 125 | # counter = 0 126 | # with torch.no_grad(): 127 | # for mod in model.modules(): 128 | # if type(mod) == torch.nn.Linear: 129 | # mod.weight.copy_(torch.dequantize(qml[counter].weight())) 130 | # counter += 1 131 | 132 | model = model.cuda() 133 | 134 | logger.info("Start eval") 135 | ems = evaluate(model, test_dataset, test_dataloader, tokenizer, opt) 136 | 137 | if opt.write_results and opt.is_master: 138 | print(opt.is_master) 139 | glob_path = Path(opt.checkpoint_dir) / opt.name / 'test_results' 140 | write_path = Path(opt.checkpoint_dir) / opt.name / 'final_output.json' 141 | util.write_output(glob_path, write_path) 142 | 143 | logger.info("EM %.6f" % (ems)) 144 | -------------------------------------------------------------------------------- /FiD/test_ac_scheduler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | import logging 6 | import util 7 | import numpy as np 8 | from pathlib import Path 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, SequentialSampler 11 | import evaluation 12 | 13 | from options import Options 14 | from fidt5_ac import ACFiDT5, T5Config 15 | import data_ac as data 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def evaluate(model, dataset, dataloader, tokenizer, opt): 21 | model.eval() 22 | if hasattr(model, "module"): 23 | model = model.module 24 | 25 | total = 0 26 | ems = [] 27 | all_layer_cost = [] 28 | 29 | fw = None 30 | if opt.write_results: 31 | write_path = os.path.join(opt.checkpoint_dir, opt.name, 'test_results') 32 | fw = open(os.path.join(write_path, '%d.txt' % opt.global_rank), 'w') 33 | 34 | with torch.no_grad(): 35 | for i, batch in enumerate(dataloader): 36 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 37 | model.encoder.n_passages = context_ids.size(1) 38 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 39 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 40 | 41 | outputs, layer_cost = model.generate( 42 | input_ids=context_ids, 43 | attention_mask=context_mask, 44 | max_length=50, 45 | ) 46 | 47 | for k, o in enumerate(outputs): 48 | ans = tokenizer.decode(o, skip_special_tokens=True) 49 | example = dataset.get_example(idx[k]) 50 | question = example.question 51 | gold = example.answers 52 | id = example.id 53 | ems_score = evaluation.ems(ans, gold) 54 | ems.append(ems_score) 55 | 56 | if fw is not None: 57 | fw.write(f"{id}\t{question}\t{ans}\n") 58 | 59 | total += 1 60 | 61 | for c in layer_cost: 62 | all_layer_cost.append(c.item()) 63 | 64 | if (i + 1) % opt.eval_print_freq == 0: 65 | logger.warning(f"{opt.global_rank}, {i + 1} / {len(dataloader)} -- average = {np.mean(ems):.3f}") 66 | 67 | logger.warning(f"{opt.global_rank}, total {total} -- average = {np.mean(ems):.3f}") 68 | if opt.world_size > 1 and not opt.local_rank == -1: 69 | torch.distributed.barrier() 70 | score, total = util.weighted_average(np.mean(ems), total, opt) 71 | logger.info('total number of example %d' % total) 72 | logger.info(f"average EM = {score:.5f}") 73 | avg_layer_cost = np.mean(all_layer_cost) 74 | logger.info(f"average layer cost = {avg_layer_cost:.3f}") 75 | 76 | # write result 77 | with open(os.path.join(opt.checkpoint_dir, "all_results"), "a") as f: 78 | f.write(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}, " 79 | f"layer cost = {avg_layer_cost}, EM = {score}\n") 80 | 81 | return score 82 | 83 | 84 | if __name__ == "__main__": 85 | options = Options() 86 | opt = options.parse() 87 | opt.train_batch_size = opt.per_gpu_batch_size 88 | logger.info("Distributed training") 89 | opt.is_master = True 90 | 91 | model_name = 't5-' + opt.model_size 92 | model_class = ACFiDT5 93 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name, return_dict=False) 94 | 95 | collator_function = data.Collator(opt, tokenizer) 96 | test_examples = data.load_data(opt.test_data_path, n_context=opt.n_context) 97 | test_dataset = data.Dataset(test_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 98 | 99 | test_sampler = SequentialSampler(test_dataset) 100 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=opt.per_gpu_batch_size, 101 | shuffle=False, num_workers=4, collate_fn=collator_function) 102 | 103 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 104 | directory_exists = os.path.exists(dir_path) 105 | if opt.world_size > 1 and not opt.local_rank == -1: 106 | torch.distributed.barrier() 107 | os.makedirs(dir_path, exist_ok=True) 108 | if opt.write_results: 109 | os.makedirs(os.path.join(dir_path, 'test_results'), exist_ok=True) 110 | if not directory_exists and opt.is_master: 111 | options.print_options(opt) 112 | if opt.world_size > 1 and not opt.local_rank == -1: 113 | torch.distributed.barrier() 114 | 115 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 116 | stdout_handler = logging.StreamHandler(sys.stdout) 117 | handlers = [file_handler, stdout_handler] 118 | logging.basicConfig( 119 | datefmt="%m/%d/%Y %H:%M:%S", 120 | level=logging.INFO if opt.is_master else logging.WARN, 121 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 122 | handlers=handlers, 123 | ) 124 | 125 | # Load model 126 | logger.info("Loading %s" % opt.model_path) 127 | config = T5Config.from_pretrained(opt.model_path) 128 | 129 | # # Update AC-FiD config with arguments 130 | # if opt.has_answer_pool_type != "none": 131 | # config.has_answer_pool_type = opt.has_answer_pool_type 132 | # if opt.scheduler_type != "none": 133 | # config.scheduler_type = opt.scheduler_type 134 | # config.scheduler_n_context = opt.scheduler_n_context 135 | # config.scheduler_embed_size = opt.scheduler_embed_size 136 | # config.scheduler_hidden_size = opt.scheduler_hidden_size 137 | 138 | model = model_class.from_pretrained(opt.model_path, config=config) 139 | model = model.cuda() 140 | logger.info("Model loaded from %s" % opt.model_path) 141 | logger.info("Model config %s", str(config)) 142 | 143 | # Set model training configs (a hack around) which are only used during training 144 | model.encoder.checkpoint = False 145 | model.decoder.checkpoint = False 146 | model.encoder.n_passages = opt.n_context 147 | model.freeze_fid_params = False # config for training has_answer_heads with FiD parameters froze 148 | 149 | if opt.n_context > config.scheduler_n_context: 150 | raise ValueError(f"n_context can not exceed scheduler_n_context={config.scheduler_n_context}") 151 | 152 | # Set the parameters of the AC scheduler 153 | model.encoder.budget = opt.budget # config for training/evaluating AC scheduler 154 | model.encoder.num_passages_retained = opt.num_passages_retained # config for training/evaluating AC scheduler 155 | model.freeze_has_answer_heads = False 156 | model.step_cost = 0. 157 | model.discount = 1. 158 | logger.warning(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}") 159 | 160 | logger.info("Start eval") 161 | ems = evaluate(model, test_dataset, test_dataloader, tokenizer, opt) 162 | 163 | if opt.write_results and opt.is_master: 164 | print(opt.is_master) 165 | glob_path = Path(opt.checkpoint_dir) / opt.name / 'test_results' 166 | write_path = Path(opt.checkpoint_dir) / opt.name / 'final_output.json' 167 | util.write_output(glob_path, write_path) 168 | 169 | logger.info("EM %.6f" % (ems)) 170 | -------------------------------------------------------------------------------- /FiD/test_retrieval_acc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | import logging 6 | import util 7 | import numpy as np 8 | from pathlib import Path 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, SequentialSampler 11 | import evaluation 12 | 13 | from options import Options 14 | from fidt5_ac import ACFiDT5, T5Config 15 | import data_ac as data 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | from evaluation import SimpleTokenizer, has_answer 20 | 21 | simple_tokenizer = SimpleTokenizer() 22 | 23 | 24 | def evaluate(model, dataset, dataloader, tokenizer, opt): 25 | model.eval() 26 | if hasattr(model, "module"): 27 | model = model.module 28 | 29 | all_accuracies = [] 30 | 31 | with torch.no_grad(): 32 | for i, batch in enumerate(dataloader): 33 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 34 | model.encoder.n_passages = context_ids.size(1) 35 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 36 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 37 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 38 | decoder_input_ids = None 39 | has_answer_labels = None 40 | # labels = answer_ids.masked_fill(~answer_mask, -100) 41 | labels = None 42 | 43 | inputs = { 44 | 'input_ids': context_ids, 45 | 'attention_mask': context_mask, 46 | 'decoder_attention_mask': answer_mask, 47 | 'decoder_input_ids': decoder_input_ids, 48 | 'labels': labels, 49 | 'has_answer_labels': has_answer_labels, 50 | } 51 | outputs = model(**inputs) 52 | scheduler_outputs = outputs[-2] 53 | actions, log_probs, all_skylines, retained_passages = scheduler_outputs 54 | 55 | # retained_passages: [bsz, num_passages_retained] 56 | for j, psg_ranks in enumerate(retained_passages): 57 | answer_acc = 0 # 1 if the selected top-k passages contain the answer, 0 otherwise 58 | example = dataset.get_example(idx[j]) 59 | answers = example.answers 60 | for k, rank in enumerate(psg_ranks): 61 | rank = rank.item() 62 | context = example.contexts[rank] 63 | if has_answer(answers, context, simple_tokenizer): 64 | answer_acc = 1 65 | break 66 | all_accuracies.append(answer_acc) 67 | 68 | accuracy = np.mean(all_accuracies) 69 | 70 | logger.info('total number of example %d' % len(all_accuracies)) 71 | logger.info(f"top-k retrieval accuracy = {accuracy:.5f}") 72 | 73 | # write result 74 | with open(os.path.join(opt.checkpoint_dir, "retrieval_acc"), "a") as f: 75 | f.write(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}, " 76 | f"accuracy = {accuracy}\n") 77 | 78 | return accuracy 79 | 80 | 81 | if __name__ == "__main__": 82 | options = Options() 83 | opt = options.parse() 84 | opt.train_batch_size = opt.per_gpu_batch_size 85 | logger.info("Distributed training") 86 | opt.is_master = True 87 | 88 | model_name = 't5-' + opt.model_size 89 | model_class = ACFiDT5 90 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name, return_dict=False) 91 | 92 | collator_function = data.Collator(opt, tokenizer) 93 | test_examples = data.load_data(opt.test_data_path, n_context=opt.n_context) 94 | test_dataset = data.Dataset(test_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 95 | 96 | test_sampler = SequentialSampler(test_dataset) 97 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=opt.per_gpu_batch_size, 98 | shuffle=False, num_workers=4, collate_fn=collator_function) 99 | 100 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 101 | directory_exists = os.path.exists(dir_path) 102 | if opt.world_size > 1 and not opt.local_rank == -1: 103 | torch.distributed.barrier() 104 | os.makedirs(dir_path, exist_ok=True) 105 | if opt.write_results: 106 | os.makedirs(os.path.join(dir_path, 'test_results'), exist_ok=True) 107 | if not directory_exists and opt.is_master: 108 | options.print_options(opt) 109 | if opt.world_size > 1 and not opt.local_rank == -1: 110 | torch.distributed.barrier() 111 | 112 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 113 | stdout_handler = logging.StreamHandler(sys.stdout) 114 | handlers = [file_handler, stdout_handler] 115 | logging.basicConfig( 116 | datefmt="%m/%d/%Y %H:%M:%S", 117 | level=logging.INFO if opt.is_master else logging.WARN, 118 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 119 | handlers=handlers, 120 | ) 121 | 122 | # Load model 123 | logger.info("Loading %s" % opt.model_path) 124 | config = T5Config.from_pretrained(opt.model_path) 125 | 126 | # # Update AC-FiD config with arguments 127 | # if opt.has_answer_pool_type != "none": 128 | # config.has_answer_pool_type = opt.has_answer_pool_type 129 | # if opt.scheduler_type != "none": 130 | # config.scheduler_type = opt.scheduler_type 131 | # config.scheduler_n_context = opt.scheduler_n_context 132 | # config.scheduler_embed_size = opt.scheduler_embed_size 133 | # config.scheduler_hidden_size = opt.scheduler_hidden_size 134 | 135 | model = model_class.from_pretrained(opt.model_path, config=config) 136 | model = model.cuda() 137 | logger.info("Model loaded from %s" % opt.model_path) 138 | logger.info("Model config %s", str(config)) 139 | 140 | # Set model training configs (a hack around) which are only used during training 141 | model.encoder.checkpoint = False 142 | model.decoder.checkpoint = False 143 | model.encoder.n_passages = opt.n_context 144 | model.freeze_fid_params = True # config for training has_answer_heads with FiD parameters froze 145 | 146 | if opt.n_context > config.scheduler_n_context: 147 | raise ValueError(f"n_context can not exceed scheduler_n_context={config.scheduler_n_context}") 148 | 149 | # Set the parameters of the AC scheduler 150 | model.encoder.budget = opt.budget # config for training/evaluating AC scheduler 151 | model.encoder.num_passages_retained = opt.num_passages_retained # config for training/evaluating AC scheduler 152 | model.freeze_has_answer_heads = True 153 | model.step_cost = 0. 154 | model.discount = 1. 155 | logger.warning(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}") 156 | 157 | logger.info("Start eval") 158 | accuracy = evaluate(model, test_dataset, test_dataloader, tokenizer, opt) 159 | 160 | logger.info("accuracy %.6f" % (accuracy)) 161 | -------------------------------------------------------------------------------- /FiD/test_retriever_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import transformers 5 | import logging 6 | import util 7 | import numpy as np 8 | from pathlib import Path 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, SequentialSampler 11 | 12 | import evaluation 13 | from options import Options 14 | from fidt5_ac import ACFiDT5, T5Config 15 | import data_ac as data 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | from evaluation import SimpleTokenizer, has_answer 20 | 21 | simple_tokenizer = SimpleTokenizer() 22 | 23 | 24 | def evaluate(dataset, dataloader, opt): 25 | all_accuracies = [] 26 | 27 | for i, batch in enumerate(dataloader): 28 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 29 | # model.encoder.n_passages = context_ids.size(1) 30 | # answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 31 | # context_ids = context_ids.cuda().view(context_ids.size(0), -1) 32 | # context_mask = context_mask.cuda().view(context_ids.size(0), -1) 33 | # decoder_input_ids = None 34 | # has_answer_labels = None 35 | # # labels = answer_ids.masked_fill(~answer_mask, -100) 36 | # labels = None 37 | # 38 | # inputs = { 39 | # 'input_ids': context_ids, 40 | # 'attention_mask': context_mask, 41 | # 'decoder_attention_mask': answer_mask, 42 | # 'decoder_input_ids': decoder_input_ids, 43 | # 'labels': labels, 44 | # 'has_answer_labels': has_answer_labels, 45 | # } 46 | # outputs = model(**inputs) 47 | # scheduler_outputs = outputs[-2] 48 | # actions, log_probs, all_skylines, retained_passages = scheduler_outputs 49 | 50 | # retained_passages: [bsz, num_passages_retained] 51 | for j, index in enumerate(idx): 52 | answer_acc = 0 # 1 if the selected top-k passages contain the answer, 0 otherwise 53 | example = dataset.get_example(index) 54 | answers = example.answers 55 | for k in range(opt.num_passages_retained): 56 | context = example.contexts[k] 57 | if has_answer(answers, context, simple_tokenizer): 58 | answer_acc = 1 59 | break 60 | all_accuracies.append(answer_acc) 61 | 62 | accuracy = np.mean(all_accuracies) 63 | 64 | logger.info('total number of example %d' % len(all_accuracies)) 65 | logger.info(f"top-k retrieval accuracy = {accuracy:.5f}") 66 | 67 | # # write result 68 | # with open(os.path.join(opt.checkpoint_dir, "retrieval_acc"), "a") as f: 69 | # f.write(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}, " 70 | # f"accuracy = {accuracy}\n") 71 | 72 | return accuracy 73 | 74 | 75 | if __name__ == "__main__": 76 | options = Options() 77 | opt = options.parse() 78 | opt.train_batch_size = opt.per_gpu_batch_size 79 | logger.info("Distributed training") 80 | opt.is_master = True 81 | 82 | model_name = 't5-' + opt.model_size 83 | # model_class = ACFiDT5 84 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name, return_dict=False) 85 | 86 | collator_function = data.Collator(opt, tokenizer) 87 | test_examples = data.load_data(opt.test_data_path, n_context=opt.n_context) 88 | test_dataset = data.Dataset(test_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 89 | 90 | test_sampler = SequentialSampler(test_dataset) 91 | test_dataloader = DataLoader(test_dataset, sampler=test_sampler, batch_size=opt.per_gpu_batch_size, 92 | shuffle=False, num_workers=4, collate_fn=collator_function) 93 | 94 | # dir_path = os.path.join(opt.checkpoint_dir, opt.name) 95 | # directory_exists = os.path.exists(dir_path) 96 | # if opt.world_size > 1 and not opt.local_rank == -1: 97 | # torch.distributed.barrier() 98 | # os.makedirs(dir_path, exist_ok=True) 99 | # if opt.write_results: 100 | # os.makedirs(os.path.join(dir_path, 'test_results'), exist_ok=True) 101 | # if not directory_exists and opt.is_master: 102 | # options.print_options(opt) 103 | # if opt.world_size > 1 and not opt.local_rank == -1: 104 | # torch.distributed.barrier() 105 | 106 | # file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 107 | # stdout_handler = logging.StreamHandler(sys.stdout) 108 | # handlers = [file_handler, stdout_handler] 109 | logging.basicConfig( 110 | datefmt="%m/%d/%Y %H:%M:%S", 111 | level=logging.INFO if opt.is_master else logging.WARN, 112 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 113 | # handlers=handlers, 114 | ) 115 | 116 | # Load model 117 | # logger.info("Loading %s" % opt.model_path) 118 | # config = T5Config.from_pretrained(opt.model_path) 119 | 120 | # # Update AC-FiD config with arguments 121 | # if opt.has_answer_pool_type != "none": 122 | # config.has_answer_pool_type = opt.has_answer_pool_type 123 | # if opt.scheduler_type != "none": 124 | # config.scheduler_type = opt.scheduler_type 125 | # config.scheduler_n_context = opt.scheduler_n_context 126 | # config.scheduler_embed_size = opt.scheduler_embed_size 127 | # config.scheduler_hidden_size = opt.scheduler_hidden_size 128 | 129 | # model = model_class.from_pretrained(opt.model_path, config=config) 130 | # model = model.cuda() 131 | # logger.info("Model loaded from %s" % opt.model_path) 132 | # logger.info("Model config %s", str(config)) 133 | 134 | # # Set model training configs (a hack around) which are only used during training 135 | # model.encoder.checkpoint = False 136 | # model.decoder.checkpoint = False 137 | # model.encoder.n_passages = opt.n_context 138 | # model.freeze_fid_params = True # config for training has_answer_heads with FiD parameters froze 139 | # 140 | # if opt.n_context > config.scheduler_n_context: 141 | # raise ValueError(f"n_context can not exceed scheduler_n_context={config.scheduler_n_context}") 142 | # 143 | # # Set the parameters of the AC scheduler 144 | # model.encoder.budget = opt.budget # config for training/evaluating AC scheduler 145 | # model.encoder.num_passages_retained = opt.num_passages_retained # config for training/evaluating AC scheduler 146 | # model.freeze_has_answer_heads = True 147 | # model.step_cost = 0. 148 | # model.discount = 1. 149 | logger.warning(f"budget = {opt.budget}, num_passages_retained = {opt.num_passages_retained}") 150 | 151 | logger.info("Start eval") 152 | accuracy = evaluate(test_dataset, test_dataloader, opt) 153 | 154 | logger.info("accuracy %.6f" % (accuracy)) 155 | -------------------------------------------------------------------------------- /FiD/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | import torch 5 | import transformers 6 | # import slurm 7 | import logging 8 | import util 9 | import numpy as np 10 | import torch.distributed as dist 11 | from torch.utils.tensorboard import SummaryWriter 12 | from options import Options 13 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 14 | from fidt5 import FiDT5 15 | import evaluation 16 | import data 17 | from tqdm.auto import tqdm 18 | 19 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 20 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train_evaluate(model, optimizer, scheduler, global_step, 25 | train_dataset, dev_dataset, opt, collator_function, best_dev_em): 26 | if opt.is_master: 27 | tb_logger = SummaryWriter(os.path.join(opt.checkpoint_dir, opt.name)) 28 | 29 | train_sampler = (RandomSampler(train_dataset) if opt.local_rank == -1 or opt.world_size == 1 30 | else DistributedSampler(train_dataset)) 31 | dev_sampler = SequentialSampler(dev_dataset) 32 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 33 | batch_size=opt.per_gpu_batch_size, drop_last=True, num_workers=20, 34 | collate_fn=collator_function) 35 | dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=opt.per_gpu_batch_size, 36 | drop_last=True, num_workers=20, collate_fn=collator_function) 37 | 38 | loss, curr_loss = 0.0, 0.0 39 | epoch = 1 40 | model.train() 41 | while global_step < opt.total_step: 42 | if opt.world_size > 1: 43 | train_sampler.set_epoch(epoch) 44 | epoch += 1 45 | for i, batch in tqdm(enumerate(train_dataloader)): 46 | global_step += 1 47 | idx, answer_ids, answer_mask, context_ids, context_mask = batch 48 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 49 | labels = answer_ids.masked_fill(~answer_mask, -100) 50 | if hasattr(model, "module"): 51 | model.module.encoder.n_passages = context_ids.size(1) 52 | else: 53 | model.encoder.n_passages = context_ids.size(1) 54 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 55 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 56 | decoder_input_ids = None 57 | 58 | model.zero_grad() 59 | inputs = { 60 | 'input_ids': context_ids, 61 | 'attention_mask': context_mask, 62 | 'decoder_attention_mask': answer_mask, 63 | 'decoder_input_ids': decoder_input_ids, 64 | 'labels': labels, 65 | } 66 | train_loss = model(**inputs)[0] 67 | train_loss.backward() 68 | util.clip_gradients(model, opt.clip) 69 | optimizer.step() 70 | 71 | scheduler.step() 72 | 73 | train_loss = util.average_master(train_loss, opt) 74 | curr_loss += train_loss.item() 75 | 76 | if global_step % opt.eval_freq == 0: 77 | dev_em = evaluate(model, dev_dataset, dev_dataloader, tokenizer, opt) 78 | if opt.is_master: 79 | tb_logger.add_scalar("Evaluation", dev_em, global_step) 80 | if dev_em > best_dev_em: 81 | best_dev_em = dev_em 82 | if opt.is_master: 83 | model_to_save = model.module if hasattr(model, "module") else model 84 | util.save(model_to_save, optimizer, scheduler, global_step, best_dev_em, opt, dir_path, 85 | 'best_dev') 86 | model.train() 87 | if opt.is_master and global_step % opt.eval_freq == 0: 88 | logger.info( 89 | f"{global_step} / {opt.total_step} -- train = {curr_loss / opt.eval_freq:.3f} | evaluation = {100 * dev_em:.2f}EM | lr = {scheduler.get_last_lr()[0]:.5f}" 90 | ) 91 | tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), global_step) 92 | curr_loss = 0 93 | 94 | if opt.is_master and global_step % (50 * opt.eval_freq) == 0: 95 | model_to_save = model.module if hasattr(model, "module") else model 96 | util.save(model_to_save, optimizer, scheduler, global_step, best_dev_em, opt, dir_path, 97 | f"step-{global_step}") 98 | if global_step > opt.total_step: 99 | break 100 | 101 | 102 | def evaluate(model, dataset, dataloader, tokenizer, opt): 103 | model.eval() 104 | if hasattr(model, "module"): 105 | model = model.module 106 | total = 0 107 | ems = [] 108 | with torch.no_grad(): 109 | for i, batch in enumerate(dataloader): 110 | idx, answer_ids, answer_mask, context_ids, context_mask = batch 111 | if hasattr(model, "module"): 112 | model.module.encoder.n_passages = context_ids.size(1) 113 | else: 114 | model.encoder.n_passages = context_ids.size(1) 115 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 116 | context_mask = context_mask.cuda().view(context_mask.size(0), -1) 117 | 118 | outputs = model.generate( 119 | input_ids=context_ids, 120 | attention_mask=context_mask, 121 | max_length=50, 122 | ) 123 | 124 | for k, o in enumerate(outputs): 125 | ans = tokenizer.decode(o, skip_special_tokens=True) 126 | gold = dataset.get_example(idx[k]).answers 127 | ems_score = evaluation.ems(ans, gold) 128 | total += 1 129 | ems.append(ems_score) 130 | if opt.is_master and (i + 1) % opt.eval_print_freq == 0: 131 | logger.info(f"{i + 1} / {len(dataloader)} -- average = {100 * np.mean(ems):.2f}EM") 132 | 133 | score, total = util.weighted_average(np.mean(ems), total, opt) 134 | return score 135 | 136 | 137 | if __name__ == "__main__": 138 | options = Options() 139 | opt = options.parse() 140 | torch.manual_seed(opt.seed) 141 | # slurm.init_distributed_mode(opt) 142 | # slurm.init_signal_handler() 143 | opt.train_batch_size = opt.per_gpu_batch_size * max(1, opt.world_size) 144 | logger.info("Distributed training") 145 | 146 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 147 | 148 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 149 | 150 | model_name = 't5-' + opt.model_size 151 | model_class = FiDT5 152 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) 153 | 154 | collator_function = data.Collator(opt, tokenizer) 155 | 156 | train_examples = data.load_data(opt.train_data_path, n_context=opt.n_context) 157 | train_dataset = data.Dataset(train_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 158 | dev_examples = data.load_data(opt.dev_data_path, global_rank=opt.global_rank, 159 | # use the global rank and world size attibutes to split the dev set on multiple gpus 160 | world_size=opt.world_size, 161 | n_context=opt.n_context) 162 | dev_dataset = data.Dataset(dev_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 163 | 164 | directory_exists = os.path.exists(dir_path) 165 | if opt.world_size > 1 and not opt.local_rank == -1: 166 | torch.distributed.barrier() 167 | os.makedirs(dir_path, exist_ok=True) 168 | if not directory_exists and opt.is_master: 169 | options.print_options(opt) 170 | if opt.world_size > 1 and not opt.local_rank == -1: 171 | torch.distributed.barrier() 172 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 173 | stdout_handler = logging.StreamHandler(sys.stdout) 174 | handlers = [file_handler, stdout_handler] 175 | logging.basicConfig( 176 | datefmt="%m/%d/%Y %H:%M:%S", 177 | level=logging.INFO if opt.is_master else logging.WARN, 178 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 179 | handlers=handlers, 180 | ) 181 | 182 | if opt.world_size > 1 and not opt.local_rank == -1: 183 | torch.distributed.barrier() 184 | 185 | global_step = 0 186 | best_dev_em = 0. 187 | 188 | if not directory_exists and opt.model_path == "none": 189 | model = model_class.from_pretrained(model_name) 190 | model = model.to(0 if opt.local_rank == -1 else opt.local_rank) 191 | optimizer, scheduler = util.set_optim(opt, model) 192 | elif opt.model_path == "none": 193 | model, optimizer, scheduler, opt_checkpoint, global_step, best_dev_em = util.restore_epoch( 194 | model_class, dir_path, opt, reset_params=False, name="latest", 195 | ) 196 | logger.info("Model loaded from %s" % dir_path) 197 | else: 198 | model, optimizer, scheduler, opt_checkpoint, global_step, best_dev_em = util.restore_epoch( 199 | model_class, opt.model_path, opt, reset_params=True, name="latest", 200 | ) 201 | logger.info("Model loaded from %s" % opt.model_path) 202 | 203 | if opt.checkpointing_encoder: 204 | model.encoder.checkpoint = True 205 | if opt.checkpointing_decoder: 206 | model.decoder.checkpoint = True 207 | model.encoder.n_passages = opt.n_context 208 | 209 | if opt.world_size > 1 and opt.local_rank != -1: 210 | model = torch.nn.parallel.DistributedDataParallel( 211 | model, 212 | device_ids=[opt.local_rank], 213 | output_device=opt.local_rank, 214 | find_unused_parameters=False, 215 | ) 216 | 217 | logger.info("Start training") 218 | train_evaluate(model, optimizer, scheduler, global_step, 219 | train_dataset, dev_dataset, opt, collator_function, best_dev_em) 220 | -------------------------------------------------------------------------------- /FiD/train_ac.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | from collections import defaultdict 5 | import random 6 | import torch 7 | import transformers 8 | # import slurm 9 | import logging 10 | import util 11 | import numpy as np 12 | from tqdm.auto import tqdm 13 | 14 | import torch.distributed as dist 15 | from torch.utils.tensorboard import SummaryWriter 16 | from options import Options 17 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 18 | # import evaluation 19 | 20 | # ACFiD specific 21 | from fidt5_ac import ACFiDT5, T5Config 22 | import data_ac as data 23 | 24 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 25 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 26 | logger = logging.getLogger(__name__) 27 | 28 | # Initialise wandb 29 | try: 30 | import wandb 31 | 32 | wandb.ensure_configured() 33 | if wandb.api.api_key is None: 34 | _has_wandb = False 35 | wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") 36 | else: 37 | _has_wandb = False if os.getenv("WANDB_DISABLED") else True 38 | except (ImportError, AttributeError): 39 | _has_wandb = False 40 | 41 | 42 | def log_scalar(name, value, step): 43 | tb_logger.add_scalar(name, value, step) 44 | if _has_wandb: 45 | wandb.log({name: value, "step": step}) 46 | 47 | 48 | def train_evaluate(model, optimizer, scheduler, global_step, 49 | train_dataset, dev_dataset, opt, collator_function, best_metric): 50 | train_sampler = (RandomSampler(train_dataset) if opt.local_rank == -1 or opt.world_size == 1 51 | else DistributedSampler(train_dataset)) 52 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 53 | batch_size=opt.per_gpu_batch_size, drop_last=True, num_workers=3, 54 | collate_fn=collator_function) 55 | 56 | dev_sampler = SequentialSampler(dev_dataset) 57 | dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=opt.per_gpu_batch_size, 58 | drop_last=True, num_workers=3, collate_fn=collator_function) 59 | 60 | # Freeze the FiD parameters and only train AC part. 61 | trainable_np = list(model.named_parameters()) 62 | if opt.freeze_fid_params: 63 | new_np = [] 64 | for n, p in trainable_np: 65 | if n.startswith("encoder.has_answer_heads") or n.startswith("ac_scheduler"): 66 | p.requires_grad = True 67 | new_np.append((n, p)) 68 | else: 69 | p.requires_grad = False 70 | trainable_np = new_np 71 | 72 | if opt.freeze_has_answer_heads: 73 | new_np = [] 74 | for n, p in trainable_np: 75 | if n.startswith("encoder.has_answer_heads"): 76 | p.requires_grad = False 77 | else: 78 | p.requires_grad = True 79 | new_np.append((n, p)) 80 | trainable_np = new_np 81 | 82 | trainable_parameters = [p for n, p in trainable_np] 83 | # Prepare optimizer and schedule (linear warmup and decay) 84 | if optimizer is None or scheduler is None: 85 | optimizer = torch.optim.Adam(trainable_parameters, lr=opt.lr) 86 | scheduler = util.FixedScheduler(optimizer) 87 | 88 | # fp16 89 | if opt.fp16: 90 | try: 91 | from apex import amp 92 | except ImportError: 93 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 94 | 95 | model, optimizer = amp.initialize(model, optimizer, opt_level=opt.fp16_opt_level) 96 | 97 | # Distributed training 98 | if opt.world_size > 1 and opt.local_rank != -1: 99 | model = torch.nn.parallel.DistributedDataParallel( 100 | model, 101 | device_ids=[opt.local_rank], 102 | output_device=opt.local_rank, 103 | find_unused_parameters=False, 104 | ) 105 | 106 | # Train! 107 | logger.info("***** Running training *****") 108 | logger.info(" Num examples = %d", len(train_dataset)) 109 | # logger.info(" Num Epochs = %d", args.num_train_epochs) 110 | logger.info(" Instantaneous batch size per GPU = %d", opt.per_gpu_batch_size) 111 | logger.info( 112 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 113 | opt.train_batch_size * opt.gradient_accumulation_steps 114 | ) 115 | logger.info(" Gradient Accumulation steps = %d", opt.gradient_accumulation_steps) 116 | logger.info(" Total optimization steps = %d", opt.total_step) 117 | logger.info(" Total number of training epochs = %f", 118 | opt.total_step * opt.train_batch_size * opt.gradient_accumulation_steps / len(train_dataset)) 119 | 120 | loss, curr_loss = 0.0, 0.0 121 | epoch = 0 122 | step = 0 123 | model.train() 124 | model.zero_grad() 125 | while global_step < opt.total_step: 126 | epoch += 1 127 | if opt.world_size > 1: 128 | train_sampler.set_epoch(epoch) 129 | for i, batch in tqdm(enumerate(train_dataloader)): 130 | step += 1 131 | 132 | # Process the inputs 133 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 134 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 135 | has_answer_labels = has_answer_labels.cuda() 136 | labels = answer_ids.masked_fill(~answer_mask, -100) 137 | if hasattr(model, "module"): 138 | model.module.encoder.n_passages = context_ids.size(1) 139 | else: 140 | model.encoder.n_passages = context_ids.size(1) 141 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 142 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 143 | decoder_input_ids = None 144 | 145 | inputs = { 146 | 'input_ids': context_ids, 147 | 'attention_mask': context_mask, 148 | 'decoder_attention_mask': answer_mask, 149 | 'decoder_input_ids': decoder_input_ids, 150 | 'labels': labels, 151 | 'has_answer_labels': has_answer_labels, 152 | } 153 | 154 | # Run the model 155 | outputs = model(**inputs) 156 | train_loss = outputs[0] 157 | train_loss = util.average_master(train_loss, opt) 158 | 159 | if opt.gradient_accumulation_steps > 1: 160 | train_loss = train_loss / opt.gradient_accumulation_steps 161 | 162 | if opt.fp16: 163 | with amp.scale_loss(train_loss, optimizer) as scaled_loss: 164 | scaled_loss.backward() 165 | else: 166 | train_loss.backward() 167 | 168 | curr_loss += train_loss.item() 169 | if step % opt.gradient_accumulation_steps == 0: 170 | # util.clip_gradients(model, opt.clip) 171 | if opt.fp16: 172 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), opt.clip) 173 | else: 174 | torch.nn.utils.clip_grad_norm_(trainable_parameters, opt.clip) 175 | 176 | optimizer.step() 177 | scheduler.step() # Update learning rate schedule 178 | model.zero_grad() 179 | global_step += 1 180 | 181 | if opt.is_master and global_step % opt.log_freq == 0: 182 | logger.info( 183 | f"{global_step} / {opt.total_step} -- train loss = {curr_loss / opt.log_freq:.3f}" 184 | f" | lr = {scheduler.get_last_lr()[0]:.5f}" 185 | ) 186 | log_scalar("Train/Loss", curr_loss / opt.log_freq, global_step) 187 | curr_loss = 0. 188 | 189 | if global_step % opt.eval_freq == 0: 190 | results = evaluate(model, dev_dataset, dev_dataloader, tokenizer, opt) 191 | dev_f1 = results["avg_f1"] # use average F1 (across all layers) as evaluation metric 192 | if opt.is_master: 193 | logger.info(f"{global_step} / {opt.total_step} -- dev evaluation = {100 * dev_f1:.2f} F1") 194 | for k, v in results.items(): 195 | log_scalar(f"Dev/{k}", v, global_step) 196 | 197 | if dev_f1 > best_metric: 198 | best_metric = dev_f1 199 | if opt.is_master: 200 | model_to_save = model.module if hasattr(model, "module") else model 201 | util.save(model_to_save, optimizer, scheduler, global_step, best_metric, opt, dir_path, 202 | 'best_dev') 203 | model.train() 204 | 205 | if opt.is_master and global_step % opt.save_freq == 0: 206 | model_to_save = model.module if hasattr(model, "module") else model 207 | util.save(model_to_save, optimizer, scheduler, global_step, best_metric, opt, dir_path, 208 | f"step-{global_step}") 209 | if global_step > opt.total_step: 210 | break 211 | 212 | 213 | def evaluate(model, dataset, dataloader, tokenizer, opt): 214 | model.eval() 215 | if hasattr(model, "module"): 216 | model = model.module 217 | 218 | num_layers = model.encoder.config.num_layers 219 | all_results = [defaultdict(list) for _ in range(num_layers)] 220 | 221 | with torch.no_grad(): 222 | for i, batch in enumerate(dataloader): 223 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 224 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 225 | has_answer_labels = has_answer_labels.cuda() 226 | labels = answer_ids.masked_fill(~answer_mask, -100) 227 | if hasattr(model, "module"): 228 | model.module.encoder.n_passages = context_ids.size(1) 229 | else: 230 | model.encoder.n_passages = context_ids.size(1) 231 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 232 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 233 | decoder_input_ids = None 234 | 235 | inputs = { 236 | 'input_ids': context_ids, 237 | 'attention_mask': context_mask, 238 | 'decoder_attention_mask': answer_mask, 239 | 'decoder_input_ids': decoder_input_ids, 240 | 'labels': labels, 241 | 'has_answer_labels': has_answer_labels, 242 | } 243 | outputs = model(**inputs) 244 | all_has_answer_outputs = outputs[-1] # Tuple[Tensor], shape: [bsz, n_passages] 245 | 246 | count = torch.numel(has_answer_labels) 247 | for layer_idx, logits in enumerate(all_has_answer_outputs): 248 | correct = torch.sum(torch.eq((logits.sigmoid() > 0.5).float(), has_answer_labels)).item() 249 | all_results[layer_idx]["acc"].append((correct, count)) 250 | 251 | predictions = (logits.sigmoid() > 0.5).float() 252 | true_positive = torch.sum(predictions * has_answer_labels).item() 253 | pred_positive = torch.sum(predictions).item() 254 | gt_positive = torch.sum(has_answer_labels).item() 255 | 256 | all_results[layer_idx]["prec"].append((true_positive, pred_positive)) 257 | all_results[layer_idx]["recall"].append((true_positive, gt_positive)) 258 | 259 | final_results = {} 260 | for idx, results in enumerate(all_results): 261 | for metric, values in results.items(): 262 | value_list, count_list = zip(*values) 263 | final_results[f"layer{idx}/{metric}"] = sum(value_list) / max(sum(count_list), 1) 264 | 265 | all_f1 = [] 266 | for idx in range(num_layers): 267 | prec = final_results[f"layer{idx}/prec"] 268 | recall = final_results[f"layer{idx}/recall"] 269 | f1 = 2 * prec * recall / max(prec + recall, 1e-5) 270 | final_results[f"layer{idx}/f1"] = f1 271 | all_f1.append(f1) 272 | 273 | average_f1 = np.mean(all_f1) 274 | final_results["avg_f1"] = average_f1 275 | return final_results 276 | 277 | 278 | if __name__ == "__main__": 279 | options = Options() 280 | opt = options.parse() 281 | torch.manual_seed(opt.seed) 282 | # slurm.init_distributed_mode(opt) 283 | # slurm.init_signal_handler() 284 | opt.train_batch_size = opt.per_gpu_batch_size * max(1, opt.world_size) 285 | logger.info("Distributed training") 286 | 287 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 288 | 289 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 290 | 291 | model_name = 't5-' + opt.model_size 292 | model_class = ACFiDT5 293 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) 294 | 295 | collator_function = data.Collator(opt, tokenizer) 296 | 297 | train_examples = data.load_data(opt.train_data_path, n_context=opt.n_context) 298 | train_dataset = data.Dataset(train_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 299 | dev_examples = data.load_data(opt.dev_data_path, global_rank=opt.global_rank, 300 | # use the global rank and world size attibutes to split the dev set on multiple gpus 301 | world_size=opt.world_size, 302 | n_context=opt.n_context) 303 | if opt.dev_data_size > 0: 304 | random.seed(opt.seed) 305 | dev_examples = random.sample(dev_examples, opt.dev_data_size) 306 | # dev_examples = dev_examples[:opt.dev_data_size] 307 | dev_dataset = data.Dataset(dev_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 308 | 309 | directory_exists = os.path.exists(dir_path) 310 | if opt.world_size > 1 and not opt.local_rank == -1: 311 | torch.distributed.barrier() 312 | os.makedirs(dir_path, exist_ok=True) 313 | if not directory_exists and opt.is_master: 314 | options.print_options(opt) 315 | if opt.world_size > 1 and not opt.local_rank == -1: 316 | torch.distributed.barrier() 317 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 318 | stdout_handler = logging.StreamHandler(sys.stdout) 319 | handlers = [file_handler, stdout_handler] 320 | logging.basicConfig( 321 | datefmt="%m/%d/%Y %H:%M:%S", 322 | level=logging.INFO if opt.is_master else logging.WARN, 323 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 324 | handlers=handlers, 325 | ) 326 | 327 | if opt.world_size > 1 and not opt.local_rank == -1: 328 | torch.distributed.barrier() 329 | 330 | if opt.is_master: 331 | tb_logger = SummaryWriter(os.path.join(opt.checkpoint_dir, opt.name)) 332 | # Setup wandb 333 | if opt.is_master and _has_wandb: 334 | opt_dict = vars(opt) 335 | # os.makedirs(os.path.join(dir_path, "wandb")) 336 | wandb.init(project="ACFiD", name=opt.name, dir=os.path.join(dir_path), config=opt_dict) 337 | 338 | global_step = 0 339 | best_metric = 0. 340 | 341 | if not directory_exists and opt.model_path == "none": 342 | model = model_class.from_pretrained(model_name) 343 | model = model.to(0 if opt.local_rank == -1 else opt.local_rank) 344 | # optimizer, scheduler = util.set_optim(opt, model) 345 | optimizer, scheduler = None, None 346 | elif opt.model_path == "none": # directory exists, but model_path is none 347 | model, optimizer, scheduler, opt_checkpoint, global_step, best_metric = util.restore_epoch( 348 | model_class, dir_path, opt, reset_params=False, name="latest", 349 | ) 350 | logger.info("Model loaded from %s" % dir_path) 351 | else: # model_path is given 352 | logger.info("Loading %s" % opt.model_path) 353 | # model, optimizer, scheduler = util.load_model(model_class, opt.model_path, opt) 354 | config = T5Config.from_pretrained(opt.model_path) 355 | 356 | # Update config with arguments 357 | if opt.has_answer_pool_type != "none": 358 | config.has_answer_pool_type = opt.has_answer_pool_type 359 | if opt.scheduler_type != "none": 360 | config.scheduler_type = opt.scheduler_type 361 | config.scheduler_n_context = opt.scheduler_n_context 362 | config.scheduler_embed_size = opt.scheduler_embed_size 363 | 364 | model = model_class.from_pretrained(opt.model_path, config=config) 365 | model = model.to(0 if opt.local_rank == -1 else opt.local_rank) 366 | optimizer, scheduler = None, None 367 | logger.info("Model loaded from %s" % opt.model_path) 368 | logger.info("Model config %s", str(config)) 369 | 370 | # Set model training configs (a hack around) which are only used during training 371 | model.encoder.checkpoint = opt.checkpointing_encoder 372 | model.decoder.checkpoint = opt.checkpointing_decoder 373 | model.encoder.n_passages = opt.n_context 374 | model.freeze_fid_params = opt.freeze_fid_params # config for training has_answer_heads with FiD parameters froze 375 | 376 | # Training the scheduler 377 | model.encoder.budget = None # set to None to disable the scheduler 378 | model.encoder.num_passages_retained = None # set to None to disable the scheduler 379 | model.freeze_has_answer_heads = False 380 | model.use_bce_loss = True 381 | model.use_rl_loss = False 382 | # model.step_cost = opt.step_cost 383 | # model.discount = opt.discount 384 | 385 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. 386 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will 387 | # remove the need for this code, but it is still valid. 388 | if opt.fp16: 389 | try: 390 | import apex 391 | 392 | apex.amp.register_half_function(torch, "einsum") 393 | except ImportError: 394 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 395 | 396 | logger.info("Start training") 397 | train_evaluate(model, optimizer, scheduler, global_step, 398 | train_dataset, dev_dataset, opt, collator_function, best_metric) 399 | -------------------------------------------------------------------------------- /FiD/train_ac_scheduler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import torch 5 | import transformers 6 | import logging 7 | import util 8 | import numpy as np 9 | from tqdm.auto import tqdm 10 | 11 | import torch.distributed as dist 12 | from torch.utils.tensorboard import SummaryWriter 13 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 14 | 15 | # FiD related 16 | from options import Options 17 | # import evaluation 18 | # import slurm 19 | 20 | # ACFiD specific 21 | from fidt5_ac import ACFiDT5, T5Config 22 | import data_ac as data 23 | 24 | logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR) 25 | logging.getLogger('transformers.tokenization_utils_base').setLevel(logging.ERROR) 26 | logger = logging.getLogger(__name__) 27 | 28 | # Initialise wandb 29 | try: 30 | import wandb 31 | 32 | wandb.ensure_configured() 33 | if wandb.api.api_key is None: 34 | _has_wandb = False 35 | wandb.termwarn("W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable.") 36 | else: 37 | _has_wandb = False if os.getenv("WANDB_DISABLED") else True 38 | except (ImportError, AttributeError): 39 | _has_wandb = False 40 | 41 | 42 | def log_scalar(name, value, step): 43 | tb_logger.add_scalar(name, value, step) 44 | if _has_wandb: 45 | wandb.log({name: value, "step": step}) 46 | 47 | 48 | def train_evaluate(model, optimizer, scheduler, global_step, 49 | train_dataset, dev_dataset, opt, collator_function, best_metric): 50 | train_sampler = (RandomSampler(train_dataset) if opt.local_rank == -1 or opt.world_size == 1 51 | else DistributedSampler(train_dataset)) 52 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, 53 | batch_size=opt.per_gpu_batch_size, drop_last=True, num_workers=3, 54 | collate_fn=collator_function) 55 | 56 | dev_sampler = SequentialSampler(dev_dataset) 57 | dev_dataloader = DataLoader(dev_dataset, sampler=dev_sampler, batch_size=opt.per_gpu_batch_size, 58 | shuffle=False, num_workers=3, collate_fn=collator_function) 59 | 60 | # Freeze the FiD parameters and only train AC part. 61 | trainable_np = list(model.named_parameters()) 62 | if opt.freeze_fid_params: 63 | new_np = [] 64 | for n, p in trainable_np: 65 | if n.startswith("encoder.has_answer_heads") or n.startswith("ac_scheduler"): 66 | p.requires_grad = True 67 | new_np.append((n, p)) 68 | else: 69 | p.requires_grad = False 70 | trainable_np = new_np 71 | 72 | if opt.freeze_has_answer_heads: 73 | new_np = [] 74 | for n, p in trainable_np: 75 | if n.startswith("encoder.has_answer_heads"): 76 | p.requires_grad = False 77 | else: 78 | p.requires_grad = True 79 | new_np.append((n, p)) 80 | trainable_np = new_np 81 | 82 | trainable_parameters = [p for n, p in trainable_np] 83 | # Prepare optimizer and schedule (linear warmup and decay) 84 | if optimizer is None or scheduler is None: 85 | optimizer = torch.optim.Adam(trainable_parameters, lr=opt.lr) 86 | # optimizer = torch.optim.SGD(trainable_parameters, lr=opt.lr) 87 | scheduler = util.FixedScheduler(optimizer) 88 | 89 | # fp16 90 | if opt.fp16: 91 | try: 92 | from apex import amp 93 | except ImportError: 94 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 95 | 96 | model, optimizer = amp.initialize(model, optimizer, opt_level=opt.fp16_opt_level) 97 | 98 | # Distributed training 99 | if opt.world_size > 1 and opt.local_rank != -1: 100 | model = torch.nn.parallel.DistributedDataParallel( 101 | model, 102 | device_ids=[opt.local_rank], 103 | output_device=opt.local_rank, 104 | find_unused_parameters=False, 105 | ) 106 | 107 | # Train! 108 | logger.info("***** Running training *****") 109 | logger.info(" Num examples = %d", len(train_dataset)) 110 | # logger.info(" Num Epochs = %d", args.num_train_epochs) 111 | logger.info(" Instantaneous batch size per GPU = %d", opt.per_gpu_batch_size) 112 | logger.info( 113 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 114 | opt.train_batch_size * opt.gradient_accumulation_steps 115 | ) 116 | logger.info(" Gradient Accumulation steps = %d", opt.gradient_accumulation_steps) 117 | logger.info(" Total optimization steps = %d", opt.total_step) 118 | logger.info(" Total number of training epochs = %f", 119 | opt.total_step * opt.train_batch_size * opt.gradient_accumulation_steps / len(train_dataset)) 120 | 121 | loss, curr_loss, curr_reward = 0.0, 0.0, 0.0 122 | epoch = 0 123 | step = 0 124 | model.train() 125 | model.zero_grad() 126 | while global_step < opt.total_step: 127 | epoch += 1 128 | if opt.world_size > 1: 129 | train_sampler.set_epoch(epoch) 130 | for i, batch in tqdm(enumerate(train_dataloader)): 131 | step += 1 132 | 133 | # Process the inputs 134 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 135 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 136 | has_answer_labels = has_answer_labels.cuda() 137 | labels = answer_ids.masked_fill(~answer_mask, -100) 138 | if hasattr(model, "module"): 139 | model.module.encoder.n_passages = context_ids.size(1) 140 | else: 141 | model.encoder.n_passages = context_ids.size(1) 142 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 143 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 144 | decoder_input_ids = None 145 | 146 | inputs = { 147 | 'input_ids': context_ids, 148 | 'attention_mask': context_mask, 149 | 'decoder_attention_mask': answer_mask, 150 | 'decoder_input_ids': decoder_input_ids, 151 | 'labels': labels, 152 | 'has_answer_labels': has_answer_labels, 153 | } 154 | 155 | # Run the model 156 | outputs = model(**inputs) 157 | train_loss = outputs[0] 158 | train_loss = util.average_master(train_loss, opt) 159 | train_reward = outputs[1] 160 | train_reward = util.average_master(train_reward, opt) 161 | 162 | if opt.gradient_accumulation_steps > 1: 163 | train_loss = train_loss / float(opt.gradient_accumulation_steps) 164 | train_reward = train_reward / float(opt.gradient_accumulation_steps) 165 | 166 | if opt.fp16: 167 | with amp.scale_loss(train_loss, optimizer) as scaled_loss: 168 | scaled_loss.backward() 169 | else: 170 | train_loss.backward() 171 | 172 | curr_loss += train_loss.item() 173 | curr_reward += train_reward.item() 174 | if step % opt.gradient_accumulation_steps == 0: 175 | # util.clip_gradients(model, opt.clip) 176 | if opt.fp16: 177 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), opt.clip) 178 | else: 179 | torch.nn.utils.clip_grad_norm_(trainable_parameters, opt.clip) 180 | 181 | optimizer.step() 182 | scheduler.step() # Update learning rate schedule 183 | model.zero_grad() 184 | global_step += 1 185 | 186 | if opt.is_master and global_step % opt.log_freq == 0: 187 | logger.info( 188 | f"{global_step} / {opt.total_step} -- train loss = {curr_loss / opt.log_freq}" 189 | f" | train reward = {curr_reward / opt.log_freq:.3f}" 190 | f" | lr = {scheduler.get_last_lr()[0]}" 191 | ) 192 | log_scalar("Train/Loss", curr_loss / opt.log_freq, global_step) 193 | log_scalar("Train/Reward", curr_reward / opt.log_freq, global_step) 194 | curr_loss = 0. 195 | curr_reward = 0. 196 | 197 | if global_step % opt.eval_freq == 0: 198 | results = evaluate(model, dev_dataset, dev_dataloader, tokenizer, opt) 199 | dev_reward = results["reward"] # use reward as evaluation metric 200 | if opt.is_master: 201 | logger.info(f"{global_step} / {opt.total_step} -- dev reward = {dev_reward:.2f}") 202 | for k, v in results.items(): 203 | log_scalar(f"Dev/{k}", v, global_step) 204 | 205 | if dev_reward > best_metric: 206 | best_metric = dev_reward 207 | if opt.is_master: 208 | model_to_save = model.module if hasattr(model, "module") else model 209 | util.save(model_to_save, optimizer, scheduler, global_step, best_metric, opt, dir_path, 210 | 'best_dev') 211 | model.train() 212 | 213 | if opt.is_master and global_step % opt.save_freq == 0: 214 | model_to_save = model.module if hasattr(model, "module") else model 215 | util.save(model_to_save, optimizer, scheduler, global_step, best_metric, opt, dir_path, 216 | f"step-{global_step}") 217 | if global_step > opt.total_step: 218 | break 219 | 220 | 221 | def evaluate(model, dataset, dataloader, tokenizer, opt): 222 | model.eval() 223 | if hasattr(model, "module"): 224 | model = model.module 225 | 226 | all_rewards = [] 227 | 228 | with torch.no_grad(): 229 | for i, batch in enumerate(dataloader): 230 | idx, answer_ids, answer_mask, context_ids, context_mask, has_answer_labels = batch 231 | answer_ids, answer_mask = answer_ids.cuda(), answer_mask.bool().cuda() 232 | has_answer_labels = has_answer_labels.cuda() 233 | labels = answer_ids.masked_fill(~answer_mask, -100) 234 | if hasattr(model, "module"): 235 | model.module.encoder.n_passages = context_ids.size(1) 236 | else: 237 | model.encoder.n_passages = context_ids.size(1) 238 | context_ids = context_ids.cuda().view(context_ids.size(0), -1) 239 | context_mask = context_mask.cuda().view(context_ids.size(0), -1) 240 | decoder_input_ids = None 241 | 242 | inputs = { 243 | 'input_ids': context_ids, 244 | 'attention_mask': context_mask, 245 | 'decoder_attention_mask': answer_mask, 246 | 'decoder_input_ids': decoder_input_ids, 247 | 'labels': labels, 248 | 'has_answer_labels': has_answer_labels, 249 | } 250 | outputs = model(**inputs) 251 | mean_rewards = outputs[1].item() 252 | all_rewards.append(mean_rewards) 253 | 254 | final_results = {"reward": np.mean(all_rewards)} 255 | return final_results 256 | 257 | 258 | if __name__ == "__main__": 259 | options = Options() 260 | opt = options.parse() 261 | torch.manual_seed(opt.seed) 262 | # slurm.init_distributed_mode(opt) 263 | # slurm.init_signal_handler() 264 | opt.train_batch_size = opt.per_gpu_batch_size * max(1, opt.world_size) 265 | logger.info("Distributed training") 266 | 267 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 268 | 269 | dir_path = os.path.join(opt.checkpoint_dir, opt.name) 270 | 271 | model_name = 't5-' + opt.model_size 272 | model_class = ACFiDT5 273 | tokenizer = transformers.T5Tokenizer.from_pretrained(model_name) 274 | 275 | collator_function = data.Collator(opt, tokenizer) 276 | 277 | train_examples = data.load_data(opt.train_data_path, n_context=opt.n_context) 278 | train_dataset = data.Dataset(train_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 279 | dev_examples = data.load_data(opt.dev_data_path, global_rank=opt.global_rank, 280 | # use the global rank and world size attibutes to split the dev set on multiple gpus 281 | world_size=opt.world_size, 282 | n_context=opt.n_context) 283 | if opt.dev_data_size > 0: 284 | random.seed(opt.seed) 285 | dev_examples = random.sample(dev_examples, opt.dev_data_size) 286 | # dev_examples = dev_examples[:opt.dev_data_size] 287 | dev_dataset = data.Dataset(dev_examples, opt.n_context, tokenizer, opt.max_passage_length, opt.no_title) 288 | 289 | directory_exists = os.path.exists(dir_path) 290 | if opt.world_size > 1 and not opt.local_rank == -1: 291 | torch.distributed.barrier() 292 | os.makedirs(dir_path, exist_ok=True) 293 | if not directory_exists and opt.is_master: 294 | options.print_options(opt) 295 | if opt.world_size > 1 and not opt.local_rank == -1: 296 | torch.distributed.barrier() 297 | file_handler = logging.FileHandler(filename=os.path.join(dir_path, "run.log")) 298 | stdout_handler = logging.StreamHandler(sys.stdout) 299 | handlers = [file_handler, stdout_handler] 300 | logging.basicConfig( 301 | datefmt="%m/%d/%Y %H:%M:%S", 302 | level=logging.INFO if opt.is_master else logging.WARN, 303 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 304 | handlers=handlers, 305 | ) 306 | 307 | if opt.world_size > 1 and not opt.local_rank == -1: 308 | torch.distributed.barrier() 309 | 310 | if opt.is_master: 311 | tb_logger = SummaryWriter(os.path.join(opt.checkpoint_dir, opt.name)) 312 | # Setup wandb 313 | if opt.is_master and _has_wandb: 314 | opt_dict = vars(opt) 315 | # os.makedirs(os.path.join(dir_path, "wandb")) 316 | wandb.init(project="ACFiD", name=opt.name, dir=os.path.join(dir_path), config=opt_dict) 317 | 318 | global_step = 0 319 | best_metric = 0. 320 | 321 | if not directory_exists and opt.model_path == "none": 322 | model = model_class.from_pretrained(model_name) 323 | model = model.to(0 if opt.local_rank == -1 else opt.local_rank) 324 | # optimizer, scheduler = util.set_optim(opt, model) 325 | optimizer, scheduler = None, None 326 | elif opt.model_path == "none": # directory exists, but model_path is none 327 | model, optimizer, scheduler, opt_checkpoint, global_step, best_metric = util.restore_epoch( 328 | model_class, dir_path, opt, reset_params=False, name="latest", 329 | ) 330 | logger.info("Model loaded from %s" % dir_path) 331 | else: # model_path is given 332 | logger.info("Loading %s" % opt.model_path) 333 | # model, optimizer, scheduler = util.load_model(model_class, opt.model_path, opt) 334 | config = T5Config.from_pretrained(opt.model_path) 335 | 336 | # Update config with arguments 337 | if opt.has_answer_pool_type != "none": 338 | config.has_answer_pool_type = opt.has_answer_pool_type 339 | if opt.scheduler_type != "none": 340 | config.scheduler_type = opt.scheduler_type 341 | config.scheduler_n_context = opt.scheduler_n_context 342 | config.scheduler_embed_size = opt.scheduler_embed_size 343 | config.scheduler_hidden_size = opt.scheduler_hidden_size 344 | 345 | model = model_class.from_pretrained(opt.model_path, config=config) 346 | model = model.to(0 if opt.local_rank == -1 else opt.local_rank) 347 | optimizer, scheduler = None, None 348 | logger.info("Model loaded from %s" % opt.model_path) 349 | logger.info("Model config %s", str(config)) 350 | 351 | # Set model training configs (a hack around) which are only used during training 352 | model.encoder.checkpoint = opt.checkpointing_encoder 353 | model.decoder.checkpoint = opt.checkpointing_decoder 354 | model.encoder.n_passages = opt.n_context 355 | model.freeze_fid_params = opt.freeze_fid_params # config for training has_answer_heads with FiD parameters froze 356 | 357 | # Training the scheduler 358 | model.encoder.budget = opt.budget # config for training/evaluating AC scheduler 359 | model.encoder.num_passages_retained = opt.num_passages_retained # config for training/evaluating AC scheduler 360 | model.freeze_has_answer_heads = opt.freeze_has_answer_heads 361 | model.step_cost = opt.step_cost 362 | model.discount = opt.discount 363 | # Losses 364 | model.use_bce_loss = opt.use_bce_loss 365 | model.use_rl_loss = opt.use_rl_loss 366 | 367 | # Before we do anything with models, we want to ensure that we get fp16 execution of torch.einsum if args.fp16 is set. 368 | # Otherwise it'll default to "promote" mode, and we'll get fp32 operations. Note that running `--fp16_opt_level="O2"` will 369 | # remove the need for this code, but it is still valid. 370 | if opt.fp16: 371 | try: 372 | import apex 373 | 374 | apex.amp.register_half_function(torch, "einsum") 375 | except ImportError: 376 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 377 | 378 | logger.info("Start training") 379 | train_evaluate(model, optimizer, scheduler, global_step, 380 | train_dataset, dev_dataset, opt, collator_function, best_metric) 381 | -------------------------------------------------------------------------------- /FiD/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import sys 5 | import logging 6 | import torch.distributed as dist 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def copy_dir(source, target): 12 | if os.path.exists(target): 13 | assert os.path.isdir(target) 14 | shutil.rmtree(target) 15 | 16 | shutil.copytree(source, target) 17 | 18 | 19 | def save(model, optimizer, scheduler, step, best_dev_em, opt, dir_path, name): 20 | path = os.path.join(dir_path, "checkpoint") 21 | epoch_path = os.path.join(path, name) # "step-%s" % step) 22 | os.makedirs(epoch_path, exist_ok=True) 23 | model.save_pretrained(epoch_path) 24 | 25 | # Save optimizer states 26 | fp = os.path.join(epoch_path, "optimizer.pth.tar") 27 | checkpoint = { 28 | "step": step, 29 | "optimizer": optimizer.state_dict(), 30 | "scheduler": scheduler.state_dict(), 31 | "opt": opt, 32 | "best_dev_em": best_dev_em, 33 | } 34 | torch.save(checkpoint, fp) 35 | 36 | latest_path = os.path.join(path, "latest") 37 | copy_dir(epoch_path, latest_path) 38 | 39 | 40 | def restore_epoch(model_class, dir_path, opt, name, reset_params=False): 41 | epoch_path = os.path.join(dir_path, "checkpoint", name) # str(epoch)) 42 | epoch_path = os.path.realpath(epoch_path) 43 | optimizer_path = os.path.join(epoch_path, "optimizer.pth.tar") 44 | logger.info("Loading %s" % epoch_path) 45 | model = model_class.from_pretrained(epoch_path) # , map_location="cuda:"+str(opt.local_rank)) 46 | logger.info("loading checkpoint %s" % optimizer_path) 47 | 48 | local_rank = 0 if opt.local_rank == -1 else opt.local_rank 49 | checkpoint = torch.load(optimizer_path, map_location="cuda:" + str(local_rank)) 50 | opt_checkpoint = checkpoint["opt"] 51 | step = checkpoint["step"] 52 | best_dev_em = checkpoint["best_dev_em"] 53 | if not reset_params: 54 | optimizer, scheduler = set_optim(opt_checkpoint, model) 55 | scheduler.load_state_dict(checkpoint["scheduler"]) 56 | optimizer.load_state_dict(checkpoint["optimizer"]) 57 | else: 58 | optimizer, scheduler = set_optim(opt, model) 59 | 60 | model = model.to(local_rank) 61 | return model, optimizer, scheduler, opt_checkpoint, step, best_dev_em 62 | 63 | 64 | def load_model(model_class, model_path, opt): 65 | logger.info("Loading %s" % model_path) 66 | model = model_class.from_pretrained(model_path) # , map_location="cuda:"+str(opt.local_rank)) 67 | 68 | local_rank = 0 if opt.local_rank == -1 else opt.local_rank 69 | model = model.to(local_rank) 70 | optimizer, scheduler = set_optim(opt, model) 71 | 72 | return model, optimizer, scheduler 73 | 74 | 75 | ############ OPTIM 76 | 77 | 78 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 79 | def __init__( 80 | self, optimizer, warmup_steps, t_total, min_ratio, fixed_lr, last_epoch=-1 81 | ): 82 | self.warmup_steps = warmup_steps 83 | self.t_total = t_total 84 | self.min_ratio = min_ratio 85 | self.fixed_lr = fixed_lr 86 | super(WarmupLinearScheduler, self).__init__( 87 | optimizer, self.lr_lambda, last_epoch=last_epoch 88 | ) 89 | 90 | def lr_lambda(self, step): 91 | if step < self.warmup_steps: 92 | return (1 - self.min_ratio) * float(step) / float( 93 | max(1, self.warmup_steps) 94 | ) + self.min_ratio 95 | return 1.0 96 | 97 | if self.fixed_lr: 98 | return 1.0 99 | 100 | return max( 101 | 0.0, 102 | 1.0 103 | + float((self.min_ratio - 1) * (step - self.warmup_steps)) 104 | / float(max(1.0, self.t_total - self.warmup_steps)), 105 | ) 106 | 107 | 108 | class FixedScheduler(torch.optim.lr_scheduler.LambdaLR): 109 | def __init__(self, optimizer, last_epoch=-1): 110 | super(FixedScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 111 | 112 | def lr_lambda(self, step): 113 | return 1.0 114 | 115 | 116 | def clip_gradients(model, clip): 117 | for p in list(filter(lambda p: p.grad is not None, model.parameters())): 118 | clip_coef = clip / (p.grad.data.norm(2) + 1e-6) 119 | if clip_coef < 1: 120 | p.grad.data.mul_(clip_coef) 121 | 122 | 123 | # def set_optim(opt, model): 124 | # #cache_p, model_p = [], [] 125 | # #for n, p in model.named_parameters(): 126 | # # if 'cache' not in n: 127 | # # model_p.append(p) 128 | # # else: 129 | # # cache_p.append(p) 130 | # if opt.optim == "adam": 131 | # optimizer = torch.optim.Adam( 132 | # model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps 133 | # ) 134 | # elif opt.optim == "adagrad": 135 | # optimizer = torch.optim.Adagrad(model.parameters(), lr=opt.lr) 136 | # elif opt.optim == "adafactor": 137 | # optimizer = fairseq.optim.adafactor.Adafactor(model.parameters(), lr=opt.lr, relative_step=False) 138 | # elif opt.optim == "sgd": 139 | # optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr) 140 | # if opt.scheduler == 'linear': 141 | # scheduler = WarmupLinearScheduler(optimizer, opt.warmup, t_total=opt.t_total, min_ratio=opt.min_lr/opt.lr, fixed_lr=opt.fixed_lr) 142 | # elif opt.scheduler == 'fixed': 143 | # scheduler = FixedScheduler(optimizer) 144 | # return optimizer, scheduler 145 | 146 | 147 | def set_optim(opt, model): 148 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 149 | scheduler = FixedScheduler(optimizer) 150 | return optimizer, scheduler 151 | 152 | 153 | def _get_grad_requiring_params(model): 154 | nb_parameters = 0 155 | grad_requiring_params = [] 156 | for param in model.parameters(): 157 | if param.requires_grad: 158 | nb_parameters += param.numel() 159 | grad_requiring_params.append(param) 160 | return grad_requiring_params 161 | 162 | 163 | def print_parameters(net, log_dir, verbose=False): 164 | file_name = os.path.join(log_dir, "opt.txt") 165 | num_params = 0 166 | for param in net.parameters(): 167 | num_params += param.numel() 168 | message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6) 169 | print(message) 170 | if verbose: 171 | print(net) 172 | sys.stdout.flush() 173 | with open(file_name, "a") as log_file: 174 | log_file.write(message + "\n") 175 | with open(file_name, "a") as log_file: 176 | log_file.write(str(net) + "\n") 177 | 178 | 179 | def average_master(x, opt): 180 | if opt.world_size > 1: 181 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 182 | if opt.is_master: 183 | x = x / opt.world_size 184 | return x 185 | 186 | 187 | def sum_master(x, opt): 188 | if opt.world_size > 1: 189 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 190 | return x 191 | 192 | 193 | def weighted_average(x, count, opt): 194 | local_rank = 0 if opt.local_rank == -1 else opt.local_rank 195 | t_loss = torch.tensor([x * count], device="cuda:" + str(local_rank)) 196 | t_total = torch.tensor([count], device="cuda:" + str(local_rank)) 197 | t_loss = sum_master(t_loss, opt) 198 | t_total = sum_master(t_total, opt) 199 | return (t_loss / t_total).item(), t_total.item() 200 | 201 | 202 | def write_output(glob_path, output_path): 203 | files = list(glob_path.glob('*.txt')) 204 | files.sort() 205 | with open(output_path, 'w') as outfile: 206 | for path in files: 207 | with open(path, 'r') as f: 208 | lines = f.readlines() 209 | for line in lines: 210 | outfile.write(line) 211 | path.unlink() 212 | glob_path.rmdir() 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Passage Encoder for Open-domain Question Answering 2 | 3 | 4 | 5 | Source code repository of our ACL-IJCNLP 2021 paper "Training Adaptive Computation for Open-Domain Question Answering with Computational Constraints". 6 | 7 | This is based on [Fusion-in-Decoder (FiD)](https://arxiv.org/abs/2007.01282) work, so the two projects share some dependencies and datasets. Refer to [FiD](https://github.com/facebookresearch/FiD) repository for more details regarding downloading the data and checkpoints. 8 | 9 | ## Dependencies 10 | 11 | - Python 3 12 | - [NumPy](http://www.numpy.org/) 13 | - [PyTorch](http://pytorch.org/) (currently tested on version 1.6.0) 14 | - [Transformers](http://huggingface.co/transformers/) (version 3.0.2, unlikely to work with a different version) 15 | 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | tqdm 3 | sentencepiece 4 | -------------------------------------------------------------------------------- /scripts/batch_eval_retrieval_acc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME="$(pwd)" 4 | 5 | DATA=$1 6 | MODEL=$2 7 | SIZE=$3 8 | N_CTX=$4 9 | 10 | BSZ=4 11 | #BUDGET=$5 12 | #TOPK=$6 13 | 14 | for k in 5 10 20; do 15 | TOPK=$k 16 | if [ "$SIZE" = "base" ]; then 17 | BUDGET=$(($TOPK * 12)) 18 | else 19 | BUDGET=$(($TOPK * 24)) 20 | fi 21 | echo "budget $BUDGET topk $TOPK" 22 | 23 | ./scripts/eval_retrieval_acc.sh $DATA $N_CTX $BSZ $SIZE $MODEL $BUDGET $TOPK 24 | cd "$HOME" 25 | done 26 | -------------------------------------------------------------------------------- /scripts/batch_eval_scheduler_nq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME="$(pwd)" 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq/nq_dpr_test.json" 6 | DATA="nq" 7 | 8 | MODEL=$1 9 | SIZE=$2 10 | N_CTX=$3 11 | BSZ=4 12 | #BUDGET=$5 13 | #TOPK=$6 14 | 15 | for k in 5 10 20 30 40; do 16 | TOPK=$k 17 | if [ "$SIZE" = "base" ]; then 18 | BUDGET=$(($TOPK * 12)) 19 | else 20 | BUDGET=$(($TOPK * 24)) 21 | fi 22 | echo "budget $BUDGET topk $TOPK" 23 | 24 | ./scripts/eval_scheduler_nq.sh $N_CTX $BSZ $SIZE $MODEL $BUDGET $TOPK 25 | cd "$HOME" 26 | done 27 | -------------------------------------------------------------------------------- /scripts/batch_eval_scheduler_trivia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME="$(pwd)" 4 | 5 | MODEL=$1 6 | SIZE=$2 7 | N_CTX=$3 8 | BSZ=4 9 | #BUDGET=$5 10 | #TOPK=$6 11 | 12 | for k in 5 10 20 30 40; do 13 | TOPK=$k 14 | if [ "$SIZE" = "base" ]; then 15 | BUDGET=$(($TOPK * 12)) 16 | else 17 | BUDGET=$(($TOPK * 24)) 18 | fi 19 | echo "budget $BUDGET topk $TOPK" 20 | 21 | ./scripts/eval_scheduler_trivia.sh $N_CTX $BSZ $SIZE $MODEL $BUDGET $TOPK 22 | cd "$HOME" 23 | done 24 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_DIR=$(pwd) 4 | 5 | DATA_DIR="$ROOT_DIR/data" 6 | mkdir -p "$DATA_DIR" 7 | PROC_DATA="$DATA_DIR"/preprocessed_data 8 | mkdir -p "$PROC_DATA" 9 | 10 | # Download retrieved passages for NaturalQuestions 11 | cd $ROOT_DIR 12 | mkdir -p "$PROC_DATA"/nq 13 | cd $PROC_DATA/nq 14 | 15 | if [[ ! -f nq_dpr_train.json ]]; then 16 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/nq/nq_dpr_train.json.xz 17 | xz --decompress nq_dpr_train.json.xz 18 | fi 19 | #6e86173809c6b2f8390f9dd20631c001 nq/nq_dpr_train.json 20 | 21 | if [[ ! -f nq_dpr_dev.json ]]; then 22 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/nq/nq_dpr_dev.json.xz 23 | xz --decompress nq_dpr_dev.json.xz 24 | fi 25 | #9e09fba3a450bebf86c7706b181a7491 nq/nq_dpr_dev.json 26 | 27 | if [[ ! -f nq_dpr_test.json ]]; then 28 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/nq/nq_dpr_test.json.xz 29 | xz --decompress nq_dpr_test.json.xz 30 | fi 31 | #06850dd776b473818129665344470664 nq/nq_dpr_test.json 32 | 33 | # Download retrieved passages for TriviaQA 34 | mkdir -p "$PROC_DATA"/trivia 35 | cd "$PROC_DATA"/trivia 36 | 37 | if [[ ! -f trivia_dpr_train.json ]]; then 38 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/trivia/trivia_dpr_train.json.xz 39 | xz --decompress trivia_dpr_train.json.xz 40 | fi 41 | #dd12dddd006ec9c35894e0a2d188f9d6 trivia/trivia_dpr_train.json 42 | 43 | if [[ ! -f trivia_dpr_dev.json ]]; then 44 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/trivia/trivia_dpr_dev.json.xz 45 | xz --decompress trivia_dpr_dev.json.xz 46 | fi 47 | #2128d1f3aafb35c61c62855a37d63d0f trivia/trivia_dpr_dev.json 48 | 49 | if [[ ! -f trivia_dpr_test.json ]]; then 50 | wget http://dl.fbaipublicfiles.com/FiD/preprocessed_data/trivia/trivia_dpr_test.json.xz 51 | xz --decompress trivia_dpr_test.json.xz 52 | fi 53 | #22c74e02c429b1fb777c4f62967e6756 trivia/trivia_dpr_test.json 54 | 55 | cd $ROOT_DIR -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_DIR=$PWD 4 | 5 | MODEL_DIR=$ROOT_DIR/pretrained_models 6 | mkdir -p "$MODEL_DIR" 7 | 8 | for NAME in "nq_base_dpr" "nq_large_dpr" "triviaqa_base_dpr" "triviaqa_large_dpr"; do 9 | mkdir -p "$MODEL_DIR"/${NAME} 10 | cd $MODEL_DIR/${NAME} 11 | 12 | if [[ ! -f pytorch_model.bin ]]; then 13 | wget http://dl.fbaipublicfiles.com/FiD/pretrained_models/${NAME}/pytorch_model.bin 14 | fi 15 | if [[ ! -f config.json ]]; then 16 | wget http://dl.fbaipublicfiles.com/FiD/pretrained_models/${NAME}/config.json 17 | fi 18 | 19 | cd $ROOT_DIR 20 | done 21 | -------------------------------------------------------------------------------- /scripts/eval_ac_nq_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq/nq_dpr_test.json" 6 | DATA="nq" 7 | 8 | N_CTX=$1 9 | BSZ=$2 10 | SIZE=$3 11 | MODEL=$4 12 | BUDGET=$5 13 | TOPK=$6 14 | 15 | NAME="${DATA}_${SIZE}_nctx=${N_CTX}_budget=${BUDGET}_topk=${TOPK}" 16 | 17 | python FiD/test_ac.py \ 18 | --model_path $MODEL \ 19 | --test_data_path $NQ_DATA \ 20 | --model_size $SIZE \ 21 | --per_gpu_batch_size $BSZ \ 22 | --n_context $N_CTX \ 23 | --name $NAME \ 24 | --checkpoint_dir "/tmp/${NAME}" \ 25 | --budget $BUDGET \ 26 | --num_passages_retained $TOPK 27 | -------------------------------------------------------------------------------- /scripts/eval_nq_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate FiD 4 | cd FiD 5 | 6 | NQ_DATA="../data/preprocessed_data/nq/nq_dpr_test.json" 7 | DATA="nq" 8 | SIZE=$1 9 | MODEL="../pretrained_models/nq_${SIZE}_dpr/" 10 | 11 | BSZ=2 12 | 13 | for N_CTX in 6 12 22 26 30 35 45; do 14 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 15 | done 16 | -------------------------------------------------------------------------------- /scripts/eval_nq_single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq/nq_dpr_test.json" 6 | DATA="nq" 7 | 8 | N_CTX=$1 9 | BSZ=$2 10 | SIZE=$3 11 | #MODEL="${HOME}/pretrained_models/nq_${SIZE}_dpr/" 12 | MODEL=$4 13 | 14 | NAME="${DATA}_${SIZE}_nctx=${N_CTX}" 15 | python FiD/test.py \ 16 | --model_path $MODEL \ 17 | --test_data_path $NQ_DATA \ 18 | --model_size $SIZE \ 19 | --per_gpu_batch_size $BSZ \ 20 | --n_context $N_CTX \ 21 | --name $NAME \ 22 | --checkpoint_dir "$MODEL/eval" \ 23 | --write_results \ 24 | --is_master 25 | -------------------------------------------------------------------------------- /scripts/eval_retrieval_acc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | DATA=$1 6 | DATA_PATH="${HOME}/data/preprocessed_data/${DATA}/${DATA}_dpr_test.json" 7 | 8 | N_CTX=$2 9 | BSZ=$3 10 | SIZE=$4 11 | MODEL=$5 12 | BUDGET=$6 13 | TOPK=$7 14 | 15 | NAME="${DATA}_${SIZE}_nctx=${N_CTX}_budget=${BUDGET}_topk=${TOPK}" 16 | 17 | python FiD/test_retrieval_acc.py \ 18 | --model_path $MODEL \ 19 | --checkpoint_dir "$MODEL/eval" \ 20 | --name $NAME \ 21 | --test_data_path $DATA_PATH \ 22 | --model_size $SIZE \ 23 | --per_gpu_batch_size $BSZ \ 24 | --n_context $N_CTX \ 25 | --budget $BUDGET \ 26 | --num_passages_retained $TOPK \ 27 | --write_results \ 28 | --is_master 29 | 30 | # --checkpoint_dir $MODEL \ 31 | -------------------------------------------------------------------------------- /scripts/eval_scheduler_nq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq/nq_dpr_test.json" 6 | DATA="nq" 7 | 8 | N_CTX=$1 9 | BSZ=$2 10 | SIZE=$3 11 | MODEL=$4 12 | BUDGET=$5 13 | TOPK=$6 14 | 15 | NAME="${DATA}_${SIZE}_nctx=${N_CTX}_budget=${BUDGET}_topk=${TOPK}" 16 | 17 | python FiD/test_ac_scheduler.py \ 18 | --model_path $MODEL \ 19 | --checkpoint_dir "$MODEL/eval" \ 20 | --name $NAME \ 21 | --test_data_path $NQ_DATA \ 22 | --model_size $SIZE \ 23 | --per_gpu_batch_size $BSZ \ 24 | --n_context $N_CTX \ 25 | --budget $BUDGET \ 26 | --num_passages_retained $TOPK \ 27 | --write_results \ 28 | --is_master 29 | 30 | # --checkpoint_dir $MODEL \ 31 | -------------------------------------------------------------------------------- /scripts/eval_scheduler_trivia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | TRIVIA_DATA="${HOME}/data/preprocessed_data/trivia/trivia_dpr_test.json" 6 | DATA="trivia" 7 | 8 | N_CTX=$1 9 | BSZ=$2 10 | SIZE=$3 11 | MODEL=$4 12 | BUDGET=$5 13 | TOPK=$6 14 | 15 | NAME="${DATA}_${SIZE}_nctx=${N_CTX}_budget=${BUDGET}_topk=${TOPK}" 16 | 17 | python FiD/test_ac_scheduler.py \ 18 | --model_path $MODEL \ 19 | --checkpoint_dir "$MODEL/eval" \ 20 | --name $NAME \ 21 | --test_data_path $TRIVIA_DATA \ 22 | --model_size $SIZE \ 23 | --per_gpu_batch_size $BSZ \ 24 | --n_context $N_CTX \ 25 | --budget $BUDGET \ 26 | --num_passages_retained $TOPK \ 27 | --write_results \ 28 | --is_master 29 | 30 | # --checkpoint_dir $MODEL \ 31 | -------------------------------------------------------------------------------- /scripts/eval_trivia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate FiD 4 | cd FiD 5 | 6 | NQ_DATA="../data/preprocessed_data/trivia/trivia_dpr_test.json" 7 | DATA="trivia" 8 | SIZE="base" 9 | MODEL="../pretrained_models/triviaqa_base_dpr/" 10 | 11 | BSZ=10 12 | N_CTX=5 13 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 14 | N_CTX=10 15 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 16 | N_CTX=20 17 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 18 | BSZ=5 19 | N_CTX=40 20 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 21 | N_CTX=50 22 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 23 | BSZ=2 24 | N_CTX=80 25 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 26 | N_CTX=100 27 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 28 | 29 | SIZE="large" 30 | MODEL="../pretrained_models/triviaqa_large_dpr/" 31 | 32 | BSZ=8 33 | N_CTX=5 34 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 35 | N_CTX=10 36 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 37 | BSZ=4 38 | N_CTX=20 39 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 40 | BSZ=2 41 | N_CTX=40 42 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 43 | N_CTX=50 44 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 45 | BSZ=1 46 | N_CTX=80 47 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 48 | N_CTX=100 49 | python test.py --model_path $MODEL --test_data_path $NQ_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 50 | -------------------------------------------------------------------------------- /scripts/eval_trivia_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate FiD 4 | cd FiD 5 | 6 | TRIVIA_DATA="../data/preprocessed_data/trivia/trivia_dpr_test.json" 7 | DATA="trivia" 8 | SIZE=$1 9 | MODEL="../pretrained_models/triviaqa_${SIZE}_dpr/" 10 | 11 | BSZ=2 12 | 13 | for N_CTX in 6 12 22 26 30 35 45; do 14 | echo "${SIZE} ${DATA} ${N_CTX}" 15 | python test.py --model_path $MODEL --test_data_path $TRIVIA_DATA --model_size $SIZE --per_gpu_batch_size $BSZ --n_context $N_CTX --name "${DATA}_${SIZE}_nctx=${N_CTX}" --checkpoint_dir ../checkpoint/ 16 | done 17 | -------------------------------------------------------------------------------- /scripts/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_DIR=$(pwd) 4 | echo $ROOT_DIR 5 | 6 | pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 7 | 8 | # Install transformers 9 | #pip install transformers==3.0.2 10 | #pip install transformers==3.1.0 11 | 12 | cd $ROOT_DIR/etc/transformers-v3.0.2 13 | #cd $ROOT_DIR/etc/transformers-v3.1.0 14 | pip install -e . 15 | cd $ROOT_DIR 16 | 17 | ## Install Apex 18 | #cd $ROOT_DIR/etc/apex 19 | #pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 20 | 21 | # Install wandb 22 | pip install wandb 23 | 24 | cd $ROOT_DIR 25 | pip install -r requirements.txt 26 | -------------------------------------------------------------------------------- /scripts/train_ac_nq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq" 6 | 7 | # Copy the entire data folder to avoid corrupting the original data files (an issue in Colab) 8 | TMP_DATA=$(mktemp -d -t data-XXXXXXXXXX) 9 | cp -r $NQ_DATA/* $TMP_DATA 10 | echo "Finished copying data from ${NQ_DATA} to ${TMP_DATA}" 11 | NQ_DATA=$TMP_DATA 12 | #md5sum $NQ_DATA/* 13 | 14 | DATA="nq" 15 | #SIZE="base" 16 | SIZE=$3 17 | NOW=$(date '+%Y%m%d-%H-%M-%S') 18 | CKPT="${HOME}/checkpoints" 19 | MODEL="${HOME}/pretrained_models/nq_${SIZE}_dpr/" 20 | 21 | N_CTX=$1 22 | POOL=$2 23 | NAME="acfid-${SIZE}-${DATA}_${POOL}_${NOW}" 24 | 25 | python FiD/train_ac.py \ 26 | --model_path $MODEL \ 27 | --train_data_path $NQ_DATA/nq_dpr_train.json \ 28 | --dev_data_path $NQ_DATA/nq_dpr_dev.json \ 29 | --dev_data_size 2000 \ 30 | --model_size $SIZE \ 31 | --per_gpu_batch_size $4 \ 32 | --gradient_accumulation_steps $5 \ 33 | --n_context $N_CTX \ 34 | --name "${NAME}" \ 35 | --checkpoint_dir $CKPT \ 36 | --lr 1e-4 \ 37 | --log_freq 50 \ 38 | --eval_freq 5000 \ 39 | --save_freq 5000 \ 40 | --total_step 20000 \ 41 | --is_master \ 42 | --freeze_fid_params \ 43 | --has_answer_pool_type $POOL 44 | # --fp16 45 | # --checkpointing_encoder 46 | 47 | rm -fr $NQ_DATA 48 | -------------------------------------------------------------------------------- /scripts/train_ac_scheduler_nq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq" 6 | 7 | # Copy the entire data folder to avoid corrupting the original data files (an issue in Colab) 8 | TMP_DATA=$(mktemp -d -t data-XXXXXXXXXX) 9 | cp -r $NQ_DATA/* $TMP_DATA 10 | echo "Finished copying data from ${NQ_DATA} to ${TMP_DATA}" 11 | NQ_DATA=$TMP_DATA 12 | #md5sum $NQ_DATA/* 13 | 14 | DATA="nq" 15 | NOW=$(date '+%Y%m%d-%H-%M-%S') 16 | CKPT="${HOME}/checkpoints" 17 | 18 | MODEL=$1 19 | SIZE=$2 20 | BSZ=$3 21 | ACC=$4 22 | LR=$5 23 | DISC=$6 24 | 25 | N_CTX=$7 26 | BUDGET=$8 27 | COST=$9 28 | TYPE=${10} 29 | EMBED=${11} 30 | HID=${12} 31 | K=5 32 | EXTRA=${@:13} 33 | echo $EXTRA 34 | 35 | NAME="acfid-${SIZE}-${DATA}_scheduler_${NOW}" 36 | 37 | python3 FiD/train_ac_scheduler.py \ 38 | --model_path $MODEL \ 39 | --train_data_path $NQ_DATA/nq_dpr_train.json \ 40 | --dev_data_path $NQ_DATA/nq_dpr_dev.json \ 41 | --dev_data_size 2000 \ 42 | --model_size $SIZE \ 43 | --per_gpu_batch_size $BSZ \ 44 | --gradient_accumulation_steps $ACC \ 45 | --n_context $N_CTX \ 46 | --name "${NAME}" \ 47 | --checkpoint_dir $CKPT \ 48 | --lr $LR \ 49 | --log_freq 10 \ 50 | --eval_freq 100 \ 51 | --save_freq 100 \ 52 | --total_step 10000 \ 53 | --is_master \ 54 | --freeze_fid_params \ 55 | --scheduler_type $TYPE \ 56 | --scheduler_n_context $N_CTX \ 57 | --scheduler_embed_size $EMBED \ 58 | --scheduler_hidden_size $HID \ 59 | --budget $BUDGET \ 60 | --num_passages_retained $K \ 61 | --step_cost $COST \ 62 | --discount $DISC \ 63 | $EXTRA 64 | # --freeze_has_answer_heads \ 65 | # --fp16 66 | # --checkpointing_encoder 67 | 68 | rm -fr $NQ_DATA 69 | -------------------------------------------------------------------------------- /scripts/train_ac_scheduler_trivia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | TRIVIA_DATA="${HOME}/data/preprocessed_data/trivia" 6 | 7 | ## Copy the entire data folder to avoid corrupting the original data files (an issue in Colab) 8 | #TMP_DATA=$(mktemp -d -t data-XXXXXXXXXX) 9 | #cp -r $TRIVIA_DATA/* $TMP_DATA 10 | #echo "Finished copying data from ${TRIVIA_DATA} to ${TMP_DATA}" 11 | #TRIVIA_DATA=$TMP_DATA 12 | ##md5sum $TRIVIA_DATA/* 13 | 14 | DATA="trivia" 15 | NOW=$(date '+%Y%m%d-%H-%M-%S') 16 | CKPT="${HOME}/checkpoints" 17 | 18 | MODEL=$1 19 | SIZE=$2 20 | BSZ=$3 21 | ACC=$4 22 | LR=$5 23 | DISC=$6 24 | 25 | N_CTX=$7 26 | BUDGET=$8 27 | COST=$9 28 | TYPE=${10} 29 | EMBED=${11} 30 | HID=${12} 31 | K=5 32 | EXTRA=${@:13} 33 | echo $EXTRA 34 | 35 | NAME="acfid-${SIZE}-${DATA}_scheduler_${NOW}" 36 | 37 | python3 FiD/train_ac_scheduler.py \ 38 | --model_path $MODEL \ 39 | --train_data_path $TRIVIA_DATA/trivia_dpr_train.json \ 40 | --dev_data_path $TRIVIA_DATA/trivia_dpr_dev.json \ 41 | --dev_data_size 2000 \ 42 | --model_size $SIZE \ 43 | --per_gpu_batch_size $BSZ \ 44 | --gradient_accumulation_steps $ACC \ 45 | --n_context $N_CTX \ 46 | --name "${NAME}" \ 47 | --checkpoint_dir $CKPT \ 48 | --lr $LR \ 49 | --log_freq 10 \ 50 | --eval_freq 100 \ 51 | --save_freq 100 \ 52 | --total_step 5000 \ 53 | --is_master \ 54 | --freeze_fid_params \ 55 | --freeze_has_answer_heads \ 56 | --scheduler_type $TYPE \ 57 | --scheduler_n_context $N_CTX \ 58 | --scheduler_embed_size $EMBED \ 59 | --scheduler_hidden_size $HID \ 60 | --budget $BUDGET \ 61 | --num_passages_retained $K \ 62 | --step_cost $COST \ 63 | --discount $DISC \ 64 | --use_rl_loss \ 65 | $EXTRA 66 | # --fp16 67 | # --checkpointing_encoder 68 | 69 | #rm -fr $TRIVIA_DATA 70 | -------------------------------------------------------------------------------- /scripts/train_ac_trivia.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | TRIVIA_DATA="${HOME}/data/preprocessed_data/trivia" 6 | 7 | ## Copy the entire data folder to avoid corrupting the original data files (an issue in Colab) 8 | #TMP_DATA=$(mktemp -d -t data-XXXXXXXXXX) 9 | #cp -r $TRIVIA_DATA/* $TMP_DATA 10 | #echo "Finished copying data from ${TRIVIA_DATA} to ${TMP_DATA}" 11 | #TRIVIA_DATA=$TMP_DATA 12 | ##md5sum $TRIVIA_DATA/* 13 | 14 | DATA="trivia" 15 | N_CTX=$1 16 | POOL=$2 17 | SIZE=$3 18 | BSZ=$4 19 | ACC=$5 20 | 21 | NOW=$(date '+%Y%m%d-%H-%M-%S') 22 | CKPT="${HOME}/checkpoints" 23 | MODEL="${HOME}/pretrained_models/triviaqa_${SIZE}_dpr/" 24 | 25 | NAME="acfid-${SIZE}-${DATA}_${POOL}_${NOW}" 26 | 27 | python FiD/train_ac.py \ 28 | --model_path $MODEL \ 29 | --train_data_path $TRIVIA_DATA/trivia_dpr_train.json \ 30 | --dev_data_path $TRIVIA_DATA/trivia_dpr_dev.json \ 31 | --dev_data_size 2000 \ 32 | --model_size $SIZE \ 33 | --per_gpu_batch_size $BSZ \ 34 | --gradient_accumulation_steps $ACC \ 35 | --n_context $N_CTX \ 36 | --name "${NAME}" \ 37 | --checkpoint_dir $CKPT \ 38 | --lr 1e-4 \ 39 | --log_freq 50 \ 40 | --eval_freq 5000 \ 41 | --save_freq 5000 \ 42 | --total_step 20000 \ 43 | --is_master \ 44 | --freeze_fid_params \ 45 | --has_answer_pool_type $POOL 46 | # --fp16 47 | # --checkpointing_encoder 48 | 49 | #rm -fr $TRIVIA_DATA 50 | -------------------------------------------------------------------------------- /scripts/train_nq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | HOME=$(pwd) 4 | 5 | NQ_DATA="${HOME}/data/preprocessed_data/nq" 6 | 7 | # Copy the entire data folder to avoid corrupting the original data files (an issue in Colab) 8 | TMP_DATA=$(mktemp -d -t data-XXXXXXXXXX) 9 | cp -r $NQ_DATA/* $TMP_DATA 10 | echo "Finished copying data from ${NQ_DATA} to ${TMP_DATA}" 11 | NQ_DATA=$TMP_DATA 12 | md5sum $NQ_DATA/* 13 | 14 | DATA="nq" 15 | SIZE="base" 16 | NOW=$(date '+%Y%m%d-%H-%M-%S') 17 | NAME="fid-${SIZE}-${DATA}_${NOW}" 18 | CKPT="${HOME}/checkpoints/${NAME}" 19 | 20 | python FiD/train.py \ 21 | --train_data_path $NQ_DATA/nq_dpr_train.json \ 22 | --dev_data_path $NQ_DATA/nq_dpr_dev.json \ 23 | --model_size $SIZE \ 24 | --per_gpu_batch_size 2 \ 25 | --n_context 10 \ 26 | --name "${NAME}" \ 27 | --checkpoint_dir $CKPT \ 28 | --lr 1e-4 \ 29 | --eval_freq 1000 \ 30 | --eval_print_freq 1000 \ 31 | --save_freq 1000 \ 32 | --total_step 50000 \ 33 | --is_master 34 | # --checkpointing_encoder 35 | 36 | rm -fr $NQ_DATA 37 | --------------------------------------------------------------------------------