├── .DS_Store ├── .gitignore ├── README.md ├── data ├── .DS_Store └── FewCOMM │ ├── dev.txt │ ├── test.txt │ └── train.txt ├── main.py ├── model ├── .DS_Store ├── maml.py ├── mtnet.py ├── nnshot.py ├── proto.py ├── relation_ner.py └── siamese.py ├── pic └── dataset.png ├── requirements.txt ├── run.sh ├── transformer_model ├── .DS_Store └── tip.md └── utils ├── .DS_Store ├── config.py ├── contrastiveloss.py ├── data_loader.py ├── fewshotsampler.py ├── framework.py ├── framework_mtnet.py ├── tripletloss.py ├── viterbi.py └── word_encoder.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | temp.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta-Learning Triplet Network with Adaptive Margins for Few-Shot Named Entity Recognition 2 | 3 | This repository contains the code and data for our paper: 4 | 5 | [*Meta-Learning Triplet Network with Adaptive Margins for Few-Shot Named Entity Recognition*](https://arxiv.org/pdf/2302.07739) 6 | 7 | If you find this work useful and use it on your own research, please cite our paper. 8 | 9 | ````` 10 | @article{han2023meta, 11 | title={Meta-Learning Triplet Network with Adaptive Margins for Few-Shot Named Entity Recognition}, 12 | author={Chengcheng Han and 13 | Renyu Zhu and 14 | Jun Kuang and 15 | FengJiao Chen and 16 | Xiang Li and 17 | Ming Gao and 18 | Xuezhi Cao and 19 | Wei Wu}, 20 | journal={arXiv preprint arXiv:2302.07739}, 21 | year={2023} 22 | } 23 | ````` 24 | 25 | ### Overview 26 | 27 | We propose an improved triplet network with adaptive margins (MeTNet) and a new inference procedure for few-shot NER. 28 | 29 | We release the first Chinese few-shot NER dataset FEW-COMM. 30 | 31 | ### Data 32 | 33 | The datasets used by our experiments are in the `data/` folder, including FEW-COMM, FEW-NERD, WNUT17, Restaurant and Multiwoz. 34 | 35 | 36 | 37 | 40 | 41 | **FEW-COMM** is a Chinese few-shot NER dataset we released, which consists of 66,165 product description texts that merchants display on a large e-commerce platform, including 140,936 entities and 92 pre-defined entity types. These entity types are various commodity attributes that are manually defined by domain experts, such as "material", "color" and "origin". Please see Appendix C of our paper for more details on the dataset. 42 | 43 | 44 | 48 | 49 | 50 | ## Setup 51 | 52 | This implemetation is based on Python3.7. To run the code, you need the following dependencies: 53 | 54 | - nltk>=3.6.4 55 | - numpy==1.21.0 56 | - pandas==1.3.5 57 | - torch==1.7.1 58 | - transformers==4.0.1 59 | - apex==0.9.10dev 60 | - scikit_learn==0.24.1 61 | - seqeval 62 | 63 | You can simply run 64 | 65 | ```Sh 66 | pip install -r requirements.txt 67 | ``` 68 | 69 | ## Repository structure 70 | 71 | We select some important files for detailed description. 72 | 73 | ```Shell 74 | |-- data # experiments for five datasets 75 | |-- Few-COMM/ # a Chinese few-shot NER dataset we released 76 | |-- model # includes all model implementations 77 | |-- transformer_model # includes BERT pre-trained checkpoints 78 | |--bert-base-chinese 79 | |--bert-base-uncased 80 | |-- utils 81 | |-- config.py # configuration 82 | |-- data_loader.py # load data 83 | |-- fewshotsampler.py # construst meta-tasks 84 | |-- framework.py # includes train\eval\test procedure 85 | |-- tripletloss.py # an improved triplet loss 86 | |-- main.py 87 | |-- run.sh 88 | ``` 89 | 90 | ### Quickstart 91 | 92 | 1. Unzip our processed data file `data.zip` and put the data files under `data/` folder. 93 | 94 | 2. Please download pretrained BERT files [bert-base-chinese](https://huggingface.co/bert-base-chinese/tree/main) and [bert-base-uncased](https://huggingface.co/bert-base-uncased/tree/main) and put them under `transformer_model/` folder. 95 | 96 | 3. ```sh run.sh``` 97 | 98 | You can also adjust the model by modifying the parameters in the `run.sh` file. 99 | 100 | Currently, the benchmarks on the FEW-COMM dataset are as follows: 101 | 102 | | FEW-COMM | 5-way 1-shot | 5-way 5-shot | 10-way 1-shot | 10-way 5-shot | 103 | | ---| ---| ---| ---| ---| 104 | |MAML|28.16|54.38|26.23|44.66| 105 | |NNShot|48.40|71.55|41.75|67.91| 106 | |StructShot|48.61|70.62|47.77|65.09| 107 | |PROTO|22.73|53.95|22.17|45.81| 108 | |CONTaiNER|57.13|63.38|51.87|60.98| 109 | |ESD|65.37|73.29|58.32|70.93| 110 | |DecomMETA|68.01|72.89|62.13|72.14| 111 | |SpanProto|70.97|76.59|63.94|74.67| 112 | |**MeTNet**|**71.89**|**78.14**|**65.11**|**77.58**| 113 | 114 | If you have the latest experimental results on the FEW-COMM dataset, please contact us to update the benchmark. 115 | 116 | For the [FewNERD](https://github.com/thunlp/Few-NERD) dataset, please download it from its official website. 117 | 118 | 124 | 125 | 135 | 142 | 143 | ## Attribution 144 | 145 | Parts of this code are based on the following repositories: 146 | 147 | - [FewNERD](https://github.com/thunlp/Few-NERD) 148 | - [MLADA](https://github.com/hccngu/MLADA) 149 | 150 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/data/.DS_Store -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 4 | 5 | import sys 6 | import time 7 | import torch 8 | from torch.utils.data import dataloader 9 | from torch.multiprocessing import reductions 10 | from multiprocessing.reduction import ForkingPickler 11 | 12 | import numpy as np 13 | import json 14 | import argparse 15 | import random 16 | 17 | # import torch 18 | from torch import optim, nn 19 | 20 | from transformers import BertTokenizer 21 | 22 | from utils.config import config, setSeed, print_args 23 | from utils.word_encoder import BERTWordEncoder 24 | from utils.data_loader import get_loader 25 | from utils.framework import FewShotNERFramework, FewShotNERFramework_MAML, FewShotNERFramework_RelationNER, FewShotNERFramework_Siamese, FewShotNERFramework_SiameseMAML 26 | from utils.framework_mtnet import FewShotNERFramework_MTNet 27 | from model.proto import Proto, Proto_multiOclass, ProtoMAML 28 | # from model.mtnet import MTNet 29 | from model.nnshot import NNShot 30 | from model.maml import MAML 31 | from model.relation_ner import RelationNER 32 | from model.siamese import Siamese, SiameseMAML 33 | from model.mtnet import MTNet 34 | 35 | def default_collate_override(batch): 36 | dataloader._use_shared_memory = False 37 | return default_collate_func(batch) 38 | 39 | def main(): 40 | opt = config() 41 | print_args(opt) 42 | 43 | trainN = opt.trainN 44 | N = opt.N 45 | K = opt.K 46 | Q = opt.Q 47 | batch_size = opt.batch_size 48 | model_name = opt.model 49 | max_length = opt.max_length 50 | 51 | print("{}-way-{}-shot Few-Shot NER".format(N, K)) 52 | print("model: {}".format(model_name)) 53 | print("max_length: {}".format(max_length)) 54 | print('mode: {}'.format(opt.mode)) 55 | 56 | setSeed(opt.seed) 57 | 58 | print('loading pre-trained language model and tokenizer...') 59 | if opt.dataset == 'fewcomm': 60 | UNCASED = './transformer_model/bert-base-chinese' 61 | VOCAB = 'vocab.txt' 62 | pretrain_ckpt = opt.pretrain_ckpt or './transformer_model/bert-base-chinese' 63 | word_encoder = BERTWordEncoder(pretrain_ckpt) 64 | tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED,VOCAB)) 65 | else: 66 | UNCASED = './transformer_model/bert-base-uncased' 67 | VOCAB = 'vocab.txt' 68 | pretrain_ckpt = opt.pretrain_ckpt or './transformer_model/bert-base-uncased' 69 | word_encoder = BERTWordEncoder(pretrain_ckpt) 70 | tokenizer = BertTokenizer.from_pretrained(os.path.join(UNCASED,VOCAB)) 71 | 72 | print('loading data...') 73 | if opt.dataset == 'fewcomm': 74 | if opt.dataset_mode == 'BIO': 75 | opt.train = f'./data/Few-COMM/train_BIO.txt' 76 | opt.test = f'./data/Few-COMM/test_BIO.txt' 77 | opt.dev = f'./data/Few-COMM/dev_BIO.txt' 78 | else: 79 | opt.train = f'./data/Few-COMM/train.txt' 80 | opt.test = f'./data/Few-COMM/test.txt' 81 | opt.dev = f'./data/Few-COMM/dev.txt' 82 | elif opt.dataset == 'mitrestaurant' or opt.dataset == 'multiwoz' or opt.dataset == 'wnut17': 83 | opt.train = f'./data/Few-NERD/inter/train.txt' 84 | opt.test = f'./data/{opt.dataset}/test.txt' 85 | opt.dev = f'./data/Few-NERD/inter/dev.txt' 86 | elif opt.dataset == 'fewnerd': 87 | if not opt.use_sampled_data: 88 | opt.train = f'./data/Few-NERD/{opt.mode}/train.txt' 89 | opt.test = f'./data/Few-NERD/{opt.mode}/test.txt' 90 | opt.dev = f'./data/Few-NERD/{opt.mode}/dev.txt' 91 | 92 | else: 93 | opt.train = f'./data/Few-NERD/episode-data/{opt.mode}/train_{opt.N}_{opt.K}.jsonl' 94 | opt.test = f'./data/Few_NERD/episode-data/{opt.mode}/test_{opt.N}_{opt.K}.jsonl' 95 | opt.dev = f'./data/Few-NERD/episode-data/{opt.mode}/dev_{opt.N}_{opt.K}.jsonl' 96 | else: 97 | raise NotImplementedError 98 | 99 | if not (os.path.exists(opt.train) and os.path.exists(opt.dev) and os.path.exists(opt.test)): 100 | raise RuntimeError('data file is not exist!') 101 | 102 | train_data_loader = get_loader(opt.train, tokenizer, 103 | N=trainN, K=K, Q=Q, batch_size=batch_size, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data) 104 | val_data_loader = get_loader(opt.dev, tokenizer, 105 | N=N, K=K, Q=Q, batch_size=1, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data) 106 | test_data_loader = get_loader(opt.test, tokenizer, 107 | N=N, K=K, Q=Q, batch_size=1, max_length=max_length, ignore_index=opt.ignore_index, use_sampled_data=opt.use_sampled_data) 108 | 109 | 110 | prefix = '-'.join([model_name, opt.dataset, opt.mode, opt.dataset_mode, str(N), str(K), 'seed'+str(opt.seed), str(int(round(time.time() * 1000)))]) 111 | if opt.dot: 112 | prefix += '-dot' 113 | if len(opt.ckpt_name) > 0: 114 | prefix += '-' + opt.ckpt_name 115 | 116 | print('Loading model...') 117 | if model_name == 'proto': 118 | print('use proto') 119 | model = Proto(opt, word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 120 | framework = FewShotNERFramework(opt, tokenizer, train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data) 121 | elif model_name == 'nnshot': 122 | print('use nnshot') 123 | model = NNShot(opt, word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 124 | framework = FewShotNERFramework(opt, tokenizer, train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data) 125 | elif model_name == 'structshot': 126 | print('use structshot') 127 | model = NNShot(opt, word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 128 | framework = FewShotNERFramework(opt, tokenizer, train_data_loader, val_data_loader, test_data_loader, N=opt.N, tau=opt.tau, train_fname=opt.train, viterbi=True, use_sampled_data=opt.use_sampled_data) 129 | elif model_name == 'MTNet': 130 | print('use MTNet') 131 | model = MTNet(word_encoder, dot=opt.dot, args=opt, ignore_index=opt.ignore_index) 132 | framework = FewShotNERFramework_MTNet(train_data_loader, val_data_loader, test_data_loader, args=opt, tokenizer=tokenizer, use_sampled_data=opt.use_sampled_data) 133 | # model = MTNet(word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 134 | # framework = FewShotNERFramework(train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data) 135 | elif model_name == 'MAML': 136 | print('use MAML') 137 | model = MAML(word_encoder, dot=opt.dot, args=opt, ignore_index=opt.ignore_index) 138 | framework = FewShotNERFramework_MAML(tokenizer, train_data_loader, val_data_loader, test_data_loader, args=opt, use_sampled_data=opt.use_sampled_data) 139 | elif model_name == 'relation_ner': 140 | print('use RelationNER') 141 | model = RelationNER(word_encoder, dot=opt.dot, args=opt, ignore_index=opt.ignore_index) 142 | framework = FewShotNERFramework_RelationNER(train_data_loader, val_data_loader, test_data_loader, args=opt, tokenizer=tokenizer, use_sampled_data=opt.use_sampled_data) 143 | elif model_name == 'proto_maml': 144 | print('use ProtoMAML') 145 | model = ProtoMAML(opt, word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 146 | framework = FewShotNERFramework(opt, tokenizer, train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data) 147 | elif model_name == 'Siamese': 148 | print('use Siamese Network') 149 | model = Siamese(word_encoder, dot=opt.dot, args=opt, ignore_index=opt.ignore_index) 150 | framework = FewShotNERFramework_Siamese(train_data_loader, val_data_loader, test_data_loader, args=opt, tokenizer=tokenizer, use_sampled_data=opt.use_sampled_data) 151 | elif model_name == 'SiameseMAML': 152 | print('use Siamese MAML') 153 | model = SiameseMAML(word_encoder, dot=opt.dot, args=opt, ignore_index=opt.ignore_index) 154 | framework = FewShotNERFramework_SiameseMAML(train_data_loader, val_data_loader, test_data_loader, args=opt, tokenizer=tokenizer, use_sampled_data=opt.use_sampled_data) 155 | elif model_name == 'proto_multiOclass': 156 | print('use proto_multiOclass') 157 | model = Proto_multiOclass(opt, word_encoder, dot=opt.dot, ignore_index=opt.ignore_index) 158 | framework = FewShotNERFramework(opt, tokenizer, train_data_loader, val_data_loader, test_data_loader, use_sampled_data=opt.use_sampled_data) 159 | else: 160 | raise NotImplementedError 161 | 162 | if not os.path.exists('checkpoint'): 163 | os.mkdir('checkpoint') 164 | ckpt = 'checkpoint/{}.pth.tar'.format(prefix) 165 | if opt.save_ckpt: 166 | ckpt = opt.save_ckpt 167 | print('model-save-path:', ckpt) 168 | 169 | if torch.cuda.is_available(): 170 | model.cuda() 171 | 172 | if not opt.only_test: 173 | if opt.lr == -1: 174 | opt.lr = 2e-5 175 | 176 | framework.train(model, prefix, 177 | load_ckpt=opt.load_ckpt, save_ckpt=ckpt, 178 | val_step=opt.val_step, fp16=opt.fp16, 179 | train_iter=opt.train_iter, warmup_step=int(opt.train_iter * 0.1), val_iter=opt.val_iter, learning_rate=opt.lr, use_sgd_for_bert=opt.use_sgd_for_bert) 180 | else: 181 | ckpt = opt.load_ckpt 182 | if ckpt is None: 183 | print("Warning: --load_ckpt is not specified. Will load Hugginface pre-trained checkpoint.") 184 | ckpt = 'none' 185 | 186 | print('testing...') 187 | precision, recall, f1, fp, fn, within, outer = framework.eval(model, opt.test_iter, ckpt=ckpt) 188 | print("RESULT: precision: %.4f, recall: %.4f, f1:%.4f" % (precision, recall, f1)) 189 | print('ERROR ANALYSIS: fp: %.4f, fn: %.4f, within:%.4f, outer: %.4f'%(fp, fn, within, outer)) 190 | 191 | if __name__ == '__main__': 192 | 193 | default_collate_func = dataloader.default_collate 194 | 195 | setattr(dataloader, 'default_collate', default_collate_override) 196 | 197 | for t in torch._storage_classes: 198 | if sys.version_info[0] == 2: 199 | if t in ForkingPickler.dispatch: 200 | del ForkingPickler.dispatch[t] 201 | else: 202 | if t in ForkingPickler._extra_reducers: 203 | del ForkingPickler._extra_reducers[t] 204 | 205 | main() -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/model/.DS_Store -------------------------------------------------------------------------------- /model/maml.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import utils 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class MAML(utils.framework.FewShotNERModel): 10 | 11 | def __init__(self, word_encoder, dot=False, args=None, ignore_index=-1): 12 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 13 | self.drop = nn.Dropout() 14 | self.dot = dot 15 | self.fc = nn.Linear(768, args.N+1) 16 | # self.fc1 = nn.Linear(768, 512) 17 | # self.fc2 = nn.Linear(512, 256) 18 | # self.fc3 = nn.Linear(256, 128) 19 | # self.fc4 = nn.Linear(128, args.N+1) 20 | # self.loss = nn.CrossEntropyLoss() 21 | 22 | def __dist__(self, x, y, dim): 23 | if self.dot: 24 | return (x * y).sum(dim) 25 | else: 26 | return -(torch.pow(x - y, 2)).sum(dim) 27 | 28 | def __batch_dist__(self, S, Q, q_mask): 29 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 30 | assert Q.size()[:2] == q_mask.size() 31 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 32 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 33 | 34 | def __get_proto__(self, embedding, tag, mask): 35 | proto = [] 36 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 37 | tag = torch.cat(tag, 0) 38 | assert tag.size(0) == embedding.size(0) 39 | for label in range(torch.max(tag)+1): 40 | proto.append(torch.mean(embedding[tag==label], 0)) 41 | proto = torch.stack(proto) 42 | return proto 43 | 44 | def forward(self, data, model_parameters=None): 45 | ''' 46 | data : support or query 47 | support: Inputs of the support set. 48 | query: Inputs of the query set. 49 | N: Num of classes 50 | K: Num of instances for each class in the support set 51 | Q: Num of instances in the query set 52 | ''' 53 | data_emb = self.word_encoder(data['word'], data['mask']) # [num_sent, number_of_tokens, 768] 54 | data_emb = self.drop(data_emb) 55 | # data_emb = data_emb.view(-1, 768) # [num_sent*number_of_tokens, 768] 56 | temp_sent = [] 57 | for i, line in enumerate(data_emb): 58 | temp_sent.append(line[data['text_mask'][i]==1]) 59 | data_emb = torch.cat(temp_sent, 0) 60 | if model_parameters == None: 61 | # emb = self.fc1(data_emb) # [num_sent*number_of_tokens, N+1] 62 | # emb = F.relu(emb) 63 | # emb = self.fc2(emb) 64 | # emb = F.relu(emb) 65 | # emb = self.fc3(emb) 66 | # emb = F.relu(emb) 67 | # logits = self.fc4(emb) 68 | logits = self.fc(data_emb) 69 | _, pred = torch.max(logits, 1) 70 | else: 71 | # emb = F.linear(data_emb, model_parameters['fc1']['weight']) 72 | # emb = F.relu(emb) 73 | # emb = F.linear(emb, model_parameters['fc2']['weight']) 74 | # emb = F.relu(emb) 75 | # emb = F.linear(emb, model_parameters['fc3']['weight']) 76 | # emb = F.relu(emb) 77 | # logits = F.linear(emb, model_parameters['fc4']['weight']) 78 | logits = F.linear(data_emb, model_parameters['fc']['weight']) 79 | _, pred = torch.max(logits, 1) 80 | 81 | return logits, pred 82 | 83 | def cloned_fc_dict(self): 84 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 85 | 86 | def cloned_fc1_dict(self): 87 | return {key: val.clone() for key, val in self.fc1.state_dict().items()} 88 | 89 | def cloned_fc2_dict(self): 90 | return {key: val.clone() for key, val in self.fc2.state_dict().items()} 91 | 92 | def cloned_fc3_dict(self): 93 | return {key: val.clone() for key, val in self.fc3.state_dict().items()} 94 | 95 | def cloned_fc4_dict(self): 96 | return {key: val.clone() for key, val in self.fc4.state_dict().items()} 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /model/mtnet.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import sys 3 | from this import d 4 | sys.path.append('..') 5 | import utils 6 | import torch 7 | from torch import autograd, optim, nn 8 | from torch.autograd import Variable 9 | from torch.nn import functional as F 10 | 11 | 12 | 13 | class MTNet(utils.framework.FewShotNERModel): 14 | 15 | def __init__(self, word_encoder, dot=False, args=None, ignore_index=-1): 16 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 17 | self.drop = nn.Dropout(p=args.dropout) 18 | self.dot = dot 19 | self.args = args 20 | self.fc = nn.Linear(768, 512) 21 | self.att = nn.Linear(768, 768, bias=False) 22 | if self.args.multi_margin is True: 23 | if self.args.have_otherO is True: 24 | self.param = nn.Parameter(torch.Tensor([args.trainable_margin_init for _ in range(self.args.N+1)])) # 8.5 25 | else: 26 | self.param = nn.Parameter(torch.Tensor([args.trainable_margin_init for _ in range(self.args.N)])) # 8.5 27 | else: 28 | self.param = nn.Parameter(torch.Tensor([args.trainable_margin_init])) # 8.5 29 | # if self.args.alpha_is_trainable: 30 | self.alpha = nn.Parameter(torch.Tensor([args.trainable_alpha_init])) # 0.5 31 | # else: 32 | # self.alpha = args.trainable_alpha_init 33 | 34 | def __dist__(self, x, y, dim): 35 | if self.dot: 36 | return (x * y).sum(dim) 37 | else: 38 | return -(torch.pow(x - y, 2)).sum(dim) 39 | 40 | def __batch_dist__(self, S, Q, q_mask): 41 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 42 | assert Q.size()[:2] == q_mask.size() 43 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 44 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 45 | 46 | def __neg_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 47 | return -torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 48 | 49 | def __pos_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 50 | return torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 51 | 52 | def __get_proto__(self, label_data_emb, mask): 53 | proto = [] 54 | assert label_data_emb.shape[0] == mask.shape[0] 55 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 56 | p = l_ebd[mask[i]].mean(0) 57 | proto.append(p.view(1,-1)) 58 | proto = torch.cat(proto,0) 59 | proto = proto.view((-1,self.args.N,768)) 60 | return proto 61 | 62 | def __forward_once__(self, emb): 63 | emb = self.fc(emb) 64 | return emb 65 | 66 | def __forward_once_with_param__(self, emb, param): 67 | emb = F.linear(emb, param['fc']['weight']) 68 | return emb 69 | 70 | def __get_sample_pairs__(self, data): 71 | data_1 = {} 72 | data_2 = {} 73 | data['word_emb'] = data['word_emb'][data['text_mask']==1] 74 | data['label'] = torch.cat(data['label'], dim=0) 75 | data_1['word_emb'] = data['word_emb'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 76 | data_1['label'] = data['label'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 77 | data_2['word_emb'] = data['word_emb'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 78 | data_2['label'] = data['label'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 79 | 80 | return data_1, data_2 81 | 82 | def __generate_query_pair_label__(query, query_dis_output): 83 | query_label = torch.cat(query['label'], 0) 84 | query_label = query_label.cuda() 85 | assert query_label.shape[0] == query_dis_output.shape[0] 86 | query_dis_output = query_dis_output[query_label!=-1] 87 | query_label = query_label[query_label!=-1] 88 | 89 | def forward(self, data, model_parameters=None): 90 | ''' 91 | data : support or query 92 | support: Inputs of the support set. 93 | query: Inputs of the query set. 94 | N: Num of classes 95 | K: Num of instances for each class in the support set 96 | Q: Num of instances in the query set 97 | ''' 98 | 99 | if model_parameters is None: 100 | out = self.__forward_once__(data) # [x, 768] -> [x, 256] 101 | out = self.drop(out) 102 | else: 103 | out = self.__forward_once_with_param__(data, model_parameters) 104 | out = self.drop(out) 105 | 106 | # calculate distance 107 | # dis = self.__pos_dist__(out1, out2).view(-1) 108 | # dis = F.layer_norm(dis, normalized_shape=[dis.shape[0]], bias=torch.full((dis.shape[0],), self.args.ln_bias).cuda()) # weight是乘,bias是加,shape和dis相同 109 | 110 | return out 111 | 112 | 113 | ''' 114 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 115 | support_emb = self.drop(support_emb) 116 | support['word_emb'] = support_emb 117 | 118 | if query_flag is not True: 119 | 120 | # get sample pairs 121 | support_1, support_2 = self.__get_sample_pairs__(support) 122 | 123 | if model_parameters is None: 124 | out1 = self.__forward_once__(support_1['word_emb']) # [x, 768] -> [x, 256] 125 | out2 = self.__forward_once__(support_2['word_emb']) 126 | else: 127 | out1 = self.__forward_once_with_param__(support_1['word_emb'], model_parameters) 128 | out2 = self.__forward_once_with_param__(support_2['word_emb'], model_parameters) 129 | 130 | # calculate distance 131 | dis = self.__pos_dist__(out1, out2).view(-1) 132 | # print('out1', out1) 133 | # print('out2', out2) 134 | # print('support12', support_1['word_emb'], support_2['word_emb']) 135 | print('dis', dis) 136 | dis = F.layer_norm(dis, normalized_shape=[dis.shape[0]], bias=torch.full((dis.shape[0],), self.args.ln_bias).cuda()) # weight是乘,bias是加,shape和dis相同 137 | print('dis_after_ln', dis) 138 | pair_label = self.__generate_pair_label__(support_1['label'], support_2['label']) 139 | 140 | return dis, pair_label 141 | else: 142 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 143 | query_emb = self.drop(query_emb) 144 | query_emb = query_emb[query['text_mask']==1] 145 | query['word_emb'] = query_emb 146 | 147 | support_out = self.__forward_once_with_param__(support_emb, model_parameters)[support['text_mask']==1] 148 | query_out = self.__forward_once_with_param__(query_emb, model_parameters)[query['text_mask']==1] 149 | proto = [] 150 | assert support_out.shape[0] == support['label'].shape[0] 151 | for label in range(1, self.args.N+1): 152 | proto.append(torch.mean(support_out[support['label']==label], dim=0)) 153 | proto = torch.stack(proto) 154 | query_dis = self.__pos_dist__(query_out, proto).view(-1) 155 | query_dis_output = F.layer_norm(query_dis, normalized_shape=[query_dis.shape[0]], bias=torch.full((query_dis.shape[0],), self.args.ln_bias).cuda()) 156 | query_dis = query_dis_output.view(-1, proto.shape[0]) 157 | query_pred = [] 158 | for tmp in query_dis: 159 | if any(t < margin for t in tmp): 160 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 161 | else: 162 | query_pred.append(0) 163 | query_dis_output, query_pair_label = self.__generate_query_pair_label__(query, query_dis_output) 164 | return torch.Tensor(query_pred).cuda(), query_dis_output, query_pair_label 165 | ''' 166 | 167 | def __generate_pair_label__(self, label1, label2): 168 | pair_label = [] 169 | for l1 in label1: 170 | for l2 in label2: 171 | if l1 == l2: 172 | pair_label.append(1.0) 173 | else: 174 | pair_label.append(0.0) 175 | return torch.Tensor(pair_label).cuda() 176 | 177 | def cloned_fc_dict(self): 178 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 179 | 180 | def cloned_fc1_dict(self): 181 | return {key: val.clone() for key, val in self.fc1.state_dict().items()} 182 | 183 | def cloned_fc2_dict(self): 184 | return {key: val.clone() for key, val in self.fc2.state_dict().items()} 185 | 186 | def cloned_fc3_dict(self): 187 | return {key: val.clone() for key, val in self.fc3.state_dict().items()} 188 | 189 | def cloned_fc4_dict(self): 190 | return {key: val.clone() for key, val in self.fc4.state_dict().items()} 191 | 192 | 193 | -------------------------------------------------------------------------------- /model/nnshot.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import utils 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class NNShot(utils.framework.FewShotNERModel): 10 | 11 | def __init__(self, args, word_encoder, dot=False, ignore_index=-1): 12 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 13 | self.drop = nn.Dropout() 14 | self.dot = dot 15 | self.args = args 16 | 17 | def __dist__(self, x, y, dim): 18 | if self.dot: 19 | return (x * y).sum(dim) 20 | else: 21 | return -(torch.pow(x - y, 2)).sum(dim) 22 | 23 | def __batch_dist__(self, S, Q, q_mask): 24 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 25 | assert Q.size()[:2] == q_mask.size() 26 | Q = Q[q_mask==1].view(-1, Q.size(-1)) 27 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 28 | 29 | def __get_nearest_dist__(self, embedding, tag, mask, query, q_mask): 30 | nearest_dist = [] 31 | S = embedding[mask==1].view(-1, embedding.size(-1)) 32 | tag = torch.cat(tag, 0) 33 | assert tag.size(0) == S.size(0) 34 | dist = self.__batch_dist__(S, query, q_mask) # [num_of_query_tokens, num_of_support_tokens] 35 | for label in range(torch.max(tag)+1): 36 | nearest_dist.append(torch.max(dist[:,tag==label], 1)[0]) 37 | nearest_dist = torch.stack(nearest_dist, dim=1) # [num_of_query_tokens, class_num] 38 | return nearest_dist 39 | 40 | def __get_nearest_dist_for_BIO__(self, embedding, tag, mask, query, q_mask): 41 | nearest_dist = [] 42 | S = embedding[mask==1].view(-1, embedding.size(-1)) 43 | tag = torch.cat(tag, 0) 44 | assert tag.size(0) == S.size(0) 45 | dist = self.__batch_dist__(S, query, q_mask) # [num_of_query_tokens, num_of_support_tokens] 46 | for label in range(self.args.N*2+1): 47 | if label not in tag: 48 | nearest_dist.append(torch.max(dist[:,tag==0], 1)[0]) 49 | else: 50 | nearest_dist.append(torch.max(dist[:,tag==label], 1)[0]) 51 | nearest_dist = torch.stack(nearest_dist, dim=1) # [num_of_query_tokens, class_num] 52 | return nearest_dist 53 | 54 | def forward(self, support, query): 55 | ''' 56 | support: Inputs of the support set. 57 | query: Inputs of the query set. 58 | N: Num of classes 59 | K: Num of instances for each class in the support set 60 | Q: Num of instances in the query set 61 | ''' 62 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 63 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 64 | support_emb = self.drop(support_emb) 65 | query_emb = self.drop(query_emb) 66 | 67 | logits = [] 68 | current_support_num = 0 69 | current_query_num = 0 70 | assert support_emb.size()[:2] == support['mask'].size() 71 | assert query_emb.size()[:2] == query['mask'].size() 72 | 73 | if self.args.dataset_mode == 'BIO': 74 | get_nearest_dist = self.__get_nearest_dist_for_BIO__ 75 | else: 76 | get_nearest_dist = self.__get_nearest_dist__ 77 | 78 | for i, sent_support_num in enumerate(support['sentence_num']): 79 | sent_query_num = query['sentence_num'][i] 80 | # Calculate nearest distance to single entity in each class in support set 81 | logits.append(get_nearest_dist(support_emb[current_support_num:current_support_num+sent_support_num], 82 | support['label'][current_support_num:current_support_num+sent_support_num], 83 | support['text_mask'][current_support_num: current_support_num+sent_support_num], 84 | query_emb[current_query_num:current_query_num+sent_query_num], 85 | query['text_mask'][current_query_num: current_query_num+sent_query_num])) 86 | current_query_num += sent_query_num 87 | current_support_num += sent_support_num 88 | logits = torch.cat(logits, 0) 89 | _, pred = torch.max(logits, 1) 90 | return logits, pred 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /model/proto.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import utils 4 | import numpy as np 5 | import os 6 | import time 7 | import torch 8 | from torch import autograd, optim, nn 9 | from torch.autograd import Variable 10 | from torch.nn import functional as F 11 | 12 | class Proto(utils.framework.FewShotNERModel): 13 | 14 | def __init__(self, args, word_encoder, dot=False, ignore_index=-1): 15 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 16 | self.drop = nn.Dropout() 17 | self.dot = dot 18 | self.args = args 19 | if self.args.mlp is True: 20 | self.fc = nn.Linear(768, 256) 21 | 22 | def __dist__(self, x, y, dim): 23 | if self.dot: 24 | return (x * y).sum(dim) 25 | else: 26 | return -(torch.pow(x - y, 2)).sum(dim) 27 | 28 | def __batch_dist__(self, S, Q, q_mask): 29 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 30 | assert Q.size()[:2] == q_mask.size() 31 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 32 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 33 | 34 | def __get_proto_for_BIO__(self, embedding, tag, mask): 35 | proto = [] 36 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 37 | tag = torch.cat(tag, 0) 38 | assert tag.size(0) == embedding.size(0) 39 | for label in range(self.args.N*2+1): 40 | if embedding[tag==label].shape[0] == 0: 41 | proto.append(proto[0]) 42 | else: 43 | proto.append(torch.mean(embedding[tag==label], 0)) 44 | # print(torch.mean(embedding[tag==label], 0)) 45 | proto = torch.stack(proto) 46 | return proto 47 | 48 | def __get_proto__(self, embedding, tag, mask): 49 | proto = [] 50 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 51 | tag = torch.cat(tag, 0) 52 | assert tag.size(0) == embedding.size(0) 53 | for label in range(torch.max(tag)+1): 54 | proto.append(torch.mean(embedding[tag==label], 0)) 55 | proto = torch.stack(proto) 56 | return proto 57 | 58 | def forward(self, support, query): 59 | ''' 60 | support: Inputs of the support set. 61 | query: Inputs of the query set. 62 | N: Num of classes 63 | K: Num of instances for each class in the support set 64 | Q: Num of instances in the query set 65 | ''' 66 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 67 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 68 | if (self.args.only_test is True) and (self.args.save_query_ebd is True): 69 | save_qe_path = '_'.join(['425', self.args.dataset, self.args.mode, 'proto', str(self.args.N), str(self.args.K), str(self.args.Q), str(int(round(time.time() * 1000)))]) 70 | if not os.path.exists(save_qe_path): 71 | os.mkdir(save_qe_path) 72 | f_write = open(os.path.join(save_qe_path, 'label2tag.txt'), 'w', encoding='utf-8') 73 | for ln in query['label2tag'][0]: 74 | f_write.write(query['label2tag'][0][ln] + '\n') 75 | f_write.flush() 76 | f_write.close() 77 | se = support_emb[support['text_mask']==1].view(-1, support_emb.size(-1)) 78 | qe = query_emb[query['text_mask']==1].view(-1, query_emb.size(-1)) 79 | sl = torch.cat(support['label'], 0) 80 | ql = torch.cat(query['label'], 0) 81 | for i in range(self.args.N+1): 82 | if i == 0: 83 | p = torch.mean(se[sl==i], dim=0, keepdim=True) 84 | else: 85 | p = torch.cat((p, torch.mean(se[sl==i], dim=0, keepdim=True)), dim=0) 86 | np.save(os.path.join(save_qe_path, str(i)+'.npy'), qe[ql==i].cpu().detach().numpy()) 87 | np.save(os.path.join(save_qe_path, 'proto.npy'), p.cpu().detach().numpy()) 88 | sys.exit() 89 | 90 | support_emb = self.drop(support_emb) 91 | query_emb = self.drop(query_emb) 92 | if self.args.mlp is True: 93 | support_emb = self.fc(support_emb) 94 | query_emb = self.fc(query_emb) 95 | 96 | # Prototypical Networks 97 | logits = [] 98 | current_support_num = 0 99 | current_query_num = 0 100 | assert support_emb.size()[:2] == support['mask'].size() 101 | assert query_emb.size()[:2] == query['mask'].size() 102 | 103 | if self.args.dataset_mode == 'BIO': 104 | get_proto = self.__get_proto_for_BIO__ 105 | else: 106 | get_proto = self.__get_proto__ 107 | 108 | for i, sent_support_num in enumerate(support['sentence_num']): 109 | sent_query_num = query['sentence_num'][i] 110 | # Calculate prototype for each class 111 | support_proto = get_proto( 112 | support_emb[current_support_num:current_support_num+sent_support_num], 113 | support['label'][current_support_num:current_support_num+sent_support_num], 114 | support['text_mask'][current_support_num: current_support_num+sent_support_num]) 115 | # calculate distance to each prototype 116 | logits.append(self.__batch_dist__( # logits[0]:[110, 6](110个词和6个proto的距离) 117 | support_proto, 118 | query_emb[current_query_num:current_query_num+sent_query_num], 119 | query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] 120 | current_query_num += sent_query_num 121 | current_support_num += sent_support_num 122 | logits = torch.cat(logits, 0) 123 | _, pred = torch.max(logits, 1) 124 | return logits, pred 125 | 126 | 127 | class ProtoMAML(utils.framework.FewShotNERModel): 128 | 129 | def __init__(self, args, word_encoder, dot=False, ignore_index=-1): 130 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 131 | self.drop = nn.Dropout() 132 | self.dot = dot 133 | self.args = args 134 | self.fc = nn.Linear(768, 256) 135 | self.bn = nn.LayerNorm(768) 136 | 137 | def __dist__(self, x, y, dim): 138 | if self.dot: 139 | return (x * y).sum(dim) 140 | else: 141 | return -(torch.pow(x - y, 2)).sum(dim) 142 | 143 | def __batch_dist__(self, S, Q, q_mask): 144 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 145 | assert Q.size()[:2] == q_mask.size() 146 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 147 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 148 | 149 | def __get_proto_for_BIO__(self, embedding, tag, mask): 150 | proto = [] 151 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 152 | tag = torch.cat(tag, 0) 153 | assert tag.size(0) == embedding.size(0) 154 | for label in range(self.args.N*2+1): 155 | if embedding[tag==label].shape[0] == 0: 156 | proto.append(proto[0]) 157 | else: 158 | proto.append(torch.mean(embedding[tag==label], 0)) 159 | # print(torch.mean(embedding[tag==label], 0)) 160 | proto = torch.stack(proto) 161 | return proto 162 | 163 | def __get_proto__(self, embedding, tag, mask): 164 | proto = [] 165 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 166 | tag = torch.cat(tag, 0) 167 | assert tag.size(0) == embedding.size(0) 168 | for label in range(torch.max(tag)+1): 169 | proto.append(torch.mean(embedding[tag==label], 0)) 170 | proto = torch.stack(proto) 171 | return proto 172 | 173 | def forward(self, data, query, query_flag=False, model_parameters=None): 174 | ''' 175 | support: Inputs of the support set. 176 | query: Inputs of the query set. 177 | N: Num of classes 178 | K: Num of instances for each class in the support set 179 | Q: Num of instances in the query set 180 | ''' 181 | data_emb = self.word_encoder(data['word'], data['mask']) # [num_sent, number_of_tokens, 768] 182 | # query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 183 | data_emb = self.drop(data_emb) 184 | # print(data_emb.shape) 185 | data_emb = self.bn(data_emb) 186 | # query_emb = self.drop(query_emb) 187 | if model_parameters == None: 188 | data_emb = self.fc(data_emb) 189 | else: 190 | data_emb = F.linear(data_emb, model_parameters['fc']['weight']) 191 | # query_emb = self.fc(query_emb) 192 | if query_flag is True: 193 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 194 | query_emb = self.drop(query_emb) 195 | if model_parameters == None: 196 | query_emb = self.fc(query_emb) 197 | else: 198 | query_emb = F.linear(query_emb, model_parameters['fc']['weight']) 199 | else: 200 | query = data 201 | query_emb = data_emb 202 | 203 | # Prototypical Networks 204 | logits = [] 205 | current_support_num = 0 206 | current_query_num = 0 207 | assert data_emb.size()[:2] == data['mask'].size() 208 | assert query_emb.size()[:2] == query['mask'].size() 209 | 210 | if self.args.dataset_mode == 'BIO': 211 | get_proto = self.__get_proto_for_BIO__ 212 | else: 213 | get_proto = self.__get_proto__ 214 | 215 | for i, sent_support_num in enumerate(data['sentence_num']): 216 | sent_query_num = query['sentence_num'][i] 217 | # Calculate prototype for each class 218 | support_proto = get_proto( 219 | data_emb[current_support_num:current_support_num+sent_support_num], 220 | data['label'][current_support_num:current_support_num+sent_support_num], 221 | data['text_mask'][current_support_num: current_support_num+sent_support_num]) 222 | # calculate distance to each prototype 223 | logits.append(self.__batch_dist__( # logits[0]:[110, 6](110个词和6个proto的距离) 224 | support_proto, 225 | query_emb[current_query_num:current_query_num+sent_query_num], 226 | query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] 227 | current_query_num += sent_query_num 228 | current_support_num += sent_support_num 229 | logits = torch.cat(logits, 0) 230 | _, pred = torch.max(logits, 1) 231 | return logits, pred 232 | 233 | def cloned_fc_dict(self): 234 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 235 | 236 | 237 | class NoOtherProto(utils.framework.FewShotNERModel): 238 | 239 | def __init__(self,word_encoder, dot=False, ignore_index=-1): 240 | utils.framework.FewShotNERModel.__init__(self, word_encoder, ignore_index=ignore_index) 241 | self.drop = nn.Dropout() 242 | self.dot = dot 243 | 244 | def __dist__(self, x, y, dim): 245 | if self.dot: 246 | return (x * y).sum(dim) 247 | else: 248 | return -(torch.pow(x - y, 2)).sum(dim) 249 | 250 | def __batch_dist__(self, S, Q, q_mask): 251 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 252 | assert Q.size()[:2] == q_mask.size() 253 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 254 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 255 | 256 | def __get_proto__(self, embedding, tag, mask): 257 | proto = [] 258 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 259 | tag = torch.cat(tag, 0) 260 | assert tag.size(0) == embedding.size(0) 261 | for label in range(torch.max(tag)+1): 262 | proto.append(torch.mean(embedding[tag==label], 0)) 263 | proto = torch.stack(proto) 264 | return proto 265 | 266 | def forward(self, support, query): 267 | ''' 268 | support: Inputs of the support set. 269 | query: Inputs of the query set. 270 | N: Num of classes 271 | K: Num of instances for each class in the support set 272 | Q: Num of instances in the query set 273 | ''' 274 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 275 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 276 | support_emb = self.drop(support_emb) 277 | query_emb = self.drop(query_emb) 278 | 279 | # Prototypical Networks 280 | logits = [] 281 | current_support_num = 0 282 | current_query_num = 0 283 | assert support_emb.size()[:2] == support['mask'].size() 284 | assert query_emb.size()[:2] == query['mask'].size() 285 | 286 | for i, sent_support_num in enumerate(support['sentence_num']): 287 | sent_query_num = query['sentence_num'][i] 288 | # Calculate prototype for each class 289 | support_proto = self.__get_proto__( 290 | support_emb[current_support_num:current_support_num+sent_support_num], 291 | support['label'][current_support_num:current_support_num+sent_support_num], 292 | support['text_mask'][current_support_num: current_support_num+sent_support_num]) 293 | # calculate distance to each prototype 294 | logits.append(self.__batch_dist__( # logits[0]:[110, 6](110个词和6个proto的距离) 295 | support_proto, 296 | query_emb[current_query_num:current_query_num+sent_query_num], 297 | query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] 298 | current_query_num += sent_query_num 299 | current_support_num += sent_support_num 300 | logits = torch.cat(logits, 0) # [x, 6] 301 | logits = -logits[:, 1:]/torch.mean(logits[:,1:]) 302 | logits = F.sigmoid(logits) 303 | _, pred = torch.max(logits, 1) 304 | return logits, pred 305 | 306 | 307 | class Proto_multiOclass(utils.framework.FewShotNERModel): 308 | 309 | def __init__(self, args, word_encoder, dot=False, ignore_index=-1): 310 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 311 | self.drop = nn.Dropout() 312 | self.dot = dot 313 | self.args = args 314 | if self.args.mlp is True: 315 | self.fc = nn.Linear(768, 256) 316 | 317 | def __dist__(self, x, y, dim): 318 | if self.dot: 319 | return (x * y).sum(dim) 320 | else: 321 | return -(torch.pow(x - y, 2)).sum(dim) 322 | 323 | def __batch_dist__(self, S, Q, q_mask): 324 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 325 | assert Q.size()[:2] == q_mask.size() 326 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 327 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 328 | 329 | def __get_proto_for_BIO__(self, embedding, tag, mask): 330 | proto = [] 331 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 332 | tag = torch.cat(tag, 0) 333 | assert tag.size(0) == embedding.size(0) 334 | for label in range(self.args.N*2+1): 335 | if embedding[tag==label].shape[0] == 0: 336 | proto.append(proto[0]) 337 | else: 338 | proto.append(torch.mean(embedding[tag==label], 0)) 339 | # print(torch.mean(embedding[tag==label], 0)) 340 | proto = torch.stack(proto) 341 | return proto 342 | 343 | def __get_proto__(self, embedding, tag, mask): 344 | proto = [] 345 | embedding = embedding[mask==1].view(-1, embedding.size(-1)) 346 | tag = torch.cat(tag, 0) 347 | assert tag.size(0) == embedding.size(0) 348 | for label in range(torch.max(tag)+1): 349 | proto.append(torch.mean(embedding[tag==label], 0)) 350 | proto = torch.stack(proto) 351 | return proto 352 | 353 | def forward(self, support, query): 354 | ''' 355 | support: Inputs of the support set. 356 | query: Inputs of the query set. 357 | N: Num of classes 358 | K: Num of instances for each class in the support set 359 | Q: Num of instances in the query set 360 | ''' 361 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 362 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 363 | if (self.args.only_test is True) and (self.args.save_query_ebd is True): 364 | save_qe_path = '_'.join(['425', self.args.dataset, self.args.mode, 'proto', str(self.args.N), str(self.args.K), str(self.args.Q), str(int(round(time.time() * 1000)))]) 365 | if not os.path.exists(save_qe_path): 366 | os.mkdir(save_qe_path) 367 | f_write = open(os.path.join(save_qe_path, 'label2tag.txt'), 'w', encoding='utf-8') 368 | for ln in query['label2tag'][0]: 369 | f_write.write(query['label2tag'][0][ln] + '\n') 370 | f_write.flush() 371 | f_write.close() 372 | se = support_emb[support['text_mask']==1].view(-1, support_emb.size(-1)) 373 | qe = query_emb[query['text_mask']==1].view(-1, query_emb.size(-1)) 374 | sl = torch.cat(support['label'], 0) 375 | ql = torch.cat(query['label'], 0) 376 | for i in range(self.args.N+1): 377 | if i == 0: 378 | p = torch.mean(se[sl==i], dim=0, keepdim=True) 379 | else: 380 | p = torch.cat((p, torch.mean(se[sl==i], dim=0, keepdim=True)), dim=0) 381 | np.save(os.path.join(save_qe_path, str(i)+'.npy'), qe[ql==i].cpu().detach().numpy()) 382 | np.save(os.path.join(save_qe_path, 'proto.npy'), p.cpu().detach().numpy()) 383 | sys.exit() 384 | 385 | support_emb = self.drop(support_emb) 386 | query_emb = self.drop(query_emb) 387 | if self.args.mlp is True: 388 | support_emb = self.fc(support_emb) 389 | query_emb = self.fc(query_emb) 390 | 391 | # Prototypical Networks 392 | logits = [] 393 | current_support_num = 0 394 | current_query_num = 0 395 | assert support_emb.size()[:2] == support['mask'].size() 396 | assert query_emb.size()[:2] == query['mask'].size() 397 | 398 | if self.args.dataset_mode == 'BIO': 399 | get_proto = self.__get_proto_for_BIO__ 400 | else: 401 | get_proto = self.__get_proto__ 402 | 403 | for i, sent_support_num in enumerate(support['sentence_num']): 404 | sent_query_num = query['sentence_num'][i] 405 | # Calculate prototype for each class 406 | support_proto = get_proto( 407 | support_emb[current_support_num:current_support_num+sent_support_num], 408 | support['label'][current_support_num:current_support_num+sent_support_num], 409 | support['text_mask'][current_support_num: current_support_num+sent_support_num]) 410 | # calculate distance to each prototype 411 | logits.append(self.__batch_dist__( # logits[0]:[110, 6](110个词和6个proto的距离) 412 | support_proto, 413 | query_emb[current_query_num:current_query_num+sent_query_num], 414 | query['text_mask'][current_query_num: current_query_num+sent_query_num])) # [num_of_query_tokens, class_num] 415 | current_query_num += sent_query_num 416 | current_support_num += sent_support_num 417 | logits = torch.cat(logits, 0) 418 | _, pred = torch.max(logits, 1) 419 | return logits, pred 420 | 421 | -------------------------------------------------------------------------------- /model/relation_ner.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import sys 3 | sys.path.append('..') 4 | import utils 5 | import torch 6 | from torch import autograd, optim, nn 7 | from torch.autograd import Variable 8 | from torch.nn import functional as F 9 | 10 | class RelationNER(utils.framework.FewShotNERModel): 11 | 12 | def __init__(self, word_encoder, dot=False, args=None, ignore_index=-1): 13 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 14 | self.drop = nn.Dropout(p=args.dropout) 15 | self.dot = dot 16 | self.args = args 17 | self.fc = nn.Linear(768*2, 1) 18 | self.bn1 = nn.BatchNorm1d(768*2) 19 | if self.args.alpha == -1: 20 | self.param = nn.Parameter(torch.Tensor([0.5])) 21 | # self.bn2 = nn.BatchNorm1d(1) 22 | # self.fc1 = nn.Linear(768, 512) 23 | # self.fc2 = nn.Linear(512, 256) 24 | # self.fc3 = nn.Linear(256, 128) 25 | # self.fc4 = nn.Linear(128, args.N+1) 26 | # self.loss = nn.CrossEntropyLoss() 27 | 28 | def __dist__(self, x, y, dim): 29 | if self.dot: 30 | return (x * y).sum(dim) 31 | else: 32 | return -(torch.pow(x - y, 2)).sum(dim) 33 | 34 | def __batch_dist__(self, S, Q, q_mask): 35 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 36 | assert Q.size()[:2] == q_mask.size() 37 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 38 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 39 | 40 | def __get_proto__(self, label_data_emb, mask): 41 | proto = [] 42 | assert label_data_emb.shape[0] == mask.shape[0] 43 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 44 | p = l_ebd[mask[i]].mean(0) 45 | proto.append(p.view(1,-1)) 46 | proto = torch.cat(proto,0) 47 | proto = proto.view((-1,self.args.N,768)) 48 | return proto 49 | 50 | def forward(self, data, label_data, model_parameters=None): 51 | ''' 52 | data : support or query 53 | support: Inputs of the support set. 54 | query: Inputs of the query set. 55 | N: Num of classes 56 | K: Num of instances for each class in the support set 57 | Q: Num of instances in the query set 58 | ''' 59 | 60 | # 1. get prototypes by label name 61 | label_data_emb = self.word_encoder(label_data['word'], label_data['mask']) # [num_label_sent, 10, 768] 62 | proto_list = self.__get_proto__(label_data_emb, label_data['text_mask']) # [batch_num, N, 768] 63 | 64 | data_emb = self.word_encoder(data['word'], data['mask']) # [num_sent, number_of_tokens, 768] 65 | data_emb = self.drop(data_emb) 66 | 67 | word_label = data['label'] 68 | temp_sen_num = [] 69 | temp_sent = [] # [x,768] 70 | temp_label = [] 71 | 72 | temp_count = 0 73 | temp_sen_num.append(0) 74 | for num in data['sentence_num']: 75 | temp_count += num 76 | temp_sen_num.append(temp_count) 77 | 78 | for i, line in enumerate(data_emb): 79 | temp_sent.append(line[data['text_mask'][i]==1]) 80 | # data_emb = torch.cat(temp_sent, 0) # [num_sent*number_of_tokens, N+1] 81 | temp_label_list = [] # [batch_num,x] 82 | temp_sent_list = [] # [batch_num,x,768] 83 | for i in range(len(temp_sen_num)-1): 84 | temp_sent_list.append(temp_sent[temp_sen_num[i]: temp_sen_num[i+1]]) 85 | temp_label_list.append(word_label[temp_sen_num[i]: temp_sen_num[i+1]]) 86 | 87 | sample_pairs = [] # n行,每行是一个tensor,维度[768*2] 88 | pair_labels = [] # 0/1向量[0,0,0,1,...,0] 89 | for samples, protos, labels in zip(temp_sent_list, proto_list, temp_label_list): 90 | for _sample, _label in zip(samples, labels): 91 | assert _sample.shape[0] == _label.shape[0] 92 | for sample, label in zip(_sample, _label): 93 | if label == -1: 94 | continue 95 | for i, proto in enumerate(protos): 96 | tmp_pair = torch.cat((sample, proto), dim=0) 97 | sample_pairs.append(tmp_pair.view(1,-1)) 98 | if label == i+1: 99 | pair_labels.append(1) 100 | else: 101 | pair_labels.append(0) 102 | emb = torch.cat(sample_pairs, dim=0) 103 | emb = self.bn1(emb) 104 | 105 | if model_parameters == None: 106 | # emb = self.fc1(data_emb) # [num_sent*number_of_tokens, N+1] 107 | # emb = F.relu(emb) 108 | # emb = self.fc2(emb) 109 | # emb = F.relu(emb) 110 | # emb = self.fc3(emb) 111 | # emb = F.relu(emb) 112 | emb = self.fc(emb) 113 | logits = torch.sigmoid(emb) 114 | # _, pred = torch.max(logits, 1) 115 | else: 116 | # emb = F.linear(data_emb, model_parameters['fc1']['weight']) 117 | # emb = F.relu(emb) 118 | # emb = F.linear(emb, model_parameters['fc2']['weight']) 119 | # emb = F.relu(emb) 120 | # emb = F.linear(emb, model_parameters['fc3']['weight']) 121 | # emb = F.relu(emb) 122 | emb = F.linear(emb, model_parameters['fc']['weight']) 123 | logits = torch.sigmoid(emb) # [num*N, 1] 124 | # _, pred = torch.max(logits, 1) 125 | 126 | # get pred 127 | pred = [] 128 | tmp_logits = logits.view(-1, self.args.N) 129 | for tmp in tmp_logits: 130 | if any(t > self.args.alpha for t in tmp): 131 | pred.append(torch.max(tmp,dim=0)[1].item()+1) 132 | else: 133 | pred.append(0) 134 | # _, pred = torch.max(logits.view(-1, self.args.N), 1) 135 | 136 | return emb, pair_labels, pred # loss自带sigmoid 137 | 138 | def cloned_fc_dict(self): 139 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 140 | 141 | def cloned_fc1_dict(self): 142 | return {key: val.clone() for key, val in self.fc1.state_dict().items()} 143 | 144 | def cloned_fc2_dict(self): 145 | return {key: val.clone() for key, val in self.fc2.state_dict().items()} 146 | 147 | def cloned_fc3_dict(self): 148 | return {key: val.clone() for key, val in self.fc3.state_dict().items()} 149 | 150 | def cloned_fc4_dict(self): 151 | return {key: val.clone() for key, val in self.fc4.state_dict().items()} 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /model/siamese.py: -------------------------------------------------------------------------------- 1 | from cProfile import label 2 | import sys 3 | from this import d 4 | sys.path.append('..') 5 | import utils 6 | import torch 7 | from torch import autograd, optim, nn 8 | from torch.autograd import Variable 9 | from torch.nn import functional as F 10 | 11 | class Siamese(utils.framework.FewShotNERModel): 12 | 13 | def __init__(self, word_encoder, dot=False, args=None, ignore_index=-1): 14 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 15 | self.drop = nn.Dropout(p=args.dropout) 16 | self.dot = dot 17 | self.args = args 18 | self.fc = nn.Linear(768, 512) 19 | # self.fc2 = nn.Linear(512, 128) 20 | # self.ln = nn.LayerNorm(768) 21 | # self.bn1 = nn.BatchNorm1d(768*2) 22 | # self.bn2 = nn.BatchNorm1d(1) 23 | # self.fc1 = nn.Linear(768, 512) 24 | # self.fc2 = nn.Linear(512, 256) 25 | # self.fc3 = nn.Linear(256, 128) 26 | # self.fc4 = nn.Linear(128, args.N+1) 27 | # self.loss = nn.CrossEntropyLoss() 28 | 29 | def __dist__(self, x, y, dim): 30 | if self.dot: 31 | return (x * y).sum(dim) 32 | else: 33 | return -(torch.pow(x - y, 2)).sum(dim) 34 | 35 | def __batch_dist__(self, S, Q, q_mask): 36 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 37 | assert Q.size()[:2] == q_mask.size() 38 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 39 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 40 | 41 | def __neg_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 42 | return -torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 43 | 44 | def __pos_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 45 | return torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 46 | 47 | def __get_proto__(self, label_data_emb, mask): 48 | proto = [] 49 | assert label_data_emb.shape[0] == mask.shape[0] 50 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 51 | p = l_ebd[mask[i]].mean(0) 52 | proto.append(p.view(1,-1)) 53 | proto = torch.cat(proto,0) 54 | proto = proto.view((-1,self.args.N,768)) 55 | return proto 56 | 57 | def __forward_once__(self, emb): 58 | emb = self.fc(emb) 59 | # emb = self.fc2(emb) 60 | return emb 61 | 62 | def __forward_once_with_param__(self, emb, param): 63 | emb = F.linear(emb, param['fc']['weight']) 64 | return emb 65 | 66 | def __get_sample_pairs__(self, data): 67 | data_1 = {} 68 | data_2 = {} 69 | data['word_emb'] = data['word_emb'][data['text_mask']==1] 70 | data['label'] = torch.cat(data['label'], dim=0) 71 | data_1['word_emb'] = data['word_emb'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 72 | data_1['label'] = data['label'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 73 | data_2['word_emb'] = data['word_emb'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 74 | data_2['label'] = data['label'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 75 | 76 | return data_1, data_2 77 | 78 | def forward(self, support, query, query_flag=False, margin=None, model_parameters=None): 79 | ''' 80 | data : support or query 81 | support: Inputs of the support set. 82 | query: Inputs of the query set. 83 | N: Num of classes 84 | K: Num of instances for each class in the support set 85 | Q: Num of instances in the query set 86 | ''' 87 | 88 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 89 | support_emb = self.drop(support_emb) 90 | support['word_emb'] = support_emb 91 | 92 | if query_flag is not True: 93 | 94 | # get sample pairs 95 | support_1, support_2 = self.__get_sample_pairs__(support) 96 | 97 | if model_parameters is None: 98 | out1 = self.__forward_once__(support_1['word_emb']) # [x, 768] -> [x, 256] 99 | out2 = self.__forward_once__(support_2['word_emb']) 100 | else: 101 | out1 = self.__forward_once_with_param__(support_1['word_emb'], model_parameters) 102 | out2 = self.__forward_once_with_param__(support_2['word_emb'], model_parameters) 103 | 104 | # calculate distance 105 | dis = self.__pos_dist__(out1, out2).view(-1) 106 | print('out1', out1) 107 | print('out2', out2) 108 | print('support12', support_1['word_emb'], support_2['word_emb']) 109 | print('dis', dis) 110 | dis = F.layer_norm(dis, normalized_shape=[dis.shape[0]], bias=torch.full((dis.shape[0],),10.).cuda()) # weight是乘,bias是加,shape和dis相同 111 | print('dis_after_ln', dis) 112 | pair_label = self.__generate_pair_label__(support_1['label'], support_2['label']) 113 | 114 | return dis, pair_label 115 | else: 116 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 117 | query_emb = self.drop(query_emb) 118 | query['word_emb'] = query_emb 119 | 120 | support_out = self.__forward_once__(support_emb)[support['text_mask']==1] 121 | query_out = self.__forward_once__(query_emb)[query['text_mask']==1] 122 | proto = [] 123 | assert support_out.shape[0] == support['label'].shape[0] 124 | for label in range(1, self.args.N+1): 125 | proto.append(torch.mean(support_out[support['label']==label], dim=0)) 126 | proto = torch.stack(proto) 127 | query_dis = self.__pos_dist__(query_out, proto).view(-1) 128 | query_dis = F.layer_norm(query_dis, normalized_shape=[query_dis.shape[0]], bias=torch.full((query_dis.shape[0],),10.).cuda()).view(-1, proto.shape[0]) 129 | query_pred = [] 130 | for tmp in query_dis: 131 | if any(t < margin for t in tmp): 132 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 133 | else: 134 | query_pred.append(0) 135 | return torch.Tensor(query_pred).cuda() 136 | 137 | 138 | 139 | def __generate_pair_label__(self, label1, label2): 140 | pair_label = [] 141 | for l1 in label1: 142 | for l2 in label2: 143 | if l1 == l2: 144 | pair_label.append(1.0) 145 | else: 146 | pair_label.append(0.0) 147 | return torch.Tensor(pair_label).cuda() 148 | 149 | def cloned_fc_dict(self): 150 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 151 | 152 | def cloned_fc1_dict(self): 153 | return {key: val.clone() for key, val in self.fc1.state_dict().items()} 154 | 155 | def cloned_fc2_dict(self): 156 | return {key: val.clone() for key, val in self.fc2.state_dict().items()} 157 | 158 | def cloned_fc3_dict(self): 159 | return {key: val.clone() for key, val in self.fc3.state_dict().items()} 160 | 161 | def cloned_fc4_dict(self): 162 | return {key: val.clone() for key, val in self.fc4.state_dict().items()} 163 | 164 | 165 | 166 | 167 | class SiameseMAML(utils.framework.FewShotNERModel): 168 | 169 | def __init__(self, word_encoder, dot=False, args=None, ignore_index=-1): 170 | utils.framework.FewShotNERModel.__init__(self, args, word_encoder, ignore_index=ignore_index) 171 | self.drop = nn.Dropout(p=args.dropout) 172 | self.dot = dot 173 | self.args = args 174 | self.fc = nn.Linear(768, 512) 175 | if args.margin_num == -1: 176 | self.param = nn.Parameter(torch.Tensor([args.trainable_margin_init])) # 8.5 177 | 178 | def __dist__(self, x, y, dim): 179 | if self.dot: 180 | return (x * y).sum(dim) 181 | else: 182 | return -(torch.pow(x - y, 2)).sum(dim) 183 | 184 | def __batch_dist__(self, S, Q, q_mask): 185 | # S [class, embed_dim], Q [num_of_sent, num_of_tokens, embed_dim] 186 | assert Q.size()[:2] == q_mask.size() 187 | Q = Q[q_mask==1].view(-1, Q.size(-1)) # [num_of_all_text_tokens, embed_dim] 188 | return self.__dist__(S.unsqueeze(0), Q.unsqueeze(1), 2) 189 | 190 | def __neg_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 191 | return -torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 192 | 193 | def __pos_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 194 | return torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 195 | 196 | def __get_proto__(self, label_data_emb, mask): 197 | proto = [] 198 | assert label_data_emb.shape[0] == mask.shape[0] 199 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 200 | p = l_ebd[mask[i]].mean(0) 201 | proto.append(p.view(1,-1)) 202 | proto = torch.cat(proto,0) 203 | proto = proto.view((-1,self.args.N,768)) 204 | return proto 205 | 206 | def __forward_once__(self, emb): 207 | emb = self.fc(emb) 208 | # emb = self.fc2(emb) 209 | return emb 210 | 211 | def __forward_once_with_param__(self, emb, param): 212 | emb = F.linear(emb, param['fc']['weight']) 213 | return emb 214 | 215 | def __get_sample_pairs__(self, data): 216 | data_1 = {} 217 | data_2 = {} 218 | data['word_emb'] = data['word_emb'][data['text_mask']==1] 219 | data['label'] = torch.cat(data['label'], dim=0) 220 | data_1['word_emb'] = data['word_emb'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 221 | data_1['label'] = data['label'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 222 | data_2['word_emb'] = data['word_emb'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 223 | data_2['label'] = data['label'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 224 | 225 | return data_1, data_2 226 | 227 | def __generate_query_pair_label__(query, query_dis_output): 228 | query_label = torch.cat(query['label'], 0) 229 | query_label = query_label.cuda() 230 | assert query_label.shape[0] == query_dis_output.shape[0] 231 | query_dis_output = query_dis_output[query_label!=-1] 232 | query_label = query_label[query_label!=-1] 233 | 234 | def forward(self, data1, data2, model_parameters=None): 235 | ''' 236 | data : support or query 237 | support: Inputs of the support set. 238 | query: Inputs of the query set. 239 | N: Num of classes 240 | K: Num of instances for each class in the support set 241 | Q: Num of instances in the query set 242 | ''' 243 | 244 | if model_parameters is None: 245 | out1 = self.__forward_once__(data1['word_emb']) # [x, 768] -> [x, 256] 246 | out2 = self.__forward_once__(data2['word_emb']) 247 | else: 248 | out1 = self.__forward_once_with_param__(data1['word_emb'], model_parameters) 249 | out2 = self.__forward_once_with_param__(data2['word_emb'], model_parameters) 250 | 251 | # calculate distance 252 | dis = self.__pos_dist__(out1, out2).view(-1) 253 | dis = F.layer_norm(dis, normalized_shape=[dis.shape[0]], bias=torch.full((dis.shape[0],), self.args.ln_bias).cuda()) # weight是乘,bias是加,shape和dis相同 254 | 255 | return dis 256 | 257 | 258 | ''' 259 | support_emb = self.word_encoder(support['word'], support['mask']) # [num_sent, number_of_tokens, 768] 260 | support_emb = self.drop(support_emb) 261 | support['word_emb'] = support_emb 262 | 263 | if query_flag is not True: 264 | 265 | # get sample pairs 266 | support_1, support_2 = self.__get_sample_pairs__(support) 267 | 268 | if model_parameters is None: 269 | out1 = self.__forward_once__(support_1['word_emb']) # [x, 768] -> [x, 256] 270 | out2 = self.__forward_once__(support_2['word_emb']) 271 | else: 272 | out1 = self.__forward_once_with_param__(support_1['word_emb'], model_parameters) 273 | out2 = self.__forward_once_with_param__(support_2['word_emb'], model_parameters) 274 | 275 | # calculate distance 276 | dis = self.__pos_dist__(out1, out2).view(-1) 277 | # print('out1', out1) 278 | # print('out2', out2) 279 | # print('support12', support_1['word_emb'], support_2['word_emb']) 280 | print('dis', dis) 281 | dis = F.layer_norm(dis, normalized_shape=[dis.shape[0]], bias=torch.full((dis.shape[0],), self.args.ln_bias).cuda()) # weight是乘,bias是加,shape和dis相同 282 | print('dis_after_ln', dis) 283 | pair_label = self.__generate_pair_label__(support_1['label'], support_2['label']) 284 | 285 | return dis, pair_label 286 | else: 287 | query_emb = self.word_encoder(query['word'], query['mask']) # [num_sent, number_of_tokens, 768] 288 | query_emb = self.drop(query_emb) 289 | query_emb = query_emb[query['text_mask']==1] 290 | query['word_emb'] = query_emb 291 | 292 | support_out = self.__forward_once_with_param__(support_emb, model_parameters)[support['text_mask']==1] 293 | query_out = self.__forward_once_with_param__(query_emb, model_parameters)[query['text_mask']==1] 294 | proto = [] 295 | assert support_out.shape[0] == support['label'].shape[0] 296 | for label in range(1, self.args.N+1): 297 | proto.append(torch.mean(support_out[support['label']==label], dim=0)) 298 | proto = torch.stack(proto) 299 | query_dis = self.__pos_dist__(query_out, proto).view(-1) 300 | query_dis_output = F.layer_norm(query_dis, normalized_shape=[query_dis.shape[0]], bias=torch.full((query_dis.shape[0],), self.args.ln_bias).cuda()) 301 | query_dis = query_dis_output.view(-1, proto.shape[0]) 302 | query_pred = [] 303 | for tmp in query_dis: 304 | if any(t < margin for t in tmp): 305 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 306 | else: 307 | query_pred.append(0) 308 | query_dis_output, query_pair_label = self.__generate_query_pair_label__(query, query_dis_output) 309 | return torch.Tensor(query_pred).cuda(), query_dis_output, query_pair_label 310 | ''' 311 | 312 | def __generate_pair_label__(self, label1, label2): 313 | pair_label = [] 314 | for l1 in label1: 315 | for l2 in label2: 316 | if l1 == l2: 317 | pair_label.append(1.0) 318 | else: 319 | pair_label.append(0.0) 320 | return torch.Tensor(pair_label).cuda() 321 | 322 | def cloned_fc_dict(self): 323 | return {key: val.clone() for key, val in self.fc.state_dict().items()} 324 | 325 | def cloned_fc1_dict(self): 326 | return {key: val.clone() for key, val in self.fc1.state_dict().items()} 327 | 328 | def cloned_fc2_dict(self): 329 | return {key: val.clone() for key, val in self.fc2.state_dict().items()} 330 | 331 | def cloned_fc3_dict(self): 332 | return {key: val.clone() for key, val in self.fc3.state_dict().items()} 333 | 334 | def cloned_fc4_dict(self): 335 | return {key: val.clone() for key, val in self.fc4.state_dict().items()} 336 | 337 | 338 | -------------------------------------------------------------------------------- /pic/dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/pic/dataset.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk>=3.6.4 2 | numpy==1.21.0 3 | pandas==1.3.5 4 | torch==1.7.1 5 | transformers==4.0.1 6 | apex==0.9.10dev 7 | scikit_learn==0.24.1 8 | seqeval -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | startTime=`date +%Y%m%d-%H:%M` 4 | startTime_s=`date +%s` 5 | 6 | dataset=fewcomm 7 | dataset_mode=IO 8 | N=5 9 | K=1 10 | python -u main.py --multi_margin --use_proto_as_neg --model MTNet --dataset $dataset --dataset_mode $dataset_mode --mode $mode --trainN $N --N $N --K $K --Q 1 --trainable_margin_init 6.5 11 | 12 | K=5 13 | python -u main.py --multi_margin --use_proto_as_neg --model MTNet --dataset $dataset --dataset_mode $dataset_mode --mode $mode --trainN $N --N $N --K $K --Q 1 --trainable_margin_init 7.3 14 | 15 | N=10 16 | K=1 17 | python -u main.py --multi_margin --use_proto_as_neg --model MTNet --dataset $dataset --dataset_mode $dataset_mode --mode $mode --trainN $N --N $N --K $K --Q 1 --trainable_margin_init 6.1 18 | 19 | N=10 20 | K=5 21 | python -u main.py --multi_margin --use_proto_as_neg --model MTNet --dataset $dataset --dataset_mode $dataset_mode --mode $mode --trainN $N --N $N --K $K --Q 1 --trainable_margin_init 6.4 22 | 23 | endTime=`date +%Y%m%d-%H:%M` 24 | endTime_s=`date +%s` 25 | sumTime=$[ $endTime_s - $startTime_s ] 26 | echo "$startTime ---> $endTime" "Totl:$sumTime seconds" 27 | -------------------------------------------------------------------------------- /transformer_model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/transformer_model/.DS_Store -------------------------------------------------------------------------------- /transformer_model/tip.md: -------------------------------------------------------------------------------- 1 | Download pretrained BERT files [bert-base-chinese](https://huggingface.co/bert-base-chinese/tree/main) and [bert-base-uncased](https://huggingface.co/bert-base-uncased/tree/main) and put them here. -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hccngu/MeTNet/a5112c600364e682eca45b278386b773ba56d5d6/utils/.DS_Store -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import random 4 | 5 | import torch 6 | 7 | def config(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dataset', default='fewnerd', help='[fewcomm, fewnerd]') 10 | parser.add_argument('--dataset_mode', default='IO', help='[BIO, IO] only for fewcomm') 11 | parser.add_argument('--mode', default='inter', help='training mode [inter, intra, supervised], only use for fewnerd') 12 | parser.add_argument('--trainN', default=5, type=int, help='N in train') 13 | parser.add_argument('--N', default=5, type=int, help='N way') 14 | parser.add_argument('--K', default=1, type=int, help='K shot') 15 | parser.add_argument('--Q', default=1, type=int, help='Num of queries per class') 16 | parser.add_argument('--batch_size', default=1, type=int, help='batch size') 17 | parser.add_argument('--train_iter', default=6000, type=int, help='num of iters in training, default=6000') 18 | parser.add_argument('--val_iter', default=100, type=int, help='num of iters in validation') 19 | parser.add_argument('--test_iter', default=500, type=int, help='num of iters in testing, default=500') 20 | parser.add_argument('--val_step', default=200, type=int, help='val after training how many iters, default=200') 21 | parser.add_argument('--model', default='proto_multiOclass', help='model name [proto, nnshot, structshot, MTNet, MAML, relation_ner, Siamese, SiameseMAML, proto_multiOclass]') 22 | parser.add_argument('--max_length', default=100, type=int, help='max length') 23 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate for proto, nnshot, structshot') 24 | parser.add_argument('--grad_iter', default=1, type=int, help='accumulate gradient every x iterations') 25 | # checkpoint/proto-fewnerd-inter-IO-5-1-seed0-1649991137204.pth.tar 26 | # checkpoint/MTNet-fewnerd-inter-IO-5-1-seed0-1649354243516.pth.tar 27 | parser.add_argument('--load_ckpt', default=None, help='load ckpt') 28 | parser.add_argument('--save_ckpt', default=None, help='save ckpt') 29 | parser.add_argument('--fp16', action='store_true', help='use nvidia apex fp16') 30 | parser.add_argument('--only_test', action='store_true', help='only test') 31 | parser.add_argument('--ckpt_name', type=str, default='', help='checkpoint name.') 32 | parser.add_argument('--seed', type=int, default=0, help='random seed') 33 | parser.add_argument('--ignore_index', type=int, default=-1, help='label index to ignore when calculating loss and metrics') 34 | parser.add_argument('--use_sampled_data', action='store_true', help='use released sampled data, the data should be stored at "data/episode-data/" ') 35 | # experiment 36 | parser.add_argument('--use_sgd_for_bert', action='store_true', help='use SGD instead of AdamW for BERT.') 37 | # only for bert / roberta 38 | parser.add_argument('--pretrain_ckpt', default=None, help='bert / roberta pre-trained checkpoint') 39 | # for print inference 40 | parser.add_argument('--save_test_inference', default='none', help='test inference profile, default=test_inference') 41 | 42 | 43 | # only for prototypical networks 44 | parser.add_argument('--dot', action='store_true', help='use dot instead of L2 distance for proto') 45 | parser.add_argument('--mlp', action='store_true', help='use a mlp for proto') 46 | # only for structshot 47 | parser.add_argument('--tau', default=0.05, type=float, help='StructShot parameter to re-normalizes the transition probabilities') 48 | # RelationNER 49 | parser.add_argument("--use_class_weights", type=bool, default=True, help="use class weights for MAML, SiameseMAML") 50 | parser.add_argument("--cs", type=float, default=100, help="class weight hyper-param") 51 | parser.add_argument("--alpha", type=float, default=0.5, help="pred O when all preds are smaller than it") 52 | # Siamese 53 | parser.add_argument("--margin_num", type=int, default=-1, help="control margin(*N*K) by the number of distance, default=8") 54 | parser.add_argument("--margin", type=float, default=-1, help="control margin of distance, default=8") 55 | parser.add_argument('--only_use_test', action='store_true', help='eval use test set.') 56 | # parser.add_argument('--alpha_is_trainable', action='store_true', help='use trainable alpha or not.') 57 | parser.add_argument("--trainable_alpha_init", type=float, default=1.0, help="alpha init.default=0.5") 58 | 59 | 60 | # MTNet 61 | parser.add_argument("--bert_lr", type=float, default=2e-5, help="learning rate of bert") 62 | parser.add_argument("--meta_lr", type=float, default=5e-4, help="learning rate of meta(out)") 63 | parser.add_argument("--task_lr", type=float, default=1e-1, help="learning rate of task(in)") 64 | parser.add_argument("--train_support_iter", type=int, default=3, help="Number of iterations of training(in)") 65 | parser.add_argument("--neg_num", type=int, default=1, help="the bias of layer norm") 66 | parser.add_argument("--use_proto_as_neg", action='store_true', help="use proto as neg") 67 | parser.add_argument("--use_diff_threshold", action='store_true', help="use different thresholds") 68 | parser.add_argument('--threshold_mode', default='max', help='mean or max') 69 | parser.add_argument("--ln_bias", type=float, default=10.0, help="the bias of layer norm for SiameseMAML and MTNet") 70 | parser.add_argument("--trainable_margin_init", type=float, default=6.0, help="use trainable margin, it's the init of it.default=8.5") 71 | parser.add_argument("--dropout", type=float, default=0.0, help="dropout") 72 | parser.add_argument("--bert_wd", type=float, default=1e-5, help="bert weight decay") 73 | parser.add_argument("--wobert_wd", type=float, default=1e-5, help="weight decay of param without bert") 74 | parser.add_argument('--multi_margin', action='store_true', help='multi adaptive margin') 75 | 76 | # Ablation study 77 | parser.add_argument('--label_name_mode', default='LnAsQKV', help='[mean, LnAsQ, LnAsQKV]') 78 | parser.add_argument('--tripletloss_mode', default='sig+dp+dn', help='[tl, tl+dp, sig+dp+dn]') 79 | parser.add_argument('--have_otherO', action='store_true', help='[for Ablation study of MTNet]') 80 | 81 | # for visualization 82 | parser.add_argument('--save_query_ebd', action='store_true', help='[save query ebd]') 83 | parser.add_argument('--load_ckpt_proto', default=None, help='load ckpt') 84 | parser.add_argument('--load_ckpt_metnet', default=None, help='load ckpt') 85 | 86 | 87 | opt = parser.parse_args() 88 | 89 | return opt 90 | 91 | 92 | def setSeed(seed): 93 | torch.manual_seed(seed) 94 | torch.cuda.manual_seed_all(seed) 95 | np.random.seed(seed) 96 | random.seed(seed) 97 | torch.backends.cudnn.deterministic = True 98 | 99 | 100 | def print_args(args): 101 | """ 102 | Print arguments (only show the relevant arguments) 103 | """ 104 | 105 | print("\nParameters:") 106 | for attr, value in sorted(args.__dict__.items()): 107 | print("\t{}={}".format(attr.upper(), value)) 108 | print(""" 109 | MTNet -> Go!Go!Go! 110 | """) 111 | 112 | return -------------------------------------------------------------------------------- /utils/contrastiveloss.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | # 自定义ContrastiveLoss 9 | class ContrastiveLoss(torch.nn.Module): 10 | """ 11 | Contrastive loss function. 12 | Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 13 | """ 14 | 15 | def __init__(self, args): 16 | super(ContrastiveLoss, self).__init__() 17 | self.args = args 18 | 19 | def forward(self, dis, label, margin): 20 | 21 | tmp1 = (label) * torch.pow(dis, 2).squeeze(-1) 22 | # mean_val = torch.mean(euclidean_distance) 23 | tmp2 = (1 - label) * torch.pow(torch.clamp(margin - dis, min=0.0), 24 | 2).squeeze(-1) 25 | loss_contrastive = torch.mean(tmp1 + tmp2) 26 | 27 | # print("**********************************************************************") 28 | return loss_contrastive 29 | 30 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | from .fewshotsampler import FewshotSampler, FewshotSampleBase 5 | import numpy as np 6 | import json 7 | 8 | def get_class_name(rawtag): 9 | # get (finegrained) class name 10 | if rawtag.startswith('B-') or rawtag.startswith('I-'): 11 | return rawtag[2:] 12 | else: 13 | return rawtag 14 | 15 | class Sample(FewshotSampleBase): 16 | def __init__(self, filelines): 17 | filelines = [line.split('\t') for line in filelines] 18 | filelines_new = [] 19 | for l in filelines: 20 | if len(l) == 2: 21 | filelines_new.append(l) 22 | self.words, self.tags = zip(*filelines_new) 23 | self.words = [word.lower() for word in self.words] 24 | # strip B-, I- 25 | self.normalized_tags = list(map(get_class_name, self.tags)) 26 | self.class_count = {} 27 | 28 | def __count_entities__(self): 29 | current_tag = self.normalized_tags[0] 30 | for tag in self.normalized_tags[1:]: 31 | if tag == current_tag: 32 | continue 33 | else: 34 | if current_tag != 'O': 35 | if current_tag in self.class_count: 36 | self.class_count[current_tag] += 1 37 | else: 38 | self.class_count[current_tag] = 1 39 | current_tag = tag 40 | if current_tag != 'O': 41 | if current_tag in self.class_count: 42 | self.class_count[current_tag] += 1 43 | else: 44 | self.class_count[current_tag] = 1 45 | 46 | def get_class_count(self): 47 | if self.class_count: 48 | return self.class_count 49 | else: 50 | self.__count_entities__() 51 | return self.class_count 52 | 53 | def get_tag_class(self): 54 | # strip 'B' 'I' 55 | tag_class = list(set(self.normalized_tags)) 56 | if 'O' in tag_class: 57 | tag_class.remove('O') 58 | return tag_class # tag_class:[organization-education, B, C, ...] 59 | 60 | def valid(self, target_classes): 61 | return (set(self.get_class_count().keys()).intersection(set(target_classes))) and not (set(self.get_class_count().keys()).difference(set(target_classes))) 62 | 63 | def __str__(self): 64 | newlines = zip(self.words, self.tags) 65 | return '\n'.join(['\t'.join(line) for line in newlines]) 66 | 67 | class FewShotNERDatasetWithRandomSampling(data.Dataset): 68 | """ 69 | Fewshot NER Dataset 70 | """ 71 | def __init__(self, filepath, tokenizer, N, K, Q, max_length, ignore_label_id=-1): 72 | if not os.path.exists(filepath): 73 | print("[ERROR] Data file does not exist!") 74 | assert(0) 75 | self.class2sampleid = {} 76 | self.N = N 77 | self.K = K 78 | self.Q = Q 79 | self.tokenizer = tokenizer 80 | self.samples, self.classes = self.__load_data_from_file__(filepath) 81 | self.max_length = max_length 82 | self.sampler = FewshotSampler(N, K, Q, self.samples, classes=self.classes) 83 | self.ignore_label_id = ignore_label_id 84 | 85 | def __insert_sample__(self, index, sample_classes): 86 | ''' 87 | get dict -> { 88 | 'label name(organization-education)':[0, 1, ...] 89 | ... 90 | } 91 | ''' 92 | for item in sample_classes: 93 | if item in self.class2sampleid: 94 | self.class2sampleid[item].append(index) 95 | else: 96 | self.class2sampleid[item] = [index] 97 | 98 | def __load_data_from_file__(self, filepath): 99 | samples = [] 100 | classes = [] 101 | with open(filepath, 'r', encoding='utf-8')as f: 102 | lines = f.readlines() 103 | samplelines = [] 104 | index = 0 105 | for line in lines: 106 | line = line.strip() 107 | if line: 108 | samplelines.append(line) 109 | else: 110 | sample = Sample(samplelines) 111 | samples.append(sample) 112 | sample_classes = sample.get_tag_class() 113 | self.__insert_sample__(index, sample_classes) 114 | classes += sample_classes # classes:[A, B, ...]除了O以外,其他标签的顺序排列结果(有重复) 115 | samplelines = [] 116 | index += 1 117 | if samplelines: 118 | sample = Sample(samplelines) 119 | samples.append(sample) 120 | sample_classes = sample.get_tag_class() 121 | self.__insert_sample__(index, sample_classes) 122 | classes += sample_classes 123 | samplelines = [] 124 | index += 1 125 | classes = list(set(classes)) 126 | return samples, classes # samples列表里全是sample类的实例化, classes列表里是label names 127 | 128 | def __get_token_label_list__(self, sample): 129 | tokens = [] 130 | labels = [] 131 | for word, tag in zip(sample.words, sample.normalized_tags): 132 | word_tokens = self.tokenizer.tokenize(word) # word就是一个个单词 133 | if word_tokens: 134 | tokens.extend(word_tokens) 135 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 136 | word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1) 137 | labels.extend(word_labels) 138 | return tokens, labels 139 | 140 | 141 | def __getraw__(self, tokens, labels): 142 | # get tokenized word list, attention mask, text mask (mask [CLS], [SEP] as well), tags 143 | 144 | # split into chunks of length (max_length-2) 145 | # 2 is for special tokens [CLS] and [SEP] 146 | tokens_list = [] 147 | labels_list = [] 148 | while len(tokens) > self.max_length - 2: 149 | tokens_list.append(tokens[:self.max_length-2]) 150 | tokens = tokens[self.max_length-2:] 151 | labels_list.append(labels[:self.max_length-2]) 152 | labels = labels[self.max_length-2:] 153 | if tokens: 154 | tokens_list.append(tokens) 155 | labels_list.append(labels) 156 | 157 | # add special tokens and get masks 158 | indexed_tokens_list = [] 159 | mask_list = [] 160 | text_mask_list = [] 161 | for i, tokens in enumerate(tokens_list): 162 | # token -> ids 163 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 164 | indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens) 165 | 166 | # padding 167 | while len(indexed_tokens) < self.max_length: 168 | indexed_tokens.append(0) 169 | indexed_tokens_list.append(indexed_tokens) 170 | 171 | # mask 172 | mask = np.zeros((self.max_length), dtype=np.int32) 173 | mask[:len(tokens)] = 1 174 | mask_list.append(mask) 175 | 176 | # text mask, also mask [CLS] and [SEP] 177 | text_mask = np.zeros((self.max_length), dtype=np.int32) 178 | text_mask[1:len(tokens)-1] = 1 179 | text_mask_list.append(text_mask) 180 | 181 | assert len(labels_list[i]) == len(tokens) - 2, print(labels_list[i], tokens) 182 | return indexed_tokens_list, mask_list, text_mask_list, labels_list 183 | 184 | def __additem__(self, index, d, word, mask, text_mask, label): 185 | d['index'].append(index) 186 | d['word'] += word 187 | d['mask'] += mask 188 | d['label'] += label 189 | d['text_mask'] += text_mask 190 | 191 | def __populate__(self, idx_list, savelabeldic=False): 192 | ''' 193 | populate samples into data dict 194 | set savelabeldic=True if you want to save label2tag dict 195 | 'index': sample_index 196 | 'word': tokenized word ids 197 | 'mask': attention mask in BERT 198 | 'label': NER labels 199 | 'sentence_num': number of sentences in this set (a batch contains multiple sets) 200 | 'text_mask': 0 for special tokens and paddings, 1 for real text 201 | ''' 202 | dataset = {'index':[], 'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'text_mask':[] } 203 | for idx in idx_list: 204 | tokens, labels = self.__get_token_label_list__(self.samples[idx]) 205 | word, mask, text_mask, label = self.__getraw__(tokens, labels) 206 | word = torch.tensor(word).long() 207 | mask = torch.tensor(np.array(mask)).long() 208 | text_mask = torch.tensor(np.array(text_mask)).long() 209 | self.__additem__(idx, dataset, word, mask, text_mask, label) 210 | dataset['sentence_num'] = [len(dataset['word'])] 211 | if savelabeldic: 212 | dataset['label2tag'] = [self.label2tag] 213 | return dataset 214 | 215 | def __getitem__(self, index): 216 | target_classes, support_idx, query_idx = self.sampler.__next__() 217 | # add 'O' and make sure 'O' is labeled 0 218 | distinct_tags = ['O'] + target_classes 219 | self.tag2label = {tag:idx for idx, tag in enumerate(distinct_tags)} 220 | self.label2tag = {idx:tag for idx, tag in enumerate(distinct_tags)} 221 | support_set = self.__populate__(support_idx) 222 | query_set = self.__populate__(query_idx, savelabeldic=True) 223 | return support_set, query_set 224 | 225 | def __len__(self): 226 | return 100000 227 | 228 | 229 | class FewShotNERDataset(FewShotNERDatasetWithRandomSampling): 230 | def __init__(self, filepath, tokenizer, max_length, ignore_label_id=-1): 231 | if not os.path.exists(filepath): 232 | print("[ERROR] Data file does not exist!") 233 | assert(0) 234 | self.class2sampleid = {} 235 | self.tokenizer = tokenizer 236 | self.samples = self.__load_data_from_file__(filepath) 237 | self.max_length = max_length 238 | self.ignore_label_id = ignore_label_id 239 | 240 | def __load_data_from_file__(self, filepath): 241 | with open(filepath)as f: 242 | lines = f.readlines() 243 | for i in range(len(lines)): 244 | lines[i] = json.loads(lines[i].strip()) 245 | return lines 246 | 247 | def __additem__(self, d, word, mask, text_mask, label): 248 | d['word'] += word 249 | d['mask'] += mask 250 | d['label'] += label 251 | d['text_mask'] += text_mask 252 | 253 | def __get_token_label_list__(self, words, tags): 254 | tokens = [] 255 | labels = [] 256 | for word, tag in zip(words, tags): 257 | word_tokens = self.tokenizer.tokenize(word) 258 | if word_tokens: 259 | tokens.extend(word_tokens) 260 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 261 | word_labels = [self.tag2label[tag]] + [self.ignore_label_id] * (len(word_tokens) - 1) 262 | labels.extend(word_labels) 263 | return tokens, labels 264 | 265 | def __populate__(self, data, savelabeldic=False): 266 | ''' 267 | populate samples into data dict 268 | set savelabeldic=True if you want to save label2tag dict 269 | 'word': tokenized word ids 270 | 'mask': attention mask in BERT 271 | 'label': NER labels 272 | 'sentence_num': number of sentences in this set (a batch contains multiple sets) 273 | 'text_mask': 0 for special tokens and paddings, 1 for real text 274 | ''' 275 | dataset = {'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'text_mask':[] } 276 | for i in range(len(data['word'])): 277 | tokens, labels = self.__get_token_label_list__(data['word'][i], data['label'][i]) 278 | word, mask, text_mask, label = self.__getraw__(tokens, labels) 279 | word = torch.tensor(word).long() 280 | mask = torch.tensor(mask).long() 281 | text_mask = torch.tensor(text_mask).long() 282 | self.__additem__(dataset, word, mask, text_mask, label) 283 | dataset['sentence_num'] = [len(dataset['word'])] 284 | if savelabeldic: 285 | dataset['label2tag'] = [self.label2tag] 286 | return dataset 287 | 288 | def __getitem__(self, index): 289 | sample = self.samples[index] 290 | target_classes = sample['types'] 291 | support = sample['support'] 292 | query = sample['query'] 293 | # add 'O' and make sure 'O' is labeled 0 294 | distinct_tags = ['O'] + target_classes 295 | self.tag2label = {tag:idx for idx, tag in enumerate(distinct_tags)} 296 | self.label2tag = {idx:tag for idx, tag in enumerate(distinct_tags)} 297 | support_set = self.__populate__(support) 298 | query_set = self.__populate__(query, savelabeldic=True) 299 | return support_set, query_set 300 | 301 | def __len__(self): 302 | return len(self.samples) 303 | 304 | 305 | def collate_fn(data): 306 | batch_support = {'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'text_mask':[]} 307 | batch_query = {'word': [], 'mask': [], 'label':[], 'sentence_num':[], 'label2tag':[], 'text_mask':[]} 308 | support_sets, query_sets = zip(*data) 309 | for i in range(len(support_sets)): 310 | for k in batch_support: 311 | batch_support[k] += support_sets[i][k] 312 | for k in batch_query: 313 | batch_query[k] += query_sets[i][k] 314 | for k in batch_support: 315 | if k != 'label' and k != 'sentence_num': 316 | batch_support[k] = torch.stack(batch_support[k], 0) 317 | for k in batch_query: 318 | if k !='label' and k != 'sentence_num' and k!= 'label2tag': 319 | batch_query[k] = torch.stack(batch_query[k], 0) 320 | batch_support['label'] = [torch.tensor(tag_list).long() for tag_list in batch_support['label']] 321 | batch_query['label'] = [torch.tensor(tag_list).long() for tag_list in batch_query['label']] 322 | return batch_support, batch_query 323 | 324 | def get_loader(filepath, tokenizer, N, K, Q, batch_size, max_length, 325 | num_workers=8, collate_fn=collate_fn, ignore_index=-1, use_sampled_data=True): 326 | if not use_sampled_data: 327 | dataset = FewShotNERDatasetWithRandomSampling(filepath, tokenizer, N, K, Q, max_length, ignore_label_id=ignore_index) 328 | else: 329 | dataset = FewShotNERDataset(filepath, tokenizer, max_length, ignore_label_id=ignore_index) 330 | data_loader = data.DataLoader(dataset=dataset, 331 | batch_size=batch_size, 332 | shuffle=True, 333 | pin_memory=True, 334 | num_workers=num_workers, 335 | collate_fn=collate_fn) 336 | return data_loader -------------------------------------------------------------------------------- /utils/fewshotsampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | class FewshotSampleBase: 3 | ''' 4 | Abstract Class 5 | DO NOT USE 6 | Build your own Sample class and inherit from this class 7 | ''' 8 | def __init__(self): 9 | self.class_count = {} 10 | 11 | def get_class_count(self): 12 | ''' 13 | return a dictionary of {class_name:count} in format {any : int} 14 | ''' 15 | return self.class_count 16 | 17 | 18 | class FewshotSampler: 19 | ''' 20 | sample one support set and one query set 21 | ''' 22 | def __init__(self, N, K, Q, samples, classes=None, random_state=0): 23 | ''' 24 | N: int, how many types in each set 25 | K: int, how many instances for each type in support set 26 | Q: int, how many instances for each type in query set 27 | samples: List[Sample], Sample class must have `get_class_count` attribute 28 | classes[Optional]: List[any], all unique classes in samples. If not given, the classes will be got from samples.get_class_count() 29 | random_state[Optional]: int, the random seed 30 | ''' 31 | self.K = K 32 | self.N = N 33 | self.Q = Q 34 | self.samples = samples 35 | self.__check__() # check if samples have correct types(check if get_class_count exists) 36 | if classes: 37 | self.classes = classes 38 | else: 39 | self.classes = self.__get_all_classes__() 40 | random.seed(random_state) 41 | 42 | def __get_all_classes__(self): 43 | classes = [] 44 | for sample in self.samples: 45 | classes += list(sample.get_class_count().keys()) 46 | return list(set(classes)) 47 | 48 | def __check__(self): 49 | for idx, sample in enumerate(self.samples): 50 | if not hasattr(sample,'get_class_count'): 51 | print('[ERROR] samples in self.samples expected to have `get_class_count` attribute, but self.samples[{idx}] does not') 52 | raise ValueError 53 | 54 | def __additem__(self, index, set_class): 55 | class_count = self.samples[index].get_class_count() 56 | for class_name in class_count: 57 | if class_name in set_class: 58 | set_class[class_name] += class_count[class_name] 59 | else: 60 | set_class[class_name] = class_count[class_name] 61 | 62 | def __valid_sample__(self, sample, set_class, target_classes): 63 | threshold = 2 * set_class['k'] 64 | class_count = sample.get_class_count() 65 | if not class_count: 66 | return False 67 | isvalid = False 68 | for class_name in class_count: 69 | if class_name not in target_classes: 70 | return False 71 | if class_count[class_name] + set_class.get(class_name, 0) > threshold: 72 | return False 73 | if set_class.get(class_name, 0) < set_class['k']: 74 | isvalid = True 75 | return isvalid 76 | 77 | def __finish__(self, set_class): 78 | if len(set_class) < self.N+1: 79 | return False 80 | for k in set_class: 81 | if set_class[k] < set_class['k']: 82 | return False 83 | return True 84 | 85 | def __get_candidates__(self, target_classes): 86 | return [idx for idx, sample in enumerate(self.samples) if sample.valid(target_classes)] 87 | 88 | def __next__(self): 89 | ''' 90 | randomly sample one support set and one query set 91 | return: 92 | target_classes: List[any] 93 | support_idx: List[int], sample index in support set in samples list 94 | support_idx: List[int], sample index in query set in samples list 95 | ''' 96 | support_class = {'k':self.K} 97 | support_idx = [] 98 | query_class = {'k':self.Q} 99 | query_idx = [] 100 | target_classes = random.sample(self.classes, self.N) 101 | candidates = self.__get_candidates__(target_classes) 102 | while not candidates: 103 | target_classes = random.sample(self.classes, self.N) 104 | candidates = self.__get_candidates__(target_classes) 105 | 106 | # greedy search for support set 107 | while not self.__finish__(support_class): 108 | index = random.choice(candidates) 109 | if index not in support_idx: 110 | if self.__valid_sample__(self.samples[index], support_class, target_classes): 111 | self.__additem__(index, support_class) 112 | support_idx.append(index) 113 | # same for query set 114 | while not self.__finish__(query_class): 115 | index = random.choice(candidates) 116 | if index not in query_idx and index not in support_idx: 117 | if self.__valid_sample__(self.samples[index], query_class, target_classes): 118 | self.__additem__(index, query_class) 119 | query_idx.append(index) 120 | return target_classes, support_idx, query_idx 121 | 122 | def __iter__(self): 123 | return self 124 | 125 | 126 | -------------------------------------------------------------------------------- /utils/framework_mtnet.py: -------------------------------------------------------------------------------- 1 | from cmath import nan 2 | from copy import deepcopy 3 | import os 4 | import sklearn.metrics 5 | import numpy as np 6 | import pandas as pd 7 | import sys 8 | import time 9 | from collections import OrderedDict 10 | from . import word_encoder 11 | from . import data_loader 12 | import torch 13 | from torch import autograd, optim, nn, threshold 14 | from torch.autograd import Variable 15 | from torch.nn import functional as F 16 | # from pytorch_pretrained_bert import BertAdam 17 | from transformers import AdamW, get_linear_schedule_with_warmup 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | 20 | from .viterbi import ViterbiDecoder 21 | from utils.tripletloss import TripletLoss 22 | 23 | 24 | class FewShotNERFramework_MTNet: 25 | 26 | def __init__(self, train_data_loader, val_data_loader, test_data_loader, args, tokenizer, use_sampled_data=False): 27 | ''' 28 | train_data_loader: DataLoader for training. 29 | val_data_loader: DataLoader for validating. 30 | test_data_loader: DataLoader for testing. 31 | ''' 32 | self.train_data_loader = train_data_loader 33 | self.val_data_loader = val_data_loader 34 | self.test_data_loader = test_data_loader 35 | self.args = args 36 | self.tokenizer = tokenizer 37 | self.use_sampled_data = use_sampled_data 38 | 39 | def __load_model__(self, ckpt): 40 | ''' 41 | ckpt: Path of the checkpoint 42 | return: Checkpoint dict 43 | ''' 44 | if os.path.isfile(ckpt): 45 | checkpoint = torch.load(ckpt) 46 | print("Successfully loaded checkpoint '%s'" % ckpt) 47 | return checkpoint 48 | else: 49 | raise Exception("No checkpoint found at '%s'" % ckpt) 50 | 51 | def item(self, x): 52 | ''' 53 | PyTorch before and after 0.4 54 | ''' 55 | torch_version = torch.__version__.split('.') 56 | if int(torch_version[0]) == 0 and int(torch_version[1]) < 4: 57 | return x[0] 58 | else: 59 | return x.item() 60 | 61 | def __generate_label_data__(self, query): 62 | label_tokens_index = [] 63 | label_tokens_mask = [] 64 | label_text_mask = [] 65 | if self.args.have_otherO is True: 66 | for label_dic in query['label2tag']: 67 | for label_id in label_dic: 68 | if label_id == 0: 69 | label_tokens = ['other'] 70 | else: 71 | label_tokens = label_dic[label_id].split('-') 72 | label_tokens = ['[CLS]'] + label_tokens + ['[SEP]'] 73 | indexed_label_tokens = self.tokenizer.convert_tokens_to_ids(label_tokens) 74 | # padding 75 | while len(indexed_label_tokens) < 10: 76 | indexed_label_tokens.append(0) 77 | label_tokens_index.append(indexed_label_tokens) 78 | # mask 79 | mask = np.zeros((10), dtype=np.int32) 80 | mask[:len(label_tokens)] = 1 81 | label_tokens_mask.append(mask) 82 | # text mask, also mask [CLS] and [SEP] 83 | text_mask = np.zeros((10), dtype=np.int32) 84 | text_mask[1:len(label_tokens)-1] = 1 85 | label_text_mask.append(text_mask) 86 | else: 87 | for label_dic in query['label2tag']: 88 | for label_id in label_dic: 89 | if label_id != 0: 90 | label_tokens = label_dic[label_id].split('-') 91 | label_tokens = ['[CLS]'] + label_tokens + ['[SEP]'] 92 | indexed_label_tokens = self.tokenizer.convert_tokens_to_ids(label_tokens) 93 | # padding 94 | while len(indexed_label_tokens) < 10: 95 | indexed_label_tokens.append(0) 96 | label_tokens_index.append(indexed_label_tokens) 97 | # mask 98 | mask = np.zeros((10), dtype=np.int32) 99 | mask[:len(label_tokens)] = 1 100 | label_tokens_mask.append(mask) 101 | # text mask, also mask [CLS] and [SEP] 102 | text_mask = np.zeros((10), dtype=np.int32) 103 | text_mask[1:len(label_tokens)-1] = 1 104 | label_text_mask.append(text_mask) 105 | 106 | label_tokens_index = torch.Tensor(label_tokens_index).long().cuda() 107 | label_tokens_mask = torch.Tensor(label_tokens_mask).long().cuda() 108 | label_text_mask = torch.Tensor(label_text_mask).long().cuda() 109 | 110 | label_data = {} 111 | label_data['word'] = label_tokens_index 112 | label_data['mask'] = label_tokens_mask 113 | label_data['text_mask'] = label_text_mask 114 | return label_data 115 | 116 | def __zero_grad__(self, params): 117 | for p in params: 118 | if p.grad is not None: 119 | p.grad.zero_() 120 | 121 | def __get_sample_pairs__(self, data): 122 | data_1 = {} 123 | data_2 = {} 124 | data_1['word_emb'] = data['word_emb'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 125 | data_1['label'] = data['label'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 126 | data_2['word_emb'] = data['word_emb'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 127 | data_2['label'] = data['label'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 128 | 129 | return data_1, data_2 130 | 131 | def __generate_pair_label__(self, label1, label2): 132 | pair_label = [] 133 | for l1 in label1: 134 | for l2 in label2: 135 | if l1 == l2: 136 | pair_label.append(1.0) 137 | else: 138 | pair_label.append(0.0) 139 | return torch.Tensor(pair_label).cuda() 140 | 141 | def __generate_query_pair_label__(self, query_dis, query_label): 142 | query_pair_label = [] 143 | after_query_dis = [] 144 | for i, l in enumerate(query_label): 145 | tmp = torch.zeros([1, self.args.N]) 146 | if l == -1: 147 | continue 148 | elif l == 0: 149 | query_pair_label.append(tmp) 150 | after_query_dis.append(query_dis[i]) 151 | else: 152 | tmp[0, l-1] = 1 153 | query_pair_label.append(tmp) 154 | after_query_dis.append(query_dis[i]) 155 | query_pair_label = torch.cat(query_pair_label, dim=0).view(-1).cuda() 156 | after_query_dis = torch.stack(after_query_dis).view(-1).cuda() 157 | 158 | return after_query_dis, query_pair_label 159 | 160 | def __get_proto__(self, label_data_emb, label_data_text_mask, support_emb, support, model): 161 | if self.args.label_name_mode == 'mean': 162 | temp_word_list = [] 163 | for i, word_emb_list in enumerate(support_emb): 164 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 165 | temp_label_list = [] 166 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 167 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 168 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 169 | Proto = [] 170 | if self.args.have_otherO is True: 171 | for i in range(self.args.N+1): 172 | Proto.append(torch.mean(temp_word_list[temp_label_list==i], dim=0).view(1,-1)) 173 | else: 174 | for i in range(self.args.N): 175 | Proto.append(torch.mean(temp_word_list[temp_label_list==i+1], dim=0).view(1,-1)) 176 | Proto = torch.cat(Proto) 177 | elif self.args.label_name_mode == 'LnAsQ': 178 | # get Q = mean(init label name) 179 | Q = [] 180 | K = {} 181 | assert label_data_emb.shape[0] == label_data_text_mask.shape[0] 182 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 183 | p = l_ebd[label_data_text_mask[i]==1] 184 | # K[i] = p 185 | p = p.mean(dim=0) 186 | Q.append(p.view(1,-1)) 187 | Q = torch.cat(Q,0) 188 | 189 | # get K or V = cat(label name and word in class i) 190 | temp_word_list = [] 191 | for i, word_emb_list in enumerate(support_emb): 192 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 193 | temp_label_list = [] 194 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 195 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 196 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 197 | if self.args.have_otherO is True: 198 | for i in range(self.args.N+1): 199 | K[i] = temp_word_list[temp_label_list==i] 200 | else: 201 | for i in range(self.args.N): 202 | K[i] = temp_word_list[temp_label_list==i+1] 203 | 204 | # Attention 205 | Proto = [] 206 | for i, q in enumerate(Q): 207 | temp = torch.mm(model.att(q.view(1, -1)), K[i].t()) 208 | att_weights = F.softmax(F.layer_norm(temp, normalized_shape=(temp.shape[0],temp.shape[1])), dim=1) 209 | # print("att_weights:", att_weights) 210 | proto = torch.mm(att_weights, K[i]) # [1, 768] 211 | Proto.append(proto) 212 | Proto = torch.cat(Proto) 213 | elif self.args.label_name_mode == 'LnAsQKV': 214 | # get Q = mean(init label name) 215 | Q = [] 216 | K = {} 217 | assert label_data_emb.shape[0] == label_data_text_mask.shape[0] 218 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 219 | p = l_ebd[label_data_text_mask[i]==1] 220 | K[i] = p 221 | p = p.mean(dim=0) 222 | Q.append(p.view(1,-1)) 223 | Q = torch.cat(Q,0) 224 | 225 | # get K or V = cat(label name and word in class i) 226 | temp_word_list = [] 227 | for i, word_emb_list in enumerate(support_emb): 228 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 229 | temp_label_list = [] 230 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 231 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 232 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 233 | 234 | if self.args.have_otherO is True: 235 | for i in range(self.args.N+1): 236 | K[i] = torch.cat((K[i],temp_word_list[temp_label_list==i]),dim=0) 237 | else: 238 | for i in range(self.args.N): 239 | K[i] = torch.cat((K[i],temp_word_list[temp_label_list==i+1]),dim=0) 240 | 241 | # Attention 242 | Proto = [] 243 | for i, q in enumerate(Q): 244 | temp = torch.mm(model.att(q.view(1, -1)), K[i].t()) 245 | att_weights = F.softmax(F.layer_norm(temp, normalized_shape=(temp.shape[0],temp.shape[1])), dim=1) 246 | # print("att_weights:", att_weights) 247 | proto = torch.mm(att_weights, K[i]) # [1, 768] 248 | Proto.append(proto) 249 | Proto = torch.cat(Proto) 250 | else: 251 | raise NotImplementedError 252 | 253 | return Proto 254 | 255 | def __pos_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 256 | return torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 257 | 258 | def train(self, 259 | model, 260 | model_name, 261 | learning_rate=1e-4, 262 | train_iter=30000, 263 | val_iter=1000, 264 | val_step=2000, 265 | load_ckpt=None, 266 | save_ckpt=None, 267 | warmup_step=300, 268 | grad_iter=1, 269 | fp16=False, 270 | use_sgd_for_bert=False): 271 | ''' 272 | model: a FewShotREModel instance 273 | model_name: Name of the model 274 | B: Batch size 275 | N: Num of classes for each batch 276 | K: Num of instances for each class in the support set 277 | Q: Num of instances for each class in the query set 278 | ckpt_dir: Directory of checkpoints 279 | learning_rate: Initial learning rate 280 | train_iter: Num of iterations of training 281 | val_iter: Num of iterations of validating 282 | val_step: Validate every val_step steps 283 | ''' 284 | print("Start training...") 285 | 286 | # Init optimizer 287 | print('Use bert optim!') 288 | 289 | # set Bert learning rate 290 | parameters_to_optimize = list(model.word_encoder.named_parameters()) 291 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 292 | parameters_to_optimize = [ 293 | {'params': [p for n, p in parameters_to_optimize 294 | if not any(nd in n for nd in no_decay)], 'weight_decay': self.args.bert_wd}, 295 | {'params': [p for n, p in parameters_to_optimize 296 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 297 | ] 298 | if use_sgd_for_bert: 299 | bert_optimizer = torch.optim.SGD(parameters_to_optimize, lr=self.args.bert_lr) 300 | else: 301 | bert_optimizer = AdamW(parameters_to_optimize, lr=self.args.bert_lr, correct_bias=False) 302 | bert_scheduler = get_linear_schedule_with_warmup(bert_optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter) 303 | 304 | # set learning rate of model without Bert 305 | parameters_to_optimize = list(model.named_parameters()) 306 | without = ['word_encoder'] 307 | no_decay = ['bias'] 308 | parameters_to_optimize = [ 309 | {'params': [p for n, p in parameters_to_optimize 310 | if not (any(nd in n for nd in no_decay) or any(wo in n for wo in without))], 'weight_decay': self.args.wobert_wd}, 311 | {'params': [p for n, p in parameters_to_optimize 312 | if any(nd in n for nd in no_decay) and not any(wo in n for wo in without)], 'weight_decay': 0.0} 313 | ] 314 | wobert_optimizer = AdamW(parameters_to_optimize, lr=self.args.meta_lr, correct_bias=False) 315 | wobert_scheduler = get_linear_schedule_with_warmup(wobert_optimizer, num_warmup_steps=warmup_step, num_training_steps=train_iter) 316 | 317 | # load model 318 | if load_ckpt: 319 | state_dict = self.__load_model__(load_ckpt)['state_dict'] 320 | own_state = model.state_dict() 321 | for name, param in state_dict.items(): 322 | if name not in own_state: 323 | print('ignore {}'.format(name)) 324 | continue 325 | print('load {} from {}'.format(name, load_ckpt)) 326 | own_state[name].copy_(param) 327 | 328 | if fp16: 329 | from apex import amp 330 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 331 | 332 | model.train() 333 | loss_func = TripletLoss(args=self.args) 334 | 335 | # Training 336 | best_f1 = 0.0 337 | iter_loss = 0.0 338 | iter_sample = 0 339 | pred_cnt = 0 340 | label_cnt = 0 341 | correct_cnt = 0 342 | 343 | it = 0 344 | while it + 1 < train_iter: 345 | for _, (support, query) in enumerate(self.train_data_loader): 346 | ''' 347 | support/query: 348 | { 349 | 'word': 2维tensor矩阵[~n*k, max_length], 里面都是单词在此表中的索引号[[101, 1996,...,0],[...]], 350 | 'mask': 同上, 补PAD的地方为0, 其他地方为1[[1, 1,..., 0],[...]], 351 | 'label': 列表[tensor([0, 1,..., 0, 0]), tensor([0, 0,..., 0]),...]set(-1, 0, 1, 2), 352 | 'sentence_num': [5, 5, 5, 5](长度为batch_size大小, 每个位置表示单个batch中的句子数目), 353 | 'text_mask': 与mask类似, 就是补CLS和SEP的位置也都为0了, 354 | query独有: 355 | 'label2tag': # 对应一个batch里的4个部分 356 | [ 357 | { 0: 'O', 358 | 1: 'product-software', 359 | 2: 'location-island', 360 | 3: 'person-director', 361 | 4: 'event-protest', 362 | 5: 'other-disease' 363 | }, 364 | { 365 | 0: 'O', 366 | 1: 'location-GPE', 367 | 2: 'location-road/railway/highway/transit', 368 | 3: 'person-director', 369 | 4: 'other-biologything', 370 | 5: 'building-airport' 371 | }, 372 | { 373 | 0: 'O', 374 | 1: 'event-attack/battle/war/militaryconflict', 375 | 2: 'product-software', 376 | 3: 'other-award', 377 | 4: 'building-restaurant', 378 | 5: 'person-politician' 379 | }, 380 | { 381 | 0: 'O', 382 | 1: 'person-artist/author', 383 | 2: 'building-hotel', 384 | 3: 'other-award', 385 | 4: 'location-mountain', 386 | 5: 'other-god' 387 | } 388 | ] 389 | } 390 | ''' 391 | margin = model.param 392 | alpha = model.alpha 393 | if torch.cuda.is_available(): 394 | for k in support: 395 | if k != 'label' and k != 'sentence_num': 396 | support[k] = support[k].cuda() 397 | query[k] = query[k].cuda() 398 | query_label = torch.cat(query['label'], 0) 399 | query_label = query_label.cuda() 400 | 401 | # get proto init rep 402 | label_data = self.__generate_label_data__(query) 403 | label_data_emb = model.word_encoder(label_data['word'], label_data['mask']) # [num_label_sent, 10, 768] 404 | support_emb = model.word_encoder(support['word'], support['mask']) 405 | Proto = self.__get_proto__(label_data_emb, label_data['text_mask'], support_emb, support, model) #[N, 768] 406 | 407 | # support,proto -> MLP -> new emb 408 | support_label = torch.cat(support['label'], dim=0) 409 | support_emb = support_emb[support['text_mask']==1] 410 | support_afterMLP_emb = model(support_emb) 411 | proto_afterMLP_emb = model(Proto) 412 | if self.args.use_proto_as_neg is True: 413 | if self.args.have_otherO is True: 414 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(self.args.N)],dtype=torch.int64))) 415 | else: 416 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 417 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 418 | support_dis = [] 419 | if self.args.have_otherO is True: 420 | for i in range(self.args.N+1): 421 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 422 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 423 | temp_lst[-(self.args.N+1-i)] = 1 424 | temp_lst = np.array(temp_lst) 425 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 426 | support_dis.append(support_dis_one_line) 427 | else: 428 | for i in range(self.args.N): 429 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 430 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 431 | temp_lst[-(self.args.N-i)] = 1 432 | temp_lst = np.array(temp_lst) 433 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 434 | support_dis.append(support_dis_one_line) 435 | support_dis = torch.cat(support_dis, dim=0).view(-1) # [N+1, N*K] 436 | else: 437 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) # [N, N*K] 438 | # if self.args.use_diff_threshold == False: 439 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 440 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 441 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 442 | if self.args.have_otherO is True: 443 | support_dis = support_dis.view(self.args.N+1, -1) 444 | else: 445 | support_dis = support_dis.view(self.args.N, -1) 446 | # if self.args.use_diff_threshold == True: 447 | # margin = torch.mean(support_dis) 448 | support_loss = loss_func(support_dis, support_label, margin, alpha) 449 | 450 | self.__zero_grad__(model.fc.parameters()) 451 | grads_fc = autograd.grad(support_loss, model.fc.parameters(), allow_unused=True, retain_graph=True) 452 | fast_weights_fc, orderd_params_fc = model.cloned_fc_dict(), OrderedDict() 453 | for (key, val), grad in zip(model.fc.named_parameters(), grads_fc): 454 | fast_weights_fc[key] = orderd_params_fc[key] = val - self.args.task_lr * grad # grad中weight数量级是1e-4,bias是1e-11,有点太小了? 455 | 456 | fast_weights = {} 457 | fast_weights['fc'] = fast_weights_fc 458 | 459 | train_support_loss = [] 460 | for _ in range(self.args.train_support_iter - 1): 461 | support_afterMLP_emb = model(support_emb, fast_weights) 462 | proto_afterMLP_emb = model(Proto, fast_weights) 463 | if self.args.use_proto_as_neg == True: 464 | # support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 465 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 466 | support_dis = [] 467 | if self.args.have_otherO is True: 468 | for i in range(self.args.N+1): 469 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 470 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 471 | temp_lst[-(self.args.N+1-i)] = 1 472 | temp_lst = np.array(temp_lst) 473 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 474 | support_dis.append(support_dis_one_line) 475 | else: 476 | for i in range(self.args.N): 477 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 478 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 479 | temp_lst[-(self.args.N-i)] = 1 480 | temp_lst = np.array(temp_lst) 481 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 482 | support_dis.append(support_dis_one_line) 483 | support_dis = torch.cat(support_dis, dim=0).view(-1) 484 | else: 485 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) 486 | # if self.args.use_diff_threshold == False: 487 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 488 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 489 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 490 | if self.args.have_otherO is True: 491 | support_dis = support_dis.view(self.args.N+1, -1) 492 | else: 493 | support_dis = support_dis.view(self.args.N, -1) 494 | # if self.args.use_diff_threshold == True: 495 | # margin = torch.mean(support_dis) 496 | support_loss = loss_func(support_dis, support_label, margin, alpha) 497 | train_support_loss.append(support_loss.item()) 498 | # print_info = 'train_support, ' + str(support_loss.item()) 499 | # print('\033[0;31;40m{}\033[0m'.format(print_info)) 500 | self.__zero_grad__(orderd_params_fc.values()) 501 | 502 | grads_fc = torch.autograd.grad(support_loss, orderd_params_fc.values(), allow_unused=True, retain_graph=True) 503 | for (key, val), grad in zip(orderd_params_fc.items(), grads_fc): 504 | if grad is not None: 505 | fast_weights['fc'][key] = orderd_params_fc[key] = val - self.args.task_lr * grad 506 | 507 | # query, proto -> MLP -> new emb 508 | query_emb = model.word_encoder(query['word'], query['mask']) 509 | query_emb = query_emb[query['text_mask']==1] 510 | query_afterMLP_emb = model(query_emb, fast_weights) 511 | proto_afterMLP_emb = model(Proto, fast_weights) 512 | if self.args.use_proto_as_neg == True: 513 | if self.args.have_otherO is True: 514 | query_label = torch.cat((query_label, torch.tensor([0 for _ in range(self.args.N)],dtype=torch.int64).cuda())) 515 | else: 516 | query_label = torch.cat((query_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64).cuda())) 517 | query_afterMLP_emb = torch.cat((query_afterMLP_emb,proto_afterMLP_emb+1e-8),dim=0) 518 | query_dis = [] 519 | if self.args.have_otherO is True: 520 | for i in range(self.args.N+1): 521 | query_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), query_afterMLP_emb) 522 | temp_lst = [0 for _ in range(query_dis_one_line.shape[1])] 523 | temp_lst[-(self.args.N+1-i)] = 1 524 | temp_lst = np.array(temp_lst) 525 | query_dis_one_line = query_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 526 | query_dis.append(query_dis_one_line) 527 | else: 528 | for i in range(self.args.N): 529 | query_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), query_afterMLP_emb) 530 | temp_lst = [0 for _ in range(query_dis_one_line.shape[1])] 531 | temp_lst[-(self.args.N-i)] = 1 532 | temp_lst = np.array(temp_lst) 533 | query_dis_one_line = query_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 534 | query_dis.append(query_dis_one_line) 535 | query_dis = torch.cat(query_dis, dim=0).view(-1) 536 | else: 537 | query_dis = self.__pos_dist__(proto_afterMLP_emb, query_afterMLP_emb).view(-1) # [N, N*K] 538 | # if self.args.use_diff_threshold == False: 539 | # print("query_dis_before_norm:", torch.max(query_dis).item(), torch.min(query_dis).item(), torch.mean(query_dis).item()) 540 | query_dis = F.layer_norm(query_dis, normalized_shape=[query_dis.shape[0]], bias=torch.full((query_dis.shape[0],), self.args.ln_bias).cuda()) 541 | # print("query_dis_before_norm:", torch.max(query_dis).item(), torch.min(query_dis).item(), torch.mean(query_dis).item()) 542 | if self.args.have_otherO is True: 543 | query_dis = query_dis.view(self.args.N+1, -1) 544 | else: 545 | query_dis = query_dis.view(self.args.N, -1) 546 | # if self.args.use_diff_threshold == True: 547 | # margin = torch.mean(query_dis) 548 | query_loss = loss_func(query_dis, query_label, margin, alpha) 549 | 550 | # update param 551 | bert_optimizer.zero_grad() 552 | wobert_optimizer.zero_grad() 553 | query_loss.backward() 554 | bert_optimizer.step() 555 | wobert_optimizer.step() 556 | bert_scheduler.step() 557 | wobert_scheduler.step() 558 | 559 | # make prediction 560 | if self.args.use_proto_as_neg == True: 561 | if self.args.have_otherO is True: 562 | query_dis = query_dis[:, :-(self.args.N)] 563 | query_label = query_label[:-(self.args.N)] 564 | else: 565 | query_dis = query_dis[:, :-(self.args.N-1)] 566 | query_label = query_label[:-(self.args.N-1)] 567 | 568 | if self.args.use_diff_threshold == True: 569 | threshold = [] 570 | if self.args.have_otherO is True: 571 | for i in range(0, self.args.N+1): 572 | if self.args.threshold_mode == 'mean': 573 | threshold.append(torch.mean(query_dis[i][query_label==i]).item()) 574 | elif self.args.threshold_mode == 'max': 575 | threshold.append(torch.max(query_dis[i][query_label==i]).item()) 576 | else: 577 | raise NotImplementedError 578 | else: 579 | for i in range(1, self.args.N+1): 580 | if self.args.threshold_mode == 'mean': 581 | threshold.append(torch.mean(query_dis[i-1][query_label==i]).item()) 582 | elif self.args.threshold_mode == 'max': 583 | threshold.append(torch.max(query_dis[i-1][query_label==i]).item()) 584 | else: 585 | raise NotImplementedError 586 | 587 | print("threshold:", threshold) 588 | query_dis = query_dis.t() # [x, N] 589 | query_pred = [] 590 | 591 | if self.args.have_otherO is True: 592 | for tmp in query_dis: 593 | query_pred.append(torch.min(tmp, dim=0)[1].item()) 594 | else: 595 | for tmp in query_dis: 596 | flag = [] 597 | for tm, th in zip(tmp, threshold): 598 | if tm < th: 599 | flag.append(tm.view(1,-1)) 600 | else: 601 | flag.append(torch.tensor([[9999.0]]).cuda()) 602 | flag = torch.cat(flag).view(-1) 603 | if torch.min(flag) != 9999.0: 604 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 605 | else: 606 | query_pred.append(0) 607 | query_pred = torch.Tensor(query_pred).cuda() 608 | else: 609 | query_dis = query_dis.t() # [x, N] 610 | query_pred = [] 611 | 612 | if self.args.have_otherO is True: 613 | for tmp in query_dis: 614 | query_pred.append(torch.min(tmp, dim=0)[1].item()) 615 | else: 616 | if self.args.multi_margin is True: 617 | for tmp in query_dis: 618 | flag = [] 619 | for tm, th in zip(tmp, margin): 620 | if tm < th: 621 | flag.append(tm.view(1,-1)) 622 | else: 623 | flag.append(torch.tensor([[9999.0]]).cuda()) 624 | flag = torch.cat(flag).view(-1) 625 | if torch.min(flag) != 9999.0: 626 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 627 | else: 628 | query_pred.append(0) 629 | else: 630 | for tmp in query_dis: 631 | if any(t < margin for t in tmp): 632 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 633 | else: 634 | query_pred.append(0) 635 | query_pred = torch.Tensor(query_pred).cuda() 636 | 637 | assert query_pred.shape[0] == query_label.shape[0] 638 | 639 | tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(query_pred, query_label) 640 | 641 | iter_loss += self.item(query_loss.data) 642 | pred_cnt += tmp_pred_cnt 643 | label_cnt += tmp_label_cnt 644 | correct_cnt += correct 645 | iter_sample += 1 646 | 647 | if (it + 1) % 100 == 0 or (it + 1) % val_step == 0: 648 | precision = correct_cnt / (pred_cnt + 1e-10) 649 | recall = correct_cnt / (label_cnt + 1e-10) 650 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 651 | print('step: {0:4} | loss: {1:2.6f} | [ENTITY] precision: {2:3.4f}, recall: {3:3.4f}, f1: {4:3.4f}'\ 652 | .format(it + 1, iter_loss/ iter_sample, precision, recall, f1) + '\r') 653 | print('margin:', margin) 654 | # print('alpha:', alpha.item()) 655 | # a = deepcopy(alpha) 656 | # print('alpha_after_sigmoid:', torch.sigmoid(a).item()) 657 | 658 | iter_loss = 0. 659 | iter_sample = 0. 660 | pred_cnt = 0 661 | label_cnt = 0 662 | correct_cnt = 0 663 | 664 | if (it + 1) % val_step == 0: 665 | # torch.save({'state_dict': model.state_dict()}, 'current_siamese.ckpt') 666 | _, _, f1, _, _, _, _ = self.eval(model, val_iter) 667 | model.train() 668 | if f1 > best_f1: 669 | print('Best checkpoint') 670 | torch.save({'state_dict': model.state_dict()}, save_ckpt) 671 | best_f1 = f1 672 | iter_loss = 0. 673 | iter_sample = 0. 674 | pred_cnt = 0 675 | label_cnt = 0 676 | correct_cnt = 0 677 | 678 | if (it + 1) == train_iter: 679 | break 680 | it += 1 681 | 682 | print("\n####################\n") 683 | print("Finish training " + model_name) 684 | 685 | def __save_test_inference__(self, logits, pred, query): 686 | 687 | # query 去掉-1 688 | new_query_label = [] 689 | new_query_word = [] 690 | new_query_textmask = [] 691 | for lq, wq, tq in zip(query['label'], query['word'], query['text_mask']): 692 | pass 693 | 694 | 695 | # logits = F.softmax(logits, dim=-1) 696 | # 将word转为真实单词 697 | sentence_list = [] # 二维列表 698 | for words, mask in zip(query['word'], query['text_mask']): 699 | real_words = [] 700 | for word in words[mask==1]: 701 | real_words.append(self.tokenizer.decode(word).replace(" ", "")) 702 | sentence_list.append(real_words) 703 | 704 | # 将label和pred转为label word 705 | sentence_num = [] 706 | sentence_num.append(0) 707 | tmp = 0 708 | for i in query['sentence_num']: 709 | tmp += i 710 | sentence_num.append(tmp) 711 | real_label_list = [] # 二维列表 712 | pred_label_list = [] # 二维列表 713 | label_name_list = [] # 二维列表 714 | # pred和logits切成二维矩阵 715 | pred_list = [] 716 | # logits_list = [] 717 | sentence_len = [] 718 | sentence_len.append(0) 719 | tmp = 0 720 | for _labels in query['label']: 721 | tmp += _labels.shape[0] 722 | sentence_len.append(tmp) 723 | for i in range(len(sentence_len)-1): 724 | tmp2 = pred[sentence_len[i]: sentence_len[i+1]] 725 | # tmp3 = logits[sentence_len[i]: sentence_len[i+1]] 726 | pred_list.append(tmp2.cpu()) 727 | # logits_list.append(tmp3.cpu().detach().numpy().tolist()) 728 | 729 | for i in range(len(sentence_num)-1): 730 | for j in range(sentence_num[i], sentence_num[i+1]): 731 | tmp_label_list = [] 732 | tmp_pred_list = [] 733 | tmp_label_name_list = [] 734 | assert query['label'][j].shape[0] == pred_list[j].shape[0] 735 | for l, p in zip(query['label'][j], pred_list[j]): 736 | if l == -1: 737 | tmp_label_list.append(str(-1)) 738 | else: 739 | tmp_label_list.append(query['label2tag'][i][l.item()]) 740 | tmp_pred_list.append(query['label2tag'][i][p.item()]) 741 | tmp_label_name_list.append(str(query['label2tag'][i])) 742 | real_label_list.append(tmp_label_list) 743 | pred_label_list.append(tmp_pred_list) 744 | label_name_list.append(tmp_label_name_list) # 每个元任务的label_list 745 | 746 | return sentence_list, real_label_list, pred_label_list, label_name_list 747 | 748 | def eval(self, 749 | model, 750 | eval_iter, 751 | ckpt=None): 752 | ''' 753 | model: a FewShotREModel instance 754 | B: Batch size 755 | N: Num of classes for each batch 756 | K: Num of instances for each class in the support set 757 | Q: Num of instances for each class in the query set 758 | eval_iter: Num of iterations 759 | ckpt: Checkpoint path. Set as None if using current model parameters. 760 | return: Accuracy 761 | ''' 762 | print("evaluating...") 763 | 764 | model.eval() 765 | loss_func = TripletLoss(args=self.args) 766 | 767 | if ckpt is None: 768 | print("Use val dataset") 769 | eval_dataset = self.val_data_loader 770 | # print("Use test dataset") 771 | # eval_dataset = self.test_data_loader 772 | else: 773 | print("Use test dataset") 774 | if ckpt != 'none': 775 | state_dict = self.__load_model__(ckpt)['state_dict'] 776 | own_state = model.state_dict() 777 | for name, param in state_dict.items(): 778 | if name not in own_state: 779 | continue 780 | own_state[name].copy_(param) 781 | eval_dataset = self.test_data_loader 782 | 783 | if self.args.only_use_test: 784 | eval_dataset = self.test_data_loader 785 | 786 | margin = model.param 787 | alpha = model.alpha 788 | print("margin:", margin) 789 | print("alpha:", alpha) 790 | 791 | if self.args.margin != -1: 792 | margin = self.args.margin 793 | print('set margin:', margin) 794 | 795 | pred_cnt = 0 # pred entity cnt 796 | label_cnt = 0 # true label entity cnt 797 | correct_cnt = 0 # correct predicted entity cnt 798 | 799 | fp_cnt = 0 # misclassify O as I- 800 | fn_cnt = 0 # misclassify I- as O 801 | total_token_cnt = 0 # total token cnt 802 | within_cnt = 0 # span correct but of wrong fine-grained type 803 | outer_cnt = 0 # span correct but of wrong coarse-grained type 804 | total_span_cnt = 0 # span correct 805 | 806 | query_loss_all = [] 807 | 808 | eval_iter = min(eval_iter, len(eval_dataset)) 809 | 810 | # print test inference 811 | # if ckpt is not None: 812 | # if self.args.save_test_inference is not 'none': 813 | # # generate save path 814 | # if self.args.dataset == 'fewnerd': 815 | # save_path = '_'.join([self.args.save_test_inference, self.args.dataset, self.args.mode, self.args.model, str(self.args.N), str(self.args.K)]) 816 | # else: 817 | # save_path = '_'.join([self.args.save_test_inference, self.args.dataset, self.args.model, str(self.args.N), str(self.args.K)]) 818 | # f_write = open(save_path + '.txt', 'a', encoding='utf-8') 819 | 820 | 821 | it = 0 822 | while it + 1 < eval_iter: 823 | for _, (support, query) in enumerate(eval_dataset): 824 | if torch.cuda.is_available(): 825 | for k in support: 826 | if k != 'label' and k != 'sentence_num': 827 | support[k] = support[k].cuda() 828 | query[k] = query[k].cuda() 829 | query_label = torch.cat(query['label'], 0) 830 | query_label = query_label.cuda() 831 | 832 | # get proto init rep 833 | label_data = self.__generate_label_data__(query) 834 | label_data_emb = model.word_encoder(label_data['word'], label_data['mask']) # [num_label_sent, 10, 768] 835 | support_emb = model.word_encoder(support['word'], support['mask']) 836 | Proto = self.__get_proto__(label_data_emb, label_data['text_mask'], support_emb, support, model) #[N, 768] 837 | 838 | # support,proto -> MLP -> new emb 839 | support_label = torch.cat(support['label'], dim=0) 840 | support_emb = support_emb[support['text_mask']==1] 841 | support_afterMLP_emb = model(support_emb) 842 | proto_afterMLP_emb = model(Proto) 843 | if self.args.use_proto_as_neg == True: 844 | if self.args.have_otherO is True: 845 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(self.args.N)],dtype=torch.int64))) 846 | else: 847 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 848 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 849 | support_dis = [] 850 | if self.args.have_otherO is True: 851 | for i in range(self.args.N+1): 852 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 853 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 854 | temp_lst[-(self.args.N+1-i)] = 1 855 | temp_lst = np.array(temp_lst) 856 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 857 | support_dis.append(support_dis_one_line) 858 | else: 859 | for i in range(self.args.N): 860 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 861 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 862 | temp_lst[-(self.args.N-i)] = 1 863 | temp_lst = np.array(temp_lst) 864 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 865 | support_dis.append(support_dis_one_line) 866 | support_dis = torch.cat(support_dis, dim=0).view(-1) # [N+1, N*K] 867 | else: 868 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) # [N, N*K] 869 | # if self.args.use_diff_threshold == False: 870 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 871 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 872 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 873 | if self.args.have_otherO is True: 874 | support_dis = support_dis.view(self.args.N+1, -1) 875 | else: 876 | support_dis = support_dis.view(self.args.N, -1) 877 | # if self.args.use_diff_threshold == True: 878 | # margin = torch.mean(support_dis) 879 | support_loss = loss_func(support_dis, support_label, margin, alpha) 880 | 881 | self.__zero_grad__(model.fc.parameters()) 882 | grads_fc = autograd.grad(support_loss, model.fc.parameters(), allow_unused=True, retain_graph=True) 883 | fast_weights_fc, orderd_params_fc = model.cloned_fc_dict(), OrderedDict() 884 | for (key, val), grad in zip(model.fc.named_parameters(), grads_fc): 885 | fast_weights_fc[key] = orderd_params_fc[key] = val - self.args.task_lr * grad # grad中weight数量级是1e-4,bias是1e-11,有点太小了? 886 | 887 | fast_weights = {} 888 | fast_weights['fc'] = fast_weights_fc 889 | 890 | train_support_loss = [] 891 | for _ in range(self.args.train_support_iter - 1): 892 | support_afterMLP_emb = model(support_emb, fast_weights) 893 | proto_afterMLP_emb = model(Proto, fast_weights) 894 | if self.args.use_proto_as_neg == True: 895 | # support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 896 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 897 | support_dis = [] 898 | if self.args.have_otherO is True: 899 | for i in range(self.args.N+1): 900 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 901 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 902 | temp_lst[-(self.args.N+1-i)] = 1 903 | temp_lst = np.array(temp_lst) 904 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 905 | support_dis.append(support_dis_one_line) 906 | else: 907 | for i in range(self.args.N): 908 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 909 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 910 | temp_lst[-(self.args.N-i)] = 1 911 | temp_lst = np.array(temp_lst) 912 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 913 | support_dis.append(support_dis_one_line) 914 | support_dis = torch.cat(support_dis, dim=0).view(-1) 915 | else: 916 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) 917 | # if self.args.use_diff_threshold == False: 918 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 919 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 920 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 921 | if self.args.have_otherO is True: 922 | support_dis = support_dis.view(self.args.N+1, -1) 923 | else: 924 | support_dis = support_dis.view(self.args.N, -1) 925 | # if self.args.use_diff_threshold == True: 926 | # margin = torch.mean(support_dis) 927 | support_loss = loss_func(support_dis, support_label, margin, alpha) 928 | train_support_loss.append(support_loss.item()) 929 | # print_info = 'train_support, ' + str(support_loss.item()) 930 | # print('\033[0;31;40m{}\033[0m'.format(print_info)) 931 | self.__zero_grad__(orderd_params_fc.values()) 932 | 933 | grads_fc = torch.autograd.grad(support_loss, orderd_params_fc.values(), allow_unused=True, retain_graph=True) 934 | for (key, val), grad in zip(orderd_params_fc.items(), grads_fc): 935 | if grad is not None: 936 | fast_weights['fc'][key] = orderd_params_fc[key] = val - self.args.task_lr * grad 937 | 938 | # query, proto -> MLP -> new emb 939 | query_emb = model.word_encoder(query['word'], query['mask']) 940 | query_emb = query_emb[query['text_mask']==1] 941 | query_afterMLP_emb = model(query_emb, fast_weights) 942 | proto_afterMLP_emb = model(Proto, fast_weights) 943 | if self.args.save_query_ebd is True: 944 | save_qe_path = '_'.join([self.args.dataset, self.args.mode, self.args.model, str(self.args.N), str(self.args.K), str(self.args.Q), str(int(round(time.time() * 1000)))]) 945 | if not os.path.exists(save_qe_path): 946 | os.mkdir(save_qe_path) 947 | f_write = open(os.path.join(save_qe_path, 'label2tag.txt'), 'w', encoding='utf-8') 948 | for ln in query['label2tag'][0]: 949 | f_write.write(query['label2tag'][0][ln] + '\n') 950 | f_write.flush() 951 | f_write.close() 952 | np.save(os.path.join(save_qe_path, 'proto.npy'), proto_afterMLP_emb.cpu().detach().numpy()) 953 | np.save(os.path.join(save_qe_path, '0.npy'), query_afterMLP_emb[query_label == 0].cpu().detach().numpy()) 954 | np.save(os.path.join(save_qe_path, '1.npy'), query_afterMLP_emb[query_label == 1].cpu().detach().numpy()) 955 | np.save(os.path.join(save_qe_path, '2.npy'), query_afterMLP_emb[query_label == 2].cpu().detach().numpy()) 956 | np.save(os.path.join(save_qe_path, '3.npy'), query_afterMLP_emb[query_label == 3].cpu().detach().numpy()) 957 | np.save(os.path.join(save_qe_path, '4.npy'), query_afterMLP_emb[query_label == 4].cpu().detach().numpy()) 958 | np.save(os.path.join(save_qe_path, '5.npy'), query_afterMLP_emb[query_label == 5].cpu().detach().numpy()) 959 | sys.exit() 960 | 961 | if self.args.use_proto_as_neg == True: 962 | if self.args.have_otherO is True: 963 | query_label = torch.cat((query_label, torch.tensor([0 for _ in range(self.args.N)],dtype=torch.int64).cuda())) 964 | else: 965 | query_label = torch.cat((query_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64).cuda())) 966 | query_afterMLP_emb = torch.cat((query_afterMLP_emb,proto_afterMLP_emb+1e-8),dim=0) 967 | query_dis = [] 968 | if self.args.have_otherO is True: 969 | for i in range(self.args.N+1): 970 | query_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), query_afterMLP_emb) 971 | temp_lst = [0 for _ in range(query_dis_one_line.shape[1])] 972 | temp_lst[-(self.args.N+1-i)] = 1 973 | temp_lst = np.array(temp_lst) 974 | query_dis_one_line = query_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 975 | query_dis.append(query_dis_one_line) 976 | else: 977 | for i in range(self.args.N): 978 | query_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), query_afterMLP_emb) 979 | temp_lst = [0 for _ in range(query_dis_one_line.shape[1])] 980 | temp_lst[-(self.args.N-i)] = 1 981 | temp_lst = np.array(temp_lst) 982 | query_dis_one_line = query_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 983 | query_dis.append(query_dis_one_line) 984 | query_dis = torch.cat(query_dis, dim=0).view(-1) 985 | else: 986 | query_dis = self.__pos_dist__(proto_afterMLP_emb, query_afterMLP_emb).view(-1) # [N, N*K] 987 | # if self.args.use_diff_threshold == False: 988 | # print("query_dis_before_norm:", torch.max(query_dis).item(), torch.min(query_dis).item(), torch.mean(query_dis).item()) 989 | query_dis = F.layer_norm(query_dis, normalized_shape=[query_dis.shape[0]], bias=torch.full((query_dis.shape[0],), self.args.ln_bias).cuda()) 990 | # print("query_dis_after_norm:", torch.max(query_dis).item(), torch.min(query_dis).item(), torch.mean(query_dis).item()) 991 | if self.args.have_otherO is True: 992 | query_dis = query_dis.view(self.args.N+1, -1) 993 | else: 994 | query_dis = query_dis.view(self.args.N, -1) 995 | # if self.args.use_diff_threshold == True: 996 | # margin = torch.mean(query_dis) 997 | query_loss = loss_func(query_dis, query_label, margin, alpha) 998 | query_loss_all.append(query_loss.item()) 999 | 1000 | # make prediction 1001 | if self.args.use_proto_as_neg == True: 1002 | if self.args.have_otherO is True: 1003 | query_dis = query_dis[:, :-(self.args.N)] 1004 | query_label = query_label[:-(self.args.N)] 1005 | else: 1006 | query_dis = query_dis[:, :-(self.args.N-1)] 1007 | query_label = query_label[:-(self.args.N-1)] 1008 | 1009 | if self.args.use_diff_threshold == True: 1010 | threshold = [] 1011 | if self.args.have_otherO is True: 1012 | for i in range(0, self.args.N+1): 1013 | if self.args.threshold_mode == 'mean': 1014 | threshold.append(torch.mean(query_dis[i][query_label==i]).item()) 1015 | elif self.args.threshold_mode == 'max': 1016 | threshold.append(torch.max(query_dis[i][query_label==i]).item()) 1017 | else: 1018 | raise NotImplementedError 1019 | else: 1020 | for i in range(1, self.args.N+1): 1021 | if self.args.threshold_mode == 'mean': 1022 | threshold.append(torch.mean(query_dis[i-1][query_label==i]).item()) 1023 | elif self.args.threshold_mode == 'max': 1024 | threshold.append(torch.max(query_dis[i-1][query_label==i]).item()) 1025 | else: 1026 | raise NotImplementedError 1027 | 1028 | print("threshold:", threshold) 1029 | query_dis = query_dis.t() # [x, N] 1030 | query_pred = [] 1031 | 1032 | if self.args.have_otherO is True: 1033 | for tmp in query_dis: 1034 | query_pred.append(torch.min(tmp, dim=0)[1].item()) 1035 | else: 1036 | for tmp in query_dis: 1037 | flag = [] 1038 | for tm, th in zip(tmp, threshold): 1039 | if tm < th: 1040 | flag.append(tm.view(1,-1)) 1041 | else: 1042 | flag.append(torch.tensor([[9999.0]]).cuda()) 1043 | flag = torch.cat(flag).view(-1) 1044 | if torch.min(flag) != 9999.0: 1045 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 1046 | else: 1047 | query_pred.append(0) 1048 | query_pred = torch.Tensor(query_pred).cuda() 1049 | else: 1050 | query_dis = query_dis.t() # [x, N] 1051 | query_pred = [] 1052 | 1053 | if self.args.have_otherO is True: 1054 | for tmp in query_dis: 1055 | query_pred.append(torch.min(tmp, dim=0)[1].item()) 1056 | else: 1057 | if self.args.multi_margin is True: 1058 | for tmp in query_dis: 1059 | flag = [] 1060 | for tm, th in zip(tmp, margin): 1061 | if tm < th: 1062 | flag.append(tm.view(1,-1)) 1063 | else: 1064 | flag.append(torch.tensor([[9999.0]]).cuda()) 1065 | flag = torch.cat(flag).view(-1) 1066 | if torch.min(flag) != 9999.0: 1067 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 1068 | else: 1069 | query_pred.append(0) 1070 | else: 1071 | for tmp in query_dis: 1072 | if any(t < margin for t in tmp): 1073 | query_pred.append(torch.min(tmp, dim=0)[1].item()+1) 1074 | else: 1075 | query_pred.append(0) 1076 | query_pred = torch.Tensor(query_pred).cuda() 1077 | 1078 | assert query_pred.shape[0] == query_label.shape[0] 1079 | 1080 | tmp_pred_cnt, tmp_label_cnt, correct = model.metrics_by_entity(query_pred, query_label) 1081 | 1082 | fp, fn, token_cnt, within, outer, total_span = model.error_analysis(query_pred, query_label, query) 1083 | pred_cnt += tmp_pred_cnt 1084 | label_cnt += tmp_label_cnt 1085 | correct_cnt += correct 1086 | 1087 | fn_cnt += self.item(fn.data) 1088 | fp_cnt += self.item(fp.data) 1089 | total_token_cnt += token_cnt 1090 | outer_cnt += outer 1091 | within_cnt += within 1092 | total_span_cnt += total_span 1093 | 1094 | # # if ckpt is not None: 1095 | # # if self.args.save_test_inference is not 'none': 1096 | # # sentence_list, real_label_list, pred_label_list, label_name_list = self.__save_test_inference__(query_logits, query_label_no_neg, query) 1097 | # # assert len(sentence_list) == len(real_label_list) == len(pred_label_list) == len(label_name_list) 1098 | # # for i in range(len(sentence_list)): 1099 | # # assert len(sentence_list[i]) == len(real_label_list[i]) == len(pred_label_list[i]) == len(label_name_list[i]) 1100 | # # for j in range(len(sentence_list[i])): 1101 | # # f_write.write(sentence_list[i][j] + '\t' + real_label_list[i][j] + '\t' + pred_label_list[i][j] + '\n') 1102 | # # f_write.flush() 1103 | # # f_write.write('\n') 1104 | # # f_write.flush() 1105 | 1106 | if it + 1 == eval_iter: 1107 | break 1108 | it += 1 1109 | 1110 | precision = correct_cnt / (pred_cnt + 1e-10) 1111 | recall = correct_cnt / (label_cnt + 1e-10) 1112 | f1 = 2 * precision * recall / (precision + recall + 1e-10) 1113 | fp_error = fp_cnt / total_token_cnt 1114 | fn_error = fn_cnt / total_token_cnt 1115 | within_error = within_cnt / total_span_cnt 1116 | outer_error = outer_cnt / total_span_cnt 1117 | qloss = np.mean(np.array(query_loss_all)) 1118 | print('[EVAL] step: {0:4} loss: {4:3.4f}| [ENTITY] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'.format(it + 1, precision, recall, f1, qloss) + '\r') 1119 | 1120 | # sys.stdout.write('[EVAL] step: {0:4} | [ENTITY] precision: {1:3.4f}, recall: {2:3.4f}, f1: {3:3.4f}'.format(it + 1, precision, recall, f1) + '\r') 1121 | # sys.stdout.flush() 1122 | # print("") 1123 | # if ckpt is not None: 1124 | # if self.args.save_test_inference is not 'none': 1125 | # f_write.close() 1126 | 1127 | return precision, recall, f1, fp_error, fn_error, within_error, outer_error 1128 | 1129 | 1130 | class FewShotNERFramework_draw: 1131 | 1132 | def __init__(self, train_data_loader, val_data_loader, test_data_loader, args, tokenizer, use_sampled_data=False): 1133 | ''' 1134 | train_data_loader: DataLoader for training. 1135 | val_data_loader: DataLoader for validating. 1136 | test_data_loader: DataLoader for testing. 1137 | ''' 1138 | self.train_data_loader = train_data_loader 1139 | self.val_data_loader = val_data_loader 1140 | self.test_data_loader = test_data_loader 1141 | self.args = args 1142 | self.tokenizer = tokenizer 1143 | self.use_sampled_data = use_sampled_data 1144 | 1145 | def __load_model__(self, ckpt): 1146 | ''' 1147 | ckpt: Path of the checkpoint 1148 | return: Checkpoint dict 1149 | ''' 1150 | if os.path.isfile(ckpt): 1151 | checkpoint = torch.load(ckpt) 1152 | print("Successfully loaded checkpoint '%s'" % ckpt) 1153 | return checkpoint 1154 | else: 1155 | raise Exception("No checkpoint found at '%s'" % ckpt) 1156 | 1157 | def item(self, x): 1158 | ''' 1159 | PyTorch before and after 0.4 1160 | ''' 1161 | torch_version = torch.__version__.split('.') 1162 | if int(torch_version[0]) == 0 and int(torch_version[1]) < 4: 1163 | return x[0] 1164 | else: 1165 | return x.item() 1166 | 1167 | def __generate_label_data__(self, query): 1168 | label_tokens_index = [] 1169 | label_tokens_mask = [] 1170 | label_text_mask = [] 1171 | if self.args.have_otherO is True: 1172 | for label_dic in query['label2tag']: 1173 | for label_id in label_dic: 1174 | if label_id == 0: 1175 | label_tokens = ['other'] 1176 | else: 1177 | label_tokens = label_dic[label_id].split('-') 1178 | label_tokens = ['[CLS]'] + label_tokens + ['[SEP]'] 1179 | indexed_label_tokens = self.tokenizer.convert_tokens_to_ids(label_tokens) 1180 | # padding 1181 | while len(indexed_label_tokens) < 10: 1182 | indexed_label_tokens.append(0) 1183 | label_tokens_index.append(indexed_label_tokens) 1184 | # mask 1185 | mask = np.zeros((10), dtype=np.int32) 1186 | mask[:len(label_tokens)] = 1 1187 | label_tokens_mask.append(mask) 1188 | # text mask, also mask [CLS] and [SEP] 1189 | text_mask = np.zeros((10), dtype=np.int32) 1190 | text_mask[1:len(label_tokens)-1] = 1 1191 | label_text_mask.append(text_mask) 1192 | else: 1193 | for label_dic in query['label2tag']: 1194 | for label_id in label_dic: 1195 | if label_id != 0: 1196 | label_tokens = label_dic[label_id].split('-') 1197 | label_tokens = ['[CLS]'] + label_tokens + ['[SEP]'] 1198 | indexed_label_tokens = self.tokenizer.convert_tokens_to_ids(label_tokens) 1199 | # padding 1200 | while len(indexed_label_tokens) < 10: 1201 | indexed_label_tokens.append(0) 1202 | label_tokens_index.append(indexed_label_tokens) 1203 | # mask 1204 | mask = np.zeros((10), dtype=np.int32) 1205 | mask[:len(label_tokens)] = 1 1206 | label_tokens_mask.append(mask) 1207 | # text mask, also mask [CLS] and [SEP] 1208 | text_mask = np.zeros((10), dtype=np.int32) 1209 | text_mask[1:len(label_tokens)-1] = 1 1210 | label_text_mask.append(text_mask) 1211 | 1212 | label_tokens_index = torch.Tensor(label_tokens_index).long().cuda() 1213 | label_tokens_mask = torch.Tensor(label_tokens_mask).long().cuda() 1214 | label_text_mask = torch.Tensor(label_text_mask).long().cuda() 1215 | 1216 | label_data = {} 1217 | label_data['word'] = label_tokens_index 1218 | label_data['mask'] = label_tokens_mask 1219 | label_data['text_mask'] = label_text_mask 1220 | return label_data 1221 | 1222 | def __zero_grad__(self, params): 1223 | for p in params: 1224 | if p.grad is not None: 1225 | p.grad.zero_() 1226 | 1227 | def __get_sample_pairs__(self, data): 1228 | data_1 = {} 1229 | data_2 = {} 1230 | data_1['word_emb'] = data['word_emb'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 1231 | data_1['label'] = data['label'][[l in [*range(1, self.args.N+1)] for l in data['label']]] 1232 | data_2['word_emb'] = data['word_emb'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 1233 | data_2['label'] = data['label'][[l in [*range(0, self.args.N+1)] for l in data['label']]] 1234 | 1235 | return data_1, data_2 1236 | 1237 | def __generate_pair_label__(self, label1, label2): 1238 | pair_label = [] 1239 | for l1 in label1: 1240 | for l2 in label2: 1241 | if l1 == l2: 1242 | pair_label.append(1.0) 1243 | else: 1244 | pair_label.append(0.0) 1245 | return torch.Tensor(pair_label).cuda() 1246 | 1247 | def __generate_query_pair_label__(self, query_dis, query_label): 1248 | query_pair_label = [] 1249 | after_query_dis = [] 1250 | for i, l in enumerate(query_label): 1251 | tmp = torch.zeros([1, self.args.N]) 1252 | if l == -1: 1253 | continue 1254 | elif l == 0: 1255 | query_pair_label.append(tmp) 1256 | after_query_dis.append(query_dis[i]) 1257 | else: 1258 | tmp[0, l-1] = 1 1259 | query_pair_label.append(tmp) 1260 | after_query_dis.append(query_dis[i]) 1261 | query_pair_label = torch.cat(query_pair_label, dim=0).view(-1).cuda() 1262 | after_query_dis = torch.stack(after_query_dis).view(-1).cuda() 1263 | 1264 | return after_query_dis, query_pair_label 1265 | 1266 | def __get_proto__(self, label_data_emb, label_data_text_mask, support_emb, support, model): 1267 | if self.args.label_name_mode == 'mean': 1268 | temp_word_list = [] 1269 | for i, word_emb_list in enumerate(support_emb): 1270 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 1271 | temp_label_list = [] 1272 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 1273 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 1274 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 1275 | Proto = [] 1276 | if self.args.have_otherO is True: 1277 | for i in range(self.args.N+1): 1278 | Proto.append(torch.mean(temp_word_list[temp_label_list==i], dim=0).view(1,-1)) 1279 | else: 1280 | for i in range(self.args.N): 1281 | Proto.append(torch.mean(temp_word_list[temp_label_list==i+1], dim=0).view(1,-1)) 1282 | Proto = torch.cat(Proto) 1283 | elif self.args.label_name_mode == 'LnAsQ': 1284 | # get Q = mean(init label name) 1285 | Q = [] 1286 | K = {} 1287 | assert label_data_emb.shape[0] == label_data_text_mask.shape[0] 1288 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 1289 | p = l_ebd[label_data_text_mask[i]==1] 1290 | # K[i] = p 1291 | p = p.mean(dim=0) 1292 | Q.append(p.view(1,-1)) 1293 | Q = torch.cat(Q,0) 1294 | 1295 | # get K or V = cat(label name and word in class i) 1296 | temp_word_list = [] 1297 | for i, word_emb_list in enumerate(support_emb): 1298 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 1299 | temp_label_list = [] 1300 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 1301 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 1302 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 1303 | if self.args.have_otherO is True: 1304 | for i in range(self.args.N+1): 1305 | K[i] = temp_word_list[temp_label_list==i] 1306 | else: 1307 | for i in range(self.args.N): 1308 | K[i] = temp_word_list[temp_label_list==i+1] 1309 | 1310 | # Attention 1311 | Proto = [] 1312 | for i, q in enumerate(Q): 1313 | temp = torch.mm(model.att(q.view(1, -1)), K[i].t()) 1314 | att_weights = F.softmax(F.layer_norm(temp, normalized_shape=(temp.shape[0],temp.shape[1])), dim=1) 1315 | # print("att_weights:", att_weights) 1316 | proto = torch.mm(att_weights, K[i]) # [1, 768] 1317 | Proto.append(proto) 1318 | Proto = torch.cat(Proto) 1319 | elif self.args.label_name_mode == 'LnAsQKV': 1320 | # get Q = mean(init label name) 1321 | Q = [] 1322 | K = {} 1323 | assert label_data_emb.shape[0] == label_data_text_mask.shape[0] 1324 | for i, l_ebd in enumerate(label_data_emb): # [10, 768] 1325 | p = l_ebd[label_data_text_mask[i]==1] 1326 | K[i] = p 1327 | p = p.mean(dim=0) 1328 | Q.append(p.view(1,-1)) 1329 | Q = torch.cat(Q,0) 1330 | 1331 | # get K or V = cat(label name and word in class i) 1332 | temp_word_list = [] 1333 | for i, word_emb_list in enumerate(support_emb): 1334 | temp_word_list.append(word_emb_list[support['text_mask'][i]==1]) 1335 | temp_label_list = [] 1336 | temp_word_list = torch.cat(temp_word_list) # [x, 768] 1337 | temp_label_list = torch.cat(support['label'], dim=0) # [x,] 1338 | assert temp_word_list.shape[0] == temp_label_list.shape[0] 1339 | 1340 | if self.args.have_otherO is True: 1341 | for i in range(self.args.N+1): 1342 | K[i] = torch.cat((K[i],temp_word_list[temp_label_list==i]),dim=0) 1343 | else: 1344 | for i in range(self.args.N): 1345 | K[i] = torch.cat((K[i],temp_word_list[temp_label_list==i+1]),dim=0) 1346 | 1347 | # Attention 1348 | Proto = [] 1349 | for i, q in enumerate(Q): 1350 | temp = torch.mm(model.att(q.view(1, -1)), K[i].t()) 1351 | att_weights = F.softmax(F.layer_norm(temp, normalized_shape=(temp.shape[0],temp.shape[1])), dim=1) 1352 | # print("att_weights:", att_weights) 1353 | proto = torch.mm(att_weights, K[i]) # [1, 768] 1354 | Proto.append(proto) 1355 | Proto = torch.cat(Proto) 1356 | else: 1357 | raise NotImplementedError 1358 | 1359 | return Proto 1360 | 1361 | def __pos_dist__(self, instances, class_proto): # ins:[N*K, 256], cla:[N, 256] 1362 | return torch.pow(torch.pow(class_proto.unsqueeze(0) - instances.unsqueeze(1), 2).sum(-1), 0.5) 1363 | 1364 | def __save_test_inference__(self, logits, pred, query): 1365 | 1366 | # query 去掉-1 1367 | new_query_label = [] 1368 | new_query_word = [] 1369 | new_query_textmask = [] 1370 | for lq, wq, tq in zip(query['label'], query['word'], query['text_mask']): 1371 | pass 1372 | 1373 | 1374 | # logits = F.softmax(logits, dim=-1) 1375 | # 将word转为真实单词 1376 | sentence_list = [] # 二维列表 1377 | for words, mask in zip(query['word'], query['text_mask']): 1378 | real_words = [] 1379 | for word in words[mask==1]: 1380 | real_words.append(self.tokenizer.decode(word).replace(" ", "")) 1381 | sentence_list.append(real_words) 1382 | 1383 | # 将label和pred转为label word 1384 | sentence_num = [] 1385 | sentence_num.append(0) 1386 | tmp = 0 1387 | for i in query['sentence_num']: 1388 | tmp += i 1389 | sentence_num.append(tmp) 1390 | real_label_list = [] # 二维列表 1391 | pred_label_list = [] # 二维列表 1392 | label_name_list = [] # 二维列表 1393 | # pred和logits切成二维矩阵 1394 | pred_list = [] 1395 | # logits_list = [] 1396 | sentence_len = [] 1397 | sentence_len.append(0) 1398 | tmp = 0 1399 | for _labels in query['label']: 1400 | tmp += _labels.shape[0] 1401 | sentence_len.append(tmp) 1402 | for i in range(len(sentence_len)-1): 1403 | tmp2 = pred[sentence_len[i]: sentence_len[i+1]] 1404 | # tmp3 = logits[sentence_len[i]: sentence_len[i+1]] 1405 | pred_list.append(tmp2.cpu()) 1406 | # logits_list.append(tmp3.cpu().detach().numpy().tolist()) 1407 | 1408 | for i in range(len(sentence_num)-1): 1409 | for j in range(sentence_num[i], sentence_num[i+1]): 1410 | tmp_label_list = [] 1411 | tmp_pred_list = [] 1412 | tmp_label_name_list = [] 1413 | assert query['label'][j].shape[0] == pred_list[j].shape[0] 1414 | for l, p in zip(query['label'][j], pred_list[j]): 1415 | if l == -1: 1416 | tmp_label_list.append(str(-1)) 1417 | else: 1418 | tmp_label_list.append(query['label2tag'][i][l.item()]) 1419 | tmp_pred_list.append(query['label2tag'][i][p.item()]) 1420 | tmp_label_name_list.append(str(query['label2tag'][i])) 1421 | real_label_list.append(tmp_label_list) 1422 | pred_label_list.append(tmp_pred_list) 1423 | label_name_list.append(tmp_label_name_list) # 每个元任务的label_list 1424 | 1425 | return sentence_list, real_label_list, pred_label_list, label_name_list 1426 | 1427 | def eval(self, 1428 | model_proto, 1429 | model_metnet, 1430 | eval_iter, 1431 | model_nn=None, 1432 | model_struct=None, 1433 | ckpt_proto=None, 1434 | ckpt_metnet=None, 1435 | ckpt_nn=None, 1436 | ckpt_struct=None): 1437 | ''' 1438 | model: a FewShotREModel instance 1439 | B: Batch size 1440 | N: Num of classes for each batch 1441 | K: Num of instances for each class in the support set 1442 | Q: Num of instances for each class in the query set 1443 | eval_iter: Num of iterations 1444 | ckpt: Checkpoint path. Set as None if using current model parameters. 1445 | return: Accuracy 1446 | ''' 1447 | print("evaluating...") 1448 | 1449 | model_proto.eval() 1450 | model_metnet.eval() 1451 | # model_nn.eval() 1452 | # model_struct.eval() 1453 | 1454 | loss_func = TripletLoss(args=self.args) 1455 | 1456 | if (ckpt_proto is None) or (ckpt_metnet is None): # or (ckpt_nn is None) or (ckpt_struct is None): 1457 | print('ckpt_proto:', ckpt_proto) 1458 | print('ckpt_proto:', ckpt_metnet) 1459 | 1460 | print("Use test dataset") 1461 | eval_dataset = self.test_data_loader 1462 | 1463 | pred_cnt = 0 # pred entity cnt 1464 | label_cnt = 0 # true label entity cnt 1465 | correct_cnt = 0 # correct predicted entity cnt 1466 | 1467 | fp_cnt = 0 # misclassify O as I- 1468 | fn_cnt = 0 # misclassify I- as O 1469 | total_token_cnt = 0 # total token cnt 1470 | within_cnt = 0 # span correct but of wrong fine-grained type 1471 | outer_cnt = 0 # span correct but of wrong coarse-grained type 1472 | total_span_cnt = 0 # span correct 1473 | 1474 | query_loss_all = [] 1475 | 1476 | eval_iter = min(eval_iter, len(eval_dataset)) 1477 | 1478 | # print test inference 1479 | # if ckpt is not None: 1480 | # if self.args.save_test_inference is not 'none': 1481 | # # generate save path 1482 | # if self.args.dataset == 'fewnerd': 1483 | # save_path = '_'.join([self.args.save_test_inference, self.args.dataset, self.args.mode, self.args.model, str(self.args.N), str(self.args.K)]) 1484 | # else: 1485 | # save_path = '_'.join([self.args.save_test_inference, self.args.dataset, self.args.model, str(self.args.N), str(self.args.K)]) 1486 | # f_write = open(save_path + '.txt', 'a', encoding='utf-8') 1487 | 1488 | 1489 | it = 0 1490 | while it + 1 < eval_iter: 1491 | for _, (support, query) in enumerate(eval_dataset): 1492 | 1493 | if ckpt_metnet != 'none': 1494 | state_dict = self.__load_model__(ckpt_metnet)['state_dict'] 1495 | own_state = model_metnet.state_dict() 1496 | for name, param in state_dict.items(): 1497 | if name not in own_state: 1498 | continue 1499 | own_state[name].copy_(param) 1500 | 1501 | margin = model_metnet.param 1502 | alpha = model_metnet.alpha 1503 | print("margin:", margin) 1504 | print("alpha:", alpha) 1505 | 1506 | if self.args.margin != -1: 1507 | margin = self.args.margin 1508 | print('set margin:', margin) 1509 | 1510 | if torch.cuda.is_available(): 1511 | for k in support: 1512 | if k != 'label' and k != 'sentence_num': 1513 | support[k] = support[k].cuda() 1514 | query[k] = query[k].cuda() 1515 | query_label = torch.cat(query['label'], 0) 1516 | query_label = query_label.cuda() 1517 | 1518 | # get proto init rep 1519 | label_data = self.__generate_label_data__(query) 1520 | label_data_emb = model_metnet.word_encoder(label_data['word'], label_data['mask']) # [num_label_sent, 10, 768] 1521 | support_emb = model_metnet.word_encoder(support['word'], support['mask']) 1522 | Proto = self.__get_proto__(label_data_emb, label_data['text_mask'], support_emb, support, model_metnet) #[N, 768] 1523 | 1524 | # support,proto -> MLP -> new emb 1525 | support_label = torch.cat(support['label'], dim=0) 1526 | support_emb = support_emb[support['text_mask']==1] 1527 | support_afterMLP_emb = model_metnet(support_emb) 1528 | proto_afterMLP_emb = model_metnet(Proto) 1529 | if self.args.use_proto_as_neg == True: 1530 | if self.args.have_otherO is True: 1531 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(self.args.N)],dtype=torch.int64))) 1532 | else: 1533 | support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 1534 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 1535 | support_dis = [] 1536 | if self.args.have_otherO is True: 1537 | for i in range(self.args.N+1): 1538 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 1539 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 1540 | temp_lst[-(self.args.N+1-i)] = 1 1541 | temp_lst = np.array(temp_lst) 1542 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 1543 | support_dis.append(support_dis_one_line) 1544 | else: 1545 | for i in range(self.args.N): 1546 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 1547 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 1548 | temp_lst[-(self.args.N-i)] = 1 1549 | temp_lst = np.array(temp_lst) 1550 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 1551 | support_dis.append(support_dis_one_line) 1552 | support_dis = torch.cat(support_dis, dim=0).view(-1) # [N+1, N*K] 1553 | else: 1554 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) # [N, N*K] 1555 | # if self.args.use_diff_threshold == False: 1556 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 1557 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 1558 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 1559 | if self.args.have_otherO is True: 1560 | support_dis = support_dis.view(self.args.N+1, -1) 1561 | else: 1562 | support_dis = support_dis.view(self.args.N, -1) 1563 | # if self.args.use_diff_threshold == True: 1564 | # margin = torch.mean(support_dis) 1565 | support_loss = loss_func(support_dis, support_label, margin, alpha) 1566 | 1567 | self.__zero_grad__(model_metnet.fc.parameters()) 1568 | grads_fc = autograd.grad(support_loss, model_metnet.fc.parameters(), allow_unused=True, retain_graph=True) 1569 | fast_weights_fc, orderd_params_fc = model_metnet.cloned_fc_dict(), OrderedDict() 1570 | for (key, val), grad in zip(model_metnet.fc.named_parameters(), grads_fc): 1571 | fast_weights_fc[key] = orderd_params_fc[key] = val - self.args.task_lr * grad # grad中weight数量级是1e-4,bias是1e-11,有点太小了? 1572 | 1573 | fast_weights = {} 1574 | fast_weights['fc'] = fast_weights_fc 1575 | 1576 | train_support_loss = [] 1577 | for _ in range(self.args.train_support_iter - 1): 1578 | support_afterMLP_emb = model_metnet(support_emb, fast_weights) 1579 | proto_afterMLP_emb = model_metnet(Proto, fast_weights) 1580 | if self.args.use_proto_as_neg == True: 1581 | # support_label = torch.cat((support_label, torch.tensor([0 for _ in range(1,self.args.N)],dtype=torch.int64))) 1582 | support_afterMLP_emb = torch.cat((support_afterMLP_emb, proto_afterMLP_emb+1e-8),dim=0) 1583 | support_dis = [] 1584 | if self.args.have_otherO is True: 1585 | for i in range(self.args.N+1): 1586 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 1587 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 1588 | temp_lst[-(self.args.N+1-i)] = 1 1589 | temp_lst = np.array(temp_lst) 1590 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 1591 | support_dis.append(support_dis_one_line) 1592 | else: 1593 | for i in range(self.args.N): 1594 | support_dis_one_line = self.__pos_dist__(proto_afterMLP_emb[i].view(1, -1), support_afterMLP_emb) 1595 | temp_lst = [0 for _ in range(support_dis_one_line.shape[1])] 1596 | temp_lst[-(self.args.N-i)] = 1 1597 | temp_lst = np.array(temp_lst) 1598 | support_dis_one_line = support_dis_one_line.view(-1)[temp_lst == 0].view(1, -1) 1599 | support_dis.append(support_dis_one_line) 1600 | support_dis = torch.cat(support_dis, dim=0).view(-1) 1601 | else: 1602 | support_dis = self.__pos_dist__(proto_afterMLP_emb, support_afterMLP_emb).view(-1) 1603 | # if self.args.use_diff_threshold == False: 1604 | # print("support_dis_before_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 1605 | support_dis = F.layer_norm(support_dis, normalized_shape=[support_dis.shape[0]], bias=torch.full((support_dis.shape[0],), self.args.ln_bias).cuda()) 1606 | # print("support_dis_after_norm:", torch.max(support_dis).item(), torch.min(support_dis).item(), torch.mean(support_dis).item()) 1607 | if self.args.have_otherO is True: 1608 | support_dis = support_dis.view(self.args.N+1, -1) 1609 | else: 1610 | support_dis = support_dis.view(self.args.N, -1) 1611 | # if self.args.use_diff_threshold == True: 1612 | # margin = torch.mean(support_dis) 1613 | support_loss = loss_func(support_dis, support_label, margin, alpha) 1614 | train_support_loss.append(support_loss.item()) 1615 | # print_info = 'train_support, ' + str(support_loss.item()) 1616 | # print('\033[0;31;40m{}\033[0m'.format(print_info)) 1617 | self.__zero_grad__(orderd_params_fc.values()) 1618 | 1619 | grads_fc = torch.autograd.grad(support_loss, orderd_params_fc.values(), allow_unused=True, retain_graph=True) 1620 | for (key, val), grad in zip(orderd_params_fc.items(), grads_fc): 1621 | if grad is not None: 1622 | fast_weights['fc'][key] = orderd_params_fc[key] = val - self.args.task_lr * grad 1623 | 1624 | # query, proto -> MLP -> new emb 1625 | query_emb = model_metnet.word_encoder(query['word'], query['mask']) 1626 | query_emb = query_emb[query['text_mask']==1] 1627 | query_afterMLP_emb = model_metnet(query_emb, fast_weights) 1628 | proto_afterMLP_emb = model_metnet(Proto, fast_weights) 1629 | if self.args.save_query_ebd is True: 1630 | save_qe_path = '_'.join([self.args.dataset, self.args.mode, 'metnet',str(self.args.N), str(self.args.K), str(self.args.Q), str(int(round(time.time() * 1000)))]) 1631 | if not os.path.exists(save_qe_path): 1632 | os.mkdir(save_qe_path) 1633 | f_write = open(os.path.join(save_qe_path, 'label2tag.txt'), 'w', encoding='utf-8') 1634 | for ln in query['label2tag'][0]: 1635 | f_write.write(query['label2tag'][0][ln] + '\n') 1636 | f_write.flush() 1637 | f_write.close() 1638 | np.save(os.path.join(save_qe_path, 'proto.npy'), proto_afterMLP_emb.cpu().detach().numpy()) 1639 | np.save(os.path.join(save_qe_path, '0.npy'), query_afterMLP_emb[query_label == 0].cpu().detach().numpy()) 1640 | np.save(os.path.join(save_qe_path, '1.npy'), query_afterMLP_emb[query_label == 1].cpu().detach().numpy()) 1641 | np.save(os.path.join(save_qe_path, '2.npy'), query_afterMLP_emb[query_label == 2].cpu().detach().numpy()) 1642 | np.save(os.path.join(save_qe_path, '3.npy'), query_afterMLP_emb[query_label == 3].cpu().detach().numpy()) 1643 | np.save(os.path.join(save_qe_path, '4.npy'), query_afterMLP_emb[query_label == 4].cpu().detach().numpy()) 1644 | np.save(os.path.join(save_qe_path, '5.npy'), query_afterMLP_emb[query_label == 5].cpu().detach().numpy()) 1645 | 1646 | del model_metnet 1647 | del support_emb 1648 | del support_afterMLP_emb 1649 | del Proto 1650 | del proto_afterMLP_emb 1651 | del query_emb 1652 | del query_afterMLP_emb 1653 | 1654 | 1655 | # PROTO 1656 | if ckpt_proto != 'none': 1657 | state_dict = self.__load_model__(ckpt_proto)['state_dict'] 1658 | own_state = model_proto.state_dict() 1659 | for name, param in state_dict.items(): 1660 | if name not in own_state: 1661 | continue 1662 | own_state[name].copy_(param) 1663 | 1664 | if torch.cuda.is_available(): 1665 | for k in support: 1666 | if k != 'label' and k != 'sentence_num': 1667 | support[k] = support[k].cuda() 1668 | query[k] = query[k].cuda() 1669 | label = torch.cat(query['label'], 0) 1670 | label = label.cuda() 1671 | support_label = torch.cat(support['label'], 0).cuda() 1672 | 1673 | logits, pred = model_proto(support, query) 1674 | sys.exit() 1675 | # del model_proto 1676 | 1677 | 1678 | # # nnshot 1679 | # if ckpt_nn != 'none': 1680 | # state_dict = self.__load_model__(ckpt_nn)['state_dict'] 1681 | # own_state = model_nn.state_dict() 1682 | # for name, param in state_dict.items(): 1683 | # if name not in own_state: 1684 | # continue 1685 | # own_state[name].copy_(param) 1686 | 1687 | # if torch.cuda.is_available(): 1688 | # for k in support: 1689 | # if k != 'label' and k != 'sentence_num': 1690 | # support[k] = support[k].cuda() 1691 | # query[k] = query[k].cuda() 1692 | # label = torch.cat(query['label'], 0) 1693 | # label = label.cuda() 1694 | # support_label = torch.cat(support['label'], 0).cuda() 1695 | 1696 | # logits, pred = model_nn(support, query) 1697 | # sys.exit() 1698 | 1699 | # del model_nn 1700 | 1701 | 1702 | # # structshot 1703 | # if ckpt_struct != 'none': 1704 | # state_dict = self.__load_model__(ckpt_struct)['state_dict'] 1705 | # own_state = model_struct.state_dict() 1706 | # for name, param in state_dict.items(): 1707 | # if name not in own_state: 1708 | # continue 1709 | # own_state[name].copy_(param) 1710 | 1711 | # if torch.cuda.is_available(): 1712 | # for k in support: 1713 | # if k != 'label' and k != 'sentence_num': 1714 | # support[k] = support[k].cuda() 1715 | # query[k] = query[k].cuda() 1716 | # label = torch.cat(query['label'], 0) 1717 | # label = label.cuda() 1718 | # support_label = torch.cat(support['label'], 0).cuda() 1719 | 1720 | # logits, pred = model_struct(support, query) 1721 | # sys.exit() 1722 | 1723 | 1724 | 1725 | -------------------------------------------------------------------------------- /utils/tripletloss.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | # 自定义ContrastiveLoss 9 | class TripletLoss(torch.nn.Module): 10 | """ 11 | Triplet loss function. 12 | """ 13 | 14 | def __init__(self, args): 15 | super(TripletLoss, self).__init__() 16 | self.args = args 17 | 18 | def __tripletloss__(self, dp_one, dn_one, margin, alpha): 19 | # alpha = torch.sigmoid(alpha) 20 | # return alpha*torch.clamp(dp_one-dn_one+margin, min=0.0) + (1-alpha)*dp_one 21 | return torch.clamp(dp_one-dn_one+margin, min=0.0) 22 | 23 | def __tripletloss_dp__(self, dp_one, dn_one, margin, alpha): 24 | # alpha = torch.sigmoid(alpha) 25 | # return alpha*torch.clamp(dp_one-dn_one+margin, min=0.0) + (1-alpha)*dp_one 26 | return 0.5 * torch.clamp(dp_one-dn_one+margin, min=0.0) + 0.5 * dp_one 27 | 28 | def __sigmoid_tripletloss__(self, dp_one, dn_one, margin, alpha): 29 | pos_weight = torch.sigmoid(dp_one - margin) 30 | neg_weight = torch.sigmoid(margin - dn_one) 31 | return pos_weight*dp_one + neg_weight*torch.clamp(margin - dn_one, min=0.0) 32 | 33 | 34 | def forward(self, dis, label, margin, alpha): 35 | ''' 36 | dis: [N, N*K] 37 | label: [N*K] 38 | margin: a trainable param 39 | ''' 40 | loss = 0 41 | for i, ps_dis in enumerate(dis): 42 | temp_loss = 0 43 | if self.args.have_otherO is True: 44 | dp = ps_dis[label==i] 45 | dn = ps_dis[label!=i] 46 | else: 47 | dp = ps_dis[label==i+1] 48 | dn = ps_dis[label!=i+1] 49 | dn, index = torch.sort(dn) 50 | if dn.shape[0] > self.args.neg_num: 51 | dn = dn[:self.args.neg_num] 52 | for dp_one in dp: 53 | for dn_one in dn: 54 | if self.args.tripletloss_mode == 'tl': 55 | if self.args.multi_margin is True: 56 | temp_loss += self.__tripletloss__(dp_one, dn_one, margin[i], alpha) 57 | else: 58 | temp_loss += self.__tripletloss__(dp_one, dn_one, margin, alpha) 59 | elif self.args.tripletloss_mode == 'tl+dp': 60 | if self.args.multi_margin is True: 61 | temp_loss += self.__tripletloss_dp__(dp_one, dn_one, margin[i], alpha) 62 | else: 63 | temp_loss += self.__tripletloss_dp__(dp_one, dn_one, margin, alpha) 64 | elif self.args.tripletloss_mode == 'sig+dp+dn': 65 | if self.args.multi_margin is True: 66 | temp_loss += self.__sigmoid_tripletloss__(dp_one, dn_one, margin[i], alpha) 67 | else: 68 | temp_loss += self.__sigmoid_tripletloss__(dp_one, dn_one, margin, alpha) 69 | else: 70 | raise NotImplementedError 71 | temp_loss = temp_loss/dp.shape[0]/dn.shape[0] 72 | loss += temp_loss 73 | loss = loss/dis.shape[0] 74 | return loss 75 | 76 | # tmp1 = (label) * torch.pow(dis, 2).squeeze(-1) 77 | # # mean_val = torch.mean(euclidean_distance) 78 | # tmp2 = (1 - label) * torch.pow(torch.clamp(margin - dis, min=0.0), 79 | # 2).squeeze(-1) 80 | # loss_contrastive = torch.mean(tmp1 + tmp2) 81 | 82 | # # print("**********************************************************************") 83 | # return loss_contrastive 84 | 85 | -------------------------------------------------------------------------------- /utils/viterbi.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | START_ID = 0 6 | O_ID = 1 7 | 8 | class ViterbiDecoder: 9 | """ 10 | Generalized Viterbi decoding 11 | """ 12 | 13 | def __init__(self, n_tag, abstract_transitions, tau): 14 | """ 15 | We assume the batch size is 1, so no need to worry about PAD for now 16 | n_tag: START, O, and I_Xs 17 | """ 18 | super().__init__() 19 | self.transitions = self.project_target_transitions(n_tag, abstract_transitions, tau) 20 | 21 | @staticmethod 22 | def project_target_transitions(n_tag, abstract_transitions, tau): 23 | s_o, s_i, o_o, o_i, i_o, i_i, x_y = abstract_transitions 24 | # self transitions for I-X tags 25 | a = torch.eye(n_tag) * i_i 26 | # transitions from I-X to I-Y 27 | b = torch.ones(n_tag, n_tag) * x_y / (n_tag - 3) 28 | c = torch.eye(n_tag) * x_y / (n_tag - 3) 29 | transitions = a + b - c 30 | # transition from START to O 31 | transitions[START_ID, O_ID] = s_o 32 | # transitions from START to I-X 33 | transitions[START_ID, O_ID+1:] = s_i / (n_tag - 2) 34 | # transition from O to O 35 | transitions[O_ID, O_ID] = o_o 36 | # transitions from O to I-X 37 | transitions[O_ID, O_ID+1:] = o_i / (n_tag - 2) 38 | # transitions from I-X to O 39 | transitions[O_ID+1:, O_ID] = i_o 40 | # no transitions to START 41 | transitions[:, START_ID] = 0. 42 | 43 | powered = torch.pow(transitions, tau) 44 | summed = powered.sum(dim=1) 45 | 46 | transitions = powered / summed.view(n_tag, 1) 47 | 48 | transitions = torch.where(transitions > 0, transitions, torch.tensor(.000001)) 49 | 50 | #print(transitions) 51 | #print(torch.sum(transitions, dim=1)) 52 | return torch.log(transitions) 53 | 54 | def forward(self, scores: torch.Tensor) -> torch.Tensor: # type: ignore 55 | """ 56 | Take the emission scores calculated by NERModel, and return a tensor of CRF features, 57 | which is the sum of transition scores and emission scores. 58 | :param scores: emission scores calculated by NERModel. 59 | shape: (batch_size, sentence_length, ntags) 60 | :return: a tensor containing the CRF features whose shape is 61 | (batch_size, sentence_len, ntags, ntags). F[b, t, i, j] represents 62 | emission[t, j] + transition[i, j] for the b'th sentence in this batch. 63 | """ 64 | batch_size, sentence_len, _ = scores.size() 65 | 66 | # expand the transition matrix batch-wise as well as sentence-wise 67 | transitions = self.transitions.expand(batch_size, sentence_len, -1, -1) 68 | 69 | # add another dimension for the "from" state, then expand to match 70 | # the dimensions of the expanded transition matrix above 71 | emissions = scores.unsqueeze(2).expand_as(transitions) 72 | 73 | # add them up 74 | return transitions + emissions 75 | 76 | @staticmethod 77 | def viterbi(features: torch.Tensor) -> torch.Tensor: 78 | """ 79 | Decode the most probable sequence of tags. 80 | Note that the delta values are calculated in the log space. 81 | :param features: the feature matrix from the forward method of CRF. 82 | shaped (batch_size, sentence_len, ntags, ntags) 83 | :return: a tensor containing the most probable sequences for the batch. 84 | shaped (batch_size, sentence_len) 85 | """ 86 | batch_size, sentence_len, ntags, _ = features.size() 87 | 88 | # initialize the deltas 89 | delta_t = features[:, 0, START_ID, :] 90 | deltas = [delta_t] 91 | 92 | # use dynamic programming to iteratively calculate the delta values 93 | for t in range(1, sentence_len): 94 | f_t = features[:, t] 95 | delta_t, _ = torch.max(f_t + delta_t.unsqueeze(2).expand_as(f_t), 1) 96 | deltas.append(delta_t) 97 | 98 | # now iterate backward to figure out the most probable tags 99 | sequences = [torch.argmax(deltas[-1], 1, keepdim=True)] 100 | for t in reversed(range(sentence_len - 1)): 101 | f_prev = features[:, t + 1].gather( 102 | 2, sequences[-1].unsqueeze(2).expand(batch_size, ntags, 1)).squeeze(2) 103 | sequences.append(torch.argmax(f_prev + deltas[t], 1, keepdim=True)) 104 | sequences.reverse() 105 | return torch.cat(sequences, dim=1) -------------------------------------------------------------------------------- /utils/word_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | import numpy as np 6 | import os 7 | from torch import optim 8 | from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification, RobertaModel, RobertaTokenizer, RobertaForSequenceClassification 9 | 10 | class BERTWordEncoder(nn.Module): 11 | 12 | def __init__(self, pretrain_path): 13 | nn.Module.__init__(self) 14 | self.bert = BertModel.from_pretrained(pretrain_path) 15 | 16 | def forward(self, words, masks): 17 | outputs = self.bert(words, attention_mask=masks, output_hidden_states=True, return_dict=True) 18 | #outputs = self.bert(inputs['word'], attention_mask=inputs['mask'], output_hidden_states=True, return_dict=True) 19 | # use the sum of the last 4 layers 20 | last_four_hidden_states = torch.cat([hidden_state.unsqueeze(0) for hidden_state in outputs['hidden_states'][-4:]], 0) 21 | del outputs 22 | word_embeddings = torch.sum(last_four_hidden_states, 0) # [num_sent, number_of_tokens, 768] 23 | return word_embeddings 24 | --------------------------------------------------------------------------------