├── model.png ├── requirements.txt ├── src ├── functions │ ├── metric.py │ ├── utils.py │ ├── biattention.py │ └── processor.py ├── model │ ├── main_functions.py │ └── model.py └── dependency │ └── merge.py ├── README.md └── run_NLI.py /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/NeuralSymbolic_KU_NLI/HEAD/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | tqdm 4 | tokenizers==0.10.3 5 | attrdict 6 | fastprogress 7 | transformers==4.6.1 8 | torch==1.9.0+cu111 9 | -------------------------------------------------------------------------------- /src/functions/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score 3 | 4 | 5 | def get_sklearn_score(predicts, corrects, idx2label): 6 | predicts = [idx2label[predict] for predict in predicts] 7 | corrects = [idx2label[correct] for correct in corrects] 8 | result = {"accuracy": accuracy_score(corrects, predicts), 9 | "macro_precision": precision_score(corrects, predicts, average="macro"), 10 | "micro_precision": precision_score(corrects, predicts, average="micro"), 11 | "macro_f1": f1_score(corrects, predicts, average="macro"), 12 | "micro_f1": f1_score(corrects, predicts, average="micro"), 13 | "macro_recall": recall_score(corrects, predicts, average="macro"), 14 | "micro_recall": recall_score(corrects, predicts, average="micro"), 15 | 16 | } 17 | for k, v in result.items(): 18 | result[k] = round(v, 3) 19 | print(k + ": " + str(v)) 20 | return result 21 | 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Natural Language Inference using Dependency Parsing 2 | Code for HCLT 2021 paper: *[Natural Language Inference using Dependency Parsing](https://koreascience.kr/article/CFKO202130060562801.page?&lang=ko)* 3 | 4 | 5 | ## Setting up the code environment 6 | 7 | ``` 8 | $ virtualenv --python=python3.7 venv 9 | $ source venv/bin/activate 10 | $ pip install -r requirements.txt 11 | ``` 12 | 13 | All code only supports running on Linux. 14 | 15 | 16 | ## Model Structure 17 | 18 | 19 | 20 | ## Data 21 | 22 | Korean Language Understanding Evaluation-Natural Language Inference version1: *[KLUE-NLI](https://klue-benchmark.com/tasks/68/data/description)* 23 | 24 | ### Directory and Pre-processing 25 | `의존 구문 분석 모델은 미공개(The dependency parser model is unpublished)` 26 | ``` 27 | ├── data 28 | │ ├── klue-nli-v1_train.json 29 | │   ├── klue-nli-v1_dev.json 30 | │   └── parsing 31 | │   ├── parsing_1_klue_nli_train.json 32 | │   └── parsing_1_klue_nli_dev.json 33 | │   └── merge 34 | │   ├── parsing_1_klue_nli_train.json 35 | │   └── parsing_1_klue_nli_dev.json 36 | ├── roberta 37 | │ ├── init_weight 38 | │   └── my_model 39 | ├── src 40 | │   ├── dependency 41 | │   └── merge.py 42 | │   ├── functions 43 | │   ├── biattention.py 44 | │   ├── utils.py 45 | │   ├── metric.py 46 | │   └── processor.json 47 | │   └── model 48 | │   ├── main_functions.py 49 | │   └── model.py 50 | ├── run_NLI.py 51 | ├── requirements.txt 52 | └── README.md 53 | ``` 54 | 55 | * 원시 데이터(data/klue-nli-v1_train.json)를 의존 구문 분석 모델을 활용하여 입력 문장 쌍에 대한 어절 단위 의존 구문 구조 추출(data/parsing/klue-nli-v1_train.json) 56 | 57 | * 입력 문장 쌍에 대한 어절 단위 의존 구문 구조(data/parsing/klue-nli-v1_train.json)를 `src/dependency/merge.py`를 통해 입력 문장 쌍에 대한 청크 단위 의존 구문 구조로 변환(data/merge/klue-nli-v1_train.json) 58 | 59 | * [roberta/init_weight](https://huggingface.co/klue/roberta-base)/vocab.json에 청크 단위로 구분해주는 스폐셜 토큰(Special Token) `` 추가 60 | 61 | 62 | ## Train & Test 63 | 64 | ### Pretrained Model 65 | [klue/roberta-base](https://huggingface.co/klue/roberta-base) 66 | 67 | ### How To Run 68 | ``` 69 | python run_NLI.py 70 | ``` 71 | 72 | ## Results on KLUE-NLI 73 | 74 | | Model | Acc | 75 | |---|--------- | 76 | | NLI w/ DP | 90.78% | 77 | -------------------------------------------------------------------------------- /src/functions/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import torch 4 | import numpy as np 5 | import os 6 | 7 | 8 | from src.functions.processor import ( 9 | KLUE_NLIV1Processor, 10 | klue_convert_examples_to_features 11 | ) 12 | 13 | 14 | def init_logger(): 15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 16 | datefmt='%m/%d/%Y %H:%M:%S', 17 | level=logging.INFO) 18 | 19 | def set_seed(args): 20 | random.seed(args.seed) 21 | np.random.seed(args.seed) 22 | torch.manual_seed(args.seed) 23 | if not args.no_cuda and torch.cuda.is_available(): 24 | torch.cuda.manual_seed_all(args.seed) 25 | 26 | # tensor를 list 형으로 변환하기위한 함수 27 | def to_list(tensor): 28 | return tensor.detach().cpu().tolist() 29 | 30 | 31 | # dataset을 load 하는 함수 32 | def load_examples(args, tokenizer, evaluate=False, output_examples=False, do_predict=False, input_dict=None): 33 | ''' 34 | :param args: 하이퍼 파라미터 35 | :param tokenizer: tokenization에 사용되는 tokenizer 36 | :param evaluate: 평가나 open test시, True 37 | :param output_examples: 평가나 open test 시, True / True 일 경우, examples와 features를 같이 return 38 | :param do_predict: open test시, True 39 | :param input_dict: open test시 입력되는 문서와 질문으로 이루어진 dictionary 40 | :return: 41 | examples : max_length 상관 없이, 원문으로 각 데이터를 저장한 리스트 42 | features : max_length에 따라 분할 및 tokenize된 원문 리스트 43 | dataset : max_length에 따라 분할 및 학습에 직접적으로 사용되는 tensor 형태로 변환된 입력 ids 44 | ''' 45 | input_dir = args.data_dir 46 | print("Creating features from dataset file at {}".format(input_dir)) 47 | 48 | # processor 선언 49 | ## json으로 된 train과 dev data_file명 50 | if len(set(input_dir.split("/")).intersection(["snli", "mnli", "qnli", "hans", "sick"])) != 0:processor = NLIV1Processor() 51 | else: processor = KLUE_NLIV1Processor() 52 | 53 | # open test 시 54 | if do_predict: 55 | ## input_dict: guid, premise, hypothesis로 이루어진 dictionary 56 | # examples = processor.get_example_from_input(input_dict) 57 | examples = processor.get_dev_examples(os.path.join(args.data_dir), 58 | filename=args.predict_file) 59 | # 평가 시 60 | elif evaluate: 61 | examples = processor.get_dev_examples(os.path.join(args.data_dir), 62 | filename=args.eval_file) 63 | # 학습 시 64 | else: 65 | examples = processor.get_train_examples(os.path.join(args.data_dir), 66 | filename=args.train_file) 67 | 68 | 69 | features, dataset = klue_convert_examples_to_features( 70 | examples=examples, 71 | tokenizer=tokenizer, 72 | max_seq_length=args.max_seq_length, 73 | is_training=not evaluate, 74 | return_dataset="pt", 75 | threads=args.threads, 76 | prem_max_sentence_length = args.prem_max_sentence_length, 77 | hypo_max_sentence_length = args.hypo_max_sentence_length, 78 | language = args.model_name_or_path.split("/")[-2] 79 | ) 80 | if output_examples: 81 | ## example == feature == dataset 82 | return dataset, examples, features 83 | return dataset 84 | -------------------------------------------------------------------------------- /src/functions/biattention.py: -------------------------------------------------------------------------------- 1 | # 해당 코드는 아래 링크에서 가져옴 2 | # https://github.com/KLUE-benchmark/KLUE-baseline/blob/8a03c9447e4c225e806877a84242aea11258c790/klue_baseline/models/dependency_parsing.py 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | import torch.nn.functional as F 10 | 11 | 12 | class BiAttention(nn.Module): 13 | def __init__( 14 | self, 15 | input_size_encoder, 16 | input_size_decoder, 17 | num_labels, 18 | biaffine=True, 19 | **kwargs 20 | ): 21 | super(BiAttention, self).__init__() 22 | self.input_size_encoder = input_size_encoder 23 | self.input_size_decoder = input_size_decoder 24 | self.num_labels = num_labels 25 | self.biaffine = biaffine 26 | 27 | self.W_e = Parameter(torch.Tensor(self.num_labels, self.input_size_encoder)) 28 | self.W_d = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder)) 29 | self.b = Parameter(torch.Tensor(self.num_labels, 1, 1)) 30 | if self.biaffine: 31 | self.U = Parameter( 32 | torch.Tensor( 33 | self.num_labels, self.input_size_decoder, self.input_size_encoder 34 | ) 35 | ) 36 | else: 37 | self.register_parameter("U", None) 38 | 39 | self.reset_parameters() 40 | 41 | def reset_parameters(self): 42 | nn.init.xavier_uniform_(self.W_e) 43 | nn.init.xavier_uniform_(self.W_d) 44 | nn.init.constant_(self.b, 0.0) 45 | if self.biaffine: 46 | nn.init.xavier_uniform_(self.U) 47 | 48 | def forward(self, input_e, input_d, mask_d=None, mask_e=None): 49 | assert input_d.size(0) == input_e.size(0) 50 | batch, length_decoder, _ = input_d.size() 51 | _, length_encoder, _ = input_e.size() 52 | 53 | # input_d : [b, t, d] 54 | # input_e : [b, s, e] 55 | # out_d : [b, l, d, 1] 56 | # out_e : [b, l ,1, e] 57 | out_d = torch.matmul(self.W_d, input_d.transpose(1, 2)).unsqueeze(3) 58 | out_e = torch.matmul(self.W_e, input_e.transpose(1, 2)).unsqueeze(2) 59 | 60 | if self.biaffine: 61 | # output : [b, 1, t, d] * [l, d, e] -> [b, l, t, e] 62 | output = torch.matmul(input_d.unsqueeze(1), self.U) 63 | # output : [b, l, t, e] * [b, 1, e, s] -> [b, l, t, s] 64 | output = torch.matmul(output, input_e.unsqueeze(1).transpose(2, 3)) 65 | output = output + out_d + out_e + self.b 66 | else: 67 | output = out_d + out_d + self.b 68 | 69 | if mask_d is not None: 70 | output = ( 71 | output 72 | * mask_d.unsqueeze(1).unsqueeze(3) 73 | * mask_e.unsqueeze(1).unsqueeze(2) 74 | ) 75 | 76 | # input1 = (batch_size, input11, input12) 77 | # input2 = (batch_size, input21, input22) 78 | return output # (batch_size, output_size, input11, input21) 79 | 80 | class BiLinear(nn.Module): 81 | def __init__(self, left_features: int, right_features: int, out_features: int): 82 | super(BiLinear, self).__init__() 83 | self.left_features = left_features 84 | self.right_features = right_features 85 | self.out_features = out_features 86 | 87 | self.U = Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features)) 88 | self.W_l = Parameter(torch.Tensor(self.out_features, self.left_features)) 89 | self.W_r = Parameter(torch.Tensor(self.out_features, self.right_features)) 90 | self.bias = Parameter(torch.Tensor(out_features)) 91 | 92 | self.reset_parameters() 93 | 94 | def reset_parameters(self) -> None: 95 | nn.init.xavier_uniform_(self.W_l) 96 | nn.init.xavier_uniform_(self.W_r) 97 | nn.init.constant_(self.bias, 0.0) 98 | nn.init.xavier_uniform_(self.U) 99 | 100 | def forward(self, input_left: torch.Tensor, input_right: torch.Tensor) -> torch.Tensor: 101 | left_size = input_left.size() 102 | right_size = input_right.size() 103 | assert left_size[:-1] == right_size[:-1], "batch size of left and right inputs mis-match: (%s, %s)" % ( 104 | left_size[:-1], 105 | right_size[:-1], 106 | ) 107 | batch = int(np.prod(left_size[:-1])) 108 | 109 | input_left = input_left.contiguous().view(batch, self.left_features) 110 | input_right = input_right.contiguous().view(batch, self.right_features) 111 | 112 | output = F.bilinear(input_left, input_right, self.U, self.bias) 113 | output = output + F.linear(input_left, self.W_l, None) + F.linear(input_right, self.W_r, None) 114 | return output.view(left_size[:-1] + (self.out_features,)) 115 | -------------------------------------------------------------------------------- /run_NLI.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | from attrdict import AttrDict 5 | 6 | # roberta 7 | from transformers import AutoTokenizer 8 | from transformers import RobertaConfig 9 | 10 | from src.model.model import RobertaForSequenceClassification 11 | from src.model.main_functions import train, evaluate, predict 12 | 13 | from src.functions.utils import init_logger, set_seed 14 | 15 | import sys 16 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 17 | 18 | def create_model(args): 19 | 20 | if args.model_name_or_path.split("/")[-2] == "roberta": 21 | 22 | # 모델 파라미터 Load 23 | config = RobertaConfig.from_pretrained( 24 | args.model_name_or_path 25 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)), 26 | cache_dir=args.cache_dir, 27 | ) 28 | 29 | config.num_labels = args.num_labels 30 | # roberta attention 추출하기 31 | config.output_attentions=True 32 | 33 | # tokenizer는 pre-trained된 것을 불러오는 과정이 아닌 불러오는 모델의 vocab 등을 Load 34 | # BertTokenizerFast로 되어있음 35 | tokenizer = AutoTokenizer.from_pretrained( 36 | args.model_name_or_path 37 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)), 38 | do_lower_case=args.do_lower_case, 39 | cache_dir=args.cache_dir, 40 | 41 | ) 42 | print(tokenizer) 43 | 44 | model = RobertaForSequenceClassification.from_pretrained( 45 | args.model_name_or_path 46 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)), 47 | cache_dir=args.cache_dir, 48 | config=config, 49 | prem_max_sentence_length=args.prem_max_sentence_length, 50 | hypo_max_sentence_length=args.hypo_max_sentence_length, 51 | # from_tf=True if args.from_init_weight else False 52 | ) 53 | 54 | args.model_name_or_path = args.cache_dir 55 | # print(tokenizer.convert_tokens_to_ids("")) 56 | 57 | model.to(args.device) 58 | 59 | print(" idx") 60 | print(tokenizer.convert_tokens_to_ids("")) 61 | return model, tokenizer 62 | 63 | def main(cli_args): 64 | # 파라미터 업데이트 65 | args = AttrDict(vars(cli_args)) 66 | args.device = "cuda" 67 | logger = logging.getLogger(__name__) 68 | 69 | # logger 및 seed 지정 70 | init_logger() 71 | set_seed(args) 72 | 73 | # 모델 불러오기 74 | model, tokenizer = create_model(args) 75 | 76 | # Running mode에 따른 실행 77 | if args.do_train: 78 | train(args, model, tokenizer, logger) 79 | elif args.do_eval: 80 | #for i in range(1, 13): evaluate(args, model, tokenizer, logger, epoch_idx =i) 81 | evaluate(args, model, tokenizer, logger, epoch_idx =args.checkpoint) 82 | 83 | elif args.do_predict: 84 | predict(args, model, tokenizer) 85 | 86 | 87 | if __name__ == '__main__': 88 | cli_parser = argparse.ArgumentParser() 89 | 90 | # Directory 91 | 92 | # cli_parser.add_argument("--data_dir", type=str, default="./data") 93 | # cli_parser.add_argument("--train_file", type=str, default="klue-nli-v1_train.json") 94 | # cli_parser.add_argument("--eval_file", type=str, default="klue-nli-v1_dev.json") 95 | # cli_parser.add_argument("--predict_file", type=str, default="klue-nli-v1_dev.json") #"klue-nli-v1_dev_sample_10.json") 96 | 97 | #cli_parser.add_argument("--num_labels", type=int, default=3) 98 | 99 | # roberta 100 | cli_parser.add_argument("--model_name_or_path", type=str, default="./roberta/init_weight") 101 | cli_parser.add_argument("--cache_dir", type=str, default="./roberta/init_weight") 102 | 103 | # ------------------------------------------------------------------------------------------------ 104 | cli_parser.add_argument("--data_dir", type=str, default="./data/merge") 105 | 106 | cli_parser.add_argument("--num_labels", type=int, default=3) 107 | 108 | cli_parser.add_argument("--train_file", type=str, default='parsing_1_klue_nli_train.json') 109 | cli_parser.add_argument("--eval_file", type=str, default='parsing_1_klue_nli_dev.json') 110 | cli_parser.add_argument("--predict_file", type=str, default='parsing_1_klue_nli_dev.json') 111 | #----------------------------------------------------------------------------------------------------------- 112 | ################################################################################################################################## 113 | 114 | #cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver1_wDP") # checkout-5 90.60 90.56 ± 0.04 115 | #cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver3_wDP") # checkout-3 90.51 90.45 ± 0.06 116 | cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver4_wDP") # checkout-4 90.82 90.78 ± 0.04 117 | # ------------------------------------------------------------------------------------------------------------ 118 | ## klue # ver1 = 18 ver3,4 = 27 119 | cli_parser.add_argument("--prem_max_sentence_length", type=int, default=27) 120 | cli_parser.add_argument("--hypo_max_sentence_length", type=int, default=27) 121 | 122 | # https://github.com/KLUE-benchmark/KLUE-baseline/blob/main/run_all.sh 123 | # Model Hyper Parameter 124 | cli_parser.add_argument("--max_seq_length", type=int, default=256) #512) 125 | # Training Parameter 126 | cli_parser.add_argument("--learning_rate", type=float, default=1e-5) 127 | cli_parser.add_argument("--train_batch_size", type=int, default=8) 128 | cli_parser.add_argument("--eval_batch_size", type=int, default=16) 129 | cli_parser.add_argument("--num_train_epochs", type=int, default=5) 130 | 131 | #cli_parser.add_argument("--save_steps", type=int, default=2000) 132 | cli_parser.add_argument("--logging_steps", type=int, default=100) 133 | cli_parser.add_argument("--seed", type=int, default=42) 134 | cli_parser.add_argument("--threads", type=int, default=8) 135 | 136 | cli_parser.add_argument("--weight_decay", type=float, default=0.0) 137 | cli_parser.add_argument("--adam_epsilon", type=int, default=1e-10) 138 | cli_parser.add_argument("--gradient_accumulation_steps", type=int, default=4) 139 | cli_parser.add_argument("--warmup_steps", type=int, default=0) 140 | cli_parser.add_argument("--max_steps", type=int, default=-1) 141 | cli_parser.add_argument("--max_grad_norm", type=int, default=1.0) 142 | 143 | cli_parser.add_argument("--verbose_logging", type=bool, default=False) 144 | cli_parser.add_argument("--do_lower_case", type=bool, default=False) 145 | cli_parser.add_argument("--no_cuda", type=bool, default=False) 146 | 147 | # Running Mode 148 | cli_parser.add_argument("--from_init_weight", type=bool, default= True) #False)#True) 149 | cli_parser.add_argument("--checkpoint", type=str, default="4") 150 | 151 | cli_parser.add_argument("--do_train", type=bool, default=True)#False)#True) 152 | cli_parser.add_argument("--do_eval", type=bool, default=False)#True) 153 | cli_parser.add_argument("--do_predict", type=bool, default=False)#True)#False) 154 | 155 | cli_args = cli_parser.parse_args() 156 | 157 | main(cli_args) 158 | -------------------------------------------------------------------------------- /src/model/main_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import timeit 6 | from fastprogress.fastprogress import master_bar, progress_bar 7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 8 | from transformers.file_utils import is_torch_available 9 | 10 | from transformers import ( 11 | AdamW, 12 | get_linear_schedule_with_warmup 13 | ) 14 | 15 | from src.functions.utils import load_examples, set_seed, to_list 16 | from src.functions.metric import get_score, get_sklearn_score 17 | 18 | from functools import partial 19 | 20 | def train(args, model, tokenizer, logger): 21 | max_acc =0 22 | # 학습에 사용하기 위한 dataset Load 23 | ## dataset: tensor형태의 데이터셋 24 | ## all_input_ids, 25 | # all_attention_masks, 26 | # all_labels, 27 | # all_cls_index, 28 | # all_p_mask, 29 | # all_example_indices, 30 | # all_feature_index 31 | 32 | train_dataset = load_examples(args, tokenizer, evaluate=False, output_examples=False) 33 | 34 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader 35 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정 36 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정 37 | train_sampler = RandomSampler(train_dataset) 38 | 39 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 40 | 41 | # t_total: total optimization step 42 | # optimization 최적화 schedule 을 위한 전체 training step 계산 43 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 44 | 45 | # Layer에 따른 가중치 decay 적용 46 | no_decay = ["bias", "LayerNorm.weight"] 47 | optimizer_grouped_parameters = [ 48 | { 49 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 50 | "weight_decay": args.weight_decay, 51 | }, 52 | { 53 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 54 | "weight_decay": 0.0}, 55 | ] 56 | 57 | # optimizer 및 scheduler 선언 58 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 59 | scheduler = get_linear_schedule_with_warmup( 60 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 61 | ) 62 | 63 | # Training Step 64 | logger.info("***** Running training *****") 65 | logger.info(" Num examples = %d", len(train_dataset)) 66 | logger.info(" Num Epochs = %d", args.num_train_epochs) 67 | logger.info(" Train batch size per GPU = %d", args.train_batch_size) 68 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps) 69 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 70 | logger.info(" Total optimization steps = %d", t_total) 71 | 72 | global_step = 1 73 | if not args.from_init_weight: global_step += int(args.checkpoint) 74 | 75 | tr_loss, logging_loss = 0.0, 0.0 76 | 77 | # loss buffer 초기화 78 | model.zero_grad() 79 | 80 | mb = master_bar(range(int(args.num_train_epochs))) 81 | set_seed(args) 82 | 83 | epoch_idx=0 84 | if not args.from_init_weight: epoch_idx += int(args.checkpoint) 85 | 86 | for epoch in mb: 87 | epoch_iterator = progress_bar(train_dataloader, parent=mb) 88 | for step, batch in enumerate(epoch_iterator): 89 | # train 모드로 설정 90 | model.train() 91 | batch = tuple(t.to(args.device) for t in batch) 92 | 93 | # 모델에 입력할 입력 tensor 저장 94 | inputs_list = ["input_ids", "attention_mask"] 95 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids") 96 | inputs_list.append("labels") 97 | inputs = dict() 98 | for n, input in enumerate(inputs_list): inputs[input] = batch[n] 99 | 100 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span'] 101 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m+1)] 102 | 103 | # Loss 계산 및 저장 104 | ## outputs = (total_loss,) + outputs 105 | outputs = model(**inputs) 106 | loss = outputs[0] 107 | 108 | # 높은 batch size는 학습이 진행하는 중에 발생하는 noisy gradient가 경감되어 불안정한 학습을 안정적이게 되도록 해줌 109 | # 높은 batch size 효과를 주기위한 "gradient_accumulation_step" 110 | ## batch size *= gradient_accumulation_step 111 | # batch size: 16 112 | # gradient_accumulation_step: 2 라고 가정 113 | # 실제 batch size 32의 효과와 동일하진 않지만 비슷한 효과를 보임 114 | if args.gradient_accumulation_steps > 1: 115 | loss = loss / args.gradient_accumulation_steps 116 | 117 | ## batch_size의 개수만큼의 데이터를 입력으로 받아 만들어진 모델의 loss는 118 | ## 입력 데이터들에 대한 특징을 보유하고 있다(loss를 어떻게 만드느냐에 따라 달라) 119 | ### loss_fct = CrossEntropyLoss(ignore_index=ignored_index, reduction = ?) 120 | ### reduction = mean : 입력 데이터에 대한 평균 121 | loss.backward() 122 | tr_loss += loss.item() 123 | 124 | # Loss 출력 125 | if (global_step + 1) % 50 == 0: 126 | print("{} step processed.. Current Loss : {}".format((global_step+1),loss.item())) 127 | 128 | if (step + 1) % args.gradient_accumulation_steps == 0: 129 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 130 | 131 | optimizer.step() 132 | scheduler.step() # Update learning rate schedule 133 | model.zero_grad() 134 | global_step += 1 135 | 136 | epoch_idx += 1 137 | #logger.info("***** Eval results *****") 138 | #results = evaluate(args, model, tokenizer, logger, epoch_idx = str(epoch_idx), tr_loss = loss.item()) 139 | 140 | output_dir = os.path.join(args.output_dir, "model/checkpoint-{}".format(epoch_idx)) 141 | if not os.path.exists(output_dir): 142 | os.makedirs(output_dir) 143 | 144 | # 학습된 가중치 및 vocab 저장 145 | ## pretrained 모델같은 경우 model.save_pretrained(...)로 저장 146 | ## nn.Module로 만들어진 모델일 경우 model.save(...)로 저장 147 | ### 두개가 모두 사용되는 모델일 경우 이 두가지 방법으로 저장을 해야한다!!!! 148 | model.save_pretrained(output_dir) 149 | tokenizer.save_pretrained(output_dir) 150 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 151 | logger.info("Saving model checkpoint to %s", output_dir) 152 | 153 | # # model save 154 | # if (args.logging_steps > 0 and max_acc < float(results["accuracy"])): 155 | # max_acc = float(results["accuracy"]) 156 | # # 모델 저장 디렉토리 생성 157 | # output_dir = os.path.join(args.output_dir, "model/checkpoint-{}".format(epoch_idx)) 158 | # if not os.path.exists(output_dir): 159 | # os.makedirs(output_dir) 160 | # 161 | # # 학습된 가중치 및 vocab 저장 162 | # ## pretrained 모델같은 경우 model.save_pretrained(...)로 저장 163 | # ## nn.Module로 만들어진 모델일 경우 model.save(...)로 저장 164 | # ### 두개가 모두 사용되는 모델일 경우 이 두가지 방법으로 저장을 해야한다!!!! 165 | # model.save_pretrained(output_dir) 166 | # tokenizer.save_pretrained(output_dir) 167 | # # torch.save(args, os.path.join(output_dir, "training_args.bin")) 168 | # logger.info("Saving model checkpoint to %s", output_dir) 169 | 170 | mb.write("Epoch {} done".format(epoch + 1)) 171 | 172 | return global_step, tr_loss / global_step 173 | 174 | # 정답이 사전부착된 데이터로부터 평가하기 위한 함수 175 | def evaluate(args, model, tokenizer, logger, epoch_idx = "", tr_loss = 1): 176 | # 데이터셋 Load 177 | ## dataset: tensor형태의 데이터셋 178 | ## example: json형태의 origin 데이터셋 179 | ## features: index번호가 추가된 list형태의 examples 데이터셋 180 | dataset, examples, features = load_examples(args, tokenizer, evaluate=True, output_examples=True) 181 | 182 | # 최종 출력 파일 저장을 위한 디렉토리 생성 183 | if not os.path.exists(args.output_dir): 184 | os.makedirs(args.output_dir) 185 | 186 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader 187 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정 188 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정 189 | eval_sampler = SequentialSampler(dataset) 190 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 191 | 192 | # Eval! 193 | logger.info("***** Running evaluation {} *****".format(epoch_idx)) 194 | logger.info(" Num examples = %d", len(dataset)) 195 | logger.info(" Batch size = %d", args.eval_batch_size) 196 | 197 | # 평가 시간 측정을 위한 time 변수 198 | start_time = timeit.default_timer() 199 | 200 | # 예측 라벨 201 | pred_logits = torch.tensor([], dtype = torch.long).to(args.device) 202 | for batch in progress_bar(eval_dataloader): 203 | # 모델을 평가 모드로 변경 204 | model.eval() 205 | batch = tuple(t.to(args.device) for t in batch) 206 | 207 | with torch.no_grad(): 208 | # 평가에 필요한 입력 데이터 저장 209 | inputs_list = ["input_ids", "attention_mask"] 210 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids") 211 | inputs = dict() 212 | for n, input in enumerate(inputs_list): inputs[input] = batch[n] 213 | 214 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span'] 215 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m + 1)] 216 | 217 | # outputs = (label_logits, ) 218 | # label_logits: [batch_size, num_labels] 219 | outputs = model(**inputs) 220 | 221 | pred_logits = torch.cat([pred_logits,outputs[0]], dim = 0) 222 | 223 | # pred_label과 gold_label 비교 224 | pred_logits= pred_logits.detach().cpu().numpy() 225 | pred_labels = np.argmax(pred_logits, axis=-1) 226 | ## gold_labels = 0 or 1 or 2 227 | gold_labels = [example.gold_label for example in examples] 228 | 229 | # print('\n\n=====================outputs=====================') 230 | # for g,p in zip(gold_labels, pred_labels): 231 | # print(str(g)+"\t"+str(p)) 232 | # print('===========================================================') 233 | 234 | # 평가 시간 측정을 위한 time 변수 235 | evalTime = timeit.default_timer() - start_time 236 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset)) 237 | 238 | # 최종 예측값과 원문이 저장된 example로 부터 성능 평가 239 | ## results = {"macro_precision":round(macro_precision, 4), "macro_recall":round(macro_recall, 4), "macro_f1_score":round(macro_f1_score, 4), \ 240 | ## "accuracy":round(total_accuracy, 4), \ 241 | ## "micro_precision":round(micro_precision, 4), "micro_recall":round(micro_recall, 4), "micro_f1":round(micro_f1_score, 4)} 242 | idx2label = {0:"entailment", 1:"contradiction", 2:"neutral"} 243 | #results = get_score(pred_labels, gold_labels, idx2label) 244 | results = get_sklearn_score(pred_labels, gold_labels, idx2label) 245 | 246 | output_dir = os.path.join( args.output_dir, 'eval') 247 | 248 | out_file_type = 'a' 249 | if not os.path.exists(output_dir): 250 | os.makedirs(output_dir) 251 | out_file_type ='w' 252 | 253 | # 평가 스크립트 기반 성능 저장을 위한 파일 생성 254 | if os.path.exists(args.model_name_or_path): 255 | print(args.model_name_or_path) 256 | eval_file_name = list(filter(None, args.model_name_or_path.split("/"))).pop() 257 | else: eval_file_name = "init_weight" 258 | output_eval_file = os.path.join(output_dir, "eval_result_{}.txt".format(eval_file_name)) 259 | 260 | with open(output_eval_file, out_file_type, encoding='utf-8') as f: 261 | f.write("train loss: {}\n".format(tr_loss)) 262 | f.write("epoch: {}\n".format(epoch_idx)) 263 | for k in results.keys(): 264 | f.write("{} : {}\n".format(k, results[k])) 265 | f.write("=======================================\n\n") 266 | return results 267 | 268 | def predict(args, model, tokenizer): 269 | dataset, examples, features = load_examples(args, tokenizer, evaluate=True, output_examples=True, do_predict=True) 270 | 271 | # 최종 출력 파일 저장을 위한 디렉토리 생성 272 | if not os.path.exists(args.output_dir): 273 | os.makedirs(args.output_dir) 274 | 275 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader 276 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정 277 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정 278 | eval_sampler = SequentialSampler(dataset) 279 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 280 | 281 | print("***** Running Prediction *****") 282 | print(" Num examples = %d", len(dataset)) 283 | 284 | # 예측 라벨 285 | pred_logits = torch.tensor([], dtype=torch.long).to(args.device) 286 | for batch in progress_bar(eval_dataloader): 287 | # 모델을 평가 모드로 변경 288 | model.eval() 289 | batch = tuple(t.to(args.device) for t in batch) 290 | 291 | with torch.no_grad(): 292 | # 평가에 필요한 입력 데이터 저장 293 | inputs_list = ["input_ids", "attention_mask"] 294 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids") 295 | inputs = dict() 296 | for n, input in enumerate(inputs_list): inputs[input] = batch[n] 297 | 298 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span'] 299 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m + 1)] 300 | 301 | # outputs = (label_logits, ) 302 | # label_logits: [batch_size, num_labels] 303 | outputs = model(**inputs) 304 | 305 | pred_logits = torch.cat([pred_logits, outputs[0]], dim=0) 306 | 307 | # pred_label과 gold_label 비교 308 | pred_logits = pred_logits.detach().cpu().numpy() 309 | pred_labels = np.argmax(pred_logits, axis=-1) 310 | ## gold_labels = 0 or 1 or 2 311 | gold_labels = [example.gold_label for example in examples] 312 | 313 | idx2label = {0:"entailment", 1:"contradiction", 2:"neutral"} 314 | #results = get_score(pred_labels, gold_labels, idx2label) 315 | results = get_sklearn_score(pred_labels, gold_labels, idx2label) 316 | 317 | # 검증 스크립트 기반 성능 저장 318 | output_dir = os.path.join(args.output_dir, 'test') 319 | 320 | out_file_type = 'a' 321 | if not os.path.exists(output_dir): 322 | os.makedirs(output_dir) 323 | out_file_type = 'w' 324 | 325 | ## 검증 스크립트 기반 성능 저장을 위한 파일 생성 326 | if os.path.exists(args.model_name_or_path): 327 | print(args.model_name_or_path) 328 | eval_file_name = list(filter(None, args.model_name_or_path.split("/"))).pop() 329 | else: 330 | eval_file_name = "init_weight" 331 | output_test_file = os.path.join(output_dir, "test_result_{}_incorrect.txt".format(eval_file_name)) 332 | 333 | with open(output_test_file, out_file_type, encoding='utf-8') as f: 334 | print('\n\n=====================outputs=====================') 335 | for i,(g,p) in enumerate(zip(gold_labels, pred_labels)): 336 | if g != p: 337 | f.write("premise: {}\thypothesis: {}\tcorrect: {}\tpredict: {}\n".format(examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p])) 338 | for k in results.keys(): 339 | f.write("{} : {}\n".format(k, results[k])) 340 | f.write("=======================================\n\n") 341 | 342 | out = {"premise":[], "hypothesis":[], "correct":[], "predict":[]} 343 | for i,(g,p) in enumerate(zip(gold_labels, pred_labels)): 344 | #if g != p: 345 | for k,v in zip(out.keys(),[examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p]]): 346 | out[k].append(v) 347 | for k, v in zip(out.keys(), [examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p]]): 348 | out[k].append(v) 349 | df = pd.DataFrame(out) 350 | df.to_csv(os.path.join(output_dir, "test_result_{}.csv".format(eval_file_name)), index=False) 351 | 352 | return results -------------------------------------------------------------------------------- /src/dependency/merge.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from operator import itemgetter 4 | 5 | def change_tag(change_dic, tag_list, text): 6 | #print("change_dic: " + str(change_dic)) 7 | del_dic = [] 8 | for key in change_dic.keys(): 9 | if change_dic[key][0] in change_dic.keys(): 10 | del_dic.append(change_dic[key][0]) 11 | change_dic[key] = change_dic[change_dic[key][0]] 12 | #print(del_dic) 13 | for dic in set(del_dic): 14 | del change_dic[dic] 15 | 16 | dic_val_idx = [val[1] for val in change_dic.values()] 17 | for i, val_idx in enumerate(dic_val_idx): 18 | for j, val2 in zip([j for j, val2 in enumerate(dic_val_idx) if j != i], [val2 for j, val2 in enumerate(dic_val_idx) if j != i]): 19 | #print(str(i) + " " + str(j)) 20 | #print(val_idx) 21 | #print(val2) 22 | if (len(set(val_idx).intersection(set(val2))) != 0) or (len(set(val2).intersection(set(val_idx))) != 0): 23 | #print(str(i)+" "+str(j)) 24 | change_dic[list(change_dic.keys())[i]] = [" ".join([text.split()[k] for k in list(set( 25 | change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))]), 26 | list(set(change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))] 27 | 28 | change_dic[list(change_dic.keys())[j]] = [" ".join([text.split()[k] for k in list(set( 29 | change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))]), 30 | list(set(change_dic[list(change_dic.keys())[i]][1] + 31 | change_dic[list(change_dic.keys())[j]][1]))] 32 | #print("change_dic: " + str(change_dic)) 33 | dic_val_idx = [val[1] for val in change_dic.values()] 34 | #print("change_dic: " + str(change_dic)) 35 | for tag_idx, tag_li in enumerate(tag_list): 36 | del_list = [] 37 | for tag_i,tag_l in enumerate(tag_li): 38 | if tag_l[0][0] in change_dic.keys(): 39 | tag_list[tag_idx][tag_i][2][0] = change_dic[tag_l[0][0]][1] 40 | tag_list[tag_idx][tag_i][0][0] = change_dic[tag_l[0][0]][0] 41 | if tag_l[0][1] in change_dic.keys(): 42 | tag_list[tag_idx][tag_i][2][1] = change_dic[tag_l[0][1]][1] 43 | tag_list[tag_idx][tag_i][0][1] = change_dic[tag_l[0][1]][0] 44 | 45 | if tag_l[0][0] == tag_l[0][1]: del_list.append(tag_l) 46 | tag_list[tag_idx] = [x for x in tag_list[tag_idx] if x not in del_list] 47 | 48 | return tag_list, {} 49 | 50 | def merge_tag(inf_dir, outf_dir, tag_li_type = "modifier_w/_phrase"): 51 | dp = inf_dir.split("_")[0].split("/")[-1] 52 | with open(inf_dir, "r", encoding="utf-8") as inf: 53 | datas = json.load(inf) 54 | 55 | outputs = [];outputs2 = []; 56 | for id, data in tqdm(enumerate(datas)): 57 | output = [];output2 = []; 58 | # {"origin": text, dp:root, words:[[word, tag, idx], [], ... ]} 59 | texts_list = [data["premise"], data["hypothesis"]] 60 | #print("\n=================="+str(data["guid"])+"=============================") 61 | for texts in texts_list: 62 | #print("\n--------------------------------------------------------------------") 63 | text = texts["origin"] 64 | #print(text) 65 | #print("--------------------------------------------------------------------") 66 | 67 | # [{'R', 'VNP', 'L', 'VP', 'S', 'AP', 'NP', 'DP', 'IP', 'X'}, {'None', 'MOD', 'CNJ', 'AJT', 'OBJ', 'SBJ', 'CMP'}] 68 | r_list = []; l_list = []; s_list = []; x_list = []; np_list = []; dp_list = []; vp_list = [];vnp_list = []; ap_list = []; ip_list = [] 69 | MOD = []; AJT = []; CMP = []; 70 | np_cnj_list = [] 71 | #print(texts[dp]["words"]) 72 | 73 | # 전처리 74 | ## [['어떤 방에서도', '금지됩니다.'], ['NP', 'AJT'], [[0, 1], [3]]], [['방에서도', '흡연은'], ['NP', 'MOD'], [[1], [2]]]와 같은 경우 75 | for i, koala in enumerate(texts[dp]["words"]): 76 | for j, other in enumerate(texts[dp]["words"]): 77 | if (texts[dp]["words"][i][2][0] != other[2][0]) and (set(texts[dp]["words"][i][2][0]+other[2][0]) == set(other[2][0])): 78 | texts[dp]["words"][i][0][0] = other[0][0] 79 | texts[dp]["words"][i][2][0] = other[2][0] 80 | if (texts[dp]["words"][i][2][0] != other[2][1]) and (set(texts[dp]["words"][i][2][0]+other[2][1]) == set(other[2][1])): 81 | texts[dp]["words"][i][0][0] = other[0][1] 82 | texts[dp]["words"][i][2][0] = other[2][1] 83 | if (texts[dp]["words"][i][2][1] != other[2][0]) and (set(texts[dp]["words"][i][2][1]+other[2][0]) == set(other[2][0])): 84 | texts[dp]["words"][i][0][1] = other[0][0] 85 | texts[dp]["words"][i][2][1] = other[2][0] 86 | if (texts[dp]["words"][i][2][1] != other[2][1]) and (set(texts[dp]["words"][i][2][1]+other[2][1]) == set(other[2][1])): 87 | texts[dp]["words"][i][0][1] = other[0][1] 88 | texts[dp]["words"][i][2][1] = other[2][1] 89 | #print(texts[dp]["words"]) 90 | for i, koala in enumerate(texts[dp]["words"]): 91 | #print(koala) 92 | tag = koala[1] 93 | word_idx = koala[2] 94 | if (tag[0] == "NP") and (tag[1] != "CNJ"): 95 | np_list.append(koala) 96 | elif (tag[0] == "DP"): 97 | dp_list.append(koala) 98 | elif (tag[0] == "VP"): 99 | vp_list.append(koala) 100 | elif (tag[0] == "VNP"): 101 | vnp_list.append(koala) 102 | elif (tag[0] == "AP"): 103 | ap_list.append(koala) 104 | elif (tag[0] == "IP"): 105 | ip_list.append(koala) 106 | elif (tag[0] == "R"): 107 | r_list.append(koala) 108 | elif (tag[0] == "L"): 109 | l_list.append(koala) 110 | elif (tag[0] == "S"): 111 | s_list.append(koala) 112 | elif (tag[0] == "X"): 113 | x_list.append(koala) 114 | 115 | if (tag[1] == "MOD"): MOD.append(koala); 116 | elif (tag[1] == "AJT"): AJT.append(koala); 117 | elif (tag[1] == "CMP"): CMP.append(koala); 118 | 119 | 120 | if (tag[0] == "NP") and (tag[1] == "CNJ"): 121 | np_cnj_list.append(koala) 122 | 123 | vp_list = vp_list+vnp_list 124 | tag_list = [] 125 | if tag_li_type == "modifier": 126 | tag_list = [MOD+ AJT+ CMP] #3 127 | # tag_list = [MOD, AJT, CMP] #2 128 | elif tag_li_type == "phrase": 129 | tag_list = [x for x in [np_list, dp_list, vp_list, ap_list, ip_list, r_list, l_list, s_list, x_list] if len(x) != 0] # 1 130 | elif tag_li_type == "modifier_w/_phrase": 131 | tag_list = [MOD+ AJT+ CMP]+[x for x in [np_list, dp_list, vp_list, ap_list, ip_list, r_list, l_list, s_list, x_list] if len(x) != 0] #4 132 | 133 | change_dic = {} 134 | for tag_idx,tag_li in enumerate(tag_list): 135 | #print("tag_li: "+str(tag_li)) 136 | conti = True 137 | while conti: 138 | tag_list, change_dic = change_tag(change_dic, tag_list, text) 139 | tag_li = tag_list[tag_idx] 140 | #print("tag_li: " + str(tag_li)) 141 | new_tag_li = [] 142 | other_tag_li = [] 143 | for tag_l in tag_li: 144 | if abs(max(tag_l[2][1])-min(tag_l[2][0]))==1:new_tag_li.append(tag_l) 145 | else: other_tag_li.append(tag_l) 146 | 147 | # print("new_tag_li: "+str(new_tag_li)) 148 | 149 | if (len(new_tag_li)==0) or (len(tag_li)==1): 150 | conti = False; 151 | else: 152 | new_tag_li = sorted(new_tag_li, key = lambda x:(max(x[2][1]))) 153 | # print("new_tag_li after sorted: "+str(new_tag_li)) 154 | 155 | tag_li = other_tag_li 156 | del other_tag_li 157 | #print("tag_li: " + str(tag_li)) 158 | 159 | # 거리의 길이가 1일 때 양 옆에 이어서 있는 경우 160 | i = 1 161 | while (i != len(new_tag_li)): 162 | if (new_tag_li[-i+1][0][1] == new_tag_li[-i][0][0]) and (new_tag_li[-i+1][2][1] == new_tag_li[-i][2][0]): 163 | if min(new_tag_li[-i][2][0])min(new_tag_li[-i][2][1]): 167 | change_dic.update({new_tag_li[-i][0][0]: [" ".join([new_tag_li[-i][0][1], new_tag_li[-i][0][0]]), list(set(new_tag_li[-i+1][2][1]+new_tag_li[-i][2][0]+new_tag_li[-i][2][1]))]}) 168 | change_dic.update({new_tag_li[-i][0][1]:[" ".join([new_tag_li[-i][0][1], new_tag_li[-i][0][0]]), list(set(new_tag_li[-i + 1][2][1] + new_tag_li[-i][2][0] + new_tag_li[-i][2][1]))]}) 169 | new_tag_li[-i + 1] = [[new_tag_li[-i+1][0][0], " ".join(new_tag_li[-i][0])], new_tag_li[-i+1][1], [new_tag_li[-i+1][2][0], list(set(new_tag_li[-i+1][2][1]+new_tag_li[-i][2][0]+new_tag_li[-i][2][1]))]] 170 | new_tag_li.pop() 171 | else: i+=1 172 | 173 | # 거리의 길이가 2인 경우 174 | del_list = [] 175 | for new_tag_l in new_tag_li: 176 | for tag_l in tag_li: 177 | if (tag_l[0][1] == new_tag_l[0][1]) and (max(tag_l[2][0])+1 == min(new_tag_l[2][0])): 178 | if min(tag_l[2][0]) < min(new_tag_l[2][0]): 179 | change_dic.update({new_tag_l[0][0]:[" ".join([tag_l[0][0], new_tag_l[0][0]]), list(set(tag_l[2][0] + new_tag_l[2][0]))]}) 180 | change_dic.update({tag_l[0][0]:[" ".join([tag_l[0][0], new_tag_l[0][0]]),list(set(tag_l[2][0] + new_tag_l[2][0]))]}) 181 | elif min(tag_l[2][0]) > min(new_tag_l[2][0]): 182 | change_dic.update({new_tag_l[0][0]: [" ".join([new_tag_l[0][0], tag_l[0][0]]), 183 | list(set(tag_l[2][0] + new_tag_l[2][0]))]}) 184 | change_dic.update({tag_l[0][0]: [" ".join([new_tag_l[0][0], tag_l[0][0]]), 185 | list(set(tag_l[2][0] + new_tag_l[2][0]))]}) 186 | 187 | del_list.append(new_tag_l) 188 | 189 | new_tag_li = [x for x in new_tag_li if x not in del_list] 190 | 191 | #print("new_tag_li: "+str(new_tag_li)) 192 | if len(new_tag_li) != 0: 193 | for new_tag_l in new_tag_li: 194 | if (min(new_tag_l[2][0]) < min(new_tag_l[2][1])) and (max(new_tag_l[2][0])+1 == min(new_tag_l[2][1])): 195 | change_dic.update({new_tag_l[0][0]:[" ".join(new_tag_l[0]), list(set(sum(new_tag_l[2],[])))]}) 196 | change_dic.update({new_tag_l[0][1]: [" ".join(new_tag_l[0]), list(set(sum(new_tag_l[2], [])))]}) 197 | elif (min(new_tag_l[2][0]) > min(new_tag_l[2][1])) and (max(new_tag_l[2][1])+1 == min(new_tag_l[2][0])): 198 | change_dic.update({new_tag_l[0][0]:[" ".join([new_tag_l[0][1], new_tag_l[0][0]]), list(set(sum(new_tag_l[2],[])))]}) 199 | change_dic.update({new_tag_l[0][1]: [" ".join([new_tag_l[0][1], new_tag_l[0][0]]), list(set(sum(new_tag_l[2], [])))]}) 200 | new_tag_li = [] 201 | 202 | if change_dic == {}: 203 | conti = False; 204 | #print("change_dic: " + str(change_dic)) 205 | 206 | tag_list[tag_idx] = tag_li 207 | #print("change tag_li: " + str(tag_list[tag_idx])) 208 | #print("tag_list: "+ str(tag_list)) 209 | tag_list = [x for x in tag_list if x != []] 210 | 211 | #for i, koala in enumerate(texts[dp]["words"]): 212 | # # print(koala) 213 | # tag = koala[1] 214 | # if (tag[0] == "NP") and (tag[1] == "CNJ"): 215 | # if (koala[2][0] != koala[2][1]):np_cnj_list.append(koala) 216 | #""" 217 | new_koalas = [[] for tag_idx, _ in enumerate(tag_list)] 218 | for tag_idx, _ in enumerate(tag_list): 219 | # NP-CNJ 220 | if (len(np_cnj_list) != 0): 221 | for cnj_koala in np_cnj_list: 222 | for ttag_list in tag_list[tag_idx]: 223 | new_koala = [] 224 | if (len(set(cnj_koala[2][0]).intersection(set(ttag_list[2][0]))) != 0): 225 | new_koala = [[cnj_koala[0][0], ttag_list[0][1]], 226 | cnj_koala[1], 227 | [cnj_koala[2][0], ttag_list[2][1]]] 228 | 229 | if (len(set(cnj_koala[2][0]).intersection(set(ttag_list[2][1]))) != 0): 230 | new_koala = [[ttag_list[0][1], cnj_koala[0][1]], 231 | cnj_koala[1], 232 | [ttag_list[2][1], cnj_koala[2][1]]] 233 | 234 | if (len(set(cnj_koala[2][1]).intersection(set(ttag_list[2][0]))) != 0): 235 | new_koala = [[cnj_koala[0][0], ttag_list[0][0]], 236 | cnj_koala[1], 237 | [cnj_koala[2][0], ttag_list[2][0]]] 238 | 239 | if (len(set(cnj_koala[2][1]).intersection(set(ttag_list[2][1]))) != 0): 240 | new_koala = [[cnj_koala[0][0], ttag_list[0][1]], 241 | cnj_koala[1], 242 | [cnj_koala[2][0], ttag_list[2][1]]] 243 | 244 | if new_koala != []: new_koalas[tag_idx].append(new_koala) 245 | 246 | for k_idx, new_koala in enumerate(new_koalas): 247 | new_tag_list = tag_list[k_idx]+new_koala 248 | tag_list[k_idx] = new_tag_list 249 | #""" 250 | 251 | tag_list = sum(tag_list, []) 252 | tag_list = [tag for tag in tag_list if len(set(tag[2][0]).intersection(set(tag[2][1]))) == 0] 253 | #print("tag_list: " + str(tag_list)) 254 | 255 | sub_output2 = {} 256 | for tag in tag_list: 257 | #print(tag) 258 | sub_output2[tag[0][0]] = tag[2][0] 259 | sub_output2[tag[0][1]] = tag[2][1] 260 | sub_output2 = [[key, sorted(value)] for key,value in sub_output2.items()] 261 | # print(sub_output2) 262 | sub_output2.sort(key=itemgetter(1)) 263 | # print("sub_output2: " + str(sub_output2)) 264 | # 후처리 265 | i = 0 266 | while (i < len(sub_output2) - 1): 267 | sub1 = sub_output2[i] 268 | sub2 = sub_output2[i + 1] 269 | if (sub1[0] != sub2[0]) and (sub1[1] != sub2[1]) and ((len(set(sub1[1]).intersection(set(sub2[1]))) != 0) or (len(set(sub2[1]).intersection(set(sub1[1]))) != 0)): 270 | #print("sub_output2: " + str(sub_output2)) 271 | new_idx = sorted(list(set(sub1[1] + sub2[1]))) 272 | sub_output2[i] = (" ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1]), new_idx) 273 | sub_output2 = sub_output2[:i + 1] + sub_output2[i + 2:] 274 | #print("sub_output2: " + str(sub_output2)) 275 | for j, tag in enumerate(tag_list): 276 | if (tag[0][0] in [sub1[0], sub2[0]]): 277 | tag_list[j][0][0] = " ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1]) 278 | tag_list[j][2][0] = new_idx 279 | elif (tag[0][1] in [sub1[0], sub2[0]]): 280 | tag_list[j][0][1] = " ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1]) 281 | tag_list[j][2][1] = new_idx 282 | else: i += 1 283 | 284 | # print("sub_output2: " + str(sub_output2)) 285 | # print("tag_list: " + str(tag_list)) 286 | 287 | if sub_output2 == []: 288 | sub_output2 = [[" ".join(text.split()[:-1]), [i for i in range(0, len(text.split())-1)]], [text.split()[-1], [len(text.split())-1]]] 289 | tag_list = [[[sub_output2[0][0], sub_output2[1][0]], ['VP', 'MOD'], [sub_output2[0][1], sub_output2[1][1]]]] 290 | 291 | if sum([sub[1] for sub in sub_output2], []) != [i for i,_ in enumerate(text.split())]: 292 | for li in sum([[sorted(tag[2][0]), sorted(tag[2][1])] for tag in tag_list], []): 293 | if (len(set(li).intersection(set([t for t, _ in enumerate(text.split()) if t not in sum([sub[1] for sub in sub_output2], [])]))) != 0): 294 | sub_output2 += [[" ".join([text.split()[t] for t in li]), li]] 295 | sub_output2 += [[t, [i]] for i, t in enumerate(text.split()) if i not in sum([sub[1] for sub in sub_output2], [])] 296 | sub_output2.sort(key=itemgetter(1)) 297 | #print("sub_output2: " + str(sub_output2)) 298 | #print("tag_list: " + str(tag_list)) 299 | #print(sum([sub[1] for sub in sub_output2], [])) 300 | #print([[i, t] for i, t in enumerate(text.split())]) 301 | assert sum([sub[1] for sub in sub_output2], []) == [i for i,_ in enumerate(text.split())] 302 | 303 | sub_output2 = [[sub[0], sorted(sub[1])] for sub in sub_output2] 304 | output2.append(sub_output2) 305 | tag_list = [[tag[0], tag[1], [sorted(tag[2][0]), sorted(tag[2][1])]] for tag in tag_list] 306 | #tag_list += np_cnj_list 307 | output.append(tag_list) 308 | #print("output2: " + str(output2)) 309 | #print("output: " + str(output)) 310 | outputs.append(output) 311 | # sub_output2: [('흡연자분들은', [0]), ('발코니가 있는', [1, 2]), ('방이면', [3]), ('발코니에서 흡연이', [4, 5]), ('가능합니다.', [6])] 312 | outputs2.append(output2) 313 | 314 | for i, (data, merge1, merge2) in enumerate(zip(datas, outputs, outputs2)): 315 | datas[i]["premise"]["merge"] = {"origin": merge2[0], dp: merge1[0]} 316 | datas[i]["hypothesis"]["merge"] = {"origin": merge2[1], dp: merge1[1]} 317 | 318 | for i, data in tqdm(enumerate(datas)): 319 | for sen in ["premise", "hypothesis"]: 320 | for j, merge in enumerate(data[sen]["merge"]["origin"]): 321 | if merge == ["", []]: 322 | data[sen]["merge"]["origin"] = [data[sen]["merge"]["origin"][j+1]] 323 | datas[i][sen]["merge"]["origin"] = data[sen]["merge"]["origin"] 324 | data[sen]["merge"]["parsing"][0][0][0] = data[sen]["merge"]["parsing"][0][0][1] 325 | data[sen]["merge"]["parsing"][0][2][0] = data[sen]["merge"]["parsing"][0][2][1] 326 | datas[i][sen]["merge"]["parsing"] = data[sen]["merge"]["parsing"] 327 | 328 | for sen in ["premise", "hypothesis"]: 329 | new_koala = [] 330 | for k, words in enumerate(data[sen]["merge"][dp]): 331 | origin_idx = [merge[1] for merge in data[sen]["merge"]["origin"]] 332 | if (words[2][0] in origin_idx): 333 | if (words[2][1] in origin_idx): 334 | new_koala.append(words) 335 | datas[i][sen]["merge"][dp] = new_koala 336 | 337 | 338 | with open(outf_dir, 'w', encoding="utf-8") as f: 339 | json.dump(datas, f, ensure_ascii=False, indent=4) 340 | print("\n\nfinish!!") 341 | 342 | if __name__ =='__main__': 343 | 344 | inf_dirs = ["./../../data/parsing/parsing_1_klue_nli_train.json", "./../../data/parsing/parsing_1_klue_nli_dev.json"] 345 | #["./../../data/koala/koala_ver1_klue_nli_train.json", "./../../data/koala/koala_ver1_klue_nli_dev.json"] 346 | outf_dirs = ["./../../data/merge/parsing_1_klue_nli_train.json", "./../../data/merge/parsing_1_klue_nli_dev.json"] 347 | #["./../../data/merge/merge_3_klue_nli_train.json", "./../../data/merge/merge_3_klue_nli_dev.json"] 348 | 349 | for inf_dir, outf_dir in zip(inf_dirs, outf_dirs): 350 | merge_tag(inf_dir, outf_dir, tag_li_type = "modifier_w/_phrase") 351 | # merge_tag(inf_dir, outf_dir, tag_li_type = "modifier") 352 | # merge_tag(inf_dir, outf_dir, tag_li_type = "phrase") 353 | 354 | -------------------------------------------------------------------------------- /src/functions/processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool, cpu_count 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | import transformers 11 | from transformers.file_utils import is_tf_available, is_torch_available 12 | from transformers.data.processors.utils import DataProcessor 13 | 14 | if is_torch_available(): 15 | import torch 16 | from torch.utils.data import TensorDataset 17 | 18 | if is_tf_available(): 19 | import tensorflow as tf 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | 25 | def klue_convert_example_to_features(example, max_seq_length, is_training, prem_max_sentence_length, hypo_max_sentence_length, language): 26 | 27 | # 데이터의 유효성 검사를 위한 부분 28 | # ======================================================== 29 | label = None 30 | if is_training: 31 | # Get label 32 | label = example.label 33 | 34 | # label_dictionary에 주어진 label이 존재하지 않으면 None을 feature로 출력 35 | # If the label cannot be found in the text, then skip this example. 36 | ## kind_of_label: label의 종류 37 | kind_of_label = ["entailment", "contradiction", "neutral"] 38 | actual_text = kind_of_label[label] if label<=len(kind_of_label) else label 39 | if actual_text not in kind_of_label: 40 | logger.warning("Could not find label: '%s' \n not in entailment, contradiction, and neutral", actual_text) 41 | return None 42 | # ======================================================== 43 | 44 | # 단어와 토큰 간의 위치 정보 확인 45 | tok_to_orig_index = {"premise": [], "hypothesis": []} # token 개수만큼 # token에 대한 word의 위치 46 | orig_to_tok_index = {"premise": [], "hypothesis": []} # origin 개수만큼 # word를 토큰화하여 나온 첫번째 token의 위치 47 | all_doc_tokens = {"premise": [], "hypothesis": []} # origin text를 tokenization 48 | token_to_orig_map = {"premise": {}, "hypothesis": {}} 49 | 50 | for case in example.merge.keys(): 51 | new_merge = [] 52 | new_word = [] 53 | idx = 0 54 | for merge_idx in example.merge[case]: 55 | for m_idx in merge_idx: 56 | new_word.append(example.doc_tokens[case][m_idx]) 57 | new_word.append("") 58 | merge_idx = [m_idx+idx for m_idx in range(0,len(merge_idx))] 59 | new_merge.append(merge_idx) 60 | idx = max(merge_idx)+1 61 | new_merge.append([idx]) 62 | idx+=1 63 | example.merge[case] = new_merge 64 | example.doc_tokens[case] = new_word 65 | 66 | for case in example.merge.keys(): 67 | for merge_idx in example.merge[case]: 68 | for word_idx in merge_idx: 69 | # word를 토큰화하여 나온 첫번째 token의 위치 70 | orig_to_tok_index[case].append(len(tok_to_orig_index[case])) 71 | if (example.doc_tokens[case][word_idx] == ""): 72 | sub_tokens = [""] 73 | else: sub_tokens = tokenizer.tokenize(example.doc_tokens[case][word_idx]) 74 | for sub_token in sub_tokens: 75 | # token 저장 76 | all_doc_tokens[case].append(sub_token) 77 | # token에 대한 word의 위치 78 | tok_to_orig_index[case].append(word_idx) 79 | # token_to_orig_map: {token:word} 80 | token_to_orig_map[case][len(tok_to_orig_index[case]) - 1] = len(orig_to_tok_index[case]) - 1 81 | 82 | # print("tok_to_orig_index\n"+str(tok_to_orig_index)) 83 | # print("orig_to_tok_index\n"+str(orig_to_tok_index)) 84 | # print("all_doc_tokens\n"+str(all_doc_tokens)) 85 | # print("token_to_orig_map\n\tindex of token : index of word\n\t"+str(token_to_orig_map)) 86 | 87 | # ========================================================= 88 | 89 | if int(transformers.__version__[0]) <= 3: 90 | # sequence_added_tokens: [CLS], [SEP]가 추가된 토큰이므로 2 91 | ## roberta or camembert: 3 92 | ## sen1 sen2 93 | sequence_added_tokens = ( 94 | tokenizer.max_len - tokenizer.max_len_single_sentence + 1 95 | if "roberta" in language or "camembert" in language 96 | else tokenizer.max_len - tokenizer.max_len_single_sentence 97 | ) 98 | #print("sequence_added_tokens(# using special token): "+str(sequence_added_tokens)) 99 | 100 | # special token을 제외한 최대 들어갈 수 있는 실제 premise와 hypothesis의 token길이 101 | ## BERT같은 경우 입력으로 '[CLS] P [SEP] H [SEP]'이므로 102 | ### sequence_pair_added_tokens는 special token의 개수인 3 103 | ### tokenizer.max_len = 512 & tokenizer.max_len_sentences_pair = 509 104 | ## RoBERTa는 입력으로 P H 로 구성되므로 105 | ### sequence_pair_added_tokens는 special token의 개수인 4 106 | ### tokenizer.max_len = 512 & tokenizer.max_len_sentences_pair = 508 107 | sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair 108 | #print("sequence_pair_added_tokens(# of special token in text): "+str(sequence_pair_added_tokens)) 109 | 110 | # 최대 길이 넘는지 확인 111 | assert len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + sequence_pair_added_tokens <= tokenizer.max_len 112 | 113 | else: 114 | sequence_added_tokens = ( 115 | tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1 116 | if "roberta" in language or "camembert" in language 117 | else tokenizer.model_max_length - tokenizer.max_len_single_sentence 118 | ) 119 | #print("sequence_added_tokens(# using special token): "+str(sequence_added_tokens)) 120 | 121 | sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair 122 | #print("sequence_pair_added_tokens(# of special token in text): "+str(sequence_pair_added_tokens)) 123 | 124 | # 최대 길이 넘는지 확인 125 | assert len(all_doc_tokens["premise"]) + len( 126 | all_doc_tokens["hypothesis"]) + sequence_pair_added_tokens <= tokenizer.model_max_length 127 | 128 | input_ids = [tokenizer.cls_token_id] + [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["premise"]] 129 | prem_word_idxs = [0] + list(filter(lambda x: input_ids[x] == tokenizer.convert_tokens_to_ids(""),range(len(input_ids)))) 130 | 131 | input_ids += [tokenizer.sep_token_id] 132 | if "roberta" in language or "camembert" in language: input_ids += [tokenizer.sep_token_id] 133 | token_type_ids = [0] * len(input_ids) 134 | 135 | input_ids += [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]] + [tokenizer.sep_token_id] 136 | hypo_word_idxs = list(filter(lambda x: [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]][x] == tokenizer.convert_tokens_to_ids(""),range(len([tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]])))) 137 | 138 | hypo_word_idxs = [len(token_type_ids)-1] + [x+len(token_type_ids) for x in hypo_word_idxs] 139 | 140 | token_type_ids = token_type_ids + [1] * (len(input_ids) - len(token_type_ids)) 141 | position_ids = list(range(0, len(input_ids))) 142 | 143 | # non_padded_ids: padding을 제외한 토큰의 index 번호 144 | non_padded_ids = [i for i in input_ids] 145 | 146 | # tokens: padding을 제외한 토큰 147 | non_padded_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 148 | 149 | attention_mask = [1]*len(input_ids) 150 | 151 | paddings = [tokenizer.pad_token_id]*(max_seq_length - len(input_ids)) 152 | 153 | if tokenizer.padding_side == "right": 154 | input_ids += paddings 155 | attention_mask += [0]*len(paddings) 156 | token_type_ids += paddings 157 | position_ids += paddings 158 | else: 159 | input_ids = paddings + input_ids 160 | attention_mask = [0]*len(paddings) + attention_mask 161 | token_type_ids = paddings + token_type_ids 162 | position_ids = paddings + position_ids 163 | 164 | prem_word_idxs = [x+len(paddings) for x in prem_word_idxs] 165 | hypo_word_idxs = [x + len(paddings) for x in hypo_word_idxs] 166 | 167 | # """ 168 | # mean pooling 169 | prem_not_word_list = [] 170 | for k, p_idx in enumerate(prem_word_idxs[1:]): 171 | prem_not_word_idxs = [0] * len(input_ids); 172 | for j in range(prem_word_idxs[k] + 1, p_idx): 173 | prem_not_word_idxs[j] = 1 / (p_idx - prem_word_idxs[k] - 1) 174 | prem_not_word_list.append(prem_not_word_idxs) 175 | prem_not_word_list = prem_not_word_list + [[0] * len(input_ids)] * ( 176 | prem_max_sentence_length - len(prem_not_word_list)) 177 | 178 | 179 | hypo_not_word_list = []; 180 | for k, h_idx in enumerate(hypo_word_idxs[1:]): 181 | hypo_not_word_idxs = [0] * len(input_ids); 182 | for j in range(hypo_word_idxs[k] + 1, h_idx): 183 | hypo_not_word_idxs[j] = 1 / (h_idx - hypo_word_idxs[k] - 1) 184 | hypo_not_word_list.append(hypo_not_word_idxs) 185 | hypo_not_word_list = hypo_not_word_list + [[0] * len(input_ids)] * ( 186 | hypo_max_sentence_length - len(hypo_not_word_list)) 187 | """ 188 | # (a,b, |a-b|, a*b) 189 | prem_not_word_list = [[], []] 190 | for k, p_idx in enumerate(prem_word_idxs[1:]): 191 | prem_not_word_list[0].append(prem_word_idxs[k] + 1) 192 | prem_not_word_list[1].append(p_idx - 1) 193 | prem_not_word_list[0] = prem_not_word_list[0] + [int(hypo_word_idxs[-1]+i+2) for i in range(0, (prem_max_sentence_length - len(prem_not_word_list)))] 194 | prem_not_word_list[1] = prem_not_word_list[1] + [int(hypo_word_idxs[-1] + i + 2) for i in range(0, (prem_max_sentence_length - len(prem_not_word_list)))] 195 | 196 | hypo_not_word_list = [[],[]]; 197 | for k, h_idx in enumerate(hypo_word_idxs[1:]): 198 | hypo_not_word_list[0].append(hypo_word_idxs[k]+1) 199 | hypo_not_word_list[1].append(h_idx+1) 200 | hypo_not_word_list[0] = hypo_not_word_list[0] + [int(hypo_word_idxs[-1]+i+2) for i in range(0,(hypo_max_sentence_length - len(hypo_not_word_list)))] 201 | hypo_not_word_list[1] = hypo_not_word_list[1] + [int(hypo_word_idxs[-1] + i + 2) for i in range(0, (hypo_max_sentence_length - len(hypo_not_word_list)))] 202 | """ 203 | 204 | # p_mask: mask with 0 for token which belong premise and hypothesis including CLS TOKEN 205 | # and with 1 otherwise. 206 | # Original TF implem also keep the classification token (set to 0) 207 | p_mask = np.ones_like(token_type_ids) 208 | if tokenizer.padding_side == "right": 209 | # [CLS] P [SEP] H [SEP] PADDING 210 | p_mask[:len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + 1] = 0 211 | else: 212 | p_mask[-(len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + 1): ] = 0 213 | 214 | # pad_token_indices: input_ids에서 padding된 위치 215 | pad_token_indices = np.array(range(len(non_padded_ids), len(input_ids))) 216 | # special_token_indices: special token의 위치 217 | special_token_indices = np.asarray( 218 | tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True) 219 | ).nonzero() 220 | 221 | p_mask[pad_token_indices] = 1 222 | p_mask[special_token_indices] = 1 223 | 224 | # Set the cls index to 0: the CLS index can be used for impossible answers 225 | # Identify the position of the CLS token 226 | cls_index = input_ids.index(tokenizer.cls_token_id) 227 | 228 | p_mask[cls_index] = 0 229 | 230 | # prem_dependency = [[premise_tail, premise_head, dependency], [], ...] 231 | # hypo_dependency = [[hypothesis_tail, hypothesis_head, dependency], [], ...]] 232 | if example.dependency["premise"] == [[]]: 233 | example.dependency["premise"] = [[prem_max_sentence_length-1,prem_max_sentence_length-1,0] for _ in range(0,prem_max_sentence_length)] 234 | else: 235 | example.dependency["premise"] = example.dependency["premise"] + [[prem_max_sentence_length-1,prem_max_sentence_length-1,0] for i in range(0, abs(prem_max_sentence_length-len(example.dependency["premise"])))] 236 | if example.dependency["hypothesis"] == [[]]: 237 | example.dependency["hypothesis"] = [[hypo_max_sentence_length-1,hypo_max_sentence_length-1,0] for _ in range(0,hypo_max_sentence_length)] 238 | else: 239 | example.dependency["hypothesis"] = example.dependency["hypothesis"]+ [[hypo_max_sentence_length-1,hypo_max_sentence_length-1,0] for i in range(0, abs(hypo_max_sentence_length-len(example.dependency["hypothesis"])))] 240 | 241 | prem_dependency = example.dependency["premise"] 242 | hypo_dependency = example.dependency["hypothesis"] 243 | 244 | return KLUE_NLIFeatures( 245 | input_ids, 246 | attention_mask, 247 | token_type_ids, 248 | #position_ids, 249 | cls_index, 250 | p_mask.tolist(), 251 | example_index=0, 252 | tokens=non_padded_tokens, 253 | token_to_orig_map=token_to_orig_map, 254 | label = label, 255 | guid = example.guid, 256 | language = language, 257 | prem_dependency = prem_dependency, 258 | hypo_dependency=hypo_dependency, 259 | prem_not_word_list = prem_not_word_list, 260 | hypo_not_word_list = hypo_not_word_list, 261 | ) 262 | 263 | 264 | 265 | def klue_convert_example_to_features_init(tokenizer_for_convert): 266 | global tokenizer 267 | tokenizer = tokenizer_for_convert 268 | 269 | 270 | def klue_convert_examples_to_features( 271 | examples, 272 | tokenizer, 273 | max_seq_length, 274 | is_training, 275 | return_dataset=False, 276 | threads=1, 277 | prem_max_sentence_length = 0, 278 | hypo_max_sentence_length = 0, 279 | tqdm_enabled=True, 280 | language = None, 281 | ): 282 | """ 283 | Converts a list of examples into a list of features that can be directly given as input to a model. 284 | It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. 285 | 286 | Args: 287 | examples: list of :class:`~transformers.data.processors.squad.SquadExample` 288 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` 289 | max_seq_length: The maximum sequence length of the inputs. 290 | doc_stride: The stride used when the context is too large and is split across several features. 291 | max_query_length: The maximum length of the query. 292 | is_training: whether to create features for model evaluation or model training. 293 | return_dataset: Default False. Either 'pt' or 'tf'. 294 | if 'pt': returns a torch.data.TensorDataset, 295 | if 'tf': returns a tf.data.Dataset 296 | threads: multiple processing threadsa-smi 297 | 298 | 299 | Returns: 300 | list of :class:`~transformers.data.processors.squad.SquadFeatures` 301 | 302 | Example:: 303 | 304 | processor = SquadV2Processor() 305 | examples = processor.get_dev_examples(data_dir) 306 | 307 | features = squad_convert_examples_to_features( 308 | examples=examples, 309 | tokenizer=tokenizer, 310 | max_seq_length=args.max_seq_length, 311 | doc_stride=args.doc_stride, 312 | max_query_length=args.max_query_length, 313 | is_training=not evaluate, 314 | ) 315 | """ 316 | 317 | # Defining helper methods 318 | features = [] 319 | threads = min(threads, cpu_count()) 320 | with Pool(threads, initializer=klue_convert_example_to_features_init, initargs=(tokenizer,)) as p: 321 | 322 | # annotate_ = 하나의 example에 대한 여러 feature를 리스트로 모은 것 323 | # annotate_ = list(feature1, feature2, ...) 324 | annotate_ = partial( 325 | klue_convert_example_to_features, 326 | max_seq_length=max_seq_length, 327 | prem_max_sentence_length=prem_max_sentence_length, 328 | hypo_max_sentence_length=hypo_max_sentence_length, 329 | is_training=is_training, 330 | language = language, 331 | ) 332 | 333 | # examples에 대한 annotate_ 334 | # features = list( feature1, feature2, feature3, ... ) 335 | ## len(features) == len(examples) 336 | features = list( 337 | tqdm( 338 | p.imap(annotate_, examples, chunksize=32), 339 | total=len(examples), 340 | desc="convert klue nli examples to features", 341 | disable=not tqdm_enabled, 342 | ) 343 | ) 344 | new_features = [] 345 | example_index = 0 # example의 id ## len(features) == len(examples) 346 | for example_feature in tqdm( 347 | features, total=len(features), desc="add example index", disable=not tqdm_enabled 348 | ): 349 | if not example_feature: 350 | continue 351 | 352 | example_feature.example_index = example_index 353 | new_features.append(example_feature) 354 | example_index += 1 355 | 356 | features = new_features 357 | del new_features 358 | 359 | if return_dataset == "pt": 360 | if not is_torch_available(): 361 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.") 362 | 363 | # Convert to Tensors and build dataset 364 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 365 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 366 | 367 | ## RoBERTa doesn’t have token_type_ids, you don’t need to indicate which token belongs to which segment. 368 | if language == "electra": 369 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 370 | # all_position_ids = torch.tensor([f.#position_ids for f in features], dtype=torch.long) 371 | 372 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 373 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 374 | 375 | all_example_indices = torch.tensor([f.example_index for f in features], dtype=torch.long) 376 | all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) # 전체 feature의 개별 index 377 | 378 | # all_dependency = [[[premise_tail, premise_head, dependency], [], ...],[[hypothesis_tail, hypothesis_head, dependency], [], ...]], [[],[]], ... ] 379 | all_prem_dependency = torch.tensor([f.prem_dependency for f in features], dtype=torch.long) 380 | all_hypo_dependency = torch.tensor([f.hypo_dependency for f in features], dtype=torch.long) 381 | 382 | all_prem_not_word_list = torch.tensor([f.prem_not_word_list for f in features], dtype=torch.float) 383 | all_hypo_not_word_list = torch.tensor([f.hypo_not_word_list for f in features], dtype=torch.float) 384 | 385 | if not is_training: 386 | 387 | if language == "electra": 388 | dataset = TensorDataset( 389 | all_input_ids, 390 | all_attention_masks, all_token_type_ids, #all_position_ids, 391 | all_cls_index, all_p_mask, all_feature_index, 392 | all_prem_dependency, all_hypo_dependency, 393 | all_prem_not_word_list, all_hypo_not_word_list 394 | ) 395 | else: 396 | dataset = TensorDataset( 397 | all_input_ids, 398 | all_attention_masks, # all_token_type_ids, all_position_ids, 399 | all_cls_index, all_p_mask, all_feature_index, 400 | all_prem_dependency, all_hypo_dependency, 401 | all_prem_not_word_list, all_hypo_not_word_list 402 | ) 403 | else: 404 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 405 | # label_dict = {"entailment": 0, "contradiction": 1, "neutral": 2} 406 | # all_labels = torch.tensor([label_dict[f.label] for f in features], dtype=torch.long) 407 | 408 | if language == "electra": 409 | dataset = TensorDataset( 410 | all_input_ids, 411 | all_attention_masks, 412 | all_token_type_ids, 413 | #all_position_ids, 414 | all_labels, 415 | all_cls_index, 416 | all_p_mask, 417 | all_example_indices, 418 | all_feature_index, 419 | all_prem_dependency, all_hypo_dependency, 420 | all_prem_not_word_list, all_hypo_not_word_list 421 | ) 422 | else: 423 | dataset = TensorDataset( 424 | all_input_ids, 425 | all_attention_masks, 426 | # all_token_type_ids, 427 | # all_position_ids, 428 | all_labels, 429 | all_cls_index, 430 | all_p_mask, 431 | all_example_indices, 432 | all_feature_index, 433 | all_prem_dependency, all_hypo_dependency, 434 | all_prem_not_word_list, all_hypo_not_word_list 435 | ) 436 | 437 | 438 | return features, dataset 439 | else: 440 | return features 441 | 442 | class KLUE_NLIProcessor(DataProcessor): 443 | train_file = None 444 | dev_file = None 445 | 446 | def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): 447 | if not evaluate: 448 | gold_label = None 449 | label = tensor_dict["gold_label"].numpy().decode("utf-8") 450 | else: 451 | gold_label = tensor_dict["gold_label"].numpy().decode("utf-8") 452 | label = None 453 | 454 | return KLUE_NLIExample( 455 | guid=tensor_dict["guid"].numpy().decode("utf-8"), 456 | genre=tensor_dict["genre"].numpy().decode("utf-8"), 457 | premise=tensor_dict["premise"]["origin"].numpy().decode("utf-8"), 458 | premise_koala=tensor_dict["premise"]["merge"]["koala"].numpy().decode("utf-8"), 459 | premise_merge=tensor_dict["premise"]["merge"]["origin"].numpy().decode("utf-8"), 460 | hypothesis=tensor_dict["hypothesis"]["origin"].numpy().decode("utf-8"), 461 | hypothesis_koala=tensor_dict["hypothesis"]["merge"]["koala"].numpy().decode("utf-8"), 462 | hypothesis_merge=tensor_dict["hypothesis"]["merge"]["origin"].numpy().decode("utf-8"), 463 | gold_label=gold_label, 464 | label=label, 465 | ) 466 | 467 | def get_examples_from_dataset(self, dataset, evaluate=False): 468 | """ 469 | Creates a list of :class:`~transformers.data.processors.squad.KLUE_NLIExample` using a TFDS dataset. 470 | 471 | Args: 472 | dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")` 473 | evaluate: boolean specifying if in evaluation mode or in training mode 474 | 475 | Returns: 476 | List of KLUE_NLIExample 477 | 478 | Examples:: 479 | 480 | import tensorflow_datasets as tfds 481 | dataset = tfds.load("squad") 482 | 483 | training_examples = get_examples_from_dataset(dataset, evaluate=False) 484 | evaluation_examples = get_examples_from_dataset(dataset, evaluate=True) 485 | """ 486 | 487 | if evaluate: 488 | dataset = dataset["validation"] 489 | else: 490 | dataset = dataset["train"] 491 | 492 | examples = [] 493 | for tensor_dict in tqdm(dataset): 494 | examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) 495 | 496 | return examples 497 | 498 | def get_train_examples(self, data_dir, filename=None, depend_embedding = None): 499 | """ 500 | Returns the training examples from the data directory. 501 | 502 | Args: 503 | data_dir: Directory containing the data files used for training and evaluating. 504 | filename: None by default. 505 | 506 | """ 507 | if data_dir is None: 508 | data_dir = "" 509 | 510 | if self.train_file is None: 511 | raise ValueError("KLUE_NLIProcessor should be instantiated via KLUE_NLIV1Processor.") 512 | 513 | with open( 514 | os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" 515 | ) as reader: 516 | data_file = self.train_file if filename is None else filename 517 | if data_file.split(".")[-1] == "json": 518 | input_data = json.load(reader) 519 | elif data_file.split(".")[-1] == "jsonl": 520 | input_data = [] 521 | for line in reader: 522 | data = json.loads(line) 523 | input_data.append(data) 524 | return self._create_examples(input_data, 'train', self.train_file if filename is None else filename) 525 | 526 | def get_dev_examples(self, data_dir, filename=None, depend_embedding = None): 527 | """ 528 | Returns the evaluation example from the data directory. 529 | 530 | Args: 531 | data_dir: Directory containing the data files used for training and evaluating. 532 | filename: None by default. 533 | """ 534 | if data_dir is None: 535 | data_dir = "" 536 | 537 | if self.dev_file is None: 538 | raise ValueError("KLUE_NLIProcessor should be instantiated via KLUE_NLIV1Processor.") 539 | 540 | with open( 541 | os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" 542 | ) as reader: 543 | data_file = self.dev_file if filename is None else filename 544 | if data_file.split(".")[-1] == "json": 545 | input_data = json.load(reader) 546 | elif data_file.split(".")[-1] == "jsonl": 547 | input_data = [] 548 | for line in reader: 549 | data = json.loads(line) 550 | input_data.append(data) 551 | return self._create_examples(input_data, "dev", self.dev_file if filename is None else filename) 552 | 553 | def get_example_from_input(self, input_dictionary): 554 | # guid, genre, premise, hypothesis 555 | guid=input_dictionary["guid"] 556 | premise=input_dictionary["premise"] 557 | hypothesis=input_dictionary["hypothesis"] 558 | gold_label=None 559 | label = None 560 | 561 | examples = [KLUE_NLIExample( 562 | guid=guid, 563 | genre="", 564 | premise=premise, 565 | hypothesis=hypothesis, 566 | gold_label=gold_label, 567 | label=label, 568 | )] 569 | return examples 570 | 571 | def _create_examples(self, input_data, set_type, data_file): 572 | is_training = set_type == "train" 573 | num = 0 574 | examples = [] 575 | parsing = data_file.split("/")[-1].split("_")[0] 576 | if parsing == "merge": parsing = "koala" 577 | for entry in tqdm(input_data): 578 | guid = entry["guid"] 579 | if "genre" in entry: genre = entry["genre"] 580 | elif "source" in entry: genre = entry["source"] 581 | else: genre = entry["guid"].split("_")[0] 582 | premise = entry["premise"]["origin"] 583 | premise_merge = entry["premise"]["merge"]["origin"] 584 | premise_koala = entry["premise"]["merge"][parsing] 585 | 586 | hypothesis = entry["hypothesis"]["origin"] 587 | hypothesis_merge = entry["hypothesis"]["merge"]["origin"] 588 | hypothesis_koala = entry["hypothesis"]["merge"][parsing] 589 | 590 | gold_label = None 591 | label = None 592 | 593 | if is_training: 594 | label = entry["gold_label"] 595 | else: 596 | gold_label = entry["gold_label"] 597 | 598 | example = KLUE_NLIExample( 599 | guid=guid, 600 | genre=genre, 601 | premise=premise, 602 | premise_koala=premise_koala, 603 | premise_merge=premise_merge, 604 | hypothesis=hypothesis, 605 | hypothesis_koala=hypothesis_koala, 606 | hypothesis_merge=hypothesis_merge, 607 | gold_label=gold_label, 608 | label=label, 609 | ) 610 | examples.append(example) 611 | # len(examples) == len(input_data) 612 | return examples 613 | 614 | 615 | class KLUE_NLIV1Processor(KLUE_NLIProcessor): 616 | train_file = "klue-nil-v1_train.json" 617 | dev_file = "klue-nli-v1_dev.json" 618 | 619 | 620 | class KLUE_NLIExample(object): 621 | def __init__( 622 | self, 623 | guid, 624 | genre, 625 | premise, 626 | premise_koala, 627 | premise_merge, 628 | hypothesis, 629 | hypothesis_koala, 630 | hypothesis_merge, 631 | gold_label=None, 632 | label=None, 633 | ): 634 | self.guid = guid 635 | self.genre = genre 636 | self.premise = premise 637 | self.premise_merge = premise_merge 638 | self.premise_koala = premise_koala 639 | self.hypothesis = hypothesis 640 | self.hypothesis_koala = hypothesis_koala 641 | self.hypothesis_merge = hypothesis_merge 642 | 643 | label_dict = {"entailment": 0, "contradiction": 1, "neutral": 2} 644 | if gold_label in label_dict.keys(): 645 | gold_label = label_dict[gold_label] 646 | self.gold_label = gold_label 647 | 648 | if label in label_dict.keys(): 649 | label = label_dict[label] 650 | self.label = label 651 | 652 | # doct_tokens : 띄어쓰기 기준으로 나누어진 어절(word)로 만들어진 리스트 653 | ## sentence1 sentence2 654 | self.doc_tokens = {"premise":self.premise.strip().split(), "hypothesis":self.hypothesis.strip().split()} 655 | 656 | # merge: 말뭉치의 시작위치를 어절 기준으로 만든 리스트 657 | merge_index = [[],[]] 658 | merge_word = [[],[]] 659 | for merge in self.premise_merge: 660 | if merge[1] != []: merge_index[0].append(merge[1]) 661 | for merge in self.hypothesis_merge: 662 | if merge[1] != []: merge_index[1].append(merge[1]) 663 | 664 | # 구문구조 종류 665 | depend2idx = {"None":0}; idx2depend ={0:"None"} 666 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']: 667 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]: 668 | depend2idx[depend1 + "-" + depend2] = len(depend2idx) 669 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2 670 | 671 | if ([words for words in self.premise_koala if words[2][0] != words[2][1]] == []): merge_word[0].append([]) 672 | else: 673 | for words in self.premise_koala: 674 | if words[2][0] != words[2][1]: 675 | if not [merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1])] in [merge_word[:-1] for merge_word in merge_word[0]]: 676 | merge_word[0].append([merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1]),depend2idx[words[1][0]+"-"+words[1][1]]]) 677 | else: 678 | idx = [merge_w[:-1] for merge_w in merge_word[0]].index([merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1])]) 679 | tag = idx2depend[[merge_w[-1] for merge_w in merge_word[0]][idx]].split("-") 680 | if (words[1][1] in ['SBJ', 'CNJ', 'OBJ']) and(tag[1] in ['CMP', 'MOD', 'AJT', 'None', "UNDEF"]): 681 | merge_word[0][idx][2] = depend2idx[tag[0]+"-"+words[1][1]] 682 | 683 | 684 | if ([words for words in self.hypothesis_koala if words[2][0] != words[2][1]] == []): merge_word[1].append([]) 685 | else: 686 | for words in self.hypothesis_koala: 687 | if words[2][0] != words[2][1]: 688 | if not [merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1])] in [merge_word[:-1] for merge_word in merge_word[1]]: 689 | merge_word[1].append([merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1]),depend2idx[words[1][0]+"-"+words[1][1]]]) 690 | else: 691 | idx = [merge_w[:-1] for merge_w in merge_word[1]].index([merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1])]) 692 | tag = idx2depend[[merge_w[-1] for merge_w in merge_word[1]][idx]].split("-") 693 | if (words[1][1] in ['SBJ', 'CNJ', 'OBJ']) and(tag[1] in ['CMP', 'MOD', 'AJT', 'None', "UNDEF"]): 694 | merge_word[1][idx][2] = depend2idx[tag[0]+"-"+words[1][1]] 695 | 696 | self.merge = {"premise":merge_index[0], "hypothesis":merge_index[1]} 697 | self.dependency = {"premise":merge_word[0], "hypothesis":merge_word[1]} 698 | 699 | 700 | class KLUE_NLIFeatures(object): 701 | def __init__( 702 | self, 703 | input_ids, 704 | attention_mask, 705 | token_type_ids, 706 | #position_ids, 707 | cls_index, 708 | p_mask, 709 | example_index, 710 | token_to_orig_map, 711 | guid, 712 | tokens, 713 | label, 714 | language, 715 | prem_dependency, 716 | hypo_dependency, 717 | prem_not_word_list, 718 | hypo_not_word_list, 719 | ): 720 | self.input_ids = input_ids 721 | self.attention_mask = attention_mask 722 | if language == "electra": self.token_type_ids = token_type_ids 723 | #self.position_ids = #position_ids 724 | self.cls_index = cls_index 725 | self.p_mask = p_mask 726 | 727 | self.example_index = example_index 728 | self.token_to_orig_map = token_to_orig_map 729 | self.guid = guid 730 | self.tokens = tokens 731 | 732 | self.label = label 733 | 734 | self.prem_dependency = prem_dependency 735 | self.hypo_dependency = hypo_dependency 736 | 737 | self.prem_not_word_list = prem_not_word_list, 738 | self.hypo_not_word_list = hypo_not_word_list, 739 | 740 | 741 | class KLUEResult(object): 742 | def __init__(self, example_index, label_logits, gold_label=None, cls_logits=None): 743 | self.label_logits = label_logits 744 | self.example_index = example_index 745 | 746 | if gold_label: 747 | self.gold_label = gold_label 748 | self.cls_logits = cls_logits -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | # model += Parsing Infor Collecting Layer (PIC) 2 | 3 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | #from transformers.modeling_electra import ElectraModel, ElectraPreTrainedModel 9 | 10 | from transformers import ElectraModel, RobertaModel 11 | 12 | import transformers 13 | if int(transformers.__version__[0]) <= 3: 14 | from transformers.modeling_roberta import RobertaPreTrainedModel 15 | from transformers.modeling_bert import BertPreTrainedModel 16 | from transformers.modeling_electra import ElectraModel, ElectraPreTrainedModel 17 | else: 18 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel 19 | from transformers.models.bert.modeling_bert import BertPreTrainedModel 20 | from transformers.models.electra.modeling_electra import ElectraPreTrainedModel 21 | 22 | from src.functions.biattention import BiAttention, BiLinear 23 | 24 | class RobertaForSequenceClassification(BertPreTrainedModel): 25 | 26 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 27 | super().__init__(config) 28 | self.num_labels = config.num_labels 29 | self.config = config 30 | self.roberta = RobertaModel(config) 31 | 32 | # 입력 토큰에서 token1, token2가 있을 때 (index of token1, index of token2)를 하나의 span으로 보고 이에 대한 정보를 학습 33 | self.span_info_collect = SICModel1(config) 34 | #self.span_info_collect = SICModel2(config) 35 | 36 | # biaffine을 통해 premise와 hypothesis span에 대한 정보를 결합후 정규화 37 | self.parsing_info_collect = PICModel1(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + klue-biaffine attention + bilistm + klue-bilinear classification 38 | #self.parsing_info_collect = PICModel2(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + bilistm + klue-bilinear classification 39 | #self.parsing_info_collect = PICModel3(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + bilistm + klue-bilinear classification 40 | #self.parsing_info_collect = PICModel4(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보(1) + bilistm + bilinear classification 41 | #self.parsing_info_collect = PICModel5(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + biaffine attention + bilistm + bilinear classification 42 | 43 | self.init_weights() 44 | 45 | def forward( 46 | self, 47 | input_ids=None, 48 | attention_mask=None, 49 | token_type_ids=None, 50 | position_ids=None, 51 | head_mask=None, 52 | inputs_embeds=None, 53 | labels=None, 54 | prem_span=None, 55 | hypo_span=None, 56 | prem_word_idxs=None, 57 | hypo_word_idxs=None, 58 | ): 59 | batch_size = input_ids.shape[0] 60 | discriminator_hidden_states = self.roberta( 61 | input_ids=input_ids, 62 | attention_mask=attention_mask, 63 | ) 64 | # last-layer hidden state 65 | # sequence_output: [batch_size, seq_length, hidden_size] 66 | sequence_output = discriminator_hidden_states[0] 67 | 68 | # span info collecting layer(SIC) 69 | h_ij = self.span_info_collect(sequence_output, prem_word_idxs, hypo_word_idxs) 70 | 71 | # parser info collecting layer(PIC) 72 | logits = self.parsing_info_collect(h_ij, 73 | batch_size= batch_size, 74 | prem_span=prem_span,hypo_span=hypo_span,) 75 | 76 | outputs = (logits, ) + discriminator_hidden_states[2:] 77 | 78 | if labels is not None: 79 | if self.num_labels == 1: 80 | # We are doing regression 81 | loss_fct = MSELoss() 82 | loss = loss_fct(logits.view(-1), labels.view(-1)) 83 | else: 84 | loss_fct = CrossEntropyLoss() 85 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 86 | print("loss: "+str(loss)) 87 | outputs = (loss,) + outputs 88 | 89 | return outputs # (loss), logits, (hidden_states), (attentions) 90 | 91 | 92 | 93 | class SICModel1(nn.Module): 94 | def __init__(self, config): 95 | super().__init__() 96 | 97 | def forward(self, hidden_states, prem_word_idxs, hypo_word_idxs): 98 | # (batch, max_pre_sen, seq_len) @ (batch, seq_len, hidden) = (batch, max_pre_sen, hidden) 99 | prem_word_idxs = prem_word_idxs.squeeze(1) 100 | hypo_word_idxs = hypo_word_idxs.squeeze(1) 101 | 102 | prem = torch.matmul(prem_word_idxs, hidden_states) 103 | hypo = torch.matmul(hypo_word_idxs, hidden_states) 104 | 105 | return [prem, hypo] 106 | 107 | class SICModel2(nn.Module): 108 | def __init__(self, config): 109 | super().__init__() 110 | self.hidden_size = config.hidden_size 111 | 112 | self.W_p_1 = nn.Linear(self.hidden_size, self.hidden_size) 113 | self.W_p_2 = nn.Linear(self.hidden_size, self.hidden_size) 114 | self.W_p_3 = nn.Linear(self.hidden_size, self.hidden_size) 115 | self.W_p_4 = nn.Linear(self.hidden_size, self.hidden_size) 116 | 117 | self.W_h_1 = nn.Linear(self.hidden_size, self.hidden_size) 118 | self.W_h_2 = nn.Linear(self.hidden_size, self.hidden_size) 119 | self.W_h_3 = nn.Linear(self.hidden_size, self.hidden_size) 120 | self.W_h_4 = nn.Linear(self.hidden_size, self.hidden_size) 121 | 122 | def forward(self, hidden_states, prem_word_idxs, hypo_word_idxs): 123 | prem_word_idxs = prem_word_idxs.squeeze(1).type(torch.LongTensor).to("cuda") 124 | hypo_word_idxs = hypo_word_idxs.squeeze(1).type(torch.LongTensor).to("cuda") 125 | 126 | Wp1_h = self.W_p_1(hidden_states) # (bs, length, hidden_size) 127 | Wp2_h = self.W_p_2(hidden_states) 128 | Wp3_h = self.W_p_3(hidden_states) 129 | Wp4_h = self.W_p_4(hidden_states) 130 | 131 | Wh1_h = self.W_h_1(hidden_states) # (bs, length, hidden_size) 132 | Wh2_h = self.W_h_2(hidden_states) 133 | Wh3_h = self.W_h_3(hidden_states) 134 | Wh4_h = self.W_h_4(hidden_states) 135 | 136 | W1_hi_emb=torch.tensor([], dtype=torch.long).to("cuda") 137 | W2_hi_emb=torch.tensor([], dtype=torch.long).to("cuda") 138 | W3_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda") 139 | W3_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda") 140 | W4_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda") 141 | W4_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda") 142 | for i in range(0, hidden_states.shape[0]): 143 | sub_W1_hi_emb = torch.index_select(Wp1_h[i], 0, prem_word_idxs[i][0]) # (prem_max_seq_length, hidden_size) 144 | sub_W2_hi_emb = torch.index_select(Wp2_h[i], 0, prem_word_idxs[i][1]) 145 | sub_W3_hi_start_emb = torch.index_select(Wp3_h[i], 0, prem_word_idxs[i][0]) 146 | sub_W3_hi_end_emb = torch.index_select(Wp3_h[i], 0, prem_word_idxs[i][1]) 147 | sub_W4_hi_start_emb = torch.index_select(Wp4_h[i], 0, prem_word_idxs[i][0]) 148 | sub_W4_hi_end_emb = torch.index_select(Wp4_h[i], 0, prem_word_idxs[i][1]) 149 | 150 | W1_hi_emb = torch.cat((W1_hi_emb, sub_W1_hi_emb.unsqueeze(0))) 151 | W2_hi_emb = torch.cat((W2_hi_emb, sub_W2_hi_emb.unsqueeze(0))) 152 | W3_hi_start_emb = torch.cat((W3_hi_start_emb, sub_W3_hi_start_emb.unsqueeze(0))) 153 | W3_hi_end_emb = torch.cat((W3_hi_end_emb, sub_W3_hi_end_emb.unsqueeze(0))) 154 | W4_hi_start_emb = torch.cat((W4_hi_start_emb, sub_W4_hi_start_emb.unsqueeze(0))) 155 | W4_hi_end_emb = torch.cat((W4_hi_end_emb, sub_W4_hi_end_emb.unsqueeze(0))) 156 | 157 | # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)] 158 | prem_span = W1_hi_emb + W2_hi_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hi_start_emb, W4_hi_end_emb) # (batch_size, prem_max_seq_length, hidden_size) 159 | prem_h_ij = torch.tanh(prem_span) 160 | 161 | W1_hi_emb = torch.tensor([], dtype=torch.long).to("cuda") 162 | W2_hi_emb = torch.tensor([], dtype=torch.long).to("cuda") 163 | W3_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda") 164 | W3_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda") 165 | W4_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda") 166 | W4_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda") 167 | for i in range(0, hidden_states.shape[0]): 168 | sub_W1_hi_emb = torch.index_select(Wh1_h[i], 0, hypo_word_idxs[i][0]) # (hypo_max_seq_length, hidden_size) 169 | sub_W2_hi_emb = torch.index_select(Wh2_h[i], 0, hypo_word_idxs[i][1]) 170 | sub_W3_hi_start_emb = torch.index_select(Wh3_h[i], 0, hypo_word_idxs[i][0]) 171 | sub_W3_hi_end_emb = torch.index_select(Wh3_h[i], 0, hypo_word_idxs[i][1]) 172 | sub_W4_hi_start_emb = torch.index_select(Wh4_h[i], 0, hypo_word_idxs[i][0]) 173 | sub_W4_hi_end_emb = torch.index_select(Wh4_h[i], 0, hypo_word_idxs[i][1]) 174 | 175 | W1_hi_emb = torch.cat((W1_hi_emb, sub_W1_hi_emb.unsqueeze(0))) 176 | W2_hi_emb = torch.cat((W2_hi_emb, sub_W2_hi_emb.unsqueeze(0))) 177 | W3_hi_start_emb = torch.cat((W3_hi_start_emb, sub_W3_hi_start_emb.unsqueeze(0))) 178 | W3_hi_end_emb = torch.cat((W3_hi_end_emb, sub_W3_hi_end_emb.unsqueeze(0))) 179 | W4_hi_start_emb = torch.cat((W4_hi_start_emb, sub_W4_hi_start_emb.unsqueeze(0))) 180 | W4_hi_end_emb = torch.cat((W4_hi_end_emb, sub_W4_hi_end_emb.unsqueeze(0))) 181 | 182 | # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)] 183 | hypo_span = W1_hi_emb + W2_hi_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hi_start_emb, W4_hi_end_emb) # (batch_size, hypo_max_seq_length, hidden_size) 184 | hypo_h_ij = torch.tanh(hypo_span) 185 | 186 | h_ij = [prem_h_ij, hypo_h_ij] 187 | 188 | return h_ij 189 | 190 | 191 | class PICModel1(nn.Module): 192 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 193 | super().__init__() 194 | self.hidden_size = config.hidden_size 195 | self.prem_max_sentence_length = prem_max_sentence_length 196 | self.hypo_max_sentence_length = hypo_max_sentence_length 197 | self.num_labels = config.num_labels 198 | 199 | # 구문구조 종류 200 | depend2idx = {"None": 0}; 201 | idx2depend = {0: "None"}; 202 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']: 203 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]: 204 | depend2idx[depend1 + "-" + depend2] = len(depend2idx) 205 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2 206 | self.depend2idx = depend2idx 207 | self.idx2depend = idx2depend 208 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda") 209 | 210 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 211 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 212 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 213 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 214 | 215 | self.biaffine1 = BiAttention(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 216 | self.biaffine2 = BiAttention(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 217 | 218 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 219 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 220 | 221 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels) 222 | 223 | # self.W_1_bilinear = nn.Bilinear(int(self.hidden_size // 3), int(self.hidden_size // 3), self.hidden_size, bias=False) 224 | # self.W_1_linear = nn.Linear(int(self.hidden_size // 3) + int(self.hidden_size // 3), self.hidden_size) 225 | # self.W_2_bilinear = nn.Bilinear(int(self.hidden_size // 3), int(self.hidden_size // 3), self.hidden_size, bias=False) 226 | # self.W_2_linear = nn.Linear(int(self.hidden_size // 3) + int(self.hidden_size // 3), self.hidden_size) 227 | # 228 | # self.bi_lism_1 = nn.LSTM(input_size=self.hidden_size, hidden_size=self.prem_max_sentence_length//2, num_layers=1, bidirectional=True) 229 | # self.bi_lism_2 = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hypo_max_sentence_length//2, num_layers=1, bidirectional=True) 230 | # 231 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) # 일반화된 정보를 사용 232 | # self.biaffine_W_bilinear = nn.Linear((2*(self.prem_max_sentence_length//2))*(2*(self.hypo_max_sentence_length//2)), self.num_labels, bias=False) 233 | # self.biaffine_W_linear = nn.Linear(2*(self.prem_max_sentence_length//2) + 2*(self.hypo_max_sentence_length//2), self.num_labels) 234 | # #self.biaffine_W_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, self.num_labels, bias=False) 235 | # #self.biaffine_W_linear = nn.Linear(self.hidden_size*2, self.num_labels) 236 | # 237 | # self.reset_parameters() 238 | 239 | def forward(self, hidden_states, batch_size, prem_span, hypo_span): 240 | # hidden_states: [[batch_size, word_idxs, hidden_size], []] 241 | # span: [batch_size, max_sentence_length, max_sentence_length] 242 | # word_idxs: [batch_size, seq_length] 243 | # -> sequence_outputs: [batch_size, seq_length, hidden_size] 244 | 245 | prem_hidden_states= hidden_states[0] 246 | hypo_hidden_states= hidden_states[1] 247 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape) 248 | 249 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size) 250 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda") 251 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda") 252 | 253 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())): 254 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len) 255 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda") 256 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda") 257 | 258 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size) 259 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail) 260 | p_span_dep = self.depend_embedding(p_span_dep) 261 | 262 | n_p_span = p_span_head + p_span_tail + p_span_dep 263 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0))) 264 | 265 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len) 266 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda") 267 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda") 268 | 269 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size) 270 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail) 271 | h_span_dep = self.depend_embedding(h_span_dep) 272 | 273 | n_h_span = h_span_head + h_span_tail + h_span_dep 274 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0))) 275 | 276 | prem_span = new_prem_span 277 | hypo_span = new_hypo_span 278 | 279 | del new_prem_span 280 | del new_hypo_span 281 | 282 | # biaffine attention 283 | # hidden_states: (batch_size, max_prem_len, hidden_size) 284 | # span: (batch, max_prem_len, hidden_size) 285 | # -> biaffine_outputs: [batch_size, 100, max_prem_len, max_prem_len] 286 | prem_span = self.reduction1(prem_span) 287 | prem_hidden_states = self.reduction2(prem_hidden_states) 288 | hypo_span = self.reduction3(hypo_span) 289 | hypo_hidden_states = self.reduction4(hypo_hidden_states) 290 | 291 | prem_biaffine_outputs= self.biaffine1(prem_hidden_states, prem_span) 292 | hypo_biaffine_outputs = self.biaffine2(hypo_hidden_states, hypo_span) 293 | 294 | # outputs = self.bilinear(prem_biaffine_outputs.view(-1,self.prem_max_sentence_length*self.prem_max_sentence_length), 295 | # hypo_biaffine_outputs.view(-1,self.hypo_max_sentence_length*self.hypo_max_sentence_length)) 296 | 297 | # bilstm 298 | # biaffine_outputs: [batch_size, 100, max_prem_len, max_prem_len] -> [batch_size, 100, max_prem_len] -> [max_prem_len, batch_size, 100] 299 | # -> hidden_states: [batch_size, max_sentence_length] 300 | prem_biaffine_outputs = prem_biaffine_outputs.mean(-1) 301 | hypo_biaffine_outputs = hypo_biaffine_outputs.mean(-1) 302 | 303 | prem_biaffine_outputs = prem_biaffine_outputs.transpose(1,2).transpose(0,1) 304 | hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(1,2).transpose(0,1) 305 | 306 | prem_states = None 307 | hypo_states = None 308 | 309 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs) 310 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs) 311 | 312 | 313 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 314 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 315 | 316 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states) 317 | 318 | # new_prem_hidden_states = prem_hidden_states.view(-1, self.prem_max_sentence_length, 1, self.hidden_size) # (batch_size, max_prem_len, 1, hidden_size) 319 | # new_hypo_hidden_states = hypo_hidden_states.view(-1, self.hypo_max_sentence_length, 1, self.hidden_size) 320 | # new_prem_span = prem_span.view(-1, self.prem_max_sentence_length, 3, self.hidden_size)# (batch_size, max_prem_len, 3, hidden_size) 321 | # new_hypo_span = hypo_span.view(-1, self.hypo_max_sentence_length, 3, self.hidden_size) 322 | # 323 | # prem_depend = (new_prem_hidden_states.unsqueeze(-1) * new_prem_span.unsqueeze(-2)).view(-1, self.prem_max_sentence_length, 3*self.hidden_size, self.hidden_size) # (batch_size, max_prem_len, 3*hidden_size, hidden_size) 324 | # prem_depend = self.reduction1(prem_depend.transpose(2,3)).view(-1, self.prem_max_sentence_length, int(3 * self.hidden_size // 64)*self.hidden_size) # (batch_size, max_prem_len, hidden_size, 3*hidden_size) -> (batch_size, max_prem_len, hidden_size, int(3 * self.hidden_size // 64)) 325 | # 326 | # hypo_depend = (new_hypo_hidden_states.unsqueeze(-1) * new_hypo_span.unsqueeze(-2)).view(-1, self.hypo_max_sentence_length, 3*self.hidden_size, self.hidden_size) 327 | # hypo_depend = self.reduction2(hypo_depend.transpose(2,3)).view(-1, self.hypo_max_sentence_length, int(3 * self.hidden_size // 64)*self.hidden_size) 328 | # 329 | # prem_biaffine_outputs= self.W_1_bilinear(prem_depend) + self.W_1_linear(torch.cat((prem_span, prem_hidden_states), dim = -1)) 330 | # hypo_biaffine_outputs = self.W_2_bilinear(hypo_depend) + self.W_2_linear(torch.cat((hypo_span, hypo_hidden_states), dim = -1)) 331 | # 332 | # # bilstm 333 | # # biaffine_outputs: [batch_size, max_sentence_length, hidden_size] 334 | # # -> hidden_states: [batch_size, max_sentence_length] 335 | # prem_biaffine_outputs = prem_biaffine_outputs.transpose(0,1) 336 | # hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(0,1) 337 | # 338 | # prem_states = None 339 | # hypo_states = None 340 | # 341 | # prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs) 342 | # hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs) 343 | # 344 | # prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 345 | # hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 346 | # # biaffine attention 347 | # # prem_hidden_states: (batch_size, max_prem_len) 348 | # # hypo_hidden_states: (batch_size, max_hypo_len) 349 | # 350 | # prem_hypo = (prem_hidden_states.unsqueeze(-1) * hypo_hidden_states.unsqueeze(-2)).view(-1, 351 | # (2*(self.prem_max_sentence_length//2))*(2*(self.hypo_max_sentence_length//2))) 352 | # 353 | # outputs = self.biaffine_W_bilinear(prem_hypo) + self.biaffine_W_linear(torch.cat((prem_hidden_states, hypo_hidden_states), dim=-1)) 354 | 355 | return outputs 356 | 357 | def reset_parameters(self): 358 | self.W_1_bilinear.reset_parameters() 359 | self.W_1_linear.reset_parameters() 360 | self.W_2_bilinear.reset_parameters() 361 | self.W_2_linear.reset_parameters() 362 | 363 | self.biaffine_W_bilinear.reset_parameters() 364 | self.biaffine_W_linear.reset_parameters() 365 | 366 | 367 | 368 | class PICModel2(nn.Module): 369 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 370 | super().__init__() 371 | self.hidden_size = config.hidden_size 372 | self.prem_max_sentence_length = prem_max_sentence_length 373 | self.hypo_max_sentence_length = hypo_max_sentence_length 374 | self.num_labels = config.num_labels 375 | 376 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 377 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 378 | 379 | self.bi_lism_1 = nn.LSTM(input_size=int(self.hidden_size // 3), hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 380 | self.bi_lism_2 = nn.LSTM(input_size=int(self.hidden_size // 3), hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 381 | 382 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels) 383 | 384 | def forward(self, hidden_states, batch_size, prem_span, hypo_span): 385 | # hidden_states: [[batch_size, word_idxs, hidden_size], []] 386 | # span: [batch_size, max_sentence_length, max_sentence_length] 387 | # word_idxs: [batch_size, seq_length] 388 | # -> sequence_outputs: [batch_size, seq_length, hidden_size] 389 | 390 | prem_hidden_states= hidden_states[0] 391 | hypo_hidden_states= hidden_states[1] 392 | 393 | # biLSTM 394 | # hidden_states: (batch_size, max_prem_len, hidden_size) 395 | # -> # -> hidden_states: [batch_size, hidden_size] 396 | prem_hidden_states = self.reduction1(prem_hidden_states) 397 | hypo_hidden_states = self.reduction2(hypo_hidden_states) 398 | prem_hidden_states = prem_hidden_states.transpose(0,1) 399 | hypo_hidden_states = hypo_hidden_states.transpose(0,1) 400 | 401 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_hidden_states) 402 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_hidden_states) 403 | 404 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 405 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 406 | 407 | # bilinear classification 408 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states) 409 | 410 | return outputs 411 | 412 | 413 | class PICModel3(nn.Module): 414 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 415 | super().__init__() 416 | self.hidden_size = config.hidden_size 417 | self.prem_max_sentence_length = prem_max_sentence_length 418 | self.hypo_max_sentence_length = hypo_max_sentence_length 419 | self.num_labels = config.num_labels 420 | 421 | # 구문구조 종류 422 | depend2idx = {"None": 0}; 423 | idx2depend = {0: "None"}; 424 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']: 425 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]: 426 | depend2idx[depend1 + "-" + depend2] = len(depend2idx) 427 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2 428 | self.depend2idx = depend2idx 429 | self.idx2depend = idx2depend 430 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda") 431 | 432 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 433 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 434 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 435 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 436 | 437 | self.tag1 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 438 | self.tag2 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 439 | 440 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 441 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 442 | 443 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels) 444 | 445 | def forward(self, hidden_states, batch_size, prem_span, hypo_span): 446 | # hidden_states: [[batch_size, word_idxs, hidden_size], []] 447 | # span: [batch_size, max_sentence_length, max_sentence_length] 448 | # word_idxs: [batch_size, seq_length] 449 | # -> sequence_outputs: [batch_size, seq_length, hidden_size] 450 | 451 | prem_hidden_states= hidden_states[0] 452 | hypo_hidden_states= hidden_states[1] 453 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape) 454 | 455 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size) 456 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda") 457 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda") 458 | 459 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())): 460 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len) 461 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda") 462 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda") 463 | 464 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size) 465 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail) 466 | p_span_dep = self.depend_embedding(p_span_dep) 467 | 468 | n_p_span = p_span_head + p_span_tail + p_span_dep 469 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0))) 470 | 471 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len) 472 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda") 473 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda") 474 | 475 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size) 476 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail) 477 | h_span_dep = self.depend_embedding(h_span_dep) 478 | 479 | n_h_span = h_span_head + h_span_tail + h_span_dep 480 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0))) 481 | 482 | prem_span = new_prem_span 483 | hypo_span = new_hypo_span 484 | 485 | del new_prem_span 486 | del new_hypo_span 487 | 488 | # bilinear 489 | # hidden_states: (batch_size, max_prem_len, hidden_size) 490 | # span: (batch, max_prem_len, hidden_size) 491 | # -> bilinear_outputs: [batch_size, max_prem_len, 100] 492 | prem_span = self.reduction1(prem_span) 493 | prem_hidden_states = self.reduction2(prem_hidden_states) 494 | hypo_span = self.reduction3(hypo_span) 495 | hypo_hidden_states = self.reduction4(hypo_hidden_states) 496 | 497 | prem_bilinear_outputs= self.tag1(prem_hidden_states, prem_span) 498 | hypo_bilinear_outputs = self.tag2(hypo_hidden_states, hypo_span) 499 | 500 | # bilstm 501 | # biaffine_outputs: [batch_size, max_prem_len, 100] 502 | # -> hidden_states: [batch_size, hidden_size] 503 | 504 | prem_bilinear_outputs = prem_bilinear_outputs. transpose(0,1) 505 | hypo_bilinear_outputs = hypo_bilinear_outputs.transpose(0,1) 506 | 507 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_bilinear_outputs) 508 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_bilinear_outputs) 509 | 510 | 511 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 512 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 513 | 514 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states) 515 | 516 | return outputs 517 | 518 | class PICModel4(nn.Module): 519 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 520 | super().__init__() 521 | self.hidden_size = config.hidden_size 522 | self.prem_max_sentence_length = prem_max_sentence_length 523 | self.hypo_max_sentence_length = hypo_max_sentence_length 524 | self.num_labels = config.num_labels 525 | 526 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 527 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3)) 528 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 529 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3)) 530 | 531 | self.tag1 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 532 | self.tag2 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100) 533 | 534 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 535 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 536 | 537 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels) 538 | 539 | def forward(self, hidden_states, batch_size, prem_span, hypo_span): 540 | # hidden_states: [[batch_size, word_idxs, hidden_size], []] 541 | # span: [batch_size, max_sentence_length, max_sentence_length] 542 | # word_idxs: [batch_size, seq_length] 543 | # -> sequence_outputs: [batch_size, seq_length, hidden_size] 544 | 545 | prem_hidden_states= hidden_states[0] 546 | hypo_hidden_states= hidden_states[1] 547 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape) 548 | 549 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size) 550 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda") 551 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda") 552 | 553 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())): 554 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len) 555 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda") 556 | 557 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size) 558 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail) 559 | 560 | n_p_span = p_span_head + p_span_tail 561 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0))) 562 | 563 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len) 564 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda") 565 | 566 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size) 567 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail) 568 | 569 | n_h_span = h_span_head + h_span_tail 570 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0))) 571 | 572 | prem_span = new_prem_span 573 | hypo_span = new_hypo_span 574 | 575 | del new_prem_span 576 | del new_hypo_span 577 | 578 | # bilinear 579 | # hidden_states: (batch_size, max_prem_len, hidden_size) 580 | # span: (batch, max_prem_len, hidden_size) 581 | # -> bilinear_outputs: [batch_size, max_prem_len, 100] 582 | prem_span = self.reduction1(prem_span) 583 | prem_hidden_states = self.reduction2(prem_hidden_states) 584 | hypo_span = self.reduction3(hypo_span) 585 | hypo_hidden_states = self.reduction4(hypo_hidden_states) 586 | 587 | prem_bilinear_outputs= self.tag1(prem_hidden_states, prem_span) 588 | hypo_bilinear_outputs = self.tag2(hypo_hidden_states, hypo_span) 589 | 590 | # bilstm 591 | # biaffine_outputs: [batch_size, max_prem_len, 100] 592 | # -> hidden_states: [batch_size, hidden_size] 593 | 594 | prem_bilinear_outputs = prem_bilinear_outputs. transpose(0,1) 595 | hypo_bilinear_outputs = hypo_bilinear_outputs.transpose(0,1) 596 | 597 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_bilinear_outputs) 598 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_bilinear_outputs) 599 | 600 | 601 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 602 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 603 | 604 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states) 605 | 606 | return outputs 607 | 608 | 609 | class PICModel5(nn.Module): 610 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length): 611 | super().__init__() 612 | self.hidden_size = config.hidden_size 613 | self.prem_max_sentence_length = prem_max_sentence_length 614 | self.hypo_max_sentence_length = hypo_max_sentence_length 615 | self.num_labels = config.num_labels 616 | 617 | # 구문구조 종류 618 | depend2idx = {"None": 0}; 619 | idx2depend = {0: "None"}; 620 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']: 621 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]: 622 | depend2idx[depend1 + "-" + depend2] = len(depend2idx) 623 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2 624 | self.depend2idx = depend2idx 625 | self.idx2depend = idx2depend 626 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda") 627 | 628 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 6)) 629 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 6)) 630 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 6)) 631 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 6)) 632 | 633 | self.W_1_bilinear = nn.Bilinear(int(self.hidden_size // 6), int(self.hidden_size // 6), 100, bias=False) 634 | self.W_1_linear1 = nn.Linear(int(self.hidden_size // 6), 100) 635 | self.W_1_linear2 = nn.Linear(int(self.hidden_size // 6), 100) 636 | self.W_2_bilinear = nn.Bilinear(int(self.hidden_size // 6), int(self.hidden_size // 6), 100, bias=False) 637 | self.W_2_linear1 = nn.Linear(int(self.hidden_size // 6), 100) 638 | self.W_2_linear2 = nn.Linear(int(self.hidden_size // 6), 100) 639 | 640 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 641 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True) 642 | 643 | self.dropout = nn.Dropout(config.hidden_dropout_prob) # 일반화된 정보를 사용 644 | self.biaffine_W_bilinear = nn.Bilinear((2*(self.hidden_size//2)),(2*(self.hidden_size//2)), self.num_labels, bias=False) 645 | self.biaffine_W_linear1 = nn.Linear(2 * (self.hidden_size//2), self.num_labels) 646 | self.biaffine_W_linear2 = nn.Linear(2 * (self.hidden_size // 2), self.num_labels) 647 | self.reset_parameters() 648 | 649 | def forward(self, hidden_states, batch_size, prem_span, hypo_span): 650 | # hidden_states: [[batch_size, word_idxs, hidden_size], []] 651 | # span: [batch_size, max_sentence_length, max_sentence_length] 652 | # word_idxs: [batch_size, seq_length] 653 | # -> sequence_outputs: [batch_size, seq_length, hidden_size] 654 | 655 | prem_hidden_states= hidden_states[0] 656 | hypo_hidden_states= hidden_states[1] 657 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape) 658 | 659 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size) 660 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda") 661 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda") 662 | 663 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())): 664 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len) 665 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda") 666 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda") 667 | 668 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size) 669 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail) 670 | p_span_dep = self.depend_embedding(p_span_dep) 671 | 672 | n_p_span = p_span_head + p_span_tail + p_span_dep 673 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0))) 674 | 675 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len) 676 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda") 677 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda") 678 | 679 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size) 680 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail) 681 | h_span_dep = self.depend_embedding(h_span_dep) 682 | 683 | n_h_span = h_span_head + h_span_tail + h_span_dep 684 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0))) 685 | 686 | prem_span = new_prem_span 687 | hypo_span = new_hypo_span 688 | 689 | del new_prem_span 690 | del new_hypo_span 691 | 692 | # biaffine attention 693 | # hidden_states: (batch_size, max_prem_len, hidden_size) 694 | # span: (batch, max_prem_len, hidden_size) 695 | # -> biaffine_outputs: [batch_size, max_prem_len, 100] 696 | prem_span = self.reduction1(prem_span) 697 | prem_hidden_states = self.reduction2(prem_hidden_states) 698 | hypo_span = self.reduction3(hypo_span) 699 | hypo_hidden_states = self.reduction4(hypo_hidden_states) 700 | 701 | prem_biaffine_outputs= self.W_1_bilinear(prem_span, prem_hidden_states) + self.W_1_linear1(prem_span) + self.W_1_linear2(prem_hidden_states) 702 | hypo_biaffine_outputs = self.W_2_bilinear(hypo_span, hypo_hidden_states) + self.W_2_linear1(hypo_span) + self.W_2_linear2(hypo_hidden_states) 703 | 704 | # bilstm 705 | # biaffine_outputs: [batch_size, max_sentence_length, hidden_size] 706 | # -> hidden_states: [batch_size, max_sentence_length] 707 | prem_biaffine_outputs = prem_biaffine_outputs.transpose(0,1) 708 | hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(0,1) 709 | 710 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs) 711 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs) 712 | 713 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 714 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1) 715 | 716 | # biaffine attention 717 | # prem_hidden_states: (batch_size, max_prem_len) 718 | # hypo_hidden_states: (batch_size, max_hypo_len) 719 | # -> outputs: (batch_size, num_labels) 720 | outputs = self.biaffine_W_bilinear(prem_hidden_states, hypo_hidden_states) + self.biaffine_W_linear1(prem_hidden_states) +self.biaffine_W_linear2(hypo_hidden_states) 721 | 722 | return outputs 723 | 724 | def reset_parameters(self): 725 | self.W_1_bilinear.reset_parameters() 726 | self.W_1_linear1.reset_parameters() 727 | self.W_1_linear2.reset_parameters() 728 | self.W_2_bilinear.reset_parameters() 729 | self.W_2_linear1.reset_parameters() 730 | self.W_2_linear2.reset_parameters() 731 | 732 | self.biaffine_W_bilinear.reset_parameters() 733 | self.biaffine_W_linear1.reset_parameters() 734 | self.biaffine_W_linear2.reset_parameters() --------------------------------------------------------------------------------