├── .gitignore ├── LICENSE ├── README.md ├── download_gdrive.py ├── finetune ├── dataloader.py ├── metrics.py ├── model.py ├── run_re.py ├── run_typing.py └── utils.py ├── lama ├── batch_eval_KB_completion.py ├── eval_lama.py ├── evaluation_metrics.py ├── lama_utils.py └── model.py ├── preprocess ├── WikiExtractor.py ├── extract.py ├── gen_data.py ├── merge.py └── statistic.py ├── pretrain ├── dataset.py ├── emb_ip.cfg ├── init_ent_rel.py ├── large_emb.py ├── metrics.py ├── model.py ├── run_pretrain.py ├── run_pretrain.sh └── utils.py ├── read_ent_vocab.bin └── read_rel_vocab.bin /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # mine 132 | .idea/ 133 | *.DS_Store 134 | *.npy 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tianxiang Sun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoLAKE 2 | 3 | Source code for paper "[CoLAKE: Contextualized Language and Knowledge Embedding](https://arxiv.org/abs/2010.00309)". If you have any problem about reproducing the experiments, please feel free to contact us or propose an issue. 4 | 5 | ## Prepare your environment 6 | 7 | We recommend to create a new environment. 8 | 9 | ```bash 10 | conda create --name colake python=3.7 11 | source activate colake 12 | ``` 13 | 14 | CoLAKE is implemented based on [fastNLP](https://github.com/fastnlp/fastNLP) and [huggingface's transformers](https://github.com/huggingface/transformers), and uses [fitlog](https://github.com/fastnlp/fitlog) to record the experiments. 15 | 16 | ```bash 17 | git clone https://github.com/fastnlp/fastNLP.git 18 | cd fastNLP/ & python setup.py install 19 | git clone https://github.com/fastnlp/fitlog.git 20 | cd fitlog/ & python setup.py install 21 | pip install transformers==2.11 22 | pip install sklearn 23 | ``` 24 | 25 | To re-train CoLAKE, you may need mixed CPU-GPU training to handle the large number of entities. Our implementation is based on KVStore provided by [DGL](https://github.com/dmlc/dgl). In addition, to reproduce the experiments on link prediction, you may also need [DGL-KE](https://github.com/awslabs/dgl-ke). 26 | 27 | ```bash 28 | pip install dgl==0.4.3 29 | pip install dglke 30 | ``` 31 | 32 | ## Reproduce the experiments 33 | 34 | ### 1. Download the model and entity embeddings 35 | 36 | Download the pre-trained CoLAKE [model](https://drive.google.com/file/d/1MEGcmJUBXOyxKaK6K88fZFyj_IbH9U5b) and [embeddings](https://drive.google.com/file/d/1_FG9mpTrOnxV2NolXlu1n2ihgSZFXHnI) for more than 3M entities. To reproduce the experiments on LAMA and LAMA-UHN, you only need to download the model. You can use the `download_gdrive.py` in this repo to directly download files from Google Drive to your server: 37 | 38 | ```bash 39 | mkdir model 40 | python download_gdrive.py 1MEGcmJUBXOyxKaK6K88fZFyj_IbH9U5b ./model/model.bin 41 | python download_gdrive.py 1_FG9mpTrOnxV2NolXlu1n2ihgSZFXHnI ./model/entities.npy 42 | ``` 43 | 44 | Alternatively, you can use `gdown`: 45 | 46 | ```bash 47 | pip install gdown 48 | gdown https://drive.google.com/uc?id=1MEGcmJUBXOyxKaK6K88fZFyj_IbH9U5b 49 | gdown https://drive.google.com/uc?id=1_FG9mpTrOnxV2NolXlu1n2ihgSZFXHnI 50 | ``` 51 | 52 | ### 2. Run the experiments 53 | 54 | Download the datasets for the experiments in the paper: [Google Drive](https://drive.google.com/file/d/1UNXICdkB5JbRyS5WTq6QNX4ndpMlNob6/view?usp=sharing). 55 | 56 | ```bash 57 | python download_gdrive.py 1UNXICdkB5JbRyS5WTq6QNX4ndpMlNob6 ./data.tar.gz 58 | tar -xzvf data.tar.gz 59 | cd finetune/ 60 | ``` 61 | 62 | #### FewRel 63 | 64 | ```bash 65 | python run_re.py --debug --gpu 0 66 | ``` 67 | 68 | #### Open Entity 69 | 70 | ```bash 71 | python run_typing.py --debug --gpu 0 72 | ``` 73 | 74 | #### LAMA and LAMA-UHN 75 | 76 | ```bash 77 | cd ../lama/ 78 | python eval_lama.py 79 | ``` 80 | 81 | ## Re-train CoLAKE 82 | 83 | ### 1. Download the data 84 | 85 | Download the latest wiki dump (XML format): 86 | 87 | ```bash 88 | wget -c https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2 89 | ``` 90 | 91 | Download the knowledge graph (Wikidata5M): 92 | 93 | ```bash 94 | wget -c https://www.dropbox.com/s/6sbhm0rwo4l73jq/wikidata5m_transductive.tar.gz?dl=1 95 | tar -xzvf wikidata5m_transductive.tar.gz 96 | ``` 97 | 98 | Download the Wikidata5M entity & relation aliases: 99 | 100 | ```bash 101 | wget -c https://www.dropbox.com/s/lnbhc8yuhit4wm5/wikidata5m_alias.tar.gz?dl=1 102 | tar -xzvf wikidata5m_alias.tar.gz 103 | ``` 104 | 105 | ### 2. Preprocess the data 106 | 107 | Preprocess wiki dump: 108 | 109 | ```bash 110 | mkdir pretrain_data 111 | # process xml-format wiki dump 112 | python preprocess/WikiExtractor.py enwiki-latest-pages-articles.xml.bz2 -o pretrain_data/output -l --min_text_length 100 --filter_disambig_pages -it abbr,b,big --processes 4 113 | # Modify anchors 114 | python preprocess/extract.py 4 115 | python preprocess/gen_data.py 4 116 | # Count entity & relation frequency and generate vocabs 117 | python statistic.py 118 | ``` 119 | 120 | ### 3. Train CoLAKE 121 | 122 | Initialize entity and relation embeddings with the average of RoBERTa BPE embedding of entity and relation aliases: 123 | 124 | ```bash 125 | cd pretrain/ 126 | python init_ent_rel.py 127 | ``` 128 | 129 | Train CoLAKE with mixed CPU-GPU: 130 | 131 | ```bash 132 | ./run_pretrain.sh 133 | ``` 134 | 135 | ## Cite 136 | 137 | If you use the code and model, please cite this paper: 138 | 139 | ``` 140 | @inproceedings{sun2020colake, 141 | author = {Tianxiang Sun and Yunfan Shao and Xipeng Qiu and Qipeng Guo and Yaru Hu and Xuanjing Huang and Zheng Zhang}, 142 | title = {CoLAKE: Contextualized Language and Knowledge Embedding}, 143 | booktitle = {Proceedings of the 28th International Conference on Computational Linguistics, {COLING}}, 144 | year = {2020} 145 | } 146 | ``` 147 | 148 | ## Acknowledgments 149 | 150 | - [fastNLP](https://github.com/fastnlp/fastNLP) 151 | 152 | - [LAMA](https://github.com/facebookresearch/LAMA) 153 | 154 | - [ERNIE](https://github.com/thunlp/ERNIE) 155 | 156 | -------------------------------------------------------------------------------- /download_gdrive.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Guide for usage: 3 | In your terminal, run the command: 4 | python download_gdrive.py GoogleFileID /path/for/this/file/to/download/file.type 5 | Credited to 6 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 7 | author: https://stackoverflow.com/users/1475331/user115202 8 | ''' 9 | 10 | import requests 11 | 12 | from tqdm import tqdm 13 | 14 | def download_file_from_google_drive(id, destination): 15 | def get_confirm_token(response): 16 | for key, value in response.cookies.items(): 17 | if key.startswith('download_warning'): 18 | return value 19 | 20 | return None 21 | 22 | def save_response_content(response, destination): 23 | CHUNK_SIZE = 32768 24 | 25 | with open(destination, "wb") as f: 26 | with tqdm(unit='B', unit_scale=True, unit_divisor=1024) as bar: 27 | for chunk in response.iter_content(CHUNK_SIZE): 28 | if chunk: # filter out keep-alive new chunks 29 | f.write(chunk) 30 | bar.update(CHUNK_SIZE) 31 | 32 | URL = "https://docs.google.com/uc?export=download" 33 | 34 | session = requests.Session() 35 | 36 | response = session.get(URL, params = { 'id' : id }, stream = True) 37 | token = get_confirm_token(response) 38 | 39 | if token: 40 | params = { 'id' : id, 'confirm' : token } 41 | response = session.get(URL, params = params, stream = True) 42 | 43 | save_response_content(response, destination) 44 | 45 | 46 | if __name__ == "__main__": 47 | import sys 48 | if len(sys.argv) is not 3: 49 | print("Usage: python google_drive.py drive_file_id destination_file_path") 50 | else: 51 | # TAKE ID FROM SHAREABLE LINK 52 | file_id = sys.argv[1] 53 | # DESTINATION FILE ON YOUR DISK 54 | destination = sys.argv[2] 55 | download_file_from_google_drive(file_id, destination) -------------------------------------------------------------------------------- /finetune/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from torch.utils.data import Dataset 5 | from fastNLP import seq_len_to_mask 6 | from transformers import RobertaTokenizer 7 | 8 | WORD_PADDING_INDEX = 1 9 | ENTITY_PADDING_INDEX = 1 10 | 11 | 12 | class REGraphDataSet(Dataset): 13 | def __init__(self, path, set_type, label_vocab, ent_vocab): 14 | self.set_type = set_type + '.json' 15 | self.label_vocab = label_vocab 16 | self.data = [] 17 | with open(os.path.join(path, self.set_type), 'r', encoding='utf-8') as fin: 18 | raw_data = json.load(fin) 19 | for ins in raw_data: 20 | nodes_index = [] 21 | n_words = 0 22 | for node in ins['nodes']: 23 | if isinstance(node, str): 24 | nodes_index.append(ent_vocab[node]) 25 | else: 26 | nodes_index.append(node) 27 | n_words += 1 28 | 29 | self.data.append({ 30 | 'input_ids': nodes_index, 31 | 'n_word_nodes': n_words, 32 | 'n_entity_nodes': len(nodes_index) - n_words, 33 | 'position_ids': ins['soft_position'], 34 | 'token_type_ids': ins['token_type_ids'], 35 | 'target': label_vocab[ins['label']] 36 | }) 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | def __getitem__(self, item): 42 | return self.data[item] 43 | 44 | def collate_fn(self, batch): 45 | input_keys = ['input_ids', 'n_word_nodes', 'n_entity_nodes', 'position_ids', 'attention_mask', 'target', 46 | 'token_type_ids'] 47 | target_keys = ['target'] 48 | max_words = max_ents = 0 49 | batch_x = {n: [] for n in input_keys} 50 | batch_y = {n: [] for n in target_keys} 51 | for sample in batch: 52 | max_words = sample['n_word_nodes'] if sample['n_word_nodes'] > max_words else max_words 53 | max_ents = sample['n_entity_nodes'] if sample['n_entity_nodes'] > max_ents else max_ents 54 | 55 | for sample in batch: 56 | word_pad = max_words - sample['n_word_nodes'] 57 | ent_pad = max_ents - sample['n_entity_nodes'] 58 | n_words = sample['n_word_nodes'] 59 | batch_x['input_ids'].append(sample['input_ids'][:n_words] + [WORD_PADDING_INDEX] * word_pad + \ 60 | sample['input_ids'][n_words:] + [ENTITY_PADDING_INDEX] * ent_pad) 61 | 62 | batch_x['position_ids'].append(sample['position_ids'][:n_words] + [0] * word_pad + \ 63 | sample['position_ids'][n_words:] + [0] * ent_pad) 64 | 65 | batch_x['token_type_ids'].append(sample['token_type_ids'][:n_words] + [0] * word_pad + \ 66 | sample['token_type_ids'][n_words:] + [0] * ent_pad) 67 | 68 | batch_x['n_word_nodes'].append(max_words) 69 | batch_x['n_entity_nodes'].append(max_ents) 70 | 71 | adj = torch.ones(len(sample['input_ids']), len(sample['input_ids']), dtype=torch.int) 72 | adj = torch.cat((adj[:n_words, :], 73 | torch.ones(word_pad, adj.shape[1], dtype=torch.int), 74 | adj[n_words:, :], 75 | torch.ones(ent_pad, adj.shape[1], dtype=torch.int)), dim=0) 76 | 77 | adj = torch.cat((adj[:, :n_words], 78 | torch.zeros(max_words + max_ents, word_pad, dtype=torch.int), 79 | adj[:, n_words:], 80 | torch.zeros(max_words + max_ents, ent_pad, dtype=torch.int)), dim=1) 81 | 82 | batch_x['attention_mask'].append(adj) 83 | batch_x['target'].append(sample['target']) 84 | batch_y['target'].append(sample['target']) 85 | 86 | for k, v in batch_x.items(): 87 | if k == 'attention_mask': 88 | batch_x[k] = torch.stack(v, dim=0) 89 | else: 90 | batch_x[k] = torch.tensor(v) 91 | for k, v in batch_y.items(): 92 | batch_y[k] = torch.tensor(v) 93 | 94 | return (batch_x, batch_y) 95 | 96 | 97 | class TypingGraphDataSet(Dataset): 98 | def __init__(self, path, set_type, label_vocab, ent_vocab): 99 | self.set_type = set_type + '.json' 100 | self.label_vocab = label_vocab 101 | self.data = [] 102 | with open(os.path.join(path, self.set_type), 'r', encoding='utf-8') as fin: 103 | raw_data = json.load(fin) 104 | for ins in raw_data: 105 | nodes_index = [] 106 | n_words = 0 107 | for node in ins['nodes']: 108 | if isinstance(node, str): 109 | nodes_index.append(ent_vocab[node]) 110 | else: 111 | nodes_index.append(node) 112 | n_words += 1 113 | 114 | label_index = [self.label_vocab[x] for x in ins['labels']] 115 | label_vec = [0] * len(self.label_vocab) 116 | for golden_id in label_index: 117 | label_vec[golden_id] = 1 118 | 119 | self.data.append({ 120 | 'input_ids': nodes_index, 121 | 'n_word_nodes': n_words, 122 | 'n_entity_nodes': len(nodes_index) - n_words, 123 | 'position_ids': ins['soft_position'], 124 | 'token_type_ids': ins['token_type_ids'], 125 | 'target': label_vec 126 | }) 127 | 128 | def __len__(self): 129 | return len(self.data) 130 | 131 | def __getitem__(self, item): 132 | return self.data[item] 133 | 134 | def collate_fn(self, batch): 135 | input_keys = ['input_ids', 'n_word_nodes', 'n_entity_nodes', 'position_ids', 'attention_mask', 'target', 136 | 'token_type_ids'] 137 | target_keys = ['target'] 138 | max_words = max_ents = 0 139 | batch_x = {n: [] for n in input_keys} 140 | batch_y = {n: [] for n in target_keys} 141 | for sample in batch: 142 | max_words = sample['n_word_nodes'] if sample['n_word_nodes'] > max_words else max_words 143 | max_ents = sample['n_entity_nodes'] if sample['n_entity_nodes'] > max_ents else max_ents 144 | 145 | for sample in batch: 146 | word_pad = max_words - sample['n_word_nodes'] 147 | ent_pad = max_ents - sample['n_entity_nodes'] 148 | n_words = sample['n_word_nodes'] 149 | batch_x['input_ids'].append(sample['input_ids'][:n_words] + [WORD_PADDING_INDEX] * word_pad + \ 150 | sample['input_ids'][n_words:] + [ENTITY_PADDING_INDEX] * ent_pad) 151 | 152 | batch_x['position_ids'].append(sample['position_ids'][:n_words] + [0] * word_pad + \ 153 | sample['position_ids'][n_words:] + [0] * ent_pad) 154 | 155 | batch_x['token_type_ids'].append(sample['token_type_ids'][:n_words] + [0] * word_pad + \ 156 | sample['token_type_ids'][n_words:] + [0] * ent_pad) 157 | 158 | batch_x['n_word_nodes'].append(max_words) 159 | batch_x['n_entity_nodes'].append(max_ents) 160 | 161 | adj = torch.ones(len(sample['input_ids']), len(sample['input_ids']), dtype=torch.int) 162 | adj = torch.cat((adj[:n_words, :], 163 | torch.ones(word_pad, adj.shape[1], dtype=torch.int), 164 | adj[n_words:, :], 165 | torch.ones(ent_pad, adj.shape[1], dtype=torch.int)), dim=0) 166 | 167 | adj = torch.cat((adj[:, :n_words], 168 | torch.zeros(max_words + max_ents, word_pad, dtype=torch.int), 169 | adj[:, n_words:], 170 | torch.zeros(max_words + max_ents, ent_pad, dtype=torch.int)), dim=1) 171 | 172 | batch_x['attention_mask'].append(adj) 173 | batch_x['target'].append(sample['target']) 174 | batch_y['target'].append(sample['target']) 175 | 176 | for k, v in batch_x.items(): 177 | if k == 'attention_mask': 178 | batch_x[k] = torch.stack(v, dim=0) 179 | elif k == 'target': 180 | batch_x[k] = torch.FloatTensor(v) 181 | else: 182 | batch_x[k] = torch.tensor(v) 183 | for k, v in batch_y.items(): 184 | batch_y[k] = torch.tensor(v) 185 | 186 | return (batch_x, batch_y) -------------------------------------------------------------------------------- /finetune/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastNLP.core.metrics import MetricBase 3 | from fastNLP.core.utils import _get_func_signature 4 | from sklearn.metrics import f1_score, precision_recall_fscore_support 5 | 6 | 7 | class MacroMetric(MetricBase): 8 | def __init__(self, pred=None, target=None): 9 | super().__init__() 10 | self._init_param_map(pred=pred, target=target, seq_len=None) 11 | self._target = [] 12 | self._pred = [] 13 | 14 | def evaluate(self, pred, target, seq_len=None): 15 | ''' 16 | :param pred: batch_size 17 | :param target: batch_size 18 | :param seq_len: not uesed when doing text classification 19 | :return: 20 | ''' 21 | 22 | if not isinstance(pred, torch.Tensor): 23 | raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 24 | f"got {type(pred)}.") 25 | if not isinstance(target, torch.Tensor): 26 | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 27 | f"got {type(target)}.") 28 | 29 | if pred.dim() != target.dim(): 30 | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " 31 | f"size:{pred.size()}, target should have size: {pred.size()} or " 32 | f"{pred.size()[:-1]}, got {target.size()}.") 33 | 34 | pred = pred.detach().cpu().numpy().tolist() 35 | target = target.to('cpu').numpy().tolist() 36 | self._pred.extend(pred) 37 | self._target.extend(target) 38 | 39 | def get_metric(self, reset=True): 40 | precision, recall, f_score, _ = precision_recall_fscore_support(self._target, self._pred, average='macro') 41 | evaluate_result = { 42 | 'f_score': f_score, 43 | 'precision': precision, 44 | 'recall': recall, 45 | } 46 | if reset: 47 | self._pred = [] 48 | self._target = [] 49 | 50 | return evaluate_result 51 | 52 | 53 | class MicroMetric(MetricBase): 54 | def __init__(self, pred=None, target=None, no_relation_idx=0): 55 | super().__init__() 56 | self._init_param_map(pred=pred, target=target, seq_len=None) 57 | self.no_relation = no_relation_idx 58 | self.num_predict = 0 59 | self.num_golden = 0 60 | self.true_positive = 0 61 | 62 | def evaluate(self, pred, target, seq_len=None): 63 | ''' 64 | :param pred: batch_size 65 | :param target: batch_size 66 | :param seq_len: not uesed when doing text classification 67 | :return: 68 | ''' 69 | 70 | if not isinstance(pred, torch.Tensor): 71 | raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 72 | f"got {type(pred)}.") 73 | if not isinstance(target, torch.Tensor): 74 | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 75 | f"got {type(target)}.") 76 | 77 | if pred.dim() != target.dim(): 78 | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " 79 | f"size:{pred.size()}, target should have size: {pred.size()} or " 80 | f"{pred.size()[:-1]}, got {target.size()}.") 81 | 82 | preds = pred.detach().cpu().numpy().tolist() 83 | targets = target.to('cpu').numpy().tolist() 84 | for pred, target in zip(preds, targets): 85 | if pred == target and pred != self.no_relation: 86 | self.true_positive += 1 87 | if target != self.no_relation: 88 | self.num_golden += 1 89 | if pred != self.no_relation: 90 | self.num_predict += 1 91 | 92 | def get_metric(self, reset=True): 93 | if self.num_predict > 0: 94 | micro_precision = self.true_positive / self.num_predict 95 | else: 96 | micro_precision = 0. 97 | micro_recall = self.true_positive / self.num_golden 98 | micro_fscore = self._calculate_f1(micro_precision, micro_recall) 99 | evaluate_result = { 100 | 'f_score': micro_fscore, 101 | 'precision': micro_precision, 102 | 'recall': micro_recall 103 | } 104 | 105 | if reset: 106 | self.num_predict = 0 107 | self.num_golden = 0 108 | self.true_positive = 0 109 | 110 | return evaluate_result 111 | 112 | def _calculate_f1(self, p, r): 113 | if r == 0.: 114 | return 0. 115 | return 2 * p * r / float(p + r) 116 | 117 | 118 | class TypingMetric(MetricBase): 119 | def __init__(self, pred=None, target=None): 120 | super().__init__() 121 | self._init_param_map(pred=pred, target=target, seq_len=None) 122 | self.acc_count = 0 123 | self.total = 0 124 | self._target = [] 125 | self._pred = [] 126 | 127 | def evaluate(self, pred, target, seq_len=None): 128 | ''' 129 | :param pred: batch_size x num_labels 130 | :param target: batch_size x num_labels 131 | :param seq_len: not uesed when doing text classification 132 | :return: 133 | ''' 134 | 135 | if not isinstance(pred, torch.Tensor): 136 | raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 137 | f"got {type(pred)}.") 138 | if not isinstance(target, torch.Tensor): 139 | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 140 | f"got {type(target)}.") 141 | 142 | if pred.dim() != target.dim(): 143 | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " 144 | f"size:{pred.size()}, target should have size: {pred.size()} or " 145 | f"{pred.size()[:-1]}, got {target.size()}.") 146 | 147 | pred = pred.detach().cpu().numpy() 148 | target = target.to('cpu').numpy() 149 | cnt = 0 150 | y1, y2 = [], [] 151 | for x1, x2 in zip(pred, target): 152 | yy1 = [] 153 | yy2 = [] 154 | for i in range(len(x1)): 155 | if x1[i] > 0: 156 | yy1.append(i) 157 | if x2[i] > 0: 158 | yy2.append(i) 159 | y1.append(yy1) 160 | y2.append(yy2) 161 | cnt += set(yy1) == set(yy2) 162 | 163 | self.acc_count += cnt 164 | self.total += len(pred) 165 | self._pred.extend(y1) 166 | self._target.extend(y2) 167 | 168 | def get_metric(self, reset=True): 169 | # for calculating macro F1 170 | num_predict, num_golden = 0, 0 171 | p = 0. 172 | r = 0. 173 | # for calculating micro F1 174 | num_predicted_labels = 0. 175 | num_golden_labels = 0. 176 | num_correct_labels = 0. 177 | 178 | for true_labels, predicted_labels in zip(self._target, self._pred): 179 | overlap = len(set(predicted_labels).intersection(set(true_labels))) 180 | # calculating macro F1 181 | if len(predicted_labels) > 0: 182 | p += overlap / float(len(predicted_labels)) 183 | num_predict += 1 184 | if len(true_labels) > 0: 185 | r += overlap / float(len(true_labels)) 186 | num_golden += 1 187 | # calculating micro F1 188 | num_predicted_labels += len(predicted_labels) 189 | num_golden_labels += len(true_labels) 190 | num_correct_labels += overlap 191 | 192 | if num_predict > 0: 193 | macro_precision = p / num_predict 194 | else: 195 | macro_precision = 0. 196 | macro_recall = r / num_golden 197 | macro = self._calculate_f1(macro_precision, macro_recall) 198 | 199 | if num_predicted_labels > 0: 200 | micro_precision = num_correct_labels / num_predicted_labels 201 | else: 202 | micro_precision = 0. 203 | micro_recall = num_correct_labels / num_golden_labels 204 | micro = self._calculate_f1(micro_precision, micro_recall) 205 | 206 | evaluate_result = {'micro_f': micro, 207 | 'micro_p': micro_precision, 208 | 'micro_r': micro_recall, 209 | 'acc': round(float(self.acc_count) / (self.total + 1e-12), 6), 210 | # 'macro_p': macro_precision, 211 | # 'macro_r': macro_recall, 212 | # 'macro_f': macro, 213 | } 214 | if reset: 215 | self.acc_count = 0 216 | self.total = 0 217 | self._pred = [] 218 | self._target = [] 219 | 220 | return evaluate_result 221 | 222 | def _calculate_f1(self, p, r): 223 | if r == 0.: 224 | return 0. 225 | return 2 * p * r / float(p + r) -------------------------------------------------------------------------------- /finetune/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss 4 | from transformers import BertPreTrainedModel, RobertaModel, RobertaConfig 5 | from transformers.modeling_bert import BertLayerNorm, gelu 6 | 7 | 8 | class ClsHead(nn.Module): 9 | def __init__(self, config, num_labels, dropout=0.3): 10 | super().__init__() 11 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 12 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 13 | 14 | self.decoder = nn.Linear(config.hidden_size, num_labels, bias=False) 15 | self.bias = nn.Parameter(torch.zeros(num_labels), requires_grad=True) 16 | self.dropout = nn.Dropout(p=dropout) 17 | self.decoder.bias = self.bias 18 | 19 | def forward(self, features, **kwargs): 20 | x = self.dense(features) 21 | x = gelu(x) 22 | x = self.dropout(x) 23 | x = self.layer_norm(x) 24 | x = self.decoder(x) 25 | return x 26 | 27 | 28 | 29 | class CoLAKEForRE(BertPreTrainedModel): 30 | base_model_prefix = "roberta" 31 | 32 | def __init__(self, config, num_types, ent_emb): 33 | super().__init__(config) 34 | self.roberta = RobertaModel(config) 35 | self.rel_head = ClsHead(config, num_types) 36 | self.apply(self._init_weights) 37 | self.ent_embeddings_n = nn.Embedding.from_pretrained(ent_emb) 38 | self.num_types = num_types 39 | 40 | def tie_rel_weights(self, rel_cls_weight): 41 | # rel_index: num_types 42 | self.rel_head.decoder.weight.data = rel_cls_weight 43 | if getattr(self.rel_head.decoder, "bias", None) is not None: 44 | self.rel_head.decoder.bias.data = torch.nn.functional.pad( 45 | self.rel_head.decoder.bias.data, 46 | (0, self.rel_head.decoder.weight.shape[0] - self.rel_head.decoder.bias.shape[0],), 47 | "constant", 48 | 0, 49 | ) 50 | 51 | def forward( 52 | self, 53 | input_ids=None, 54 | attention_mask=None, 55 | token_type_ids=None, 56 | position_ids=None, 57 | head_mask=None, 58 | inputs_embeds=None, 59 | n_word_nodes=None, 60 | n_entity_nodes=None, 61 | target=None 62 | ): 63 | n_word_nodes = n_word_nodes[0] 64 | word_embeddings = self.roberta.embeddings.word_embeddings( 65 | input_ids[:, :n_word_nodes]) # batch x n_word_nodes x hidden_size 66 | 67 | ent_embeddings = self.ent_embeddings_n( 68 | input_ids[:, n_word_nodes:]) 69 | 70 | inputs_embeds = torch.cat([word_embeddings, ent_embeddings], 71 | dim=1) # batch x seq_len x hidden_size 72 | 73 | outputs = self.roberta( 74 | input_ids=None, 75 | attention_mask=attention_mask, 76 | token_type_ids=token_type_ids, 77 | position_ids=position_ids, 78 | head_mask=head_mask, 79 | inputs_embeds=inputs_embeds, 80 | ) 81 | pooler_output = outputs[0][:, 0, :] # batch x hidden_size 82 | logits = self.rel_head(pooler_output) 83 | loss_fct = CrossEntropyLoss() 84 | loss = loss_fct(logits.view(-1, logits.size(-1)), target.view(-1)) 85 | return {'loss': loss, 'pred': torch.argmax(logits, dim=-1)} 86 | 87 | 88 | class CoLAKEForTyping(BertPreTrainedModel): 89 | base_model_prefix = "roberta" 90 | 91 | def __init__(self, config, num_types, ent_emb): 92 | super().__init__(config) 93 | self.roberta = RobertaModel(config) 94 | self.cls_head = ClsHead(config, num_types) 95 | self.apply(self._init_weights) 96 | self.ent_embeddings_n = nn.Embedding.from_pretrained(ent_emb) 97 | self.num_types = num_types 98 | 99 | def forward( 100 | self, 101 | input_ids=None, 102 | attention_mask=None, 103 | token_type_ids=None, 104 | position_ids=None, 105 | head_mask=None, 106 | inputs_embeds=None, 107 | n_word_nodes=None, 108 | n_entity_nodes=None, 109 | target=None 110 | ): 111 | n_word_nodes = n_word_nodes[0] 112 | word_embeddings = self.roberta.embeddings.word_embeddings( 113 | input_ids[:, :n_word_nodes]) # batch x n_word_nodes x hidden_size 114 | 115 | ent_embeddings = self.ent_embeddings_n( 116 | input_ids[:, n_word_nodes:]) 117 | 118 | inputs_embeds = torch.cat([word_embeddings, ent_embeddings], 119 | dim=1) # batch x seq_len x hidden_size 120 | 121 | outputs = self.roberta( 122 | input_ids=None, 123 | attention_mask=attention_mask, 124 | token_type_ids=token_type_ids, 125 | position_ids=position_ids, 126 | head_mask=head_mask, 127 | inputs_embeds=inputs_embeds, 128 | ) 129 | pooler_output = outputs[0][:, 0, :] # batch x hidden_size 130 | logits = self.cls_head(pooler_output) 131 | loss_fct = BCEWithLogitsLoss() 132 | loss = loss_fct(logits.view(-1, self.num_types), target.view(-1, self.num_types)) 133 | return {'loss': loss, 'pred': logits} -------------------------------------------------------------------------------- /finetune/run_re.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | import torch.nn as nn 9 | from transformers import RobertaConfig, RobertaTokenizer 10 | 11 | import fitlog 12 | from fastNLP import cache_results 13 | from fastNLP import FitlogCallback, WarmupCallback, GradientClipCallback 14 | from fastNLP import RandomSampler, TorchLoaderIter, LossInForward, Trainer, Tester 15 | 16 | sys.path.append('../') 17 | from finetune.dataloader import REGraphDataSet 18 | from finetune.model import CoLAKEForRE 19 | from finetune.metrics import MacroMetric 20 | from finetune.utils import build_label_vocab, build_temp_ent_vocab 21 | from pretrain.utils import load_ent_rel_vocabs 22 | 23 | 24 | @cache_results(_cache_fp='fewrel_CoLAKE.bin', _refresh=False) 25 | def load_fewrel_graph_data(data_dir): 26 | datasets = ['train', 'dev', 'test'] 27 | label_vocab = build_label_vocab(data_dir) 28 | ent_vocab = build_temp_ent_vocab(data_dir) 29 | result = [] 30 | for set_type in datasets: 31 | print('processing {} set...'.format(set_type)) 32 | dataset = REGraphDataSet(data_dir, set_type=set_type, label_vocab=label_vocab, ent_vocab=ent_vocab) 33 | result.append(dataset) 34 | result.append(ent_vocab) 35 | return result 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--data_dir', type=str, default='../data/fewrel', 41 | help="data directory path") 42 | parser.add_argument('--log_dir', type=str, default='./logs/', 43 | help="fitlog directory path") 44 | parser.add_argument('--batch_size', type=int, default=32, help="batch size") 45 | parser.add_argument('--lr', type=float, default=5e-5, help="learning rate") 46 | parser.add_argument('--beta', type=float, default=0.999, help="beta_2 of adam") 47 | parser.add_argument('--weight_decay', type=float, default=0.01, help="weight decay") 48 | parser.add_argument('--warm_up', type=float, default=0.1, help="warmup proportion or steps") 49 | parser.add_argument('--epoch', type=int, default=3, help="number of epochs") 50 | parser.add_argument('--grad_accumulation', type=int, default=1, help="gradient accumulation") 51 | parser.add_argument('--gpu', type=str, default='all', help="run script on which devices") 52 | parser.add_argument('--debug', action='store_true', help="do not log") 53 | parser.add_argument('--model_path', type=str, default="../model/", 54 | help="the path of directory containing model and entity embeddings.") 55 | return parser.parse_args() 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | 61 | if args.debug: 62 | fitlog.debug() 63 | 64 | fitlog.set_log_dir(args.log_dir) 65 | fitlog.commit(__file__) 66 | fitlog.add_hyper_in_file(__file__) 67 | fitlog.add_hyper(args) 68 | if args.gpu != 'all': 69 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 70 | 71 | train_set, dev_set, test_set, temp_ent_vocab = load_fewrel_graph_data(data_dir=args.data_dir) 72 | 73 | print('data directory: {}'.format(args.data_dir)) 74 | print('# of train samples: {}'.format(len(train_set))) 75 | print('# of dev samples: {}'.format(len(dev_set))) 76 | print('# of test samples: {}'.format(len(test_set))) 77 | 78 | ent_vocab, rel_vocab = load_ent_rel_vocabs(path='../') 79 | 80 | # load entity embeddings 81 | ent_index = [] 82 | for k, v in temp_ent_vocab.items(): 83 | ent_index.append(ent_vocab[k]) 84 | ent_index = torch.tensor(ent_index) 85 | ent_emb = np.load(os.path.join(args.model_path, 'entities.npy')) 86 | ent_embedding = nn.Embedding.from_pretrained(torch.from_numpy(ent_emb)) 87 | ent_emb = ent_embedding(ent_index.view(1, -1)).squeeze().detach() 88 | 89 | # load CoLAKE parameters 90 | config = RobertaConfig.from_pretrained('roberta-base', type_vocab_size=3) 91 | model = CoLAKEForRE(config, 92 | num_types=len(train_set.label_vocab), 93 | ent_emb=ent_emb) 94 | states_dict = torch.load(os.path.join(args.model_path, 'model.bin')) 95 | model.load_state_dict(states_dict, strict=False) 96 | print('parameters below are randomly initializecd:') 97 | for name, param in model.named_parameters(): 98 | if name not in states_dict: 99 | print(name) 100 | 101 | # tie relation classification head 102 | rel_index = [] 103 | for k, v in train_set.label_vocab.items(): 104 | rel_index.append(rel_vocab[k]) 105 | rel_index = torch.LongTensor(rel_index) 106 | rel_embeddings = nn.Embedding.from_pretrained(states_dict['rel_embeddings.weight']) 107 | rel_index = rel_index.cuda() 108 | rel_cls_weight = rel_embeddings(rel_index.view(1, -1)).squeeze() 109 | model.tie_rel_weights(rel_cls_weight) 110 | 111 | model.rel_head.dense.weight.data = states_dict['rel_lm_head.dense.weight'] 112 | model.rel_head.dense.bias.data = states_dict['rel_lm_head.dense.bias'] 113 | model.rel_head.layer_norm.weight.data = states_dict['rel_lm_head.layer_norm.weight'] 114 | model.rel_head.layer_norm.bias.data = states_dict['rel_lm_head.layer_norm.bias'] 115 | 116 | model.resize_token_embeddings(len(RobertaTokenizer.from_pretrained('roberta-base')) + 4) 117 | print('parameters of CoLAKE has been loaded.') 118 | 119 | # fine-tune 120 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'embedding'] 121 | param_optimizer = list(model.named_parameters()) 122 | optimizer_grouped_parameters = [ 123 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 124 | 'weight_decay': args.weight_decay}, 125 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 126 | ] 127 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.lr, betas=(0.9, args.beta), eps=1e-6) 128 | 129 | metrics = [MacroMetric(pred='pred', target='target')] 130 | 131 | test_data_iter = TorchLoaderIter(dataset=test_set, batch_size=args.batch_size, sampler=RandomSampler(), 132 | num_workers=4, 133 | collate_fn=test_set.collate_fn) 134 | devices = list(range(torch.cuda.device_count())) 135 | tester = Tester(data=test_data_iter, model=model, metrics=metrics, device=devices) 136 | # tester.test() 137 | 138 | fitlog_callback = FitlogCallback(tester=tester, log_loss_every=100, verbose=1) 139 | gradient_clip_callback = GradientClipCallback(clip_value=1, clip_type='norm') 140 | warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear') 141 | 142 | bsz = args.batch_size // args.grad_accumulation 143 | 144 | train_data_iter = TorchLoaderIter(dataset=train_set, 145 | batch_size=bsz, 146 | sampler=RandomSampler(), 147 | num_workers=4, 148 | collate_fn=train_set.collate_fn) 149 | dev_data_iter = TorchLoaderIter(dataset=dev_set, 150 | batch_size=bsz, 151 | sampler=RandomSampler(), 152 | num_workers=4, 153 | collate_fn=dev_set.collate_fn) 154 | 155 | trainer = Trainer(train_data=train_data_iter, 156 | dev_data=dev_data_iter, 157 | model=model, 158 | optimizer=optimizer, 159 | loss=LossInForward(), 160 | batch_size=bsz, 161 | update_every=args.grad_accumulation, 162 | n_epochs=args.epoch, 163 | metrics=metrics, 164 | callbacks=[fitlog_callback, gradient_clip_callback, warmup_callback], 165 | device=devices, 166 | use_tqdm=True) 167 | 168 | trainer.train(load_best_model=False) 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /finetune/run_typing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | import torch.nn as nn 9 | from transformers import RobertaConfig, RobertaTokenizer 10 | 11 | import fitlog 12 | from fastNLP import cache_results 13 | from fastNLP import FitlogCallback, WarmupCallback, GradientClipCallback 14 | from fastNLP import RandomSampler, TorchLoaderIter, LossInForward, Trainer, Tester 15 | 16 | sys.path.append('../') 17 | from finetune.dataloader import TypingGraphDataSet 18 | from finetune.model import CoLAKEForTyping 19 | from finetune.metrics import TypingMetric 20 | from finetune.utils import build_label_vocab, build_temp_ent_vocab 21 | from pretrain.utils import load_ent_rel_vocabs 22 | 23 | 24 | @cache_results(_cache_fp='openentity_CoLAKE.bin', _refresh=False) 25 | def load_openentity_graph_data(data_dir): 26 | datasets = ['train', 'dev', 'test'] 27 | label_vocab = build_label_vocab(data_dir, task_type='typing') 28 | ent_vocab = build_temp_ent_vocab(data_dir) 29 | result = [] 30 | for set_type in datasets: 31 | print('processing {} set...'.format(set_type)) 32 | dataset = TypingGraphDataSet(data_dir, set_type=set_type, label_vocab=label_vocab, ent_vocab=ent_vocab) 33 | result.append(dataset) 34 | result.append(ent_vocab) 35 | return result 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--data_dir', type=str, default='../data/openentity', 41 | help="data directory path") 42 | parser.add_argument('--log_dir', type=str, default='./logs/', 43 | help="fitlog directory path") 44 | parser.add_argument('--batch_size', type=int, default=32, help="batch size") 45 | parser.add_argument('--lr', type=float, default=5e-5, help="learning rate") 46 | parser.add_argument('--beta', type=float, default=0.999, help="beta_2 of adam") 47 | parser.add_argument('--weight_decay', type=float, default=0.01, help="weight decay") 48 | parser.add_argument('--warm_up', type=float, default=0.1, help="warmup proportion or steps") 49 | parser.add_argument('--epoch', type=int, default=5, help="number of epochs") 50 | parser.add_argument('--grad_accumulation', type=int, default=1, help="gradient accumulation") 51 | parser.add_argument('--gpu', type=str, default='all', help="run script on which devices") 52 | parser.add_argument('--debug', action='store_true', help="do not log") 53 | parser.add_argument('--model_path', type=str, default='../model/', 54 | help="load params of trained model") 55 | return parser.parse_args() 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | 61 | if args.debug: 62 | fitlog.debug() 63 | 64 | fitlog.set_log_dir(args.log_dir) 65 | fitlog.commit(__file__) 66 | fitlog.add_hyper_in_file(__file__) 67 | fitlog.add_hyper(args) 68 | if args.gpu != 'all': 69 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 70 | 71 | train_set, dev_set, test_set, temp_ent_vocab = load_openentity_graph_data(data_dir=args.data_dir) 72 | 73 | print('data directory: {}'.format(args.data_dir)) 74 | print('# of train samples: {}'.format(len(train_set))) 75 | print('# of dev samples: {}'.format(len(dev_set))) 76 | print('# of test samples: {}'.format(len(test_set))) 77 | print('# of labels: {}'.format(len(train_set.label_vocab))) 78 | 79 | ent_vocab, rel_vocab = load_ent_rel_vocabs(path='../') 80 | 81 | # load entity embeddings 82 | ent_index = [] 83 | for k, v in temp_ent_vocab.items(): 84 | assert v == len(ent_index) 85 | ent_index.append(ent_vocab[k]) 86 | ent_index = torch.tensor(ent_index) 87 | ent_emb = np.load(os.path.join(args.model_path, 'entities.npy')) 88 | ent_embedding = nn.Embedding.from_pretrained(torch.from_numpy(ent_emb)) 89 | ent_emb = ent_embedding(ent_index.view(1, -1)).squeeze().detach() 90 | 91 | # load CoLAKE parameters 92 | config = RobertaConfig.from_pretrained('roberta-base', type_vocab_size=3) 93 | model = CoLAKEForTyping(config, 94 | num_types=len(train_set.label_vocab), 95 | ent_emb=ent_emb) 96 | states_dict = torch.load(os.path.join(args.model_path, 'model.bin')) 97 | model.load_state_dict(states_dict, strict=False) 98 | print('parameters below are randomly initializecd:') 99 | for name, param in model.named_parameters(): 100 | if name not in states_dict: 101 | print(name) 102 | 103 | model.cls_head.dense.weight.data = states_dict['ent_lm_head.dense.weight'] 104 | model.cls_head.dense.bias.data = states_dict['ent_lm_head.dense.bias'] 105 | model.cls_head.layer_norm.weight.data = states_dict['ent_lm_head.layer_norm.weight'] 106 | model.cls_head.layer_norm.bias.data = states_dict['ent_lm_head.layer_norm.bias'] 107 | 108 | model.resize_token_embeddings(len(RobertaTokenizer.from_pretrained('roberta-base')) + 2) 109 | print('parameters of CoLAKE has been loaded.') 110 | 111 | # fine-tune 112 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'embedding'] 113 | param_optimizer = list(model.named_parameters()) 114 | optimizer_grouped_parameters = [ 115 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 116 | 'weight_decay': args.weight_decay}, 117 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 118 | ] 119 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.lr, betas=(0.9, args.beta), eps=1e-6) 120 | 121 | metrics = [TypingMetric(pred='pred', target='target')] 122 | 123 | test_data_iter = TorchLoaderIter(dataset=test_set, batch_size=args.batch_size, sampler=RandomSampler(), 124 | num_workers=4, 125 | collate_fn=test_set.collate_fn) 126 | devices = list(range(torch.cuda.device_count())) 127 | tester = Tester(data=test_data_iter, model=model, metrics=metrics, device=devices) 128 | # tester.test() 129 | 130 | fitlog_callback = FitlogCallback(tester=tester, log_loss_every=100, verbose=1) 131 | gradient_clip_callback = GradientClipCallback(clip_value=1, clip_type='norm') 132 | warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear') 133 | 134 | bsz = args.batch_size // args.grad_accumulation 135 | 136 | train_data_iter = TorchLoaderIter(dataset=train_set, 137 | batch_size=bsz, 138 | sampler=RandomSampler(), 139 | num_workers=4, 140 | collate_fn=train_set.collate_fn) 141 | dev_data_iter = TorchLoaderIter(dataset=dev_set, 142 | batch_size=bsz, 143 | sampler=RandomSampler(), 144 | num_workers=4, 145 | collate_fn=dev_set.collate_fn) 146 | 147 | trainer = Trainer(train_data=train_data_iter, 148 | dev_data=dev_data_iter, 149 | model=model, 150 | optimizer=optimizer, 151 | loss=LossInForward(), 152 | batch_size=bsz, 153 | update_every=args.grad_accumulation, 154 | n_epochs=args.epoch, 155 | metrics=metrics, 156 | callbacks=[fitlog_callback, gradient_clip_callback, warmup_callback], 157 | device=devices, 158 | use_tqdm=True) 159 | 160 | trainer.train(load_best_model=False) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /finetune/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def build_label_vocab(data_dir, task_type='re'): 6 | label_vocab = {} 7 | with open(os.path.join(data_dir, 'train.json'), 'r', encoding='utf-8') as fin: 8 | data = json.load(fin) 9 | if task_type == 're': 10 | for ins in data: 11 | label = ins['label'] 12 | if label not in label_vocab: 13 | label_vocab[label] = len(label_vocab) 14 | elif task_type == 'typing': 15 | for ins in data: 16 | labels = ins['labels'] 17 | for label in labels: 18 | if label not in label_vocab: 19 | label_vocab[label] = len(label_vocab) 20 | else: 21 | raise RuntimeError('wrong task_type') 22 | print('# of labels: {}'.format(len(label_vocab))) 23 | return label_vocab 24 | 25 | 26 | def build_temp_ent_vocab(path): 27 | ent_vocab = {'': 0, '': 1, '': 2} 28 | files = ['train.json', 'dev.json', 'test.json'] 29 | for file in files: 30 | with open(os.path.join(path, file), 'r', encoding='utf-8') as fin: 31 | data = json.load(fin) 32 | for ins in data: 33 | for node in ins['nodes']: 34 | if isinstance(node, str) and node.startswith('Q'): 35 | if node not in ent_vocab: 36 | ent_vocab[node] = len(ent_vocab) 37 | print('# of entities occurred in train/dev/test files: {}'.format(len(ent_vocab))) 38 | return ent_vocab -------------------------------------------------------------------------------- /lama/batch_eval_KB_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from tqdm import tqdm 8 | from random import shuffle 9 | import os 10 | import json 11 | import spacy 12 | from pprint import pprint 13 | import logging.config 14 | import logging 15 | import pickle 16 | from multiprocessing.pool import ThreadPool 17 | import multiprocessing 18 | import lama.evaluation_metrics as metrics 19 | import time, sys 20 | from lama.lama_utils import parse_template, load_vocab, load_file 21 | 22 | 23 | def batchify(data, batch_size): 24 | msg = "" 25 | list_samples_batches = [] 26 | list_sentences_batches = [] 27 | current_samples_batch = [] 28 | current_sentences_batches = [] 29 | c = 0 30 | 31 | # sort to group togheter sentences with similar length 32 | for sample in sorted( 33 | data, key=lambda k: len(" ".join(k["masked_sentences"]).split()) 34 | ): 35 | masked_sentences = sample["masked_sentences"] 36 | current_samples_batch.append(sample) 37 | current_sentences_batches.append(masked_sentences) 38 | c += 1 39 | if c >= batch_size: 40 | list_samples_batches.append(current_samples_batch) 41 | list_sentences_batches.append(current_sentences_batches) 42 | current_samples_batch = [] 43 | current_sentences_batches = [] 44 | c = 0 45 | 46 | # last batch 47 | if current_samples_batch and len(current_samples_batch) > 0: 48 | list_samples_batches.append(current_samples_batch) 49 | list_sentences_batches.append(current_sentences_batches) 50 | 51 | return list_samples_batches, list_sentences_batches, msg 52 | 53 | 54 | def run_thread(arguments): 55 | 56 | msg = "" 57 | 58 | # 1. compute the ranking metrics on the filtered log_probs tensor 59 | sample_MRR, sample_P, experiment_result, return_msg = metrics.get_ranking( 60 | arguments["filtered_log_probs"], 61 | arguments["masked_indices"], 62 | arguments["vocab"], 63 | label_index=arguments["label_index"], 64 | index_list=arguments["index_list"], 65 | print_generation=arguments["interactive"], 66 | topk=10000, 67 | ) 68 | msg += "\n" + return_msg 69 | 70 | sample_perplexity = 0.0 71 | 72 | return experiment_result, sample_MRR, sample_P, sample_perplexity, msg 73 | 74 | 75 | def filter_samples(model, samples, vocab_subset, max_sentence_length, template): 76 | msg = "" 77 | new_samples = [] 78 | samples_exluded = 0 79 | for sample in samples: 80 | excluded = False 81 | if "obj_label" in sample and "sub_label" in sample: 82 | 83 | obj_label_ids = model.get_id(sample["obj_label"]) 84 | 85 | if obj_label_ids: 86 | recostructed_word = " ".join( 87 | [model.vocab[x] for x in obj_label_ids] 88 | ).strip() 89 | else: 90 | recostructed_word = None 91 | 92 | excluded = False 93 | if not template or len(template) == 0: 94 | masked_sentences = sample["masked_sentences"] 95 | text = " ".join(masked_sentences) 96 | if len(text.split()) > max_sentence_length: 97 | msg += "\tEXCLUDED for exeeding max sentence length: {}\n".format( 98 | masked_sentences 99 | ) 100 | samples_exluded += 1 101 | excluded = True 102 | 103 | # MAKE SURE THAT obj_label IS IN VOCABULARIES 104 | if vocab_subset: 105 | for x in sample["obj_label"].split(" "): 106 | if x not in vocab_subset: 107 | excluded = True 108 | msg += "\tEXCLUDED object label {} not in vocab subset\n".format( 109 | sample["obj_label"] 110 | ) 111 | samples_exluded += 1 112 | break 113 | 114 | if excluded: 115 | pass 116 | elif obj_label_ids is None: 117 | msg += "\tEXCLUDED object label {} not in model vocabulary\n".format( 118 | sample["obj_label"] 119 | ) 120 | samples_exluded += 1 121 | elif not recostructed_word or recostructed_word != sample["obj_label"]: 122 | msg += "\tEXCLUDED object label {} not in model vocabulary\n".format( 123 | sample["obj_label"] 124 | ) 125 | samples_exluded += 1 126 | # elif vocab_subset is not None and sample['obj_label'] not in vocab_subset: 127 | # msg += "\tEXCLUDED object label {} not in vocab subset\n".format(sample['obj_label']) 128 | # samples_exluded+=1 129 | elif "judgments" in sample: 130 | # only for Google-RE 131 | num_no = 0 132 | num_yes = 0 133 | for x in sample["judgments"]: 134 | if x["judgment"] == "yes": 135 | num_yes += 1 136 | else: 137 | num_no += 1 138 | if num_no > num_yes: 139 | # SKIP NEGATIVE EVIDENCE 140 | pass 141 | else: 142 | new_samples.append(sample) 143 | else: 144 | new_samples.append(sample) 145 | else: 146 | msg += "\tEXCLUDED since 'obj_label' not sample or 'sub_label' not in sample: {}\n".format( 147 | sample 148 | ) 149 | samples_exluded += 1 150 | msg += "samples exluded : {}\n".format(samples_exluded) 151 | return new_samples, msg 152 | 153 | 154 | def run_evaluation(args, shuffle_data=True, model=None): 155 | 156 | msg = "" 157 | print(model) 158 | 159 | # deal with vocab subset 160 | vocab_subset = None 161 | index_list = None 162 | msg += "args: {}\n".format(args) 163 | if args.common_vocab_filename is not None: 164 | vocab_subset = load_vocab(args.common_vocab_filename) 165 | print('size of common vocabulary: {}'.format(len(vocab_subset))) 166 | msg += "common vocabulary size: {}\n".format(len(vocab_subset)) 167 | 168 | filter_logprob_indices, index_list = model.init_indices_for_filter_logprobs( 169 | vocab_subset 170 | ) 171 | 172 | 173 | # stats 174 | samples_with_negative_judgement = 0 175 | samples_with_positive_judgement = 0 176 | 177 | # Mean reciprocal rank 178 | MRR = 0.0 179 | MRR_negative = 0.0 180 | MRR_positive = 0.0 181 | 182 | # Precision at (default 10) 183 | Precision = 0.0 184 | Precision1 = 0.0 185 | Precision_negative = 0.0 186 | Precision_positivie = 0.0 187 | 188 | # spearman rank correlation 189 | # overlap at 1 190 | 191 | data = load_file(args.dataset_filename) 192 | print('# raw samples: {}'.format(len(data))) 193 | 194 | all_samples, ret_msg = filter_samples( 195 | model, data, vocab_subset, args.max_sentence_length, args.template 196 | ) 197 | print("# filtered samples: {}".format(len(all_samples))) 198 | 199 | # if template is active (1) use a single example for (sub,obj) and (2) ... 200 | if args.template and args.template != "": 201 | facts = [] 202 | uris = [] 203 | for sample in all_samples: 204 | sub = sample["sub_label"] 205 | obj = sample["obj_label"] 206 | if (sub, obj) not in facts: 207 | facts.append((sub, obj)) 208 | if "sub_uri" in sample: 209 | uris.append(sample["sub_uri"]) 210 | if len(uris) > 0: 211 | assert len(uris) == len(facts), "{} {}".format(len(uris), len(facts)) 212 | local_msg = "distinct template facts: {}".format(len(facts)) 213 | print(local_msg) 214 | all_samples = [] 215 | for idx, fact in enumerate(facts): 216 | (sub, obj) = fact 217 | sample = {} 218 | sample["sub_label"] = sub 219 | sample["obj_label"] = obj 220 | if len(uris) > 0: 221 | sample["sub_uri"] = uris[idx] 222 | # sobstitute all sentences with a standard template 223 | sample["masked_sentences"] = parse_template( 224 | args.template.strip(), sample["sub_label"].strip(), model.tokenizer.mask_token 225 | ) 226 | 227 | all_samples.append(sample) 228 | 229 | # create uuid if not present 230 | i = 0 231 | for sample in all_samples: 232 | if "uuid" not in sample: 233 | sample["uuid"] = i 234 | i += 1 235 | 236 | # shuffle data 237 | if shuffle_data: 238 | shuffle(all_samples) 239 | 240 | samples_batches, sentences_batches, ret_msg = batchify(all_samples, args.batch_size) 241 | 242 | # ThreadPool 243 | num_threads = args.threads 244 | if num_threads <= 0: 245 | # use all available threads 246 | num_threads = multiprocessing.cpu_count() 247 | pool = ThreadPool(num_threads) 248 | list_of_results = [] 249 | 250 | for i in tqdm(range(len(samples_batches))): 251 | 252 | samples_b = samples_batches[i] 253 | sentences_b = sentences_batches[i] 254 | 255 | ( 256 | original_log_probs_list, 257 | token_ids_list, 258 | masked_indices_list, 259 | ) = model.get_batch_generation(samples_b) 260 | 261 | if vocab_subset is not None: 262 | # filter log_probs 263 | filtered_log_probs_list = model.filter_logprobs( 264 | original_log_probs_list, filter_logprob_indices 265 | ) 266 | else: 267 | filtered_log_probs_list = original_log_probs_list 268 | 269 | label_index_list = [] 270 | for sample in samples_b: 271 | obj_label_id = model.get_id(sample["obj_label"]) 272 | 273 | # MAKE SURE THAT obj_label IS IN VOCABULARIES 274 | if obj_label_id is None: 275 | raise ValueError( 276 | "object label {} not in model vocabulary".format( 277 | sample["obj_label"] 278 | ) 279 | ) 280 | elif model.vocab[obj_label_id[0]] != sample["obj_label"]: 281 | raise ValueError( 282 | "object label {} not in model vocabulary".format( 283 | sample["obj_label"] 284 | ) 285 | ) 286 | elif vocab_subset is not None and sample["obj_label"] not in vocab_subset: 287 | raise ValueError( 288 | "object label {} not in vocab subset".format(sample["obj_label"]) 289 | ) 290 | 291 | label_index_list.append(obj_label_id) 292 | 293 | arguments = [ 294 | { 295 | "original_log_probs": original_log_probs, 296 | "filtered_log_probs": filtered_log_probs, 297 | "token_ids": token_ids, 298 | "vocab": model.vocab, 299 | "label_index": label_index[0], 300 | "masked_indices": masked_indices, 301 | "interactive": False, 302 | "index_list": index_list, 303 | "sample": sample, 304 | } 305 | for original_log_probs, filtered_log_probs, token_ids, masked_indices, label_index, sample in zip( 306 | original_log_probs_list, 307 | filtered_log_probs_list, 308 | token_ids_list, 309 | masked_indices_list, 310 | label_index_list, 311 | samples_b, 312 | ) 313 | ] 314 | # single thread for debug 315 | # for isx,a in enumerate(arguments): 316 | # print(samples_b[isx]) 317 | # run_thread(a) 318 | 319 | # multithread 320 | res = pool.map(run_thread, arguments) 321 | 322 | for idx, result in enumerate(res): 323 | 324 | result_masked_topk, sample_MRR, sample_P, sample_perplexity, msg = result 325 | 326 | sample = samples_b[idx] 327 | 328 | element = {} 329 | element["sample"] = sample 330 | element["uuid"] = sample["uuid"] 331 | element["token_ids"] = token_ids_list[idx] 332 | element["masked_indices"] = masked_indices_list[idx] 333 | element["label_index"] = label_index_list[idx] 334 | element["masked_topk"] = result_masked_topk 335 | element["sample_MRR"] = sample_MRR 336 | element["sample_Precision"] = sample_P 337 | element["sample_perplexity"] = sample_perplexity 338 | element["sample_Precision1"] = result_masked_topk["P_AT_1"] 339 | 340 | # print() 341 | # print("idx: {}".format(idx)) 342 | # print("masked_entity: {}".format(result_masked_topk['masked_entity'])) 343 | # for yi in range(10): 344 | # print("\t{} {}".format(yi,result_masked_topk['topk'][yi])) 345 | # print("masked_indices_list: {}".format(masked_indices_list[idx])) 346 | # print("sample_MRR: {}".format(sample_MRR)) 347 | # print("sample_P: {}".format(sample_P)) 348 | # print("sample: {}".format(sample)) 349 | # print() 350 | 351 | MRR += sample_MRR 352 | Precision += sample_P 353 | Precision1 += element["sample_Precision1"] 354 | 355 | # the judgment of the annotators recording whether they are 356 | # evidence in the sentence that indicates a relation between two entities. 357 | num_yes = 0 358 | num_no = 0 359 | 360 | if "judgments" in sample: 361 | # only for Google-RE 362 | for x in sample["judgments"]: 363 | if x["judgment"] == "yes": 364 | num_yes += 1 365 | else: 366 | num_no += 1 367 | if num_no >= num_yes: 368 | samples_with_negative_judgement += 1 369 | element["judgement"] = "negative" 370 | MRR_negative += sample_MRR 371 | Precision_negative += sample_P 372 | else: 373 | samples_with_positive_judgement += 1 374 | element["judgement"] = "positive" 375 | MRR_positive += sample_MRR 376 | Precision_positivie += sample_P 377 | 378 | list_of_results.append(element) 379 | 380 | pool.close() 381 | pool.join() 382 | 383 | # stats 384 | # Mean reciprocal rank 385 | MRR /= len(list_of_results) 386 | 387 | # Precision 388 | Precision /= len(list_of_results) 389 | Precision1 /= len(list_of_results) 390 | 391 | msg = "all_samples: {}\n".format(len(all_samples)) 392 | msg += "list_of_results: {}\n".format(len(list_of_results)) 393 | msg += "global MRR: {}\n".format(MRR) 394 | msg += "global Precision at 10: {}\n".format(Precision) 395 | msg += "global Precision at 1: {}\n".format(Precision1) 396 | 397 | if samples_with_negative_judgement > 0 and samples_with_positive_judgement > 0: 398 | # Google-RE specific 399 | MRR_negative /= samples_with_negative_judgement 400 | MRR_positive /= samples_with_positive_judgement 401 | Precision_negative /= samples_with_negative_judgement 402 | Precision_positivie /= samples_with_positive_judgement 403 | msg += "samples_with_negative_judgement: {}\n".format( 404 | samples_with_negative_judgement 405 | ) 406 | msg += "samples_with_positive_judgement: {}\n".format( 407 | samples_with_positive_judgement 408 | ) 409 | msg += "MRR_negative: {}\n".format(MRR_negative) 410 | msg += "MRR_positive: {}\n".format(MRR_positive) 411 | msg += "Precision_negative: {}\n".format(Precision_negative) 412 | msg += "Precision_positivie: {}\n".format(Precision_positivie) 413 | 414 | return Precision1 415 | 416 | 417 | -------------------------------------------------------------------------------- /lama/eval_lama.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pprint 3 | import statistics 4 | from os import listdir 5 | import os 6 | from os.path import isfile, join 7 | from shutil import copyfile 8 | from collections import defaultdict 9 | import sys 10 | sys.path.append('../') 11 | from lama.lama_utils import load_file 12 | from lama.model import Roberta 13 | from lama.batch_eval_KB_completion import run_evaluation 14 | 15 | common_vocab_path = "../data/LAMA/common_vocab_cased.txt" 16 | model_path = "../model/" 17 | 18 | def get_TREx_parameters(data_path_pre="../data/LAMA/"): 19 | relations = load_file("{}relations.jsonl".format(data_path_pre)) 20 | data_path_pre += "TREx/" 21 | data_path_post = ".jsonl" 22 | return relations, data_path_pre, data_path_post 23 | 24 | 25 | def get_GoogleRE_parameters(): 26 | relations = [ 27 | { 28 | "relation": "place_of_birth", 29 | "template": "[X] was born in [Y] .", 30 | "template_negated": "[X] was not born in [Y] .", 31 | }, 32 | { 33 | "relation": "date_of_birth", 34 | "template": "[X] (born [Y]).", 35 | "template_negated": "[X] (not born [Y]).", 36 | }, 37 | { 38 | "relation": "place_of_death", 39 | "template": "[X] died in [Y] .", 40 | "template_negated": "[X] did not die in [Y] .", 41 | }, 42 | ] 43 | data_path_pre = "../data/LAMA/Google_RE/" 44 | data_path_post = "_test.jsonl" 45 | return relations, data_path_pre, data_path_post 46 | 47 | 48 | def eval_model(relations, data_path_pre, data_path_post): 49 | all_Precision1 = [] 50 | type_Precision1 = defaultdict(list) 51 | type_count = defaultdict(list) 52 | 53 | for relation in relations: 54 | PARAMETERS = { 55 | "dataset_filename": "{}{}{}".format( 56 | data_path_pre, relation["relation"], data_path_post 57 | ), 58 | "common_vocab_filename": common_vocab_path, 59 | "template": "", 60 | "batch_size": 64, 61 | "max_sentence_length": 100, 62 | "threads": -1, 63 | "model_path": model_path 64 | } 65 | 66 | if "template" in relation: 67 | PARAMETERS["template"] = relation["template"] 68 | 69 | print(PARAMETERS) 70 | 71 | args = argparse.Namespace(**PARAMETERS) 72 | 73 | # see if file exists 74 | try: 75 | data = load_file(args.dataset_filename) 76 | except Exception as e: 77 | print("Relation {} excluded.".format(relation["relation"])) 78 | print("Exception: {}".format(e)) 79 | continue 80 | 81 | model = Roberta(args) 82 | print("Model: {}".format(model.__class__.__name__)) 83 | 84 | Precision1 = run_evaluation(args, shuffle_data=False, model=model) 85 | print("P@1 : {}".format(Precision1), flush=True) 86 | all_Precision1.append(Precision1) 87 | 88 | if "type" in relation: 89 | type_Precision1[relation["type"]].append(Precision1) 90 | data = load_file(PARAMETERS["dataset_filename"]) 91 | type_count[relation["type"]].append(len(data)) 92 | 93 | mean_p1 = statistics.mean(all_Precision1) 94 | print("@@@ mean P@1: {}".format(mean_p1)) 95 | 96 | for t, l in type_Precision1.items(): 97 | print( 98 | "@@@ ", 99 | t, 100 | statistics.mean(l), 101 | sum(type_count[t]), 102 | len(type_count[t]), 103 | flush=True, 104 | ) 105 | 106 | return mean_p1, all_Precision1 107 | 108 | 109 | if __name__ == "__main__": 110 | print("1. Google-RE") 111 | parameters = get_GoogleRE_parameters() 112 | eval_model(*parameters) 113 | 114 | print("2. T-REx") 115 | parameters = get_TREx_parameters() 116 | eval_model(*parameters) 117 | -------------------------------------------------------------------------------- /lama/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import numpy as np 9 | import scipy 10 | 11 | 12 | def __max_probs_values_indices(masked_indices, log_probs, topk=1000): 13 | 14 | # score only first mask 15 | masked_indices = masked_indices[:1] 16 | 17 | masked_index = masked_indices[0] 18 | log_probs = log_probs[masked_index] 19 | 20 | value_max_probs, index_max_probs = torch.topk(input=log_probs,k=topk,dim=0) 21 | index_max_probs = index_max_probs.numpy().astype(int) 22 | value_max_probs = value_max_probs.detach().numpy() 23 | 24 | return log_probs, index_max_probs, value_max_probs 25 | 26 | 27 | def __print_top_k(value_max_probs, index_max_probs, vocab, mask_topk, index_list, max_printouts = 10): 28 | result = [] 29 | msg = "\n| Top{} predictions\n".format(max_printouts) 30 | for i in range(mask_topk): 31 | filtered_idx = index_max_probs[i].item() 32 | 33 | if index_list is not None: 34 | # the softmax layer has been filtered using the vocab_subset 35 | # the original idx should be retrieved 36 | idx = index_list[filtered_idx] 37 | else: 38 | idx = filtered_idx 39 | 40 | log_prob = value_max_probs[i].item() 41 | word_form = vocab[idx] 42 | 43 | if i < max_printouts: 44 | msg += "{:<8d}{:<20s}{:<12.3f}\n".format( 45 | i, 46 | word_form, 47 | log_prob 48 | ) 49 | element = {'i' : i, 'token_idx': idx, 'log_prob': log_prob, 'token_word_form': word_form} 50 | result.append(element) 51 | return result, msg 52 | 53 | 54 | def get_ranking(log_probs, masked_indices, vocab, label_index = None, index_list = None, topk = 1000, P_AT = 10, print_generation=True): 55 | 56 | experiment_result = {} 57 | 58 | log_probs, index_max_probs, value_max_probs = __max_probs_values_indices(masked_indices, log_probs, topk=topk) 59 | result_masked_topk, return_msg = __print_top_k(value_max_probs, index_max_probs, vocab, topk, index_list) 60 | experiment_result['topk'] = result_masked_topk 61 | 62 | if print_generation: 63 | print(return_msg) 64 | 65 | MRR = 0. 66 | P_AT_X = 0. 67 | P_AT_1 = 0. 68 | PERPLEXITY = None 69 | 70 | if label_index is not None: 71 | 72 | # check if the labe_index should be converted to the vocab subset 73 | if index_list is not None: 74 | label_index = index_list.index(label_index) 75 | 76 | query = torch.full(value_max_probs.shape, label_index, dtype=torch.long).numpy().astype(int) 77 | ranking_position = (index_max_probs==query).nonzero() 78 | 79 | # LABEL PERPLEXITY 80 | tokens = torch.from_numpy(np.asarray(label_index)) 81 | label_perplexity = log_probs.gather( 82 | dim=0, 83 | index=tokens, 84 | ) 85 | PERPLEXITY = label_perplexity.item() 86 | 87 | if len(ranking_position) >0 and ranking_position[0].shape[0] != 0: 88 | rank = ranking_position[0][0] + 1 89 | 90 | # print("rank: {}".format(rank)) 91 | 92 | if rank >= 0: 93 | MRR = (1/rank) 94 | if rank >= 0 and rank <= P_AT: 95 | P_AT_X = 1. 96 | if rank == 1: 97 | P_AT_1 = 1. 98 | 99 | experiment_result["MRR"] = MRR 100 | experiment_result["P_AT_X"] = P_AT_X 101 | experiment_result["P_AT_1"] = P_AT_1 102 | experiment_result["PERPLEXITY"] = PERPLEXITY 103 | # 104 | # print("MRR: {}".format(experiment_result["MRR"])) 105 | # print("P_AT_X: {}".format(experiment_result["P_AT_X"])) 106 | # print("P_AT_1: {}".format(experiment_result["P_AT_1"])) 107 | # print("PERPLEXITY: {}".format(experiment_result["PERPLEXITY"])) 108 | 109 | return MRR, P_AT_X, experiment_result, return_msg 110 | 111 | 112 | def __overlap_negation(index_max_probs__negated, index_max_probs): 113 | # compares first ranked prediction of affirmative and negated statements 114 | # if true 1, else: 0 115 | return int(index_max_probs__negated == index_max_probs) 116 | 117 | 118 | def get_negation_metric(log_probs, masked_indices, log_probs_negated, 119 | masked_indices_negated, vocab, index_list=None, 120 | topk = 1): 121 | 122 | return_msg = "" 123 | # if negated sentence present 124 | if len(masked_indices_negated) > 0: 125 | 126 | log_probs, index_max_probs, _ = \ 127 | __max_probs_values_indices(masked_indices, log_probs, topk=topk) 128 | log_probs_negated, index_max_probs_negated, _ = \ 129 | __max_probs_values_indices(masked_indices_negated, 130 | log_probs_negated, topk=topk) 131 | 132 | # overlap btw. affirmative and negated first ranked prediction: 0 or 1 133 | overlap = __overlap_negation(index_max_probs_negated[0], 134 | index_max_probs[0]) 135 | # rank corrl. btw. affirmative and negated predicted log_probs 136 | spearman_rank_corr = scipy.stats.spearmanr(log_probs, 137 | log_probs_negated)[0] 138 | 139 | else: 140 | overlap = np.nan 141 | spearman_rank_corr = np.nan 142 | 143 | return overlap, spearman_rank_corr, return_msg 144 | -------------------------------------------------------------------------------- /lama/lama_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def load_file(filename): 6 | data = [] 7 | with open(filename, "r", encoding='utf-8') as f: 8 | for line in f.readlines(): 9 | data.append(json.loads(line)) 10 | return data 11 | 12 | 13 | def parse_template(template, subject_label, object_label): 14 | SUBJ_SYMBOL = "[X]" 15 | OBJ_SYMBOL = "[Y]" 16 | template = template.replace(SUBJ_SYMBOL, subject_label) 17 | template = template.replace(OBJ_SYMBOL, object_label) 18 | return [template] 19 | 20 | def load_vocab(vocab_filename): 21 | with open(vocab_filename, "r", encoding='utf-8') as f: 22 | lines = f.readlines() 23 | vocab = [x.strip() for x in lines] 24 | return vocab -------------------------------------------------------------------------------- /lama/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from fastNLP import seq_len_to_mask 6 | from transformers import RobertaForMaskedLM, RobertaTokenizer, BertPreTrainedModel, RobertaModel, RobertaConfig 7 | from transformers.modeling_roberta import RobertaLMHead 8 | 9 | 10 | class Roberta(object): 11 | 12 | def __init__(self, args): 13 | # self.dict_file = "{}/{}".format(args.roberta_model_dir, args.roberta_vocab_name) 14 | self.tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 15 | if args.model_path is not None: 16 | print("Testing CoLAKE...") 17 | print('loading model parameters from {}...'.format(args.model_path)) 18 | config = RobertaConfig.from_pretrained('roberta-base', type_vocab_size=3) 19 | self.model = RobertaForMaskedLM(config=config) 20 | states_dict = torch.load(os.path.join(args.model_path, 'model.bin')) 21 | self.model.load_state_dict(states_dict, strict=False) 22 | else: 23 | print("Testing RoBERTa baseline...") 24 | self.model = RobertaForMaskedLM.from_pretrained('roberta-base') 25 | 26 | self._build_vocab() 27 | self._init_inverse_vocab() 28 | self._model_device = 'cpu' 29 | self.max_sentence_length = args.max_sentence_length 30 | 31 | def _cuda(self): 32 | self.model.cuda() 33 | 34 | def _build_vocab(self): 35 | self.vocab = [] 36 | for key in range(len(self.tokenizer)): 37 | value = self.tokenizer.decode([key]) 38 | if value[0] == " ": # if the token starts with a whitespace 39 | value = value.strip() 40 | else: 41 | # this is subword information 42 | value = "_{}_".format(value) 43 | 44 | if value in self.vocab: 45 | # print("WARNING: token '{}' is already in the vocab".format(value)) 46 | value = "{}_{}".format(value, key) 47 | 48 | self.vocab.append(value) 49 | print("size of vocabulary: {}".format(len(self.vocab))) 50 | 51 | def _init_inverse_vocab(self): 52 | self.inverse_vocab = {w: i for i, w in enumerate(self.vocab)} 53 | 54 | def try_cuda(self): 55 | """Move model to GPU if one is available.""" 56 | if torch.cuda.is_available(): 57 | if self._model_device != 'cuda': 58 | self._cuda() 59 | self._model_device = 'cuda' 60 | else: 61 | print('No CUDA found') 62 | 63 | def init_indices_for_filter_logprobs(self, vocab_subset): 64 | index_list = [] 65 | new_vocab_subset = [] 66 | for word in vocab_subset: 67 | if word in self.inverse_vocab: 68 | inverse_id = self.inverse_vocab[word] 69 | index_list.append(inverse_id) 70 | new_vocab_subset.append(word) 71 | else: 72 | msg = "word {} from vocab_subset not in model vocabulary!".format(word) 73 | print("WARNING: {}".format(msg)) 74 | 75 | indices = torch.as_tensor(index_list) 76 | return indices, index_list 77 | 78 | def filter_logprobs(self, log_probs, indices): 79 | new_log_probs = log_probs.index_select(dim=2, index=indices) 80 | return new_log_probs 81 | 82 | def get_id(self, input_string): 83 | # Roberta predicts ' London' and not 'London' 84 | string = " " + str(input_string).strip() 85 | tokens = self.tokenizer.encode(string, add_special_tokens=False) 86 | # return [element.item() for element in tokens.long().flatten()] 87 | return tokens 88 | 89 | def get_batch_generation(self, samples_list, try_cuda=True): 90 | if not samples_list: 91 | return None 92 | if try_cuda: 93 | self.try_cuda() 94 | 95 | tensor_list = [] 96 | masked_indices_list = [] 97 | max_len = 0 98 | output_tokens_list = [] 99 | seq_len = [] 100 | for sample in samples_list: 101 | masked_inputs_list = sample["masked_sentences"] 102 | 103 | tokens_list = [self.tokenizer.bos_token_id] 104 | 105 | for idx, masked_input in enumerate(masked_inputs_list): 106 | tokens_list.extend(self.tokenizer.encode(" " + masked_input.strip(), add_special_tokens=False)) 107 | tokens_list.append(self.tokenizer.eos_token_id) 108 | 109 | # tokens = torch.cat(tokens_list)[: self.max_sentence_length] 110 | tokens = torch.tensor(tokens_list)[: self.max_sentence_length] 111 | output_tokens_list.append(tokens.long().cpu().numpy()) 112 | 113 | seq_len.append(len(tokens)) 114 | if len(tokens) > max_len: 115 | max_len = len(tokens) 116 | tensor_list.append(tokens) 117 | masked_index = (tokens == self.tokenizer.mask_token_id).nonzero().numpy() 118 | for x in masked_index: 119 | masked_indices_list.append([x[0]]) 120 | tokens_list = [] 121 | for tokens in tensor_list: 122 | pad_lenght = max_len - len(tokens) 123 | if pad_lenght > 0: 124 | pad_tensor = torch.full([pad_lenght], self.tokenizer.pad_token_id, dtype=torch.int) 125 | tokens = torch.cat((tokens, pad_tensor.long())) 126 | tokens_list.append(tokens) 127 | 128 | batch_tokens = torch.stack(tokens_list) 129 | seq_len = torch.LongTensor(seq_len) 130 | attn_mask = seq_len_to_mask(seq_len) 131 | 132 | with torch.no_grad(): 133 | # with utils.eval(self.model.model): 134 | self.model.eval() 135 | outputs = self.model( 136 | batch_tokens.long().to(device=self._model_device), 137 | attention_mask=attn_mask.to(device=self._model_device) 138 | ) 139 | log_probs = outputs[0] 140 | 141 | return log_probs.cpu(), output_tokens_list, masked_indices_list 142 | 143 | -------------------------------------------------------------------------------- /preprocess/extract.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import sys 3 | from urllib import parse 4 | import os 5 | from multiprocessing import Pool 6 | 7 | input_folder = "../pretrain_data/output" 8 | 9 | file_list = [] 10 | for path, _, filenames in os.walk(input_folder): 11 | for filename in filenames: 12 | file_list.append(os.path.join(path, filename)) 13 | print(len(file_list)) 14 | 15 | def run_proc(idx, n, file_list): 16 | for i in range(len(file_list)): 17 | if i % n == idx: 18 | input_name = file_list[i] 19 | print('{}: {}'.format(i, input_name)) 20 | target = input_name.replace('pretrain_data/output', "pretrain_data/ann") 21 | folder = '/'.join(target.split('/')[:-1]) 22 | if not os.path.exists(folder): 23 | os.makedirs(folder) 24 | 25 | soup = BeautifulSoup(open(input_name, encoding='utf-8'), features="html5lib") 26 | docs = soup.find_all('doc') 27 | 28 | fout = open(target, 'w', encoding='utf-8') 29 | 30 | for doc in docs: 31 | content = doc.get_text(" sepsepsep ") 32 | while content[0] == "\n": 33 | content = content[1:] 34 | content = [x.strip() for x in content.split("\n")] 35 | content = "".join(content[1:]) 36 | 37 | lookup = [] 38 | for x in doc.find_all("a"): 39 | if x.get('href') is not None: 40 | lookup.append((x.get_text().strip(), parse.unquote(x.get('href')))) 41 | # lookup = [(x.get_text().strip(), parse.unquote(x.get('href'))) for x in doc.find_all("a")] 42 | lookup = "[_end_]".join(["[_map_]".join(x) for x in lookup]) 43 | fout.write(content+"[_end_]"+lookup+"\n") 44 | 45 | fout.close() 46 | 47 | 48 | n = int(sys.argv[1]) 49 | p = Pool(n) 50 | for i in range(n): 51 | p.apply_async(run_proc, args=(i, n, file_list)) 52 | p.close() 53 | p.join() -------------------------------------------------------------------------------- /preprocess/gen_data.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from nltk.tokenize import sent_tokenize 4 | import random 5 | from transformers import RobertaTokenizer 6 | import os 7 | from tqdm import tqdm 8 | from time import * 9 | from multiprocessing import Pool 10 | import json 11 | 12 | input_folder = "../pretrain_data/ann" 13 | 14 | file_list = [] 15 | for path, _, filenames in os.walk(input_folder): 16 | for filename in filenames: 17 | file_list.append(os.path.join(path, filename)) 18 | print('# of files', len(file_list)) 19 | 20 | 21 | def load_data(): 22 | wiki5m_alias2qid, wiki5m_qid2alias = {}, {} 23 | with open("../wikidata5m_alias/wikidata5m_entity.txt", 'r', 24 | encoding='utf-8') as fin: 25 | lines = fin.readlines() 26 | for i in tqdm(range(len(lines))): 27 | line = lines[i] 28 | v = line.strip().split("\t") 29 | if len(v) < 2: 30 | continue 31 | qid = v[0] 32 | for alias in v[1:]: 33 | wiki5m_qid2alias[qid] = alias 34 | wiki5m_alias2qid[alias] = qid 35 | 36 | d_ent = wiki5m_alias2qid 37 | print('wikidata5m_entity.txt (Wikidata5M) loaded!') 38 | 39 | wiki5m_pid2alias = {} 40 | with open("../wikidata5m_alias/wikidata5m_relation.txt", 'r', 41 | encoding='utf-8') as fin: 42 | lines = fin.readlines() 43 | for i in tqdm(range(len(lines))): 44 | line = lines[i] 45 | v = line.strip().split("\t") 46 | if len(v) < 2: 47 | continue 48 | wiki5m_pid2alias[v[0]] = v[1] 49 | print('wikidata5m_relation.txt (Wikidata5M) loaded!') 50 | 51 | # This is to remove FewRel test set from our training data. If your need is not just reproducing the experiments, 52 | # you can discard this part. The `ernie_data` is obtained from https://github.com/thunlp/ERNIE 53 | fewrel_triples = set() 54 | with open('../ernie_data/fewrel/test.json', 'r', encoding='utf-8') as fin: 55 | fewrel_data = json.load(fin) 56 | for ins in fewrel_data: 57 | r = ins['label'] 58 | h, t = ins['ents'][0][0], ins['ents'][1][0] 59 | fewrel_triples.add((h, r, t)) 60 | print('# triples in FewRel test set: {}'.format(len(fewrel_triples))) 61 | print(list(fewrel_triples)[0]) 62 | head_cluster, tail_cluster = {}, {} 63 | num_del = total = 0 64 | with open("../wikidata5m_triplet.txt", 'r', encoding='utf-8') as fin: 65 | lines = fin.readlines() 66 | for i in tqdm(range(len(lines))): 67 | line = lines[i] 68 | v = line.strip().split("\t") 69 | if len(v) != 3: 70 | continue 71 | h, r, t = v 72 | if (h, r, t) not in fewrel_triples: 73 | if h in head_cluster: 74 | head_cluster[h].append((r, t)) 75 | else: 76 | head_cluster[h] = [(r, t)] 77 | if t in tail_cluster: 78 | tail_cluster[t].append((r, h)) 79 | else: 80 | tail_cluster[t] = [(r, h)] 81 | else: 82 | num_del += 1 83 | total += 1 84 | print('wikidata5m_triplet.txt (Wikidata5M) loaded!') 85 | print('deleted {} triples from Wikidata5M.'.format(num_del)) 86 | 87 | return d_ent, head_cluster 88 | 89 | 90 | d_ent, head_cluster = load_data() 91 | 92 | # args 93 | max_neighbors = 15 94 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 95 | 96 | 97 | def run_proc(index, n, file_list, min_seq_len=80, max_seq_len=200, n_samples_per_file=5000): 98 | output_folder = '../pretrain_data/data/' + str(index) 99 | if not os.path.exists(output_folder): 100 | os.makedirs(output_folder) 101 | j = index 102 | large_j = index 103 | drop_samples = 0 104 | is_large = False 105 | n_normal_data, n_large_data = 0, 0 106 | target_filename = os.path.join(output_folder, str(j)) 107 | large_target_filename = os.path.join(output_folder, 'large_' + str(large_j)) 108 | fout_normal = open(target_filename, 'a+', encoding='utf-8') 109 | fout_large = open(large_target_filename, 'a+', encoding='utf-8') 110 | for i in range(len(file_list)): 111 | if i % n == index: 112 | # num_words, num_ents = 0, 0 113 | start_time = time() 114 | input_name = file_list[i] 115 | print('[processing] # {}: {}'.format(i, input_name)) 116 | fin = open(input_name, 'r', encoding='utf-8') 117 | 118 | for doc in fin: 119 | doc = doc.strip() 120 | segs = doc.split("[_end_]") 121 | content = segs[0] 122 | sentences = sent_tokenize(content) 123 | map_segs = segs[1:] 124 | maps = {} # mention -> QID 125 | for x in map_segs: 126 | v = x.split("[_map_]") 127 | if len(v) != 2: 128 | continue 129 | if v[1] in d_ent: # if a wikipedia title is the alias of an entity in wikidata 130 | maps[v[0]] = d_ent[v[1]] 131 | elif v[1].lower() in d_ent: 132 | maps[v[0]] = d_ent[v[1].lower()] 133 | blocks, word_lst = [], [] 134 | s = '' 135 | for sent in sentences: 136 | s = '{} {}'.format(s, sent) 137 | # s = s + ' ' + sent 138 | word_lst = tokenizer.encode(s) 139 | if len(word_lst) >= min_seq_len: 140 | blocks.append(s) 141 | s = '' 142 | if len(s) > 0: 143 | blocks.append(s) 144 | for block in blocks: 145 | anchor_segs = [x.strip() for x in block.split("sepsepsep")] 146 | tokens, entities = [0], [] # [] 147 | node2label = {0: 0} # node:0 -> :0 148 | # edges = [] # mention - entity links 149 | idx = 1 # idx of word nodes in G 150 | pos = 1 # position of current node 151 | soft_position = [0] 152 | entity_position = [] 153 | 154 | for x in anchor_segs: 155 | if len(x) < 1: 156 | continue 157 | if x in maps and maps[x] not in entities: 158 | entities.append(maps[x]) 159 | entity_position.append(pos) 160 | pos += 1 161 | else: 162 | words = tokenizer.encode(x, add_special_tokens=False, add_prefix_space=True) 163 | words = words[:max_seq_len] 164 | for word in words: 165 | node2label[idx] = word 166 | tokens.append(word) 167 | soft_position.append(pos) 168 | idx += 1 169 | pos += 1 170 | if len(entities) == 0: 171 | continue 172 | node2label[idx] = 2 # node:idx -> :2 173 | tokens.append(2) # : 2 174 | soft_position.append(pos) 175 | idx += 1 176 | assert len(tokens) == idx 177 | 178 | G = nx.complete_graph(idx) 179 | for entity, pos in zip(entities, entity_position): 180 | if entity not in G.nodes: 181 | G.add_node(entity) 182 | node2label[entity] = entity 183 | soft_position.append(pos) 184 | G = nx.complete_graph(G) 185 | n_word_nodes = idx 186 | token_types = [0] * n_word_nodes + [1] * len(entities) 187 | relation_to_add = [] 188 | for entity, pos in zip(entities, entity_position): 189 | if entity in head_cluster and random.uniform(0, 1) > 0.5: 190 | triple_lst = head_cluster[entity] 191 | random.shuffle(triple_lst) 192 | head_neighbors = 0 193 | for (r, t) in triple_lst: 194 | if head_neighbors >= max_neighbors: 195 | break 196 | if t not in G.nodes: 197 | G.add_node(t) 198 | node2label[t] = t 199 | soft_position.append(pos + 2) 200 | token_types.append(1) 201 | relation_to_add.append((idx, r, entity, pos + 1, t)) 202 | head_neighbors += 1 203 | idx += 1 204 | for idx, r, entity, pos, t in relation_to_add: 205 | G.add_node(idx) 206 | node2label[idx] = r 207 | G.add_edge(entity, idx) 208 | soft_position.append(pos) 209 | token_types.append(2) 210 | G.add_edge(idx, t) 211 | # check dimension 212 | if len(G.nodes) != len(soft_position): 213 | print('[warning] number of nodes does not match length of position ids') 214 | continue 215 | if len(G.nodes) != len(token_types): 216 | print('[warning] number of nodes does not match length of token_types') 217 | continue 218 | 219 | if len(G.nodes) > 256 and len(G.nodes) <= 512: 220 | is_large = True 221 | elif len(G.nodes) <= 256: 222 | is_large = False 223 | elif len(G.nodes) > 512: 224 | drop_samples += 1 225 | # print('\texceed max nodes limitation. drop instance.') 226 | continue 227 | 228 | adj = np.array(nx.adjacency_matrix(G).todense()) 229 | adj = adj + np.eye(adj.shape[0], dtype=int) 230 | if not is_large: 231 | normal_ins = {'n_word_nodes': n_word_nodes, 'nodes': [node2label[k] for k in G.nodes], 232 | 'soft_position': soft_position, 'adj': adj.tolist(), 233 | 'token_type_ids': token_types} 234 | fout_normal.write(json.dumps(normal_ins) + '\n') 235 | n_normal_data += 1 236 | 237 | else: 238 | large_ins = {'n_word_nodes': n_word_nodes, 'nodes': [node2label[k] for k in G.nodes], 239 | 'soft_position': soft_position, 'adj': adj.tolist(), 'token_type_ids': token_types} 240 | fout_large.write(json.dumps(large_ins) + '\n') 241 | n_large_data += 1 242 | 243 | if n_normal_data >= n_samples_per_file: 244 | n_normal_data = 0 245 | fout_normal.close() 246 | j += n 247 | target_filename = os.path.join(output_folder, str(j)) 248 | fout_normal = open(target_filename, 'a+', encoding='utf-8') 249 | if n_large_data >= n_samples_per_file: 250 | n_large_data = 0 251 | fout_large.close() 252 | large_j += n 253 | large_target_filename = os.path.join(output_folder, 'large_' + str(large_j)) 254 | fout_large = open(large_target_filename, 'a+', encoding='utf-8') 255 | 256 | # avg_pro = num_words * 1.0 / num_ents 257 | # print('words : entities/relations = {}'.format(avg_pro)) 258 | fin.close() 259 | end_time = time() 260 | print('[TIME] {}s'.format(end_time - start_time)) 261 | print('drop {} samples due to max nodes limitation.\n[finished].'.format(drop_samples)) 262 | fout_normal.close() 263 | fout_large.close() 264 | print(target_filename) 265 | print(large_target_filename) 266 | 267 | 268 | import sys 269 | 270 | n = int(sys.argv[1]) 271 | p = Pool(n) 272 | for i in range(n): 273 | p.apply_async(run_proc, args=(i, n, file_list)) 274 | p.close() 275 | p.join() 276 | -------------------------------------------------------------------------------- /preprocess/merge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | input_folder = "../pretrain_data/data" 6 | output_folder = "../pretrain_data/data" 7 | if not os.path.exists(output_folder): 8 | os.mkdir(output_folder) 9 | 10 | normal_data_path = os.path.join(output_folder, 'normal') 11 | if not os.path.exists(normal_data_path): 12 | os.mkdir(normal_data_path) 13 | 14 | large_data_path = os.path.join(output_folder, 'large') 15 | if not os.path.exists(large_data_path): 16 | os.mkdir(large_data_path) 17 | 18 | file_list = [] 19 | for path, _, filenames in os.walk(input_folder): 20 | for filename in filenames: 21 | file_list.append(os.path.join(path, filename)) 22 | print('# of files', len(file_list)) 23 | 24 | data = [] 25 | i = 0 26 | num_sample_per_file = 10000 27 | num_sample_per_large_file = 10000 28 | 29 | for j in tqdm(range(len(file_list))): 30 | file = file_list[j] 31 | if '.txt' in file: 32 | continue 33 | if 'large' in file: 34 | continue 35 | # regular file (may not be of 1k) 36 | tmp = pickle.load(open(file, 'rb')) 37 | # print('processing file {}. # of samples: {}'.format(file, len(tmp))) 38 | data.extend(tmp) 39 | if len(data) >= num_sample_per_file: 40 | with open(os.path.join(normal_data_path, str(i)), 'wb') as fout: 41 | pickle.dump(data[:num_sample_per_file], fout, protocol=4) 42 | i += 1 43 | if len(data) > num_sample_per_file: 44 | data = data[num_sample_per_file:] 45 | else: 46 | data = [] 47 | print('# rest samples: {}'.format(len(data))) 48 | 49 | large_data = [] 50 | i = 0 51 | for j in tqdm(range(len(file_list))): 52 | file = file_list[j] 53 | if 'large' in file: 54 | # large file (may not be of 1k) 55 | tmp = pickle.load(open(file, 'rb')) 56 | # print('processing file {}. # of samples: {}'.format(file, len(tmp))) 57 | large_data.extend(tmp) 58 | if len(large_data) >= num_sample_per_large_file: 59 | with open(os.path.join(large_data_path, str(i)), 'wb') as fout: 60 | pickle.dump(large_data[:num_sample_per_large_file], fout, protocol=4) 61 | i += 1 62 | if len(large_data) > num_sample_per_large_file: 63 | large_data = large_data[num_sample_per_large_file:] 64 | else: 65 | large_data = [] 66 | print('# rest samples: {}'.format(len(large_data))) 67 | 68 | -------------------------------------------------------------------------------- /preprocess/statistic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | from tqdm import tqdm 5 | 6 | file_list = [] 7 | for path, _, filenames in os.walk('../pretrain_data/data'): 8 | for filename in filenames: 9 | file_list.append(os.path.join(path, filename)) 10 | 11 | ent_freq = {'': 0, '': 0, '': 0} 12 | rel_freq = {'': 0, '': 0, '': 0} 13 | ent_vocab = {'': 0, '': 1, '': 2} 14 | rel_vocab = {'': 0, '': 1, '': 2} 15 | for i in tqdm(range(len(file_list))): 16 | with open(file_list[i], 'r', encoding='utf-8') as fin: 17 | for x in fin: 18 | ins = json.loads(x) 19 | for node in ins['nodes']: 20 | if isinstance(node, str): 21 | if node.startswith('Q'): 22 | if node not in ent_freq: 23 | ent_freq[node] = 1 24 | else: 25 | ent_freq[node] += 1 26 | if node not in ent_vocab: 27 | ent_vocab[node] = len(ent_vocab) 28 | if node.startswith('P'): 29 | if node not in rel_freq: 30 | rel_freq[node] = 1 31 | else: 32 | rel_freq[node] += 1 33 | if node not in rel_vocab: 34 | rel_vocab[node] = len(rel_vocab) 35 | 36 | with open('../read_rel_freq.bin', 'wb') as fout: 37 | pickle.dump(rel_freq, fout) 38 | with open('../read_rel_vocab.bin', 'wb') as fout: 39 | pickle.dump(rel_vocab, fout) 40 | print(len(rel_vocab)) 41 | 42 | with open('../read_ent_freq.bin', 'wb') as fout: 43 | pickle.dump(ent_freq, fout) 44 | with open('../read_ent_vocab.bin', 'wb') as fout: 45 | pickle.dump(ent_vocab, fout) 46 | print(len(ent_vocab)) 47 | -------------------------------------------------------------------------------- /pretrain/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import torch 6 | import random 7 | import numpy as np 8 | from torch.utils.data import Dataset, Sampler 9 | from pretrain.utils import create_mlm_labels 10 | 11 | WORD_PADDING_INDEX = 1 12 | ENTITY_PADDING_INDEX = 1 13 | RELATION_PADDING_INDEX = 1 14 | 15 | 16 | class GraphOTFDataSet(Dataset): 17 | def __init__(self, indexed_train_fps, n_workers, rank, word_mask_index, word_vocab_size, k_negative_samples, 18 | ent_vocab, rel_vocab, ent_freq): 19 | # index and mask 20 | 21 | self.input_ids, self.n_word_nodes, self.n_entity_nodes, self.position_ids, self.attention_mask, self.masked_lm_labels, \ 22 | self.ent_masked_lm_labels, self.rel_masked_lm_labels, self.token_type_ids = [], [], [], [], [], [], [], [], [] 23 | self.word_mask_index = word_mask_index 24 | self.word_vocab_size = word_vocab_size 25 | self.rank = rank 26 | self.ent_vocab = ent_vocab 27 | self.rel_vocab = rel_vocab 28 | self.ent_freq = ent_freq 29 | 30 | self.k_negative_samples = k_negative_samples 31 | self.nega_samp_weight = self._nagative_sampling_weight() 32 | 33 | file_per_process = len(indexed_train_fps) // n_workers 34 | if file_per_process * n_workers != len(indexed_train_fps): 35 | if rank == 0: 36 | print('Drop {} files.'.format(len(indexed_train_fps) - file_per_process * n_workers)) 37 | print('# files per process: {}'.format(file_per_process)) 38 | self.fps = indexed_train_fps[rank * file_per_process:(rank + 1) * file_per_process] 39 | 40 | self.current_file_idx = 0 41 | self.data = self.read_file(self.current_file_idx) 42 | self.num_samples_per_file = len(self.data) 43 | self.total_num_samples = self.num_samples_per_file * len(self.fps) 44 | 45 | def read_file(self, idx): 46 | data = [] 47 | with open(self.fps[idx], 'r', encoding='utf-8') as fin: 48 | for x in fin: 49 | instance = json.loads(x) 50 | # n_word_nodes = instance['n_word_nodes'] 51 | words, entities, relations = self._split_nodes(instance['nodes'], instance['token_type_ids']) 52 | anchor_entities = self._find_anchor_entities(instance['adj'], len(words), entities) 53 | # entities = self._replace_anchor_entities(anchor_entities, entities) 54 | entities = [self.ent_vocab[ent] for ent in entities] 55 | anchor_entities = [self.ent_vocab[ent] for ent in anchor_entities] 56 | relations = [self.rel_vocab[rel] for rel in relations] 57 | words, word_mlm_labels = create_mlm_labels(words, self.word_mask_index, self.word_vocab_size) 58 | entities, entity_mlm_labels = create_mlm_labels(entities, self.ent_vocab[''], len(self.ent_vocab), 59 | anchor_nodes=anchor_entities) 60 | relations, relation_mlm_labels = create_mlm_labels(relations, self.rel_vocab[''], 61 | len(self.rel_vocab)) 62 | assert len(instance['nodes']) == len(words + entities + relations) 63 | assert len(instance['nodes']) == len(instance['soft_position']) 64 | assert len(instance['nodes']) == len(instance['adj']) 65 | assert len(instance['nodes']) == len(instance['token_type_ids']) 66 | data.append({ 67 | 'input_ids': words + entities + relations, 68 | 'n_word_nodes': len(words), 69 | 'n_entity_nodes': len(entities), 70 | 'position_ids': instance['soft_position'], 71 | 'attention_mask': instance['adj'], 72 | 'token_type_ids': instance['token_type_ids'], 73 | 'masked_lm_labels': word_mlm_labels, 74 | 'ent_masked_lm_labels': entity_mlm_labels, 75 | 'rel_masked_lm_labels': relation_mlm_labels 76 | }) 77 | # data is a list of dict 78 | return data 79 | 80 | def __getitem__(self, item): 81 | file_idx = item // self.num_samples_per_file 82 | if file_idx != self.current_file_idx: 83 | self.data = self.read_file(file_idx) 84 | self.current_file_idx = file_idx 85 | sample = self.data[item - file_idx * self.num_samples_per_file] 86 | # label = { 87 | # 'masked_lm_labels': sample['masked_lm_labels'], 88 | # 'ent_masked_lm_labels': sample['ent_masked_lm_labels'], 89 | # 'rel_masked_lm_labels': sample['rel_masked_lm_labels'] 90 | # } 91 | return sample 92 | 93 | def __len__(self): 94 | return self.total_num_samples 95 | 96 | def _split_nodes(self, nodes, types): 97 | assert len(nodes) == len(types) 98 | words, entities, relations = [], [], [] 99 | for node, type in zip(nodes, types): 100 | if type == 0: 101 | words.append(node) 102 | elif type == 1: 103 | entities.append(node) 104 | elif type == 2: 105 | relations.append(node) 106 | else: 107 | raise ValueError('unknown token type id.') 108 | return words, entities, relations 109 | 110 | def _find_anchor_entities(self, adj, n_word_nodes, entities): 111 | anchor_entities = [] 112 | for i in adj[:n_word_nodes]: 113 | ents = i[n_word_nodes:n_word_nodes + len(entities)] 114 | for j, mask in enumerate(ents): 115 | if mask == 1 and entities[j] not in anchor_entities: 116 | anchor_entities.append(entities[j]) 117 | # if have relations 118 | # if len(adj) > n_word_nodes + len(entities): 119 | # for idx, attn_mask in enumerate(adj[n_word_nodes:n_word_nodes + len(entities)]): 120 | # if entities[idx] in anchor_entities: 121 | # i = idx + n_word_nodes 122 | # for j in range(n_word_nodes + len(entities), len(adj)): 123 | # adj[i][j] = 0 124 | 125 | return anchor_entities 126 | 127 | def _replace_anchor_entities(self, anchor_entities, entities): 128 | replaced_ents = [] 129 | for entity in entities: 130 | if entity in anchor_entities: 131 | x = random.uniform(0, 1) 132 | if x < 0.3: 133 | replaced_ents.append('') 134 | else: 135 | replaced_ents.append(entity) 136 | else: 137 | replaced_ents.append(entity) 138 | return replaced_ents 139 | 140 | def _nagative_sampling_weight(self, pwr=0.75): 141 | ef = [] 142 | for i, ent in enumerate(self.ent_vocab.keys()): 143 | assert self.ent_vocab[ent] == i 144 | ef.append(self.ent_freq[ent]) 145 | # freq = np.array([self.ent_freq[ent] for ent in self.ent_vocab.keys()]) 146 | ef = np.array(ef) 147 | ef = ef / ef.sum() 148 | ef = np.power(ef, pwr) 149 | ef = ef / ef.sum() 150 | return torch.FloatTensor(ef) 151 | 152 | def collate_fn(self, batch): 153 | # batch: [[x1:dict, y1:dict], [x2:dict, y2:dict], ...] 154 | input_keys = ['input_ids', 'n_word_nodes', 'n_entity_nodes', 'position_ids', 'attention_mask', 'ent_index', 155 | 'masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 'token_type_ids'] 156 | target_keys = ['masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 157 | 'word_seq_len', 'ent_seq_len', 'rel_seq_len'] 158 | max_word_nodes, max_entity_nodes, max_relation_nodes = 0, 0, 0 159 | batch_word, batch_entity, batch_relation = [], [], [] 160 | batch_x = {n: [] for n in input_keys} 161 | batch_y = {n: [] for n in target_keys} 162 | 163 | ent_convert_dict = {} 164 | ent_index = [] 165 | # convert list of dict into dict of list 166 | for sample in batch: 167 | for n, v in sample.items(): 168 | if n in input_keys: 169 | batch_x[n].append(v) 170 | if n in target_keys: 171 | batch_y[n].append(v) 172 | # for n, v in x.items(): 173 | # batch_x[n].append(v) 174 | # for n, v in y.items(): 175 | # batch_y[n].append(v) 176 | batch_x['ent_index'].append([]) 177 | n_word_nodes = sample['n_word_nodes'] 178 | n_entity_nodes = sample['n_entity_nodes'] 179 | words = sample['input_ids'][0:n_word_nodes] 180 | entities = sample['input_ids'][n_word_nodes:n_word_nodes + n_entity_nodes] 181 | relations = sample['input_ids'][n_word_nodes + n_entity_nodes:] 182 | batch_word.append(words) 183 | batch_entity.append(entities) 184 | batch_relation.append(relations) 185 | 186 | batch_y['word_seq_len'].append(n_word_nodes) 187 | batch_y['ent_seq_len'].append(n_entity_nodes) 188 | batch_y['rel_seq_len'].append(len(relations)) 189 | 190 | max_word_nodes = len(words) if len(words) > max_word_nodes else max_word_nodes 191 | max_entity_nodes = len(entities) if len(entities) > max_entity_nodes else max_entity_nodes 192 | max_relation_nodes = len(relations) if len(relations) > max_relation_nodes else max_relation_nodes 193 | 194 | for golden_ent in sample['ent_masked_lm_labels']: 195 | if golden_ent >= 0 and golden_ent not in ent_convert_dict: 196 | ent_convert_dict[golden_ent] = len(ent_convert_dict) 197 | ent_index.append(golden_ent) 198 | 199 | # # check convert 200 | # for i in range(len(ent_index)): 201 | # assert ent_convert_dict[ent_index[i]] == i 202 | 203 | if len(ent_index) > 0: 204 | # negative sampling 205 | k_negas = self.k_negative_samples * len(ent_index) 206 | nega_samples = torch.multinomial(self.nega_samp_weight, num_samples=k_negas, replacement=True) 207 | for nega_ent in nega_samples: 208 | ent = int(nega_ent) 209 | if ent not in ent_convert_dict: # 保证无重复 210 | ent_convert_dict[ent] = len(ent_convert_dict) 211 | ent_index.append(ent) 212 | else: 213 | ent_index = [ENTITY_PADDING_INDEX] 214 | 215 | # pad 216 | seq_len = max_word_nodes + max(max_entity_nodes, 1) + max(max_relation_nodes, 1) 217 | for i in range(len(batch_word)): 218 | word_pad = max_word_nodes - len(batch_word[i]) 219 | entity_pad = max_entity_nodes - len(batch_entity[i]) if max_entity_nodes > 0 else 1 220 | relation_pad = max_relation_nodes - len(batch_relation[i]) if max_relation_nodes > 0 else 1 221 | batch_x['input_ids'][i] = batch_word[i] + [WORD_PADDING_INDEX] * word_pad + \ 222 | batch_entity[i] + [ENTITY_PADDING_INDEX] * entity_pad + \ 223 | batch_relation[i] + [RELATION_PADDING_INDEX] * relation_pad 224 | 225 | n_words = batch_x['n_word_nodes'][i] 226 | n_entities = batch_x['n_entity_nodes'][i] 227 | batch_x['position_ids'][i] = batch_x['position_ids'][i][:n_words] + [0] * word_pad + \ 228 | batch_x['position_ids'][i][n_words:n_words + n_entities] + [0] * entity_pad + \ 229 | batch_x['position_ids'][i][n_words + n_entities:] + [0] * relation_pad 230 | 231 | batch_x['token_type_ids'][i] = batch_x['token_type_ids'][i][:n_words] + [0] * word_pad + \ 232 | batch_x['token_type_ids'][i][n_words:n_words + n_entities] + [ 233 | 0] * entity_pad + \ 234 | batch_x['token_type_ids'][i][n_words + n_entities:] + [0] * relation_pad 235 | 236 | adj = torch.tensor(batch_x['attention_mask'][i], dtype=torch.int) 237 | adj = torch.cat((adj[:n_words, :], 238 | torch.ones(word_pad, adj.shape[1], dtype=torch.int), 239 | adj[n_words:n_words + n_entities, :], 240 | torch.ones(entity_pad, adj.shape[1], dtype=torch.int), 241 | adj[n_words + n_entities:, :], 242 | torch.ones(relation_pad, adj.shape[1], dtype=torch.int)), dim=0) 243 | assert adj.shape[0] == seq_len 244 | adj = torch.cat((adj[:, :n_words], 245 | torch.zeros(seq_len, word_pad, dtype=torch.int), 246 | adj[:, n_words:n_words + n_entities], 247 | torch.zeros(seq_len, entity_pad, dtype=torch.int), 248 | adj[:, n_words + n_entities:], 249 | torch.zeros(seq_len, relation_pad, dtype=torch.int)), dim=1) 250 | 251 | batch_x['attention_mask'][i] = adj 252 | batch_x['masked_lm_labels'][i] = batch_x['masked_lm_labels'][i] + [-1] * word_pad 253 | batch_y['masked_lm_labels'][i] = batch_y['masked_lm_labels'][i] + [-1] * word_pad 254 | 255 | batch_x['ent_masked_lm_labels'][i] = [ent_convert_dict[lb] if lb in ent_convert_dict else -1 for lb in 256 | batch_x['ent_masked_lm_labels'][i]] + [-1] * entity_pad 257 | batch_y['ent_masked_lm_labels'][i] = batch_x['ent_masked_lm_labels'][i] 258 | 259 | batch_x['rel_masked_lm_labels'][i] = batch_x['rel_masked_lm_labels'][i] + [-1] * relation_pad 260 | batch_y['rel_masked_lm_labels'][i] = batch_x['rel_masked_lm_labels'][i] 261 | 262 | batch_x['n_word_nodes'][i] = max(max_word_nodes, 1) 263 | batch_x['n_entity_nodes'][i] = max(max_entity_nodes, 1) 264 | batch_x['ent_index'][i] = ent_index 265 | 266 | for k, v in batch_x.items(): 267 | if k == 'attention_mask': 268 | batch_x[k] = torch.stack(v, dim=0) 269 | else: 270 | batch_x[k] = torch.tensor(v) 271 | for k, v in batch_y.items(): 272 | batch_y[k] = torch.tensor(v) 273 | return (batch_x, batch_y) 274 | 275 | 276 | class GraphDataSet(Dataset): 277 | def __init__(self, data_dir, word_mask_index, word_vocab_size, k_negative_samples, 278 | ent_vocab, rel_vocab, ent_freq): 279 | # index and mask 280 | self.input_ids, self.n_word_nodes, self.n_entity_nodes, self.position_ids, self.attention_mask, self.masked_lm_labels, \ 281 | self.ent_masked_lm_labels, self.rel_masked_lm_labels, self.token_type_ids = [], [], [], [], [], [], [], [], [] 282 | self.word_mask_index = word_mask_index 283 | self.word_vocab_size = word_vocab_size 284 | self.ent_vocab = ent_vocab 285 | self.rel_vocab = rel_vocab 286 | self.ent_freq = ent_freq 287 | 288 | self.k_negative_samples = k_negative_samples 289 | self.nega_samp_weight = self._nagative_sampling_weight() 290 | 291 | self.data = self.read_file(data_dir) 292 | 293 | def read_file(self, path): 294 | data = [] 295 | with open(path, 'r', encoding='utf-8') as fin: 296 | for x in fin: 297 | instance = json.loads(x) 298 | # n_word_nodes = instance['n_word_nodes'] 299 | words, entities, relations = self._split_nodes(instance['nodes'], instance['token_type_ids']) 300 | anchor_entities = self._find_anchor_entities(instance['adj'], len(words), entities) 301 | # entities = self._replace_anchor_entities(anchor_entities, entities) 302 | entities = [self.ent_vocab[ent] for ent in entities] 303 | anchor_entities = [self.ent_vocab[ent] for ent in anchor_entities] 304 | relations = [self.rel_vocab[rel] for rel in relations] 305 | words, word_mlm_labels = create_mlm_labels(words, self.word_mask_index, self.word_vocab_size) 306 | entities, entity_mlm_labels = create_mlm_labels(entities, self.ent_vocab[''], len(self.ent_vocab), 307 | anchor_nodes=anchor_entities) 308 | relations, relation_mlm_labels = create_mlm_labels(relations, self.rel_vocab[''], 309 | len(self.rel_vocab)) 310 | assert len(instance['nodes']) == len(words + entities + relations) 311 | assert len(instance['nodes']) == len(instance['soft_position']) 312 | assert len(instance['nodes']) == len(instance['adj']) 313 | assert len(instance['nodes']) == len(instance['token_type_ids']) 314 | data.append({ 315 | 'input_ids': words + entities + relations, 316 | 'n_word_nodes': len(words), 317 | 'n_entity_nodes': len(entities), 318 | 'position_ids': instance['soft_position'], 319 | 'attention_mask': instance['adj'], 320 | 'token_type_ids': instance['token_type_ids'], 321 | 'masked_lm_labels': word_mlm_labels, 322 | 'ent_masked_lm_labels': entity_mlm_labels, 323 | 'rel_masked_lm_labels': relation_mlm_labels 324 | }) 325 | # data is a list of dict 326 | return data 327 | 328 | def __getitem__(self, item): 329 | return self.data[item] 330 | 331 | def __len__(self): 332 | return len(self.data) 333 | 334 | def _split_nodes(self, nodes, types): 335 | assert len(nodes) == len(types) 336 | words, entities, relations = [], [], [] 337 | for node, type in zip(nodes, types): 338 | if type == 0: 339 | words.append(node) 340 | elif type == 1: 341 | entities.append(node) 342 | elif type == 2: 343 | relations.append(node) 344 | else: 345 | raise ValueError('unknown token type id.') 346 | return words, entities, relations 347 | 348 | def _nagative_sampling_weight(self, pwr=0.75): 349 | ef = [] 350 | for i, ent in enumerate(self.ent_vocab.keys()): 351 | assert self.ent_vocab[ent] == i 352 | ef.append(self.ent_freq[ent]) 353 | # freq = np.array([self.ent_freq[ent] for ent in self.ent_vocab.keys()]) 354 | ef = np.array(ef) 355 | ef = ef / ef.sum() 356 | ef = np.power(ef, pwr) 357 | ef = ef / ef.sum() 358 | return torch.FloatTensor(ef) 359 | 360 | def _find_anchor_entities(self, adj, n_word_nodes, entities): 361 | anchor_entities = [] 362 | for i in adj[:n_word_nodes]: 363 | ents = i[n_word_nodes:n_word_nodes + len(entities)] 364 | for j, mask in enumerate(ents): 365 | if mask == 1 and entities[j] not in anchor_entities: 366 | anchor_entities.append(entities[j]) 367 | 368 | return anchor_entities 369 | 370 | def collate_fn(self, batch): 371 | # batch: [[x1:dict, y1:dict], [x2:dict, y2:dict], ...] 372 | input_keys = ['input_ids', 'n_word_nodes', 'n_entity_nodes', 'position_ids', 'attention_mask', 'ent_index', 373 | 'masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 'token_type_ids'] 374 | target_keys = ['masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 375 | 'word_seq_len', 'ent_seq_len', 'rel_seq_len'] 376 | max_word_nodes, max_entity_nodes, max_relation_nodes = 0, 0, 0 377 | batch_word, batch_entity, batch_relation = [], [], [] 378 | batch_x = {n: [] for n in input_keys} 379 | batch_y = {n: [] for n in target_keys} 380 | 381 | ent_convert_dict = {} 382 | ent_index = [] 383 | # convert list of dict into dict of list 384 | for sample in batch: 385 | for n, v in sample.items(): 386 | if n in input_keys: 387 | batch_x[n].append(v) 388 | if n in target_keys: 389 | batch_y[n].append(v) 390 | # for n, v in x.items(): 391 | # batch_x[n].append(v) 392 | # for n, v in y.items(): 393 | # batch_y[n].append(v) 394 | batch_x['ent_index'].append([]) 395 | n_word_nodes = sample['n_word_nodes'] 396 | n_entity_nodes = sample['n_entity_nodes'] 397 | words = sample['input_ids'][0:n_word_nodes] 398 | entities = sample['input_ids'][n_word_nodes:n_word_nodes + n_entity_nodes] 399 | relations = sample['input_ids'][n_word_nodes + n_entity_nodes:] 400 | batch_word.append(words) 401 | batch_entity.append(entities) 402 | batch_relation.append(relations) 403 | 404 | batch_y['word_seq_len'].append(n_word_nodes) 405 | batch_y['ent_seq_len'].append(n_entity_nodes) 406 | batch_y['rel_seq_len'].append(len(relations)) 407 | 408 | max_word_nodes = len(words) if len(words) > max_word_nodes else max_word_nodes 409 | max_entity_nodes = len(entities) if len(entities) > max_entity_nodes else max_entity_nodes 410 | max_relation_nodes = len(relations) if len(relations) > max_relation_nodes else max_relation_nodes 411 | 412 | for golden_ent in sample['ent_masked_lm_labels']: 413 | if golden_ent >= 0 and golden_ent not in ent_convert_dict: 414 | ent_convert_dict[golden_ent] = len(ent_convert_dict) 415 | ent_index.append(golden_ent) 416 | 417 | # # check convert 418 | # for i in range(len(ent_index)): 419 | # assert ent_convert_dict[ent_index[i]] == i 420 | 421 | if len(ent_index) > 0: 422 | # negative sampling 423 | k_negas = self.k_negative_samples * len(ent_index) 424 | nega_samples = torch.multinomial(self.nega_samp_weight, num_samples=k_negas, replacement=True) 425 | for nega_ent in nega_samples: 426 | ent = int(nega_ent) 427 | if ent not in ent_convert_dict: # 保证无重复 428 | ent_convert_dict[ent] = len(ent_convert_dict) 429 | ent_index.append(ent) 430 | else: 431 | ent_index = [ENTITY_PADDING_INDEX] 432 | 433 | # pad 434 | seq_len = max_word_nodes + max(max_entity_nodes, 1) + max(max_relation_nodes, 1) 435 | for i in range(len(batch_word)): 436 | word_pad = max_word_nodes - len(batch_word[i]) 437 | entity_pad = max_entity_nodes - len(batch_entity[i]) if max_entity_nodes > 0 else 1 438 | relation_pad = max_relation_nodes - len(batch_relation[i]) if max_relation_nodes > 0 else 1 439 | batch_x['input_ids'][i] = batch_word[i] + [WORD_PADDING_INDEX] * word_pad + \ 440 | batch_entity[i] + [ENTITY_PADDING_INDEX] * entity_pad + \ 441 | batch_relation[i] + [RELATION_PADDING_INDEX] * relation_pad 442 | 443 | n_words = batch_x['n_word_nodes'][i] 444 | n_entities = batch_x['n_entity_nodes'][i] 445 | batch_x['position_ids'][i] = batch_x['position_ids'][i][:n_words] + [0] * word_pad + \ 446 | batch_x['position_ids'][i][n_words:n_words + n_entities] + [0] * entity_pad + \ 447 | batch_x['position_ids'][i][n_words + n_entities:] + [0] * relation_pad 448 | 449 | batch_x['token_type_ids'][i] = batch_x['token_type_ids'][i][:n_words] + [0] * word_pad + \ 450 | batch_x['token_type_ids'][i][n_words:n_words + n_entities] + [ 451 | 0] * entity_pad + \ 452 | batch_x['token_type_ids'][i][n_words + n_entities:] + [0] * relation_pad 453 | 454 | adj = torch.tensor(batch_x['attention_mask'][i], dtype=torch.int) 455 | adj = torch.cat((adj[:n_words, :], 456 | torch.ones(word_pad, adj.shape[1], dtype=torch.int), 457 | adj[n_words:n_words + n_entities, :], 458 | torch.ones(entity_pad, adj.shape[1], dtype=torch.int), 459 | adj[n_words + n_entities:, :], 460 | torch.ones(relation_pad, adj.shape[1], dtype=torch.int)), dim=0) 461 | assert adj.shape[0] == seq_len 462 | adj = torch.cat((adj[:, :n_words], 463 | torch.zeros(seq_len, word_pad, dtype=torch.int), 464 | adj[:, n_words:n_words + n_entities], 465 | torch.zeros(seq_len, entity_pad, dtype=torch.int), 466 | adj[:, n_words + n_entities:], 467 | torch.zeros(seq_len, relation_pad, dtype=torch.int)), dim=1) 468 | 469 | batch_x['attention_mask'][i] = adj 470 | batch_x['masked_lm_labels'][i] = batch_x['masked_lm_labels'][i] + [-1] * word_pad 471 | batch_y['masked_lm_labels'][i] = batch_y['masked_lm_labels'][i] + [-1] * word_pad 472 | 473 | batch_x['ent_masked_lm_labels'][i] = [ent_convert_dict[lb] if lb in ent_convert_dict else -1 for lb in 474 | batch_x['ent_masked_lm_labels'][i]] + [-1] * entity_pad 475 | batch_y['ent_masked_lm_labels'][i] = batch_x['ent_masked_lm_labels'][i] 476 | 477 | batch_x['rel_masked_lm_labels'][i] = batch_x['rel_masked_lm_labels'][i] + [-1] * relation_pad 478 | batch_y['rel_masked_lm_labels'][i] = batch_x['rel_masked_lm_labels'][i] 479 | 480 | batch_x['n_word_nodes'][i] = max(max_word_nodes, 1) 481 | batch_x['n_entity_nodes'][i] = max(max_entity_nodes, 1) 482 | batch_x['ent_index'][i] = ent_index 483 | 484 | for k, v in batch_x.items(): 485 | if k == 'attention_mask': 486 | batch_x[k] = torch.stack(v, dim=0) 487 | else: 488 | batch_x[k] = torch.tensor(v) 489 | for k, v in batch_y.items(): 490 | batch_y[k] = torch.tensor(v) 491 | return (batch_x, batch_y) 492 | 493 | 494 | class FewRelDevDataSet(Dataset): 495 | def __init__(self, path, label_vocab, ent_vocab): 496 | self.label_vocab = label_vocab 497 | self.ent_vocab = ent_vocab 498 | self.data = [] 499 | with open(path, 'r', encoding='utf-8') as fin: 500 | raw_data = json.load(fin) 501 | for ins in raw_data: 502 | nodes_index = [] 503 | for node in ins['nodes']: 504 | if isinstance(node, str): 505 | nodes_index.append(ent_vocab[node]) 506 | else: 507 | nodes_index.append(node) 508 | 509 | self.data.append({ 510 | 'input_ids': nodes_index, 511 | # 'n_word_nodes': n_words, 512 | # 'n_entity_nodes': 2, 513 | 'position_ids': ins['soft_position'], 514 | 'attention_mask': ins['adj'], 515 | # 'token_type_ids': [0]*n_words + [1]*2 + [2], 516 | 'target': label_vocab[ins['label']] 517 | }) 518 | 519 | def __len__(self): 520 | return len(self.data) 521 | 522 | def __getitem__(self, item): 523 | return self.data[item] 524 | 525 | def collate_fn(self, batch): 526 | input_keys = ['input_ids', 'n_word_nodes', 'n_entity_nodes', 'position_ids', 'attention_mask', 'ent_index', 527 | 'masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 'token_type_ids'] 528 | target_keys = ['masked_lm_labels', 'ent_masked_lm_labels', 'rel_masked_lm_labels', 529 | 'word_seq_len', 'ent_seq_len', 'rel_seq_len'] 530 | max_nodes = 0 531 | batch_x = {n: [] for n in input_keys} 532 | batch_y = {n: [] for n in target_keys} 533 | 534 | ent_index = [ENTITY_PADDING_INDEX, 3] 535 | for sample in batch: 536 | max_nodes = len(sample['input_ids']) if len(sample['input_ids']) > max_nodes else max_nodes 537 | 538 | for sample in batch: 539 | word_pad = max_nodes - len(sample['input_ids']) 540 | n_words = len(sample['input_ids']) - 3 541 | batch_y['word_seq_len'].append(n_words) 542 | batch_y['ent_seq_len'].append(2) 543 | batch_y['rel_seq_len'].append(1) 544 | 545 | batch_x['input_ids'].append(sample['input_ids'][:-3] + [WORD_PADDING_INDEX] * word_pad + sample['input_ids'][-3:]) 546 | batch_x['n_word_nodes'].append(max_nodes - 3) 547 | batch_x['n_entity_nodes'].append(2) 548 | batch_x['position_ids'].append(sample['position_ids'][:-3] + [0] * word_pad + sample['position_ids'][-3:]) 549 | adj = torch.tensor(sample['attention_mask'], dtype=torch.int) 550 | adj = torch.cat((adj[:-3, :], 551 | torch.ones(word_pad, adj.shape[1], dtype=torch.int), 552 | adj[-3:, :]), dim=0) 553 | 554 | adj = torch.cat((adj[:, :-3], 555 | torch.zeros(max_nodes, word_pad, dtype=torch.int), 556 | adj[:, -3:]), dim=1) 557 | 558 | batch_x['attention_mask'].append(adj) 559 | batch_x['token_type_ids'].append([0] * (max_nodes - 3) + [1, 1, 2]) 560 | batch_x['ent_index'].append(ent_index) 561 | batch_x['masked_lm_labels'].append([-1] * (max_nodes - 3)) 562 | batch_x['ent_masked_lm_labels'].append([-1, -1]) 563 | batch_x['rel_masked_lm_labels'].append([sample['target']]) 564 | 565 | batch_y['masked_lm_labels'].append([-1] * (max_nodes - 3)) 566 | batch_y['ent_masked_lm_labels'].append([-1, -1]) 567 | batch_y['rel_masked_lm_labels'].append([sample['target']]) 568 | 569 | for k, v in batch_x.items(): 570 | if k == 'attention_mask': 571 | batch_x[k] = torch.stack(v, dim=0) 572 | else: 573 | batch_x[k] = torch.tensor(v) 574 | for k, v in batch_y.items(): 575 | batch_y[k] = torch.tensor(v) 576 | 577 | return (batch_x, batch_y) 578 | -------------------------------------------------------------------------------- /pretrain/emb_ip.cfg: -------------------------------------------------------------------------------- 1 | 127.0.0.1 201203 1 -------------------------------------------------------------------------------- /pretrain/init_ent_rel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import numpy as np 5 | from transformers import RobertaConfig, RobertaModel, RobertaTokenizer 6 | 7 | 8 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 9 | roberta = RobertaModel.from_pretrained('roberta-base') 10 | path = '../wikidata5m_alias' 11 | if not os.path.exists('../wikidata5m_alias_emb'): 12 | os.makedirs('../wikidata5m_alias_emb') 13 | 14 | with open('../read_ent_vocab.bin', 'rb') as fin: 15 | ent_vocab = pickle.load(fin) 16 | with open('../read_rel_vocab.bin', 'rb') as fin: 17 | rel_vocab = pickle.load(fin) 18 | print(len(ent_vocab)) 19 | print(len(rel_vocab)) 20 | 21 | aliases = {} 22 | with open(os.path.join(path, 'wikidata5m_entity.txt'), 'r', encoding='utf-8') as fin: 23 | for line in fin: 24 | segs = line.strip().split('\t') 25 | entity = segs[0] 26 | alias = segs[1:] 27 | aliases[entity] = alias 28 | print(len(aliases)) 29 | 30 | miss = 0 31 | entity_embeddings = [] 32 | for k, v in ent_vocab.items(): 33 | if k in aliases: 34 | alias = aliases[k][0] 35 | tokens = tokenizer.encode(' '+alias, add_special_tokens=False) 36 | embedding = roberta.embeddings.word_embeddings(torch.tensor(tokens).view(1,-1)).squeeze(0).mean(dim=0) 37 | else: 38 | miss += 1 39 | embedding = torch.randn(768) / 10 40 | entity_embeddings.append(embedding) 41 | 42 | assert len(ent_vocab) == len(entity_embeddings) 43 | entity_embeddings = torch.stack(entity_embeddings, dim=0) 44 | print(miss * 1.0 / len(ent_vocab)) 45 | print(entity_embeddings.shape) 46 | 47 | np.save('../wikidata5m_alias_emb/entities.npy', entity_embeddings.detach().numpy()) 48 | del entity_embeddings 49 | 50 | rel_aliases = {} 51 | with open(os.path.join(path, 'wikidata5m_relation.txt'), 'r', encoding='utf-8') as fin: 52 | for line in fin: 53 | segs = line.strip().split('\t') 54 | relation = segs[0] 55 | alias = segs[1:] 56 | rel_aliases[relation] = alias 57 | 58 | miss = 0 59 | relation_embeddings = [] 60 | for k, v in rel_vocab.items(): 61 | if k in rel_aliases: 62 | alias = rel_aliases[k][0] 63 | tokens = tokenizer.encode(' '+alias, add_special_tokens=False) 64 | embedding = roberta.embeddings.word_embeddings(torch.tensor(tokens).view(1,-1)).squeeze(0).mean(dim=0) 65 | else: 66 | miss += 1 67 | embedding = torch.randn(768) / 10 68 | relation_embeddings.append(embedding) 69 | 70 | assert len(rel_vocab) == len(relation_embeddings) 71 | relation_embeddings = torch.stack(relation_embeddings, dim=0) 72 | print(relation_embeddings.shape) 73 | print(miss * 1.0 / len(ent_vocab)) 74 | np.save('../wikidata5m_alias_emb/relations.npy', relation_embeddings.detach().numpy()) -------------------------------------------------------------------------------- /pretrain/large_emb.py: -------------------------------------------------------------------------------- 1 | """ 2 | callback 3 | - wrapper of client 4 | client 5 | - pass emb_idx, return embeddings 6 | - pass emb_grads to server 7 | server 8 | - given emb_idx, return embeddings 9 | - given grad, implement adam to update embeddings 10 | """ 11 | import os 12 | from dgl.contrib import KVClient, KVServer, read_ip_config 13 | import socket 14 | import torch 15 | import math 16 | from torch import nn 17 | from argparse import ArgumentParser 18 | from fastNLP import Callback, get_local_rank 19 | from torch import distributed as dist 20 | from torch.optim import Adam 21 | import numpy as np 22 | 23 | def is_overflow(n): 24 | norm = torch.norm(n) 25 | if norm == float('inf') or norm == float('nan'): 26 | return True 27 | return False 28 | 29 | def row_sparse_adagrad(name, ID, data, target, args): 30 | """Row-Sparse Adagrad update function 31 | """ 32 | lr = args['lr'] 33 | original_name = name[0:-6] 34 | state_sum = target[original_name + '_state-data-'] 35 | # import pdb; pdb.set_trace() 36 | grad_sum = (data * data).mean(1) 37 | state_sum.index_add_(0, ID, grad_sum) 38 | std = state_sum[ID] # _sparse_mask 39 | std_values = std.sqrt_().add_(1e-10).unsqueeze(1) 40 | tmp = (-lr * data / std_values) 41 | target[name].index_add_(0, ID, tmp) 42 | 43 | 44 | def sparse_adam(name, ID, data, target, args): 45 | lr = args['lr'] 46 | b1, b2 = args['betas'] 47 | eps = args['eps'] 48 | 49 | original_name = name[0:-6] 50 | exp_avg_name = original_name + '_exp_avg-data-' 51 | exp_avg_sq_name = original_name + '_exp_avg_sq-data-' 52 | # Exponential moving average of gradient values 53 | exp_avg = target[exp_avg_name][ID] 54 | # Exponential moving average of squared gradient values 55 | exp_avg_sq = target[exp_avg_sq_name][ID] 56 | step = target[original_name + '_step-data-'][0].add_(1) 57 | 58 | exp_avg.mul_(b1).add_(1 - b1, data) 59 | exp_avg_sq.mul_(b2).addcmul_(1 - b2, data, data) 60 | 61 | # bias_correction1 = (1-torch.zeros_like(step).fill_(b1).pow_(step)).view(-1, 1) 62 | # bias_correction2 = (1-torch.zeros_like(step).fill_(b2).pow_(step)).view(-1, 1) 63 | # denom = (exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) 64 | bias_correction1 = 1 - b1 ** step 65 | bias_correction2 = 1 - b2 ** step 66 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) 67 | 68 | step_size = lr / bias_correction1 69 | update = -step_size * exp_avg / denom 70 | # print('grad: {:.3f}, update: {:.3f}, exp_avg: {:.3f}, exp_avg_sq: {:.3f}'.format( 71 | # torch.norm(data), torch.norm(update), torch.norm(exp_avg), torch.norm(exp_avg_sq))) 72 | 73 | target[name].index_add_(0, ID, update) 74 | target[exp_avg_name][ID] = exp_avg 75 | target[exp_avg_sq_name][ID] = exp_avg_sq 76 | 77 | 78 | class EmbServer(KVServer): 79 | def __init__(self, server_id, server_namebook, num_client, queue_size=20 * 1024 * 1024 * 1024, net_type='socket'): 80 | super().__init__(server_id, server_namebook, num_client, queue_size, net_type) 81 | self._udf_push_handler = sparse_adam 82 | self._args = {} 83 | 84 | def set_args(self, args): 85 | self._args.update(args) 86 | 87 | def set_push_handler(self, handler): 88 | self._udf_push_handler = handler 89 | 90 | def _push_handler(self, name, ID, data, target): 91 | """push gradient only""" 92 | self._udf_push_handler(name, ID, data, target, self._args) 93 | 94 | 95 | class EmbClient(KVClient): 96 | def __init__(self, server_namebook, queue_size=20 * 1024 * 1024 * 1024, net_type='socket'): 97 | super().__init__(server_namebook, queue_size, net_type) 98 | self._udf_push_handler = sparse_adam 99 | self._args = {} 100 | 101 | def set_args(self, args): 102 | self._args.update(args) 103 | 104 | def set_push_handler(self, handler): 105 | self._udf_push_handler = handler 106 | 107 | def _push_handler(self, name, ID, data, target): 108 | """push gradient only""" 109 | self._udf_push_handler(name, ID, data, target, self._args) 110 | 111 | 112 | def check_port_available(port): 113 | """Return True is port is available to use 114 | """ 115 | while True: 116 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 117 | try: 118 | s.bind(('', port)) ## Try to open port 119 | except OSError as e: 120 | if e.errno is 98: ## Errorno 98 means address already bound 121 | return False 122 | raise e 123 | s.close() 124 | 125 | return True 126 | 127 | 128 | def start_server(args): 129 | # torch.set_num_threads(1) 130 | 131 | server_namebook = read_ip_config(filename=args.ip_config) 132 | 133 | port = server_namebook[args.server_id][2] 134 | if check_port_available(port) == False: 135 | print("Error: port %d is not available." % port) 136 | exit() 137 | 138 | my_server = EmbServer(server_id=args.server_id, 139 | server_namebook=server_namebook, 140 | num_client=args.total_client) 141 | 142 | my_server.set_args({'lr': args.lr, 'betas': (0.9, 0.999), 'eps': 1e-8}) 143 | my_server.set_push_handler(sparse_adam) 144 | 145 | if os.path.exists(args.ent_emb): 146 | print('load pretrained entity embedding: {}.\nnum_ent and ent_dim are ignored.'.format(args.ent_emb)) 147 | ent_emb = np.load(args.ent_emb) 148 | if args.add_special_tokens: 149 | add_embs = np.random.randn(3, ent_emb.shape[1]) # add , , 150 | ent_emb = np.r_[add_embs, ent_emb] 151 | entity_emb = torch.from_numpy(ent_emb).float() 152 | print('shape: ', entity_emb.shape) 153 | else: 154 | entity_emb = torch.randn((args.num_ent, args.ent_dim)) 155 | name = args.emb_name 156 | print('starting server...') 157 | my_server.init_data(name=name, data_tensor=entity_emb) 158 | my_server.init_data(name=name + '_exp_avg', data_tensor=torch.zeros_like(entity_emb)) 159 | my_server.init_data(name=name + '_exp_avg_sq', data_tensor=torch.zeros_like(entity_emb)) 160 | my_server.init_data(name=name + '_step', data_tensor=torch.zeros(entity_emb.size(0))) 161 | # my_server.init_data(name=name + '_state', data_tensor=torch.zeros(entity_emb.size(0))) 162 | 163 | print('KVServer listen at {}:{}'.format(my_server._ip, my_server._port)) 164 | my_server.start() 165 | 166 | 167 | class LargeEmbedding(nn.Module): 168 | """ 169 | :param ip_config: path to ip_config file 170 | :param emb_name: name of emb in server 171 | :param lr: learning rate 172 | :param num_emb: number of embeddings 173 | """ 174 | 175 | def __init__(self, ip_config, emb_name, lr, num_emb): 176 | super().__init__() 177 | server_namebook = read_ip_config(ip_config) 178 | self.client = EmbClient(server_namebook) 179 | optim_args = {'lr': lr, 'betas': (0.9, 0.999), 'eps': 1e-8} 180 | self.client.set_args(optim_args) 181 | self.client.set_push_handler(sparse_adam) 182 | self.client.connect() 183 | if not dist.is_initialized() or get_local_rank() == 0: 184 | self.client.set_partition_book(emb_name, torch.zeros(num_emb)) 185 | else: 186 | self.client.set_partition_book(emb_name, None) 187 | self.name = emb_name 188 | self.trace = [] 189 | self.num_emb = num_emb 190 | 191 | def __del__(self): 192 | self.client.shut_down() 193 | 194 | def forward(self, idx): 195 | """pull emb from server""" 196 | with torch.no_grad(): 197 | bsz, slen = idx.size() 198 | cpu_idx = idx.cpu() 199 | unique_idx = torch.unique(cpu_idx) 200 | unique_emb = self.client.pull(self.name, unique_idx) 201 | gpu_emb = unique_emb.to(idx.device).detach_().requires_grad_(True) 202 | idx_mapping = {i.item(): j for j, i in enumerate(unique_idx)} 203 | mapped_idx = torch.zeros((bsz, slen), dtype=torch.long) 204 | for i in range(bsz): 205 | for j in range(slen): 206 | mapped_idx[i][j] = idx_mapping[cpu_idx[i][j].item()] 207 | 208 | # emb = torch.index_select(gpu_emb, 0, mapped_idx.cuda().view(-1)).view(bsz, slen, -1) 209 | emb = torch.embedding(gpu_emb, mapped_idx.cuda()) 210 | # print('emb norm: {:.3f}, dtype: {}'.format(torch.norm(emb.data), emb.dtype)) 211 | if self.training: 212 | self.trace.append((unique_idx, gpu_emb)) 213 | return emb 214 | 215 | def update(self, skip=False): 216 | """push grad to server""" 217 | # print('update skip: {}'.format(skip)) 218 | if skip: 219 | self.trace.clear() 220 | return 221 | with torch.no_grad(): 222 | for idx, gpu_emb in self.trace: 223 | if gpu_emb.grad is not None: 224 | grad = gpu_emb.grad.cpu() 225 | self.client.push(self.name, idx, grad) 226 | self.trace.clear() 227 | 228 | def save(self, save_path): 229 | """pull all entity embeddings from server and save to disk""" 230 | with torch.no_grad(): 231 | ent_emb = self.client.pull(self.name, torch.arange(0, self.num_emb)) 232 | ent_emb = ent_emb.cpu().detach_().numpy() 233 | np.save(os.path.join(save_path, 'entities.npy'), ent_emb) 234 | 235 | 236 | class EmbUpdateCallback(Callback): 237 | def __init__(self, large_emb): 238 | super().__init__() 239 | self.large_emb = large_emb 240 | 241 | def on_backward_end(self): 242 | # gradient overflow时 optimizer就会被patch 243 | if hasattr(self.optimizer, '_amp_stash'): 244 | # print('optimizer patched {}'.format(self.optimizer._amp_stash.already_patched)) 245 | if self.optimizer._amp_stash.already_patched: 246 | self.large_emb.update(skip=True) 247 | return 248 | self.large_emb.update(skip=False) 249 | 250 | 251 | if __name__ == '__main__': 252 | parser = ArgumentParser() 253 | parser.add_argument('--lr', type=float, default=3e-4, help='learning rate to update the emb') 254 | parser.add_argument('--server_id', type=int, default=0, help='id of server, from 0 to (num of server - 1)') 255 | parser.add_argument('--ip_config', type=str, default='emb_ip.cfg', 256 | help='path to config of server, every line in file is [ip] [port] [Num of server]') 257 | parser.add_argument('--emb_name', type=str, default='entity_emb', 258 | help='name of the embedding, should match with clients\' side') 259 | parser.add_argument('--num_ent', type=int, default=3000000, help='num of embeddings') # 500w 260 | parser.add_argument('--ent_dim', type=int, default=200, help='embedding dim') 261 | parser.add_argument('--total_client', type=int, default=1, help='num of client') 262 | parser.add_argument('--add_special_tokens', action='store_true', help='whether to add special tokens') 263 | parser.add_argument('--ent_emb', type=str, default=None) 264 | args = parser.parse_args() 265 | start_server(args) 266 | 267 | 268 | 269 | -------------------------------------------------------------------------------- /pretrain/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fastNLP import AccuracyMetric 3 | from fastNLP.core.utils import _get_func_signature 4 | 5 | 6 | class MLMAccuracyMetric(AccuracyMetric): 7 | 8 | def __init__(self, pred=None, target=None, seq_len=None): 9 | super(MLMAccuracyMetric, self).__init__(pred, target, seq_len) 10 | 11 | def evaluate(self, pred, target, seq_len=None): 12 | 13 | if not isinstance(pred, torch.Tensor): 14 | raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 15 | f"got {type(pred)}.") 16 | if not isinstance(target, torch.Tensor): 17 | raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 18 | f"got {type(target)}.") 19 | 20 | if seq_len is not None and not isinstance(seq_len, torch.Tensor): 21 | raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," 22 | f"got {type(seq_len)}.") 23 | 24 | masks = target != -1 # ignore_index = -1 25 | 26 | if pred.dim() == target.dim(): 27 | pass 28 | elif pred.dim() == target.dim() + 1: 29 | pred = pred.argmax(dim=-1) 30 | else: 31 | raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " 32 | f"size:{pred.size()}, target should have size: {pred.size()} or " 33 | f"{pred.size()[:-1]}, got {target.size()}.") 34 | 35 | target = target.to(pred) 36 | 37 | self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() 38 | self.total += torch.sum(masks).item() 39 | 40 | 41 | class WordMLMAccuracy(MLMAccuracyMetric): 42 | def __init__(self, pred=None, target=None, seq_len=None): 43 | super(WordMLMAccuracy, self).__init__(pred, target, seq_len) 44 | 45 | 46 | class EntityMLMAccuracy(MLMAccuracyMetric): 47 | def __init__(self, pred=None, target=None, seq_len=None): 48 | super(EntityMLMAccuracy, self).__init__(pred, target, seq_len) 49 | 50 | 51 | class RelationMLMAccuracy(MLMAccuracyMetric): 52 | def __init__(self, pred=None, target=None, seq_len=None): 53 | super(RelationMLMAccuracy, self).__init__(pred, target, seq_len) 54 | -------------------------------------------------------------------------------- /pretrain/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import CrossEntropyLoss 5 | from pretrain.large_emb import LargeEmbedding 6 | 7 | from transformers import RobertaConfig, RobertaForMaskedLM, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 8 | from transformers.modeling_bert import BertLayerNorm, gelu 9 | 10 | 11 | class CoLAKE(RobertaForMaskedLM): 12 | config_class = RobertaConfig 13 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 14 | base_model_prefix = "roberta" 15 | 16 | def __init__(self, config, num_ent, num_rel, ent_lr, ip_config='emb_ip.cfg', rel_emb=None, emb_name='entity_emb'): 17 | super().__init__(config) 18 | # self.ent_embeddings = nn.Embedding(num_ent, ent_dim, padding_idx=1) 19 | self.ent_embeddings = LargeEmbedding(ip_config, emb_name, ent_lr, num_ent) 20 | self.rel_embeddings = nn.Embedding(num_rel, config.hidden_size, padding_idx=1) 21 | self.ent_lm_head = EntLMHead(config) 22 | self.rel_lm_head = RelLMHead(config, num_rel) 23 | self.apply(self._init_weights) 24 | if rel_emb is not None: 25 | self.rel_embeddings = nn.Embedding.from_pretrained(rel_emb, padding_idx=1) 26 | print('pre-trained relation embeddings loaded.') 27 | self.tie_rel_weights() 28 | 29 | def extend_type_embedding(self, token_type=3): 30 | self.roberta.embeddings.token_type_embeddings = nn.Embedding(token_type, self.config.hidden_size, 31 | _weight=torch.zeros( 32 | (token_type, self.config.hidden_size))) 33 | 34 | def tie_rel_weights(self): 35 | self.rel_lm_head.decoder.weight = self.rel_embeddings.weight 36 | if getattr(self.rel_lm_head.decoder, "bias", None) is not None: 37 | self.rel_lm_head.decoder.bias.data = torch.nn.functional.pad( 38 | self.rel_lm_head.decoder.bias.data, 39 | (0, self.rel_lm_head.decoder.weight.shape[0] - self.rel_lm_head.decoder.bias.shape[0],), 40 | "constant", 41 | 0, 42 | ) 43 | 44 | def forward( 45 | self, 46 | input_ids=None, 47 | attention_mask=None, 48 | token_type_ids=None, 49 | position_ids=None, 50 | head_mask=None, 51 | inputs_embeds=None, 52 | masked_lm_labels=None, 53 | ent_masked_lm_labels=None, 54 | rel_masked_lm_labels=None, 55 | n_word_nodes=None, 56 | n_entity_nodes=None, 57 | ent_index=None 58 | ): 59 | n_word_nodes = n_word_nodes[0] 60 | n_entity_nodes = n_entity_nodes[0] 61 | word_embeddings = self.roberta.embeddings.word_embeddings( 62 | input_ids[:, :n_word_nodes]) # batch x n_word_nodes x hidden_size 63 | 64 | ent_embeddings = self.ent_embeddings( 65 | input_ids[:, n_word_nodes:n_word_nodes + n_entity_nodes]) 66 | 67 | rel_embeddings = self.rel_embeddings( 68 | input_ids[:, n_word_nodes + n_entity_nodes:]) 69 | 70 | inputs_embeds = torch.cat([word_embeddings, ent_embeddings, rel_embeddings], 71 | dim=1) # batch x seq_len x hidden_size 72 | 73 | outputs = self.roberta( 74 | input_ids=None, 75 | attention_mask=attention_mask, 76 | token_type_ids=token_type_ids, 77 | position_ids=position_ids, 78 | head_mask=head_mask, 79 | inputs_embeds=inputs_embeds, 80 | ) 81 | sequence_output = outputs[0] # batch x seq_len x hidden_size 82 | 83 | loss_fct = CrossEntropyLoss(ignore_index=-1, reduction='mean') 84 | word_logits = self.lm_head(sequence_output[:, :n_word_nodes, :]) 85 | word_predict = torch.argmax(word_logits, dim=-1) 86 | masked_lm_loss = loss_fct(word_logits.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) 87 | ent_cls_weight = self.ent_embeddings(ent_index[0].view(1,-1)).squeeze() 88 | ent_logits = self.ent_lm_head(sequence_output[:, n_word_nodes:n_word_nodes + n_entity_nodes, :], 89 | ent_cls_weight) 90 | ent_predict = torch.argmax(ent_logits, dim=-1) 91 | ent_masked_lm_loss = loss_fct(ent_logits.view(-1, ent_logits.size(-1)), ent_masked_lm_labels.view(-1)) 92 | 93 | rel_logits = self.rel_lm_head(sequence_output[:, n_word_nodes + n_entity_nodes:, :]) 94 | rel_predict = torch.argmax(rel_logits, dim=-1) 95 | rel_masked_lm_loss = loss_fct(rel_logits.view(-1, rel_logits.size(-1)), rel_masked_lm_labels.view(-1)) 96 | loss = masked_lm_loss + ent_masked_lm_loss + rel_masked_lm_loss 97 | return {'loss': loss, 98 | 'word_pred': word_predict, 99 | 'entity_pred': ent_predict, 100 | 'relation_pred': rel_predict} 101 | 102 | 103 | class EntLMHead(nn.Module): 104 | def __init__(self, config): 105 | super().__init__() 106 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 107 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 108 | # self.dropout = nn.Dropout(p=dropout) 109 | 110 | def forward(self, features, weight, **kwargs): 111 | x = self.dense(features) 112 | x = gelu(x) 113 | # x = self.dropout(x) 114 | x = self.layer_norm(x) 115 | x = x.matmul(weight.t()) 116 | 117 | return x 118 | 119 | 120 | class RelLMHead(nn.Module): 121 | def __init__(self, config, num_rel): 122 | super().__init__() 123 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 124 | self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps) 125 | 126 | self.decoder = nn.Linear(config.hidden_size, num_rel, bias=False) 127 | self.bias = nn.Parameter(torch.zeros(num_rel), requires_grad=True) 128 | # self.dropout = nn.Dropout(p=dropout) 129 | 130 | self.decoder.bias = self.bias 131 | 132 | def forward(self, features, **kwargs): 133 | x = self.dense(features) 134 | x = gelu(x) 135 | # x = self.dropout(x) 136 | x = self.layer_norm(x) 137 | 138 | x = self.decoder(x) 139 | 140 | return x 141 | -------------------------------------------------------------------------------- /pretrain/run_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | sys.path.append('../') 6 | 7 | import fitlog 8 | import argparse 9 | import torch 10 | from torch import optim 11 | import torch.distributed as dist 12 | from transformers import RobertaTokenizer 13 | from fastNLP import RandomSampler, TorchLoaderIter, LossInForward, Tester, DistTrainer, get_local_rank 14 | from fastNLP import FitlogCallback, WarmupCallback, GradientClipCallback, logger, init_logger_dist 15 | from pretrain.model import CoLAKE 16 | from pretrain.utils import OTFDistributedSampler, SaveModelCallback 17 | from pretrain.utils import load_ent_rel_vocabs, get_ent_freq, MyFitlogCallback 18 | from pretrain.dataset import GraphOTFDataSet, GraphDataSet, FewRelDevDataSet 19 | from pretrain.metrics import WordMLMAccuracy, EntityMLMAccuracy, RelationMLMAccuracy 20 | from pretrain.large_emb import EmbUpdateCallback 21 | from transformers import PYTORCH_PRETRAINED_BERT_CACHE, RobertaConfig 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--name', type=str, default='test', help="experiment name") 27 | parser.add_argument('--data_dir', type=str, 28 | default='../pretrain_data/data', 29 | help="data directory path") 30 | parser.add_argument('--test_data', type=str, default=None, 31 | help="fewrel test data directory path") 32 | parser.add_argument('--save_dir', type=str, default='../ckpts/', 33 | help="model directory path") 34 | parser.add_argument('--log_dir', type=str, default='./pretrain_logs', 35 | help="fitlog directory path") 36 | parser.add_argument('--rel_emb', type=str, 37 | default='../wikidata5m_alias_emb/relations.npy') 38 | parser.add_argument('--kg_path', type=str, default='../wikidata5m') 39 | parser.add_argument('--emb_name', type=str, default='entity_emb') 40 | parser.add_argument('--data_prop', type=float, default=0.3, help="using what proportion of wiki to train") 41 | parser.add_argument('--n_negs', type=int, default=100, help="number of negative samples") 42 | parser.add_argument('--batch_size', type=int, default=256, help="batch size") 43 | parser.add_argument('--lr', type=float, default=1e-4, help="learning rate") 44 | parser.add_argument('--beta', type=float, default=0.999, help="beta_2 in Adam") 45 | parser.add_argument('--warm_up', type=float, default=0.1, help="warmup proportion or steps") 46 | parser.add_argument('--epoch', type=int, default=1, help="number of epochs") 47 | parser.add_argument('--grad_accumulation', type=int, default=4, help="gradient accumulation") 48 | parser.add_argument('--local_rank', type=int, default=0, help="local rank") 49 | parser.add_argument('--fp16', action='store_true', help="whether to use fp16") 50 | parser.add_argument('--save_model', action='store_true', help="whether save model") 51 | parser.add_argument('--do_test', action='store_true', help="test trained model") 52 | parser.add_argument('--debug', action='store_true', help="do not log") 53 | parser.add_argument('--model_name', type=str, default=None, help="test or further train") 54 | parser.add_argument('--ent_dim', type=int, default=200, help="dimension of entity embeddings") 55 | parser.add_argument('--rel_dim', type=int, default=200, help="dimension of relation embeddings") 56 | parser.add_argument('--ip_config', type=str, default='emb_ip.cfg') 57 | parser.add_argument('--ent_lr', type=float, default=1e-4, help="entity embedding learning rate") 58 | return parser.parse_args() 59 | 60 | 61 | def train(): 62 | args = parse_args() 63 | if args.debug: 64 | fitlog.debug() 65 | args.save_model = False 66 | # ================= define ================= 67 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 68 | word_mask_index = tokenizer.mask_token_id 69 | word_vocab_size = len(tokenizer) 70 | 71 | if get_local_rank() == 0: 72 | fitlog.set_log_dir(args.log_dir) 73 | fitlog.commit(__file__, fit_msg=args.name) 74 | fitlog.add_hyper_in_file(__file__) 75 | fitlog.add_hyper(args) 76 | 77 | # ================= load data ================= 78 | dist.init_process_group('nccl') 79 | init_logger_dist() 80 | 81 | n_proc = dist.get_world_size() 82 | bsz = args.batch_size // args.grad_accumulation // n_proc 83 | args.local_rank = get_local_rank() 84 | args.save_dir = os.path.join(args.save_dir, args.name) if args.save_model else None 85 | if args.save_dir is not None and os.path.exists(args.save_dir): 86 | raise RuntimeError('save_dir has already existed.') 87 | logger.info('save directory: {}'.format('None' if args.save_dir is None else args.save_dir)) 88 | devices = list(range(torch.cuda.device_count())) 89 | NUM_WORKERS = 4 90 | 91 | ent_vocab, rel_vocab = load_ent_rel_vocabs() 92 | logger.info('# entities: {}'.format(len(ent_vocab))) 93 | logger.info('# relations: {}'.format(len(rel_vocab))) 94 | ent_freq = get_ent_freq() 95 | assert len(ent_vocab) == len(ent_freq), '{} {}'.format(len(ent_vocab), len(ent_freq)) 96 | 97 | ##### 98 | root = args.data_dir 99 | dirs = os.listdir(root) 100 | drop_files = [] 101 | for dir in dirs: 102 | path = os.path.join(root, dir) 103 | max_idx = 0 104 | for file_name in os.listdir(path): 105 | if 'large' in file_name: 106 | continue 107 | max_idx = int(file_name) if int(file_name) > max_idx else max_idx 108 | drop_files.append(os.path.join(path, str(max_idx))) 109 | ##### 110 | 111 | file_list = [] 112 | for path, _, filenames in os.walk(args.data_dir): 113 | for filename in filenames: 114 | file = os.path.join(path, filename) 115 | if 'large' in file or file in drop_files: 116 | continue 117 | file_list.append(file) 118 | logger.info('used {} files in {}.'.format(len(file_list), args.data_dir)) 119 | if args.data_prop > 1: 120 | used_files = file_list[:int(args.data_prop)] 121 | else: 122 | used_files = file_list[:round(args.data_prop * len(file_list))] 123 | 124 | data = GraphOTFDataSet(used_files, n_proc, args.local_rank, word_mask_index, word_vocab_size, 125 | args.n_negs, ent_vocab, rel_vocab, ent_freq) 126 | dev_data = GraphDataSet(used_files[0], word_mask_index, word_vocab_size, args.n_negs, ent_vocab, 127 | rel_vocab, ent_freq) 128 | 129 | sampler = OTFDistributedSampler(used_files, n_proc, get_local_rank()) 130 | train_data_iter = TorchLoaderIter(dataset=data, batch_size=bsz, sampler=sampler, num_workers=NUM_WORKERS, 131 | collate_fn=data.collate_fn) 132 | dev_data_iter = TorchLoaderIter(dataset=dev_data, batch_size=bsz, sampler=RandomSampler(), 133 | num_workers=NUM_WORKERS, 134 | collate_fn=dev_data.collate_fn) 135 | if args.test_data is not None: 136 | test_data = FewRelDevDataSet(path=args.test_data, label_vocab=rel_vocab, ent_vocab=ent_vocab) 137 | test_data_iter = TorchLoaderIter(dataset=test_data, batch_size=32, sampler=RandomSampler(), 138 | num_workers=NUM_WORKERS, 139 | collate_fn=test_data.collate_fn) 140 | 141 | if args.local_rank == 0: 142 | print('full wiki files: {}'.format(len(file_list))) 143 | print('used wiki files: {}'.format(len(used_files))) 144 | print('# of trained samples: {}'.format(len(data) * n_proc)) 145 | print('# of trained entities: {}'.format(len(ent_vocab))) 146 | print('# of trained relations: {}'.format(len(rel_vocab))) 147 | 148 | # ================= prepare model ================= 149 | logger.info('model init') 150 | if args.rel_emb is not None: # load pretrained relation embeddings 151 | rel_emb = np.load(args.rel_emb) 152 | # add_embs = np.random.randn(3, rel_emb.shape[1]) # add , , 153 | # rel_emb = np.r_[add_embs, rel_emb] 154 | rel_emb = torch.from_numpy(rel_emb).float() 155 | assert rel_emb.shape[0] == len(rel_vocab), '{} {}'.format(rel_emb.shape[0], len(rel_vocab)) 156 | # assert rel_emb.shape[1] == args.rel_dim 157 | logger.info('loaded pretrained relation embeddings. dim: {}'.format(rel_emb.shape[1])) 158 | else: 159 | rel_emb = None 160 | if args.model_name is not None: 161 | logger.info('further pre-train.') 162 | config = RobertaConfig.from_pretrained('roberta-base', type_vocab_size=3) 163 | model = CoLAKE(config=config, 164 | num_ent=len(ent_vocab), 165 | num_rel=len(rel_vocab), 166 | ent_dim=args.ent_dim, 167 | rel_dim=args.rel_dim, 168 | ent_lr=args.ent_lr, 169 | ip_config=args.ip_config, 170 | rel_emb=None, 171 | emb_name=args.emb_name) 172 | states_dict = torch.load(args.model_name) 173 | model.load_state_dict(states_dict, strict=True) 174 | else: 175 | model = CoLAKE.from_pretrained('roberta-base', 176 | num_ent=len(ent_vocab), 177 | num_rel=len(rel_vocab), 178 | ent_lr=args.ent_lr, 179 | ip_config=args.ip_config, 180 | rel_emb=rel_emb, 181 | emb_name=args.emb_name, 182 | cache_dir=PYTORCH_PRETRAINED_BERT_CACHE + '/dist_{}'.format(args.local_rank)) 183 | model.extend_type_embedding(token_type=3) 184 | # if args.local_rank == 0: 185 | # for name, param in model.named_parameters(): 186 | # if param.requires_grad is True: 187 | # print('{}: {}'.format(name, param.shape)) 188 | 189 | # ================= train model ================= 190 | # lr=1e-4 for peak value, lr=5e-5 for initial value 191 | logger.info('trainer init') 192 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'layer_norm.bias', 'layer_norm.weight'] 193 | param_optimizer = list(model.named_parameters()) 194 | optimizer_grouped_parameters = [ 195 | {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 196 | {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 197 | ] 198 | word_acc = WordMLMAccuracy(pred='word_pred', target='masked_lm_labels', seq_len='word_seq_len') 199 | ent_acc = EntityMLMAccuracy(pred='entity_pred', target='ent_masked_lm_labels', seq_len='ent_seq_len') 200 | rel_acc = RelationMLMAccuracy(pred='relation_pred', target='rel_masked_lm_labels', seq_len='rel_seq_len') 201 | metrics = [word_acc, ent_acc, rel_acc] 202 | 203 | if args.test_data is not None: 204 | test_metric = [rel_acc] 205 | tester = Tester(data=test_data_iter, model=model, metrics=test_metric, device=list(range(torch.cuda.device_count()))) 206 | # tester.test() 207 | else: 208 | tester = None 209 | 210 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.lr, betas=(0.9, args.beta), eps=1e-6) 211 | # warmup_callback = WarmupCallback(warmup=args.warm_up, schedule='linear') 212 | fitlog_callback = MyFitlogCallback(tester=tester, log_loss_every=100, verbose=1) 213 | gradient_clip_callback = GradientClipCallback(clip_value=1, clip_type='norm') 214 | emb_callback = EmbUpdateCallback(model.ent_embeddings) 215 | all_callbacks = [gradient_clip_callback, emb_callback] 216 | if args.save_dir is None: 217 | master_callbacks = [fitlog_callback] 218 | else: 219 | save_callback = SaveModelCallback(args.save_dir, model.ent_embeddings, only_params=True) 220 | master_callbacks = [fitlog_callback, save_callback] 221 | 222 | if args.do_test: 223 | states_dict = torch.load(os.path.join(args.save_dir, args.model_name)).state_dict() 224 | model.load_state_dict(states_dict) 225 | data_iter = TorchLoaderIter(dataset=data, batch_size=args.batch_size, sampler=RandomSampler(), 226 | num_workers=NUM_WORKERS, 227 | collate_fn=data.collate_fn) 228 | tester = Tester(data=data_iter, model=model, metrics=metrics, device=devices) 229 | tester.test() 230 | else: 231 | trainer = DistTrainer(train_data=train_data_iter, 232 | dev_data=dev_data_iter, 233 | model=model, 234 | optimizer=optimizer, 235 | loss=LossInForward(), 236 | batch_size_per_gpu=bsz, 237 | update_every=args.grad_accumulation, 238 | n_epochs=args.epoch, 239 | metrics=metrics, 240 | callbacks_master=master_callbacks, 241 | callbacks_all=all_callbacks, 242 | validate_every=5000, 243 | use_tqdm=True, 244 | fp16='O1' if args.fp16 else '') 245 | trainer.train(load_best_model=False) 246 | 247 | 248 | if __name__ == '__main__': 249 | train() 250 | -------------------------------------------------------------------------------- /pretrain/run_pretrain.sh: -------------------------------------------------------------------------------- 1 | EMB_NAME=entity_emb 2 | LR=1e-4 3 | 4 | python large_emb.py --lr $LR --total_client 8 --emb_name $EMB_NAME \ 5 | --ent_emb ../wikidata5m_alias_emb/entities.npy & 6 | python -m torch.distributed.launch --nproc_per_node=8 run_pretrain.py --name CoLAKE --data_prop 1.0 \ 7 | --batch_size 2048 --lr $LR --ent_lr $LR --epoch 1 --grad_accumulation 16 --save_model --emb_name $EMB_NAME \ 8 | --n_negs 200 --beta 0.98 -------------------------------------------------------------------------------- /pretrain/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pickle 4 | import torch 5 | import collections 6 | import numpy as np 7 | from fastNLP import Callback, cache_results 8 | from torch.utils.data import Sampler 9 | from itertools import chain 10 | import pandas as pd 11 | from copy import deepcopy 12 | from fastNLP import Tester, DataSet 13 | import fitlog 14 | 15 | 16 | MaskedLmInstance = collections.namedtuple("MaskedLmInstance", ["index", "label"]) 17 | 18 | 19 | class OTFDistributedSampler(Sampler): 20 | # On-The-Fly sampler 21 | def __init__(self, indexed_train_fps, n_workers, rank, shuffle=True): 22 | super(OTFDistributedSampler, self).__init__(0) 23 | self.epoch = 0 24 | self.shuffle = shuffle 25 | 26 | file_per_process = len(indexed_train_fps) // n_workers 27 | if file_per_process * n_workers != len(indexed_train_fps): 28 | if rank == 0: 29 | print('[Sampler] Drop {} files.'.format(len(indexed_train_fps) - file_per_process * n_workers)) 30 | print('[Sampler] # files per process: {}'.format(file_per_process)) 31 | self.fps = indexed_train_fps[rank * file_per_process:(rank + 1) * file_per_process] 32 | self.file_per_process = file_per_process 33 | 34 | data = [] 35 | with open(self.fps[0], 'r', encoding='utf-8') as fin: 36 | import json 37 | for x in fin: 38 | data.append(json.loads(x)) 39 | self.num_samples_per_file = len(data) 40 | assert self.num_samples_per_file == 5000 41 | self.total_num_samples = self.num_samples_per_file * len(self.fps) 42 | 43 | def __iter__(self): 44 | # deterministically shuffle based on epoch 45 | g = torch.Generator() 46 | g.manual_seed(self.epoch) 47 | if self.shuffle: 48 | indices = [] 49 | for i in range(self.file_per_process): 50 | indexes = list(np.arange(self.num_samples_per_file) + i * self.num_samples_per_file) # indices within one file 51 | np.random.shuffle(indexes) 52 | indices.append(indexes) 53 | np.random.shuffle(indices) 54 | indices = list(chain(*indices)) 55 | else: 56 | indices = [] 57 | for i in range(self.file_per_process): 58 | indexes = list(np.arange(self.num_samples_per_file) + i * self.num_samples_per_file) 59 | indices.extend(indexes) 60 | 61 | return iter(indices) 62 | 63 | def __len__(self): 64 | return self.total_num_samples 65 | 66 | def set_epoch(self, epoch): 67 | self.epoch = epoch 68 | 69 | 70 | class SaveModelCallback(Callback): 71 | def __init__(self, save_path, ent_emb, only_params=True): 72 | super(SaveModelCallback, self).__init__() 73 | self.save_path = save_path 74 | self.only_params = only_params 75 | self.ent_emb = ent_emb 76 | 77 | def on_epoch_end(self): 78 | if self.is_master: 79 | path = os.path.join(self.save_path, 'epoch_' + str(self.epoch)) 80 | os.makedirs(path, exist_ok=True) 81 | model_path = os.path.join(path, 'model.bin') 82 | model_to_save = self.trainer.ddp_model.module 83 | if self.only_params: 84 | model_to_save = model_to_save.state_dict() 85 | torch.save(model_to_save, model_path) 86 | self.ent_emb.save(path) 87 | self.trainer.logger.info('Saved checkpoint to {}.'.format(path)) 88 | 89 | 90 | def create_mlm_labels(tokens, mask_index, vocab_size, masked_lm_prob=0.15, max_predictions_per_seq=15, anchor_nodes=None): 91 | rng = random.Random(2020) 92 | cand_indexes = [] 93 | if mask_index == 50264: # indicates word nodes 94 | special_tokens = [0, 1, 2, 3] # 0: , 1: , 2: , 3: 95 | else: 96 | special_tokens = [0, 1] # 0: 1: 97 | for (i, token) in enumerate(tokens): 98 | if token in special_tokens: 99 | continue 100 | cand_indexes.append(i) 101 | 102 | rng.shuffle(cand_indexes) 103 | output_tokens = list(tokens) 104 | num_to_predict = min(max_predictions_per_seq, 105 | max(1, int(round(len(tokens) * masked_lm_prob)))) 106 | masked_labels = [] 107 | covered_indexes = set() 108 | for index in cand_indexes: 109 | if len(masked_labels) >= num_to_predict: 110 | if anchor_nodes is None: 111 | break 112 | elif tokens[index] not in anchor_nodes: 113 | continue 114 | else: # tokens[index] is anchor node 115 | if index in covered_indexes: 116 | continue 117 | covered_indexes.add(index) 118 | if rng.random() < 0.8: 119 | masked_token = tokens[index] # 以80%概率是本身 120 | else: 121 | if rng.random() < 0.5: 122 | masked_token = mask_index 123 | else: 124 | masked_token = rng.randint(0, vocab_size - 1) 125 | else: 126 | if index in covered_indexes: 127 | continue 128 | covered_indexes.add(index) 129 | if rng.random() < 0.8: 130 | masked_token = mask_index # [MASK] 131 | else: 132 | if rng.random() < 0.5: 133 | masked_token = tokens[index] 134 | else: 135 | masked_token = rng.randint(0, vocab_size - 1) 136 | output_tokens[index] = masked_token 137 | masked_labels.append(MaskedLmInstance(index=index, label=tokens[index])) 138 | masked_labels = sorted(masked_labels, key=lambda x: x.index) 139 | masked_lm_positions = [] 140 | masked_lm_labels = [] 141 | for p in masked_labels: 142 | masked_lm_positions.append(p.index) 143 | masked_lm_labels.append(p.label) 144 | masked_labels = np.ones(len(tokens), dtype=int) * -1 145 | masked_labels[masked_lm_positions] = masked_lm_labels 146 | masked_labels = list(masked_labels) 147 | return output_tokens, masked_labels 148 | 149 | 150 | class MyFitlogCallback(Callback): 151 | def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): 152 | super().__init__() 153 | self.datasets = {} 154 | self.testers = {} 155 | self._log_exception = log_exception 156 | assert isinstance(log_loss_every, int) and log_loss_every >= 0 157 | if tester is not None: 158 | if isinstance(tester, dict): 159 | for name, test in tester.items(): 160 | if not isinstance(test, Tester): 161 | raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") 162 | self.testers['tester-' + name] = test 163 | if isinstance(tester, Tester): 164 | self.testers['tester-test'] = tester 165 | for tester in self.testers.values(): 166 | setattr(tester, 'verbose', 0) 167 | 168 | if isinstance(data, dict): 169 | for key, value in data.items(): 170 | assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." 171 | for key, value in data.items(): 172 | self.datasets['data-' + key] = value 173 | elif isinstance(data, DataSet): 174 | self.datasets['data-test'] = data 175 | elif data is not None: 176 | raise TypeError("data receives dict[DataSet] or DataSet object.") 177 | 178 | self.verbose = verbose 179 | self._log_loss_every = log_loss_every 180 | self._avg_loss = 0 181 | 182 | def on_train_begin(self): 183 | if len(self.datasets) > 0: 184 | for key, data in self.datasets.items(): 185 | tester = Tester(data=data, model=self.model, 186 | batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), 187 | metrics=self.trainer.metrics, 188 | verbose=0, 189 | use_tqdm=self.trainer.test_use_tqdm) 190 | self.testers[key] = tester 191 | fitlog.add_progress(total_steps=self.n_steps) 192 | 193 | def on_valid_end(self, eval_result, metric_key, optimizer, better_result): 194 | if better_result: 195 | eval_result = deepcopy(eval_result) 196 | eval_result['step'] = self.step 197 | eval_result['epoch'] = self.epoch 198 | fitlog.add_best_metric(eval_result) 199 | fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) 200 | if len(self.testers) > 0: 201 | for key, tester in self.testers.items(): 202 | try: 203 | eval_result = tester.test() 204 | if self.verbose != 0: 205 | self.pbar.write("FitlogCallback evaluation on {}:".format(key)) 206 | self.pbar.write(tester._format_eval_results(eval_result)) 207 | fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) 208 | if better_result: 209 | fitlog.add_best_metric(eval_result, name=key) 210 | except Exception as e: 211 | self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) 212 | raise e 213 | 214 | def on_train_end(self): 215 | fitlog.finish() 216 | 217 | def on_exception(self, exception): 218 | fitlog.finish(status=1) 219 | if self._log_exception: 220 | fitlog.add_other(repr(exception), name='except_info') 221 | 222 | 223 | @cache_results(_cache_fp='ent_freq.bin', _refresh=False) 224 | def get_ent_freq(path): 225 | with open(path, 'rb') as fin: 226 | ent_freq = pickle.load(fin) 227 | print('# of entities: {}'.format(len(ent_freq))) 228 | return ent_freq 229 | 230 | 231 | @cache_results(_cache_fp='ent_rel_vocab.bin', _refresh=False) 232 | def load_ent_rel_vocabs(path): 233 | with open(os.path.join(path, 'read_ent_vocab.bin'), 'rb') as fin: 234 | ent_vocab = pickle.load(fin) 235 | print('# of entities: {}'.format(len(ent_vocab))) 236 | 237 | with open(os.path.join(path, 'read_rel_vocab.bin'), 'rb') as fin: 238 | rel_vocab = pickle.load(fin) 239 | print('# of relations: {}'.format(len(rel_vocab))) 240 | 241 | return ent_vocab, rel_vocab 242 | 243 | 244 | 245 | 246 | 247 | -------------------------------------------------------------------------------- /read_ent_vocab.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txsun1997/CoLAKE/c9874bcbe6824a7e6bea0821aaaaf629dccb801e/read_ent_vocab.bin -------------------------------------------------------------------------------- /read_rel_vocab.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/txsun1997/CoLAKE/c9874bcbe6824a7e6bea0821aaaaf629dccb801e/read_rel_vocab.bin --------------------------------------------------------------------------------