├── img ├── animation.gif └── affiliation.jpeg ├── dpr_scale ├── conf │ ├── task │ │ ├── transform │ │ │ ├── bert_transform_default.yaml │ │ │ └── hf_transform.yaml │ │ ├── optim │ │ │ ├── madgrad.yaml │ │ │ ├── lamb.yaml │ │ │ └── adamw.yaml │ │ ├── lm.yaml │ │ ├── context_encoder_cfg │ │ │ └── hf_encoder.yaml │ │ └── query_encoder_cfg │ │ │ └── hf_encoder.yaml │ ├── datamodule │ │ ├── generate_query_emb.yaml │ │ ├── generate.yaml │ │ └── default.yaml │ ├── trainer │ │ ├── gpu_1_host.yaml │ │ └── slurm.yaml │ ├── checkpoint_callback │ │ ├── periodic_checkpoint.yaml │ │ └── default.yaml │ ├── __init__.py │ ├── lm.yaml │ └── config.py ├── __init__.py ├── optim │ ├── __init__.py │ └── madgrad.py ├── task │ ├── __init__.py │ └── all_gather.py ├── datamodule │ ├── __init__.py │ ├── corpus.py │ ├── lm.py │ └── utils.py ├── models │ ├── __init__.py │ └── hf_encoder.py ├── transforms │ ├── __init__.py │ └── lm_transform.py ├── checkpoint_callback │ ├── __init__.py │ └── periodic_checkpoint.py ├── generate_lm_embeddings.py ├── main.py └── utils │ └── utils.py ├── .gitignore ├── requirements.txt ├── task ├── utils_eval.py ├── task.py ├── create_lama_uhn.py └── load_data.py ├── CONTRIBUTING.md ├── scripts ├── download_corpus.sh ├── train_debug.sh ├── train.sh ├── save_embeddings.sh ├── download_data.sh ├── util_clm.py ├── demo.py ├── prompt.py ├── create_table.py └── clm_prompt.py ├── preprocess ├── mask_spans.py ├── utils.py ├── concat_files.py ├── process_cc_news.py ├── process_wiki.py └── utils_span.py ├── config └── roberta_stopwords.txt ├── npm ├── model.py ├── searcher.py ├── npm_single.py └── npm.py ├── CODE_OF_CONDUCT.md ├── train.md └── README.md /img/animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NPM/HEAD/img/animation.gif -------------------------------------------------------------------------------- /img/affiliation.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/NPM/HEAD/img/affiliation.jpeg -------------------------------------------------------------------------------- /dpr_scale/conf/task/transform/bert_transform_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.transforms.hf_bert.BertTransform 3 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/transform/hf_transform.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.transforms.hf_transform.HFTransform 3 | max_seq_len: 256 4 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/optim/madgrad.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: "madgrad.MADGRAD" 3 | lr: 1.0e-03 4 | eps: 1.0e-06 5 | weight_decay: 0 6 | momentum: 0.9 7 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/lm.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.task.mlm_task.MaskedLanguageModelingTask 3 | warmup_steps: null 4 | pretrained_checkpoint_path: 5 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/optim/lamb.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: "torch_optimizer.Lamb" 3 | lr: 1.0e-05 4 | betas: 5 | - 0.9 6 | - 0.999 7 | eps: 1.0e-08 8 | weight_decay: 0 9 | -------------------------------------------------------------------------------- /dpr_scale/conf/datamodule/generate_query_emb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | datamodule: 3 | _target_: dpr_scale.datamodule.dpr.DenseRetrieverQueriesDataModule 4 | test_batch_size: 128 5 | -------------------------------------------------------------------------------- /dpr_scale/conf/trainer/gpu_1_host.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | gpus: 8 3 | max_epochs: 40 4 | num_nodes: 1 5 | num_sanity_val_steps: 0 6 | log_every_n_steps: 10 7 | gradient_clip_val: 2.0 8 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/context_encoder_cfg/hf_encoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.models.hf_encoder.Encoder 3 | model_path: roberta-large 4 | initialize: true 5 | dropout: 0.1 6 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/query_encoder_cfg/hf_encoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.models.hf_encoder.Encoder 3 | model_path: roberta-large 4 | initialize: true 5 | dropout: 0.1 6 | -------------------------------------------------------------------------------- /dpr_scale/conf/task/optim/adamw.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: "torch.optim.AdamW" 3 | lr: 1.0e-03 4 | betas: 5 | - 0.9 6 | - 0.999 7 | eps: 1.0e-08 8 | weight_decay: 0 9 | amsgrad: false 10 | -------------------------------------------------------------------------------- /dpr_scale/conf/checkpoint_callback/periodic_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.checkpoint_callback.periodic_checkpoint.PeriodicCheckpoint 3 | every: 100 4 | save_last: true 5 | verbose: true 6 | -------------------------------------------------------------------------------- /dpr_scale/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/conf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/task/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /dpr_scale/checkpoint_callback/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /dpr_scale/conf/datamodule/generate.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | datamodule: 3 | _target_: dpr_scale.datamodule.dpr.DenseRetrieverPassagesDataModule 4 | test_path: "/private/home/vladk/data/wikipedia/wiki_passages/psgs_w100.tsv" 5 | test_batch_size: 128 6 | use_title: True 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pyc 3 | *outputs/* 4 | *multirun* 5 | *vscode* 6 | Makefile 7 | *tmp* 8 | *.html 9 | *.out 10 | *.err 11 | *.log 12 | *.json 13 | *.npy 14 | my* 15 | task_data 16 | core 17 | data 18 | save 19 | corpus 20 | train_corpus 21 | deleted_files 22 | 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.1.0.dev3 2 | hydra-submitit-launcher==1.0.1 3 | pyyaml 4 | pathos 5 | torch==1.12.0 6 | pytorch-lightning==1.6.4 7 | fairscale==0.4.6 8 | transformers==4.21.3 9 | datasets==2.3.2 10 | tqdm 11 | prettytable 12 | pyserini 13 | faiss-gpu==1.7.2 14 | gdown 15 | -------------------------------------------------------------------------------- /dpr_scale/conf/checkpoint_callback/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | monitor: valid_loss 4 | mode: min # this means that model_selection_metric should be maximized 5 | save_last: true 6 | verbose: true 7 | filename: checkpoint_best 8 | save_top_k: 3 9 | -------------------------------------------------------------------------------- /dpr_scale/conf/datamodule/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | _target_: dpr_scale.datamodule.dpr.DenseRetrieverJsonlDataModule 3 | train_path: "/private/home/barlaso/repos/DPR/data/retriever/nq-train.jsonl" 4 | val_path: "/private/home/barlaso/repos/DPR/data/retriever/nq-dev.jsonl" 5 | test_path: "/private/home/barlaso/repos/DPR/data/retriever/nq-dev.jsonl" 6 | batch_size: 2 7 | -------------------------------------------------------------------------------- /dpr_scale/conf/trainer/slurm.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /hydra/launcher: submitit_slurm 4 | 5 | trainer: 6 | gpus: 8 7 | num_nodes: 1 8 | max_epochs: 25 9 | num_sanity_val_steps: 0 10 | log_every_n_steps: 10 11 | gradient_clip_val: 2.0 12 | accumulate_grad_batches: 1 13 | strategy: ddp_sharded 14 | precision: 16 15 | 16 | hydra: 17 | launcher: 18 | gpus_per_node: ${trainer.gpus} 19 | tasks_per_node: ${trainer.gpus} 20 | nodes: ${trainer.num_nodes} 21 | timeout_min: 4260 22 | cpus_per_task: 10 23 | mem_gb: 256 24 | constraint: volta32gb 25 | partition: learnlab 26 | sweep: 27 | dir: /checkpoint/${env:USER}/hydra_outputs/${hydra.launcher.name}/${now:%Y-%m-%d-%H%M%S} 28 | -------------------------------------------------------------------------------- /task/utils_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import re 7 | import string 8 | 9 | def normalize_answer(s): 10 | """Lower text and remove punctuation, articles and extra whitespace.""" 11 | def remove_articles(text): 12 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 13 | return re.sub(regex, ' ', text) 14 | def white_space_fix(text): 15 | return ' '.join(text.split()) 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | def lower(text): 20 | return text.lower() 21 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 22 | 23 | -------------------------------------------------------------------------------- /dpr_scale/conf/lm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - config 3 | - override task: lm 4 | - override checkpoint_callback: periodic_checkpoint 5 | 6 | 7 | task: 8 | optim: 9 | lr: 2e-4 10 | warmup_steps: 1237 11 | 12 | datamodule: 13 | _target_: dpr_scale.datamodule.lm.LanguageModelingJsonlDataModule 14 | train_path: cc_net_gpt_tokenized/en_head_train_v0.1_sample0.3_shard*.jsonl+enwiki_gpt_tokenized/en_head_train.jsonl 15 | val_path: cc_net_gpt_tokenized/en_head_debug.jsonl 16 | test_path: cc_net_gpt_tokenized/en_head_debug.jsonl 17 | batch_size: 2 18 | 19 | trainer: 20 | gpus: 8 21 | num_nodes: 1 22 | max_epochs: 2 23 | num_sanity_val_steps: 0 24 | log_every_n_steps: 10 25 | gradient_clip_val: 2.0 26 | precision: 16 27 | strategy: ddp_sharded 28 | limit_val_batches: 0 29 | 30 | checkpoint_callback: 31 | save_weights_only: true 32 | every: 10000 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to NPM 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | TBD 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | ## Coding Style 29 | * 2 spaces for indentation rather than tabs 30 | * 120 character line length 31 | * ... 32 | 33 | ## License 34 | By contributing to NPM toolkit, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /dpr_scale/conf/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # (c) Facebook, Inc. and its affiliates. Confidential and proprietary. 7 | 8 | from dataclasses import dataclass, field 9 | from typing import List, Any 10 | 11 | from hydra.core.config_store import ConfigStore 12 | from omegaconf import MISSING 13 | 14 | defaults = [ 15 | "_self_", 16 | {"task": "dpr"}, 17 | # Model 18 | #{"task/model": "hf_model"}, 19 | {"task/query_encoder_cfg": "hf_encoder"}, 20 | {"task/context_encoder_cfg": "hf_encoder"}, 21 | # Transform 22 | {"task/transform": "hf_transform"}, 23 | # Optim 24 | {"task/optim": "adamw"}, 25 | # Data 26 | {"datamodule": "default"}, 27 | # Trainer 28 | {"trainer": "gpu_1_host"}, 29 | # Trainer callbacks 30 | {"checkpoint_callback": "default"}, 31 | ] 32 | 33 | 34 | @dataclass 35 | class MainConfig: 36 | defaults: List[Any] = field(default_factory=lambda: defaults) 37 | task: Any = MISSING 38 | datamodule: Any = MISSING 39 | trainer: Any = MISSING 40 | test_only: bool = False 41 | checkpoint_callback: Any = MISSING 42 | 43 | cs = ConfigStore.instance() 44 | 45 | cs.store(name="config", node=MainConfig) 46 | -------------------------------------------------------------------------------- /dpr_scale/checkpoint_callback/periodic_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | from pathlib import Path 9 | 10 | from weakref import proxy 11 | 12 | class PeriodicCheckpoint(ModelCheckpoint): 13 | def __init__(self, dirpath=None, filename=None, monitor=None, verbose=False, 14 | save_weights_only=True, save_last=None, every: int = 100): 15 | super().__init__(dirpath=dirpath, filename=filename, 16 | monitor=monitor, verbose=verbose, 17 | save_weights_only=save_weights_only, 18 | save_last=save_last) 19 | self.every = every 20 | 21 | def on_train_batch_end( 22 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, *args, **kwargs 23 | ): 24 | #print ("on_train_batch_end called at step=%d" % pl_module.global_step) 25 | if pl_module.global_rank==0 and pl_module.global_step % self.every == 0: 26 | assert self.dirpath is not None 27 | current = Path(self.dirpath) / f"latest-{pl_module.global_step}.ckpt" 28 | trainer.save_checkpoint(current, self.save_weights_only) 29 | self._last_global_step_saved = trainer.global_step 30 | 31 | print (current) 32 | -------------------------------------------------------------------------------- /scripts/download_corpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | mkdir -p corpus 9 | 10 | if [[ $1 == "closed" ]] ; then 11 | # Download Wikipedia corpus (released by us) 12 | 13 | if [[ -f "corpus/enwiki/0.npy" ]] ; then 14 | echo "enwiki-0 already downloaded" 15 | else 16 | wget https://dl.fbaipublicfiles.com/NPM/enwiki-0.tar.gz -O corpus/enwiki-0.tar.gz 17 | tar -xf corpus/enwiki-0.tar.gz -C corpus && rm -f corpus/enwiki-0.tar.gz 18 | 19 | wget https://dl.fbaipublicfiles.com/NPM/CC-BY-SA-4.0 -O corpus/enwiki/LICENSE 20 | 21 | fi 22 | 23 | # Download rest of the corpus data (released by Shi et al. 2022) 24 | wget https://dl.fbaipublicfiles.com/NPM/corpus.tar.gz -O corpus.tar.gz 25 | tar -xf corpus.tar.gz -C corpus && rm -f corpus.tar.gz 26 | 27 | fi 28 | 29 | if [[ $1 == "enwiki" ]] ; then 30 | # Download Wikipedia corpus (released by us) 31 | wget https://dl.fbaipublicfiles.com/NPM/enwiki.tar.gz -O corpus/enwiki.tar.gz 32 | tar -xf corpus/enwiki.tar.gz -C corpus && rm -f corpus/enwiki.tar.gz 33 | 34 | fi 35 | 36 | if [[ $1 == "new-enwiki" ]] ; then 37 | # Download Wikipedia 2022 corpus (released by us) 38 | 39 | wget https://dl.fbaipublicfiles.com/NPM/new-enwiki.tar.gz -O corpus/new-enwiki.tar.gz 40 | tar -xf corpus/new-enwiki.tar.gz -C corpus && rm -f corpus/new-enwiki.tar.gz 41 | wget https://dl.fbaipublicfiles.com/NPM/CC-BY-SA-4.0 -O corpus/new-enwiki/LICENSE 42 | 43 | fi 44 | 45 | -------------------------------------------------------------------------------- /dpr_scale/models/hf_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | 9 | import torch.nn as nn 10 | from dpr_scale.utils.utils import PathManager 11 | 12 | # @manual=//python/wheel/transformers3:transformers3 13 | from transformers import RobertaForMaskedLM, AutoModelForMaskedLM, AutoConfig 14 | 15 | class Encoder(nn.Module): 16 | def __init__( 17 | self, 18 | model_path: str = "roberta-large", 19 | initialize: bool = True, 20 | dropout: float = 0.1, 21 | num_hidden_layers = None, 22 | hidden_size = None, 23 | vocab_size = None, 24 | ): 25 | super().__init__() 26 | # remove recursive argument which is not supported now 27 | local_model_path = PathManager.get_local_path(model_path) 28 | 29 | cfg = AutoConfig.from_pretrained(local_model_path) 30 | cfg.attention_probs_dropout_prob = dropout 31 | cfg.hidden_dropout_prob = dropout 32 | if num_hidden_layers is not None: 33 | cfg.num_hidden_layers = num_hidden_layers 34 | if hidden_size is not None: 35 | cfg.hidden_size = hidden_size 36 | if vocab_size is not None: 37 | cfg.vocab_size = vocab_size 38 | 39 | if initialize: 40 | self.transformer = AutoModelForMaskedLM.from_pretrained(local_model_path, config=cfg) 41 | print ("Initializing from", local_model_path) 42 | else: 43 | self.transformer = RobertaForMaskedLM(config=cfg) 44 | 45 | def forward(self, tokens): 46 | return self.transformer(**tokens, return_dict=True) 47 | 48 | def add_adapter(self, name, config): 49 | self.transformer.add_adapter(name, config) 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/train_debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | SAVE_DIR=$1 9 | DO_PHRASE=$2 10 | LR=$3 11 | BS=$4 12 | MR=$5 13 | SPAN=$6 14 | P=$7 15 | 16 | init=true 17 | model_type=roberta-base 18 | nd=0.01 19 | wm=4000 20 | num_nodes=1 21 | gpus=1 22 | clip=2.0 23 | msb=true 24 | cmr=0.0 25 | emp=true # make sure masked tokens have positives 26 | ns=half #false #half # how to select negatives 27 | 28 | SAVE_DIR=${SAVE_DIR}/debug-LR-${LR}_BS-${BS}_MR-${MR} 29 | 30 | if [[ $SPAN == "uniform" ]] ; then 31 | train_path=$(pwd)/train_corpus/enwiki/BS${BS}_shard0.jsonl 32 | else 33 | SAVE_DIR=${SAVE_DIR}_P-${P} 34 | train_path=$(pwd)/train_corpus/enwiki/BS${BS}_shard0_mr${MR}_p${P}.jsonl 35 | if [[ $DO_PHRASE == "true" ]] ; then 36 | SPAN="span-merge" 37 | fi 38 | fi 39 | 40 | echo "$train_path" 41 | 42 | PYTHONPATH=. python dpr_scale/main.py -m \ 43 | --config-name=lm.yaml \ 44 | trainer.num_nodes=${num_nodes} \ 45 | trainer.gpus=${gpus} \ 46 | datamodule.batch_size=1 \ 47 | task.optim.lr=${LR} \ 48 | task.optim.weight_decay=0.01 \ 49 | task.warmup_steps=${wm} \ 50 | task.query_encoder_cfg.initialize=${init} \ 51 | task.query_encoder_cfg.model_path=${model_type} \ 52 | +task.do_phrase=${DO_PHRASE} \ 53 | datamodule.train_path="${train_path}" \ 54 | datamodule.val_path=null \ 55 | datamodule.test_path=null \ 56 | +datamodule.bidirectional=true \ 57 | +datamodule.masking_ratio=${MR} \ 58 | +datamodule.enforce_masking_positives=${emp} \ 59 | +datamodule.masking=${SPAN} \ 60 | +task.task_type=contrastive \ 61 | +task.contrastive_maskout_same_block=${msb} \ 62 | +task.contrastive_negative_selection=${ns} \ 63 | +task.contrastive_context_masking_ratio=${cmr} \ 64 | trainer.max_epochs=8 \ 65 | trainer.gradient_clip_val=${clip} 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /preprocess/mask_spans.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import math 8 | import time 9 | import numpy as np 10 | import json 11 | from tqdm import tqdm 12 | from functools import partial 13 | from multiprocessing import Pool 14 | 15 | from utils_span import mask_spans 16 | 17 | def main(): 18 | import argparse 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--data_dir", type=str, default="train_corpus") 21 | parser.add_argument("--mr", type=float, default=0.15) 22 | parser.add_argument("--p", type=float, default=0.5) 23 | parser.add_argument("--batch_size", type=int, default=16) 24 | parser.add_argument("--num_shards", type=int, default=10) 25 | 26 | args = parser.parse_args() 27 | 28 | ext = "_mr{}_p{}.jsonl".format(args.mr, args.p) 29 | 30 | def find_files(out_dir): 31 | if os.path.isdir(out_dir): 32 | return sorted([fn for sub_dir in os.listdir(out_dir) for fn in find_files(os.path.join(out_dir, sub_dir))]) 33 | 34 | fn = out_dir 35 | if fn.split("/")[-1].startswith("BS{}_shard".format(args.batch_size)) and fn.endswith(".jsonl"): 36 | if fn.endswith(ext): 37 | return [] 38 | if os.path.exists(fn.replace(".jsonl", ext)): 39 | return [] 40 | return [fn] 41 | 42 | return [] 43 | 44 | filenames = find_files(args.data_dir) 45 | filenames = [fn for fn in filenames if fn.split(".")[-2][-3] not in ["6", "7", "8"]] 46 | print ("Start span masking for %d files" % len(filenames)) 47 | f = partial(mask_spans, 48 | masking_ratio=args.mr, 49 | p=args.p) 50 | 51 | tot = 0 52 | with Pool(min(len(filenames), 80)) as p: 53 | for _ in p.imap(f, filenames): 54 | tot += 1 55 | 56 | 57 | if __name__=='__main__': 58 | main() 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /preprocess/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from transformers import RobertaTokenizer 3 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 4 | mask_id = tokenizer.mask_token_id 5 | 6 | def create_blocks_from_plain_text(sentences, doc_idx, max_seq_length=256): 7 | input_ids = tokenizer._batch_encode_plus(sentences)["input_ids"] 8 | assert type(input_ids)==list 9 | 10 | curr_input_ids_list = [[]] 11 | for tokens in input_ids: 12 | 13 | if mask_id in tokens: 14 | # sometimes, the raw text contains [MASK]. in this case, we skip. 15 | continue 16 | 17 | if len(tokens) + len(curr_input_ids_list[-1]) <= max_seq_length: 18 | curr_input_ids_list[-1] += tokens 19 | elif len(tokens) <= max_seq_length: 20 | curr_input_ids_list.append(tokens) 21 | else: 22 | while len(tokens) > max_seq_length: 23 | th = max_seq_length-len(curr_input_ids_list[-1]) 24 | curr_input_ids_list[-1] += tokens[:th] 25 | tokens = tokens[th:] 26 | curr_input_ids_list.append([]) 27 | if len(tokens)>0: 28 | curr_input_ids_list[-1] += tokens 29 | 30 | output_lines = [] 31 | n_tokens = [] 32 | for block_idx, _input_ids in enumerate(curr_input_ids_list): 33 | assert 0 Tensor: 19 | ctx.group = group 20 | 21 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 22 | 23 | torch.distributed.all_gather(gathered_tensor, tensor, group=group) 24 | gathered_tensor = torch.stack(gathered_tensor, dim=0) 25 | 26 | return gathered_tensor 27 | 28 | @staticmethod 29 | def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: 30 | grad_output = torch.cat(grad_output) 31 | torch.distributed.all_reduce(grad_output, 32 | op=torch.distributed.ReduceOp.SUM, 33 | async_op=False, 34 | group=ctx.group) 35 | 36 | return grad_output[torch.distributed.get_rank()], None 37 | 38 | # from https://github.com/vlkit/vlkit/blob/master/vlkit/ops/distributed.py 39 | class AllGatherGrad2(torch.autograd.Function): 40 | """ 41 | all_gather with gradient back-propagation 42 | """ 43 | @staticmethod 44 | def forward(ctx, tensor_list, tensor): 45 | torch.distributed.all_gather(tensor_list, tensor) 46 | return tuple(tensor_list) 47 | 48 | @staticmethod 49 | def backward(ctx, *grad_list): 50 | grad_list = list(grad_list) 51 | rank = torch.distributed.get_rank() 52 | 53 | #print ("all gather 1") 54 | dist_ops = [ 55 | torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) 56 | ] 57 | 58 | #print ("all gather 2") 59 | for op in dist_ops: 60 | op.wait() 61 | 62 | #print ("all gather 3") 63 | return None, grad_list[rank] 64 | 65 | my_all_gather = AllGatherGrad.apply 66 | my_all_gather2 = AllGatherGrad2.apply 67 | 68 | 69 | -------------------------------------------------------------------------------- /preprocess/concat_files.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import math 8 | import time 9 | import numpy as np 10 | import json 11 | 12 | def concat_files(filenames, out_file): 13 | start_time = time.time() 14 | n_files = 0 15 | n_lines = 0 16 | assert not os.path.exists(out_file) 17 | print ("Starting %s" % out_file) 18 | with open(out_file, "a+") as f_w: 19 | for filename in filenames: 20 | if not os.path.exists(filename): 21 | continue 22 | with open(filename, "r") as f: 23 | for line in f: 24 | f_w.write(line) 25 | n_lines += 1 26 | n_files += 1 27 | print ("Finish saving %d lines at %s from %d files (%dmin)" % ( 28 | n_lines, out_file, n_files, (time.time()-start_time)/60)) 29 | for filename in filenames: 30 | if os.path.exists(filename): 31 | os.remove(filename) 32 | print ("Finish deleting %d files (%dmin)" % (n_files, (time.time()-start_time)/60)) 33 | 34 | def main(): 35 | import argparse 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--data_dir", type=str, default="train_corpus") 38 | parser.add_argument("--mr", type=float, default=None) 39 | parser.add_argument("--p", type=float, default=None) 40 | 41 | parser.add_argument("--batch_size", type=int, default=16) 42 | parser.add_argument("--num_shards", type=int, default=10) 43 | 44 | args = parser.parse_args() 45 | 46 | if args.mr is None and args.p is None: 47 | ext = ".jsonl" 48 | else: 49 | ext = "_mr{}_p{}.jsonl".format(args.mr, args.p) 50 | 51 | def find_files(out_dir): 52 | if os.path.isdir(out_dir): 53 | return sorted([fn for sub_dir in os.listdir(out_dir) for fn in find_files(os.path.join(out_dir, sub_dir))]) 54 | 55 | fn = out_dir 56 | if fn.split("/")[-1].startswith("BS{}_shard".format(args.batch_size)) and fn.endswith(ext): 57 | return [fn] 58 | return [] 59 | 60 | filenames = find_files(os.path.join(args.data_dir, "cc_news")) 61 | n_files_per_shard = math.ceil(len(filenames) / args.num_shards) 62 | for batch_idx in range(args.num_shards): 63 | curr_filenames = filenames[batch_idx*n_files_per_shard:(batch_idx+1)*n_files_per_shard] 64 | concat_files(curr_filenames, 65 | os.path.join(args.data_dir, "cc_news", 66 | "BS{}_batchshard{}".format(args.batch_size, batch_idx) + ext)) 67 | print ("Finish %d batches" % args.num_shards) 68 | 69 | 70 | if __name__=='__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | SAVE_DIR=$1 10 | DO_PHRASE=$2 11 | LR=$3 12 | BS=$4 13 | MR=$5 14 | SPAN=$6 15 | P=$7 16 | 17 | init=true 18 | wd=0.01 19 | wm=4000 20 | model_type=roberta-large 21 | num_nodes=4 22 | gpus=8 23 | clip=2.0 24 | msb=true 25 | cmr=0.0 26 | emp=true # make sure masked tokens have positives 27 | ns=half #false #half # how to select negatives 28 | 29 | SAVE_DIR=${SAVE_DIR}/LR-${LR}_BS-${BS}_MR-${MR} 30 | 31 | if [[ $SPAN == "uniform" ]] ; then 32 | train_path=$(pwd)/train_corpus/cc_news/BS${BS}_batchshard0.jsonl 33 | train_path=${train_path}+$(pwd)/train_corpus/enwiki/BS${BS}_shard0.jsonl 34 | for i in {1..9} ; do \ 35 | train_path=${train_path}+$(pwd)/train_corpus/cc_news/BS${BS}_batchshard0.jsonl 36 | train_path=${train_path}+$(pwd)/train_corpus/enwiki/BS${BS}_shard0.jsonl 37 | done 38 | 39 | else 40 | SAVE_DIR=${SAVE_DIR}_P-${P} 41 | train_path=$(pwd)/train_corpus/cc_news/BS${BS}_batchshard0_mr${MR}_p${P}.jsonl 42 | train_path=${train_path}+$(pwd)/train_corpus/enwiki/BS${BS}_shard0_mr${MR}_p${P}.jsonl 43 | for i in {1..9} ; do \ 44 | train_path=${train_path}+$(pwd)/train_corpus/cc_news/BS${BS}_batchshard0_mr${MR}_p${P}.jsonl 45 | train_path=${train_path}+$(pwd)/train_corpus/enwiki/BS${BS}_shard0_mr${MR}_p${P}.jsonl 46 | done 47 | if [[ $DO_PHRASE == "true" ]] ; then 48 | SPAN="span-merge" 49 | fi 50 | fi 51 | 52 | echo "$train_path" 53 | 54 | HYDRA_FULL_ERROR=1 PYTHONPATH=. python dpr_scale/main.py -m \ 55 | --config-name=lm.yaml \ 56 | trainer.num_nodes=${num_nodes} \ 57 | trainer.gpus=${gpus} \ 58 | datamodule.batch_size=1 \ 59 | task.optim.lr=${LR} \ 60 | task.optim.weight_decay=0.01 \ 61 | task.warmup_steps=${wm} \ 62 | task.query_encoder_cfg.initialize=${init} \ 63 | task.query_encoder_cfg.model_path=${model_type} \ 64 | +task.do_phrase=${DO_PHRASE} \ 65 | datamodule.train_path="${train_path}" \ 66 | datamodule.val_path=null \ 67 | datamodule.test_path=null \ 68 | +datamodule.bidirectional=true \ 69 | +datamodule.masking_ratio=${MR} \ 70 | +datamodule.enforce_masking_positives=${emp} \ 71 | +datamodule.masking=${SPAN} \ 72 | +task.task_type=contrastive \ 73 | +task.contrastive_maskout_same_block=${msb} \ 74 | +task.contrastive_negative_selection=${ns} \ 75 | +task.contrastive_context_masking_ratio=${cmr} \ 76 | trainer.max_epochs=8 \ 77 | trainer.gradient_clip_val=${clip} \ 78 | trainer=slurm \ 79 | hydra.launcher.name=${SAVE_DIR} \ 80 | hydra.sweep.dir=${SAVE_DIR} 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /npm/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import torch 8 | from transformers import AutoModelForMaskedLM, AutoTokenizer 9 | 10 | class SingleModel(object): 11 | 12 | def __init__(self, checkpoint_path): 13 | if checkpoint_path in ["npm", "npm-single"]: 14 | checkpoint_path = "facebook/" + checkpoint_path 15 | 16 | is_registered = checkpoint_path.startswith("roberta-") or checkpoint_path.startswith("facebook/") 17 | 18 | if is_registered: 19 | self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 20 | self.model = AutoModelForMaskedLM.from_pretrained(checkpoint_path) 21 | print ("Loaded from HF Hub:", checkpoint_path) 22 | else: 23 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/npm") 24 | # loading from trained checkpoint 25 | state_dict = torch.load(checkpoint_path) 26 | if "state_dict" in state_dict: 27 | state_dict = state_dict["state_dict"] 28 | encoder_state_dict = {".".join(k.split(".")[2:]): v for k, v in state_dict.items()} 29 | self.model = AutoModelForMaskedLM.from_pretrained("facebook/npm", state_dict=encoder_state_dict) 30 | print ("Loaded from a local checkpoint:", checkpoint_path) 31 | 32 | self.model.cuda() 33 | self.model.eval() 34 | self.cnt = 0 35 | 36 | def forward(self, input_ids, idx): 37 | if self.cnt < 3: 38 | print (self.tokenizer.decode(input_ids[0])) 39 | 40 | outputs = self.model(input_ids, output_hidden_states=True, return_dict=True) 41 | logits = outputs.logits[0, idx, :] 42 | query = outputs["hidden_states"][-1][:, idx, :] 43 | self.cnt += 1 44 | 45 | return logits, query 46 | 47 | class Model(SingleModel): 48 | 49 | def forward(self, input_ids, idx): 50 | assert len(input_ids)==1 51 | assert input_ids[0, idx]==self.tokenizer.mask_token_id 52 | new_input_ids = torch.cat([input_ids[:, :idx+1], input_ids[:, idx:]], -1) 53 | assert len(new_input_ids[0])==len(input_ids[0])+1 54 | assert new_input_ids[0, idx]==new_input_ids[0, idx+1]==self.tokenizer.mask_token_id 55 | 56 | if self.cnt < 3: 57 | print (self.tokenizer.decode(new_input_ids[0])) 58 | 59 | outputs = self.model(new_input_ids, output_hidden_states=True, return_dict=True) 60 | logits = outputs.logits[0, idx, :] 61 | start_query = outputs["hidden_states"][-1][:, idx, :] 62 | end_query = outputs["hidden_states"][-1][:, idx+1, :] 63 | self.cnt += 1 64 | 65 | return logits, (start_query, end_query) 66 | 67 | 68 | -------------------------------------------------------------------------------- /dpr_scale/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # (c) Facebook, Inc. and its affiliates. Confidential and proprietary. 7 | import os 8 | import hydra 9 | from dpr_scale.conf.config import MainConfig 10 | 11 | from omegaconf import OmegaConf 12 | from pytorch_lightning.callbacks import LearningRateMonitor 13 | from pytorch_lightning.trainer import Trainer 14 | 15 | 16 | """ 17 | Sample commands: 18 | Default: $ buck run //deeplearning/projects/dpr-scale:main 19 | 20 | For debugging Hydra: 21 | $ HYDRA_FULL_ERROR=1 buck run //deeplearning/projects/dpr-scale:main -- --info 22 | """ 23 | 24 | @hydra.main(config_path="conf", config_name="config") 25 | def main(cfg: MainConfig): 26 | print(OmegaConf.to_yaml(cfg)) 27 | #cnt_files(cfg.datamodule.train_path) 28 | #return 29 | 30 | if cfg.test_only: 31 | 32 | def do_test(ckpt_path): 33 | assert os.path.exists(ckpt_path), ckpt_path 34 | cfg.task.pretrained_checkpoint_path = ckpt_path 35 | 36 | task = hydra.utils.instantiate(cfg.task, _recursive_=False) 37 | transform = hydra.utils.instantiate(cfg.task.transform) 38 | datamodule = hydra.utils.instantiate(cfg.datamodule) #, transform=transform) 39 | checkpoint_callback = hydra.utils.instantiate(cfg.checkpoint_callback) 40 | lr_monitor = LearningRateMonitor(logging_interval='step') 41 | trainer = Trainer(**cfg.trainer, callbacks=[checkpoint_callback, lr_monitor]) 42 | 43 | trainer.test( 44 | model=task, 45 | ckpt_path=ckpt_path, 46 | verbose=True, 47 | datamodule=datamodule, 48 | ) 49 | 50 | ckpt_path = cfg.task.pretrained_checkpoint_path 51 | if "+" in ckpt_path: 52 | segments = ckpt_path.split("+") 53 | assert len(segments)>=3 54 | for ckpt_idx in range(1, len(segments)-1): 55 | ckpt_path = segments[0] + segments[ckpt_idx] + segments[-1] 56 | do_test(ckpt_path) 57 | else: 58 | do_test(ckpt_path) 59 | 60 | else: 61 | task = hydra.utils.instantiate(cfg.task, _recursive_=False) 62 | 63 | #assert cfg.task.model.model_path == cfg.task.transform.model_path 64 | transform = None #hydra.utils.instantiate(cfg.task.transform) 65 | datamodule = hydra.utils.instantiate(cfg.datamodule) #, transform=transform) 66 | checkpoint_callback = hydra.utils.instantiate(cfg.checkpoint_callback) 67 | lr_monitor = LearningRateMonitor(logging_interval='step') 68 | trainer = Trainer(**cfg.trainer, callbacks=[checkpoint_callback, lr_monitor]) 69 | 70 | trainer.fit(task, datamodule=datamodule) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /dpr_scale/datamodule/corpus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | 9 | import json 10 | import mmap 11 | 12 | import torch 13 | from dpr_scale.transforms.lm_transform import LanguageModelingTransform 14 | from dpr_scale.utils.utils import ( 15 | ContiguousDistributedSampler, 16 | ContiguousDistributedSamplerForTest, 17 | PathManager, 18 | maybe_add_title, 19 | ) 20 | from pytorch_lightning import LightningDataModule 21 | 22 | from dpr_scale.datamodule.utils import CorpusDataset 23 | 24 | class CorpusDataModule(LightningDataModule): 25 | def __init__( 26 | self, 27 | train_path: str, 28 | val_path: str, 29 | test_path: str, 30 | batch_size: int = 2, 31 | val_batch_size: int = 0, # defaults to batch_size 32 | test_batch_size: int = 0, # defaults to val_batch_size 33 | ): 34 | super().__init__() 35 | self.batch_size = batch_size 36 | self.num_workers = 1 37 | self.val_batch_size = val_batch_size if val_batch_size else batch_size 38 | self.test_batch_size = ( 39 | test_batch_size if test_batch_size else self.val_batch_size 40 | ) 41 | 42 | self.datasets = { 43 | "train": CorpusDataset(train_path), 44 | "valid": CorpusDataset(val_path), 45 | "test": CorpusDataset(test_path), 46 | } 47 | 48 | def train_dataloader(self): 49 | return torch.utils.data.DataLoader( 50 | self.datasets["train"], 51 | shuffle=False, 52 | batch_size=self.batch_size, 53 | num_workers=self.num_workers, 54 | collate_fn=self.collate_eval, 55 | ) 56 | 57 | def val_dataloader(self): 58 | return torch.utils.data.DataLoader( 59 | self.datasets["valid"], 60 | shuffle=False, 61 | batch_size=self.val_batch_size, 62 | num_workers=self.num_workers, 63 | collate_fn=self.collate_eval, 64 | ) 65 | 66 | def test_dataloader(self): 67 | return torch.utils.data.DataLoader( 68 | self.datasets["test"], 69 | shuffle=False, 70 | batch_size=self.test_batch_size, 71 | num_workers=self.num_workers, 72 | collate_fn=self.collate_test, 73 | ) 74 | 75 | def collate_eval(self, batch): 76 | return self.collate(batch, "eval") 77 | 78 | def collate_test(self, batch): 79 | return self.collate(batch, "test") 80 | 81 | def collate_train(self, batch): 82 | return self.collate(batch, "train") 83 | 84 | def collate(self, batch, stage): 85 | return {"input_ids": torch.LongTensor([b["input_ids"] for b in batch]), 86 | "attention_mask": torch.LongTensor([b["attention_mask"] for b in batch]), 87 | "is_valid": torch.LongTensor([b["is_valid"] for b in batch]), 88 | } 89 | 90 | 91 | -------------------------------------------------------------------------------- /scripts/save_embeddings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | model_name=$1 9 | corpus=$2 10 | open=$3 11 | bs=$4 12 | 13 | 14 | out=$(pwd)/save/${model_name} 15 | ctx_embeddings_dir=${out}/dstore/${corpus} 16 | 17 | if [[ $open == "true" ]] ; then 18 | if [[ $corpus == "enwiki-"* ]] ; then 19 | 20 | arr=(${corpus//-/ }) 21 | data_path=$(pwd)/corpus/enwiki/${arr[1]}.npy 22 | if [[ -f "${ctx_embeddings_dir}/embeddings_wo_stopwords.float16.npy" ]] ; then 23 | echo "embeddings already saved" 24 | else 25 | PYTHONPATH=. python \ 26 | dpr_scale/generate_lm_embeddings.py -m \ 27 | --config-name lm.yaml \ 28 | datamodule._target_=dpr_scale.datamodule.corpus.CorpusDataModule \ 29 | datamodule.batch_size=$bs \ 30 | datamodule.train_path=null \ 31 | datamodule.val_path=null \ 32 | datamodule.test_path=${data_path} \ 33 | trainer.num_nodes=1 \ 34 | trainer.precision=32 \ 35 | trainer.gpus=1 \ 36 | task.query_encoder_cfg.model_path=facebook/${model_name} \ 37 | +task.ctx_embeddings_dir=${ctx_embeddings_dir} \ 38 | +task.stopwords_dir=$(pwd)/config \ 39 | +task.task_type="contrastive" \ 40 | +task.remove_stopwords=true \ 41 | trainer=slurm \ 42 | hydra.launcher.name=npm-${corpus} \ 43 | hydra.sweep.dir=${out} \ 44 | hydra.launcher.partition=devlab \ 45 | hydra.launcher.cpus_per_task=5 46 | fi 47 | else 48 | echo "corpus has to be enwiki-* (currently, ${corpus})" 49 | exit 50 | fi 51 | 52 | else 53 | if [[ $corpus == "enwiki-0" ]] ; then 54 | data_path=$(pwd)/corpus/enwiki/0.npy 55 | else 56 | data_path=$(pwd)/corpus/${corpus}/text.npy 57 | fi 58 | 59 | if [[ -f "${ctx_embeddings_dir}/embeddings.float16.npy" ]] ; then 60 | echo "embeddings already saved" 61 | else 62 | PYTHONPATH=. python \ 63 | dpr_scale/generate_lm_embeddings.py -m \ 64 | --config-name lm.yaml \ 65 | datamodule._target_=dpr_scale.datamodule.corpus.CorpusDataModule \ 66 | datamodule.batch_size=$bs \ 67 | datamodule.train_path=null \ 68 | datamodule.val_path=null \ 69 | datamodule.test_path=${data_path} \ 70 | trainer.num_nodes=1 \ 71 | trainer.precision=32 \ 72 | trainer.gpus=1 \ 73 | task.query_encoder_cfg.model_path=facebook/${model_name} \ 74 | +task.ctx_embeddings_dir=${ctx_embeddings_dir} \ 75 | +task.task_type="contrastive" \ 76 | +task.remove_stopwords=false \ 77 | #trainer=slurm \ 78 | #hydra.launcher.name=npm-${corpus} \ 79 | #hydra.sweep.dir=${out} \ 80 | #hydra.launcher.partition=devlab \ 81 | #hydra.launcher.cpus_per_task=5 82 | fi 83 | 84 | fi 85 | -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | mkdir -p data 9 | 10 | if [[ $1 == "closed" ]] ; then 11 | # Download evaluation datasets 12 | 13 | # SST-2 (script provided by Holtzman, West et al. 2022) 14 | 15 | mkdir data/sst2 16 | wget https://raw.githubusercontent.com/prrao87/fine-grained-sentiment/master/data/sst/sst_dev.txt -O data/sst2/dev.tsv 17 | wget https://raw.githubusercontent.com/prrao87/fine-grained-sentiment/master/data/sst/sst_test.txt -O data/sst2/test.tsv 18 | wget https://raw.githubusercontent.com/prrao87/fine-grained-sentiment/master/data/sst/sst_train.txt -O data/sst2/train.tsv 19 | 20 | # AGN (data provided by Zhao et al. 2021) 21 | 22 | mkdir data/agn/ 23 | wget https://github.com/tonyzhaozh/few-shot-learning/raw/main/data/agnews/train.csv -O data/agn/dev.csv 24 | 25 | # MR and CR (data provided by Gao et al. 2021) 26 | 27 | wget https://nlp.cs.princeton.edu/projects/lm-bff/datasets.tar 28 | tar xvf datasets.tar 29 | mv original/mr data/mr 30 | mv original/cr data/cr 31 | mv original/subj data/subj 32 | rm -f datasets.tar 33 | rm -rf original 34 | 35 | # Download fuzzy verbalizers (released by Shi et al. 2022) 36 | if [[ -d "data/fuzzy_verbalizers" ]] ; then 37 | echo "fuzzy_verbalizers already downloaded" 38 | else 39 | gdown 1aRlqMnNyJbgkMm6vgpfEGonukuX_eZEk -O data/ 40 | unzip data/fuzzy_verbalizers.zip -d data/ && rm -f data/fuzzy_verbalizers.zip 41 | fi 42 | 43 | 44 | fi 45 | 46 | if [[ $1 == "open" ]] ; then 47 | # LAMA (data provided by Petroni et al. 2019 and Zhong et al. 2022) 48 | 49 | wget https://dl.fbaipublicfiles.com/LAMA/data.zip 50 | unzip data.zip -d data/lama 51 | rm data.zip 52 | 53 | wget https://nlp.cs.princeton.edu/projects/optiprompt/data.tar.gz 54 | tar -xf data.tar.gz -C data/lama 55 | rm -f data.tar.gz 56 | 57 | python task/create_lama_uhn.py --srcdir data/lama/data/Google_RE 58 | 59 | rm -rf data/lama/data/ConceptNet 60 | rm -rf data/lama/data/Squad 61 | rm -rf data/lama/data/autoprompt_data 62 | rm -rf data/lama/data/cmp_lms_data 63 | 64 | # NQ (data provided by Lee et al. 2019 (re-formatted by Min et al. 2020)) 65 | 66 | wget -P data/nq https://nlp.cs.washington.edu/ambigqa/data/nqopen-test.json 67 | wget -P data/nq https://nlp.cs.washington.edu/ambigqa/data/test_id2answers.json 68 | 69 | # KAMEL (data provided by Kalo and Fichtel, 2022) 70 | 71 | wget -O kamel.zip https://github.com/JanKalo/KAMEL/blob/master/data/kamel.zip?raw=true 72 | unzip kamel.zip -d data/kamel 73 | wget -P data/kamel https://raw.githubusercontent.com/JanKalo/KAMEL/master/question-templates.csv 74 | 75 | rm -rf kamel.zip 76 | rm -rf data/kamel/__MACOSX 77 | 78 | # Entity translation (released by us) 79 | 80 | wget https://dl.fbaipublicfiles.com/NPM/entity_translation.tar.gz -O data/entity_translation.tar.gz 81 | tar -xf data/entity_translation.tar.gz -C data && rm -f data/entity_translation.tar.gz 82 | 83 | fi 84 | 85 | if [[ $1 == "templama" ]] ; then 86 | # TempLAMA (changed and unchanged) released by us 87 | 88 | wget https://dl.fbaipublicfiles.com/NPM/templama.tar.gz -O data/templama.tar.gz 89 | tar -xf data/templama.tar.gz -C data && rm -f data/templama.tar.gz 90 | 91 | fi 92 | 93 | 94 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /task/task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import numpy as np 8 | from collections import defaultdict 9 | 10 | #from task.data_loaders import * 11 | from task.load_data import load_data, load_fuzzy_verbalizer 12 | 13 | class Task(object): 14 | 15 | def __init__(self, dataname, data_dir, n_samples=0): 16 | examples = load_data(dataname, data_dir) 17 | 18 | if dataname in ["sst2", "mr", "rt", "cr", "amazon"]: 19 | self.label2syn = load_fuzzy_verbalizer(os.path.join(data_dir, "fuzzy_verbalizers/sst2.txt")) 20 | elif dataname=="agn": 21 | self.label2syn = load_fuzzy_verbalizer(os.path.join(data_dir, "fuzzy_verbalizers/agn.txt")) 22 | elif dataname=="yahoo": 23 | self.label2syn = load_fuzzy_verbalizer(os.path.join(data_dir, "fuzzy_verbalizers/yahoo.json")) 24 | else: 25 | self.label2syn = None 26 | 27 | if dataname.startswith("lama-"): 28 | for i, ex in enumerate(examples): 29 | examples[i]["ngram"] = min(4, np.min([len(a) for a in ex["tokenized_answers"]])) 30 | 31 | if n_samples: 32 | np.random.seed(0) 33 | examples_sample = [] 34 | 35 | if dataname=="entity_translation": 36 | examples_dict = defaultdict(list) 37 | for ex in examples: 38 | examples_dict[ex["lang"]].append(ex) 39 | for lang, curr_examples in examples_dict.items(): 40 | indices = np.random.permutation(range(len(curr_examples)))[:n_samples // 3] 41 | examples_sample += [curr_examples[i] for i in indices] 42 | np.random.shuffle(examples_sample) 43 | elif dataname.startswith("lama-"): 44 | examples_dict = defaultdict(list) 45 | for ex in examples: 46 | examples_dict[ex["ngram"]].append(ex) 47 | for n, curr_examples in examples_dict.items(): 48 | indices = np.random.permutation(range(len(curr_examples)))[:n_samples // 3] 49 | examples_sample += [curr_examples[i] for i in indices] 50 | np.random.shuffle(examples_sample) 51 | else: 52 | for i in np.random.permutation(range(len(examples)))[:n_samples]: 53 | examples_sample.append(examples[i]) 54 | 55 | examples = examples_sample 56 | 57 | self.dataname = dataname 58 | self.examples = examples 59 | 60 | if dataname.startswith("lama-"): 61 | self.ngrams = [ex["ngram"] for ex in examples] 62 | elif dataname=="entity_translation": 63 | self.ngrams = [ex["lang"] for ex in examples] 64 | else: 65 | self.ngrams = None 66 | 67 | if dataname.startswith("lama-"): 68 | self.is_question = False 69 | elif dataname in ["kamel", "nq", "triviaqa"]: 70 | self.is_question = True 71 | else: 72 | self.is_question = False 73 | 74 | def __str__(self): 75 | return "Task: " + self.dataname 76 | 77 | def __len__(self): 78 | return len(self.examples) 79 | 80 | if __name__=='__main__': 81 | data_dir = "data" 82 | #for dataname in ["agn", "yahoo", "subj", "rte", "sst2", "mr", "rt", "cr", "amazon"]: 83 | # task = Task(dataname, data_dir) 84 | 85 | for dataname in ["lama-trex", 86 | "lama-google_re", 87 | #"kamel", "triviaqa", "nq" 88 | ]: 89 | task = Task(dataname, data_dir) 90 | assert "input" in task.examples[0] and "answers" in task.examples[0], (dataname, task.examples[0]) 91 | 92 | 93 | -------------------------------------------------------------------------------- /dpr_scale/datamodule/lm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | 9 | import json 10 | import mmap 11 | 12 | import torch 13 | from dpr_scale.transforms.lm_transform import LanguageModelingTransform 14 | from dpr_scale.utils.utils import ( 15 | ContiguousDistributedSampler, 16 | ContiguousDistributedSamplerForTest, 17 | PathManager, 18 | maybe_add_title, 19 | ) 20 | from pytorch_lightning import LightningDataModule 21 | 22 | from dpr_scale.datamodule.utils import MemoryMappedDataset 23 | 24 | class LanguageModelingJsonlDataModule(LightningDataModule): 25 | def __init__( 26 | self, 27 | train_path: str, 28 | val_path: str, 29 | test_path: str, 30 | batch_size: int = 2, 31 | val_batch_size: int = 0, # defaults to batch_size 32 | test_batch_size: int = 0, # defaults to val_batch_size 33 | bidirectional: bool = True, 34 | masking: str = None, 35 | masking_ratio: float = 0.0, 36 | enforce_masking_positives: bool = False, 37 | num_workers: int = 0, # increasing this bugs out right now 38 | max_cnt: int = -1, 39 | ): 40 | super().__init__() 41 | self.batch_size = batch_size 42 | self.val_batch_size = val_batch_size if val_batch_size else batch_size 43 | self.test_batch_size = ( 44 | test_batch_size if test_batch_size else self.val_batch_size 45 | ) 46 | 47 | _path = train_path if train_path is not None else test_path 48 | 49 | self.dpr_transform = LanguageModelingTransform( 50 | bidirectional=bidirectional, 51 | masking=masking, 52 | masking_ratio=masking_ratio, 53 | enforce_masking_positives=enforce_masking_positives 54 | ) 55 | self.num_workers = num_workers 56 | self.datasets = { 57 | "train": MemoryMappedDataset(train_path, max_cnt=max_cnt), 58 | "valid": MemoryMappedDataset(val_path, max_cnt=max_cnt), 59 | "test": MemoryMappedDataset(test_path, max_cnt=max_cnt), 60 | } 61 | 62 | def train_dataloader(self): 63 | sampler = None 64 | if ( 65 | self.trainer 66 | and hasattr(self.trainer, "world_size") 67 | and self.trainer.world_size > 1 68 | ): 69 | sampler = ContiguousDistributedSampler( 70 | self.datasets["train"], num_replicas_per_node=self.trainer.gpus 71 | ) 72 | 73 | return torch.utils.data.DataLoader( 74 | self.datasets["train"], 75 | batch_size=self.batch_size, 76 | num_workers=self.num_workers, 77 | collate_fn=self.collate_train, 78 | sampler=sampler, 79 | ) 80 | 81 | def val_dataloader(self): 82 | return torch.utils.data.DataLoader( 83 | self.datasets["valid"], 84 | shuffle=False, 85 | batch_size=self.val_batch_size, 86 | num_workers=self.num_workers, 87 | collate_fn=self.collate_eval, 88 | ) 89 | 90 | def test_dataloader(self): 91 | return torch.utils.data.DataLoader( 92 | self.datasets["test"], 93 | shuffle=False, 94 | batch_size=self.test_batch_size, 95 | num_workers=self.num_workers, 96 | collate_fn=self.collate_test, 97 | ) 98 | 99 | def collate_eval(self, batch): 100 | return self.collate(batch, "eval") 101 | 102 | def collate_test(self, batch): 103 | return self.collate(batch, "test") 104 | 105 | def collate_train(self, batch): 106 | return self.collate(batch, "train") 107 | 108 | def collate(self, batch, stage): 109 | #print ("datamodule collate called") 110 | return self.dpr_transform(batch, stage) 111 | 112 | -------------------------------------------------------------------------------- /preprocess/process_cc_news.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import gzip 8 | import json 9 | import time 10 | import argparse 11 | import numpy as np 12 | 13 | from functools import partial 14 | from collections import Counter, defaultdict 15 | from utils import create_blocks_from_plain_text 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--in_dir", 20 | type=str, 21 | default="/datasets01/CC-NEWS/022719/json") 22 | parser.add_argument("--out_dir", 23 | type=str, 24 | default="train_corpus/cc_news/") 25 | 26 | parser.add_argument("--batch_size", type=int, default=16) 27 | parser.add_argument("--max_seq_length", type=int, default=256) 28 | parser.add_argument("--num_shards", type=int, default=883) 29 | parser.add_argument("--save_flatten_data", action="store_true") 30 | parser.add_argument("--save_nested_data", action="store_true") 31 | 32 | args = parser.parse_args() 33 | 34 | assert args.save_flatten_data or args.save_nested_data 35 | 36 | if not os.path.exists(args.out_dir): 37 | os.makedirs(args.out_dir) 38 | 39 | grouped_filenames = defaultdict(list) 40 | for fn_idx, fn in enumerate(sorted(os.listdir(args.in_dir), reverse=True)): 41 | date = fn.split("-")[-2][:8] 42 | grouped_filenames[date].append(os.path.join(args.in_dir, fn)) 43 | 44 | filenames = list(enumerate(sorted(list(grouped_filenames.values()), reverse=True))) 45 | print ("We will only process %d out of %d shards (dates)" % (args.num_shards, len(filenames))) 46 | 47 | start_time = time.time() 48 | n_blocks = [] 49 | 50 | from multiprocessing import Pool 51 | with Pool(min(args.num_shards, 60)) as p: 52 | for curr_n_blocks in p.imap(partial(process_file, args=args), filenames[:args.num_shards]): 53 | n_blocks.append(curr_n_blocks) 54 | print ("Finish processing %d/%d files (%.1fM blocks, %dmin)" % ( 55 | len(n_blocks), 56 | args.num_shards, 57 | np.sum(n_blocks) / 1000000, 58 | (time.time() - start_time)/60 59 | )) 60 | 61 | def process_file(pair, args): 62 | fn_idx, filenames = pair 63 | doc_id = str(fn_idx) 64 | lines = [] 65 | print ("Start reading %d files for idx=%s" % (len(filenames), doc_id)) 66 | for fn in filenames: 67 | with gzip.open(fn, "r") as f: 68 | for line in f: 69 | dp = json.loads(line) 70 | 71 | lang = dp["language"] 72 | if lang!="en": 73 | continue 74 | 75 | title = dp["title"] 76 | text = dp["text"] 77 | 78 | if text is None: 79 | continue 80 | 81 | if title is not None: 82 | text = title.strip() + ". " + text.strip() 83 | 84 | lines.append(text) 85 | 86 | if len(lines)==0: 87 | return 0 88 | 89 | outputs, n_tokens = create_blocks_from_plain_text(lines, doc_idx=doc_id, max_seq_length=args.max_seq_length) 90 | print ("Saving %dK tokens, %d output sequences from %d text lines for idx=%s" % ( 91 | np.sum(n_tokens)/1000, len(outputs), len(lines), doc_id)) 92 | 93 | if args.save_flatten_data: 94 | out_file = os.path.join(args.out_dir, "flatten_shard{}.jsonl".format(doc_id)) 95 | with open(out_file, "w") as f: 96 | for dp in outputs: 97 | f.write(json.dumps(dp)+"\n") 98 | 99 | if args.save_nested_data: 100 | out_file = os.path.join(args.out_dir, "BS{}_shard{}.jsonl".format(args.batch_size, doc_id)) 101 | with open(out_file, "w") as f: 102 | for idx in range(len(outputs) // args.batch_size): 103 | curr_data = outputs[idx*args.batch_size:(idx+1)*args.batch_size] 104 | assert len(curr_data)==args.batch_size 105 | grouped_dp = {} 106 | for k in curr_data[0]: 107 | v = [dp[k] for dp in curr_data] 108 | assert len(v)==args.batch_size 109 | grouped_dp[k] = v 110 | f.write(json.dumps(grouped_dp)+"\n") 111 | 112 | return len(outputs) 113 | 114 | if __name__=='__main__': 115 | main() 116 | 117 | 118 | -------------------------------------------------------------------------------- /npm/searcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import subprocess 8 | import numpy as np 9 | from pyserini.search.lucene import LuceneSearcher 10 | 11 | class BM25Searcher(object): 12 | 13 | def __init__(self, data_dir, index_dir): 14 | if not os.path.exists(index_dir): 15 | self.build_index(data_dir, index_dir) 16 | self.searcher = LuceneSearcher(index_dir) 17 | print ("Loaded BM25 index from %s" % index_dir) 18 | 19 | def build_index(self, data_dir, index_dir): 20 | print ("Start building index for %s at %s" % (data_dir, index_dir)) 21 | command = """python -m pyserini.index.lucene \ 22 | --collection JsonCollection \ 23 | --input %s \ 24 | --index %s \ 25 | --generator DefaultLuceneDocumentGenerator \ 26 | --threads 1""" % (data_dir, index_dir) 27 | ret_code = subprocess.run([command], 28 | shell=True, 29 | #stdout=subprocess.DEVNULL, 30 | #stderr=subprocess.STDOUT 31 | ) 32 | if ret_code.returncode != 0: 33 | print("Failed to build the index") 34 | exit() 35 | else: 36 | print("Successfully built the index") 37 | 38 | def search(self, questions, k=3, is_question=False): 39 | is_batch = type(questions)==list 40 | if not is_batch: 41 | questions = [questions] 42 | 43 | results = [] 44 | for question in questions: 45 | if is_question: 46 | question = question[:question.index("?")+1].strip() 47 | else: 48 | question = question.replace("", "_____") 49 | ids = [int(hit.docid) for hit in self.searcher.search(question, k=k)] 50 | assert len(set(ids))==len(ids), (len(results), ids) 51 | results.append(ids) 52 | 53 | if not is_batch: 54 | return results[0] 55 | return results 56 | 57 | def batch_search(self, input_): 58 | from task.task import Task 59 | ''' 60 | input_ can be either 61 | - an instance of the class `Task` 62 | - a list of integers: a list of block indices you will be restricted to 63 | - a list of strings: a list of inputs, if these are all you will use, so that a list of 64 | block indices can be computed offline 65 | - a dictionary: string->a list of intergers, precomputed BM25 block indices 66 | - True: meaning you will use restricted search but on the fly. this will load all the embeddings 67 | - False or None: you will not use restricted search 68 | ''' 69 | 70 | def _flatten(ids): 71 | if type(ids)!=list: 72 | return ids 73 | return [_id for _ids in ids for _id in _ids] 74 | 75 | if type(input_)==list and isinstance(input_[0], Task): 76 | restricted_dict = {} 77 | restricted = set() 78 | for task in input_: 79 | restricted_inputs = [ex["input"] for ex in task.examples] 80 | retrieved_ids = self.search(restricted_inputs, is_question=task.is_question) 81 | restricted_dict.update({ 82 | _input: _ids for _input, _ids in zip(restricted_inputs, retrieved_ids) 83 | }) 84 | restricted |= set(_flatten(retrieved_ids)) 85 | elif isinstance(input_, Task): 86 | task = input_ 87 | restricted_inputs = [ex["input"] for ex in task.examples] 88 | retrieved_ids = self.search(restricted_inputs, is_question=task.is_question) 89 | restricted_dict = {_input: _ids for _input, _ids in zip(restricted_inputs, retrieved_ids)} 90 | restricted = _flatten(retrieved_ids) 91 | restricted = set(restricted) 92 | elif type(input_)==list and np.all([type(r)==str for r in input_]): 93 | retrieved_ids = self.search(input_) 94 | restricted_dict = {_input: _ids for _input, _ids in zip(input_, retrieved_ids)} 95 | restricted = _flatten(retrieved_ids) 96 | restricted = set(restricted) 97 | elif type(input_)==list and np.all([type(r)==int for r in input_]): 98 | restricted = set(self.restricted) 99 | restricted_dict = {} 100 | elif type(input_)==dict: 101 | restricted_dict = {k: v.copy() for k, v in input_.items()} 102 | restricted = set(_flatten(self.restricted_dict.values())) 103 | else: 104 | restricted = True 105 | restricted_dict = {} 106 | 107 | return restricted, restricted_dict 108 | 109 | 110 | -------------------------------------------------------------------------------- /train.md: -------------------------------------------------------------------------------- 1 | # Training NPM 2 | 3 | This is a guideline for training the NPM model. The training code is largely based on [facebookresearch/dpr-scale](https://github.com/facebookresearch/dpr-scale). 4 | 5 | ## Content 6 | 7 | 1. [Prepare Training Data](#prepare-training-data) 8 | * [Preprocessing](#preprocessing) 9 | * [Span Masking](#span-masking) 10 | * [Uniform Masking](#uniform-masking) 11 | 2. [Training](#training) 12 | * [Debugging locally](#debugging-locally): see this if you want to do a test run before running the entire pipeline. 13 | 3. [Evaluation](#evaluation) 14 | 15 | ## Prepare Training Data 16 | 17 | ### Preprocessing 18 | 19 | #### Wikipedia 20 | You need a Wikipedia file that following the format of [the KILT knowledge base](https://github.com/facebookresearch/KILT). Run 21 | ```bash 22 | python3 preprocess/process_wiki.py \ 23 | --in_path {a_json_file_in_kilt_format} \ 24 | --save_nested_data \ 25 | --shard_data 26 | ``` 27 | This will save `train_corpus/enwiki/text_shard[0-9].jsonl` (the sharded raw text files) and `train_corpus/enwiki/BS16_shard[0-9].jsonl` (preprocessed files). 28 | 29 | #### CC News 30 | You need CC News data in a specific format. Please see `process_file` in `preprocess/process_cc_news.py` to see the data format, or modify the function to read the data file you have. 31 | ```bash 32 | python3 preprocess/process_cc_news.py \ 33 | --in_dir {a_dir_containing_json_files} \ 34 | --save_nested_data 35 | ``` 36 | This will save `train_corpus/cc_news/BS16_shard*.jsonl` (preprocessed files). 37 | 38 | Note: by default, we are using `--batch_size 16`, which is good for training with 32GB GPUs. If you are using GPUs with smaller/larger memory, please modify it accordingly. It is highly recommended to use the largest possible batch size. 39 | 40 | ### Span Masking 41 | 42 | To save the data with span masking, run the following: 43 | ```bash 44 | python3 preprocess/mask_spans.py --mr 0.15 --p 0.5 45 | ``` 46 | 47 | In case of CC News, if the number of shards is larger than 10, the training script may not work. Therefore, we run the following to merge files so that the number of shards is 10. 48 | ```bash 49 | python3 preprocess/concat_files.py --mr 0.15 --p 0.5 50 | ``` 51 | 52 | When you are done, the following files are ready to be used for training. 53 | ```bash 54 | train_corpus 55 | /enwiki 56 | /BS16_shard[0-9]_mr0.15_p0.5.jsonl 57 | /cc_news 58 | /BS16_trsinahrd[0-9]_mr0.15_p0.5.jsonl 59 | ``` 60 | 61 | 62 | ### Uniform Masking 63 | 64 | You can optionally use uniform masking instead of span masking if you are interested in NPM-single (a variant of NPM that retrieves tokens instead of phrases). If you want to explore uniform masking, skip `preprocess/mask_spans.py`. You still need to concat files via `python3 preprocess/concat_files.py`. 65 | 66 | When you are done, the following files are ready to be used for training. 67 | ```bash 68 | train_corpus 69 | /enwiki 70 | /BS16_shard[0-9].jsonl 71 | /cc_news 72 | /BS16_trsinahrd[0-9].jsonl 73 | ``` 74 | 75 | ## Training 76 | 77 | To train NPM with span masking, run 78 | ```bash 79 | bash scripts/train.sh {save_dir} true 3e-05 16 0.15 span 0.5 80 | ``` 81 | Each argument indicates save dir, whether it is a phrase retrieval model, learning rate, batch size, masking ratio, masking strategy, and p (a hyperparameter for span masking). 82 | 83 | By default, we use 32 GPUs (4 nodes, 8 GPUs/node), each with 32GB memory. We use slurm and [hydra](https://github.com/facebookresearch/hydra) to run training. To run training with different configurations, see the command in `scripts/train.sh`. 84 | 85 | You can use tensorboard to monitor training: `tensorboard --logdir {save_dir}`. 86 | 87 | To train NPM-single with span masking, run 88 | ``` 89 | bash scripts/train.sh {save_dir} false 3e-05 16 0.15 span 0.5 90 | ``` 91 | 92 | To train NPM-single with uniform masking, run 93 | ``` 94 | bash scripts/train.sh {save_dir} false 3e-05 16 0.15 uniform 95 | ``` 96 | 97 | ### Debugging Locally 98 | If you want a training run on a subset of datas with one local GPU (instead of using slurm and hydra), simply run `scripts/train_debug.sh` instead of `scripts/train.sh` with the same arguments as in the [Training section](#training). 99 | 100 | This use RoBERTA-base instead of RoBERTa-large, and can work with >=9GB GPU memory. 101 | 102 | Note: This only uses the first shard of English Wikipedia (no CC-News), so if you have not started preprocessing and want to do a test run first, you can preprocess English Wikipedia only and keep CC-News later. 103 | 104 | ## Evaluation 105 | Evaluation can be done by following the guidelines for inference in the main [README](README.md). 106 | 107 | * Checkpoints are saved every 10,000 training steps. You can find them under `{save_dir}/{hyperparam_settings}/0/lightning_logs/version_{slurm_id}/checkpoints`. 108 | * When saving embeddings, specify `+task.checkpoint_path=${checkpoint_path}` 109 | * When running `python -m scripts.prompt`, specify `--checkpoint_path ${checkpoint_path}` 110 | 111 | 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /scripts/util_clm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | def assert_all_approx_close(a, b, rtol, atol, count): 9 | 10 | idx = torch.isclose(a.float(), b.float(), rtol, atol) 11 | sumval = (idx==0).sum().item() 12 | if sumval > count: 13 | print(f'Too many values not close: assert {sumval} < {count}') 14 | try: 15 | torch.testing.assert_allclose(a, b, rtol, atol) 16 | except Exception as e: 17 | print(e) 18 | 19 | 20 | def get_memory_footprint(model, return_buffers=True): 21 | """ 22 | Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. 23 | Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the 24 | PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 25 | Arguments: 26 | return_buffers (`bool`, *optional*, defaults to `True`): 27 | Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers 28 | are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch 29 | norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 30 | """ 31 | mem = sum([param.nelement() * param.element_size() for param in model.parameters()]) 32 | if return_buffers: 33 | mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) 34 | mem = mem + mem_bufs 35 | return mem 36 | 37 | 38 | def ـreplace_linear_with_int8linear(model, modules_to_not_convert="lm_head"): 39 | for name, module in model.named_children(): 40 | ـreplace_linear_with_int8linear(module, modules_to_not_convert) 41 | 42 | if isinstance(module, torch.nn.Linear) and name != modules_to_not_convert: 43 | model._modules[name] = QuantizedLinearInt8(linear_layer=module) 44 | return 45 | 46 | 47 | class QuantizedLinearInt8(torch.nn.Module): 48 | ''' 49 | A simple but effictive implmenetion of Int8 quantization for linear layers. 50 | The weights are quantized and stored as Int8, which saves ~50% of the gpu memory. 51 | During the forwared pass, the weights are de-quantized back to fp16 to do multiplication. 52 | Pros: 53 | - saves ~50% of the gpu memory 54 | - accurate quantization because only the weights are quantized, and the weights don't suffer 55 | from the "outliers" issue mentioned in the LLM.int8 paper; only the activations do. 56 | - high precision results beacuse the multiplication is done in fp16 57 | - much faster than LLM.int8 58 | Cons: 59 | - a bit slower because of the added computation of dequantization in each forward pass. In practice, the slowdown 60 | is not large because in the generation application, gpu utilization is not very high. 61 | ''' 62 | def __init__(self, linear_layer): 63 | super().__init__() 64 | self.bias = linear_layer.bias 65 | 66 | weight_bit_width = 8 67 | weight = linear_layer.weight 68 | 69 | self.weight_scale = torch.nn.Parameter( 70 | (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half(), 71 | ) 72 | # print(self.weight_scale.max().item(), self.weight_scale.min().item(), self.weight_scale.mean().item()) 73 | # if self.weight_scale.max().item() > 0.002: 74 | # print(self.weight_scale.max().item()) 75 | self.weight = torch.nn.Parameter( 76 | torch.round(weight.float() / self.weight_scale[:, None]).char(), 77 | requires_grad=False 78 | ) 79 | 80 | def forward(self, x): 81 | weight = self.weight.half() * self.weight_scale[:, None] 82 | return torch.nn.functional.linear(x, weight, self.bias) 83 | 84 | 85 | def convert_model_to_int8_on_gpu(model, device): 86 | """ 87 | Quantize a model to int8 and move it to GPU using a simple method. 88 | """ 89 | if 'cuda' not in device: 90 | raise ValueError(f"Target device should be a gpu. Device {device} is not supported") 91 | 92 | model.half() 93 | 94 | memory_before_quantization = get_memory_footprint(model) # without lm_head 95 | 96 | ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8` 97 | 98 | model.to(device=device) 99 | memory_after_quantization = get_memory_footprint(model) # without lm_head 100 | 101 | saving = round(100 * memory_after_quantization/memory_before_quantization) 102 | memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing 103 | memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing 104 | 105 | print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') 106 | return model 107 | -------------------------------------------------------------------------------- /dpr_scale/datamodule/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | 9 | import json 10 | import mmap 11 | import os 12 | import time 13 | import glob 14 | import numpy as np 15 | import pickle as pkl 16 | 17 | from functools import partial 18 | 19 | import torch 20 | from pytorch_lightning import LightningDataModule 21 | 22 | from dpr_scale.utils.utils import PathManager 23 | from pathos.multiprocessing import ProcessingPool as Pool 24 | from joblib import Parallel, delayed 25 | from concurrent.futures import ThreadPoolExecutor 26 | 27 | def _initialize(path, header=False, max_cnt=-1): 28 | mm = {} 29 | offset_dict = {} 30 | count = 0 31 | local_path = PathManager.get_local_path(path) 32 | file = open(local_path, mode="r") 33 | mm = mmap.mmap(file.fileno(), 0, prot=mmap.PROT_READ) 34 | offset_dict[count] = mm.tell() 35 | if header: 36 | line = mm.readline() 37 | line = mm.readline() 38 | while line: 39 | count += 1 40 | offset = mm.tell() 41 | offset_dict[count] = offset 42 | line = mm.readline() 43 | if max_cnt > -1 and max_cnt == count: 44 | break 45 | #if count % 100000 == 0: 46 | # print("Finished reading %.1fM lines from %s" % (count/1000000, path)) 47 | return path, mm, offset_dict, count 48 | 49 | class MemoryMappedDataset(torch.utils.data.Dataset): 50 | """ 51 | A memory mapped dataset. 52 | """ 53 | def __init__(self, path, header=False, max_cnt = -1): 54 | self.mm = {} 55 | self.offset_dict = {} 56 | self.count = 0 57 | 58 | if path is None: 59 | return 60 | 61 | paths = [_path for path in path.split("+") for _path in glob.glob(path)] 62 | 63 | print ("Start reading %d files" % len(paths)) 64 | if len(paths)==0: 65 | max_cnt_per_path = 0 66 | else: 67 | max_cnt_per_path = max_cnt // len(paths) if max_cnt > -1 else -1 68 | start_time = time.time() 69 | 70 | path_idx = 0 71 | func = partial(_initialize, header=header, max_cnt=max_cnt_per_path) 72 | 73 | with ThreadPoolExecutor() as threads: 74 | for path, mm, offset_dict, count in threads.map(func, paths): 75 | self.mm[path_idx] = mm 76 | self.offset_dict.update({self.count + count: (path_idx, offset) 77 | for count, offset in offset_dict.items()}) 78 | self.count += count 79 | print ("Finish reading %s (path_idx=%d, %.1fmin)" % ( 80 | "/".join(path.split("/")[-2:]), 81 | path_idx, 82 | (time.time()-start_time)/60)) 83 | path_idx += 1 84 | 85 | print ("Final # of lines = %.3fM (%dmin)" % ( 86 | self.count / 1000000, (time.time()-start_time)/60)) 87 | 88 | def __len__(self): 89 | return self.count 90 | 91 | def process_line(self, line): 92 | return line 93 | 94 | def __getitem__(self, index): 95 | path_idx, offset = self.offset_dict[index] 96 | self.mm[path_idx].seek(offset) 97 | line = self.mm[path_idx].readline() 98 | return self.process_line(line) 99 | 100 | class CorpusDataset(torch.utils.data.Dataset): 101 | def __init__(self, path): 102 | 103 | if path is None: 104 | self.tot = 0 105 | else: 106 | self.all_input_ids = np.load(path) 107 | self.block_idx_to_token_idx = np.load(path.replace(".npy", "_blocks.npy")) 108 | 109 | if os.path.exists(path.replace(".npy", "_valid.pkl")): 110 | with open(path.replace(".npy", "_valid.pkl"), "rb") as f: 111 | self.valid_candidates = pkl.load(f) 112 | assert len(self.block_idx_to_token_idx)==len(self.valid_candidates) 113 | else: 114 | self.valid_candidates = None 115 | 116 | self.tot = len(self.block_idx_to_token_idx) 117 | self.block_idx_to_token_idx = np.concatenate( 118 | [self.block_idx_to_token_idx, [len(self.all_input_ids)]]) 119 | 120 | def __len__(self): 121 | return self.tot 122 | 123 | def process_line(self, line): 124 | return line 125 | 126 | def __getitem__(self, index, max_seq_length=256): 127 | # start, end = self.block_idx_to_token_idx[index] 128 | start = self.block_idx_to_token_idx[index] 129 | end = self.block_idx_to_token_idx[index+1] 130 | 131 | input_ids = np.concatenate([self.all_input_ids[start:end], [0]*(max_seq_length-end+start)]) 132 | attention_mask = [1]*(end-start) + [0]*(max_seq_length-end+start) 133 | 134 | if self.valid_candidates is not None: 135 | start_valid_indices, end_valid_indices = self.valid_candidates[index] 136 | is_valid = [i in start_valid_indices or i in end_valid_indices for i in range(max_seq_length)] 137 | else: 138 | is_valid = [i for i, _id in enumerate(input_ids) if _id not in [0, 2]] 139 | 140 | return {"input_ids": input_ids, "attention_mask": attention_mask, "is_valid": is_valid} 141 | 142 | 143 | -------------------------------------------------------------------------------- /scripts/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import json 8 | import argparse 9 | import time 10 | import numpy as np 11 | 12 | import torch 13 | from task.task import Task 14 | from npm.npm_single import NPMSingle 15 | from npm.npm import NPM 16 | from npm.dstore import DataStore 17 | from npm.model import Model, SingleModel 18 | 19 | class NPMDemo(object): 20 | 21 | def __init__(self, save_dir, setting, checkpoint_path, k, temperature, 22 | remove_stopwords, remove_stopwords_except_k, single, restricted, 23 | embs_consider_boundary, keep_uint8): 24 | start_time = time.time() 25 | dstore = DataStore(setting=setting, 26 | model_dir=os.path.join(save_dir, "dstore"), 27 | do_load_index=False, 28 | remove_stopwords=remove_stopwords, 29 | remove_stopwords_except_k=remove_stopwords_except_k, 30 | restricted=restricted, 31 | embs_consider_boundary=embs_consider_boundary, 32 | keep_uint8=keep_uint8 33 | ) 34 | model_class = SingleModel if single else Model 35 | model = model_class(checkpoint_path=checkpoint_path) 36 | print ("Finish loading the model (%dsec)" % (time.time()-start_time)) 37 | 38 | npm_class = NPMSingle if single else NPM 39 | npm = npm_class(model=model, dstore=dstore, k=k, temperature=temperature) 40 | 41 | mask = npm.get_stopword_mask() 42 | def valid_func(tokens): 43 | return np.sum(mask[tokens])==0 44 | 45 | self.npm = npm 46 | self.valid_func = valid_func 47 | 48 | def predict(self, text): 49 | if "" not in text: 50 | text = text.strip() + "." 51 | predicted = self.npm.predict_span(text, 52 | ngram_max=10, 53 | valid_func=self.valid_func, 54 | alphas=[0.0])["a=0.0"] 55 | return self.npm.decode(predicted) 56 | 57 | def generate(self, text, num_tokens=20, num_masked_tokens=20, return_metadata=False): 58 | assert "" not in text 59 | metadata = [] 60 | for _ in range(num_tokens): 61 | input_text = text + ""*num_masked_tokens 62 | predicted = self.npm.predict_span(input_text, 63 | ngram_max=10, 64 | alphas=[0.0], 65 | return_metadata=return_metadata) 66 | if return_metadata: 67 | _, curr_metadata = predicted 68 | predicted = curr_metadata["predicted"] 69 | metadata.append(curr_metadata) 70 | else: 71 | predicted = self.npm.decode(predicted["a=0.0"]) 72 | text += predicted 73 | 74 | if return_metadata: 75 | return text, metadata 76 | return text 77 | 78 | def bm25_search(self, text, k=3): 79 | block_ids = self.npm.dstore.searcher.search(text, k=3) 80 | blocks = [self.npm.decode(self.npm.dstore.input_ids[block_id]) for block_id in block_ids] 81 | return block_ids, blocks 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--save_dir', type=str, default="save") 86 | parser.add_argument('--setting', type=str, default="enwiki") 87 | parser.add_argument('--checkpoint_path', type=str, default="npm") 88 | parser.add_argument('--k', type=int, default=4096) 89 | parser.add_argument('--temperature', type=float, default=1.0) 90 | parser.add_argument("--remove_stopwords", action="store_true") 91 | parser.add_argument("--remove_stopwords_except_k", type=int, default=None) 92 | parser.add_argument("--single", action="store_true") 93 | parser.add_argument("--restricted", action="store_true") 94 | 95 | parser.add_argument("--embs_consider_boundary", action="store_true", default=True) 96 | parser.add_argument("--keep_uint8", action="store_true") 97 | 98 | args = parser.parse_args() 99 | npm = NPMDemo(save_dir=args.save_dir, 100 | setting=args.setting, 101 | checkpoint_path=args.checkpoint_path, 102 | k=args.k, 103 | temperature=args.temperature, 104 | remove_stopwords=args.remove_stopwords, 105 | remove_stopwords_except_k=args.remove_stopwords_except_k, 106 | single=args.single, 107 | restricted=args.restricted, 108 | embs_consider_boundary=args.embs_consider_boundary, 109 | keep_uint8=args.keep_uint8) 110 | 111 | input_text = "Hagios Demetrios is located in" 112 | 113 | start_time = time.time() 114 | print (npm.predict(input_text)) 115 | print ("(Took %.2fs to predict)" % (time.time()-start_time)) 116 | 117 | start_time = time.time() 118 | print (npm.generate(input_text)) 119 | print ("(Took %.2fs to generate)" % (time.time()-start_time)) 120 | 121 | start_time = time.time() 122 | text, metadata = npm.generate("Jo Kwon is a singer who", return_metadata=True) 123 | print (text) 124 | print ("(Took %.2fs to generate)" % (time.time()-start_time)) 125 | 126 | for dic in metadata: 127 | context = dic["predicted_spans"][0][0] 128 | print ("Input:", dic["input"]) 129 | print ("Predicted:", dic["predicted"]) 130 | print (context) 131 | print ("-"*30) 132 | 133 | from IPython import embed; embed() 134 | 135 | if __name__=='__main__': 136 | main() 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /task/create_lama_uhn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Code to create LAMA-UHN, a subset of LAMA-Google-RE and LAMA-T-REx 8 | # where ``easy-to-guess'' questions are filtered out. 9 | # 10 | # Defaults parameters correspond to setup in the following paper: 11 | # 12 | # @article{poerner2019bert, 13 | # title={BERT is Not a Knowledge Base (Yet): Factual Knowledge vs. 14 | # Name-Based Reasoning in Unsupervised QA}, 15 | # author={Poerner, Nina and Waltinger, Ulli and Sch{\"u}tze, Hinrich}, 16 | # journal={arXiv preprint arXiv:1911.03681}, 17 | # year={2019} 18 | # } 19 | 20 | import torch 21 | import json 22 | import os 23 | import argparse 24 | import tqdm 25 | 26 | from transformers import BertForMaskedLM, BertTokenizer 27 | 28 | class LAMAUHNFilter: 29 | def match(self, sub_label, obj_label, relation): 30 | raise NotImplementedError() 31 | 32 | def filter(self, queries): 33 | return [query for query in queries if not self.match(query)] 34 | 35 | 36 | class PersonNameFilter(LAMAUHNFilter): 37 | TEMP = "[CLS] [X] is a common name in the following [Y] : [MASK] . [SEP]" 38 | 39 | PLACENOUNS = { 40 | "/people/person/place_of_birth": "city", 41 | "/people/deceased_person/place_of_death": "city", 42 | "P19": "city", 43 | "P20": "city", 44 | "P27": "country", 45 | "P1412": "language", 46 | "P103": "language", 47 | } 48 | 49 | def __init__(self, top_k, bert_name): 50 | super().__init__() 51 | self.do_lower_case = "uncased" in bert_name 52 | self.top_k = top_k 53 | self.tokenizer = BertTokenizer.from_pretrained( 54 | bert_name, do_lower_case=self.do_lower_case 55 | ) 56 | self.model = BertForMaskedLM.from_pretrained(bert_name) 57 | self.model.eval() 58 | 59 | def get_top_k_for_name(self, template, name): 60 | tokens = self.tokenizer.tokenize(template.replace("[X]", name)) 61 | input_ids = self.tokenizer.convert_tokens_to_ids(tokens) 62 | output = self.model(torch.tensor(input_ids).unsqueeze(0))[0][0] 63 | logits = output[tokens.index("[MASK]")].detach() 64 | top_k_ids = torch.topk(logits, k=self.top_k)[1].numpy() 65 | top_k_tokens = self.tokenizer.convert_ids_to_tokens(top_k_ids) 66 | return top_k_tokens 67 | 68 | def match(self, query): 69 | relation = query["pred"] if "pred" in query else query["predicate_id"] 70 | if not relation in self.PLACENOUNS: 71 | return False 72 | 73 | sub_label, obj_label = query["sub_label"], query["obj_label"] 74 | if self.do_lower_case: 75 | obj_label = obj_label.lower() 76 | sub_label = sub_label.lower() 77 | 78 | template = self.TEMP.replace("[Y]", self.PLACENOUNS[relation]) 79 | for name in sub_label.split(): 80 | if obj_label in self.get_top_k_for_name(template, name): 81 | return True 82 | return False 83 | 84 | 85 | class StringMatchFilter(LAMAUHNFilter): 86 | def __init__(self, do_lower_case): 87 | self.do_lower_case = do_lower_case 88 | 89 | def match(self, query): 90 | sub_label, obj_label = query["sub_label"], query["obj_label"] 91 | if self.do_lower_case: 92 | sub_label = sub_label.lower() 93 | obj_label = obj_label.lower() 94 | return obj_label in sub_label 95 | 96 | 97 | def main(args): 98 | srcdir = args.srcdir 99 | assert os.path.isdir(srcdir) 100 | srcdir = srcdir.rstrip("/") 101 | tgtdir = srcdir + "_UHN" 102 | if not os.path.exists(tgtdir): 103 | os.mkdir(tgtdir) 104 | 105 | uhn_filters = [] 106 | if "string_match" in args.filters: 107 | uhn_filters.append( 108 | StringMatchFilter(do_lower_case=args.string_match_do_lowercase) 109 | ) 110 | if "person_name" in args.filters: 111 | uhn_filters.append( 112 | PersonNameFilter( 113 | bert_name=args.person_name_bert, top_k=args.person_name_top_k 114 | ) 115 | ) 116 | for filename in tqdm.tqdm(sorted(os.listdir(srcdir))): 117 | infile = os.path.join(srcdir, filename) 118 | outfile = os.path.join(tgtdir, filename) 119 | 120 | with open(infile) as handle: 121 | queries = [json.loads(line) for line in handle] 122 | 123 | for uhn_filter in uhn_filters: 124 | queries = uhn_filter.filter(queries) 125 | 126 | with open(outfile, "w") as handle: 127 | for query in queries: 128 | handle.write(json.dumps(query) + "\n") 129 | 130 | 131 | if __name__ == "__main__": 132 | argparser = argparse.ArgumentParser() 133 | 134 | argparser.add_argument( 135 | "--srcdir", 136 | required=True, 137 | type=str, 138 | help="Source directory. Should be Google_RE or TREx_alpaca.", 139 | ) 140 | argparser.add_argument( 141 | "--filters", 142 | nargs="+", 143 | type=str, 144 | default=("string_match", "person_name"), 145 | choices=("string_match", "person_name"), 146 | help="Filters to be applied: string_match, person_name or both.", 147 | ) 148 | argparser.add_argument( 149 | "--person_name_top_k", 150 | default=3, 151 | type=int, 152 | help="Parameter k for person name filter.", 153 | ) 154 | argparser.add_argument( 155 | "--person_name_bert", 156 | default="bert-base-cased", 157 | type=str, 158 | help="BERT version to use for person name filter.", 159 | ) 160 | argparser.add_argument( 161 | "--no_string_match_do_lowercase", 162 | default=True, 163 | action="store_false", 164 | dest="string_match_do_lowercase", 165 | help="Set flag to disable lowercasing in string match filter", 166 | ) 167 | args = argparser.parse_args() 168 | 169 | print(args) 170 | main(args) 171 | -------------------------------------------------------------------------------- /scripts/prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import json 8 | import argparse 9 | import time 10 | 11 | import torch 12 | from task.task import Task 13 | from npm.npm_single import NPMSingle 14 | from npm.npm import NPM 15 | from npm.dstore import DataStore, DataStoreUnion 16 | from npm.model import Model, SingleModel 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--eval_dataset', type=str, default="all") 21 | parser.add_argument('--corpus_data', type=str, default=None) 22 | parser.add_argument('--checkpoint_path', type=str, default="npm") 23 | parser.add_argument('--save_dir', type=str, default="save") 24 | 25 | parser.add_argument('--k', type=int, default=4096) 26 | parser.add_argument('--temperature', type=float, default=1.0) 27 | parser.add_argument('--n_samples', type=int, default=3000) 28 | parser.add_argument("--remove_stopwords", action="store_true") 29 | parser.add_argument("--remove_stopwords_except_k", type=int, default=None) 30 | 31 | parser.add_argument("--single", action="store_true") 32 | parser.add_argument("--open", action="store_true") 33 | parser.add_argument("--restricted", action="store_true") 34 | 35 | # for ablations 36 | parser.add_argument("--load_all_embs", action="store_true", default=True) 37 | parser.add_argument("--embs_consider_boundary", action="store_true", default=True) 38 | parser.add_argument("--keep_uint8", action="store_true") 39 | parser.add_argument("--debug", action="store_true") 40 | 41 | args = parser.parse_args() 42 | print (args) 43 | 44 | if args.restricted and not args.load_all_embs: 45 | tasks = [] 46 | for eval_dataset in args.eval_dataset.split("+"): 47 | task = Task(eval_dataset, "data", n_samples=args.n_samples) 48 | tasks.append(task) 49 | 50 | start_time = time.time() 51 | 52 | if args.corpus_data is None: 53 | dstore = None 54 | else: 55 | dstore_class = DataStoreUnion if "+" in args.corpus_data else DataStore 56 | dstore = dstore_class(setting=args.corpus_data, 57 | model_dir=os.path.join(args.save_dir, "dstore"), 58 | do_load_index=not args.restricted, 59 | remove_stopwords=args.remove_stopwords, 60 | remove_stopwords_except_k=args.remove_stopwords_except_k, 61 | restricted=(True if args.load_all_embs else tasks) if args.restricted else None, 62 | embs_consider_boundary=args.embs_consider_boundary, 63 | keep_uint8=args.keep_uint8 64 | ) 65 | print ("Finish loading the datastore (%dsec)" % (time.time()-start_time)) 66 | 67 | def add_postfix(corpus_data, postfix): 68 | return corpus_data.replace("+", postfix + "+") + postfix 69 | 70 | if args.remove_stopwords: 71 | args.corpus_data = add_postfix(args.corpus_data, ":no_stopwords") 72 | 73 | model_class = SingleModel if args.single else Model 74 | model = model_class(checkpoint_path=args.checkpoint_path) 75 | print ("Finish loading the model") 76 | 77 | if args.eval_dataset is None: 78 | return 79 | 80 | npm_class = NPMSingle if args.single else NPM 81 | npm = npm_class(model=model, 82 | dstore=dstore, 83 | k=args.k, 84 | temperature=args.temperature) 85 | 86 | for dataset_idx, eval_dataset in enumerate(args.eval_dataset.split("+")): 87 | # loading the task data 88 | if args.restricted and not args.load_all_embs: 89 | task = tasks[dataset_idx] 90 | else: 91 | task = Task(eval_dataset, "data", n_samples=args.n_samples) 92 | 93 | if args.debug: 94 | import numpy as np 95 | from task.utils_eval import normalize_answer 96 | 97 | # evaluate on a subset of examples where BM25 is successful. 98 | if args.load_all_embs: 99 | _, restricted_dict = dstore.searcher.batch_search(task) 100 | else: 101 | restricted_dict = dstore.restricted_dict 102 | psg_id_to_raw_text = {} 103 | for psgs in restricted_dict.values(): 104 | for psg in psgs: 105 | if psg not in psg_id_to_raw_text: 106 | psg_id_to_raw_text[psg] = normalize_answer(npm.decode(dstore.input_ids[psg])) 107 | 108 | included = [] 109 | for i, ex in enumerate(task.examples): 110 | psgs = restricted_dict[ex["input"]] 111 | psgs = [psg_id_to_raw_text[psg] for psg in psgs] 112 | answers = [normalize_answer(answer) for answer in ex["answers"]] 113 | if np.any([answer in psg for answer in answers for psg in psgs]): 114 | included.append(i) 115 | print ("Evaluating %d->%d examples..." % (len(task.examples), len(included))) 116 | task.examples = [task.examples[i] for i in included] 117 | if task.ngrams is not None: 118 | task.ngrams = [task.ngrams[i] for i in included] 119 | 120 | save_dir = os.path.join(args.save_dir, "results") 121 | if not os.path.exists(save_dir): 122 | os.makedirs(save_dir) 123 | 124 | if args.open: 125 | if args.single: 126 | raise NotImplementedError("NPM Single does not support open-set tasks.") 127 | all_predictions = npm.evaluate_open(task) 128 | else: 129 | all_predictions = npm.evaluate(task) 130 | 131 | save_path = os.path.join(save_dir, "{}{}{}{}{}.txt".format( 132 | eval_dataset, 133 | "_c={}".format(args.corpus_data) if dstore is not None else "", 134 | "_k={}".format(args.k) if dstore is not None else "", 135 | "_t={}".format(args.temperature) if dstore is not None else "", 136 | "_restricted" if args.restricted else "" 137 | )) 138 | 139 | with open(save_path, "w") as f: 140 | for pred in all_predictions: 141 | f.write(json.dumps(pred)+"\n") 142 | 143 | if __name__=='__main__': 144 | main() 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /dpr_scale/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | import os 8 | import copy 9 | import math 10 | from typing import List, Dict, Any 11 | 12 | import torch 13 | from torch.utils.data.distributed import DistributedSampler 14 | 15 | try: 16 | from pytext.utils.file_io import PathManager 17 | except ImportError: 18 | 19 | class DummyPathManager: 20 | def get_local_path(self, path, *args, **kwargs): 21 | return path 22 | 23 | def open(self, path, *args, **kwargs): 24 | return open(path, *args, **kwargs) 25 | 26 | PathManager = DummyPathManager() 27 | 28 | 29 | def maybe_add_title(text, title, use_title, sep_token): 30 | if use_title: 31 | return " ".join([title, sep_token, text]) 32 | else: 33 | return text 34 | 35 | 36 | class ContiguousDistributedSampler(DistributedSampler): 37 | def __init__( 38 | self, 39 | dataset, 40 | num_replicas=None, 41 | rank=None, 42 | shuffle: bool = True, 43 | seed: int = 0, 44 | drop_last: bool = False, 45 | num_replicas_per_node: int = 1, 46 | ) -> None: 47 | super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) 48 | self.num_replicas_per_node = num_replicas_per_node 49 | 50 | def __iter__(self): 51 | indices = list(range(len(self.dataset))) # type: ignore 52 | 53 | if not self.drop_last: 54 | # add extra samples to make it evenly divisible 55 | padding_size = self.total_size - len(indices) 56 | if padding_size <= len(indices): 57 | indices += indices[:padding_size] 58 | else: 59 | indices += (indices * math.ceil(padding_size / len(indices)))[ 60 | :padding_size 61 | ] 62 | else: 63 | # remove tail of data to make it evenly divisible. 64 | indices = indices[: self.total_size] 65 | assert len(indices) == self.total_size 66 | 67 | # subsample chunk 68 | chunk_size = self.num_samples * self.num_replicas_per_node 69 | node_rank = self.rank // self.num_replicas_per_node 70 | local_rank = self.rank % self.num_replicas_per_node 71 | start_idx = node_rank * chunk_size 72 | indices = indices[start_idx : start_idx + chunk_size] 73 | if self.shuffle: 74 | # deterministically shuffle 75 | g = torch.Generator() 76 | g.manual_seed(self.seed + self.epoch + node_rank) 77 | shuffle_idx = torch.randperm( 78 | len(indices), generator=g 79 | ).tolist() # type: ignore 80 | indices = [indices[idx] for idx in shuffle_idx] 81 | # subsample 82 | indices = indices[local_rank :: self.num_replicas_per_node] 83 | assert len(indices) == self.num_samples 84 | 85 | return iter(indices) 86 | 87 | 88 | class ContiguousDistributedSamplerForTest(DistributedSampler): 89 | def __iter__(self): 90 | shard_size = len(self.dataset) // self.num_replicas + 1 91 | return iter( 92 | range( 93 | self.rank * shard_size, 94 | min((self.rank + 1) * shard_size, len(self.dataset)), 95 | ) 96 | ) 97 | 98 | 99 | class WrapTransform(torch.nn.Module): 100 | def __init__(self, transform): 101 | super().__init__() 102 | self.transform = transform 103 | 104 | def forward(self, texts: List[str]) -> Dict[str, torch.Tensor]: 105 | batch: Dict[str, Any] = {"text": texts} 106 | return self.transform(batch) 107 | 108 | 109 | class ScriptEncoder(torch.nn.Module): 110 | # For scripting RobertaEncoder like classes 111 | def __init__(self, transform, encoder, quantize=False): 112 | super().__init__() 113 | self.transform = WrapTransform(transform) 114 | self.encoder = copy.deepcopy(encoder).cpu() 115 | if quantize: 116 | self.encoder = torch.quantization.quantize_dynamic( 117 | self.encoder, {torch.nn.Linear}, dtype=torch.qint8 118 | ) 119 | self.cpu() 120 | 121 | def forward(self, texts: List[str]) -> torch.Tensor: 122 | batch = self.transform(texts) 123 | return self.encode(batch["token_ids"]) 124 | 125 | def encode(self, model_inputs: torch.Tensor) -> torch.Tensor: 126 | return self.encoder(model_inputs) 127 | 128 | 129 | class ScriptMultiEncoder(torch.nn.Module): 130 | # For scripting an weighted ensemble of RobertaEncoder like classes 131 | def __init__(self, transform, encoders, quantize=False, weights=None): 132 | super().__init__() 133 | self.transform = WrapTransform(transform) 134 | self.encoders = torch.nn.ModuleList() 135 | self.linear = torch.nn.Linear(len(encoders), 1, bias=False, device='cpu') 136 | if weights is None: 137 | self.linear.weight.data = torch.ones( 138 | len(encoders), 1, device="cpu" 139 | ) # n_enc * 1, by default all ones 140 | else: 141 | assert len(weights) == len(encoders) 142 | self.linear.weight.data = torch.Tensor([weights], device="cpu").T 143 | for encoder in encoders: 144 | enc = copy.deepcopy(encoder).cpu() 145 | if quantize: 146 | enc = torch.quantization.quantize_dynamic( 147 | enc, {torch.nn.Linear}, dtype=torch.qint8 148 | ) 149 | self.encoders.append(enc) 150 | if quantize: 151 | self.linear = torch.quantization.quantize_dynamic( 152 | self.linear, {torch.nn.Linear}, dtype=torch.qint8 153 | ) 154 | self.cpu() 155 | 156 | def forward(self, texts: List[str]) -> torch.Tensor: 157 | batch = self.transform(texts) 158 | return self.encode(batch["token_ids"]) 159 | 160 | def encode(self, model_inputs: torch.Tensor) -> torch.Tensor: 161 | embeddings_list: List[torch.Tensor] = [] 162 | for i, encoder in enumerate(self.encoders): 163 | embeddings_list.append( 164 | self.linear.weight.data[i] * encoder(model_inputs) 165 | ) # weighted concatenation 166 | return torch.cat(embeddings_list, dim=1) # n_enc * d 167 | -------------------------------------------------------------------------------- /scripts/create_table.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import json 8 | import argparse 9 | import numpy as np 10 | 11 | from collections import defaultdict 12 | from prettytable import PrettyTable 13 | from task.task import Task 14 | from task.utils_eval import normalize_answer 15 | 16 | def load_output_file(output_file): 17 | predictions = [] 18 | with open(output_file, "r") as f: 19 | for line in f: 20 | predictions.append(json.loads(line)) 21 | return predictions 22 | 23 | def main(args): 24 | 25 | if args.closed: 26 | 27 | datasets = ["agn", "yahoo", "subj", "sst2", "mr", "rt", "cr", "amazon", "rte"] 28 | tasks = [] 29 | for dataset in datasets: 30 | tasks.append(Task(dataset, "data", n_samples=3000)) 31 | 32 | pt = PrettyTable() 33 | pt.field_names = ["Model"] + datasets 34 | pt.align["Model"] = "l" 35 | for dataset in datasets: 36 | pt.align[dataset] = "r" 37 | 38 | row = ["RoBERTa"] 39 | for dataset, task in zip(datasets, tasks): 40 | predictions = load_output_file(os.path.join(args.save_dir, "results", "{}.txt".format(dataset))) 41 | labels = [dp["label_list"][dp["label"]] for dp in task.examples] 42 | acc = np.mean(np.array(predictions)==np.array(labels)) 43 | row.append("%.1f" % (100*acc)) 44 | pt.add_row(row) 45 | 46 | for model in ["npm-single", "npm"]: 47 | row = [model] 48 | model_dir = os.path.join(args.save_dir, model, "results") 49 | for dataset, task in zip(datasets, tasks): 50 | output_files = [os.path.join(model_dir, file) for file in os.listdir(model_dir) if file.startswith(dataset+"_c=")] 51 | if len(output_files)>0 and os.path.exists(output_files[0]): 52 | assert len(output_files)==1, output_files 53 | predictions = load_output_file(output_files[0]) 54 | predictions = [p["prediction"] for p in predictions] 55 | labels = [dp["label_list"][dp["label"]] for dp in task.examples] 56 | assert len(predictions)==len(labels) 57 | acc = np.mean(np.array(predictions)==np.array(labels)) 58 | row.append("%.1f" % (100*acc)) 59 | else: 60 | row.append("-") 61 | pt.add_row(row) 62 | 63 | print (pt) 64 | 65 | if args.open: 66 | 67 | datasets = ["lama-trex", "lama-google_re", "kamel", "triviaqa", "nq"] 68 | field_names = ["trex", "trex uhn", "trex hard", "gre", "gre uhn", "kml", "tqa", "nq"] 69 | 70 | tasks = [] 71 | for dataset in datasets: 72 | tasks.append(Task(dataset, "data", n_samples=3000)) 73 | 74 | pt = PrettyTable() 75 | pt.field_names = ["Model"] + field_names 76 | pt.align["Model"] = "l" 77 | for dataset in datasets: 78 | pt.align[dataset] = "r" 79 | 80 | def compute_macro_em(accs, task, filter_func=None): 81 | acc_dict = defaultdict(list) 82 | for acc, ex, ngram in zip(accs, task.examples, task.ngrams): 83 | if filter_func is not None and not filter_func(ex): 84 | continue 85 | acc_dict[ngram].append(acc) 86 | return np.mean([np.mean(_accs) for _accs in acc_dict.values()]) 87 | 88 | def get_row(dataset, output_file): 89 | predictions = load_output_file(output_file) 90 | predictions = [p["a=0.0"] if "a=0.0" in p else p["prediction"] for p in predictions] 91 | predictions = [normalize_answer(p) for p in predictions] 92 | references = [[normalize_answer(a) for a in ex["answers"]] for ex in task.examples] 93 | assert len(predictions)==len(references) 94 | accs = [prediction in reference for prediction, reference in zip(predictions, references)] 95 | row = [] 96 | if dataset in ["lama-trex", "lama-google_re"]: 97 | row.append("%.1f" % (100*compute_macro_em(accs, task))) 98 | row.append("%.1f" % (100*compute_macro_em(accs, task, lambda x: x["is_uhn"]))) 99 | if dataset=="lama-trex": 100 | row.append("%.1f" % (100*compute_macro_em(accs, task, lambda x: x["is_hard"]))) 101 | else: 102 | row.append("%.1f" % (100*np.mean(accs))) 103 | return row 104 | 105 | for model_name in os.listdir(args.save_dir): 106 | if model_name.startswith("opt") or model_name.startswith("neo") or model_name.startswith("gpt"): 107 | row = [model_name] 108 | for dataset, task in zip(datasets, tasks): 109 | output_file = os.path.join(args.save_dir, model_name, "{}.jsonl".format(output_file)) 110 | if os.path.exists(output_file): 111 | row += get_row(dataset, output_file) 112 | else: 113 | if dataset=="lama-trex": 114 | n = 3 115 | elif dataset=="lama-google_re": 116 | n = 2 117 | else: 118 | n = 1 119 | for _ in range(n): 120 | row.append("-") 121 | pt.add_row(row) 122 | 123 | row = ["npm"] 124 | model_dir = os.path.join(args.save_dir, "npm-reproduced", "results") 125 | for dataset, task in zip(datasets, tasks): 126 | output_files = [os.path.join(model_dir, file) 127 | for file in os.listdir(model_dir) 128 | if file.startswith(dataset+"_c=enwiki:no_stopwords_")] 129 | if len(output_files)>0 and os.path.exists(output_files[0]): 130 | assert len(output_files)==1, output_files 131 | row += get_row(dataset, output_files[0]) 132 | else: 133 | if dataset=="lama-trex": 134 | n = 3 135 | elif dataset=="lama-google_re": 136 | n = 2 137 | else: 138 | n = 1 139 | for _ in range(n): 140 | row.append("-") 141 | 142 | pt.add_row(row) 143 | print (pt) 144 | 145 | if __name__=='__main__': 146 | 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--save_dir", default="save") 149 | parser.add_argument("--closed", action="store_true") 150 | parser.add_argument("--open", action="store_true") 151 | args = parser.parse_args() 152 | assert args.closed or args.open 153 | 154 | main(args) 155 | 156 | 157 | -------------------------------------------------------------------------------- /npm/npm_single.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import json 10 | import math 11 | import time 12 | import numpy as np 13 | import os 14 | import pickle as pkl 15 | import datetime 16 | import re 17 | import string 18 | 19 | from collections import defaultdict, Counter 20 | from scipy.special import softmax, log_softmax 21 | 22 | import torch.nn.functional as F 23 | 24 | class NPMSingle(object): 25 | def __init__(self, model, dstore=None, k=None, temperature=1.0): 26 | self.model = model 27 | self.k = k 28 | self.temperature = temperature 29 | self.dstore = dstore 30 | self.n_vocabs = len(self.model.tokenizer) 31 | 32 | def decode(self, ids): 33 | return self.model.tokenizer.decode(ids) 34 | 35 | def get_scores(self, queries, x): 36 | if type(queries)==np.ndarray: 37 | if type(x)==list: 38 | all_scores = np.concatenate([self.get_scores(queries, xi) for xi in x], -1) 39 | else: 40 | all_scores = np.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension) 41 | else: 42 | if type(x)==list: 43 | all_scores = torch.cat([self.get_scores(queries, xi) for xi in x], -1) 44 | else: 45 | all_scores = torch.inner(queries, x).squeeze(1) / np.sqrt(self.dstore.dimension) 46 | return all_scores 47 | 48 | def get_all_scores(self, queries): 49 | queries = queries.detach().cpu().numpy() 50 | all_scores, all_indices = self.dstore.search(queries, k=self.k) 51 | knn_ids = self.dstore.get_block_idx_and_token(all_indices.tolist(), token_only=True) 52 | x = self.dstore.get_embs(all_indices) 53 | all_scores = self.get_scores(queries, x) / self.temperature 54 | return all_scores, all_indices, knn_ids 55 | 56 | def get_knn_scores(self, 57 | queries, 58 | return_context=False): 59 | 60 | all_scores, all_indices, knn_ids = self.get_all_scores(queries) 61 | 62 | if return_context: 63 | sorted_all_indices = all_indices[0, np.argsort(-all_scores[0])] 64 | assert len(sorted_all_indices)==len(knn_ids[0]) 65 | 66 | k = all_scores.shape[1] 67 | assert len(knn_ids)==len(all_scores)==1 and len(knn_ids[0])==len(all_scores[0]) 68 | 69 | probs = softmax(all_scores, -1) 70 | assert len(knn_ids)==1 and probs.shape[0]==1 and len(knn_ids[0])==len(probs[0]) 71 | full_knn_scores = {} 72 | for vocab, p in zip(knn_ids[0], probs[0]): 73 | if vocab not in full_knn_scores: 74 | full_knn_scores[vocab] = 0 75 | full_knn_scores[vocab] += p 76 | 77 | prob = np.zeros((self.n_vocabs, )) 78 | for vocab, p in full_knn_scores.items(): 79 | prob[vocab] = p 80 | 81 | if return_context: 82 | def decode_func(input_ids, token_i): 83 | assert token_i < len(input_ids) 84 | if token_i==len(input_ids)-1: 85 | return self.model.tokenizer.decode(input_ids) + " " + colored("EOS", "red"), "EOS" 86 | retrieved_token = self.model.tokenizer.decode([input_ids[token_i+1]]) 87 | return self.model.tokenizer.decode(input_ids[:token_i+1]) + \ 88 | colored(retrieved_token, "red") + \ 89 | self.model.tokenizer.decode(input_ids[token_i+2:]), retrieved_token 90 | 91 | context = self.dstore.get_context(sorted_all_indices.tolist(), decode_func) 92 | assert len(context)==len(all_scores[0]) 93 | sorted_context_and_scores = sorted(zip(context, all_scores[0]), 94 | key=lambda x: -x[1]) 95 | 96 | return prob, sorted_context_and_scores 97 | 98 | return prob 99 | 100 | def predict(self, 101 | input_text, 102 | label2id, 103 | return_context=False, 104 | max_length=256): 105 | 106 | assert type(input_text)==str 107 | if "" not in input_text: 108 | input_text = input_text + "{}.".format(self.model.tokenizer.mask_token) 109 | inputs = self.model.tokenizer.encode_plus(input_text) 110 | input_ids = inputs["input_ids"] 111 | 112 | mask_id = self.model.tokenizer.mask_token_id 113 | idx = input_ids.index(mask_id) 114 | 115 | if len(input_ids) > max_length: 116 | input_ids = input_ids[-max_length:] 117 | assert mask_id in input_ids 118 | idx = input_ids.index(mask_id) 119 | 120 | input_ids = torch.LongTensor(input_ids).unsqueeze(0).to("cuda") 121 | 122 | with torch.no_grad(): 123 | logits, knn_queries = self.model.forward(input_ids, idx) 124 | 125 | if self.dstore is None: 126 | logits = logits.detach().cpu() 127 | prob = torch.softmax(logits, dim=-1).numpy() 128 | assert not return_context 129 | else: 130 | prob = self.get_knn_scores(knn_queries, return_context=return_context) 131 | if return_context: 132 | prob, retrieved_context = prob 133 | 134 | prob = np.array([np.sum(prob[label2id[label]]) for label in range(len(label2id))]) 135 | if return_context: 136 | return prob, retrieved_context 137 | return prob 138 | 139 | def evaluate(self, task): 140 | all_predictions = [] 141 | accs = [] 142 | 143 | examples = task.examples 144 | labels = examples[0]["label_list"] 145 | if self.dstore is not None and task.label2syn is not None: 146 | label2id = self.init_label2word_id(task.label2syn) 147 | assert np.all([v.shape[-1]==1 for v in label2id.values()]) 148 | else: 149 | labels_id = self.model.tokenizer(labels)["input_ids"] 150 | label2word = {i: [v] for i, v in enumerate(examples[0]["label_list"])} 151 | label2id = self.init_label2word_id(label2word) 152 | 153 | for ex in tqdm(examples): 154 | prob = self.predict(ex["input"], label2id) 155 | predicted_label = np.argmax(prob) 156 | accs.append(ex["label"]==predicted_label) 157 | all_predictions.append({"prediction": labels[predicted_label], "prob": prob.tolist()}) 158 | 159 | print ("%s\tAccuracy=%.1f%%" % (task, 100*np.mean(accs))) 160 | return all_predictions 161 | 162 | def init_label2word_id(self, label2synonym): 163 | label2synonym_id = {} 164 | for k, v in label2synonym.items(): 165 | synonym_id = [] 166 | for word in v: 167 | tokens = self.model.tokenizer(word)["input_ids"] 168 | assert len(tokens)==3 169 | assert (tokens[0]==0 and tokens[-1]==2) or (tokens[0]==101 and tokens[-1]==102) 170 | tokens = tokens[1:-1] 171 | assert len(tokens)==1 172 | synonym_id.append(tokens) 173 | label2synonym_id[k] = np.array(synonym_id) 174 | return label2synonym_id 175 | 176 | def get_stopword_mask(self, name="stopwords", stopwords=set()): 177 | mask = np.zeros((self.n_vocabs, )) 178 | stopwords = set() 179 | with open("config/roberta_" + name + ".txt") as f: 180 | for line in f: 181 | stopwords.add(int(line.strip())) 182 | mask[np.array(list(stopwords))] = -1e10 183 | return mask 184 | 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nonparametric Masked Language Modeling 2 | 3 | This repo contains the original implementation of the paper "[Nonparametric Masked Language Modeling](https://arxiv.org/abs/2212.01349)". 4 | 5 |

6 | 7 |

8 | 9 | ``` 10 | @article{ min2022nonparametric, 11 | title={ Nonparametric Masked Language Modeling }, 12 | author={ Min, Sewon and Shi, Weijia and Lewis, Mike and Chen, Xilun and Yih, Wen-tau and Hajishirzi, Hannaneh and Zettlemoyer, Luke }, 13 | year={ 2022 } 14 | } 15 | ``` 16 | 17 | Models are available from Huggingface Hub:hugs:! Check out [**npm**](https://huggingface.co/facebook/npm) (for phrase retrieval) and [**npm-single**](https://huggingface.co/facebook/npm-single) (for token retrieval). 18 | 19 | **We are working on a simple demo where you can simply download all the resources and deploy on your machine. Stay tuned!** 20 | 21 | ### Updates 22 | * **01/02/2023**: The code for training is released. See [train.md](train.md) for instructions. 23 | * **12/22/2022**: The code for inference is released. Stay tuned for the code for training. 24 | 25 | ## Content 26 | 27 | 1. [Requirements](#requirements) 28 | 2. [Download Data](#download-data) 29 | 3. [Closed-set Experiments](#closed-set-experiments) 30 | * [Baselines](#baselines-on-closed-set-tasks) 31 | * [NPM](#npm-on-closed-set-tasks) 32 | * [NPM Single](#npm-single-on-closed-set-tasks) 33 | 4. [Open-set Experiments](#open-set-experiments) 34 | * [Baselines](#baselines-on-open-set-tasks) 35 | * [NPM](#npm-on-open-set-tasks) 36 | 5. [License](#license) 37 | 6. [Contact](#contact) 38 | 39 | ## Requirements 40 | 41 | ``` 42 | conda create -n npm python=3.7 43 | conda activate npm 44 | pip3 install -r requirements.txt --user 45 | ``` 46 | 47 | If you will use open-set tasks, make sure to install java as well. 48 | ```bash 49 | conda install -c conda-forge openjdk 50 | ``` 51 | 52 | Note that multi-gpu inference is not supported for now. 53 | 54 | ## Download Data 55 | Evaluation datasets and reference corpora can be downloaded via 56 | ```bash 57 | # To run evaluation on closed-set tasks 58 | bash scripts/download_data.sh closed 59 | bash scripts/download_corpus.sh closed 60 | 61 | # To run evaluation on open-set tasks 62 | bash scripts/download_data.sh open 63 | bash scripts/download_corpus.sh enwiki 64 | 65 | # To run evaluation on TempLAMA (need Wikipedia 2022) 66 | bash scripts/download_data.sh templama 67 | bash scripts/download_corpus.sh new-enwiki 68 | ``` 69 | 70 | The corpus data is required for NPM and the retrieve-and-generate baselines. If you will only run parametric baselines, you can skip downloading the corpus. 71 | 72 | All reference corpus files are saved under `corpus/` and evaluation datasets are saved under `data/`. 73 | 74 | ## Closed-set Experiments 75 | 76 | #### Baselines on closed-set tasks 77 | The following is the script for runing the RoBERTA-large baseline on all 9 datasets used in the paper. 78 | ```bash 79 | python -m scripts.prompt \ 80 | --checkpoint_path roberta-large \ 81 | --eval_dataset agn+yahoo+rte+subj+sst2+mr+rt+cr+amazon \ 82 | --save_dir save/roberta \ 83 | --single 84 | ``` 85 | 86 | #### NPM on closed-set tasks 87 | 88 | ```bash 89 | # To run on AGN, Yahoo and RTE: 90 | bash scripts/save_embeddings.sh npm enwiki-0 false 320 91 | bash scripts/save_embeddings.sh npm cc_news false 320 92 | python -m scripts.prompt \ 93 | --corpus_data enwiki-0+cc_news \ 94 | --checkpoint_path npm \ 95 | --eval_dataset agn+yahoo+rte \ 96 | --temperature 5.0 \ 97 | --save_dir save/npm 98 | 99 | # To run on Subj: 100 | bash scripts/save_embeddings.sh npm subj false 320 101 | python -m scripts.prompt \ 102 | --corpus_data subj \ 103 | --checkpoint_path npm \ 104 | --eval_dataset subj \ 105 | --temperature 5.0 \ 106 | --save_dir save/npm 107 | 108 | # To run on SST-2, MR, RT, CR and Amazon: 109 | bash scripts/save_embeddings.sh npm imdb false 320 110 | bash scripts/save_embeddings.sh npm amazon false 320 111 | python -m scripts.prompt \ 112 | --corpus_data imdb+amazon \ 113 | --checkpoint_path npm \ 114 | --eval_dataset sst2+mr+rt+cr+amazon \ 115 | --temperature 5.0 \ 116 | --save_dir save/npm 117 | ``` 118 | 119 | Note that `scripts/save_embeddings.sh` takes 120 | - model name (npm or npm-single) 121 | - corpus name 122 | - whether it is an open-set task (true or false) 123 | - batch size (`320` is good for a 32gb GPU; if `trainer.precision=16` is used, `400` is good for a 32gb GPU) 124 | as arguments. Embeddings are saved under `save/{model_name}/dstore`. 125 | 126 | #### NPM Single on closed-set tasks 127 | 128 | ```bash 129 | # To run on AGN, Yahoo and RTE: 130 | bash scripts/save_embeddings.sh npm-single enwiki-0 false 320 131 | bash scripts/save_embeddings.sh npm-single cc_news false 320 132 | python -m scripts.prompt \ 133 | --corpus_data enwiki-0+cc_news \ 134 | --checkpoint_path npm-single \ 135 | --eval_dataset agn+yahoo+rte \ 136 | --temperature 5.0 \ 137 | --single \ 138 | --save_dir save/npm-single 139 | 140 | # To run on Subj: 141 | bash scripts/save_embeddings.sh npm-single subj false 320 142 | python -m scripts.prompt \ 143 | --corpus_data subj \ 144 | --checkpoint_path npm-single \ 145 | --eval_dataset subj \ 146 | --temperature 5.0 \ 147 | --single \ 148 | --save_dir save/npm-single 149 | 150 | # To run on SST-2, MR, RT, CR and Amazon: 151 | bash scripts/save_embeddings.sh npm-single imdb false 320 152 | bash scripts/save_embeddings.sh npm-single amazon false 320 153 | python -m scripts.prompt \ 154 | --corpus_data imdb+amazon \ 155 | --checkpoint_path npm-single \ 156 | --eval_dataset sst2+mr+rt+cr+amazon \ 157 | --temperature 5.0 \ 158 | --single \ 159 | --save_dir save/npm-single 160 | ``` 161 | 162 | ## Open-set Experiments 163 | 164 | #### Baselines on open-set tasks 165 | 166 | Run the following to run causal language model baselines (T5 baselines are TBA!). 167 | 168 | ```bash 169 | python -m scripts.clm_prompt \ 170 | --eval_dataset {lama-trex|lama-google_re|kamel|triviaqa|nq|entity_translation} \ 171 | --model_name {j-6b|neo-1.3b|neo-2.7b|neox-20b|opt-1.3b|opt-2.7b|opt-6.7b|opt-13b|opt-30b|bloom-1b7|bloom-3b|bloom-7b1} \ 172 | --save_dir save 173 | ``` 174 | 175 | By default, this does not use any passages from an external corpus. Specify `--ret bm25` if use BM25 passages from Wikipedia 2019, and `--ret bm25_2022` to use BM25 passages from Wikipedia 2022 (for TempLAMA). 176 | 177 | #### NPM on open-set tasks 178 | 179 | Please note that running open-set tasks requires around 70GB of RAM and 1.4TB of disk memory. If you want to reduce the RAM usage, you can specify `--keep_uint8` while running `python -m scripts.prompt` below, which reduces the RAM usage from 70GB to 40GB while increasing the datastore setting time. We will explore further optimizing RAM/disk usage in the future version of the code (PR is also welcome!). 180 | 181 | ```bash 182 | # Note that this can be executed in parallel with up to 20 GPUs. In total, it takes about 10 GPU hours and 1.4TB of disk memory. 183 | for i in {0..19} ; do 184 | bash scripts/save_embeddings.sh npm enwiki-${i} true 320 185 | done 186 | 187 | # Loading the model takes about 40min, and 70GB of RAM (specify `--keep_uint8` to reduce RAM usage to 40GB which increases the model loading time to 60-80min). 188 | python -m scripts.prompt \ 189 | --corpus_data enwiki \ 190 | --checkpoint_path npm \ 191 | --eval_dataset lama-trex+lama-google_re+kamel+triviaqa+nq+entity_translation \ 192 | --save_dir save/npm \ 193 | --remove_stopwords \ 194 | --restricted \ 195 | --open 196 | ``` 197 | 198 | To evaluate on TempLAMA, use `new-enwiki` instead of `enwiki`, and use `--eval_dataset {templama|unchanged_templama}`. 199 | 200 | ## License 201 | NPM is CC-BY-NC 4.0 licensed. 202 | 203 | ## Contact 204 | 205 | Please leave Github issues or contact Sewon Min `sewon@cs.washington.edu` for any questions. 206 | 207 | 208 |

209 | 210 |

211 | -------------------------------------------------------------------------------- /dpr_scale/optim/madgrad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | import math 9 | from collections import defaultdict 10 | from typing import Optional 11 | 12 | import torch 13 | from torch.optim.optimizer import Optimizer 14 | 15 | 16 | class MADGRAD(Optimizer): 17 | """ 18 | `MADGRAD Optimizer`: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic 19 | Optimization. 20 | Paper: https://arxiv.org/abs/2101.11075 21 | 22 | Implementation has been copied over from the original author 23 | (https://github.com/facebookresearch/madgrad/blob/master/madgrad/madgrad.py) 24 | Arguments: 25 | params (iterable): 26 | Iterable of parameters to optimize or dicts defining parameter groups. 27 | lr (float): 28 | Learning rate (default: 1e-2). 29 | momentum (float): 30 | Momentum value in the range [0,1) (default: 0.9). 31 | weight_decay (float): 32 | Weight decay, i.e. a L2 penalty (default: 0). 33 | eps (float): 34 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6). 35 | """ 36 | 37 | def __init__( 38 | self, 39 | params, 40 | lr: float = 1e-2, 41 | momentum: float = 0.9, 42 | weight_decay: float = 0, 43 | eps: float = 1e-6, 44 | k: int = 0 45 | ): 46 | if momentum < 0 or momentum >= 1: 47 | raise ValueError(f"Momentum {momentum} must be in the range [0,1]") 48 | if lr <= 0: 49 | raise ValueError(f"Learning rate {lr} must be positive") 50 | if weight_decay < 0: 51 | raise ValueError(f"Weight decay {weight_decay} must be non-negative") 52 | if eps < 0: 53 | raise ValueError("Eps must be non-negative") 54 | 55 | defaults = { 56 | "lr": lr, 57 | "eps": eps, 58 | "momentum": momentum, 59 | "weight_decay": weight_decay, 60 | "k": k, 61 | } 62 | 63 | self.momentum = momentum 64 | 65 | Optimizer.__init__(self, params, defaults) 66 | 67 | self.initialize_state() 68 | 69 | def initialize_state(self): 70 | for group in self.param_groups: 71 | for p in group["params"]: 72 | if p not in self.state: 73 | state = self.state[p] 74 | state["grad_sum_sq"] = torch.zeros_like(p.data).detach().cuda() 75 | state["s"] = torch.zeros_like(p.data).detach().cuda() 76 | if self.momentum != 0: 77 | state["x0"] = torch.clone(p.data).detach().cuda() 78 | 79 | @property 80 | def supports_memory_efficient_fp16(self) -> bool: 81 | return False 82 | 83 | @property 84 | def supports_flat_params(self) -> bool: 85 | return True 86 | 87 | def step(self, closure=None, **kwargs) -> Optional[float]: 88 | """Performs a single optimization step. 89 | Arguments: 90 | closure (callable, optional): A closure that reevaluates the model 91 | and returns the loss. 92 | """ 93 | loss = None 94 | if closure is not None: 95 | loss = closure() 96 | 97 | for group in self.param_groups: 98 | eps = group["eps"] 99 | k = group["k"] 100 | lr = group["lr"] + eps 101 | decay = group["weight_decay"] 102 | momentum = group["momentum"] 103 | 104 | ck = 1 - momentum 105 | lamb = lr * math.pow(k + 1, 0.5) 106 | 107 | for p in group["params"]: 108 | if p.grad is None: 109 | continue 110 | grad = p.grad.data 111 | state = self.state[p] 112 | 113 | if momentum != 0.0 and grad.is_sparse: 114 | raise RuntimeError( 115 | "momentum != 0 is not compatible with sparse gradients" 116 | ) 117 | 118 | grad_sum_sq = state["grad_sum_sq"] 119 | s = state["s"] 120 | 121 | # Apply weight decay 122 | if decay != 0: 123 | if grad.is_sparse: 124 | raise RuntimeError( 125 | "weight_decay option is not compatible with sparse gradients" 126 | ) 127 | 128 | grad.add_(p.data, alpha=decay) 129 | 130 | if grad.is_sparse: 131 | grad = grad.coalesce() 132 | grad_val = grad._values() 133 | 134 | p_masked = p.sparse_mask(grad) 135 | grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad) 136 | s_masked = s.sparse_mask(grad) 137 | 138 | # Compute x_0 from other known quantities 139 | rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps) 140 | x0_masked_vals = p_masked._values().addcdiv( 141 | s_masked._values(), rms_masked_vals, value=1 142 | ) 143 | 144 | # Dense + sparse op 145 | grad_sq = grad * grad 146 | grad_sum_sq.add_(grad_sq, alpha=lamb) 147 | grad_sum_sq_masked.add_(grad_sq, alpha=lamb) 148 | 149 | rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps) 150 | 151 | s.add_(grad, alpha=lamb) 152 | s_masked._values().add_(grad_val, alpha=lamb) 153 | 154 | # update masked copy of p 155 | p_kp1_masked_vals = x0_masked_vals.addcdiv( 156 | s_masked._values(), rms_masked_vals, value=-1 157 | ) 158 | # Copy updated masked p to dense p using an add operation 159 | p_masked._values().add_(p_kp1_masked_vals, alpha=-1) 160 | p.data.add_(p_masked, alpha=-1) 161 | else: 162 | if momentum == 0: 163 | # Compute x_0 from other known quantities 164 | rms = grad_sum_sq.pow(1 / 3).add_(eps) 165 | x0 = p.data.addcdiv(s, rms, value=1) 166 | else: 167 | x0 = state["x0"] 168 | 169 | # Accumulate second moments 170 | grad_sum_sq.addcmul_(grad, grad, value=lamb) 171 | rms = grad_sum_sq.pow(1 / 3).add_(eps) 172 | 173 | # Update s 174 | s.data.add_(grad, alpha=lamb) 175 | 176 | # Step 177 | if momentum == 0: 178 | p.data.copy_(x0.addcdiv(s, rms, value=-1)) 179 | else: 180 | z = x0.addcdiv(s, rms, value=-1) 181 | 182 | # p is a moving average of z 183 | p.data.mul_(1 - ck).add_(z, alpha=ck) 184 | 185 | group["k"] = group["k"] + 1 186 | return loss 187 | 188 | def add_param_group(self, param_group): 189 | r"""Add a param group to the :class:`Optimizer` s `param_groups`. 190 | 191 | This can be useful when fine tuning a pre-trained network as frozen 192 | layers can be made trainable and added to the :class:`Optimizer` as 193 | training progresses. 194 | 195 | Args: 196 | param_group (dict): Specifies what Tensors should be optimized along 197 | with group specific optimization options. 198 | """ 199 | super().add_param_group(param_group) 200 | self.initialize_state() 201 | 202 | def reset_param_groups(self): 203 | self.param_groups = [] 204 | self.state = defaultdict(dict) 205 | -------------------------------------------------------------------------------- /preprocess/process_wiki.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import json 8 | import time 9 | import math 10 | import argparse 11 | import numpy as np 12 | 13 | from tqdm import tqdm 14 | from collections import defaultdict, Counter 15 | from functools import partial 16 | 17 | from utils import create_blocks_from_plain_text 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--in_path", 22 | type=str, 23 | default="/checkpoint/sewonmin/data/FEVER/kilt_knowledgesource.json") 24 | parser.add_argument("--out_dir", 25 | type=str, 26 | default="train_corpus/enwiki/") 27 | 28 | parser.add_argument("--batch_size", type=int, default=16) 29 | parser.add_argument("--max_seq_length", type=int, default=256) 30 | parser.add_argument("--remove_list", action="store_true") 31 | parser.add_argument("--num_shards", type=int, default=10) 32 | 33 | parser.add_argument("--shard_data", action="store_true") 34 | parser.add_argument("--save_flatten_data", action="store_true") 35 | parser.add_argument("--save_nested_data", action="store_true") 36 | 37 | args = parser.parse_args() 38 | 39 | assert args.save_flatten_data or args.save_nested_data 40 | 41 | if not os.path.isdir(args.out_dir): 42 | os.makedirs(args.out_dir) 43 | args.shard_file = os.path.join(args.out_dir, "text_shard{}.jsonl") 44 | args.flatten_file = os.path.join(args.out_dir, "flatten_shard{}.jsonl") 45 | args.nested_file = os.path.join(args.out_dir, "BS%d_shard{}.jsonl" % args.batch_size) 46 | args.tmp_file = "temp.jsonl" 47 | 48 | # shard Wikipedia 49 | if args.shard_data: 50 | shard_wiki(args) 51 | else: 52 | assert np.all([os.path.exists(args.shard_file.format(shard_idx)) 53 | for shard_idx in range(args.num_shards)]), \ 54 | "You must shard the data first. If you haven't sharded yet, specify `--shard_data`" 55 | 56 | # tokenize each shard in parallel 57 | 58 | from multiprocessing import Pool 59 | article_count, flatten_count, shard_count = 0, 0, 0 60 | with Pool(args.num_shards) as p: 61 | print ("Start tokenizing...") 62 | for _article_count, _flatten_count, _shard_count in p.imap( 63 | partial(save_each_shard, args=args), range(args.num_shards)): 64 | article_count += _article_count 65 | flatten_count += _flatten_count 66 | shard_count += _shard_count 67 | print ("Done with saving! article_count=%d, flatten_count=%d, shard_count=%d" % ( 68 | article_count, flatten_count, shard_count 69 | )) 70 | 71 | if os.path.exists(args.tmp_file): 72 | os.remove(args.tmp_file) 73 | 74 | def shard_wiki(args): 75 | np.random.seed(2022) 76 | 77 | all_lines = [] 78 | cnt = 0 79 | print ("Starting sharding...") 80 | 81 | start_time = time.time() 82 | with open(args.in_path, "rb") as f: 83 | for line in f: 84 | dp = json.loads(line) 85 | 86 | if "(disambiguation)" in dp["wikipedia_title"]: 87 | continue 88 | 89 | if dp["wikipedia_title"].startswith("List of"): 90 | continue 91 | 92 | if len(dp["text"])<=1: 93 | continue 94 | 95 | if len(dp["wikipedia_title"].strip()) <= 1: 96 | continue 97 | 98 | text0 = dp["text"][0].strip() 99 | text1 = dp["text"][1].strip() 100 | if text1.startswith(text0) and ( 101 | text1.endswith("refers to:") or text1.endswith("refer to:") 102 | ): 103 | continue 104 | 105 | sentences = [""] # section 106 | sentences_text = [False] # if the sentence contain text 107 | for sent in dp["text"]: 108 | is_plain_text = True 109 | 110 | if sent.startswith("Section::::"): 111 | sent = sent.split("::::")[-1] 112 | if len(sentences[-1])>0: 113 | sentences.append("") 114 | sentences_text.append(False) 115 | is_plain_text = False 116 | 117 | if sent.startswith("External links"): 118 | break 119 | 120 | if sent.startswith("BULLET::::"): 121 | if args.remove_list: 122 | continue 123 | 124 | sent = sent.split("::::")[-1] 125 | is_plain_text = False 126 | 127 | sent = sent.strip() 128 | if len(sent)==0: 129 | continue 130 | 131 | sentences[-1] += " " + sent.strip() 132 | 133 | if is_plain_text: 134 | sentences_text[-1] = True 135 | 136 | assert len(sentences)==len(sentences_text) 137 | 138 | if args.remove_list: 139 | new_sentences = [] 140 | for sentence, has_text in zip(sentences, sentences_text): 141 | if has_text and len(sentence.split())>=5: 142 | new_sentences.append(sentence.strip()) 143 | 144 | if np.sum([len(sentence.split()) for sentence in sentences]) < 50: 145 | continue 146 | 147 | sentences = new_sentences 148 | 149 | if len(sentences)==0: 150 | continue 151 | 152 | all_lines.append(json.dumps({"title": dp["wikipedia_title"], "text": sentences})) 153 | cnt += 1 154 | 155 | if len(all_lines)==1000000: 156 | save_lines(all_lines, args.shard_file, args.num_shards) 157 | print ("Sharding... Saved %.1fM lines" % (cnt / 1000000)) 158 | all_lines = [] 159 | 160 | save_lines(all_lines, args.shard_file, args.num_shards) 161 | print ("Done with sharding: Saved %.1fM lines" % (cnt / 1000000)) 162 | 163 | def save_each_shard(shard_idx, args): 164 | outputs = [] 165 | article_cnt = 0 166 | cnt = 0 167 | flatten_cnt = 0 168 | 169 | shard_file = args.shard_file.format(shard_idx) 170 | flatten_file = args.flatten_file.format(shard_idx) if args.save_flatten_data else args.tmp_file 171 | nested_file = args.nested_file.format(shard_idx) if args.save_nested_data else args.tmp_file 172 | 173 | with open(shard_file, "r") as f: 174 | with open(flatten_file, "w") as f_w_flatten: 175 | with open(nested_file, "w") as f_w: 176 | for line in f: 177 | article_cnt += 1 178 | dp = json.loads(line) 179 | title = dp["title"] 180 | sentences = dp["text"] 181 | sentences = [title + ". " + sent for sent in sentences] 182 | 183 | curr_outputs, _ = create_blocks_from_plain_text(sentences, doc_idx=title, max_seq_length=args.max_seq_length) 184 | if args.save_flatten_data: 185 | for o in curr_outputs: 186 | flatten_cnt += 1 187 | f_w_flatten.write(json.dumps(o)+"\n") 188 | 189 | if args.save_nested_data: 190 | outputs += curr_outputs 191 | 192 | if len(outputs) max_sequence_length - max_output_length: 62 | curr_input_ids = curr_input_ids[-(max_sequence_length - max_output_length):] 63 | curr_input_ids = torch.LongTensor([curr_input_ids]).cuda() 64 | gen_tokens = self.model.generate( 65 | curr_input_ids, 66 | max_length=curr_input_ids.shape[1]+max_output_length, 67 | ) 68 | gen = self.tokenizer.decode(gen_tokens[0, curr_input_ids.shape[-1]:]).split("\n")[0].strip() 69 | 70 | if len(generations)==0: 71 | print ("Input:", prompts[0]) 72 | print ("Prediction:", gen) 73 | 74 | generations.append(gen) 75 | 76 | assert len(generations)==len(prompts) 77 | return generations 78 | 79 | import string 80 | import re 81 | 82 | def normalize_answer(s): 83 | """Lower text and remove punctuation, articles and extra whitespace.""" 84 | def remove_articles(text): 85 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 86 | return re.sub(regex, ' ', text) 87 | def white_space_fix(text): 88 | return ' '.join(text.split()) 89 | def remove_punc(text): 90 | exclude = set(string.punctuation) 91 | return ''.join(ch for ch in text if ch not in exclude) 92 | def lower(text): 93 | return text.lower() 94 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 95 | 96 | def load_inputs(args): 97 | ret = args.ret 98 | assert ret is None or ret in ["bm25", "bm25_2022"] 99 | 100 | task = Task(args.eval_dataset, "data", n_samples=3000) #args.n_samples) 101 | 102 | if args.ret: 103 | from npm.searcher import BM25Searcher 104 | base_dir = "corpus" 105 | name = "new-enwiki" if ret=="bm25_2022" else "enwiki" 106 | data_dir = os.path.join(base_dir, name) 107 | index_dir = os.path.join(base_dir, name + "-index") 108 | searcher = BM25Searcher(data_dir, index_dir) 109 | restricted, restricted_dict = searcher.batch_search(task) 110 | 111 | text_dict = {} 112 | offset = 0 113 | for i in range(20): 114 | with open(os.path.join(data_dir, "{}.jsonl".format(i)), "r") as f: 115 | for line in f: 116 | if offset in restricted: 117 | text_dict[offset] = json.loads(line) 118 | offset += 1 119 | assert len(text_dict)==len(restricted) 120 | 121 | 122 | inputs = [] 123 | outputs = [] 124 | ngrams = [] 125 | 126 | for ex in task.examples: 127 | # use a slightly different template that is better for CLMs 128 | if args.eval_dataset in ["nq", "triviaqa", "kamel"]: 129 | question = ex["input"].split("The answer is: ")[0].strip() 130 | input_text = "Question: {}\nAnswer:".format(question) 131 | elif args.eval_dataset=="entity_translation": 132 | input_text = ex["input"].split("")[0].strip() 133 | elif args.eval_dataset.startswith("lama-"): 134 | input_text = ex["input"].replace("", "_____").strip() + " Fill in the blank. Answer:" 135 | else: 136 | raise NotImplementedError() 137 | 138 | if ret: 139 | for p in restricted_dict[ex["input"]]: 140 | p = text_dict[p] 141 | input_text = p["contents"].replace("", "").replace("", "").strip() + "\n" + input_text 142 | 143 | inputs.append(input_text) 144 | outputs.append(ex["answers"]) 145 | 146 | print ("Input:", inputs[0]) 147 | print ("Output:", outputs[0]) 148 | 149 | return task, inputs, outputs 150 | 151 | 152 | def calc_accuracy(task, outputs, predictions): 153 | 154 | def postprocess_prediction(text): 155 | if str(task)=="entity_translation": 156 | # this seems to be helpful in entity translation 157 | for i in range(len(text)): 158 | if text[i] in string.punctuation: 159 | text = text[:i] 160 | break 161 | return text 162 | 163 | references = [[normalize_answer(answer) for answer in output] for output in outputs] 164 | predictions = [normalize_answer(postprocess_prediction(p)) for p in predictions] 165 | accs = [prediction in reference for prediction, reference in zip(predictions, references)] 166 | 167 | if task.ngrams is not None: 168 | accs_dict = defaultdict(list) 169 | for acc, ngram in zip(accs, task.ngrams): 170 | accs_dict[ngram].append(acc) 171 | acc = np.mean([np.mean(v) for k, v in accs_dict.items()]) 172 | print ("\tMacro EM=%.1f%%" % (100*acc)) 173 | else: 174 | acc = np.mean(accs) 175 | print ("\tEM=%.1f%%" % (100*acc)) 176 | 177 | 178 | def main(): 179 | import argparse 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("--model_name", 182 | type=str, 183 | default="neo-1.3B") 184 | parser.add_argument("--eval_dataset", 185 | type=str, 186 | default=None) 187 | parser.add_argument("--save_dir", 188 | type=str, 189 | default="save") 190 | parser.add_argument("--ret", type=str, default=None) 191 | parser.add_argument("--max_sequence_length", type=int, default=1024) 192 | parser.add_argument("--eval_only", action="store_true") 193 | args = parser.parse_args() 194 | 195 | start_time = time.time() 196 | task, inputs, outputs = load_inputs(args) 197 | print ("Took %dsec to load the data" % (time.time()-start_time)) 198 | 199 | prediction_path = os.path.join(args.save_dir, 200 | args.model_name, 201 | "{}{}.jsonl".format(args.eval_dataset, 202 | "" if args.ret is None else "_" + args.ret) 203 | ) 204 | 205 | if not args.eval_only: 206 | model = Model(args.model_name) 207 | start_time = time.time() 208 | predictions = model.generate(inputs, max_sequence_length=args.max_sequence_length) 209 | print ("Took %dsec to generate" % (time.time()-start_time)) 210 | 211 | if not os.path.exists(os.path.join(args.save_dir, args.model_name)): 212 | os.makedirs(os.path.join(args.save_dir, args.model_name)) 213 | 214 | with open(prediction_path, "w") as f: 215 | for prediction in predictions: 216 | f.write(json.dumps({"prediction": prediction}) + "\n") 217 | else: 218 | assert os.path.exists(prediction_path) 219 | with open(prediction_path, "r") as f: 220 | predictions = [] 221 | for line in f: 222 | predictions.append(json.loads(line)["prediction"]) 223 | 224 | calc_accuracy(task, outputs, predictions) 225 | 226 | 227 | if __name__=='__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /dpr_scale/transforms/lm_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | #!/usr/bin/env python3 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | import random 9 | import numpy as np 10 | import hydra 11 | import torch 12 | import math 13 | import torch.nn as nn 14 | import json 15 | 16 | class LanguageModelingTransform(nn.Module): 17 | def __init__( 18 | self, 19 | bidirectional=False, 20 | masking=None, 21 | masking_ratio=0.0, 22 | preprocessed_tokenizer_type="roberta", 23 | exactly_follow_roberta=False, 24 | enforce_masking_positives=False 25 | ): 26 | super().__init__() 27 | self.bidirectional = bidirectional 28 | self.masking = masking 29 | self.masking_ratio = masking_ratio 30 | self.preprocessed_tokenizer_type = preprocessed_tokenizer_type 31 | assert self.preprocessed_tokenizer_type in ["gpt", "roberta"] 32 | 33 | self.enforce_masking_positives = enforce_masking_positives 34 | if enforce_masking_positives: 35 | assert masking and masking_ratio > 0 36 | 37 | # if True, 80% mask, 10% original, 10% replaced 38 | self.exactly_follow_roberta = exactly_follow_roberta 39 | 40 | if self.bidirectional: 41 | from transformers import RobertaTokenizer 42 | roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 43 | 44 | self.tokenizer = roberta_tokenizer 45 | self.mask_id = roberta_tokenizer.mask_token_id 46 | self.bos_id = roberta_tokenizer.bos_token_id 47 | self.eos_id = roberta_tokenizer.eos_token_id 48 | self.pad_id = roberta_tokenizer.pad_token_id 49 | 50 | if masking is not None: 51 | assert masking in ["uniform", "span", "span-merge", "span-merge-two"] 52 | assert masking_ratio > 0 53 | else: 54 | assert masking_ratio == 0 55 | 56 | else: 57 | assert self.preprocessed_tokenizer_type=="gpt" 58 | assert masking is None and masking_ratio==0.0 59 | 60 | def get_mask_id(self, token): 61 | if self.exactly_follow_roberta: 62 | p = np.random.random() 63 | if p<0.1: 64 | return token 65 | if p<0.2: 66 | return np.random.choice(range(3, len(self.tokenizer)-1)) 67 | return self.mask_id 68 | 69 | def transform_to_roberta(self, tokens, attention_mask, 70 | add_bos_id=False, 71 | add_eos_id=False): 72 | 73 | assert type(tokens)==list and type(attention_mask)==list 74 | assert len(tokens)==len(attention_mask) 75 | new_input_ids, new_attention_mask = [], [] 76 | masked_input_ids = [] # input_ids with mask 77 | label_mask = [] # indicator on which tokens are those to be predicted 78 | 79 | if type(tokens[0])==list and type(attention_mask[0])==list: 80 | for token, mask in zip(tokens, attention_mask): 81 | _input_ids, _attention_mask, _masked_input_ids, _label_mask = self.transform_to_roberta(token, mask) 82 | new_input_ids.append(_input_ids) 83 | new_attention_mask.append(_attention_mask) 84 | masked_input_ids.append(_masked_input_ids) 85 | label_mask.append(_label_mask) 86 | 87 | elif self.preprocessed_tokenizer_type=="roberta": 88 | new_input_ids = tokens 89 | new_attention_mask = attention_mask 90 | 91 | # np.random.seed(2022) # for debugging 92 | 93 | for i, (token, mask) in enumerate(zip(tokens, attention_mask)): 94 | assert type(token)==type(mask)==int 95 | if mask and token not in [self.bos_id, self.eos_id] and \ 96 | self.masking=="uniform" and np.random.random()0: 148 | continue 149 | if curr_input_id>0 or np.sum(curr_attention_mask[j+1:])>0: 150 | masked_attention_mask[0][i][j] = 1 151 | 152 | if 0 in masked_attention_mask[0][i]: 153 | L = masked_attention_mask[0][i].index(0) 154 | assert np.all(masked_attention_mask[0][i][:L]) 155 | assert not np.any(masked_attention_mask[0][i][L:]) 156 | assert not np.any(masked_input_ids[0][i][L:]) 157 | else: 158 | assert np.all(masked_attention_mask[0][i]) 159 | 160 | assert len(masked_input_ids)==len(masked_attention_mask)==1 161 | masked_input_ids = masked_input_ids[0] 162 | masked_attention_mask = masked_attention_mask[0] 163 | 164 | _masked_input_ids, _masked_attention_mask = [], [] 165 | L = len(masked_input_ids[0]) 166 | for curr_input_ids, curr_attention_mask in zip(masked_input_ids, masked_attention_mask): 167 | offset = 0 168 | mask_cnt = curr_input_ids.count(self.mask_id) 169 | while offset < len(curr_input_ids) and self.mask_id in curr_input_ids[offset:]: 170 | idx = offset + curr_input_ids[offset:].index(self.mask_id) 171 | curr_input_ids = curr_input_ids[:idx] + [self.mask_id, self.mask_id] + curr_input_ids[idx+1:] 172 | offset = idx+2 173 | curr_attention_mask = [1] * mask_cnt + curr_attention_mask 174 | assert mask_cnt*2==curr_input_ids.count(self.mask_id) 175 | assert len(curr_input_ids)==len(curr_attention_mask)==L+mask_cnt 176 | 177 | if 0 in curr_attention_mask: 178 | l = curr_attention_mask.index(0) 179 | curr_input_ids = curr_input_ids[:l] 180 | curr_attention_mask = curr_attention_mask[:l] 181 | 182 | _masked_input_ids.append(curr_input_ids) 183 | _masked_attention_mask.append(curr_attention_mask) 184 | 185 | L = max( 186 | np.max([len(_input_ids) for _input_ids in _masked_input_ids]), 187 | np.max([np.sum(mask) for mask in attention_mask[0]])) 188 | assert L<=320 #384 189 | L = 320 #384 190 | 191 | input_ids = [ 192 | _input_ids[:min(L, np.sum(_attention_mask))] + [0]*max(0, L-np.sum(_attention_mask)) 193 | for _input_ids, _attention_mask in zip(input_ids[0], attention_mask[0]) 194 | ] 195 | attention_mask = [ 196 | _attention_mask[:min(L, np.sum(_attention_mask))] + [0]*max(0, L-np.sum(_attention_mask)) 197 | for _attention_mask in attention_mask[0] 198 | ] 199 | masked_input_ids = [ 200 | _input_ids[:L] + [0]*max(0, L-len(_input_ids)) 201 | for _input_ids in _masked_input_ids 202 | ] 203 | masked_attention_mask = [ 204 | _attention_mask[:L] + [0]*max(0, L-len(_attention_mask)) 205 | for _attention_mask in _masked_attention_mask 206 | ] 207 | 208 | input_ids = torch.LongTensor(input_ids) 209 | attention_mask = torch.LongTensor(attention_mask) 210 | masked_input_ids = torch.LongTensor(masked_input_ids) 211 | masked_attention_mask = torch.LongTensor(masked_attention_mask) 212 | labels = torch.LongTensor(flatten_labels) 213 | label_mask = masked_input_ids==self.mask_id 214 | 215 | assert len(flatten_labels)*2==torch.sum(label_mask) 216 | 217 | return { 218 | "input_ids": input_ids, 219 | "attention_mask": attention_mask, 220 | "masked_input_ids": masked_input_ids, 221 | "masked_attention_mask": masked_attention_mask, 222 | "labels": labels, 223 | "label_mask": label_mask 224 | } 225 | 226 | # span masking with each token in the span is represented with a mask 227 | elif self.masking=="span": 228 | return { 229 | "input_ids": torch.LongTensor(input_ids), 230 | "masked_input_ids": torch.LongTensor(masked_input_ids), 231 | "attention_mask": torch.LongTensor(attention_mask), 232 | } 233 | 234 | elif self.bidirectional: 235 | if self.enforce_masking_positives: 236 | assert self.masking=="uniform" 237 | 238 | input_ids = torch.LongTensor(input_ids) 239 | attention_mask = torch.LongTensor(attention_mask) 240 | assert input_ids.shape[0]==1 241 | BS = input_ids.shape[1] 242 | length = input_ids.shape[2] 243 | labels = torch.logical_and( 244 | input_ids.reshape(-1).unsqueeze(-1)==input_ids.reshape(-1).unsqueeze(0), 245 | attention_mask.reshape(-1).unsqueeze(-1)*attention_mask.reshape(-1).unsqueeze(0) 246 | ) 247 | maskout = torch.block_diag(*torch.ones((BS, length, length), dtype=torch.bool)) 248 | labels = torch.logical_and(labels, ~maskout) 249 | 250 | labels = torch.any(labels, -1).reshape(1, BS, length) 251 | labels = torch.logical_and( 252 | labels, 253 | torch.logical_and( 254 | input_ids!=self.bos_id, input_ids!=self.eos_id)) 255 | 256 | masked_input_ids = input_ids.clone() 257 | label_mask = torch.zeros_like(masked_input_ids) 258 | masking_ratio = torch.sum(attention_mask) * self.masking_ratio / torch.sum(labels) 259 | 260 | for i in range(BS): 261 | for j in range(length): 262 | if labels[0, i, j] and np.random.random() < masking_ratio: 263 | masked_input_ids[0, i, j] = self.mask_id 264 | label_mask[0, i, j] = 1 265 | 266 | else: 267 | input_ids, attention_mask, masked_input_ids, label_mask = self.transform_to_roberta(input_ids, attention_mask) 268 | input_ids = torch.LongTensor(input_ids) 269 | masked_input_ids = torch.LongTensor(masked_input_ids) 270 | attention_mask = torch.LongTensor(attention_mask) 271 | label_mask = torch.LongTensor(label_mask) 272 | # print (torch.mean(label_mask.float())) 273 | 274 | return { 275 | "input_ids": input_ids, 276 | "masked_input_ids": masked_input_ids, 277 | "attention_mask": attention_mask, 278 | "label_mask": label_mask, 279 | } 280 | 281 | else: 282 | return { 283 | "input_ids": torch.LongTensor(input_ids), 284 | "attention_mask": torch.LongTensor(attention_mask), 285 | } 286 | 287 | -------------------------------------------------------------------------------- /npm/npm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | import json 10 | import math 11 | import time 12 | import numpy as np 13 | import os 14 | import pickle as pkl 15 | import datetime 16 | import re 17 | import string 18 | 19 | from collections import defaultdict, Counter 20 | from scipy.special import softmax, log_softmax 21 | 22 | from npm.npm_single import NPMSingle 23 | from task.utils_eval import normalize_answer 24 | 25 | try: 26 | from termcolor import colored 27 | except: 28 | pass 29 | import torch.nn.functional as F 30 | 31 | class NPM(NPMSingle): 32 | def get_all_scores(self, queries): 33 | start_query, end_query = queries 34 | start_scores, start_indices, start_knn_ids = super().get_all_scores(start_query) 35 | end_scores, end_indices, end_knn_ids = super().get_all_scores(end_query) 36 | 37 | all_indices = np.concatenate([start_indices, end_indices], -1) 38 | knn_ids = [start_knn_ids[0] + end_knn_ids[0]] 39 | all_scores = np.concatenate([start_scores, end_scores], -1) 40 | all_scores /= self.temperature 41 | assert len(knn_ids)==len(all_scores)==1 and len(knn_ids[0])==len(all_scores[0]) 42 | 43 | return all_scores, all_indices, knn_ids 44 | 45 | def predict_span(self, query_text, ngram_max, valid_func=None, 46 | alphas=[0.0], is_question=False, return_metadata=False): 47 | 48 | # first, obtain query emb 49 | inputs = self.model.tokenizer(query_text) 50 | input_ids = inputs["input_ids"] 51 | assert self.model.tokenizer.mask_token_id in input_ids 52 | idx = input_ids.index(self.model.tokenizer.mask_token_id) 53 | with torch.no_grad(): 54 | input_tensor = torch.LongTensor([input_ids]).cuda() 55 | _, (start_query_tensor, end_query_tensor) = self.model.forward(input_tensor, idx) 56 | start_query = start_query_tensor.detach().cpu().numpy() 57 | end_query = end_query_tensor.detach().cpu().numpy() 58 | 59 | pos2ngram = {} 60 | predictions = {} 61 | 62 | # this is a utility function that finds all possible spans 63 | # composed with the top k start indices and end indices 64 | def get_candidates(start_indices, end_indices): 65 | consider_string_boundary = self.dstore.consider_string_boundary 66 | 67 | start_triples = self.dstore._get_token_position(start_indices.tolist(), 68 | ngram_after=ngram_max) 69 | end_triples = self.dstore._get_token_position(end_indices.tolist(), 70 | ngram_before=ngram_max) 71 | 72 | all_start_indices = set() 73 | all_end_indices = set() 74 | all_start_and_end = set() 75 | 76 | for (block_idx, token_indices, vocabs), start_token_idx in zip(start_triples[0], start_indices[0]): 77 | 78 | if consider_string_boundary and token_indices[0] not in self.dstore.orig_block_idx_to_valid_start[block_idx]: 79 | continue 80 | all_start_indices.add(start_token_idx) 81 | end_token_idx = start_token_idx 82 | 83 | for j in range(len(token_indices)): 84 | 85 | is_valid_start = token_indices[j] in self.dstore.orig_block_idx_to_valid_start[block_idx] 86 | is_valid_end = token_indices[j] in self.dstore.orig_block_idx_to_valid_end[block_idx] 87 | 88 | if self.dstore.embs_consider_boundary and not (is_valid_start or is_valid_end): 89 | continue 90 | 91 | if (not consider_string_boundary) or is_valid_end: 92 | ngram = vocabs[:j+1] 93 | ngram_pos = (start_token_idx, end_token_idx) 94 | # ngram_pos = (block_idx, token_indices[0], token_indices[0]+j) 95 | # assert len(ngram)==ngram_pos[1][1]-ngram_pos[1][0]+1 96 | if valid_func is None or valid_func(ngram): 97 | if ngram_pos in pos2ngram: 98 | assert pos2ngram[ngram_pos]==ngram 99 | else: 100 | pos2ngram[ngram_pos] = ngram 101 | all_end_indices.add(end_token_idx) 102 | all_start_and_end.add(ngram_pos) 103 | 104 | end_token_idx += 1 105 | 106 | for (block_idx, token_indices, vocabs), end_token_idx in zip(end_triples[0], end_indices[0]): 107 | 108 | if consider_string_boundary and token_indices[-1] not in self.dstore.orig_block_idx_to_valid_end[block_idx]: 109 | continue 110 | all_end_indices.add(end_token_idx) 111 | start_token_idx = end_token_idx 112 | 113 | for j in range(len(token_indices)): 114 | 115 | is_valid_start = token_indices[-j-1] in self.dstore.orig_block_idx_to_valid_start[block_idx] 116 | is_valid_end = token_indices[-j-1] in self.dstore.orig_block_idx_to_valid_end[block_idx] 117 | 118 | if self.dstore.embs_consider_boundary and not (is_valid_start or is_valid_end): 119 | continue 120 | 121 | if (not consider_string_boundary) or is_valid_start: 122 | ngram = vocabs[-j-1:] 123 | ngram_pos = (start_token_idx, end_token_idx) 124 | # ngram_pos = (block_idx, token_indices[-1]-j, token_indices[-1]) 125 | # assert len(ngram)==ngram_pos[1][1]-ngram_pos[1][0]+1 126 | if valid_func is None or valid_func(ngram): 127 | if ngram_pos in pos2ngram: 128 | assert pos2ngram[ngram_pos]==ngram 129 | else: 130 | pos2ngram[ngram_pos] = ngram 131 | all_start_indices.add(start_token_idx) 132 | all_start_and_end.add(ngram_pos) 133 | 134 | start_token_idx -= 1 135 | 136 | return all_start_indices, all_end_indices, all_start_and_end 137 | 138 | def get_scores(start_indices, end_indices): 139 | x = self.dstore.get_embs(start_indices) 140 | x = torch.from_numpy(x).cuda() 141 | start_scores = self.get_scores(start_query_tensor, x)[0] 142 | start_scores = start_scores.detach().cpu().numpy() 143 | 144 | x = self.dstore.get_embs(end_indices) 145 | x = torch.from_numpy(x).cuda() 146 | end_scores = self.get_scores(end_query_tensor, x)[0] 147 | end_scores = end_scores.detach().cpu().numpy() 148 | 149 | return start_scores, end_scores 150 | 151 | # main code starts from here 152 | if self.dstore.restricted: 153 | # find passaages to restricted 154 | if query_text in self.dstore.restricted_dict: 155 | block_ids = self.dstore.restricted_dict[query_text] 156 | else: 157 | block_ids = self.dstore.searcher.search(query_text, is_question=is_question) 158 | self.dstore.restricted_dict[query_text] = block_ids 159 | 160 | valid_idxs = [] 161 | for block_id in block_ids: 162 | start, end = self.dstore.orig_block_idx_to_emb_token_idx[block_id:block_id+2] 163 | valid_idxs += list(range(start, end)) 164 | start_indices = np.array([valid_idxs]) 165 | end_indices = np.array([valid_idxs]) 166 | 167 | else: 168 | _, start_indices = self.dstore.search(start_query, k=self.k) 169 | _, end_indices = self.dstore.search(end_query, k=self.k) 170 | 171 | if start_indices.shape[1]==end_indices.shape[1]==0: 172 | for alpha in alphas: 173 | predictions["a={}".format(alpha)] = None 174 | return predictions 175 | 176 | if self.dstore.restricted: 177 | start_scores, end_scores = get_scores(start_indices, end_indices) 178 | _, _, all_start_and_end = get_candidates(start_indices, end_indices) 179 | 180 | all_start_indices = start_indices[0].tolist() 181 | all_end_indices = end_indices[0].tolist() 182 | all_start_scores = start_scores 183 | all_end_scores = end_scores 184 | 185 | else: 186 | all_start_indices, all_end_indices, all_start_and_end = get_candidates(start_indices, end_indices) 187 | 188 | all_start_indices = sorted(all_start_indices) 189 | all_end_indices = sorted(all_end_indices) 190 | 191 | all_start_scores, all_end_scores = get_scores(all_start_indices, all_end_indices) 192 | 193 | all_start_scores = softmax(all_start_scores / self.temperature, -1) 194 | all_end_scores = softmax(all_end_scores / self.temperature, -1) 195 | 196 | idx2start_score = {start_token_idx: score for start_token_idx, score 197 | in zip(all_start_indices, all_start_scores)} 198 | idx2end_score = {end_token_idx: score for end_token_idx, score 199 | in zip(all_end_indices, all_end_scores)} 200 | 201 | pos2score = {} 202 | ngram2score = defaultdict(list) 203 | 204 | # now, assign scores to possible ngrams 205 | for (start, end) in all_start_and_end: 206 | assert start in idx2start_score 207 | assert end in idx2end_score 208 | score = idx2start_score[start] + idx2end_score[end] 209 | 210 | pos2score[(start, end)] = score 211 | ngram2score[tuple(pos2ngram[(start, end)])].append(score) 212 | 213 | if len(pos2score)==len(ngram2score)==0: 214 | for alpha in alphas: 215 | predictions["a={}".format(alpha)] = None 216 | return predictions 217 | 218 | assert len(pos2score)>0 and len(ngram2score)>0 219 | 220 | for alpha in alphas: 221 | def key_func(x, alpha=alpha): 222 | return -np.sum(x[1]) * np.power(len(x[0]), alpha) 223 | 224 | top1_ngram_score_pair = min(ngram2score.items(), key=key_func) 225 | top1_ngram = list(top1_ngram_score_pair[0]) 226 | 227 | predictions["a={}".format(alpha)] = top1_ngram 228 | 229 | if return_metadata: 230 | metadata = {"input": query_text} 231 | if self.dstore.restricted: 232 | metadata["blocks"] = [self.decode(self.dstore.input_ids[block_id]) for block_id in block_ids] 233 | 234 | metadata["pos2score"] = pos2score 235 | metadata["pos2ngram"] = pos2ngram 236 | metadata["ngram2score"] = ngram2score 237 | 238 | predicted_ngram = predictions["a=0.0"] 239 | metadata["predicted"] = self.decode(predicted_ngram) 240 | predicted_spans = [] 241 | for pos, ngram in pos2ngram.items(): 242 | if ngram==predicted_ngram: 243 | 244 | block_id_s = self.dstore.token_idx_to_block_idx[pos[0]] 245 | local_id_s = self.dstore.token_idx_to_local_idx[pos[0]] 246 | block_id_e = self.dstore.token_idx_to_block_idx[pos[1]] 247 | local_id_e = self.dstore.token_idx_to_local_idx[pos[1]] 248 | assert block_id_s==block_id_e 249 | 250 | input_ids = self.dstore.input_ids[block_id_s] 251 | decoded = self.decode(input_ids[:local_id_s]) + \ 252 | colored(self.decode(input_ids[local_id_s:local_id_e+1]), "red") + \ 253 | self.decode(input_ids[local_id_e+1:]) 254 | 255 | predicted_spans.append((decoded, pos2score[pos])) 256 | 257 | metadata["predicted_spans"] = sorted(predicted_spans, key=lambda x: -x[1]) 258 | return predictions, metadata 259 | 260 | return predictions 261 | 262 | def get_query(self, input_text): 263 | inputs = self.model.tokenizer(input_text) 264 | input_ids = inputs["input_ids"] 265 | assert self.model.tokenizer.mask_token_id in input_ids 266 | idx = input_ids.index(self.model.tokenizer.mask_token_id) 267 | with torch.no_grad(): 268 | input_tensor = torch.LongTensor([input_ids]).cuda() 269 | _, query = self.model.forward(input_tensor, idx) 270 | return query 271 | 272 | def evaluate_open(self, task): 273 | all_predictions = [] 274 | mask = self.get_stopword_mask() 275 | do_restricted = self.dstore is not None and self.dstore.restricted is not None 276 | 277 | def valid_func(tokens): 278 | return np.sum(mask[tokens])==0 279 | 280 | if "translation" in str(task): 281 | alphas = [0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0] 282 | ngram_max = 20 283 | else: 284 | alphas = [0.0, 0.5, 1.0] 285 | ngram_max = 10 286 | 287 | for ex in tqdm(task.examples): 288 | dic = self.predict_span( 289 | ex["input"], 290 | ngram_max=ngram_max, 291 | valid_func=valid_func, 292 | alphas=alphas, 293 | is_question=task.is_question, 294 | ) 295 | dic = {k: '' if v is None else self.decode(v) for k, v in dic.items()} 296 | all_predictions.append(dic) 297 | 298 | # compute accuracy 299 | references = [[normalize_answer(answer) for answer in ex["answers"]] for ex in task.examples] 300 | for k in all_predictions[0]: 301 | predictions = [normalize_answer(p[k]) for p in all_predictions] 302 | accs = [prediction in reference for prediction, reference in zip(predictions, references)] 303 | 304 | if task.ngrams is not None: 305 | accs_dict = defaultdict(list) 306 | for acc, ngram in zip(accs, task.ngrams): 307 | accs_dict[ngram].append(acc) 308 | acc = np.mean([np.mean(v) for k, v in accs_dict.items()]) 309 | print ("\t%s\tMacro EM=%.1f%%" % (k, 100*acc)) 310 | else: 311 | acc = np.mean(accs) 312 | print ("\t%s\tEM=%.1f%%" % (k, 100*acc)) 313 | 314 | return all_predictions 315 | 316 | 317 | -------------------------------------------------------------------------------- /preprocess/utils_span.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import json 8 | import time 9 | import string 10 | import numpy as np 11 | 12 | from scipy.linalg import block_diag 13 | from collections import Counter, defaultdict 14 | 15 | from transformers import RobertaTokenizer 16 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 17 | mask_id = tokenizer.mask_token_id 18 | 19 | idx2word = {idx: tokenizer._decode(idx) for idx in range(len(tokenizer.encoder))} 20 | punctuation = set() 21 | subwords = set() 22 | for idx, word in idx2word.items(): 23 | if word in ["!", ",", ".", ":", "?"]: 24 | punctuation.add(idx) 25 | elif word[0]!=" " and word not in string.punctuation: 26 | subwords.add(idx) 27 | 28 | def find_mapping(input_ids, attention_mask): 29 | blocks = [] 30 | for _input_ids, _attention_mask in zip(input_ids, attention_mask): 31 | blocks.append(_input_ids[:np.sum(_attention_mask)]) 32 | 33 | datastore = defaultdict(list) 34 | mapping = defaultdict(list) 35 | N = 1 36 | 37 | def construct_datastore(length_list=range(1, N+1)): 38 | for block_idx, block in enumerate(blocks): 39 | for token_idx, token in enumerate(block): 40 | for L in length_list: 41 | key = block[token_idx:token_idx+L] 42 | if 0 in key or 2 in key: 43 | continue 44 | if len(key)==L: 45 | datastore[tuple(key)].append((block_idx, token_idx)) 46 | 47 | construct_datastore() 48 | 49 | for block_idx, block in enumerate(blocks): 50 | token_idx = 0 51 | for token_idx, token in enumerate(block): 52 | if token in [0, 2]: 53 | continue 54 | # let's remove starting from subwords 55 | if (token_idx==0 or block[token_idx-1]!=0) and token in subwords: 56 | continue 57 | 58 | if token in punctuation and np.random.random()<0.9: 59 | continue 60 | 61 | pos = (block_idx, token_idx) 62 | 63 | for L in range(1, 11): 64 | key = block[token_idx:token_idx+L] 65 | if 0 in key or 2 in key or len(key)=2 and token in punctuation: 69 | break 70 | # let's remove if the next token is the subword 71 | if token_idx+L<=len(block)-1 and block[token_idx+L] in subwords: 72 | continue 73 | 74 | assert len(key)==L 75 | key = tuple(key) 76 | if key not in datastore: 77 | # further construct the datastore 78 | construct_datastore(length_list=range(L, 2*L)) 79 | 80 | assert key in datastore 81 | assert pos in datastore[key] 82 | 83 | candidates = [t for t in datastore[key] if block_idx!=t[0]] 84 | if len(candidates)==0: 85 | break 86 | else: 87 | mapping[pos].append((candidates, L, key)) 88 | 89 | return mapping 90 | 91 | 92 | def mask_spans(data_path, masking_ratio=0.4, p=0.2): 93 | np.random.seed(2022) 94 | start_time = time.time() 95 | out_data_path = data_path.replace(".jsonl", "_mr{}_p{}.jsonl".format(masking_ratio, p)) 96 | 97 | tot = 0 98 | line_idx = 0 99 | start_time = time.time() 100 | 101 | with open(data_path, "r") as f: 102 | with open(out_data_path, "w") as f_w: 103 | 104 | for line in f: 105 | dp = json.loads(line) 106 | mapping = find_mapping(dp["input_ids"], dp["attention_mask"]) 107 | 108 | # this is to sample spans to mask out 109 | ngram2spans = defaultdict(list) 110 | pos2dependent_keys = defaultdict(list) 111 | 112 | for (i, j), triples in mapping.items(): 113 | for (candidates, n, key) in triples: 114 | ngram2spans[min(10, n)].append((i, j, n)) 115 | for (other_i, other_j) in candidates: 116 | for k in range(n): 117 | pos2dependent_keys[(other_i, other_j+k)].append((i, j, n)) 118 | 119 | input_ids = np.array(dp["input_ids"]) 120 | attention_mask = np.array(dp["attention_mask"]) 121 | BS, length = input_ids.shape 122 | ngram2spans = {k: np.random.permutation(v).tolist() for k, v in ngram2spans.items()} 123 | 124 | masked_input_ids = input_ids.copy() 125 | mask_budget = int(np.sum(attention_mask)*masking_ratio) 126 | mask_cnts = [] 127 | 128 | finish_masking = False 129 | masked_triples = set() 130 | masked_ngram_counter = Counter() 131 | n_masks_counter = Counter() 132 | mask_list = defaultdict(list) 133 | 134 | for n in np.random.geometric(p=p, size=(mask_budget, )): 135 | if n>10: 136 | continue 137 | 138 | while True: 139 | if False: 140 | if len(all_spans)==0: 141 | finish_masking = True 142 | break 143 | i, j, ngram = all_spans[-1] 144 | all_spans = all_spans[:-1] 145 | else: 146 | while n>0 and (n not in ngram2spans or len(ngram2spans[n])==0): 147 | assert n>0 148 | n -= 1 149 | if n==0: 150 | finish_masking = True 151 | break 152 | i, j, ngram = ngram2spans[n][-1] 153 | ngram2spans[n] = ngram2spans[n][:-1] 154 | 155 | # don't mask from the same sequence too many times 156 | if n_masks_counter[i] + 1 > 64: 157 | continue 158 | 159 | # tokens-to-be-masked shouldn't be already 160 | # masked out 161 | if np.sum(masked_input_ids[i, j:j+ngram]==mask_id)>0: 162 | # print ("skipping because some of them are already masked out") 163 | continue 164 | 165 | # if the same ngram has been masked out too much, then skip 166 | freq = masked_ngram_counter[tuple(masked_input_ids[i,j:j+ngram])] 167 | if freq >= 10: 168 | # print ("skipping", tokenizer.decode(masked_input_ids[i,j:j+ngram])) 169 | continue 170 | 171 | ''' 172 | # see if ids covering this position is fine 173 | candidates, n1, _ = [triple for triple in mapping[(i, j)] if triple[1]==ngram][0] 174 | assert n1==ngram 175 | covered = False 176 | for (other_i, other_j) in candidates: 177 | if np.sum(masked_input_ids[other_i, other_j:other_j+ngram]==mask_id)==0: 178 | covered = True 179 | break 180 | if not covered: 181 | continue 182 | 183 | # see if ids covered by this position is fine 184 | not_covered_found = False 185 | dependencies = set() 186 | for k in range(ngram): 187 | dependencies |= set(pos2dependent_keys[(i, j+k)]) 188 | dependencies &= masked_triples 189 | 190 | for (other_i, other_j, other_n) in dependencies: 191 | # let's make sure there're other ngrams that cover 192 | # (other_i, other_j, other_n) 193 | other_candidates, n1, _ = \ 194 | [triple for triple in mapping[(other_i, other_j)] 195 | if triple[1]==other_n][0] 196 | assert n1==other_n 197 | 198 | covered = 0 199 | for (another_i, another_j) in other_candidates: 200 | if another_i==i and \ 201 | len(set(range(another_j, another_j+n1)) & set(range(j, j+ngram))) > 0: 202 | pass 203 | elif np.sum(masked_input_ids[another_i, another_j:another_j+n1]==mask_id)==0: 204 | covered += 1 205 | break 206 | 207 | if covered==0: 208 | not_covered_found = True 209 | break 210 | 211 | if not_covered_found: 212 | #print ("skipping 2") 213 | continue 214 | ''' 215 | break 216 | 217 | if finish_masking: 218 | break 219 | 220 | masked_triples.add((i, j, ngram)) 221 | mask_list[i].append((j, ngram)) 222 | masked_ngram_counter[tuple(masked_input_ids[i,j:j+ngram])] += 1 223 | masked_input_ids[i,j:j+ngram] = mask_id 224 | n_masks_counter[i] += 1 225 | 226 | mask_cnts.append(ngram) 227 | # masking_ngram_list.append(ngram) 228 | if np.sum(mask_cnts) >= mask_budget: 229 | break 230 | 231 | #t2 = time.time() 232 | curr_ratio = np.sum(masked_input_ids==mask_id)/np.sum(attention_mask) 233 | 234 | if curr_ratio < 0.05: 235 | # print ("skipping because not much to mask out") 236 | continue 237 | 238 | # masking_ratio_list.append(curr_ratio) 239 | 240 | input_ids = input_ids.tolist() 241 | masked_input_ids = masked_input_ids.tolist() 242 | attention_mask = attention_mask.tolist() 243 | 244 | merged_masked_input_ids = [] 245 | merged_attention_mask = [] 246 | merged_labels = [] 247 | 248 | for i, (curr_input_ids, curr_masked_input_ids, curr_attention_mask) in enumerate(zip( 249 | input_ids, masked_input_ids, attention_mask)): 250 | 251 | curr_merged_masked_input_ids = [] 252 | curr_merged_attention_mask = [] 253 | curr_merged_labels = [] 254 | 255 | curr_mask_list = sorted(mask_list[i]) 256 | ''' 257 | offset = 0 258 | while mask_id in curr_masked_input_ids[offset:]: 259 | start_idx = offset + curr_masked_input_ids[offset:].index(mask_id) 260 | end_idx = start_idx 261 | while end_idx+1=2 and token.startswith("'")) or \ 119 | (len(token)>=3 and token=="n't"): 120 | pass 121 | elif text.endswith("www."): 122 | pass 123 | elif text.endswith(".") and token=="com": 124 | pass 125 | elif token in ["(", "["]: 126 | token = " " + token 127 | no_space_next_token = True 128 | else: 129 | token = " " + token 130 | text += token 131 | return text.replace("!.", "!").replace("?.", "?") 132 | 133 | def load_sst2(data_file): 134 | data = [] 135 | with open(data_file) as f: 136 | for line in f: 137 | l, s = line.strip().split('\t') 138 | label = int(l[-1])-3 139 | if label == 0: 140 | continue 141 | label = 1 if label > 0 else 0 142 | data.append({"input": s, "label": label}) 143 | return data 144 | 145 | def load_agn(data_file): 146 | topics = [' politics', ' sports', ' business', ' technology'] 147 | examples = [] 148 | with open(data_file) as fp: 149 | reader = csv.DictReader(fp) 150 | for row in reader: 151 | label = int(row['Class Index'])-1 152 | title = row['Title'] 153 | summary = row['Description'] 154 | input_text = f"{title} \n {summary} {topic_classification_premise}" 155 | examples.append({'label' : label, 'label_list': topics, 'input': input_text}) 156 | return examples 157 | 158 | def load_rt(): 159 | data = load_dataset("rotten_tomatoes", split="test") 160 | examples = [] 161 | for dp in data: 162 | examples.append({"input": dp["text"], "label": dp["label"]}) 163 | return examples 164 | 165 | def load_yahoo(): 166 | label_list = [" society", " science", " health", " education", " computer", " sports", " business", " entertainment", " family", " politics"] 167 | data = load_dataset("yahoo_answers_topics", split="test") 168 | 169 | prompt = " " + topic_classification_premise 170 | icl_str = "" 171 | 172 | examples = [] 173 | for row in data: 174 | label = row["topic"] 175 | title = row['question_title'] 176 | summary = row['question_content'] 177 | answer = row['best_answer'] 178 | input_text = f"title: {title} content: {summary} answer: {answer}{prompt}" 179 | examples.append({'input': input_text, 'label' : label, 'label_list': label_list}) 180 | return examples 181 | 182 | def load_amazon(): 183 | data = load_dataset("amazon_polarity", split="test") 184 | 185 | def preprocess(title, content): 186 | if not any([title.endswith(p) for p in string.punctuation]): 187 | title = title + "." 188 | return title + " " + content 189 | 190 | examples = [] 191 | for dp in data: 192 | examples.append({"input": preprocess(dp["title"], dp["content"]), 193 | "label": dp["label"]}) 194 | return examples 195 | 196 | def load_rte(): 197 | data = load_dataset("glue", "rte", split="validation") #, verbose=False) 198 | puncutations = set(string.punctuation) 199 | label_list = [" Yes", " No"] 200 | 201 | examples = [] 202 | for i, dp in enumerate(data): 203 | pre = dp["sentence1"] 204 | hyp = dp["sentence2"] 205 | label = dp["label"] 206 | while pre[-1] in puncutations: 207 | pre = pre[:-1] 208 | while hyp[-1] in puncutations: 209 | hyp = hyp[:-1] 210 | input_text = pre + ", right?, " + hyp[0].lower() + hyp[1:] + "." 211 | examples.append({"input": input_text, "label": label, "label_list": label_list}) 212 | return examples 213 | 214 | def template_question(question): 215 | question = question.strip() 216 | if question.startswith('"') and question.endswith('"'): 217 | question = question[1:-1].strip() 218 | if not question.endswith("?"): 219 | question = question + "?" 220 | return question + " The answer is: ." 221 | 222 | def load_triviaqa(): 223 | orig_data = load_dataset("trivia_qa", "unfiltered.nocontext", split="validation") 224 | data = [] 225 | for dp in orig_data: 226 | answers = [dp["answer"]["value"]] 227 | for alias in dp["answer"]["aliases"] + dp["answer"]["normalized_aliases"]: 228 | answers.append(alias) 229 | data.append({"input": template_question(dp["question"]), 230 | "answers": answers}) 231 | return data 232 | 233 | def load_nq(data_dir, split="test"): 234 | with open(os.path.join(data_dir, "nq/nqopen-{}.json".format(split)), "r") as f: 235 | orig_data = json.load(f) 236 | with open(os.path.join(data_dir, "nq/{}_id2answers.json".format(split)), "r") as f: 237 | id2answers = json.load(f) 238 | data = [] 239 | for dp in orig_data: 240 | data.append({"input": template_question(dp["question"]), 241 | "answers": id2answers[dp["id"]]}) 242 | return data 243 | 244 | def load_lama(data_dir, subset): 245 | 246 | data_path = os.path.join(data_dir, "lama", "{}.jsonl".format(subset)) 247 | 248 | if os.path.exists(data_path): 249 | data = [] 250 | with open(data_path, "r") as f: 251 | for line in f: 252 | data.append(json.loads(line)) 253 | 254 | else: 255 | base_dir = os.path.join(data_dir, "lama", "data") 256 | id2hf_data = _load_hf_lama(subset) 257 | 258 | if subset=="trex": 259 | data = _load_lama(os.path.join(base_dir, "LAMA-TREx"), "trex", id2hf_data) 260 | data1 = _load_lama(os.path.join(base_dir, "LAMA-TREx_UHN"), "trex", id2hf_data) 261 | data2 = _load_lama(os.path.join(base_dir, "LAMA-TREx-easy-hard/Hard"), "trex", id2hf_data) 262 | 263 | ids = set([dp["id"] for dp in data]) 264 | ids1 = set([dp["id"] for dp in data1]) 265 | ids2 = set([dp["id"] for dp in data2]) 266 | assert len(ids1-ids)==len(ids2-ids)==0 267 | 268 | for i, dp in enumerate(data): 269 | data[i]["is_uhn"] = dp["id"] in ids1 270 | data[i]["is_hard"] = dp["id"] in ids2 271 | 272 | elif subset=="google_re": 273 | data = _load_lama(os.path.join(base_dir, "Google_RE"), "google_re", id2hf_data) 274 | data1 = _load_lama(os.path.join(base_dir, "Google_RE_UHN"), "google_re", id2hf_data) 275 | ids = set([dp["id"] for dp in data]) 276 | ids1 = set([dp["id"] for dp in data1]) 277 | assert len(ids1-ids)==0 278 | 279 | for i, dp in enumerate(data): 280 | data[i]["is_uhn"] = dp["id"] in ids1 281 | 282 | # we need to tokenize answers since it is needed for computing macro-average and sampling the data 283 | from transformers import RobertaTokenizer 284 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 285 | def tokenize(answer): 286 | if type(answer)==list: 287 | return [tokenize(_answer) for _answer in answer] 288 | input_ids = tokenizer(" " + answer.strip())["input_ids"] 289 | assert input_ids[0]==0 and input_ids[-1]==2 290 | return input_ids[1:-1] 291 | 292 | for i, dp in enumerate(data): 293 | answers = dp["answers"] 294 | data[i]["tokenized_answers"] = tokenize(answers) 295 | 296 | with open(data_path, "w") as f: 297 | for dp in data: 298 | f.write(json.dumps(dp)+"\n") 299 | 300 | return data 301 | 302 | def _load_hf_lama(name): 303 | 304 | ## Load HF version first to use their template (in case of Google RE) and double-check the data matches 305 | 306 | def convert_input(dp): 307 | text = dp["masked_sentence"] 308 | inputs = set() 309 | if "template" in dp: 310 | text = dp["template"] 311 | text = text.replace("[Y]", "").replace(" .", ".") 312 | text = text.replace("[X]", dp["sub_label"]) 313 | inputs.add(text) 314 | return inputs 315 | 316 | dataset = load_dataset("lama", name, split="train") 317 | id2data = defaultdict(list) 318 | for dp in tqdm(dataset): 319 | sub = dp["sub_label"] if "sub_label" in dp else dp["sub"] 320 | obj = dp["obj_label"] 321 | inputs = convert_input(dp) 322 | 323 | if dp["masked_sentence"].count("[MASK]")!=1: 324 | # print ("Skipping `%s` from %s" % (dp["masked_sentence"], name)) 325 | continue 326 | 327 | assert all([input_.count("")==1 for input_ in inputs]) 328 | 329 | if dp["uuid"] in id2data: 330 | old_dp = id2data[dp["uuid"]] 331 | assert old_dp["subject"]==sub and old_dp["object"]==obj 332 | id2data[dp["uuid"]]["inputs"] |= inputs 333 | else: 334 | id2data[dp["uuid"]] = {"inputs": inputs, "subject": sub, "object": obj} 335 | 336 | return id2data 337 | 338 | def _load_lama(data_dir, name, id2hf_data): 339 | data = [] 340 | for relation in os.listdir(data_dir): 341 | with open(os.path.join(data_dir, relation), "r") as f: 342 | for line in f: 343 | dp = json.loads(line) 344 | subjects = set() 345 | objects = set() 346 | inputs = set() 347 | 348 | subjects.add(dp["sub_label"]) 349 | objects.add(dp["obj_label"]) 350 | 351 | if "masked_sentences" in dp: 352 | text = " ".join(dp["masked_sentences"]) 353 | if text.count("[MASK]")!=1: 354 | continue 355 | text = text.replace("[MASK]", "").replace(" .", ".") 356 | inputs.add(text) 357 | 358 | for e in dp["evidences"]: 359 | if "sub_surface" in e: 360 | subjects.add(e["sub_surface"]) 361 | if "obj_surface" in e: 362 | objects.add(e["obj_surface"]) 363 | 364 | assert dp["uuid"] in id2hf_data 365 | if name!="google_re": 366 | if dp["uuid"] in id2hf_data: 367 | hf_dp = id2hf_data[dp["uuid"]] 368 | subjects.add(hf_dp["subject"]) 369 | objects.add(hf_dp["object"]) 370 | inputs |= hf_dp["inputs"] 371 | 372 | if len(inputs)==0: 373 | continue 374 | 375 | if name=="google_re": 376 | # weirdly, the template from the original data is more like evidence from 377 | # original sources, so we are using templates from HF 378 | inputs = id2hf_data[dp["uuid"]]["inputs"] 379 | 380 | if len(inputs)==0: 381 | continue 382 | 383 | inputs = list(inputs) 384 | assert len(inputs)==1 385 | assert inputs[0].count("")==1 386 | data.append({ 387 | "id": dp["uuid"], 388 | "subjects": list(subjects), 389 | "answers": list(objects), 390 | "input": inputs[0] 391 | }) 392 | return data 393 | 394 | def load_kamel(data_dir): 395 | data_dir = os.path.join(data_dir, "kamel") 396 | template_path = os.path.join(data_dir, "question-templates.csv") 397 | 398 | import csv 399 | rel2question = {} 400 | with open(template_path, "r") as f: 401 | for rel, question in csv.reader(f): 402 | assert "[S]" in question, question 403 | assert question.endswith("?"), question 404 | rel2question[rel] = question 405 | 406 | data = [] 407 | for rel in tqdm(os.listdir(data_dir)): 408 | if not rel.startswith("P"): 409 | continue 410 | with open(os.path.join(data_dir, rel, "dev.jsonl"), "r") as f: 411 | for line in f: 412 | dp = json.loads(line) 413 | input_text = rel2question[rel].replace("[S]", dp["sub_label"]) + " Answer: ." 414 | objects = dp["obj_label"] 415 | 416 | if any([o.endswith(".") or o.endswith("+") or o.endswith("!") for o in objects]): 417 | input_text = input_text[:-1] 418 | 419 | data.append({"input": input_text, "answers": objects}) 420 | 421 | return data 422 | 423 | 424 | 425 | 426 | --------------------------------------------------------------------------------