├── .gitignore ├── LICENSE ├── README.md ├── code ├── accumulate_results.py ├── models.py ├── run_eval_prompts.py ├── run_finetune.py ├── run_optiprompt.py └── utils.py ├── common_vocabs ├── README.md ├── common_vocab_cased.txt └── common_vocab_cased_be_ro_al.txt ├── figure └── optiprompt.png ├── relation_metainfo ├── AutoPrompt_relations.jsonl ├── LAMA_relations.jsonl ├── LPAQA_relations.jsonl └── README.md ├── requirements.txt ├── scripts ├── download_data.sh ├── run_eval_prompts.sh ├── run_finetune.sh └── run_optiprompt.sh └── slides └── slides.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Princeton Natural Language Processing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OptiPrompt 2 | 3 | This is the PyTorch implementation of the paper [Factual Probing Is [MASK]: Learning vs. Learning to Recall](https://arxiv.org/pdf/2104.05240.pdf). 4 | 5 | We propose OptiPrompt, a simple and effective approach for Factual Probing. OptiPrompt optimizes the prompts on the input embedding space directly. It outperforms previous prompting methods on the LAMA benchmark. 6 | Furthermore, in order to better interpret probing results, we propose control experiments based on the probing results on randomly initialized models. 7 | Please check [our paper](https://arxiv.org/pdf/2104.05240.pdf) for details. 8 | 9 | 10 | 11 | ## Quick links 12 | * [Setup](#Setup) 13 | * [Install dependencies](#Install-dependencies) 14 | * [Download the data](#Download-the-data) 15 | * [Run OptiPrompt](#Run-optiprompt) 16 | * [Train/evaluate OptiPrompt](#Train/evaluate-optiPrompt) 17 | * [Run experiments on all relations](#Run-experiments-on-all-relations) 18 | * [Run Fine-tuning](#Run-fine-tuning) 19 | * [Evaluate LAMA/LPAQA/AutoPrompt prompts](#Evaluate-lamalpaqaautoprompt-prompts) 20 | * [Questions?](#questions) 21 | * [Citation](#Citation) 22 | 23 | ## Setup 24 | 25 | ### Install dependecies 26 | Our code is based on python 3.7. All experiments are run on a single GPU. 27 | 28 | Please install all the dependency packages using the following command: 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Download the data 34 | We pack all datasets we used in our experiments [here](https://nlp.cs.princeton.edu/projects/optiprompt/data.tar.gz). Please download it and extract the files to `./data`, or run the following commands to autoamtically download and extract it. 35 | ```bash 36 | bash scripts/download_data.sh 37 | ``` 38 | 39 | The datasets are structured as below. 40 | 41 | data 42 | ├── LAMA-TREx # The original LAMA-TREx test set (34,039 examples) 43 | │ ├── P17.jsonl # Testing file for the relation `P17` 44 | │ └── ... 45 | ├── LAMA-TREx_UHN # The LAMA-TREx_UHN test set (27,102 examples) 46 | │ ├── P17.jsonl # Testing file for the relation `P17` 47 | │ └── ... 48 | ├── LAMA-TREx-easy-hard # The easy and hard partitions of the LAMA-TREx dataset (check the paper for details) 49 | │ ├── Easy # The LAMA-easy partition (10,546 examples) 50 | │ │ ├── P17.jsonl # Testing file for the relation `P17` 51 | │ │ └── ... 52 | │ └── Hard # The LAMA-hard partition (23,493 examples) 53 | │ ├── P17.jsonl # Testing file for the relation `P17` 54 | │ └── ... 55 | ├── autoprompt_data # Training data collected by AutoPrompt 56 | │ ├── P17 # Train/dev/test files for the relation `P17` 57 | │ │ ├── train.jsonl # Training examples 58 | │ │ ├── dev.jsonl # Development examples 59 | │ │ └── test.jsonl # Test examples (the same as LAMA-TREx test set) 60 | │ └── ... 61 | └── cmp_lms_data # Training data collected by ourselves which can be used for BERT, RoBERTa, and ALBERT (we only use this dataset in Table 6 in the paper) 62 | ├── P17 # Train/dev/test files for the relation `P17` 63 | │ ├── train.jsonl # Training examples 64 | │ ├── dev.jsonl # Development examples 65 | │ ├── test.jsonl # Test examples (a subset of the LAMA-TREx test set, filtered using the common vocab of three models) 66 | └── ... 67 | 68 | ## Run OptiPrompt 69 | 70 | ### Train/evaluate OptiPrompt 71 | You can use `code/run_optiprompt.py` to train or evaluate the prompts on a specific relation. A command template is as follow: 72 | ```bash 73 | rel=P101 74 | dir=outputs/${rel} 75 | mkdir -p ${dir} 76 | 77 | python code/run_optiprompt.py \ 78 | --relation_profile relation_metainfo/LAMA_relations.jsonl \ 79 | --relation ${rel} \ 80 | --common_vocab_filename common_vocabs/common_vocab_cased.txt \ 81 | --model_name bert-base-cased \ 82 | --do_train \ 83 | --train_data data/autoprompt_data/${rel}/train.jsonl \ 84 | --dev_data data/autoprompt_data/${rel}/dev.jsonl \ 85 | --do_eval \ 86 | --test_data data/LAMA-TREx/${rel}.jsonl \ 87 | --output_dir ${dir} \ 88 | --random_init none \ 89 | --output_predictions \ 90 | [--init_manual_template] [--num_vectors 5 | 10] 91 | ``` 92 | Arguments: 93 | 94 | * `relation_profile`: the meta information for each relation, containing the manual templates. 95 | * `relation`: the relation type (e.g., `P101`) considered in this experiment. 96 | * `common_vocab_filename`: the vocabulary used to filter out facts; it should be the intersection of different models' for fair comparison. 97 | * `model_name`: the pre-trained model used in this experiment, e.g., `bert-base-cased`, `albert-xxlarge-v1`. 98 | * `do_train`: whether to train the prompts on a training and development set. 99 | * `do_eval`: whether to test the trained prompts on a testing set. 100 | * `{train|dev|test}_data`: the file path of training/development/testing dataset. 101 | * `random_init`: how do we random initialize the model before training, there are three settings: 102 | * `none`: use the pre-trained model, no random initialization is used; 103 | * `embedding`: the `Rand E` control setting, where we random initialize the embedding layer of the model; 104 | * `all`: the `Rand M` control setting, where we random initialize all the parameters of the model. 105 | * `init_manual_template`: whether initialize the dense vectors in OptiPrompt using the manual prompts. 106 | * `num_vectors`: how many dense vectors are added in OptiPrompt (this argument is valid only when `init_manual_template` is **not** set). 107 | * `output_predictions`: whether to output top-k predictions for each testing fact (`k` is specified by `--k`). 108 | 109 | ### Run experiments on all relations 110 | 111 | We provide an example script (`scripts/run_optiprompt.sh`) to run OptiPrompt on all 41 relations on the LAMA benchmark. Run the following command to use it: 112 | 113 | ```bash 114 | bash scripts/run_opti.sh 115 | ``` 116 | 117 | The default setting of this script is to run OptiPromot initialized with manual prompts on the pre-trained `bert-base-cased` model (no random initialization is used). The results will be stored in the `outputs` directory. 118 | 119 | Please modify the shell variables (i.e., `OUTPUTS_DIR`, `MODEL`, `RAND`) in `scripts/run_optiprompt.sh` if you want to run experiments on other settings. 120 | 121 | ## Run Fine-tuning 122 | 123 | We release the code that we used in our experiments (check Section 4 in the paper). 124 | 125 | ### Fine-tuning language models on factual probing 126 | You can use `code/run_finetune.py` to fine-tune a language model on a specific relation. A command template is as follow: 127 | 128 | ```bash 129 | rel=P101 130 | dir=outputs/${rel} 131 | mkdir -p ${dir} 132 | 133 | python code/run_finetune.py \ 134 | --relation_profile relation_metainfo/LAMA_relations.jsonl \ 135 | --relation ${rel} \ 136 | --common_vocab_filename common_vocabs/common_vocab_cased.txt \ 137 | --model_name bert-base-cased \ 138 | --do_train \ 139 | --train_data data/autoprompt_data/${rel}/train.jsonl \ 140 | --dev_data data/autoprompt_data/${rel}/dev.jsonl \ 141 | --do_eval \ 142 | --test_data data/LAMA-TREx/${rel}.jsonl \ 143 | --output_dir ${dir} \ 144 | --random_init none \ 145 | --output_predictions 146 | ``` 147 | 148 | Arguments: 149 | * `relation_profile`: the meta information for each relation, containing the manual templates. 150 | * `relation`: the relation type (e.g., `P101`) considered in this experiment. 151 | * `common_vocab_filename`: the vocabulary used to filter out facts; it should be the intersection of different models' for fair comparison. 152 | * `model_name`: the pre-trained model used in this experiment, e.g., `bert-base-cased`, `albert-xxlarge-v1`. 153 | * `do_train`: whether to train the prompts on a training and development set. 154 | * `do_eval`: whether to test the trained prompts on a testing set. 155 | * `{train|dev|test}_data`: the file path of training/development/testing dataset. 156 | * `random_init`: how do we random initialize the model before training, there are three settings: 157 | * `none`: use the pre-trained model, no random initialization is used; 158 | * `embedding`: the `Rand E` control setting, where we random initialize the embedding layer of the model; 159 | * `all`: the `Rand M` control setting, where we random initialize all the parameters of the model. 160 | * `output_predictions`: whether to output top-k predictions for each testing fact (`k` is specified by `--k`). 161 | 162 | ### Run experiments on all relations 163 | We provide an example script (`scripts/run_finetune.sh`) to run fine-tuning on all 41 relations on the LAMA benchmark. Run the following command to use it: 164 | 165 | ```bash 166 | bash scripts/run_finetune.sh 167 | ``` 168 | 169 | Please modify the shell variables (i.e., `OUTPUTS_DIR`, `MODEL`, `RAND`) in `scripts/run_finetune.sh` if you want to run experiments on other settings. 170 | 171 | ## Evaluate LAMA/LPAQA/AutoPrompt prompts 172 | We provide a script to evaluate prompts released in previous works (based on `code/run_finetune.py` with only `--do_eval`). Please use the foolowing command: 173 | 174 | ```bash 175 | bash scripts/run_eval_prompts.sh {lama | lpaqa | autoprompt} 176 | ``` 177 | 178 | ## Questions? 179 | If you have any questions related to the code or the paper, feel free to email Zexuan Zhong `(zzhong@cs.princeton.edu)` or Dan Friedman `(dfriedman@cs.princeton.edu)`. If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 180 | 181 | ## Citation 182 | If you use our code in your research, please cite our work: 183 | ```bibtex 184 | @inproceedings{zhong2021factual, 185 | title={Factual Probing Is [MASK]: Learning vs. Learning to Recall}, 186 | author={Zhong, Zexuan and Friedman, Dan and Chen, Danqi}, 187 | booktitle={North American Association for Computational Linguistics (NAACL)}, 188 | year={2021} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /code/accumulate_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import sys 6 | import logging 7 | 8 | from utils import load_file 9 | 10 | output_dir = sys.argv[1] 11 | relations = 'P1001 P101 P103 P106 P108 P127 P1303 P131 P136 P1376 P138 P140 P1412 P159 P17 P176 P178 P19 P190 P20 P264 P27 P276 P279 P30 P31 P36 P361 P364 P37 P39 P407 P413 P449 P463 P47 P495 P527 P530 P740 P937'.split() 12 | 13 | tot = 0 14 | cor = 0 15 | 16 | rel_avg = [] 17 | 18 | for relation in relations: 19 | rel_tot = 0 20 | rel_cor = 0 21 | samples = load_file(os.path.join(output_dir, '%s/%s_predictions.jsonl'%(relation, relation))) 22 | for sample in samples: 23 | if sample['obj_label'] == sample['topk'][0]['token']: 24 | rel_cor += 1 25 | rel_tot += 1 26 | 27 | rel_avg.append(rel_cor / rel_tot) 28 | tot += rel_tot 29 | cor += rel_cor 30 | 31 | print('%s\t%.2f\t(%d / %d)'%(relation, (rel_cor / rel_tot * 100), rel_cor, rel_tot)) 32 | 33 | micro = sum(rel_avg) / len(rel_avg) * 100 34 | macro = cor / tot * 100 35 | 36 | print('Macro: %.2f\t(%d / %d)'%(macro, cor, tot)) 37 | print('Micro: %.2f'%(micro)) 38 | -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import random 5 | 6 | from transformers import BertTokenizer, BertForMaskedLM, BertConfig 7 | from transformers import AlbertTokenizer, AlbertForMaskedLM, AlbertConfig 8 | from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaConfig 9 | from transformers import AutoConfig 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | class Prober(): 18 | 19 | def __init__(self, args, random_init='none'): 20 | assert(random_init in ['none', 'all', 'embedding']) 21 | 22 | super().__init__() 23 | 24 | self._model_device = 'cpu' 25 | 26 | model_name = args.model_name 27 | vocab_name = model_name 28 | 29 | if args.model_dir is not None: 30 | # load bert model from file 31 | model_name = str(args.model_dir) + "/" 32 | vocab_name = model_name 33 | logger.info("loading BERT model from {}".format(model_name)) 34 | 35 | # Load pre-trained model tokenizer (vocabulary) 36 | random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | torch.cuda.manual_seed(args.seed) 39 | if torch.cuda.device_count() > 1: 40 | torch.cuda.manual_seed_all(args.seed) 41 | 42 | config = AutoConfig.from_pretrained(model_name) 43 | if isinstance(config, AlbertConfig): 44 | self.model_type = 'albert' 45 | self.tokenizer = AlbertTokenizer.from_pretrained(vocab_name) 46 | self.mlm_model = AlbertForMaskedLM.from_pretrained(model_name) 47 | if random_init == 'all': 48 | logger.info('Random initialize model...') 49 | self.mlm_model = AlbertForMaskedLM(self.mlm_model.config) 50 | self.base_model = self.mlm_model.albert 51 | elif isinstance(config, RobertaConfig): 52 | self.model_type = 'roberta' 53 | self.tokenizer = RobertaTokenizer.from_pretrained(vocab_name) 54 | self.mlm_model = RobertaForMaskedLM.from_pretrained(model_name) 55 | if random_init == 'all': 56 | logger.info('Random initialize model...') 57 | self.mlm_model = RobertaForMaskedLM(self.mlm_model.config) 58 | self.base_model = self.mlm_model.roberta 59 | elif isinstance(config, BertConfig): 60 | self.model_type = 'bert' 61 | self.tokenizer = BertTokenizer.from_pretrained(vocab_name) 62 | self.mlm_model = BertForMaskedLM.from_pretrained(model_name) 63 | if random_init == 'all': 64 | logger.info('Random initialize model...') 65 | self.mlm_model = BertForMaskedLM(self.mlm_model.config) 66 | self.base_model = self.mlm_model.bert 67 | else: 68 | raise ValueError('Model %s not supported yet!'%(model_name)) 69 | 70 | self.mlm_model.eval() 71 | 72 | if random_init == 'embedding': 73 | logger.info('Random initialize embedding layer...') 74 | self.mlm_model._init_weights(self.base_model.embeddings.word_embeddings) 75 | 76 | # original vocab 77 | self.map_indices = None 78 | self.vocab = list(self.tokenizer.get_vocab().keys()) 79 | logger.info('Vocab size: %d'%len(self.vocab)) 80 | self._init_inverse_vocab() 81 | 82 | self.MASK = self.tokenizer.mask_token 83 | self.EOS = self.tokenizer.eos_token 84 | self.CLS = self.tokenizer.cls_token 85 | self.SEP = self.tokenizer.sep_token 86 | self.UNK = self.tokenizer.unk_token 87 | # print(self.MASK, self.EOS, self.CLS, self.SEP, self.UNK) 88 | 89 | self.pad_id = self.inverse_vocab[self.tokenizer.pad_token] 90 | self.unk_index = self.inverse_vocab[self.tokenizer.unk_token] 91 | 92 | # used to output top-k predictions 93 | self.k = args.k 94 | 95 | def _cuda(self): 96 | self.mlm_model.cuda() 97 | 98 | def try_cuda(self): 99 | """Move model to GPU if one is available.""" 100 | if torch.cuda.is_available(): 101 | if self._model_device != 'cuda': 102 | logger.info('Moving model to CUDA') 103 | self._cuda() 104 | self._model_device = 'cuda' 105 | else: 106 | logger.info('No CUDA found') 107 | 108 | def init_indices_for_filter_logprobs(self, vocab_subset, logger=None): 109 | index_list = [] 110 | new_vocab_subset = [] 111 | for word in vocab_subset: 112 | tokens = self.tokenizer.tokenize(' '+word) 113 | if (len(tokens) == 1) and (tokens[0] != self.UNK): 114 | index_list.append(self.tokenizer.convert_tokens_to_ids(tokens)[0]) 115 | new_vocab_subset.append(word) 116 | else: 117 | msg = "word {} from vocab_subset not in model vocabulary!".format(word) 118 | if logger is not None: 119 | logger.warning(msg) 120 | else: 121 | logger.info("WARNING: {}".format(msg)) 122 | 123 | indices = torch.as_tensor(index_list) 124 | return indices, index_list 125 | 126 | def _init_inverse_vocab(self): 127 | self.inverse_vocab = {w: i for i, w in enumerate(self.vocab)} 128 | 129 | def get_id(self, string): 130 | tokenized_text = self.tokenizer.tokenize(string) 131 | indexed_string = self.tokenizer.convert_tokens_to_ids(tokenized_text) 132 | if self.map_indices is not None: 133 | # map indices to subset of the vocabulary 134 | indexed_string = self.convert_ids(indexed_string) 135 | 136 | return indexed_string 137 | 138 | def _get_input_tensors_batch_train(self, sentences_list, samples_list): 139 | tokens_tensors_list = [] 140 | segments_tensors_list = [] 141 | masked_indices_list = [] 142 | tokenized_text_list = [] 143 | mlm_labels_tensor_list = [] 144 | mlm_label_ids = [] 145 | 146 | max_tokens = 0 147 | for (sentences, samples) in zip(sentences_list, samples_list): 148 | tokens_tensor, segments_tensor, masked_indices, tokenized_text, mlm_labels_tensor, mlm_label_id = self.__get_input_tensors(sentences, mlm_label=samples['obj_label']) 149 | tokens_tensors_list.append(tokens_tensor) 150 | segments_tensors_list.append(segments_tensor) 151 | masked_indices_list.append(masked_indices) 152 | tokenized_text_list.append(tokenized_text) 153 | mlm_labels_tensor_list.append(mlm_labels_tensor) 154 | mlm_label_ids.append(mlm_label_id) 155 | if (tokens_tensor.shape[1] > max_tokens): 156 | max_tokens = tokens_tensor.shape[1] 157 | 158 | # apply padding and concatenate tensors 159 | # use [PAD] for tokens and 0 for segments 160 | final_tokens_tensor = None 161 | final_segments_tensor = None 162 | final_attention_mask = None 163 | final_mlm_labels_tensor = None 164 | for tokens_tensor, segments_tensor, mlm_labels_tensor in zip(tokens_tensors_list, segments_tensors_list, mlm_labels_tensor_list): 165 | dim_tensor = tokens_tensor.shape[1] 166 | pad_lenght = max_tokens - dim_tensor 167 | attention_tensor = torch.full([1,dim_tensor], 1, dtype= torch.long) 168 | if pad_lenght>0: 169 | pad_1 = torch.full([1,pad_lenght], self.pad_id, dtype= torch.long) 170 | pad_2 = torch.full([1,pad_lenght], 0, dtype= torch.long) 171 | attention_pad = torch.full([1,pad_lenght], 0, dtype= torch.long) 172 | pad_3 = torch.full([1,pad_lenght], -100, dtype=torch.long) 173 | tokens_tensor = torch.cat((tokens_tensor,pad_1), dim=1) 174 | segments_tensor = torch.cat((segments_tensor,pad_2), dim=1) 175 | attention_tensor = torch.cat((attention_tensor,attention_pad), dim=1) 176 | mlm_labels_tensor = torch.cat((mlm_labels_tensor, pad_3), dim=1) 177 | if final_tokens_tensor is None: 178 | final_tokens_tensor = tokens_tensor 179 | final_segments_tensor = segments_tensor 180 | final_attention_mask = attention_tensor 181 | final_mlm_labels_tensor = mlm_labels_tensor 182 | else: 183 | final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0) 184 | final_segments_tensor = torch.cat((final_segments_tensor,segments_tensor), dim=0) 185 | final_attention_mask = torch.cat((final_attention_mask,attention_tensor), dim=0) 186 | final_mlm_labels_tensor = torch.cat((final_mlm_labels_tensor,mlm_labels_tensor), dim=0) 187 | 188 | return final_tokens_tensor, final_segments_tensor, final_attention_mask, masked_indices_list, tokenized_text_list, final_mlm_labels_tensor, mlm_label_ids 189 | 190 | def __get_input_tensors_batch(self, sentences_list): 191 | tokens_tensors_list = [] 192 | segments_tensors_list = [] 193 | masked_indices_list = [] 194 | tokenized_text_list = [] 195 | max_tokens = 0 196 | for sentences in sentences_list: 197 | tokens_tensor, segments_tensor, masked_indices, tokenized_text = self.__get_input_tensors(sentences) 198 | tokens_tensors_list.append(tokens_tensor) 199 | segments_tensors_list.append(segments_tensor) 200 | masked_indices_list.append(masked_indices) 201 | tokenized_text_list.append(tokenized_text) 202 | if (tokens_tensor.shape[1] > max_tokens): 203 | max_tokens = tokens_tensor.shape[1] 204 | # logger.info("MAX_TOKENS: {}".format(max_tokens)) 205 | # apply padding and concatenate tensors 206 | # use [PAD] for tokens and 0 for segments 207 | final_tokens_tensor = None 208 | final_segments_tensor = None 209 | final_attention_mask = None 210 | for tokens_tensor, segments_tensor in zip(tokens_tensors_list, segments_tensors_list): 211 | dim_tensor = tokens_tensor.shape[1] 212 | pad_lenght = max_tokens - dim_tensor 213 | attention_tensor = torch.full([1,dim_tensor], 1, dtype= torch.long) 214 | if pad_lenght>0: 215 | pad_1 = torch.full([1,pad_lenght], self.pad_id, dtype= torch.long) 216 | pad_2 = torch.full([1,pad_lenght], 0, dtype= torch.long) 217 | attention_pad = torch.full([1,pad_lenght], 0, dtype= torch.long) 218 | tokens_tensor = torch.cat((tokens_tensor,pad_1), dim=1) 219 | segments_tensor = torch.cat((segments_tensor,pad_2), dim=1) 220 | attention_tensor = torch.cat((attention_tensor,attention_pad), dim=1) 221 | if final_tokens_tensor is None: 222 | final_tokens_tensor = tokens_tensor 223 | final_segments_tensor = segments_tensor 224 | final_attention_mask = attention_tensor 225 | else: 226 | final_tokens_tensor = torch.cat((final_tokens_tensor,tokens_tensor), dim=0) 227 | final_segments_tensor = torch.cat((final_segments_tensor,segments_tensor), dim=0) 228 | final_attention_mask = torch.cat((final_attention_mask,attention_tensor), dim=0) 229 | # logger.info(final_tokens_tensor) 230 | # logger.info(final_segments_tensor) 231 | # logger.info(final_attention_mask) 232 | # logger.info(final_tokens_tensor.shape) 233 | # logger.info(final_segments_tensor.shape) 234 | # logger.info(final_attention_mask.shape) 235 | return final_tokens_tensor, final_segments_tensor, final_attention_mask, masked_indices_list, tokenized_text_list 236 | 237 | def __get_input_tensors(self, sentences, mlm_label=None): 238 | 239 | if len(sentences) > 2: 240 | logger.info(sentences) 241 | raise ValueError("BERT accepts maximum two sentences in input for each data point") 242 | 243 | first_tokenized_sentence = [self.tokenizer.tokenize(token) if ((not token.startswith('[unused')) and (token != self.MASK)) else [token] for token in sentences[0].split()] 244 | first_tokenized_sentence = [item for sublist in first_tokenized_sentence for item in sublist] 245 | if self.model_type == 'roberta': 246 | first_tokenized_sentence = self.tokenizer.tokenize(sentences[0]) 247 | first_segment_id = np.zeros(len(first_tokenized_sentence), dtype=int).tolist() 248 | 249 | # add [SEP] token at the end 250 | first_tokenized_sentence.append(self.SEP) 251 | first_segment_id.append(0) 252 | 253 | if len(sentences)>1 : 254 | second_tokenized_sentece = [self.tokenizer.tokenize(token) if not token.startswith('[unused') else [token] for token in sentences[1].split()] 255 | second_tokenized_sentece = [item for sublist in second_tokenized_sentece for item in sublist] 256 | if self.model_type == 'roberta': 257 | second_tokenized_sentece = self.tokenizer.tokenize(sentences[1]) 258 | second_segment_id = np.full(len(second_tokenized_sentece),1, dtype=int).tolist() 259 | 260 | # add [SEP] token at the end 261 | second_tokenized_sentece.append(self.SEP) 262 | second_segment_id.append(1) 263 | 264 | tokenized_text = first_tokenized_sentence + second_tokenized_sentece 265 | segments_ids = first_segment_id + second_segment_id 266 | else: 267 | tokenized_text = first_tokenized_sentence 268 | segments_ids = first_segment_id 269 | 270 | # add [CLS] token at the beginning 271 | tokenized_text.insert(0,self.CLS) 272 | segments_ids.insert(0,0) 273 | 274 | # look for masked indices 275 | masked_indices = [] 276 | for i in range(len(tokenized_text)): 277 | token = tokenized_text[i] 278 | if token == self.MASK: 279 | masked_indices.append(i) 280 | 281 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokenized_text) 282 | 283 | # Convert inputs to PyTorch tensors 284 | tokens_tensor = torch.tensor([indexed_tokens]) 285 | segments_tensors = torch.tensor([segments_ids]) 286 | 287 | if mlm_label is None: 288 | return tokens_tensor, segments_tensors, masked_indices, tokenized_text 289 | 290 | # Handle mlm_label 291 | mlm_labels = np.full(len(tokenized_text), -100, dtype=int).tolist() 292 | tmp_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(' '+mlm_label)) 293 | assert(len(tmp_ids) == 1) 294 | mlm_labels[masked_indices[-1]] = tmp_ids[0] 295 | mlm_labels_tensor = torch.tensor([mlm_labels]) 296 | 297 | return tokens_tensor, segments_tensors, masked_indices, tokenized_text, mlm_labels_tensor, tmp_ids[0] 298 | 299 | def __get_token_ids_from_tensor(self, indexed_string): 300 | token_ids = [] 301 | if self.map_indices is not None: 302 | # map indices to subset of the vocabulary 303 | indexed_string = self.convert_ids(indexed_string) 304 | token_ids = np.asarray(indexed_string) 305 | else: 306 | token_ids = indexed_string 307 | return token_ids 308 | 309 | def get_batch_generation(self, sentences_list, logger= None, 310 | try_cuda=True): 311 | if not sentences_list: 312 | return None 313 | if try_cuda: 314 | self.try_cuda() 315 | 316 | tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list = self.__get_input_tensors_batch(sentences_list) 317 | 318 | if logger is not None: 319 | logger.debug("\n{}\n".format(tokenized_text_list)) 320 | 321 | with torch.no_grad(): 322 | logits = self.mlm_model( 323 | input_ids=tokens_tensor.to(self._model_device), 324 | token_type_ids=segments_tensor.to(self._model_device), 325 | attention_mask=attention_mask_tensor.to(self._model_device), 326 | ) 327 | 328 | log_probs = F.log_softmax(logits, dim=-1).cpu() 329 | 330 | token_ids_list = [] 331 | for indexed_string in tokens_tensor.numpy(): 332 | token_ids_list.append(self.__get_token_ids_from_tensor(indexed_string)) 333 | 334 | return log_probs, token_ids_list, masked_indices_list 335 | 336 | def run_batch(self, sentences_list, samples_list, try_cuda=True, training=True, filter_indices=None, index_list=None, vocab_to_common_vocab=None): 337 | if try_cuda and torch.cuda.device_count() > 0: 338 | self.try_cuda() 339 | 340 | tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list, mlm_labels_tensor, mlm_label_ids = self._get_input_tensors_batch_train(sentences_list, samples_list) 341 | 342 | if training: 343 | self.mlm_model.train() 344 | loss = self.mlm_model( 345 | input_ids=tokens_tensor.to(self._model_device), 346 | token_type_ids=segments_tensor.to(self._model_device), 347 | attention_mask=attention_mask_tensor.to(self._model_device), 348 | masked_lm_labels=mlm_labels_tensor.to(self._model_device), 349 | ) 350 | loss = loss[0] 351 | else: 352 | self.mlm_model.eval() 353 | with torch.no_grad(): 354 | loss, logits = self.mlm_model( 355 | input_ids=tokens_tensor.to(self._model_device), 356 | token_type_ids=segments_tensor.to(self._model_device), 357 | attention_mask=attention_mask_tensor.to(self._model_device), 358 | masked_lm_labels=mlm_labels_tensor.to(self._model_device), 359 | ) 360 | log_probs = F.log_softmax(logits, dim=-1).cpu() 361 | 362 | if training: 363 | return loss 364 | else: 365 | # During testing, return accuracy and top-k predictions 366 | tot = log_probs.shape[0] 367 | cor = 0 368 | preds = [] 369 | topk = [] 370 | common_vocab_loss = [] 371 | 372 | for i in range(log_probs.shape[0]): 373 | masked_index = masked_indices_list[i][0] 374 | log_prob = log_probs[i][masked_index] 375 | mlm_label = mlm_label_ids[i] 376 | if filter_indices is not None: 377 | log_prob = log_prob.index_select(dim=0, index=filter_indices) 378 | pred_common_vocab = torch.argmax(log_prob) 379 | pred = index_list[pred_common_vocab] 380 | 381 | # get top-k predictions 382 | topk_preds = [] 383 | topk_log_prob, topk_ids = torch.topk(log_prob, self.k) 384 | for log_prob_i, idx in zip(topk_log_prob, topk_ids): 385 | ori_idx = index_list[idx] 386 | token = self.vocab[ori_idx] 387 | topk_preds.append({'token': token, 'log_prob': log_prob_i.item()}) 388 | topk.append(topk_preds) 389 | 390 | # compute entropy on common vocab 391 | common_logits = logits[i][masked_index].cpu().index_select(dim=0, index=filter_indices) 392 | common_log_prob = -F.log_softmax(common_logits, dim=-1) 393 | common_label_id = vocab_to_common_vocab[mlm_label] 394 | common_vocab_loss.append(common_log_prob[common_label_id].item()) 395 | else: 396 | pred = torch.argmax(log_prob) 397 | topk.append([]) 398 | if pred == mlm_labels_tensor[i][masked_index]: 399 | cor += 1 400 | preds.append(1) 401 | else: 402 | preds.append(0) 403 | 404 | return log_probs, cor, tot, preds, topk, loss, common_vocab_loss 405 | -------------------------------------------------------------------------------- /code/run_eval_prompts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import sys 6 | import logging 7 | from tqdm import tqdm 8 | import torch 9 | 10 | from models import Prober 11 | from utils import load_vocab, load_data, batchify, save_model, evaluate, get_relation_meta 12 | 13 | from transformers import AdamW, get_linear_schedule_with_warmup 14 | 15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 16 | datefmt='%m/%d/%Y %H:%M:%S', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | def init_template(args, model): 21 | relation = get_relation_meta(args) 22 | return relation['template'] 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--model_name', type=str, default='bert-base-cased', help='the huggingface model name') 27 | parser.add_argument('--model_dir', type=str, default=None, help='the model directory (if not using --model_name)') 28 | parser.add_argument('--output_dir', type=str, default='output', help='the output directory to store trained model and prediction results') 29 | parser.add_argument('--common_vocab_filename', type=str, default='data/common_vocab_cased.txt', help='common vocabulary of models (used to filter triples)') 30 | parser.add_argument('--relation_profile', type=str, default='data/relations.jsonl', help='meta infomation of 41 relations, containing the pre-defined templates') 31 | 32 | parser.add_argument('--test_data', type=str, default=None) 33 | parser.add_argument('--eval_batch_size', type=int, default=8) 34 | 35 | parser.add_argument('--seed', type=int, default=6) 36 | 37 | parser.add_argument('--relation', type=str, required=True, help='which relation is considered in this run') 38 | parser.add_argument('--random_init', type=str, default='none', choices=['none', 'embedding', 'all'], help='none: use pre-trained model; embedding: random initialize the embedding layer of the model; all: random initialize the whole model') 39 | 40 | parser.add_argument('--output_predictions', action='store_true', help='whether to output top-k predictions') 41 | parser.add_argument('--k', type=int, default=5, help='how many predictions will be outputted') 42 | 43 | args = parser.parse_args() 44 | 45 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w')) 46 | 47 | logger.info(args) 48 | n_gpu = torch.cuda.device_count() 49 | logger.info('# GPUs: %d'%n_gpu) 50 | if n_gpu == 0: 51 | logger.warning('No GPU found! exit!') 52 | 53 | logger.info('Model: %s'%args.model_name) 54 | 55 | random.seed(args.seed) 56 | torch.manual_seed(args.seed) 57 | torch.cuda.manual_seed(args.seed) 58 | if torch.cuda.device_count() > 1: 59 | torch.cuda.manual_seed_all(args.seed) 60 | 61 | model = Prober(args, random_init=args.random_init) 62 | 63 | if args.common_vocab_filename is not None: 64 | vocab_subset = load_vocab(args.common_vocab_filename) 65 | logger.info('Common vocab: %s, size: %d'%(args.common_vocab_filename, len(vocab_subset))) 66 | filter_indices, index_list = model.init_indices_for_filter_logprobs(vocab_subset) 67 | else: 68 | filter_indices = None 69 | index_list = None 70 | 71 | if n_gpu > 1: 72 | model.mlm_model = torch.nn.DataParallel(model.mlm_model) 73 | 74 | template = init_template(args, model) 75 | logger.info('Template: %s'%template) 76 | 77 | eval_samples = load_data(args.test_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 78 | eval_samples_batches, eval_sentences_batches = batchify(eval_samples, args.eval_batch_size * n_gpu) 79 | evaluate(model, eval_samples_batches, eval_sentences_batches, filter_indices, index_list, output_topk=args.output_dir if args.output_predictions else None) 80 | -------------------------------------------------------------------------------- /code/run_finetune.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import sys 6 | import logging 7 | from tqdm import tqdm 8 | import torch 9 | 10 | from models import Prober 11 | from utils import load_vocab, load_data, batchify, save_model, evaluate, get_relation_meta 12 | 13 | from transformers import AdamW, get_linear_schedule_with_warmup 14 | 15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 16 | datefmt='%m/%d/%Y %H:%M:%S', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | def init_template(args, model): 21 | if args.no_template: 22 | return '[X] [Y]' 23 | else: 24 | relation = get_relation_meta(args) 25 | return relation['template'] 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--model_name', type=str, default='bert-base-cased', help='the huggingface model name') 30 | parser.add_argument('--model_dir', type=str, default=None, help='the model directory (if not using --model_name)') 31 | parser.add_argument('--output_dir', type=str, default='output', help='the output directory to store trained model and prediction results') 32 | parser.add_argument('--common_vocab_filename', type=str, default='data/common_vocab_cased.txt', help='common vocabulary of models (used to filter triples)') 33 | parser.add_argument('--relation_profile', type=str, default='data/relations.jsonl', help='meta infomation of 41 relations, containing the pre-defined templates') 34 | 35 | parser.add_argument('--train_data', type=str, default=None) 36 | parser.add_argument('--dev_data', type=str, default=None) 37 | parser.add_argument('--test_data', type=str, default=None) 38 | 39 | parser.add_argument('--train_batch_size', type=int, default=2) 40 | parser.add_argument('--eval_batch_size', type=int, default=8) 41 | parser.add_argument('--num_epoch', type=int, default=10) 42 | parser.add_argument('--learning_rate', type=float, default=2e-6) 43 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 44 | parser.add_argument('--eval_per_epoch', type=int, default=3) 45 | 46 | parser.add_argument('--do_shuffle', action='store_true') 47 | parser.add_argument('--do_eval', action='store_true', help="whether to run evaluation") 48 | parser.add_argument('--do_train', action='store_true', help="whether to run training process") 49 | parser.add_argument('--check_step', type=int, default=-1, help='how often to output training loss') 50 | 51 | parser.add_argument('--seed', type=int, default=6) 52 | 53 | parser.add_argument('--relation', type=str, required=True, help='which relation is considered in this run') 54 | parser.add_argument('--random_init', type=str, default='none', choices=['none', 'embedding', 'all'], help='none: use pre-trained model; embedding: random initialize the embedding layer of the model; all: random initialize the whole model') 55 | parser.add_argument('--no_template', action='store_true', help='whether to use manual tempalte during fine-tuning') 56 | 57 | parser.add_argument('--output_predictions', action='store_true', help='whether to output top-k predictions') 58 | parser.add_argument('--k', type=int, default=5, help='how many predictions will be outputted') 59 | 60 | args = parser.parse_args() 61 | 62 | if args.do_train: 63 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), 'w')) 64 | else: 65 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w')) 66 | 67 | logger.info(args) 68 | n_gpu = torch.cuda.device_count() 69 | logger.info('# GPUs: %d'%n_gpu) 70 | if n_gpu == 0: 71 | logger.warning('No GPU found! exit!') 72 | 73 | logger.info('Model: %s'%args.model_name) 74 | 75 | random.seed(args.seed) 76 | torch.manual_seed(args.seed) 77 | torch.cuda.manual_seed(args.seed) 78 | if torch.cuda.device_count() > 1: 79 | torch.cuda.manual_seed_all(args.seed) 80 | 81 | model = Prober(args, random_init=args.random_init) 82 | 83 | if args.common_vocab_filename is not None: 84 | vocab_subset = load_vocab(args.common_vocab_filename) 85 | logger.info('Common vocab: %s, size: %d'%(args.common_vocab_filename, len(vocab_subset))) 86 | filter_indices, index_list = model.init_indices_for_filter_logprobs(vocab_subset) 87 | else: 88 | filter_indices = None 89 | index_list = None 90 | 91 | if n_gpu > 1: 92 | model.mlm_model = torch.nn.DataParallel(model.mlm_model) 93 | 94 | template = init_template(args, model) 95 | logger.info('Template: %s'%template) 96 | 97 | if args.do_train: 98 | # Prepare train/valid data 99 | train_samples = load_data(args.train_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 100 | train_samples_batches, train_sentences_batches = batchify(train_samples, args.train_batch_size * max(n_gpu, 1)) 101 | logger.info('Train batches: %d'%len(train_samples_batches)) 102 | valid_samples = load_data(args.dev_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 103 | valid_samples_batches, valid_sentences_batches = batchify(valid_samples, args.eval_batch_size * max(n_gpu, 1)) 104 | logger.info('Valid batches: %d'%len(valid_samples_batches)) 105 | 106 | # Valid set before train 107 | best_result, result_rel = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list) 108 | save_model(model, args) 109 | 110 | param_optimizer = list(model.mlm_model.named_parameters()) 111 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 112 | optimizer_grouped_parameters = [ 113 | {'params': [p for n, p in param_optimizer 114 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 115 | {'params': [p for n, p in param_optimizer 116 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 117 | ] 118 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False) 119 | t_total = len(train_samples_batches) * args.num_epoch 120 | scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*args.warmup_proportion), t_total) 121 | 122 | # Train!!! 123 | global_step = 0 124 | tr_loss = 0 125 | nb_tr_examples = 0 126 | nb_tr_steps = 0 127 | eval_step = len(train_samples_batches) // args.eval_per_epoch 128 | for _ in range(args.num_epoch): 129 | if args.do_shuffle: 130 | logger.info('Shuffle train samples') 131 | train_samples_batches, train_sentences_batches = random.shuffle(zip(train_samples_batches, train_sentences_batches)) 132 | for i in tqdm(range(len(train_samples_batches))): 133 | samples_b = train_samples_batches[i] 134 | sentences_b = train_sentences_batches[i] 135 | 136 | loss = model.run_batch(sentences_b, samples_b, training=True) 137 | if n_gpu > 1: 138 | loss = loss.mean() 139 | loss.backward() 140 | 141 | tr_loss += loss.item() 142 | nb_tr_examples += len(samples_b) 143 | nb_tr_steps += 1 144 | global_step += 1 145 | 146 | optimizer.step() 147 | scheduler.step() 148 | optimizer.zero_grad() 149 | 150 | if args.check_step > 0 and ((nb_tr_steps + 1) % args.check_step == 0): 151 | logger.info('Epoch=%d, iter=%d, loss=%.5f'%(_, i, tr_loss / nb_tr_examples)) 152 | sys.stdout.flush() 153 | tr_loss = 0 154 | nb_tr_examples = 0 155 | 156 | if eval_step > 0 and (global_step + 1) % eval_step == 0: 157 | # Eval during training 158 | logger.info('Global step=%d, evaluating...'%(global_step)) 159 | precision, current_result = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list) 160 | if precision > best_result: 161 | best_result = precision 162 | result_per = current_result 163 | logger.info('!!! Best valid (epoch=%d): %.2f' % 164 | (_, best_result * 100)) 165 | save_model(model, args) 166 | logger.info('Best Valid: %.2f'%(best_result*100)) 167 | 168 | if args.do_eval: 169 | args.model_dir = args.output_dir 170 | model = Prober(args) 171 | 172 | eval_samples = load_data(args.test_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 173 | eval_samples_batches, eval_sentences_batches = batchify(eval_samples, args.eval_batch_size * max(n_gpu, 1)) 174 | 175 | evaluate(model, eval_samples_batches, eval_sentences_batches, filter_indices, index_list, output_topk=args.output_dir if args.output_predictions else None) 176 | -------------------------------------------------------------------------------- /code/run_optiprompt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import os 4 | import random 5 | import sys 6 | import logging 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from torch import optim 11 | 12 | from models import Prober 13 | from utils import load_vocab, load_data, batchify, evaluate, get_relation_meta 14 | 15 | import numpy as np 16 | 17 | from transformers import AdamW, get_linear_schedule_with_warmup 18 | 19 | MAX_NUM_VECTORS = 10 20 | 21 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 22 | datefmt='%m/%d/%Y %H:%M:%S', 23 | level=logging.INFO) 24 | logger = logging.getLogger(__name__) 25 | 26 | def get_new_token(vid): 27 | assert(vid > 0 and vid <= MAX_NUM_VECTORS) 28 | return '[V%d]'%(vid) 29 | 30 | def convert_manual_to_dense(manual_template, model): 31 | def assign_embedding(new_token, token): 32 | """ 33 | assign the embedding of token to new_token 34 | """ 35 | logger.info('Tie embeddings of tokens: (%s, %s)'%(new_token, token)) 36 | id_a = model.tokenizer.convert_tokens_to_ids([new_token])[0] 37 | id_b = model.tokenizer.convert_tokens_to_ids([token])[0] 38 | with torch.no_grad(): 39 | model.base_model.embeddings.word_embeddings.weight[id_a] = model.base_model.embeddings.word_embeddings.weight[id_b].detach().clone() 40 | 41 | new_token_id = 0 42 | template = [] 43 | for word in manual_template.split(): 44 | if word in ['[X]', '[Y]']: 45 | template.append(word) 46 | else: 47 | tokens = model.tokenizer.tokenize(' ' + word) 48 | for token in tokens: 49 | new_token_id += 1 50 | template.append(get_new_token(new_token_id)) 51 | assign_embedding(get_new_token(new_token_id), token) 52 | 53 | return ' '.join(template) 54 | 55 | def init_template(args, model): 56 | if args.init_manual_template: 57 | relation = get_relation_meta(args) 58 | template = convert_manual_to_dense(relation['template'], model) 59 | else: 60 | template = '[X] ' + ' '.join(['[V%d]'%(i+1) for i in range(args.num_vectors)]) + ' [Y] .' 61 | return template 62 | 63 | def prepare_for_dense_prompt(model): 64 | new_tokens = [get_new_token(i+1) for i in range(MAX_NUM_VECTORS)] 65 | model.tokenizer.add_tokens(new_tokens) 66 | ebd = model.mlm_model.resize_token_embeddings(len(model.tokenizer)) 67 | logger.info('# vocab after adding new tokens: %d'%len(model.tokenizer)) 68 | 69 | def save_optiprompt(args, model, original_vocab_size): 70 | logger.info("Saving OptiPrompt's [V]s..") 71 | vs = model.base_model.embeddings.word_embeddings.weight[original_vocab_size:].detach().cpu().numpy() 72 | with open(os.path.join(args.output_dir, 'prompt_vecs.npy'), 'wb') as f: 73 | np.save(f, vs) 74 | 75 | def load_optiprompt(args): 76 | # load bert model (pre-trained) 77 | model = Prober(args, random_init=args.random_init) 78 | original_vocab_size = len(list(model.tokenizer.get_vocab())) 79 | prepare_for_dense_prompt(model) 80 | 81 | logger.info("Loading OptiPrompt's [V]s..") 82 | with open(os.path.join(args.output_dir, 'prompt_vecs.npy'), 'rb') as f: 83 | vs = np.load(f) 84 | 85 | # copy fine-tuned new_tokens to the pre-trained model 86 | with torch.no_grad(): 87 | model.base_model.embeddings.word_embeddings.weight[original_vocab_size:] = torch.Tensor(vs) 88 | return model 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--model_name', type=str, default='bert-base-cased', help='the huggingface model name') 93 | parser.add_argument('--model_dir', type=str, default=None, help='the model directory (if not using --model_name)') 94 | parser.add_argument('--output_dir', type=str, default='output', help='the output directory to store trained model and prediction results') 95 | parser.add_argument('--common_vocab_filename', type=str, default=None, help='common vocabulary of models (used to filter triples)') 96 | parser.add_argument('--relation_profile', type=str, default=None, help='meta infomation of 41 relations, containing the pre-defined templates') 97 | 98 | parser.add_argument('--train_data', type=str, default=None) 99 | parser.add_argument('--dev_data', type=str, default=None) 100 | parser.add_argument('--test_data', type=str, default=None) 101 | 102 | parser.add_argument('--train_batch_size', type=int, default=16, help='training batch size per GPU') 103 | parser.add_argument('--eval_batch_size', type=int, default=8) 104 | parser.add_argument('--num_epoch', type=int, default=10) 105 | parser.add_argument('--learning_rate', type=float, default=3e-3) 106 | parser.add_argument('--warmup_proportion', type=float, default=0.1) 107 | parser.add_argument('--eval_per_epoch', type=int, default=3) 108 | 109 | parser.add_argument('--do_shuffle', action='store_true') 110 | parser.add_argument('--do_eval', action='store_true', help="whether to run evaluation") 111 | parser.add_argument('--do_train', action='store_true', help="whether to run training process") 112 | parser.add_argument('--check_step', type=int, default=-1, help='how often to output training loss') 113 | 114 | parser.add_argument('--seed', type=int, default=6) 115 | 116 | parser.add_argument('--relation', type=str, required=True, help='which relation is considered in this run') 117 | parser.add_argument('--init_manual_template', action='store_true', help='whether to use manual template to initialize the dense vectors') 118 | parser.add_argument('--random_init', type=str, default='none', choices=['none', 'embedding', 'all'], help='none: use pre-trained model; embedding: random initialize the embedding layer of the model; all: random initialize the whole model') 119 | parser.add_argument('--num_vectors', type=int, default=5, help='how many dense vectors are used in OptiPrompt') 120 | 121 | parser.add_argument('--output_predictions', action='store_true', help='whether to output top-k predictions') 122 | parser.add_argument('--k', type=int, default=5, help='how many predictions will be outputted') 123 | 124 | args = parser.parse_args() 125 | 126 | if args.do_train: 127 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "train.log"), 'w')) 128 | else: 129 | logger.addHandler(logging.FileHandler(os.path.join(args.output_dir, "eval.log"), 'w')) 130 | 131 | logger.info(args) 132 | n_gpu = torch.cuda.device_count() 133 | logger.info('# GPUs: %d'%n_gpu) 134 | if n_gpu == 0: 135 | logger.warning('No GPUs found!') 136 | 137 | logger.info('Model: %s'%args.model_name) 138 | 139 | random.seed(args.seed) 140 | torch.manual_seed(args.seed) 141 | torch.cuda.manual_seed(args.seed) 142 | if torch.cuda.device_count() > 1: 143 | torch.cuda.manual_seed_all(args.seed) 144 | 145 | model = Prober(args, random_init=args.random_init) 146 | original_vocab_size = len(list(model.tokenizer.get_vocab())) 147 | logger.info('Original vocab size: %d'%original_vocab_size) 148 | prepare_for_dense_prompt(model) 149 | 150 | if args.common_vocab_filename is not None: 151 | vocab_subset = load_vocab(args.common_vocab_filename) 152 | logger.info('Common vocab: %s, size: %d'%(args.common_vocab_filename, len(vocab_subset))) 153 | filter_indices, index_list = model.init_indices_for_filter_logprobs(vocab_subset) 154 | else: 155 | filter_indices = None 156 | index_list = None 157 | 158 | if n_gpu > 1: 159 | model.mlm_model = torch.nn.DataParallel(model.mlm_model) 160 | 161 | template = init_template(args, model) 162 | logger.info('Template: %s'%template) 163 | 164 | if args.do_train: 165 | # Prepare train/valid data 166 | train_samples = load_data(args.train_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 167 | train_samples_batches, train_sentences_batches = batchify(train_samples, args.train_batch_size * max(n_gpu, 1)) 168 | logger.info('Train batches: %d'%len(train_samples_batches)) 169 | valid_samples = load_data(args.dev_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 170 | valid_samples_batches, valid_sentences_batches = batchify(valid_samples, args.eval_batch_size * max(n_gpu, 1)) 171 | logger.info('Valid batches: %d'%len(valid_samples_batches)) 172 | 173 | # Valid set before train 174 | best_result, result_rel = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list) 175 | save_optiprompt(args, model, original_vocab_size) 176 | 177 | # Add word embeddings to the optimizer 178 | optimizer = AdamW([{'params': model.base_model.embeddings.word_embeddings.parameters()}], lr=args.learning_rate, correct_bias=False) 179 | t_total = len(train_samples_batches) * args.num_epoch 180 | scheduler = get_linear_schedule_with_warmup(optimizer, int(t_total*args.warmup_proportion), t_total) 181 | 182 | # Train!!! 183 | global_step = 0 184 | tr_loss = 0 185 | nb_tr_examples = 0 186 | nb_tr_steps = 0 187 | eval_step = len(train_samples_batches) // args.eval_per_epoch 188 | for _ in range(args.num_epoch): 189 | if args.do_shuffle: 190 | logger.info('Shuffle train samples') 191 | train_samples_batches, train_sentences_batches = random.shuffle(zip(train_samples_batches, train_sentences_batches)) 192 | for i in tqdm(range(len(train_samples_batches))): 193 | samples_b = train_samples_batches[i] 194 | sentences_b = train_sentences_batches[i] 195 | 196 | loss = model.run_batch(sentences_b, samples_b, training=True) 197 | if n_gpu > 1: 198 | loss = loss.mean() 199 | loss.backward() 200 | 201 | tr_loss += loss.item() 202 | nb_tr_examples += len(samples_b) 203 | nb_tr_steps += 1 204 | global_step += 1 205 | 206 | # set normal tokens' gradients to be zero 207 | for p in model.base_model.embeddings.word_embeddings.parameters(): 208 | # only update new tokens 209 | p.grad[:original_vocab_size, :] = 0.0 210 | 211 | optimizer.step() 212 | scheduler.step() 213 | optimizer.zero_grad() 214 | 215 | if args.check_step > 0 and ((nb_tr_steps + 1) % args.check_step == 0): 216 | logger.info('Epoch=%d, iter=%d, loss=%.5f'%(_, i, tr_loss / nb_tr_examples)) 217 | sys.stdout.flush() 218 | tr_loss = 0 219 | nb_tr_examples = 0 220 | 221 | if eval_step > 0 and (global_step + 1) % eval_step == 0: 222 | # Eval during training 223 | logger.info('Global step=%d, evaluating...'%(global_step)) 224 | precision, current_result = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list) 225 | if precision > best_result: 226 | best_result = precision 227 | result_per = current_result 228 | logger.info('!!! Best valid (epoch=%d): %.2f' % 229 | (_, best_result * 100)) 230 | save_optiprompt(args, model, original_vocab_size) 231 | logger.info('Best Valid: %.2f'%(best_result*100)) 232 | 233 | if args.do_eval: 234 | model = load_optiprompt(args) 235 | 236 | eval_samples = load_data(args.test_data, template, vocab_subset=vocab_subset, mask_token=model.MASK) 237 | eval_samples_batches, eval_sentences_batches = batchify(eval_samples, args.eval_batch_size * max(n_gpu, 1)) 238 | 239 | evaluate(model, eval_samples_batches, eval_sentences_batches, filter_indices, index_list, output_topk=args.output_dir if args.output_predictions else None) 240 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from tqdm import tqdm 4 | import sys 5 | import logging 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | def load_vocab(vocab_filename): 10 | with open(vocab_filename, "r") as f: 11 | lines = f.readlines() 12 | vocab = [x.strip() for x in lines] 13 | return vocab 14 | 15 | def load_file(filename): 16 | data = [] 17 | with open(filename, "r") as f: 18 | for line in f.readlines(): 19 | data.append(json.loads(line)) 20 | return data 21 | 22 | def parse_template(template, subject_label, object_label='[MASK]'): 23 | SUBJ_SYMBOL = "[X]" 24 | OBJ_SYMBOL = "[Y]" 25 | template = template.replace(SUBJ_SYMBOL, subject_label) 26 | template = template.replace(OBJ_SYMBOL, object_label) 27 | return [template] 28 | 29 | def convert_tokens_to_string(tokens): 30 | out_string = " ".join(tokens).replace(" ##", "").strip() 31 | return out_string 32 | 33 | def get_relation_meta(args): 34 | relations = load_file(args.relation_profile) 35 | for relation in relations: 36 | if relation['relation'] == args.relation: 37 | return relation 38 | raise ValueError('Relation info %s not found in file %s'%(args.relation, args.relation_profile)) 39 | 40 | def batchify(data, batch_size): 41 | list_samples_batches = [] 42 | list_sentences_batches = [] 43 | current_samples_batch = [] 44 | current_sentences_batches = [] 45 | 46 | c = 0 47 | for sample in data: 48 | input_sentences = sample['input_sentences'] 49 | current_samples_batch.append(sample) 50 | current_sentences_batches.append(input_sentences) 51 | c += 1 52 | if c >= batch_size: 53 | list_samples_batches.append(current_samples_batch) 54 | list_sentences_batches.append(current_sentences_batches) 55 | current_samples_batch = [] 56 | current_sentences_batches = [] 57 | c = 0 58 | 59 | if current_samples_batch and len(current_samples_batch) > 0: 60 | list_samples_batches.append(current_samples_batch) 61 | list_sentences_batches.append(current_sentences_batches) 62 | 63 | return list_samples_batches, list_sentences_batches 64 | 65 | 66 | def save_model(model, args): 67 | logger.info('Saving model...') 68 | model_to_save = model.mlm_model 69 | model_to_save.save_pretrained(args.output_dir) 70 | model.tokenizer.save_pretrained(args.output_dir) 71 | 72 | def output_result(result, eval_loss): 73 | logger.info('* Evaluation result *') 74 | cor = 0 75 | tot = 0 76 | macro = 0.0 77 | loss = 0.0 78 | for rel in result: 79 | cor_, tot_, avg_, loss_ = result[rel] 80 | cor += cor_ 81 | tot += tot_ 82 | macro += avg_ 83 | loss_ /= tot_ 84 | loss += loss_ 85 | logger.info('%s\t%.5f\t%d\t%d\t%.5f' % (rel, avg_, cor_, tot_, loss_)) 86 | macro = cor / tot if tot > 0 else 0.0 87 | micro = macro / len(result) if len(result) > 0 else 0.0 88 | logger.info('Macro avg: %.5f' % macro) 89 | logger.info('Micro avg: %.5f, Eval_loss: %.5f, Eval_loss (common vocab): %.5f' %(micro, eval_loss / tot, loss / len(result) if len(result) > 0 else 0.0)) 90 | sys.stdout.flush() 91 | return micro, macro 92 | 93 | def evaluate(model, samples_batches, sentences_batches, filter_indices=None, index_list=None, output_topk=None): 94 | vocab_to_common_vocab = None 95 | if index_list is not None: 96 | vocab_to_common_vocab = {} 97 | for cid, idx in enumerate(index_list): 98 | vocab_to_common_vocab[idx] = cid 99 | 100 | cor_all = 0 101 | tot_all = 0 102 | result = {} 103 | list_of_predictions = {} 104 | eval_loss = 0.0 105 | common_eval_loss = 0.0 106 | for i in tqdm(range(len(samples_batches))): 107 | samples_b = samples_batches[i] 108 | sentences_b = sentences_batches[i] 109 | 110 | log_probs, cor_b, tot_b, pred_b, topk_preds, loss, common_vocab_loss = model.run_batch(sentences_b, samples_b, training=False, filter_indices=filter_indices, index_list=index_list, vocab_to_common_vocab=vocab_to_common_vocab) 111 | cor_all += cor_b 112 | tot_all += tot_b 113 | 114 | for pred, sample, topk, vocab_loss in zip(pred_b, samples_b, topk_preds, common_vocab_loss): 115 | rel = sample['predicate_id'] 116 | if rel not in result: 117 | result[rel] = (0, 0, 0, 0.0) 118 | list_of_predictions[rel] = [] 119 | cor, tot, _, rel_tot_loss = result[rel] 120 | tot += 1 121 | cor += pred 122 | rel_tot_loss += vocab_loss 123 | result[rel] = (cor, tot, cor / tot if tot > 0 else 0.0, rel_tot_loss) 124 | list_of_predictions[rel].append({ 125 | 'uuid': sample['uuid'], 126 | 'relation': sample['predicate_id'], 127 | 'sub_label': sample['sub_label'], 128 | 'obj_label': sample['obj_label'], 129 | 'masked_sentences': sample['input_sentences'], 130 | 'topk': topk, 131 | }) 132 | 133 | eval_loss += loss.item() * tot_b 134 | 135 | if output_topk is not None: 136 | logger.info('Output top-k prediction to %s..'%output_topk) 137 | for rel in list_of_predictions: 138 | with open(os.path.join(output_topk, '%s_predictions.jsonl'%rel), 'w') as f: 139 | f.write('\n'.join([json.dumps(x) for x in list_of_predictions[rel]])) 140 | 141 | micro, macro = output_result(result, eval_loss) 142 | return micro, result 143 | 144 | def gen_feature_sample(data_sample, template, mask_token='[MASK]'): 145 | feature_sample = {} 146 | feature_sample['predicate_id'] = data_sample['predicate_id'] 147 | feature_sample['sub_label'] = data_sample['sub_label'] 148 | feature_sample['obj_label'] = data_sample['obj_label'] 149 | feature_sample['uuid'] = data_sample['uuid'] if 'uuid' in data_sample else '' 150 | masked_sentence = parse_template(template.strip(), feature_sample['sub_label'].strip(), mask_token) 151 | feature_sample['input_sentences'] = [masked_sentence[0]] 152 | return feature_sample 153 | 154 | def load_data(data_path, template, vocab_subset=None, mask_token='[MASK]'): 155 | all_samples = [] 156 | 157 | distinct_facts = set() 158 | raw_samples = load_file(data_path) 159 | for data_sample in raw_samples: 160 | # follow the LAMA setting, only keep distinct (sub, obj) pairs 161 | if (data_sample['sub_label'], data_sample['obj_label']) in distinct_facts: 162 | continue 163 | if (data_sample['obj_label'] not in vocab_subset): 164 | continue 165 | distinct_facts.add((data_sample['sub_label'], data_sample['obj_label'])) 166 | 167 | feature_sample = gen_feature_sample(data_sample, template, mask_token) 168 | all_samples.append(feature_sample) 169 | 170 | return all_samples 171 | 172 | -------------------------------------------------------------------------------- /common_vocabs/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the vobacularies we used in our experiments. 2 | For the main results, we use `common_vocab_cased.txt`; for the experiments comparing BERT, RoBERTa, and ALBERT, we create a new vocabulary `common_vocab_cased_be_ro_al.txt`. 3 | 4 | Vocabs: 5 | * `common_vocab_cased.txt`: the common vocab released in [LAMA](https://github.com/facebookresearch/LAMA). It is the intersection of vocabularies of five models (Transformer-XL, BERT, ELMo, GPT, and RoBERTa). We use this file in our main experiments for fair comparision. 6 | * `common_vocab_cased_be_ro_al.txt`: the common vocab we used to compared BERT, RoBERTa, and ALBERT (see Appendix C). It is the intersection of vocabularies of BERT, RoBERTa, and ALBERT. 7 | -------------------------------------------------------------------------------- /figure/optiprompt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/OptiPrompt/9450bc96f28ac3df01337c79be079a2ac3defb5c/figure/optiprompt.png -------------------------------------------------------------------------------- /relation_metainfo/AutoPrompt_relations.jsonl: -------------------------------------------------------------------------------- 1 | {"relation": "P1001", "template": "[X]vik nationwide disabilities policing within [Y]."} 2 | {"relation": "P101", "template": "[X] probability earliest fame totaled studying [Y]."} 3 | {"relation": "P103", "template": "[X]PA communerug speaks proper [Y]."} 4 | {"relation": "P106", "template": "[X] supporters studied politicians musician turned [Y]."} 5 | {"relation": "P108", "template": "[X] 1987adeNBC computing succeeded [Y]."} 6 | {"relation": "P127", "template": "[X] is hindwings mainline architecture within [Y]."} 7 | {"relation": "P1303", "template": "[X] playingdrum concertoative electric [Y]."} 8 | {"relation": "P131", "template": "[X]ediatric close suburb throughout northwest [Y]."} 9 | {"relation": "P136", "template": "[X] freaking genre orchestra fiction acid [Y]."} 10 | {"relation": "P1376", "template": "[X] boasts native territory traditionally called [Y]."} 11 | {"relation": "P138", "template": "[X] consistslanche classical name of [Y]."} 12 | {"relation": "P140", "template": "[X]urn openly explicitly mosques practicing [Y]."} 13 | {"relation": "P1412", "template": "[X] receivedorganisation 1904 speaking only [Y]."} 14 | {"relation": "P159", "template": "[X] isnky galleries headquartered in [Y]."} 15 | {"relation": "P17", "template": "[X] is association footballled southeastern [Y]."} 16 | {"relation": "P176", "template": "[X] was flight series manufactured by [Y]."} 17 | {"relation": "P178", "template": "[X] is memory arcade branding by [Y]."} 18 | {"relation": "P19", "template": "[X] clocks literary economist relocated to [Y]."} 19 | {"relation": "P190", "template": "[X] proceeded worldwidedick offices near [Y]."} 20 | {"relation": "P20", "template": "[X] reorganizationotype photographic studio in [Y]."} 21 | {"relation": "P264", "template": "[X] cameo explanation\u00f6table sued [Y]."} 22 | {"relation": "P27", "template": "[X] m\u00b3 badminton pieces internationally representing [Y]."} 23 | {"relation": "P276", "template": "[X] consists kilograms centred neighborhoods in [Y]."} 24 | {"relation": "P279", "template": "[X] is \u00ee adequately termed coated [Y]."} 25 | {"relation": "P30", "template": "[X] is commune polar continent in [Y]."} 26 | {"relation": "P31", "template": "[X] isious 1970s southwardlier [Y]."} 27 | {"relation": "P36", "template": "[X] includesiidae geologic countryside near [Y]."} 28 | {"relation": "P361", "template": "[X] isaul archaic section of [Y]."} 29 | {"relation": "P364", "template": "[X]dak \u20ac dancers speak standard [Y]."} 30 | {"relation": "P37", "template": "[X]inen dialects resembled officially exclusively [Y]."} 31 | {"relation": "P39", "template": "[X] explorers voting municipal \u2192 consecrated [Y]."} 32 | {"relation": "P407", "template": "[X] playedi\u0107 every dialect but [Y]."} 33 | {"relation": "P413", "template": "[X] played colors skier \u2194 defensive [Y]."} 34 | {"relation": "P449", "template": "[X] uncredited recording remake aired on [Y]."} 35 | {"relation": "P463", "template": "[X] splits artisticlogy prior joining [Y]."} 36 | {"relation": "P47", "template": "[X] shares undrafted border northeast neighbours [Y]."} 37 | {"relation": "P495", "template": "[X] album spanninggie chart in [Y]."} 38 | {"relation": "P527", "template": "[X] nickname involves \u032f\u00bddized [Y]."} 39 | {"relation": "P530", "template": "[X] nightclubrah preceding relations with [Y]."} 40 | {"relation": "P740", "template": "[X] refers drum blog centred downtown [Y]."} 41 | {"relation": "P937", "template": "[X] vol \u300elson gallery in [Y]."} -------------------------------------------------------------------------------- /relation_metainfo/LAMA_relations.jsonl: -------------------------------------------------------------------------------- 1 | {"relation": "P19", "template": "[X] was born in [Y] .", "label": "place of birth", "description": "most specific known (e.g. city instead of country, or hospital instead of city) birth location of a person, animal or fictional character", "type": "N-1"} 2 | {"relation": "P20", "template": "[X] died in [Y] .", "label": "place of death", "description": "most specific known (e.g. city instead of country, or hospital instead of city) death location of a person, animal or fictional character", "type": "N-1"} 3 | {"relation": "P279", "template": "[X] is a subclass of [Y] .", "label": "subclass of", "description": "all instances of these items are instances of those items; this item is a class (subset) of that item. Not to be confused with P31 (instance of)", "type": "N-1"} 4 | {"relation": "P37", "template": "The official language of [X] is [Y] .", "label": "official language", "description": "language designated as official by this item", "type": "N-1"} 5 | {"relation": "P413", "template": "[X] plays in [Y] position .", "label": "position played on team / speciality", "description": "position or specialism of a player on a team, e.g. Small Forward", "type": "N-1"} 6 | {"relation": "P166", "template": "[X] was awarded the [Y] .", "label": "award received", "description": "award or recognition received by a person, organisation or creative work", "type": "N-M"} 7 | {"relation": "P449", "template": "[X] was originally aired on [Y] .", "label": "original network", "description": "network(s) the radio or television show was originally aired on, including", "type": "N-1"} 8 | {"relation": "P69", "template": "[X] was educated at the University of [Y] .", "label": "educated at", "description": "educational institution attended by subject", "type": "N-M"} 9 | {"relation": "P47", "template": "[X] shares border with [Y] .", "label": "shares border with", "description": "countries or administrative subdivisions, of equal level, that this item borders, either by land or water", "type": "N-M"} 10 | {"relation": "P138", "template": "[X] is named after [Y] .", "label": "named after", "description": "entity or event that inspired the subject's name, or namesake (in at least one language)", "type": "N-1"} 11 | {"relation": "P364", "template": "The original language of [X] is [Y] .", "label": "original language of film or TV show", "description": "language in which a film or a performance work was originally created. Deprecated for written works; use P407 (\"language of work or name\") instead.", "type": "N-1"} 12 | {"relation": "P54", "template": "[X] plays with [Y] .", "label": "member of sports team", "description": "sports teams or clubs that the subject currently represents or formerly represented", "type": "N-1"} 13 | {"relation": "P463", "template": "[X] is a member of [Y] .", "label": "member of", "description": "organization or club to which the subject belongs. Do not use for membership in ethnic or social groups, nor for holding a position such as a member of parliament (use P39 for that).", "type": "N-M"} 14 | {"relation": "P101", "template": "[X] works in the field of [Y] .", "label": "field of work", "description": "specialization of a person or organization; see P106 for the occupation", "type": "N-M"} 15 | {"relation": "P1923", "template": "[Y] participated in the [X] .", "label": "participating team", "description": "Like 'Participant' (P710) but for teams. For an event like a cycle race or a football match you can use this property to list the teams and P710 to list the individuals (with 'member of sports team' (P54)' as a qualifier for the individuals)", "type": "N-M"} 16 | {"relation": "P106", "template": "[X] is a [Y] by profession .", "label": "occupation", "description": "occupation of a person; see also \"field of work\" (Property:P101), \"position held\" (Property:P39)", "type": "N-M"} 17 | {"relation": "P527", "template": "[X] consists of [Y] .", "label": "has part", "description": "part of this subject; inverse property of \"part of\" (P361). See also \"has parts of the class\" (P2670).", "type": "N-M"} 18 | {"relation": "P102", "template": "[X] is a member of the [Y] political party .", "label": "member of political party", "description": "the political party of which this politician is or has been a member", "type": "N-1"} 19 | {"relation": "P530", "template": "[X] maintains diplomatic relations with [Y] .", "label": "diplomatic relation", "description": "diplomatic relations of the country", "type": "N-M"} 20 | {"relation": "P176", "template": "[X] is produced by [Y] .", "label": "manufacturer", "description": "manufacturer or producer of this product", "type": "N-1"} 21 | {"relation": "P27", "template": "[X] is [Y] citizen .", "label": "country of citizenship", "description": "the object is a country that recognizes the subject as its citizen", "type": "N-M"} 22 | {"relation": "P407", "template": "[X] was written in [Y] .", "label": "language of work or name", "description": "language associated with this creative work (such as books, shows, songs, or websites) or a name (for persons use P103 and P1412)", "type": "N-1"} 23 | {"relation": "P30", "template": "[X] is located in [Y] .", "label": "continent", "description": "continent of which the subject is a part", "type": "N-1"} 24 | {"relation": "P178", "template": "[X] is developed by [Y] .", "label": "developer", "description": "organisation or person that developed the item", "type": "N-M"} 25 | {"relation": "P1376", "template": "[X] is the capital of [Y] .", "label": "capital of", "description": "country, state, department, canton or other administrative division of which the municipality is the governmental seat", "type": "1-1"} 26 | {"relation": "P131", "template": "[X] is located in [Y] .", "label": "located in the administrative territorial entity", "description": "the item is located on the territory of the following administrative entity. Use P276 (location) for specifying the location of non-administrative places and for items about events", "type": "N-1"} 27 | {"relation": "P1412", "template": "[X] used to communicate in [Y] .", "label": "languages spoken, written or signed", "description": "language(s) that a person speaks or writes, including the native language(s)", "type": "N-M"} 28 | {"relation": "P108", "template": "[X] works for [Y] .", "label": "employer", "description": "person or organization for which the subject works or worked", "type": "N-M"} 29 | {"relation": "P136", "template": "[X] plays [Y] music .", "label": "genre", "description": "creative work's genre or an artist's field of work (P101). Use main subject (P921) to relate creative works to their topic", "type": "N-1"} 30 | {"relation": "P17", "template": "[X] is located in [Y] .", "label": "country", "description": "sovereign state of this item; don't use on humans", "type": "N-1"} 31 | {"relation": "P39", "template": "[X] has the position of [Y] .", "label": "position held", "description": "subject currently or formerly holds the object position or public office", "type": "N-M"} 32 | {"relation": "P264", "template": "[X] is represented by music label [Y] .", "label": "record label", "description": "brand and trademark associated with the marketing of subject music recordings and music videos", "type": "N-1"} 33 | {"relation": "P276", "template": "[X] is located in [Y] .", "label": "location", "description": "location of the item, physical object or event is within. In case of an administrative entity use P131. In case of a distinct terrain feature use P706.", "type": "N-1"} 34 | {"relation": "P937", "template": "[X] used to work in [Y] .", "label": "work location", "description": "location where persons were active", "type": "N-M"} 35 | {"relation": "P140", "template": "[X] is affiliated with the [Y] religion .", "label": "religion", "description": "religion of a person, organization or religious building, or associated with this subject", "type": "N-1"} 36 | {"relation": "P1303", "template": "[X] plays [Y] .", "label": "instrument", "description": "musical instrument that a person plays", "type": "N-M"} 37 | {"relation": "P127", "template": "[X] is owned by [Y] .", "label": "owned by", "description": "owner of the subject", "type": "N-1"} 38 | {"relation": "P103", "template": "The native language of [X] is [Y] .", "label": "native language", "description": "language or languages a person has learned from early childhood", "type": "N-1"} 39 | {"relation": "P190", "template": "[X] and [Y] are twin cities .", "label": "twinned administrative body", "description": "twin towns, sister cities, twinned municipalities and other localities that have a partnership or cooperative agreement, either legally or informally acknowledged by their governments", "type": "N-M"} 40 | {"relation": "P1001", "template": "[X] is a legal term in [Y] .", "label": "applies to jurisdiction", "description": "the item (an institution, law, public office ...) or statement belongs to or has power over or applies to the value (a territorial jurisdiction: a country, state, municipality, ...)", "type": "N-M"} 41 | {"relation": "P31", "template": "[X] is a [Y] .", "label": "instance of", "description": "that class of which this subject is a particular example and member (subject typically an individual member with a proper name label); different from P279; using this property as a qualifier is deprecated\u2014use P2868 or P3831 instead", "type": "N-M"} 42 | {"relation": "P495", "template": "[X] was created in [Y] .", "label": "country of origin", "description": "country of origin of this item (creative work, food, phrase, product, etc.)", "type": "N-1"} 43 | {"relation": "P159", "template": "The headquarter of [X] is in [Y] .", "label": "headquarters location", "description": "specific location where an organization's headquarters is or has been situated. Inverse property of \"occupant\" (P466).", "type": "N-1"} 44 | {"relation": "P36", "template": "The capital of [X] is [Y] .", "label": "capital", "description": "primary city of a country, state or other type of administrative territorial entity", "type": "1-1"} 45 | {"relation": "P740", "template": "[X] was founded in [Y] .", "label": "location of formation", "description": "location where a group or organization was formed", "type": "N-1"} 46 | {"relation": "P361", "template": "[X] is part of [Y] .", "label": "part of", "description": "object of which the subject is a part (it's not useful to link objects which are themselves parts of other objects already listed as parts of the subject). Inverse property of \"has part\" (P527, see also \"has parts of the class\" (P2670)).", "type": "N-1"} 47 | -------------------------------------------------------------------------------- /relation_metainfo/LPAQA_relations.jsonl: -------------------------------------------------------------------------------- 1 | {"relation": "P1001", "template": "[X] is a legal term used in [Y]."} 2 | {"relation": "P101", "template": "[X] works in the domain of [Y]"} 3 | {"relation": "P103", "template": "The native language of [X] is [Y] ."} 4 | {"relation": "P106", "template": "[X] is professional [Y]."} 5 | {"relation": "P108", "template": "[X] is working for [Y]."} 6 | {"relation": "P127", "template": "[X] is the property of [Y]."} 7 | {"relation": "P1303", "template": "[X] playing [Y]."} 8 | {"relation": "P131", "template": "[X] in [Y]."} 9 | {"relation": "P136", "template": "[X] plays music from [Y]."} 10 | {"relation": "P1376", "template": "[X] is the main town of [Y]."} 11 | {"relation": "P138", "template": "[X] is called [Y]."} 12 | {"relation": "P140", "template": "[X] is related to the religion [Y]."} 13 | {"relation": "P1412", "template": "[X] communicated in [Y]."} 14 | {"relation": "P159", "template": "[X] is headquartered in [Y]."} 15 | {"relation": "P176", "template": "[X] will be produced by [Y]."} 16 | {"relation": "P178", "template": "[X] was created by [Y]."} 17 | {"relation": "P17", "template": "[X] is located in [Y].."} 18 | {"relation": "P190", "template": "[X] and [Y] are twinned towns."} 19 | {"relation": "P19", "template": "[X] was born in [Y] ."} 20 | {"relation": "P20", "template": "[X] died in [Y] ."} 21 | {"relation": "P264", "template": "[X] is represented by the music label [Y]."} 22 | {"relation": "P276", "template": "[X] is located in [Y] ."} 23 | {"relation": "P279", "template": "[X] is a subclass [Y]."} 24 | {"relation": "P27", "template": "[X] is a citizen of [Y]"} 25 | {"relation": "P30", "template": "[X] is found in [Y].."} 26 | {"relation": "P31", "template": "[X] is a [Y] ."} 27 | {"relation": "P361", "template": "[X] is a component of [Y]."} 28 | {"relation": "P364", "template": "The default language of [X] is [Y]."} 29 | {"relation": "P36", "template": "The capital of [X] is [Y] ."} 30 | {"relation": "P37", "template": "The official language [X] is [Y]."} 31 | {"relation": "P39", "template": "[X] has the position of a [Y]."} 32 | {"relation": "P407", "template": "[X] was written in [Y],"} 33 | {"relation": "P413", "template": "[X] plays at [Y] position."} 34 | {"relation": "P449", "template": "[X] was originally transmitted on [Y]."} 35 | {"relation": "P463", "template": "[X] is the member of [Y]."} 36 | {"relation": "P47", "template": "[X] shares border with [Y] ."} 37 | {"relation": "P495", "template": "[X] is made in [Y]."} 38 | {"relation": "P527", "template": "[X] means [Y]."} 39 | {"relation": "P530", "template": "[X] is in diplomatic relations with [Y]."} 40 | {"relation": "P740", "template": "[X] was born in [Y]."} 41 | {"relation": "P937", "template": "[X] worked in [Y]."} -------------------------------------------------------------------------------- /relation_metainfo/README.md: -------------------------------------------------------------------------------- 1 | This folder contains (discrete-token-based) prompts used in previous factual probing works. These files can be used to reproduce their probing results on LAMA. 2 | 3 | Prompts: 4 | * `LAMA_relations.jsonl`: manually-defined prompts used in [LAMA](https://github.com/facebookresearch/LAMA). We use these prompts to initialize the dense vectors in OptiPrompt. 5 | * `LPAQA_relations.jsonl`: the top-1 prompts used in [LPAQA](https://github.com/jzbjyb/LPAQA). 6 | * `AutoPrompt_relations`: the prompts used in [AutoPrompt](https://github.com/ucinlp/autoprompt). They are optimized on the same training set with our OptiPrompt. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | torch==1.4.0 4 | transformers==3.0.2 -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget https://nlp.cs.princeton.edu/projects/optiprompt/data.tar.gz 4 | tar -xf data.tar.gz; rm -f data.tar.gz 5 | -------------------------------------------------------------------------------- /scripts/run_eval_prompts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | method=${1} 4 | 5 | if [ "${method}" = "lama" ]; then 6 | OUTPUTS_DIR=lama-outputs 7 | prompt_file=LAMA_relations.jsonl 8 | fi 9 | 10 | if [ "${method}" = "lpaqa" ]; then 11 | OUTPUTS_DIR=lpaqa-outputs 12 | prompt_file=LPAQA_relations.jsonl 13 | fi 14 | 15 | if [ "${method}" = "autoprompt" ]; then 16 | OUTPUTS_DIR=autoprompt-outputs 17 | prompt_file=AutoPrompt_relations.jsonl 18 | fi 19 | 20 | MODEL=bert-base-cased 21 | RAND=none 22 | 23 | for REL in P1001 P101 P103 P106 P108 P127 P1303 P131 P136 P1376 P138 P140 P1412 P159 P17 P176 P178 P19 P190 P20 P264 P27 P276 P279 P30 P31 P36 P361 P364 P37 P39 P407 P413 P449 P463 P47 P495 P527 P530 P740 P937; do 24 | 25 | DIR=${OUTPUTS_DIR}/${REL} 26 | mkdir -p ${DIR} 27 | 28 | python code/run_eval_prompts.py \ 29 | --relation_profile relation_metainfo/${prompt_file} \ 30 | --relation ${REL} \ 31 | --common_vocab_filename common_vocabs/common_vocab_cased.txt \ 32 | --model_name ${MODEL} \ 33 | --test_data data/LAMA-TREx/${REL}.jsonl \ 34 | --output_dir ${DIR} \ 35 | --output_predictions 36 | 37 | done 38 | 39 | python code/accumulate_results.py ${OUTPUTS_DIR} -------------------------------------------------------------------------------- /scripts/run_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTPUTS_DIR=finetune-outputs 4 | MODEL=bert-base-cased 5 | RAND=none 6 | 7 | for REL in P1001 P101 P103 P106 P108 P127 P1303 P131 P136 P1376 P138 P140 P1412 P159 P17 P176 P178 P19 P190 P20 P264 P27 P276 P279 P30 P31 P36 P361 P364 P37 P39 P407 P413 P449 P463 P47 P495 P527 P530 P740 P937; do 8 | 9 | DIR=${OUTPUTS_DIR}/${REL} 10 | mkdir -p ${DIR} 11 | 12 | python code/run_finetune.py \ 13 | --relation_profile relation_metainfo/LAMA_relations.jsonl \ 14 | --relation ${REL} \ 15 | --common_vocab_filename common_vocabs/common_vocab_cased.txt \ 16 | --model_name ${MODEL} \ 17 | --do_train \ 18 | --train_data data/autoprompt_data/${REL}/train.jsonl \ 19 | --dev_data data/autoprompt_data/${REL}/dev.jsonl \ 20 | --do_eval \ 21 | --test_data data/LAMA-TREx/${REL}.jsonl \ 22 | --output_dir ${DIR} \ 23 | --random_init ${RAND} \ 24 | --output_predictions 25 | 26 | done 27 | 28 | python code/accumulate_results.py ${OUTPUTS_DIR} -------------------------------------------------------------------------------- /scripts/run_optiprompt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OUTPUTS_DIR=optiprompt-outputs 4 | MODEL=bert-base-cased 5 | RAND=none 6 | 7 | for REL in P1001 P101 P103 P106 P108 P127 P1303 P131 P136 P1376 P138 P140 P1412 P159 P17 P176 P178 P19 P190 P20 P264 P27 P276 P279 P30 P31 P36 P361 P364 P37 P39 P407 P413 P449 P463 P47 P495 P527 P530 P740 P937; do 8 | 9 | DIR=${OUTPUTS_DIR}/${REL} 10 | mkdir -p ${DIR} 11 | 12 | python code/run_optiprompt.py \ 13 | --relation_profile relation_metainfo/LAMA_relations.jsonl \ 14 | --relation ${REL} \ 15 | --common_vocab_filename common_vocabs/common_vocab_cased.txt \ 16 | --model_name ${MODEL} \ 17 | --do_train \ 18 | --train_data data/autoprompt_data/${REL}/train.jsonl \ 19 | --dev_data data/autoprompt_data/${REL}/dev.jsonl \ 20 | --do_eval \ 21 | --test_data data/LAMA-TREx/${REL}.jsonl \ 22 | --output_dir ${DIR} \ 23 | --random_init ${RAND} \ 24 | --init_manual_template \ 25 | --output_predictions 26 | 27 | done 28 | 29 | python code/accumulate_results.py ${OUTPUTS_DIR} -------------------------------------------------------------------------------- /slides/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/princeton-nlp/OptiPrompt/9450bc96f28ac3df01337c79be079a2ac3defb5c/slides/slides.pdf --------------------------------------------------------------------------------