├── README.md ├── logs └── zeshel_hyper_param.txt ├── muver ├── __init__.py ├── multi_view │ ├── data_loader.py │ ├── model.py │ ├── train.py │ └── zeshel_evaluate.py └── utils │ ├── logger.py │ ├── multigpu.py │ ├── params.py │ └── tools.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # MuVER 2 | This repo contains the code and pre-trained model for our EMNLP 2021 paper: 3 | **MuVER: Improving First-Stage Entity Retrieval with Multi-View Entity Representations**. Xinyin Ma, Yong Jiang, Nguyen Bach, Tao Wang, Zhongqiang Huang, Fei Huang, Weiming Lu 4 | 5 | ## Quick Start 6 | ### 1. Requirements 7 | The requirements for our code are listed in requirements.txt, install the package with the following command: 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | For [huggingface/transformers](https://github.com/huggingface/transformers), we tested it under version 4.1.X and 4.2.X. 12 | 13 | ### 2. Download data and model 14 | * Data: 15 | We follow facebookresearch/BLINK to download and preprocess data. See [instructions](https://github.com/facebookresearch/BLINK/tree/master/examples/zeshel) about how to download and convert to BLINK format. You will get a folder with the following structure: 16 | ``` 17 | - zeshel 18 | | - mentions 19 | | - documents 20 | | - blink_format 21 | ``` 22 | 23 | * Model: 24 | Model for zeshel can be downloaded on https://drive.google.com/file/d/1BBTue5Vmr3MteGcse-ePqplWjccqm9_A/view?usp=sharing 25 | 26 | ### 3. Use the released model to reproduce our results 27 | * **Without View Merging**: 28 | ``` 29 | export PYTHONPATH='.' 30 | CUDA_VISIBLE_DEVICES=YOUR_GPU_DEVICES python muver/multi_view/train.py 31 | --pretrained_model path_to_model/bert-base 32 | --dataset_path path_to_dataset/zeshel 33 | --bi_ckpt_path path_to_model/best_zeshel.bin 34 | --max_cand_len 40 35 | --max_seq_len 128 36 | --do_test 37 | --test_mode test 38 | --data_parallel 39 | --eval_batch_size 16 40 | --accumulate_score 41 | ``` 42 | 43 | 44 | Expected Result: 45 | 46 | | World | R@1 | R@2 | R@4 | R@8 | R@16 | R@32 | R@50 | R@64 | 47 | |------------------|--------|--------|--------|--------|--------|--------|--------|--------| 48 | | forgotten_realms | 0.6208 | 0.7783 | 0.8592 | 0.8983 | 0.9342 | 0.9533 | 0.9633 | 0.9700 | 49 | | lego | 0.4904 | 0.6714 | 0.7690 | 0.8357 | 0.8791 | 0.9091 | 0.9208 | 0.9249 | 50 | | star_trek | 0.4743 | 0.6130 | 0.6967 | 0.7606 | 0.8159 | 0.8581 | 0.8805 | 0.8919 | 51 | | yugioh | 0.3432 | 0.4861 | 0.6040 | 0.7004 | 0.7596 | 0.8201 | 0.8512 | 0.8672 | 52 | | total | 0.4496 | 0.5970 | 0.6936 | 0.7658 | 0.8187 | 0.8628 | 0.8854 | 0.8969 | 53 | 54 | * **With View Merging**: 55 | ``` 56 | export PYTHONPATH='.' 57 | CUDA_VISIBLE_DEVICES=YOUR_GPU_DEVICES python muver/multi_view/train.py 58 | --pretrained_model path_to_model/bert-base 59 | --dataset_path path_to_dataset/zeshel 60 | --bi_ckpt_path path_to_model/best_zeshel.bin 61 | --max_cand_len 40 62 | --max_seq_len 128 63 | --do_test 64 | --test_mode test 65 | --data_parallel 66 | --eval_batch_size 16 67 | --accumulate_score 68 | --view_expansion 69 | --merge_layers 4 70 | --top_k 0.4 71 | ``` 72 | Expected result: 73 | | World | R@1 | R@2 | R@4 | R@8 | R@16 | R@32 | R@50 | R@64 | 74 | |------------------|--------|--------|--------|--------|--------|--------|--------|--------| 75 | | forgotten_realms | 0.6175 | 0.7867 | 0.8733 | 0.9150 | 0.9375 | 0.9600 | 0.9675 | 0.9708 | 76 | | lego | 0.5046 | 0.6889 | 0.7882 | 0.8449 | 0.8882 | 0.9183 | 0.9324 | 0.9374 | 77 | | star_trek | 0.4810 | 0.6253 | 0.7121 | 0.7783 | 0.8271 | 0.8706 | 0.8935 | 0.9030 | 78 | | yugioh | 0.3444 | 0.5027 | 0.6322 | 0.7300 | 0.7902 | 0.8429 | 0.8690 | 0.8826 | 79 | | total | 0.4541 | 0.6109 | 0.7136 | 0.7864 | 0.8352 | 0.8777 | 0.8988 | 0.9084 | 80 | 81 | Optional Argument: 82 | * --data_parallel: whether you want to use multiple gpus. 83 | * --accumulate_score: accumulate score for each entity. Obtain a higher score but will take much time to inference. 84 | * --view_expansion: whether you want to merge and expand view. 85 | * --top_k: top_k pairs are expected to merge in each layer. 86 | * --merge_layers: the number of layers for merging. 87 | * --test_mode: If you want to generate candidates for train/dev set, change the test_mode to train or dev, which will generate candidates outputs and save it under the directory where you save the test model. 88 | 89 | ### 4. How to train your MuVER 90 | We provice the code to train your MuVER. Train the code with the following command: 91 | ``` 92 | export PYTHONPATH='.' 93 | CUDA_VISIBLE_DEVICES=YOUR_GPU_DEVICES python muver/multi_view/train.py 94 | --pretrained_model path_to_model/bert-base 95 | --dataset_path path_to_dataset/zeshel 96 | --epoch 30 97 | --train_batch_size 128 98 | --learning_rate 1e-5 99 | --do_train --do_eval 100 | --data_parallel 101 | --name distributed_multi_view 102 | ``` 103 | **Important**: Since constrastive learning relies heavily on a large batch size, as reported in our paper, we use eight v100(16g) to train our model. The hyperparameters for our best model are in `logs/zeshel_hyper_param.txt` 104 | 105 | The code will create a directory `runtime_log` to save the log, model and the hyperparameter you used. Everytime you trained your model(with or without grid search), it will create a directory under `runtime_log/name_in_your_args/start_time`, e.g., `runtime_log/distributed_multi_view/2021-09-07-15-12-21`, to store all the checkpoints, curve for visualization and the training log. 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /logs/zeshel_hyper_param.txt: -------------------------------------------------------------------------------- 1 | - Training Parameters: 2 | - pretrained_model: models/bert-base/ 3 | - dataset_path: data/zeshel 4 | - cross_ckpt_path: None 5 | - bi_ckpt_path: None 6 | - epoch: 30 7 | - train_batch_size: 128 8 | - max_cand_len: 40 9 | - max_seq_len: 128 10 | - max_sentence_num: 10 11 | - learning_rate: 1e-05 12 | - weight_decay: 0.01 13 | - warmup_ratio: 0.1 14 | - max_grad_norm: 1.0 15 | - gradient_accumulation: 1 16 | - eval_batch_size: 8 17 | - logging_interval: 50 18 | - eval_interval: 500 19 | - do_train: True 20 | - do_eval: True 21 | - do_test: True 22 | - debug: True 23 | - data_parallel: True 24 | - no_cuda: False 25 | - seed: 10000 26 | - name: distributed_multi_view 27 | - n_gpu: 8 28 | - local_rank: 0 29 | - device: cuda 30 | -------------------------------------------------------------------------------- /muver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alibaba-NLP/MuVER/cbc7d7f63f4630d66cbcfc8d83a6f609bebe1329/muver/__init__.py -------------------------------------------------------------------------------- /muver/multi_view/data_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : data_loader.py 5 | @Time : 2021/04/07 16:13:52 6 | @Author : Xinyin Ma 7 | @Version : 0.1 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | 11 | import os 12 | import json 13 | import random 14 | from tqdm import tqdm 15 | import numpy as np 16 | 17 | import nltk 18 | import torch 19 | from torch.utils.data import Dataset, Sampler, DataLoader 20 | from torch.utils.data.distributed import DistributedSampler 21 | 22 | from muver.utils.params import WORLDS 23 | 24 | CLS_TAG = "[CLS]" 25 | SEP_TAG = "[SEP]" 26 | MENTION_START_TAG = "[unused0]" 27 | MENTION_END_TAG = "[unused1]" 28 | ENTITY_TAG = "[unused2]" 29 | 30 | class SubworldBatchSampler(Sampler): 31 | def __init__(self, batch_size, subworld_idx): 32 | self.batch_size = batch_size 33 | self.subworld_idx = subworld_idx 34 | 35 | def __iter__(self): 36 | for world_name, world_value in self.subworld_idx.items(): 37 | world_value['perm_idx'] = torch.randperm(len(world_value['idx'])) 38 | world_value['pointer'] = 0 39 | world_names = list(self.subworld_idx.keys()) 40 | 41 | while len(world_names) > 0: 42 | world_name = np.random.choice(world_names) 43 | world_value = self.subworld_idx[world_name] 44 | start_pointer = world_value['pointer'] 45 | sample_perm_idx = world_value['perm_idx'][start_pointer:start_pointer + self.batch_size] 46 | sample_idx = [world_value['idx'][idx] for idx in sample_perm_idx] 47 | 48 | if len(sample_idx) > 0: 49 | yield sample_idx 50 | 51 | if len(sample_idx) < self.batch_size: 52 | world_names.remove(world_name) 53 | world_value['pointer'] += self.batch_size 54 | 55 | def __len__(self): 56 | return sum([len(value) // self.batch_size + 1 for _, value in self.subworld_idx.items()]) 57 | 58 | class SubWorldDistributedSampler(DistributedSampler): 59 | def __init__(self, batch_size, subworld_idx, num_replicas, rank): 60 | self.batch_size = batch_size 61 | self.subworld_idx = subworld_idx 62 | self.num_replicas = num_replicas 63 | self.rank = rank 64 | 65 | self.epoch = 0 66 | 67 | def __iter__(self): 68 | g = torch.Generator() 69 | g.manual_seed(self.epoch) 70 | 71 | for world_name, world_value in self.subworld_idx.items(): 72 | world_value['perm_idx'] = torch.randperm(len(world_value['idx']), generator=g).tolist() 73 | world_value['pointer'] = 0 74 | world_names = list(self.subworld_idx.keys()) 75 | 76 | while len(world_names) > 0: 77 | world_idx = torch.randint(len(world_names), size=(1, ), generator=g).tolist()[0] 78 | world_name = world_names[world_idx] 79 | 80 | world_value = self.subworld_idx[world_name] 81 | start_pointer = world_value['pointer'] 82 | sample_perm_idx = world_value['perm_idx'][start_pointer : start_pointer + self.batch_size] 83 | 84 | if len(sample_perm_idx) == 0: 85 | world_names.remove(world_name) 86 | continue 87 | 88 | if len(sample_perm_idx) < self.batch_size : 89 | world_names.remove(world_name) 90 | sample_perm_idx = sample_perm_idx + world_value['perm_idx'][:self.batch_size - len(sample_perm_idx)] 91 | #print(self.rank, sample_perm_idx) 92 | sample_perm_idx = sample_perm_idx[self.rank::self.num_replicas] 93 | 94 | try: 95 | sample_idx = [world_value['idx'][idx] for idx in sample_perm_idx] 96 | assert len(sample_idx) == self.batch_size // self.num_replicas 97 | except: 98 | print(world_name, sample_perm_idx, sample_idx, len(world_value['idx'])) 99 | yield sample_idx 100 | world_value['pointer'] += self.batch_size 101 | 102 | self.epoch += 1 103 | 104 | #def __len__(self): 105 | # return sum([len(value) // self.batch_size + 1 for _, value in self.subworld_idx.items()]) 106 | 107 | class EncodeDataset(Dataset): 108 | def __init__(self, document_path, world, tokenizer, max_seq_len, max_sentence_num, all_sentences = False): 109 | self.tokenizer = tokenizer 110 | self.max_seq_len = max_seq_len 111 | self.max_sentence_num = max_sentence_num 112 | self.all_sentences = all_sentences 113 | self.world = world 114 | 115 | self.seq_lens = {} 116 | 117 | preprocess_path = os.path.join(document_path, 'preprocess_multiview') 118 | 119 | if os.path.exists(preprocess_path) and os.path.exists(os.path.join(preprocess_path, world + '.pt')): 120 | self.samples, self.entity_title_to_id = torch.load(os.path.join(preprocess_path, world + '.pt')) 121 | print("World/{}: Load {} samples".format(world, len(self.samples))) 122 | else: 123 | if not os.path.exists(preprocess_path): 124 | os.mkdir(preprocess_path) 125 | 126 | document_path = os.path.join(document_path, world + '.json') 127 | self.samples, self.entity_title_to_id = self.load_entity_description(document_path, tokenizer, world) 128 | torch.save([self.samples, self.entity_title_to_id], os.path.join(preprocess_path, world + '.pt')) 129 | 130 | def __len__(self): 131 | return len(self.samples) 132 | 133 | def get_nth_title(self, idx): 134 | return self.samples[idx]['title'] 135 | 136 | def load_entity_description(self, document_path, tokenizer, world): 137 | entity_desc = [] 138 | entity_title_to_id = {} 139 | 140 | num_lines = sum(1 for line in open(document_path, 'r')) 141 | print("World/{}: preprocessing {} samples".format(world, num_lines)) 142 | 143 | sentence_nums = {} 144 | with open(document_path, 'r') as f: 145 | for idx, line in enumerate(tqdm(f, total=num_lines)): 146 | info = json.loads(line) 147 | token_ids = self.tokenize_split_description(info['title'], info['text'], tokenizer) 148 | entity_desc.append({ 149 | "token_ids": token_ids, 150 | "title": info['title'] 151 | }) 152 | num = sentence_nums.get(len(token_ids), 0) 153 | sentence_nums[len(token_ids)] = num + 1 154 | entity_title_to_id[info['title']] = idx 155 | 156 | #print(sorted(sentence_nums.items(), key = lambda x: x[0])) 157 | return entity_desc, entity_title_to_id 158 | 159 | def tokenize_description(self, title, desc, tokenizer): 160 | encode_text = [CLS_TAG] + tokenizer.tokenize(title) + [ENTITY_TAG] + tokenizer.tokenize(desc) 161 | encode_text = encode_text[:self.max_cand_len - 1] + [SEP_TAG] 162 | 163 | tokens = tokenizer.convert_tokens_to_ids(encode_text) 164 | if len(tokens) < self.max_cand_len: 165 | tokens += [0] * (self.max_cand_len - len(tokens)) 166 | 167 | assert(len(tokens) == self.max_cand_len) 168 | return tokens 169 | 170 | def tokenize_split_description(self, title, desc, tokenizer): 171 | #if not is_split_by_sentence: 172 | # encode_text = [CLS_TAG] + tokenizer.tokenize(title) + [ENTITY_TAG] + tokenizer.tokenize(desc) 173 | # encode_text = encode_text[:self.max_cand_len - 1] + [SEP_TAG] 174 | #else: 175 | title_text = tokenizer.tokenize(title) + [ENTITY_TAG] 176 | 177 | multi_sent = [] 178 | pre_text = [] 179 | for sent in nltk.sent_tokenize(desc.replace(' .', '.')): 180 | text = tokenizer.tokenize(sent) 181 | pre_text += text 182 | if len(pre_text) <= 5: 183 | continue 184 | whole_text = title_text + pre_text 185 | whole_text = [CLS_TAG] + whole_text[:self.max_seq_len - 2] + [SEP_TAG] 186 | tokens = tokenizer.convert_tokens_to_ids(whole_text) 187 | pre_text = [] 188 | 189 | if len(tokens) < self.max_seq_len: 190 | tokens += [0] * (self.max_seq_len - len(tokens)) 191 | assert len(tokens) == self.max_seq_len 192 | multi_sent.append(tokens) 193 | 194 | return multi_sent 195 | 196 | def __getitem__(self, idx): 197 | if self.all_sentences: 198 | entity_ids = self.samples[idx]['token_ids'] 199 | else: 200 | entity_ids = self.samples[idx]['token_ids'][:self.max_sentence_num] 201 | if len(entity_ids) <= self.max_sentence_num: 202 | entity_ids += [[0] * self.max_seq_len for _ in range(self.max_sentence_num - len(entity_ids))] 203 | 204 | assert len(entity_ids) == self.max_sentence_num 205 | ''' 206 | if len(self.samples[idx]['token_ids']) <= self.max_sentence_num: 207 | entity_ids = self.samples[idx]['token_ids'] 208 | else: 209 | #sentence_idx = np.random.choice(len(self.samples[idx]['token_ids']), self.max_sentence_num) 210 | entity_ids = [] 211 | sentence_idx = [] 212 | for _ in range(self.max_sentence_num): 213 | s_idx = np.random.randint(len(self.samples[idx]['token_ids'])) 214 | while s_idx in sentence_idx: 215 | s_idx = np.random.randint(len(self.samples[idx]['token_ids'])) 216 | sentence_idx.append(s_idx) 217 | entity_ids.append(self.samples[idx]['token_ids'][s_idx]) 218 | #print("random_select_sentence: ", self.samples[idx]['title'], sentence_idx, len(self.samples[idx]['token_ids'])) 219 | if len(entity_ids) < self.max_sentence_num: 220 | entity_ids += [[0] * self.max_seq_len for _ in range(self.max_sentence_num - len(entity_ids))] 221 | ''' 222 | assert len(entity_ids) == self.max_sentence_num 223 | return { 224 | 'token_ids': entity_ids, 225 | 'title': self.samples[idx]['title'], 226 | 'title_ids': idx 227 | } 228 | 229 | def bi_collate_fn(batch): 230 | token_ids = torch.tensor([sample['token_ids'] for sample in batch]) # sentence_num * max_seq_len 231 | title = [sample['title'] for sample in batch] 232 | title_ids = torch.tensor([sample['title_ids'] for sample in batch]) 233 | return { 234 | 'token_ids': token_ids, 235 | 'title': title, 236 | 'title_ids': title_ids 237 | } 238 | 239 | class ZeshelDataset(Dataset): 240 | def __init__(self, 241 | mode, desc_path, dataset_path, tokenizer, 242 | max_cand_len = 30, max_seq_len = 128, max_sentence_num = 10, all_sentences = False, 243 | ): 244 | self.tokenizer = tokenizer 245 | self.mode = mode 246 | self.max_cand_len = max_cand_len 247 | self.max_sentence_num = max_sentence_num 248 | self.max_seq_len = max_seq_len 249 | self.all_sentences = all_sentences 250 | 251 | self.entity_desc = { 252 | world[0]: EncodeDataset( 253 | document_path = desc_path, 254 | world = world[0], 255 | tokenizer = tokenizer, 256 | max_seq_len = self.max_cand_len, 257 | max_sentence_num = self.max_sentence_num, 258 | all_sentences = self.all_sentences 259 | ) 260 | for world in WORLDS[mode] 261 | } 262 | 263 | self.load_training_samples(dataset_path, mode, max_seq_len) 264 | self.subworld_idx = self.get_subworld_idx() 265 | 266 | def get_subworld_idx(self): 267 | worlds_sample_idx = {world[0]: {'idx': [], 'num': 0} for world in WORLDS[self.mode]} 268 | for idx, sample in enumerate(self.samples): 269 | world = sample['world'] 270 | worlds_sample_idx[world]['idx'].append(idx) 271 | worlds_sample_idx[world]['num'] += 1 272 | 273 | return worlds_sample_idx 274 | 275 | def load_training_samples(self, dataset_path, mode, max_seq_len): 276 | token_path = os.path.join(dataset_path, "{}_token.jsonl".format(mode)) 277 | if os.path.exists(token_path): 278 | with open(token_path, 'r') as f: 279 | self.samples = [json.loads(line) for line in f] 280 | print("Set/{}: Load {} samples".format(mode, len(self.samples))) 281 | else: 282 | data_path = os.path.join(dataset_path, "{}.jsonl".format(mode)) 283 | num_lines = sum(1 for line in open(data_path, 'r')) 284 | 285 | self.samples = [] 286 | print("Set/{}: preprocessing {} samples".format(mode, num_lines)) 287 | 288 | with open(data_path, 'r') as sample_f: 289 | for sample_line in tqdm(sample_f, total = num_lines): 290 | sample = self.tokenize_context(json.loads(sample_line), max_seq_len) 291 | self.samples.append(sample) 292 | 293 | with open(token_path, 'w') as f: 294 | for sample in self.samples: 295 | f.write(json.dumps(sample) + '\n') 296 | 297 | def __len__(self): 298 | return len(self.samples) 299 | 300 | def __getitem__(self, idx): 301 | sample = self.samples[idx] 302 | context_ids = sample['ids'] 303 | label, world = sample['label'], sample['world'] 304 | 305 | if self.all_sentences: 306 | return { 307 | "label_ids": [-1], 308 | "context_ids": context_ids, 309 | "world": world, 310 | "label_world_idx": label 311 | } 312 | else: 313 | return { 314 | "label_ids": self.entity_desc[world][label]['token_ids'], 315 | "context_ids": context_ids, 316 | "world": world, 317 | "label_world_idx": label 318 | } 319 | 320 | def concat_context_entity_ids(self, context_ids, candidate_idx, world): 321 | if 0 in context_ids: 322 | context_ids = context_ids[:context_ids.index(0)] 323 | 324 | entity_token_ids = self.entity_desc[world][candidate_idx]['token_ids'] 325 | input_ids = context_ids + entity_token_ids[1:] 326 | padding = [0] * (self.max_cand_len + self.max_seq_len - len(input_ids)) 327 | input_ids += padding 328 | assert len(input_ids) == self.max_cand_len + self.max_seq_len 329 | 330 | return input_ids 331 | 332 | def tokenize_context( 333 | self, 334 | sample, 335 | max_seq_len 336 | ): 337 | ''' 338 | https://github.com/facebookresearch/BLINK/blob/master/blink/biencoder/data_process.py 339 | ''' 340 | mention_tokens = [] 341 | if sample['mention'] and len(sample['mention']) > 0: 342 | mention_tokens = self.tokenizer.tokenize(sample['mention']) 343 | mention_tokens = [MENTION_START_TAG] + mention_tokens + [MENTION_END_TAG] 344 | 345 | context_left = sample["context_left"] 346 | context_right = sample["context_right"] 347 | context_left = self.tokenizer.tokenize(context_left) 348 | context_right = self.tokenizer.tokenize(context_right) 349 | 350 | left_quota = (max_seq_len - len(mention_tokens)) // 2 - 1 351 | right_quota = max_seq_len - len(mention_tokens) - left_quota - 2 352 | left_add = len(context_left) 353 | right_add = len(context_right) 354 | if left_add <= left_quota: 355 | if right_add > right_quota: 356 | right_quota += left_quota - left_add 357 | else: 358 | if right_add <= right_quota: 359 | left_quota += right_quota - right_add 360 | 361 | context_tokens = ( 362 | context_left[-left_quota:] + mention_tokens + context_right[:right_quota] 363 | ) 364 | 365 | context_tokens = ["[CLS]"] + context_tokens + ["[SEP]"] 366 | input_ids = self.tokenizer.convert_tokens_to_ids(context_tokens) 367 | padding = [0] * (max_seq_len - len(input_ids)) 368 | input_ids += padding 369 | assert len(input_ids) == max_seq_len 370 | 371 | return { 372 | "tokens": context_tokens, 373 | "ids": input_ids, 374 | "label": sample['label_id'], 375 | "world": sample['world'], 376 | } 377 | 378 | def cross_collate_fn(batch): 379 | world = [sample['world'] for sample in batch] 380 | label_world_idx = torch.tensor([sample['label_world_idx'] for sample in batch]) 381 | label_ids = torch.tensor([sample['label_ids'] for sample in batch]) 382 | context_ids = torch.tensor([sample['context_ids'] for sample in batch]) 383 | #label_split = torch.tensor([sample['label_split'] for sample in batch]) 384 | 385 | return { 386 | 'context_ids': context_ids, 387 | 'label_ids': label_ids, 388 | 'world': world, 389 | 'label_world_idx': label_world_idx 390 | } 391 | 392 | 393 | -------------------------------------------------------------------------------- /muver/multi_view/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : model.py 5 | @Time : 2021/03/16 15:32:52 6 | @Author : Xinyin Ma 7 | @Version : 0.1 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | import os 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from muver.utils.multigpu import GatherLayer 16 | 17 | from transformers import BertModel, BertConfig 18 | 19 | class CrossEncoder(nn.Module): 20 | def __init__(self, pretrained_model): 21 | super(CrossEncoder, self).__init__() 22 | self.cross_encoder = BertModel.from_pretrained(pretrained_model, output_attentions=True, return_dict=True) 23 | self.additional_linear = nn.Linear(self.cross_encoder.config.hidden_size, 1) 24 | self.dropout = nn.Dropout(0.1) 25 | 26 | def to_bert_input(self, input_ids): 27 | attention_mask = 1 - (input_ids == 0).long() 28 | token_type_ids = torch.zeros_like(input_ids).long() 29 | return { 30 | 'input_ids': input_ids, 31 | 'attention_mask': attention_mask, 32 | 'token_type_ids': token_type_ids 33 | } 34 | 35 | def forward(self, ctx_ids, ctx_mask = None, target = None, is_return_score = False): 36 | ''' 37 | - ctx_ids: batch_size/top_k * seq_len 38 | - split_len: batch_size/top_k * [split_num] 39 | 40 | Only support for batch_size = 1 / top_k = 1 41 | ''' 42 | 43 | batch_size = 1 44 | top_k, seq_len = ctx_ids.size() 45 | 46 | ctx_repr = self.cross_encoder(**self.to_bert_input(ctx_ids)).last_hidden_state # (batch_size/top_k) * seq_len * hidden_siate 47 | 48 | 49 | scores = torch.cat(scores, 0).unsqueeze(0) # 1 * top_k * 9 50 | max_score = torch.max(scores, -1).values 51 | if target is not None: 52 | 53 | loss = F.cross_entropy(max_score, target, reduction="mean") 54 | predict = torch.max(max_score, -1).indices 55 | acc = sum(predict == target) * 1.0 56 | return loss, acc 57 | else: 58 | return max_score, scores 59 | 60 | 61 | class BiEncoder(nn.Module): 62 | def __init__(self, pretrained_model): 63 | super(BiEncoder, self).__init__() 64 | self.ctx_encoder = BertModel.from_pretrained(pretrained_model, output_attentions = True, return_dict = True) 65 | self.ent_encoder = BertModel.from_pretrained(pretrained_model, output_attentions = True, return_dict = True) 66 | 67 | def to_bert_input(self, input_ids): 68 | attention_mask = 1 - (input_ids == 0).long() 69 | token_type_ids = torch.zeros_like(input_ids).long() 70 | return { 71 | 'input_ids': input_ids, 72 | 'attention_mask': attention_mask, 73 | 'token_type_ids': token_type_ids 74 | } 75 | 76 | def encode_candidates(self, ent_ids, ent_mask = None, interval=500, mode='train', view_expansion=False, top_k = 0.4, merge_layers = 3): 77 | sentence_num = 0 78 | if len(ent_ids.size()) == 1: # only one sentence 79 | ent_ids = ent_ids.unsqueeze(0) 80 | batch_size,sentence_num = 1,1 81 | elif len(ent_ids.size()) == 3: 82 | batch_size, sentence_num, ent_seq_len = ent_ids.size() 83 | ent_ids = ent_ids.view(-1, ent_seq_len) 84 | else: 85 | batch_size, ent_seq_len = ent_ids.size() 86 | sentence_num = 1 87 | 88 | start_ids = 0 89 | candidate_output = [] 90 | while start_ids < sentence_num * batch_size: 91 | model_output = self.ent_encoder(**self.to_bert_input(ent_ids[start_ids:start_ids + interval])) 92 | candidate_output.append(model_output.last_hidden_state[:, 0, :]) #batch_size * seq_len * hidden_size 93 | start_ids += interval 94 | candidate_output = torch.cat(candidate_output, 0).view(batch_size, sentence_num, -1) 95 | 96 | if view_expansion: 97 | ori_views = candidate_output 98 | ori_ent_ids = ent_ids.view(batch_size, sentence_num, -1).tolist() 99 | new_pools = [] 100 | 101 | def merge_sequence(seq, ent_ids): 102 | s = ent_ids[seq[0]] 103 | sentence = s[:s.index(102)] 104 | 105 | for i in range(1, len(seq)): 106 | s = ent_ids[seq[i]] 107 | mid_sentence = s[s.index(3)+1:s.index(102)] 108 | if 0 in mid_sentence: 109 | mid_sentence = mid_sentence[:mid_sentence.index(0)] 110 | sentence += mid_sentence 111 | if len(sentence) > 511: 112 | sentence = sentence[:511] 113 | sentence += [102] 114 | return sentence 115 | 116 | def batch_sentences(sentences): 117 | max_len = max([len(s) for s in sentences]) 118 | s_tensor = torch.zeros((len(sentences), max_len), dtype=torch.int64) 119 | for i, s in enumerate(sentences): 120 | s_tensor[i, :len(s)] = torch.tensor(s) 121 | return s_tensor 122 | 123 | top_k = [top_k for _ in range(merge_layers)] 124 | target_sentence_num = int(sum(top_k) * sentence_num) 125 | for ori_view, ori_ent_id in zip(ori_views, ori_ent_ids): 126 | new_pool = [] 127 | views, seq_ids = ori_view, [[i] for i in range(len(ori_ent_id))] 128 | 129 | for layer in range(len(top_k)): 130 | new_views, new_seq_ids = [], [] 131 | dis = torch.sum(F.mse_loss(views[:-1], views[1:], reduction='none'), -1) 132 | merge_ids = torch.sort(dis, descending=True).indices[: int(views.size(0) * top_k[layer])].tolist() 133 | 134 | new_sentences = [] 135 | for i in range(len(seq_ids)): 136 | if i in merge_ids and ori_ent_id[i][0] != 0: 137 | seq_ids_a, seq_ids_b = seq_ids[i], seq_ids[i+1] 138 | seq_ids_merge = [] + seq_ids_a 139 | for ids, ids_b in enumerate(seq_ids_b): 140 | if ids_b > seq_ids_a[-1] and ori_ent_id[ids_b][0]!= 0: 141 | seq_ids_merge += [ids_b] 142 | 143 | if len(seq_ids_merge) != len(seq_ids_a): 144 | new_sentence = merge_sequence(seq_ids_merge, ori_ent_id) 145 | new_seq_ids.append(seq_ids_merge) 146 | 147 | new_sentences.append(new_sentence) 148 | #new_repr = self.ent_encoder(**self.to_bert_input(new_sentence.unsqueeze(0).cuda())).last_hidden_state[:, 0, :] 149 | #new_pool.append(new_repr) 150 | new_views.append(None) 151 | else: 152 | new_seq_ids.append(seq_ids[i]) 153 | new_views.append(views[i]) 154 | else: 155 | new_seq_ids.append(seq_ids[i]) 156 | new_views.append(views[i]) 157 | 158 | if len(new_sentences) > 0: 159 | new_sentences = batch_sentences(new_sentences) 160 | pool = [] 161 | start_ids = 0 162 | while start_ids < new_sentences.size(0): 163 | new_repr = self.ent_encoder(**self.to_bert_input(new_sentences[start_ids:start_ids+50].cuda())).last_hidden_state[:, 0, :] 164 | pool += new_repr 165 | start_ids += 50 166 | new_repr = pool 167 | new_pool += new_repr 168 | view_idx = 0 169 | for i, new_view in enumerate(new_views): 170 | if new_view is None: 171 | new_views[i] = new_repr[view_idx] 172 | view_idx += 1 173 | assert view_idx == len(new_repr) 174 | views, seq_ids = torch.stack(new_views, 0), new_seq_ids 175 | 176 | if mode == 'train': 177 | if len(new_pool) == 0: 178 | new_pool = ori_view[-1].unsqueeze(0).repeat(target_sentence_num, 1) 179 | else: 180 | new_pool = torch.stack(new_pool, 0) 181 | if len(new_pool) < target_sentence_num: 182 | new_pool = torch.cat([new_pool, ori_view[-1].unsqueeze(0).repeat(target_sentence_num - len(new_pool), 1)], 0) 183 | 184 | if mode == 'train': 185 | new_pools.append(new_pool) 186 | else: 187 | new_pools += new_pool 188 | if mode == 'train': 189 | new_pools = torch.stack(new_pools, 0) 190 | #print(new_pools.shape, ori_views.shape) 191 | candidate_output = torch.cat([ori_views, new_pools], 1) 192 | else: 193 | if len(new_pools) > 0: 194 | new_pools = torch.stack(new_pools, 0) 195 | candidate_output = torch.cat([candidate_output.squeeze(0), new_pools], 0).unsqueeze(0) 196 | return candidate_output 197 | 198 | def encode_context(self, ctx_ids, ctx_mask = None): 199 | if len(ctx_ids.size()) == 1: # only one sentence 200 | ctx_ids = ctx_ids.unsqueeze(0) 201 | model_output = self.ctx_encoder(**self.to_bert_input(ctx_ids)) 202 | context_output = model_output.last_hidden_state 203 | if ctx_mask is None: 204 | context_output = context_output[:, 0, :] 205 | else: 206 | context_output = torch.bmm(ctx_mask.unsqueeze(1), context_output).squeeze(1) 207 | return context_output 208 | 209 | def score_candidates(self, ctx_ids, ctx_world, ctx_mask = None, candidate_pool = None): 210 | # candidate_pool: (entity_num * 9) * hidden_state 211 | ctx_output = self.encode_context(ctx_ids, ctx_mask).cpu().detach() 212 | res = [] 213 | for world, ctx_repr in zip(ctx_world, ctx_output): 214 | ctx_repr = ctx_repr.to(candidate_pool[world].device) 215 | res.append(ctx_repr.unsqueeze(0).mm(candidate_pool[world].T).squeeze(0)) 216 | return res 217 | 218 | def forward(self, ctx_ids, ent_ids, num_gpus = 0): 219 | ''' 220 | if num_gpus > 1: 221 | ctx_ids = ctx_ids.to("cuda:0") 222 | ent_gpus = num_gpus - 1 223 | per_gpu_batch = ent_ids // ent_gpus 224 | 225 | ent_ids = ent_ids.to("cuda:1") 226 | elif num_gpus == 1: 227 | ctx_ids = ctx_ids.cuda() 228 | ent_ids = ent_ids.cuda() 229 | ''' 230 | batch_size, sentence_num, ent_seq_len = ent_ids.size() 231 | 232 | ctx_output = self.encode_context(ctx_ids) # batch_size * hidden_size 233 | ent_output = self.encode_candidates(ent_ids, view_expansion=False) # (batch_size * sentence_num) * hidden_size 234 | return ctx_output.contiguous(), ent_output.contiguous() 235 | 236 | class NCE_Random(nn.Module): 237 | def __init__(self, num_gpus): 238 | super(NCE_Random, self).__init__() 239 | self.num_gpus = num_gpus 240 | 241 | def forward(self, ctx_output, ent_output): 242 | batch_size = ctx_output.size(0) 243 | sentence_num = ent_output.size(0) // batch_size 244 | 245 | if self.num_gpus > 1: 246 | ctx_output = torch.cat(GatherLayer.apply(ctx_output), dim=0) 247 | ent_output = torch.cat(GatherLayer.apply(ent_output), dim=0) 248 | 249 | #score = ent_output.mm(ctx_output.T).view(batch_size * self.num_gpus, sentence_num, batch_size * self.num_gpus) #(batch_size * sentence_num) * batch_size(context) 250 | score = torch.matmul(ent_output, ctx_output.T) 251 | score = score.permute(2, 0, 1) # batch_size(context) * batch_size(entity) * sentence_num 252 | max_score = torch.max(score, -1).values 253 | target = torch.arange(score.size(0)).to(ctx_output.device) 254 | loss = F.cross_entropy(max_score, target, reduction="mean") 255 | 256 | predict = torch.max(max_score, -1).indices 257 | acc = sum(predict == target) * 1.0 / score.size(0) 258 | 259 | return loss, acc, score 260 | 261 | 262 | ''' 263 | ent_diff_output = ent_output.clone().detach().to("cuda:0") 264 | score = ent_diff_output.mm(ctx_output.T).view(batch_size, sentence_num, batch_size) #(batch_size * sentence_num) * batch_size(context) 265 | score = score.permute(2, 0, 1) # batch_size(context) * batch_size(entity) * sentence_num 266 | 267 | max_score = torch.max(score, -1).values 268 | target = torch.arange(score.size(0)).to(ctx_ids.device) 269 | context_loss = F.cross_entropy(max_score, target, reduction="mean") 270 | 271 | # on entity device 272 | ctx_diff_output = ctx_output.clone().detach().to("cuda:1") 273 | score = ent_output.mm(ctx_diff_output.T).view(batch_size, sentence_num, batch_size) #(batch_size * sentence_num) * batch_size(context) 274 | score = score.permute(2, 0, 1) # batch_size(context) * batch_size(entity) * sentence_num 275 | ''' 276 | 277 | -------------------------------------------------------------------------------- /muver/multi_view/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : train.py 5 | @Time : 2021/09/08 10:36:15 6 | @Author : Xinyin Ma 7 | @Version : 1.0 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | 11 | import os 12 | import time 13 | import json 14 | import random 15 | import argparse 16 | from tqdm import tqdm 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.distributed as dist 22 | import torch.multiprocessing as mp 23 | from torch.nn.parallel import DataParallel 24 | from torch.nn.parallel import DistributedDataParallel as DDP 25 | from torch.utils.data import DataLoader, RandomSampler 26 | 27 | from transformers import BertTokenizerFast, AdamW, get_linear_schedule_with_warmup 28 | 29 | from muver.utils.logger import LoggerWithDepth 30 | from muver.utils.tools import grid_search_hyperparamters, set_random_seed 31 | 32 | from data_loader import EncodeDataset, ZeshelDataset, cross_collate_fn, SubworldBatchSampler, SubWorldDistributedSampler 33 | from model import BiEncoder, NCE_Random 34 | from zeshel_evaluate import evaluate_bi_model 35 | 36 | 37 | 38 | def argument_parser(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--pretrained_model', type=str, required=True) 41 | parser.add_argument('--dataset_path', type=str, default='data/zeshel') 42 | parser.add_argument('--bi_ckpt_path', type=str, nargs='+', default=None) 43 | 44 | parser.add_argument('--epoch', type=int, default=5) 45 | parser.add_argument('--train_batch_size', type=int, default=12) 46 | parser.add_argument('--max_cand_len', type=int, default=40) 47 | parser.add_argument('--max_seq_len', type=int, default=128) 48 | parser.add_argument('--max_sentence_num', type=int, default=10) 49 | parser.add_argument('--learning_rate', type=float, nargs='+', default=1e-5) 50 | parser.add_argument('--weight_decay', type=float, nargs='+', default=0.01) 51 | parser.add_argument('--warmup_ratio', type=float, nargs='+', default=0.1) 52 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 53 | parser.add_argument('--gradient_accumulation', type=int, default=1) 54 | parser.add_argument('--merge_layers', type=int, default=3) 55 | parser.add_argument('--top_k', type=float, default=0.4) 56 | #parser.add_argument('--alpha', type=float, default=0.5) 57 | #parser.add_argument('--beta', type=float, nargs='+', default=50) 58 | 59 | parser.add_argument('--eval_batch_size', type=int, default=8) 60 | parser.add_argument('--logging_interval', type=int, default=50) 61 | parser.add_argument('--eval_interval', type=int, default=2000) 62 | parser.add_argument('--accumulate_score', action="store_true") 63 | 64 | parser.add_argument('--do_train', action="store_true") 65 | parser.add_argument('--do_eval', action="store_true") 66 | parser.add_argument('--do_test', action="store_true") 67 | parser.add_argument('--view_expansion', action="store_true") 68 | parser.add_argument('--test_mode', type=str, default='test') 69 | 70 | parser.add_argument("--data_parallel", action="store_true") 71 | parser.add_argument("--no_cuda", action="store_true") 72 | parser.add_argument("--seed", type=int, default=10000) 73 | parser.add_argument("--name", type=str, default='test') 74 | return parser.parse_args() 75 | 76 | 77 | def main(local_rank, args, train_dataset, valid_dataset, test_dataset, tokenizer): 78 | args.local_rank = local_rank 79 | if args.do_train and args.local_rank in [0, -1]: 80 | logger = LoggerWithDepth( 81 | env_name=args.name, 82 | config=args.__dict__, 83 | ) 84 | else: 85 | logger = None 86 | 87 | # Set Training Device 88 | if args.data_parallel: 89 | 90 | if args.n_gpu == 1: 91 | args.data_parallel = False 92 | else: 93 | dist.init_process_group("nccl", rank=args.local_rank, world_size=args.n_gpu) 94 | torch.cuda.set_device(args.local_rank) 95 | 96 | args.device = "cuda" if not args.no_cuda else "cpu" 97 | set_random_seed(args.seed) 98 | 99 | grid_arguments = grid_search_hyperparamters(args) 100 | for grid_args in grid_arguments: 101 | if args.do_train and args.local_rank in [0, -1]: 102 | sub_name = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 103 | logger.setup_sublogger(sub_name, grid_args.__dict__) 104 | 105 | # Load Model and Tokenizer 106 | bi_model = BiEncoder(args.pretrained_model).cuda() 107 | 108 | criterion = NCE_Random(args.n_gpu) 109 | 110 | # Load From checkpoint 111 | if grid_args.bi_ckpt_path is not None: 112 | state_dict = torch.load(grid_args.bi_ckpt_path, map_location='cpu') 113 | new_state_dict = {} 114 | for param_name, param_value in state_dict.items(): 115 | if param_name[:7] == 'module.': 116 | new_state_dict[param_name[7:]] = param_value 117 | else: 118 | new_state_dict[param_name] = param_value 119 | bi_model.load_state_dict(new_state_dict) 120 | 121 | if args.n_gpu > 1: 122 | bi_model = DDP(bi_model, device_ids=[args.local_rank], find_unused_parameters=True) 123 | 124 | # Load Data 125 | if args.do_train: 126 | train_batch_size = grid_args.train_batch_size // grid_args.gradient_accumulation 127 | 128 | if args.data_parallel: 129 | sampler = SubWorldDistributedSampler(batch_size=grid_args.train_batch_size, subworld_idx=train_dataset.subworld_idx, num_replicas=args.n_gpu, rank=args.local_rank) 130 | else: 131 | sampler = SubworldBatchSampler(batch_size=grid_args.train_batch_size, subworld_idx=train_dataset.subworld_idx) 132 | train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, collate_fn = cross_collate_fn) 133 | 134 | # optimizer & scheduler 135 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 136 | optimizer_grouped_parameters = [ 137 | {'params': [p for n, p in bi_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': grid_args.weight_decay}, 138 | {'params': [p for n, p in bi_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 139 | ] 140 | optimizer = AdamW(optimizer_grouped_parameters, lr=grid_args.learning_rate) 141 | 142 | total_steps = len(train_dataset) * args.epoch // train_batch_size 143 | warmup_steps = int(grid_args.warmup_ratio * total_steps) 144 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 145 | if args.local_rank in [0, -1]: 146 | logger.writer.info("Optimization steps = {}, Warmup steps = {}".format(total_steps, warmup_steps)) 147 | 148 | # Train 149 | if args.do_train: 150 | step, max_score = 0, 0 151 | with tqdm(total = total_steps) as pbar: 152 | for e in range(args.epoch): 153 | tr_loss = [] 154 | for batch in train_dataloader: 155 | bi_model.train() 156 | step += 1 157 | 158 | world = batch['world'] 159 | for w in world[1:]: 160 | assert world[0] == w 161 | 162 | ctx_ids, ent_ids = batch['context_ids'], batch['label_ids'] 163 | 164 | ctx_ids = ctx_ids.cuda(non_blocking=True) 165 | ent_ids = ent_ids.cuda(non_blocking=True) 166 | ctx_output, ent_output = bi_model( 167 | ctx_ids = ctx_ids, 168 | ent_ids = ent_ids, 169 | num_gpus = args.n_gpu 170 | ) 171 | loss, bi_acc, bi_score = criterion(ctx_output, ent_output) 172 | loss.backward() 173 | 174 | #if args.n_gpu > 1: 175 | # dist.all_reduce(loss.div_(args.n_gpu)) 176 | if step % grid_args.gradient_accumulation == 0: 177 | if args.max_grad_norm > 0: 178 | torch.nn.utils.clip_grad_norm_(bi_model.parameters(), args.max_grad_norm) 179 | optimizer.step() 180 | scheduler.step() 181 | bi_model.zero_grad() 182 | 183 | if args.local_rank in [0, -1]: 184 | pbar.set_description("epoch: {}, loss: {}, acc: {}".format( 185 | e + 1, loss.item(), bi_acc.item() 186 | )) 187 | pbar.update() 188 | 189 | tr_loss.append(loss.item()) 190 | if step % args.logging_interval == 0 and args.local_rank in [0, -1]: 191 | logger.writer.info("Step {}: Average Loss = {}".format(step, sum(tr_loss) / len(tr_loss))) 192 | tr_loss = [] 193 | 194 | if step % args.eval_interval == 0 and args.do_eval: 195 | with torch.no_grad(): 196 | score, _ = evaluate_bi_model( 197 | bi_model, tokenizer, valid_dataset, 198 | mode='valid', 199 | device=args.device, 200 | local_rank=args.local_rank, 201 | n_gpu=args.n_gpu) 202 | 203 | if args.local_rank in [0, -1]: 204 | logger.writer.info(score) 205 | torch.save(bi_model.state_dict(), logger.lastest_checkpoint_path) 206 | 207 | if max_score < score: 208 | torch.save(bi_model.state_dict(), logger.checkpoint_path) 209 | max_score = score 210 | 211 | 212 | if args.local_rank in [0, -1]: 213 | torch.save(bi_model.state_dict(), os.path.join(logger.sub_dir, 'epoch_{}.bin'.format(e))) 214 | 215 | grid_args.best_evaluation_score = max_score 216 | if args.local_rank in [0, -1]: 217 | logger.write_description_to_folder(os.path.join(logger.sub_dir, 'description.txt'), grid_args.__dict__) 218 | del optimizer 219 | 220 | if args.do_test: 221 | if args.do_train and args.local_rank in [0, -1]: 222 | bi_model.load_state_dict(torch.load(logger.checkpoint_path, map_location='cpu')) 223 | with torch.no_grad(): 224 | score, candidates = evaluate_bi_model( 225 | bi_model, tokenizer, test_dataset, 226 | mode=args.test_mode, 227 | device = args.device, 228 | local_rank=args.local_rank, 229 | n_gpu=args.n_gpu, 230 | encode_batch_size=args.eval_batch_size, 231 | view_expansion = args.view_expansion, 232 | is_accumulate_score = args.accumulate_score, 233 | merge_layers=args.merge_layers, 234 | top_k=args.top_k) 235 | 236 | if args.local_rank in [-1, 0]: 237 | if logger is not None: 238 | result_path = os.path.join(logger.sub_dir, 'score.json') 239 | candidate_path = os.path.join(logger.sub_dir, 'candidates.json') 240 | else: 241 | dir_path = os.path.dirname(os.path.abspath(grid_args.bi_ckpt_path)) 242 | result_path = os.path.join(dir_path, '{}_score.json'.format(args.test_mode)) 243 | candidate_path = os.path.join(dir_path, '{}_candidates.json'.format(args.test_mode)) 244 | 245 | with open(result_path, 'w') as f: 246 | f.write(json.dumps(score)) 247 | 248 | with open(candidate_path, 'w') as f: 249 | for candidate in candidates: 250 | f.write(json.dumps(candidate) + '\n') 251 | 252 | del bi_model 253 | 254 | if __name__ == "__main__": 255 | args = argument_parser() 256 | print(args.__dict__) 257 | 258 | os.environ["MASTER_ADDR"] = "127.0.0.1" 259 | os.environ["MASTER_PORT"] = "18101" 260 | 261 | args.n_gpu = torch.cuda.device_count() 262 | 263 | # before multiprocessing, preprocess the data 264 | train_dataset, valid_dataset, test_dataset = None, None, None 265 | tokenizer = BertTokenizerFast.from_pretrained(args.pretrained_model) 266 | 267 | if args.do_train: 268 | train_dataset = ZeshelDataset( 269 | mode='train', 270 | desc_path=os.path.join(args.dataset_path, 'documents'), 271 | dataset_path=os.path.join(args.dataset_path, 'blink_format'), 272 | tokenizer=tokenizer, 273 | max_cand_len=args.max_cand_len, 274 | max_sentence_num=args.max_sentence_num, 275 | max_seq_len=args.max_seq_len, 276 | ) 277 | 278 | if args.do_eval: 279 | valid_dataset = ZeshelDataset( 280 | mode='valid', 281 | desc_path=os.path.join(args.dataset_path, 'documents'), 282 | dataset_path=os.path.join(args.dataset_path, 'blink_format'), 283 | tokenizer=tokenizer, 284 | max_cand_len=args.max_cand_len, 285 | max_seq_len=args.max_seq_len, 286 | all_sentences = True 287 | ) 288 | 289 | if args.do_test: 290 | test_dataset = ZeshelDataset( 291 | mode=args.test_mode, 292 | desc_path=os.path.join(args.dataset_path, 'documents'), 293 | dataset_path=os.path.join(args.dataset_path, 'blink_format'), 294 | tokenizer=tokenizer, 295 | max_cand_len=args.max_cand_len, 296 | max_seq_len=args.max_seq_len, 297 | all_sentences = True 298 | ) 299 | 300 | mp.spawn(main, args=(args, train_dataset, valid_dataset, test_dataset, tokenizer,), nprocs=args.n_gpu, join=True) 301 | -------------------------------------------------------------------------------- /muver/multi_view/zeshel_evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : evaluate.py 5 | @Time : 2021/03/17 23:41:28 6 | @Author : Xinyin Ma 7 | @Version : 0.1 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | import time 11 | import random 12 | 13 | import torch 14 | from tqdm import tqdm 15 | from torch.utils.data import DataLoader 16 | import torch.distributed as dist 17 | 18 | from prettytable import PrettyTable 19 | from muver.utils.params import WORLDS 20 | from data_loader import bi_collate_fn, cross_collate_fn 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.utils.data import SequentialSampler 23 | 24 | 25 | def pretty_visualize(scores, top_k): 26 | rows = [] 27 | for world, score in scores.items(): 28 | rows.append([world] + [round(s * 1.0 / score[1], 4) for s in score[0]]) 29 | 30 | table = PrettyTable() 31 | table.field_names = ["World"] + ["R@{}".format(k) for k in top_k] 32 | table.add_rows(rows) 33 | print(table) 34 | 35 | def evaluate_bi_model(model, tokenizer, dataset, mode, encode_batch_size = 16, device = "cpu", local_rank = -1, n_gpu = 1, 36 | view_expansion = False, top_k = 0.4, merge_layers = 3, is_accumulate_score = False): 37 | 38 | model.eval() 39 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 40 | test_module = model.module 41 | else: 42 | test_module = model 43 | 44 | world_entity_pool, world_entity_titles = {}, {} 45 | for world, world_dataset in dataset.entity_desc.items(): 46 | entity_pool, entity_title = [], [] 47 | if n_gpu > 1: 48 | sampler = DistributedSampler(world_dataset) 49 | else: 50 | sampler = SequentialSampler(world_dataset) 51 | encode_dataloader = DataLoader(dataset = world_dataset, batch_size = 1, collate_fn=bi_collate_fn, shuffle=False, sampler=sampler) 52 | 53 | disable = True if local_rank not in [-1, 0] else False 54 | for sample in tqdm(encode_dataloader, disable=disable): 55 | candidate_encode = test_module.encode_candidates( 56 | ent_ids = sample['token_ids'].cuda(), 57 | view_expansion = view_expansion, 58 | top_k = top_k, 59 | merge_layers = merge_layers, 60 | mode='test' 61 | ).squeeze(0).detach().to("cpu") # not support for encode_batch_size > 1 62 | entity_pool.append(candidate_encode) 63 | entity_title += [sample['title'][0]] * candidate_encode.size(0) 64 | 65 | world_entity_pool[world] = torch.cat(entity_pool, 0) 66 | world_entity_titles[world] = entity_title 67 | 68 | torch.save([world_entity_pool, world_entity_titles], 'entity_{}.pt'.format(local_rank)) 69 | 70 | 71 | if n_gpu > 1: 72 | torch.distributed.barrier() 73 | 74 | if local_rank not in [-1, 0]: 75 | return None, None 76 | 77 | world_entity_pool, world_entity_titles = {}, {} 78 | for i in range(n_gpu): 79 | sub_entity_pool, sub_entity_titles = torch.load('entity_{}.pt'.format(i), map_location='cpu') 80 | for world_name, world_num in WORLDS[mode]: 81 | titles = world_entity_titles.get(world_name, []) 82 | pool = world_entity_pool.get(world_name, []) 83 | 84 | sub_titles = sub_entity_titles[world_name] 85 | sub_pool = sub_entity_pool[world_name] 86 | if world_num % n_gpu and world_num % n_gpu - 1 < i: 87 | end_idx = len(sub_titles) - 2 88 | while sub_titles[end_idx] == sub_titles[-1]: 89 | end_idx -= 1 90 | 91 | sub_titles = sub_titles[:end_idx + 1] 92 | sub_pool = sub_pool[:end_idx+1, :] 93 | 94 | titles += sub_titles 95 | world_entity_titles[world_name] = titles 96 | 97 | pool.append(sub_pool) 98 | world_entity_pool[world_name] = pool 99 | 100 | for key, _ in WORLDS[mode]: 101 | pool = world_entity_pool[key] 102 | pool = torch.cat(pool, 0).to("cuda:0") 103 | world_entity_pool[key] = pool 104 | #print(world_entity_pool[key].shape) 105 | 106 | world_entity_ids_range = {} 107 | for key, titles in world_entity_titles.items(): 108 | ids_range = {} 109 | for ids, title in enumerate(titles): 110 | title_range = ids_range.get(title, []) 111 | title_range.append(ids) 112 | ids_range[title] = title_range 113 | world_entity_ids_range[key] = ids_range 114 | 115 | top_k = [1, 2, 4, 8, 16, 32, 50, 64] 116 | score_metrics = {world_name: [[0] * len(top_k), 0] for world_name, _ in WORLDS[mode]} 117 | score_metrics['total'] = [[0] * len(top_k), 0] 118 | candidates = [] 119 | # Then Encode the entities and Compare 120 | dataloader = DataLoader(dataset=dataset, batch_size=encode_batch_size, collate_fn=cross_collate_fn, shuffle=False) 121 | for batch in tqdm(dataloader): 122 | worlds, labels = batch['world'], batch['label_world_idx'] 123 | predict_scores = test_module.score_candidates( 124 | ctx_ids = batch['context_ids'].to("cuda:0"), 125 | ctx_world = batch['world'], 126 | candidate_pool = world_entity_pool 127 | ) # [candidates_num] * batch_size 128 | 129 | for predict_score, world, label in zip(predict_scores, worlds, labels): 130 | predict_score = torch.softmax(predict_score, -1) 131 | predict_ids = torch.sort(predict_score, -1, descending=True).indices.cpu() 132 | scores = torch.sort(predict_score, -1, descending=True).values.cpu() 133 | label_title = dataset.entity_desc[world].get_nth_title(label) 134 | accumulate_score = is_accumulate_score 135 | if accumulate_score: 136 | predict_title_dict = {} 137 | 138 | ids = 0 139 | while len(predict_title_dict.keys()) < 200:#top_k[-1]: 140 | title = world_entity_titles[world][predict_ids[ids]] 141 | title_score = predict_title_dict.get(title, 0) + scores[ids] 142 | predict_title_dict[title] = title_score 143 | 144 | ids += 1 145 | predict_title = sorted(predict_title_dict, key=predict_title_dict.get)[::-1] 146 | else: 147 | predict_title = [] 148 | ids = 0 149 | while len(predict_title) < 64: 150 | title = world_entity_titles[world][predict_ids[ids]] 151 | if title not in predict_title: 152 | predict_title.append(title) 153 | ids += 1 154 | 155 | for k_idx, k in enumerate(top_k): 156 | if label_title in predict_title[:k]: 157 | score_metrics[world][0][k_idx] += 1 158 | score_metrics['total'][0][k_idx] += 1 159 | score_metrics[world][1] += 1 160 | score_metrics['total'][1] += 1 161 | 162 | candidates.append([{'title': title} for title in predict_title]) 163 | print(score_metrics) 164 | pretty_visualize(score_metrics, top_k) 165 | 166 | return score_metrics['total'][0][-1], candidates 167 | 168 | def evaluate_cross_model(model, tokenizer, dataset, mode, encode_batch_size = 1, device = "cpu"): 169 | model.eval() 170 | if isinstance(model, torch.nn.DataParallel): 171 | test_module = model.module 172 | else: 173 | test_module = model 174 | 175 | dataloader = DataLoader(dataset=dataset, batch_size=encode_batch_size, num_workers=1, shuffle=False, collate_fn=collate_fn) 176 | #normalized_correct, normalized_total, unnormalized_total = 0, 0, 0 177 | score_metrics = {world_name: [0, 0, 0] for world_name, _ in WORLDS[mode]} 178 | score_metrics['total'] = [0, 0, 0] 179 | for batch in tqdm(dataloader): 180 | ctx_ids = batch['candidate_ids'].to(device) 181 | split_len = batch['split_len'] 182 | if encode_batch_size == 1: 183 | ctx_ids = ctx_ids.squeeze(0) 184 | split_len = split_len[0] 185 | 186 | score, _ = model(ctx_ids, split_len=split_len, target = None) # batch_size * top_k 187 | predict_idx = torch.max(score, -1).indices 188 | 189 | for idx, t, w in zip(predict_idx, batch['label'], batch['world']): 190 | if idx == t: 191 | score_metrics[w][0] += 1 192 | score_metrics['total'][0] += 1 193 | if t != -1: 194 | score_metrics[w][1] += 1 195 | score_metrics['total'][1] += 1 196 | 197 | score_metrics[w][2] += 1 198 | score_metrics['total'][2] += 1 199 | 200 | return score_metrics 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /muver/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : logger.py 5 | @Time : 2021/03/15 17:00:32 6 | @Author : Xinyin Ma 7 | @Version : 0.1 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | 11 | import os 12 | import codecs 13 | import logging 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | class LoggerWithDepth(): 17 | def __init__(self, env_name, config, root_dir = 'runtime_log', overwrite = True): 18 | if os.path.exists(os.path.join(root_dir, env_name)) and not overwrite: 19 | raise Exception("Logging Directory {} Has Already Exists. Change to another name or set OVERWRITE to True".format(os.path.join(root_dir, env_name))) 20 | 21 | self.env_name = env_name 22 | self.root_dir = root_dir 23 | self.log_dir = os.path.join(root_dir, env_name) 24 | self.overwrite = overwrite 25 | 26 | if not os.path.exists(root_dir): 27 | os.mkdir(root_dir) 28 | if not os.path.exists(self.log_dir): 29 | os.mkdir(self.log_dir) 30 | 31 | # Save Hyperparameters 32 | self.write_description_to_folder(os.path.join(self.log_dir, 'description.txt'), config) 33 | self.best_checkpoint_path = os.path.join(self.log_dir, 'pytorch_model.bin') 34 | 35 | def setup_sublogger(self, sub_name, sub_config): 36 | self.sub_dir = os.path.join(self.log_dir, sub_name) 37 | if os.path.exists(self.sub_dir): 38 | raise Exception("Logging Directory {} Has Already Exists. Change to another sub name or set OVERWRITE to True".format(self.sub_dir)) 39 | else: 40 | os.mkdir(self.sub_dir) 41 | 42 | # Setup File/Stream Writer 43 | log_format=logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") 44 | 45 | self.writer = logging.getLogger() 46 | fileHandler = logging.FileHandler(os.path.join(self.sub_dir, "training.log")) 47 | fileHandler.setFormatter(log_format) 48 | self.writer.addHandler(fileHandler) 49 | 50 | ''' 51 | consoleHandler = logging.StreamHandler() 52 | consoleHandler.setFormatter(log_format) 53 | self.writer.addHandler(consoleHandler) 54 | ''' 55 | self.writer.setLevel(logging.INFO) 56 | 57 | # Setup tensorboard Writer 58 | self.painter = SummaryWriter(self.sub_dir) 59 | tb_dir = self.painter.log_dir 60 | 61 | # Checkpoint 62 | self.checkpoint_path = os.path.join(self.sub_dir, 'pytorch_model.bin') 63 | self.lastest_checkpoint_path = os.path.join(self.sub_dir, 'latest_model.bin') 64 | 65 | 66 | def write_description_to_folder(self, file_name, config): 67 | with codecs.open(file_name, 'w') as desc_f: 68 | desc_f.write("- Training Parameters: \n") 69 | for key, value in config.items(): 70 | desc_f.write(" - {}: {}\n".format(key, value)) -------------------------------------------------------------------------------- /muver/utils/multigpu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : multigpu.py 5 | @Time : 2021/04/15 14:56:18 6 | @Author : Xinyin Ma 7 | @Version : 0.1 8 | @Contact : maxinyin@zju.edu.cn 9 | ''' 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | 15 | class GatherLayer(torch.autograd.Function): 16 | """Gather tensors from all process, supporting backward propagation.""" 17 | 18 | @staticmethod 19 | def forward(ctx, input): 20 | ctx.save_for_backward(input) 21 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 22 | dist.all_gather(output, input) 23 | return tuple(output) 24 | 25 | @staticmethod 26 | def backward(ctx, *grads): 27 | (input,) = ctx.saved_tensors 28 | grad_out = torch.zeros_like(input) 29 | grad_out[:] = grads[dist.get_rank()] 30 | return grad_out -------------------------------------------------------------------------------- /muver/utils/params.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | WORLDS = { 4 | 'train': [("american_football", 31929), ("doctor_who", 40281), ("fallout", 16992), ("final_fantasy", 14044), ("military", 104520), ("pro_wrestling", 10133), ("starwars", 87056), ("world_of_warcraft", 27677)], 5 | 'valid': [("coronation_street", 17809), ("muppets", 21344), ("ice_hockey", 28684), ("elder_scrolls", 21712)], 6 | 'test': [("forgotten_realms", 15603), ("lego", 10076), ("star_trek", 34430), ("yugioh", 10031)] 7 | } 8 | 9 | ENTITY_LINKING_BENCHMARK = { 10 | 'train':['AIDA-YAGO2_train'], 11 | 'dev': ['AIDA-YAGO2_testa'], 12 | 'aida_test': ['AIDA-YAGO2_testb'], 13 | 'test': ['ace2004_questions', 'AIDA-YAGO2_testb', 'aquaint_questions', 'clueweb_questions', 'msnbc_questions', 'wnedwiki_questions'] 14 | } -------------------------------------------------------------------------------- /muver/utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | from itertools import product 5 | 6 | import numpy as np 7 | import torch 8 | 9 | def grid_search_hyperparamters(args): 10 | search_key, search_value = [], [] 11 | for key, value in args.__dict__.items(): 12 | if isinstance(value, list) and key != 'test_set': 13 | search_key.append(key) 14 | search_value.append(value) 15 | 16 | new_args = [] 17 | for one_search_value in product(*search_value): 18 | arg = deepcopy(args) 19 | for key, value in zip(search_key, one_search_value): 20 | arg.__setattr__(key, value) 21 | new_args.append(arg) 22 | return new_args 23 | 24 | def set_random_seed(seed): 25 | # Set Random Seed 26 | random.seed(seed) 27 | np.random.seed(seed) 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed_all(seed) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | transformers==4.1.1 3 | prettytable 4 | nltk --------------------------------------------------------------------------------