├── model.png
├── requirements.txt
├── src
├── functions
│ ├── metric.py
│ ├── utils.py
│ ├── biattention.py
│ └── processor.py
├── model
│ ├── main_functions.py
│ └── model.py
└── dependency
│ └── merge.py
├── README.md
└── run_NLI.py
/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KUNLP/NeuralSymbolic_KU_NLI/HEAD/model.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | tqdm
4 | tokenizers==0.10.3
5 | attrdict
6 | fastprogress
7 | transformers==4.6.1
8 | torch==1.9.0+cu111
9 |
--------------------------------------------------------------------------------
/src/functions/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
3 |
4 |
5 | def get_sklearn_score(predicts, corrects, idx2label):
6 | predicts = [idx2label[predict] for predict in predicts]
7 | corrects = [idx2label[correct] for correct in corrects]
8 | result = {"accuracy": accuracy_score(corrects, predicts),
9 | "macro_precision": precision_score(corrects, predicts, average="macro"),
10 | "micro_precision": precision_score(corrects, predicts, average="micro"),
11 | "macro_f1": f1_score(corrects, predicts, average="macro"),
12 | "micro_f1": f1_score(corrects, predicts, average="micro"),
13 | "macro_recall": recall_score(corrects, predicts, average="macro"),
14 | "micro_recall": recall_score(corrects, predicts, average="micro"),
15 |
16 | }
17 | for k, v in result.items():
18 | result[k] = round(v, 3)
19 | print(k + ": " + str(v))
20 | return result
21 |
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Natural Language Inference using Dependency Parsing
2 | Code for HCLT 2021 paper: *[Natural Language Inference using Dependency Parsing](https://koreascience.kr/article/CFKO202130060562801.page?&lang=ko)*
3 |
4 |
5 | ## Setting up the code environment
6 |
7 | ```
8 | $ virtualenv --python=python3.7 venv
9 | $ source venv/bin/activate
10 | $ pip install -r requirements.txt
11 | ```
12 |
13 | All code only supports running on Linux.
14 |
15 |
16 | ## Model Structure
17 |
18 |
19 |
20 | ## Data
21 |
22 | Korean Language Understanding Evaluation-Natural Language Inference version1: *[KLUE-NLI](https://klue-benchmark.com/tasks/68/data/description)*
23 |
24 | ### Directory and Pre-processing
25 | `의존 구문 분석 모델은 미공개(The dependency parser model is unpublished)`
26 | ```
27 | ├── data
28 | │ ├── klue-nli-v1_train.json
29 | │ ├── klue-nli-v1_dev.json
30 | │ └── parsing
31 | │ ├── parsing_1_klue_nli_train.json
32 | │ └── parsing_1_klue_nli_dev.json
33 | │ └── merge
34 | │ ├── parsing_1_klue_nli_train.json
35 | │ └── parsing_1_klue_nli_dev.json
36 | ├── roberta
37 | │ ├── init_weight
38 | │ └── my_model
39 | ├── src
40 | │ ├── dependency
41 | │ └── merge.py
42 | │ ├── functions
43 | │ ├── biattention.py
44 | │ ├── utils.py
45 | │ ├── metric.py
46 | │ └── processor.json
47 | │ └── model
48 | │ ├── main_functions.py
49 | │ └── model.py
50 | ├── run_NLI.py
51 | ├── requirements.txt
52 | └── README.md
53 | ```
54 |
55 | * 원시 데이터(data/klue-nli-v1_train.json)를 의존 구문 분석 모델을 활용하여 입력 문장 쌍에 대한 어절 단위 의존 구문 구조 추출(data/parsing/klue-nli-v1_train.json)
56 |
57 | * 입력 문장 쌍에 대한 어절 단위 의존 구문 구조(data/parsing/klue-nli-v1_train.json)를 `src/dependency/merge.py`를 통해 입력 문장 쌍에 대한 청크 단위 의존 구문 구조로 변환(data/merge/klue-nli-v1_train.json)
58 |
59 | * [roberta/init_weight](https://huggingface.co/klue/roberta-base)/vocab.json에 청크 단위로 구분해주는 스폐셜 토큰(Special Token) `` 추가
60 |
61 |
62 | ## Train & Test
63 |
64 | ### Pretrained Model
65 | [klue/roberta-base](https://huggingface.co/klue/roberta-base)
66 |
67 | ### How To Run
68 | ```
69 | python run_NLI.py
70 | ```
71 |
72 | ## Results on KLUE-NLI
73 |
74 | | Model | Acc |
75 | |---|--------- |
76 | | NLI w/ DP | 90.78% |
77 |
--------------------------------------------------------------------------------
/src/functions/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import torch
4 | import numpy as np
5 | import os
6 |
7 |
8 | from src.functions.processor import (
9 | KLUE_NLIV1Processor,
10 | klue_convert_examples_to_features
11 | )
12 |
13 |
14 | def init_logger():
15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
16 | datefmt='%m/%d/%Y %H:%M:%S',
17 | level=logging.INFO)
18 |
19 | def set_seed(args):
20 | random.seed(args.seed)
21 | np.random.seed(args.seed)
22 | torch.manual_seed(args.seed)
23 | if not args.no_cuda and torch.cuda.is_available():
24 | torch.cuda.manual_seed_all(args.seed)
25 |
26 | # tensor를 list 형으로 변환하기위한 함수
27 | def to_list(tensor):
28 | return tensor.detach().cpu().tolist()
29 |
30 |
31 | # dataset을 load 하는 함수
32 | def load_examples(args, tokenizer, evaluate=False, output_examples=False, do_predict=False, input_dict=None):
33 | '''
34 | :param args: 하이퍼 파라미터
35 | :param tokenizer: tokenization에 사용되는 tokenizer
36 | :param evaluate: 평가나 open test시, True
37 | :param output_examples: 평가나 open test 시, True / True 일 경우, examples와 features를 같이 return
38 | :param do_predict: open test시, True
39 | :param input_dict: open test시 입력되는 문서와 질문으로 이루어진 dictionary
40 | :return:
41 | examples : max_length 상관 없이, 원문으로 각 데이터를 저장한 리스트
42 | features : max_length에 따라 분할 및 tokenize된 원문 리스트
43 | dataset : max_length에 따라 분할 및 학습에 직접적으로 사용되는 tensor 형태로 변환된 입력 ids
44 | '''
45 | input_dir = args.data_dir
46 | print("Creating features from dataset file at {}".format(input_dir))
47 |
48 | # processor 선언
49 | ## json으로 된 train과 dev data_file명
50 | if len(set(input_dir.split("/")).intersection(["snli", "mnli", "qnli", "hans", "sick"])) != 0:processor = NLIV1Processor()
51 | else: processor = KLUE_NLIV1Processor()
52 |
53 | # open test 시
54 | if do_predict:
55 | ## input_dict: guid, premise, hypothesis로 이루어진 dictionary
56 | # examples = processor.get_example_from_input(input_dict)
57 | examples = processor.get_dev_examples(os.path.join(args.data_dir),
58 | filename=args.predict_file)
59 | # 평가 시
60 | elif evaluate:
61 | examples = processor.get_dev_examples(os.path.join(args.data_dir),
62 | filename=args.eval_file)
63 | # 학습 시
64 | else:
65 | examples = processor.get_train_examples(os.path.join(args.data_dir),
66 | filename=args.train_file)
67 |
68 |
69 | features, dataset = klue_convert_examples_to_features(
70 | examples=examples,
71 | tokenizer=tokenizer,
72 | max_seq_length=args.max_seq_length,
73 | is_training=not evaluate,
74 | return_dataset="pt",
75 | threads=args.threads,
76 | prem_max_sentence_length = args.prem_max_sentence_length,
77 | hypo_max_sentence_length = args.hypo_max_sentence_length,
78 | language = args.model_name_or_path.split("/")[-2]
79 | )
80 | if output_examples:
81 | ## example == feature == dataset
82 | return dataset, examples, features
83 | return dataset
84 |
--------------------------------------------------------------------------------
/src/functions/biattention.py:
--------------------------------------------------------------------------------
1 | # 해당 코드는 아래 링크에서 가져옴
2 | # https://github.com/KLUE-benchmark/KLUE-baseline/blob/8a03c9447e4c225e806877a84242aea11258c790/klue_baseline/models/dependency_parsing.py
3 |
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.parameter import Parameter
9 | import torch.nn.functional as F
10 |
11 |
12 | class BiAttention(nn.Module):
13 | def __init__(
14 | self,
15 | input_size_encoder,
16 | input_size_decoder,
17 | num_labels,
18 | biaffine=True,
19 | **kwargs
20 | ):
21 | super(BiAttention, self).__init__()
22 | self.input_size_encoder = input_size_encoder
23 | self.input_size_decoder = input_size_decoder
24 | self.num_labels = num_labels
25 | self.biaffine = biaffine
26 |
27 | self.W_e = Parameter(torch.Tensor(self.num_labels, self.input_size_encoder))
28 | self.W_d = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder))
29 | self.b = Parameter(torch.Tensor(self.num_labels, 1, 1))
30 | if self.biaffine:
31 | self.U = Parameter(
32 | torch.Tensor(
33 | self.num_labels, self.input_size_decoder, self.input_size_encoder
34 | )
35 | )
36 | else:
37 | self.register_parameter("U", None)
38 |
39 | self.reset_parameters()
40 |
41 | def reset_parameters(self):
42 | nn.init.xavier_uniform_(self.W_e)
43 | nn.init.xavier_uniform_(self.W_d)
44 | nn.init.constant_(self.b, 0.0)
45 | if self.biaffine:
46 | nn.init.xavier_uniform_(self.U)
47 |
48 | def forward(self, input_e, input_d, mask_d=None, mask_e=None):
49 | assert input_d.size(0) == input_e.size(0)
50 | batch, length_decoder, _ = input_d.size()
51 | _, length_encoder, _ = input_e.size()
52 |
53 | # input_d : [b, t, d]
54 | # input_e : [b, s, e]
55 | # out_d : [b, l, d, 1]
56 | # out_e : [b, l ,1, e]
57 | out_d = torch.matmul(self.W_d, input_d.transpose(1, 2)).unsqueeze(3)
58 | out_e = torch.matmul(self.W_e, input_e.transpose(1, 2)).unsqueeze(2)
59 |
60 | if self.biaffine:
61 | # output : [b, 1, t, d] * [l, d, e] -> [b, l, t, e]
62 | output = torch.matmul(input_d.unsqueeze(1), self.U)
63 | # output : [b, l, t, e] * [b, 1, e, s] -> [b, l, t, s]
64 | output = torch.matmul(output, input_e.unsqueeze(1).transpose(2, 3))
65 | output = output + out_d + out_e + self.b
66 | else:
67 | output = out_d + out_d + self.b
68 |
69 | if mask_d is not None:
70 | output = (
71 | output
72 | * mask_d.unsqueeze(1).unsqueeze(3)
73 | * mask_e.unsqueeze(1).unsqueeze(2)
74 | )
75 |
76 | # input1 = (batch_size, input11, input12)
77 | # input2 = (batch_size, input21, input22)
78 | return output # (batch_size, output_size, input11, input21)
79 |
80 | class BiLinear(nn.Module):
81 | def __init__(self, left_features: int, right_features: int, out_features: int):
82 | super(BiLinear, self).__init__()
83 | self.left_features = left_features
84 | self.right_features = right_features
85 | self.out_features = out_features
86 |
87 | self.U = Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features))
88 | self.W_l = Parameter(torch.Tensor(self.out_features, self.left_features))
89 | self.W_r = Parameter(torch.Tensor(self.out_features, self.right_features))
90 | self.bias = Parameter(torch.Tensor(out_features))
91 |
92 | self.reset_parameters()
93 |
94 | def reset_parameters(self) -> None:
95 | nn.init.xavier_uniform_(self.W_l)
96 | nn.init.xavier_uniform_(self.W_r)
97 | nn.init.constant_(self.bias, 0.0)
98 | nn.init.xavier_uniform_(self.U)
99 |
100 | def forward(self, input_left: torch.Tensor, input_right: torch.Tensor) -> torch.Tensor:
101 | left_size = input_left.size()
102 | right_size = input_right.size()
103 | assert left_size[:-1] == right_size[:-1], "batch size of left and right inputs mis-match: (%s, %s)" % (
104 | left_size[:-1],
105 | right_size[:-1],
106 | )
107 | batch = int(np.prod(left_size[:-1]))
108 |
109 | input_left = input_left.contiguous().view(batch, self.left_features)
110 | input_right = input_right.contiguous().view(batch, self.right_features)
111 |
112 | output = F.bilinear(input_left, input_right, self.U, self.bias)
113 | output = output + F.linear(input_left, self.W_l, None) + F.linear(input_right, self.W_r, None)
114 | return output.view(left_size[:-1] + (self.out_features,))
115 |
--------------------------------------------------------------------------------
/run_NLI.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import logging
4 | from attrdict import AttrDict
5 |
6 | # roberta
7 | from transformers import AutoTokenizer
8 | from transformers import RobertaConfig
9 |
10 | from src.model.model import RobertaForSequenceClassification
11 | from src.model.main_functions import train, evaluate, predict
12 |
13 | from src.functions.utils import init_logger, set_seed
14 |
15 | import sys
16 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
17 |
18 | def create_model(args):
19 |
20 | if args.model_name_or_path.split("/")[-2] == "roberta":
21 |
22 | # 모델 파라미터 Load
23 | config = RobertaConfig.from_pretrained(
24 | args.model_name_or_path
25 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)),
26 | cache_dir=args.cache_dir,
27 | )
28 |
29 | config.num_labels = args.num_labels
30 | # roberta attention 추출하기
31 | config.output_attentions=True
32 |
33 | # tokenizer는 pre-trained된 것을 불러오는 과정이 아닌 불러오는 모델의 vocab 등을 Load
34 | # BertTokenizerFast로 되어있음
35 | tokenizer = AutoTokenizer.from_pretrained(
36 | args.model_name_or_path
37 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)),
38 | do_lower_case=args.do_lower_case,
39 | cache_dir=args.cache_dir,
40 |
41 | )
42 | print(tokenizer)
43 |
44 | model = RobertaForSequenceClassification.from_pretrained(
45 | args.model_name_or_path
46 | if args.from_init_weight else os.path.join(args.output_dir,"model/checkpoint-{}".format(args.checkpoint)),
47 | cache_dir=args.cache_dir,
48 | config=config,
49 | prem_max_sentence_length=args.prem_max_sentence_length,
50 | hypo_max_sentence_length=args.hypo_max_sentence_length,
51 | # from_tf=True if args.from_init_weight else False
52 | )
53 |
54 | args.model_name_or_path = args.cache_dir
55 | # print(tokenizer.convert_tokens_to_ids(""))
56 |
57 | model.to(args.device)
58 |
59 | print(" idx")
60 | print(tokenizer.convert_tokens_to_ids(""))
61 | return model, tokenizer
62 |
63 | def main(cli_args):
64 | # 파라미터 업데이트
65 | args = AttrDict(vars(cli_args))
66 | args.device = "cuda"
67 | logger = logging.getLogger(__name__)
68 |
69 | # logger 및 seed 지정
70 | init_logger()
71 | set_seed(args)
72 |
73 | # 모델 불러오기
74 | model, tokenizer = create_model(args)
75 |
76 | # Running mode에 따른 실행
77 | if args.do_train:
78 | train(args, model, tokenizer, logger)
79 | elif args.do_eval:
80 | #for i in range(1, 13): evaluate(args, model, tokenizer, logger, epoch_idx =i)
81 | evaluate(args, model, tokenizer, logger, epoch_idx =args.checkpoint)
82 |
83 | elif args.do_predict:
84 | predict(args, model, tokenizer)
85 |
86 |
87 | if __name__ == '__main__':
88 | cli_parser = argparse.ArgumentParser()
89 |
90 | # Directory
91 |
92 | # cli_parser.add_argument("--data_dir", type=str, default="./data")
93 | # cli_parser.add_argument("--train_file", type=str, default="klue-nli-v1_train.json")
94 | # cli_parser.add_argument("--eval_file", type=str, default="klue-nli-v1_dev.json")
95 | # cli_parser.add_argument("--predict_file", type=str, default="klue-nli-v1_dev.json") #"klue-nli-v1_dev_sample_10.json")
96 |
97 | #cli_parser.add_argument("--num_labels", type=int, default=3)
98 |
99 | # roberta
100 | cli_parser.add_argument("--model_name_or_path", type=str, default="./roberta/init_weight")
101 | cli_parser.add_argument("--cache_dir", type=str, default="./roberta/init_weight")
102 |
103 | # ------------------------------------------------------------------------------------------------
104 | cli_parser.add_argument("--data_dir", type=str, default="./data/merge")
105 |
106 | cli_parser.add_argument("--num_labels", type=int, default=3)
107 |
108 | cli_parser.add_argument("--train_file", type=str, default='parsing_1_klue_nli_train.json')
109 | cli_parser.add_argument("--eval_file", type=str, default='parsing_1_klue_nli_dev.json')
110 | cli_parser.add_argument("--predict_file", type=str, default='parsing_1_klue_nli_dev.json')
111 | #-----------------------------------------------------------------------------------------------------------
112 | ##################################################################################################################################
113 |
114 | #cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver1_wDP") # checkout-5 90.60 90.56 ± 0.04
115 | #cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver3_wDP") # checkout-3 90.51 90.45 ± 0.06
116 | cli_parser.add_argument("--output_dir", type=str, default="./roberta/my_model/parsing/ver4_wDP") # checkout-4 90.82 90.78 ± 0.04
117 | # ------------------------------------------------------------------------------------------------------------
118 | ## klue # ver1 = 18 ver3,4 = 27
119 | cli_parser.add_argument("--prem_max_sentence_length", type=int, default=27)
120 | cli_parser.add_argument("--hypo_max_sentence_length", type=int, default=27)
121 |
122 | # https://github.com/KLUE-benchmark/KLUE-baseline/blob/main/run_all.sh
123 | # Model Hyper Parameter
124 | cli_parser.add_argument("--max_seq_length", type=int, default=256) #512)
125 | # Training Parameter
126 | cli_parser.add_argument("--learning_rate", type=float, default=1e-5)
127 | cli_parser.add_argument("--train_batch_size", type=int, default=8)
128 | cli_parser.add_argument("--eval_batch_size", type=int, default=16)
129 | cli_parser.add_argument("--num_train_epochs", type=int, default=5)
130 |
131 | #cli_parser.add_argument("--save_steps", type=int, default=2000)
132 | cli_parser.add_argument("--logging_steps", type=int, default=100)
133 | cli_parser.add_argument("--seed", type=int, default=42)
134 | cli_parser.add_argument("--threads", type=int, default=8)
135 |
136 | cli_parser.add_argument("--weight_decay", type=float, default=0.0)
137 | cli_parser.add_argument("--adam_epsilon", type=int, default=1e-10)
138 | cli_parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
139 | cli_parser.add_argument("--warmup_steps", type=int, default=0)
140 | cli_parser.add_argument("--max_steps", type=int, default=-1)
141 | cli_parser.add_argument("--max_grad_norm", type=int, default=1.0)
142 |
143 | cli_parser.add_argument("--verbose_logging", type=bool, default=False)
144 | cli_parser.add_argument("--do_lower_case", type=bool, default=False)
145 | cli_parser.add_argument("--no_cuda", type=bool, default=False)
146 |
147 | # Running Mode
148 | cli_parser.add_argument("--from_init_weight", type=bool, default= True) #False)#True)
149 | cli_parser.add_argument("--checkpoint", type=str, default="4")
150 |
151 | cli_parser.add_argument("--do_train", type=bool, default=True)#False)#True)
152 | cli_parser.add_argument("--do_eval", type=bool, default=False)#True)
153 | cli_parser.add_argument("--do_predict", type=bool, default=False)#True)#False)
154 |
155 | cli_args = cli_parser.parse_args()
156 |
157 | main(cli_args)
158 |
--------------------------------------------------------------------------------
/src/model/main_functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | import timeit
6 | from fastprogress.fastprogress import master_bar, progress_bar
7 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
8 | from transformers.file_utils import is_torch_available
9 |
10 | from transformers import (
11 | AdamW,
12 | get_linear_schedule_with_warmup
13 | )
14 |
15 | from src.functions.utils import load_examples, set_seed, to_list
16 | from src.functions.metric import get_score, get_sklearn_score
17 |
18 | from functools import partial
19 |
20 | def train(args, model, tokenizer, logger):
21 | max_acc =0
22 | # 학습에 사용하기 위한 dataset Load
23 | ## dataset: tensor형태의 데이터셋
24 | ## all_input_ids,
25 | # all_attention_masks,
26 | # all_labels,
27 | # all_cls_index,
28 | # all_p_mask,
29 | # all_example_indices,
30 | # all_feature_index
31 |
32 | train_dataset = load_examples(args, tokenizer, evaluate=False, output_examples=False)
33 |
34 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader
35 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정
36 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정
37 | train_sampler = RandomSampler(train_dataset)
38 |
39 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
40 |
41 | # t_total: total optimization step
42 | # optimization 최적화 schedule 을 위한 전체 training step 계산
43 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
44 |
45 | # Layer에 따른 가중치 decay 적용
46 | no_decay = ["bias", "LayerNorm.weight"]
47 | optimizer_grouped_parameters = [
48 | {
49 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
50 | "weight_decay": args.weight_decay,
51 | },
52 | {
53 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
54 | "weight_decay": 0.0},
55 | ]
56 |
57 | # optimizer 및 scheduler 선언
58 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
59 | scheduler = get_linear_schedule_with_warmup(
60 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
61 | )
62 |
63 | # Training Step
64 | logger.info("***** Running training *****")
65 | logger.info(" Num examples = %d", len(train_dataset))
66 | logger.info(" Num Epochs = %d", args.num_train_epochs)
67 | logger.info(" Train batch size per GPU = %d", args.train_batch_size)
68 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps)
69 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
70 | logger.info(" Total optimization steps = %d", t_total)
71 |
72 | global_step = 1
73 | if not args.from_init_weight: global_step += int(args.checkpoint)
74 |
75 | tr_loss, logging_loss = 0.0, 0.0
76 |
77 | # loss buffer 초기화
78 | model.zero_grad()
79 |
80 | mb = master_bar(range(int(args.num_train_epochs)))
81 | set_seed(args)
82 |
83 | epoch_idx=0
84 | if not args.from_init_weight: epoch_idx += int(args.checkpoint)
85 |
86 | for epoch in mb:
87 | epoch_iterator = progress_bar(train_dataloader, parent=mb)
88 | for step, batch in enumerate(epoch_iterator):
89 | # train 모드로 설정
90 | model.train()
91 | batch = tuple(t.to(args.device) for t in batch)
92 |
93 | # 모델에 입력할 입력 tensor 저장
94 | inputs_list = ["input_ids", "attention_mask"]
95 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids")
96 | inputs_list.append("labels")
97 | inputs = dict()
98 | for n, input in enumerate(inputs_list): inputs[input] = batch[n]
99 |
100 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span']
101 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m+1)]
102 |
103 | # Loss 계산 및 저장
104 | ## outputs = (total_loss,) + outputs
105 | outputs = model(**inputs)
106 | loss = outputs[0]
107 |
108 | # 높은 batch size는 학습이 진행하는 중에 발생하는 noisy gradient가 경감되어 불안정한 학습을 안정적이게 되도록 해줌
109 | # 높은 batch size 효과를 주기위한 "gradient_accumulation_step"
110 | ## batch size *= gradient_accumulation_step
111 | # batch size: 16
112 | # gradient_accumulation_step: 2 라고 가정
113 | # 실제 batch size 32의 효과와 동일하진 않지만 비슷한 효과를 보임
114 | if args.gradient_accumulation_steps > 1:
115 | loss = loss / args.gradient_accumulation_steps
116 |
117 | ## batch_size의 개수만큼의 데이터를 입력으로 받아 만들어진 모델의 loss는
118 | ## 입력 데이터들에 대한 특징을 보유하고 있다(loss를 어떻게 만드느냐에 따라 달라)
119 | ### loss_fct = CrossEntropyLoss(ignore_index=ignored_index, reduction = ?)
120 | ### reduction = mean : 입력 데이터에 대한 평균
121 | loss.backward()
122 | tr_loss += loss.item()
123 |
124 | # Loss 출력
125 | if (global_step + 1) % 50 == 0:
126 | print("{} step processed.. Current Loss : {}".format((global_step+1),loss.item()))
127 |
128 | if (step + 1) % args.gradient_accumulation_steps == 0:
129 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
130 |
131 | optimizer.step()
132 | scheduler.step() # Update learning rate schedule
133 | model.zero_grad()
134 | global_step += 1
135 |
136 | epoch_idx += 1
137 | #logger.info("***** Eval results *****")
138 | #results = evaluate(args, model, tokenizer, logger, epoch_idx = str(epoch_idx), tr_loss = loss.item())
139 |
140 | output_dir = os.path.join(args.output_dir, "model/checkpoint-{}".format(epoch_idx))
141 | if not os.path.exists(output_dir):
142 | os.makedirs(output_dir)
143 |
144 | # 학습된 가중치 및 vocab 저장
145 | ## pretrained 모델같은 경우 model.save_pretrained(...)로 저장
146 | ## nn.Module로 만들어진 모델일 경우 model.save(...)로 저장
147 | ### 두개가 모두 사용되는 모델일 경우 이 두가지 방법으로 저장을 해야한다!!!!
148 | model.save_pretrained(output_dir)
149 | tokenizer.save_pretrained(output_dir)
150 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
151 | logger.info("Saving model checkpoint to %s", output_dir)
152 |
153 | # # model save
154 | # if (args.logging_steps > 0 and max_acc < float(results["accuracy"])):
155 | # max_acc = float(results["accuracy"])
156 | # # 모델 저장 디렉토리 생성
157 | # output_dir = os.path.join(args.output_dir, "model/checkpoint-{}".format(epoch_idx))
158 | # if not os.path.exists(output_dir):
159 | # os.makedirs(output_dir)
160 | #
161 | # # 학습된 가중치 및 vocab 저장
162 | # ## pretrained 모델같은 경우 model.save_pretrained(...)로 저장
163 | # ## nn.Module로 만들어진 모델일 경우 model.save(...)로 저장
164 | # ### 두개가 모두 사용되는 모델일 경우 이 두가지 방법으로 저장을 해야한다!!!!
165 | # model.save_pretrained(output_dir)
166 | # tokenizer.save_pretrained(output_dir)
167 | # # torch.save(args, os.path.join(output_dir, "training_args.bin"))
168 | # logger.info("Saving model checkpoint to %s", output_dir)
169 |
170 | mb.write("Epoch {} done".format(epoch + 1))
171 |
172 | return global_step, tr_loss / global_step
173 |
174 | # 정답이 사전부착된 데이터로부터 평가하기 위한 함수
175 | def evaluate(args, model, tokenizer, logger, epoch_idx = "", tr_loss = 1):
176 | # 데이터셋 Load
177 | ## dataset: tensor형태의 데이터셋
178 | ## example: json형태의 origin 데이터셋
179 | ## features: index번호가 추가된 list형태의 examples 데이터셋
180 | dataset, examples, features = load_examples(args, tokenizer, evaluate=True, output_examples=True)
181 |
182 | # 최종 출력 파일 저장을 위한 디렉토리 생성
183 | if not os.path.exists(args.output_dir):
184 | os.makedirs(args.output_dir)
185 |
186 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader
187 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정
188 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정
189 | eval_sampler = SequentialSampler(dataset)
190 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
191 |
192 | # Eval!
193 | logger.info("***** Running evaluation {} *****".format(epoch_idx))
194 | logger.info(" Num examples = %d", len(dataset))
195 | logger.info(" Batch size = %d", args.eval_batch_size)
196 |
197 | # 평가 시간 측정을 위한 time 변수
198 | start_time = timeit.default_timer()
199 |
200 | # 예측 라벨
201 | pred_logits = torch.tensor([], dtype = torch.long).to(args.device)
202 | for batch in progress_bar(eval_dataloader):
203 | # 모델을 평가 모드로 변경
204 | model.eval()
205 | batch = tuple(t.to(args.device) for t in batch)
206 |
207 | with torch.no_grad():
208 | # 평가에 필요한 입력 데이터 저장
209 | inputs_list = ["input_ids", "attention_mask"]
210 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids")
211 | inputs = dict()
212 | for n, input in enumerate(inputs_list): inputs[input] = batch[n]
213 |
214 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span']
215 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m + 1)]
216 |
217 | # outputs = (label_logits, )
218 | # label_logits: [batch_size, num_labels]
219 | outputs = model(**inputs)
220 |
221 | pred_logits = torch.cat([pred_logits,outputs[0]], dim = 0)
222 |
223 | # pred_label과 gold_label 비교
224 | pred_logits= pred_logits.detach().cpu().numpy()
225 | pred_labels = np.argmax(pred_logits, axis=-1)
226 | ## gold_labels = 0 or 1 or 2
227 | gold_labels = [example.gold_label for example in examples]
228 |
229 | # print('\n\n=====================outputs=====================')
230 | # for g,p in zip(gold_labels, pred_labels):
231 | # print(str(g)+"\t"+str(p))
232 | # print('===========================================================')
233 |
234 | # 평가 시간 측정을 위한 time 변수
235 | evalTime = timeit.default_timer() - start_time
236 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(dataset))
237 |
238 | # 최종 예측값과 원문이 저장된 example로 부터 성능 평가
239 | ## results = {"macro_precision":round(macro_precision, 4), "macro_recall":round(macro_recall, 4), "macro_f1_score":round(macro_f1_score, 4), \
240 | ## "accuracy":round(total_accuracy, 4), \
241 | ## "micro_precision":round(micro_precision, 4), "micro_recall":round(micro_recall, 4), "micro_f1":round(micro_f1_score, 4)}
242 | idx2label = {0:"entailment", 1:"contradiction", 2:"neutral"}
243 | #results = get_score(pred_labels, gold_labels, idx2label)
244 | results = get_sklearn_score(pred_labels, gold_labels, idx2label)
245 |
246 | output_dir = os.path.join( args.output_dir, 'eval')
247 |
248 | out_file_type = 'a'
249 | if not os.path.exists(output_dir):
250 | os.makedirs(output_dir)
251 | out_file_type ='w'
252 |
253 | # 평가 스크립트 기반 성능 저장을 위한 파일 생성
254 | if os.path.exists(args.model_name_or_path):
255 | print(args.model_name_or_path)
256 | eval_file_name = list(filter(None, args.model_name_or_path.split("/"))).pop()
257 | else: eval_file_name = "init_weight"
258 | output_eval_file = os.path.join(output_dir, "eval_result_{}.txt".format(eval_file_name))
259 |
260 | with open(output_eval_file, out_file_type, encoding='utf-8') as f:
261 | f.write("train loss: {}\n".format(tr_loss))
262 | f.write("epoch: {}\n".format(epoch_idx))
263 | for k in results.keys():
264 | f.write("{} : {}\n".format(k, results[k]))
265 | f.write("=======================================\n\n")
266 | return results
267 |
268 | def predict(args, model, tokenizer):
269 | dataset, examples, features = load_examples(args, tokenizer, evaluate=True, output_examples=True, do_predict=True)
270 |
271 | # 최종 출력 파일 저장을 위한 디렉토리 생성
272 | if not os.path.exists(args.output_dir):
273 | os.makedirs(args.output_dir)
274 |
275 | # tokenizing 된 데이터를 batch size만큼 가져오기 위한 random sampler 및 DataLoader
276 | ## RandomSampler: 데이터 index를 무작위로 선택하여 조정
277 | ## SequentialSampler: 데이터 index를 항상 같은 순서로 조정
278 | eval_sampler = SequentialSampler(dataset)
279 | eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
280 |
281 | print("***** Running Prediction *****")
282 | print(" Num examples = %d", len(dataset))
283 |
284 | # 예측 라벨
285 | pred_logits = torch.tensor([], dtype=torch.long).to(args.device)
286 | for batch in progress_bar(eval_dataloader):
287 | # 모델을 평가 모드로 변경
288 | model.eval()
289 | batch = tuple(t.to(args.device) for t in batch)
290 |
291 | with torch.no_grad():
292 | # 평가에 필요한 입력 데이터 저장
293 | inputs_list = ["input_ids", "attention_mask"]
294 | if args.model_name_or_path.split("/")[-2] == "electra": inputs_list.append("token_type_ids")
295 | inputs = dict()
296 | for n, input in enumerate(inputs_list): inputs[input] = batch[n]
297 |
298 | inputs_list2 = ['hypo_word_idxs', 'prem_word_idxs', 'hypo_span', 'prem_span']
299 | for m, input in enumerate(inputs_list2): inputs[input] = batch[-(m + 1)]
300 |
301 | # outputs = (label_logits, )
302 | # label_logits: [batch_size, num_labels]
303 | outputs = model(**inputs)
304 |
305 | pred_logits = torch.cat([pred_logits, outputs[0]], dim=0)
306 |
307 | # pred_label과 gold_label 비교
308 | pred_logits = pred_logits.detach().cpu().numpy()
309 | pred_labels = np.argmax(pred_logits, axis=-1)
310 | ## gold_labels = 0 or 1 or 2
311 | gold_labels = [example.gold_label for example in examples]
312 |
313 | idx2label = {0:"entailment", 1:"contradiction", 2:"neutral"}
314 | #results = get_score(pred_labels, gold_labels, idx2label)
315 | results = get_sklearn_score(pred_labels, gold_labels, idx2label)
316 |
317 | # 검증 스크립트 기반 성능 저장
318 | output_dir = os.path.join(args.output_dir, 'test')
319 |
320 | out_file_type = 'a'
321 | if not os.path.exists(output_dir):
322 | os.makedirs(output_dir)
323 | out_file_type = 'w'
324 |
325 | ## 검증 스크립트 기반 성능 저장을 위한 파일 생성
326 | if os.path.exists(args.model_name_or_path):
327 | print(args.model_name_or_path)
328 | eval_file_name = list(filter(None, args.model_name_or_path.split("/"))).pop()
329 | else:
330 | eval_file_name = "init_weight"
331 | output_test_file = os.path.join(output_dir, "test_result_{}_incorrect.txt".format(eval_file_name))
332 |
333 | with open(output_test_file, out_file_type, encoding='utf-8') as f:
334 | print('\n\n=====================outputs=====================')
335 | for i,(g,p) in enumerate(zip(gold_labels, pred_labels)):
336 | if g != p:
337 | f.write("premise: {}\thypothesis: {}\tcorrect: {}\tpredict: {}\n".format(examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p]))
338 | for k in results.keys():
339 | f.write("{} : {}\n".format(k, results[k]))
340 | f.write("=======================================\n\n")
341 |
342 | out = {"premise":[], "hypothesis":[], "correct":[], "predict":[]}
343 | for i,(g,p) in enumerate(zip(gold_labels, pred_labels)):
344 | #if g != p:
345 | for k,v in zip(out.keys(),[examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p]]):
346 | out[k].append(v)
347 | for k, v in zip(out.keys(), [examples[i].premise, examples[i].hypothesis, idx2label[g], idx2label[p]]):
348 | out[k].append(v)
349 | df = pd.DataFrame(out)
350 | df.to_csv(os.path.join(output_dir, "test_result_{}.csv".format(eval_file_name)), index=False)
351 |
352 | return results
--------------------------------------------------------------------------------
/src/dependency/merge.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from operator import itemgetter
4 |
5 | def change_tag(change_dic, tag_list, text):
6 | #print("change_dic: " + str(change_dic))
7 | del_dic = []
8 | for key in change_dic.keys():
9 | if change_dic[key][0] in change_dic.keys():
10 | del_dic.append(change_dic[key][0])
11 | change_dic[key] = change_dic[change_dic[key][0]]
12 | #print(del_dic)
13 | for dic in set(del_dic):
14 | del change_dic[dic]
15 |
16 | dic_val_idx = [val[1] for val in change_dic.values()]
17 | for i, val_idx in enumerate(dic_val_idx):
18 | for j, val2 in zip([j for j, val2 in enumerate(dic_val_idx) if j != i], [val2 for j, val2 in enumerate(dic_val_idx) if j != i]):
19 | #print(str(i) + " " + str(j))
20 | #print(val_idx)
21 | #print(val2)
22 | if (len(set(val_idx).intersection(set(val2))) != 0) or (len(set(val2).intersection(set(val_idx))) != 0):
23 | #print(str(i)+" "+str(j))
24 | change_dic[list(change_dic.keys())[i]] = [" ".join([text.split()[k] for k in list(set(
25 | change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))]),
26 | list(set(change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))]
27 |
28 | change_dic[list(change_dic.keys())[j]] = [" ".join([text.split()[k] for k in list(set(
29 | change_dic[list(change_dic.keys())[i]][1] + change_dic[list(change_dic.keys())[j]][1]))]),
30 | list(set(change_dic[list(change_dic.keys())[i]][1] +
31 | change_dic[list(change_dic.keys())[j]][1]))]
32 | #print("change_dic: " + str(change_dic))
33 | dic_val_idx = [val[1] for val in change_dic.values()]
34 | #print("change_dic: " + str(change_dic))
35 | for tag_idx, tag_li in enumerate(tag_list):
36 | del_list = []
37 | for tag_i,tag_l in enumerate(tag_li):
38 | if tag_l[0][0] in change_dic.keys():
39 | tag_list[tag_idx][tag_i][2][0] = change_dic[tag_l[0][0]][1]
40 | tag_list[tag_idx][tag_i][0][0] = change_dic[tag_l[0][0]][0]
41 | if tag_l[0][1] in change_dic.keys():
42 | tag_list[tag_idx][tag_i][2][1] = change_dic[tag_l[0][1]][1]
43 | tag_list[tag_idx][tag_i][0][1] = change_dic[tag_l[0][1]][0]
44 |
45 | if tag_l[0][0] == tag_l[0][1]: del_list.append(tag_l)
46 | tag_list[tag_idx] = [x for x in tag_list[tag_idx] if x not in del_list]
47 |
48 | return tag_list, {}
49 |
50 | def merge_tag(inf_dir, outf_dir, tag_li_type = "modifier_w/_phrase"):
51 | dp = inf_dir.split("_")[0].split("/")[-1]
52 | with open(inf_dir, "r", encoding="utf-8") as inf:
53 | datas = json.load(inf)
54 |
55 | outputs = [];outputs2 = [];
56 | for id, data in tqdm(enumerate(datas)):
57 | output = [];output2 = [];
58 | # {"origin": text, dp:root, words:[[word, tag, idx], [], ... ]}
59 | texts_list = [data["premise"], data["hypothesis"]]
60 | #print("\n=================="+str(data["guid"])+"=============================")
61 | for texts in texts_list:
62 | #print("\n--------------------------------------------------------------------")
63 | text = texts["origin"]
64 | #print(text)
65 | #print("--------------------------------------------------------------------")
66 |
67 | # [{'R', 'VNP', 'L', 'VP', 'S', 'AP', 'NP', 'DP', 'IP', 'X'}, {'None', 'MOD', 'CNJ', 'AJT', 'OBJ', 'SBJ', 'CMP'}]
68 | r_list = []; l_list = []; s_list = []; x_list = []; np_list = []; dp_list = []; vp_list = [];vnp_list = []; ap_list = []; ip_list = []
69 | MOD = []; AJT = []; CMP = [];
70 | np_cnj_list = []
71 | #print(texts[dp]["words"])
72 |
73 | # 전처리
74 | ## [['어떤 방에서도', '금지됩니다.'], ['NP', 'AJT'], [[0, 1], [3]]], [['방에서도', '흡연은'], ['NP', 'MOD'], [[1], [2]]]와 같은 경우
75 | for i, koala in enumerate(texts[dp]["words"]):
76 | for j, other in enumerate(texts[dp]["words"]):
77 | if (texts[dp]["words"][i][2][0] != other[2][0]) and (set(texts[dp]["words"][i][2][0]+other[2][0]) == set(other[2][0])):
78 | texts[dp]["words"][i][0][0] = other[0][0]
79 | texts[dp]["words"][i][2][0] = other[2][0]
80 | if (texts[dp]["words"][i][2][0] != other[2][1]) and (set(texts[dp]["words"][i][2][0]+other[2][1]) == set(other[2][1])):
81 | texts[dp]["words"][i][0][0] = other[0][1]
82 | texts[dp]["words"][i][2][0] = other[2][1]
83 | if (texts[dp]["words"][i][2][1] != other[2][0]) and (set(texts[dp]["words"][i][2][1]+other[2][0]) == set(other[2][0])):
84 | texts[dp]["words"][i][0][1] = other[0][0]
85 | texts[dp]["words"][i][2][1] = other[2][0]
86 | if (texts[dp]["words"][i][2][1] != other[2][1]) and (set(texts[dp]["words"][i][2][1]+other[2][1]) == set(other[2][1])):
87 | texts[dp]["words"][i][0][1] = other[0][1]
88 | texts[dp]["words"][i][2][1] = other[2][1]
89 | #print(texts[dp]["words"])
90 | for i, koala in enumerate(texts[dp]["words"]):
91 | #print(koala)
92 | tag = koala[1]
93 | word_idx = koala[2]
94 | if (tag[0] == "NP") and (tag[1] != "CNJ"):
95 | np_list.append(koala)
96 | elif (tag[0] == "DP"):
97 | dp_list.append(koala)
98 | elif (tag[0] == "VP"):
99 | vp_list.append(koala)
100 | elif (tag[0] == "VNP"):
101 | vnp_list.append(koala)
102 | elif (tag[0] == "AP"):
103 | ap_list.append(koala)
104 | elif (tag[0] == "IP"):
105 | ip_list.append(koala)
106 | elif (tag[0] == "R"):
107 | r_list.append(koala)
108 | elif (tag[0] == "L"):
109 | l_list.append(koala)
110 | elif (tag[0] == "S"):
111 | s_list.append(koala)
112 | elif (tag[0] == "X"):
113 | x_list.append(koala)
114 |
115 | if (tag[1] == "MOD"): MOD.append(koala);
116 | elif (tag[1] == "AJT"): AJT.append(koala);
117 | elif (tag[1] == "CMP"): CMP.append(koala);
118 |
119 |
120 | if (tag[0] == "NP") and (tag[1] == "CNJ"):
121 | np_cnj_list.append(koala)
122 |
123 | vp_list = vp_list+vnp_list
124 | tag_list = []
125 | if tag_li_type == "modifier":
126 | tag_list = [MOD+ AJT+ CMP] #3
127 | # tag_list = [MOD, AJT, CMP] #2
128 | elif tag_li_type == "phrase":
129 | tag_list = [x for x in [np_list, dp_list, vp_list, ap_list, ip_list, r_list, l_list, s_list, x_list] if len(x) != 0] # 1
130 | elif tag_li_type == "modifier_w/_phrase":
131 | tag_list = [MOD+ AJT+ CMP]+[x for x in [np_list, dp_list, vp_list, ap_list, ip_list, r_list, l_list, s_list, x_list] if len(x) != 0] #4
132 |
133 | change_dic = {}
134 | for tag_idx,tag_li in enumerate(tag_list):
135 | #print("tag_li: "+str(tag_li))
136 | conti = True
137 | while conti:
138 | tag_list, change_dic = change_tag(change_dic, tag_list, text)
139 | tag_li = tag_list[tag_idx]
140 | #print("tag_li: " + str(tag_li))
141 | new_tag_li = []
142 | other_tag_li = []
143 | for tag_l in tag_li:
144 | if abs(max(tag_l[2][1])-min(tag_l[2][0]))==1:new_tag_li.append(tag_l)
145 | else: other_tag_li.append(tag_l)
146 |
147 | # print("new_tag_li: "+str(new_tag_li))
148 |
149 | if (len(new_tag_li)==0) or (len(tag_li)==1):
150 | conti = False;
151 | else:
152 | new_tag_li = sorted(new_tag_li, key = lambda x:(max(x[2][1])))
153 | # print("new_tag_li after sorted: "+str(new_tag_li))
154 |
155 | tag_li = other_tag_li
156 | del other_tag_li
157 | #print("tag_li: " + str(tag_li))
158 |
159 | # 거리의 길이가 1일 때 양 옆에 이어서 있는 경우
160 | i = 1
161 | while (i != len(new_tag_li)):
162 | if (new_tag_li[-i+1][0][1] == new_tag_li[-i][0][0]) and (new_tag_li[-i+1][2][1] == new_tag_li[-i][2][0]):
163 | if min(new_tag_li[-i][2][0])min(new_tag_li[-i][2][1]):
167 | change_dic.update({new_tag_li[-i][0][0]: [" ".join([new_tag_li[-i][0][1], new_tag_li[-i][0][0]]), list(set(new_tag_li[-i+1][2][1]+new_tag_li[-i][2][0]+new_tag_li[-i][2][1]))]})
168 | change_dic.update({new_tag_li[-i][0][1]:[" ".join([new_tag_li[-i][0][1], new_tag_li[-i][0][0]]), list(set(new_tag_li[-i + 1][2][1] + new_tag_li[-i][2][0] + new_tag_li[-i][2][1]))]})
169 | new_tag_li[-i + 1] = [[new_tag_li[-i+1][0][0], " ".join(new_tag_li[-i][0])], new_tag_li[-i+1][1], [new_tag_li[-i+1][2][0], list(set(new_tag_li[-i+1][2][1]+new_tag_li[-i][2][0]+new_tag_li[-i][2][1]))]]
170 | new_tag_li.pop()
171 | else: i+=1
172 |
173 | # 거리의 길이가 2인 경우
174 | del_list = []
175 | for new_tag_l in new_tag_li:
176 | for tag_l in tag_li:
177 | if (tag_l[0][1] == new_tag_l[0][1]) and (max(tag_l[2][0])+1 == min(new_tag_l[2][0])):
178 | if min(tag_l[2][0]) < min(new_tag_l[2][0]):
179 | change_dic.update({new_tag_l[0][0]:[" ".join([tag_l[0][0], new_tag_l[0][0]]), list(set(tag_l[2][0] + new_tag_l[2][0]))]})
180 | change_dic.update({tag_l[0][0]:[" ".join([tag_l[0][0], new_tag_l[0][0]]),list(set(tag_l[2][0] + new_tag_l[2][0]))]})
181 | elif min(tag_l[2][0]) > min(new_tag_l[2][0]):
182 | change_dic.update({new_tag_l[0][0]: [" ".join([new_tag_l[0][0], tag_l[0][0]]),
183 | list(set(tag_l[2][0] + new_tag_l[2][0]))]})
184 | change_dic.update({tag_l[0][0]: [" ".join([new_tag_l[0][0], tag_l[0][0]]),
185 | list(set(tag_l[2][0] + new_tag_l[2][0]))]})
186 |
187 | del_list.append(new_tag_l)
188 |
189 | new_tag_li = [x for x in new_tag_li if x not in del_list]
190 |
191 | #print("new_tag_li: "+str(new_tag_li))
192 | if len(new_tag_li) != 0:
193 | for new_tag_l in new_tag_li:
194 | if (min(new_tag_l[2][0]) < min(new_tag_l[2][1])) and (max(new_tag_l[2][0])+1 == min(new_tag_l[2][1])):
195 | change_dic.update({new_tag_l[0][0]:[" ".join(new_tag_l[0]), list(set(sum(new_tag_l[2],[])))]})
196 | change_dic.update({new_tag_l[0][1]: [" ".join(new_tag_l[0]), list(set(sum(new_tag_l[2], [])))]})
197 | elif (min(new_tag_l[2][0]) > min(new_tag_l[2][1])) and (max(new_tag_l[2][1])+1 == min(new_tag_l[2][0])):
198 | change_dic.update({new_tag_l[0][0]:[" ".join([new_tag_l[0][1], new_tag_l[0][0]]), list(set(sum(new_tag_l[2],[])))]})
199 | change_dic.update({new_tag_l[0][1]: [" ".join([new_tag_l[0][1], new_tag_l[0][0]]), list(set(sum(new_tag_l[2], [])))]})
200 | new_tag_li = []
201 |
202 | if change_dic == {}:
203 | conti = False;
204 | #print("change_dic: " + str(change_dic))
205 |
206 | tag_list[tag_idx] = tag_li
207 | #print("change tag_li: " + str(tag_list[tag_idx]))
208 | #print("tag_list: "+ str(tag_list))
209 | tag_list = [x for x in tag_list if x != []]
210 |
211 | #for i, koala in enumerate(texts[dp]["words"]):
212 | # # print(koala)
213 | # tag = koala[1]
214 | # if (tag[0] == "NP") and (tag[1] == "CNJ"):
215 | # if (koala[2][0] != koala[2][1]):np_cnj_list.append(koala)
216 | #"""
217 | new_koalas = [[] for tag_idx, _ in enumerate(tag_list)]
218 | for tag_idx, _ in enumerate(tag_list):
219 | # NP-CNJ
220 | if (len(np_cnj_list) != 0):
221 | for cnj_koala in np_cnj_list:
222 | for ttag_list in tag_list[tag_idx]:
223 | new_koala = []
224 | if (len(set(cnj_koala[2][0]).intersection(set(ttag_list[2][0]))) != 0):
225 | new_koala = [[cnj_koala[0][0], ttag_list[0][1]],
226 | cnj_koala[1],
227 | [cnj_koala[2][0], ttag_list[2][1]]]
228 |
229 | if (len(set(cnj_koala[2][0]).intersection(set(ttag_list[2][1]))) != 0):
230 | new_koala = [[ttag_list[0][1], cnj_koala[0][1]],
231 | cnj_koala[1],
232 | [ttag_list[2][1], cnj_koala[2][1]]]
233 |
234 | if (len(set(cnj_koala[2][1]).intersection(set(ttag_list[2][0]))) != 0):
235 | new_koala = [[cnj_koala[0][0], ttag_list[0][0]],
236 | cnj_koala[1],
237 | [cnj_koala[2][0], ttag_list[2][0]]]
238 |
239 | if (len(set(cnj_koala[2][1]).intersection(set(ttag_list[2][1]))) != 0):
240 | new_koala = [[cnj_koala[0][0], ttag_list[0][1]],
241 | cnj_koala[1],
242 | [cnj_koala[2][0], ttag_list[2][1]]]
243 |
244 | if new_koala != []: new_koalas[tag_idx].append(new_koala)
245 |
246 | for k_idx, new_koala in enumerate(new_koalas):
247 | new_tag_list = tag_list[k_idx]+new_koala
248 | tag_list[k_idx] = new_tag_list
249 | #"""
250 |
251 | tag_list = sum(tag_list, [])
252 | tag_list = [tag for tag in tag_list if len(set(tag[2][0]).intersection(set(tag[2][1]))) == 0]
253 | #print("tag_list: " + str(tag_list))
254 |
255 | sub_output2 = {}
256 | for tag in tag_list:
257 | #print(tag)
258 | sub_output2[tag[0][0]] = tag[2][0]
259 | sub_output2[tag[0][1]] = tag[2][1]
260 | sub_output2 = [[key, sorted(value)] for key,value in sub_output2.items()]
261 | # print(sub_output2)
262 | sub_output2.sort(key=itemgetter(1))
263 | # print("sub_output2: " + str(sub_output2))
264 | # 후처리
265 | i = 0
266 | while (i < len(sub_output2) - 1):
267 | sub1 = sub_output2[i]
268 | sub2 = sub_output2[i + 1]
269 | if (sub1[0] != sub2[0]) and (sub1[1] != sub2[1]) and ((len(set(sub1[1]).intersection(set(sub2[1]))) != 0) or (len(set(sub2[1]).intersection(set(sub1[1]))) != 0)):
270 | #print("sub_output2: " + str(sub_output2))
271 | new_idx = sorted(list(set(sub1[1] + sub2[1])))
272 | sub_output2[i] = (" ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1]), new_idx)
273 | sub_output2 = sub_output2[:i + 1] + sub_output2[i + 2:]
274 | #print("sub_output2: " + str(sub_output2))
275 | for j, tag in enumerate(tag_list):
276 | if (tag[0][0] in [sub1[0], sub2[0]]):
277 | tag_list[j][0][0] = " ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1])
278 | tag_list[j][2][0] = new_idx
279 | elif (tag[0][1] in [sub1[0], sub2[0]]):
280 | tag_list[j][0][1] = " ".join(texts["origin"].split()[min(new_idx):max(new_idx) + 1])
281 | tag_list[j][2][1] = new_idx
282 | else: i += 1
283 |
284 | # print("sub_output2: " + str(sub_output2))
285 | # print("tag_list: " + str(tag_list))
286 |
287 | if sub_output2 == []:
288 | sub_output2 = [[" ".join(text.split()[:-1]), [i for i in range(0, len(text.split())-1)]], [text.split()[-1], [len(text.split())-1]]]
289 | tag_list = [[[sub_output2[0][0], sub_output2[1][0]], ['VP', 'MOD'], [sub_output2[0][1], sub_output2[1][1]]]]
290 |
291 | if sum([sub[1] for sub in sub_output2], []) != [i for i,_ in enumerate(text.split())]:
292 | for li in sum([[sorted(tag[2][0]), sorted(tag[2][1])] for tag in tag_list], []):
293 | if (len(set(li).intersection(set([t for t, _ in enumerate(text.split()) if t not in sum([sub[1] for sub in sub_output2], [])]))) != 0):
294 | sub_output2 += [[" ".join([text.split()[t] for t in li]), li]]
295 | sub_output2 += [[t, [i]] for i, t in enumerate(text.split()) if i not in sum([sub[1] for sub in sub_output2], [])]
296 | sub_output2.sort(key=itemgetter(1))
297 | #print("sub_output2: " + str(sub_output2))
298 | #print("tag_list: " + str(tag_list))
299 | #print(sum([sub[1] for sub in sub_output2], []))
300 | #print([[i, t] for i, t in enumerate(text.split())])
301 | assert sum([sub[1] for sub in sub_output2], []) == [i for i,_ in enumerate(text.split())]
302 |
303 | sub_output2 = [[sub[0], sorted(sub[1])] for sub in sub_output2]
304 | output2.append(sub_output2)
305 | tag_list = [[tag[0], tag[1], [sorted(tag[2][0]), sorted(tag[2][1])]] for tag in tag_list]
306 | #tag_list += np_cnj_list
307 | output.append(tag_list)
308 | #print("output2: " + str(output2))
309 | #print("output: " + str(output))
310 | outputs.append(output)
311 | # sub_output2: [('흡연자분들은', [0]), ('발코니가 있는', [1, 2]), ('방이면', [3]), ('발코니에서 흡연이', [4, 5]), ('가능합니다.', [6])]
312 | outputs2.append(output2)
313 |
314 | for i, (data, merge1, merge2) in enumerate(zip(datas, outputs, outputs2)):
315 | datas[i]["premise"]["merge"] = {"origin": merge2[0], dp: merge1[0]}
316 | datas[i]["hypothesis"]["merge"] = {"origin": merge2[1], dp: merge1[1]}
317 |
318 | for i, data in tqdm(enumerate(datas)):
319 | for sen in ["premise", "hypothesis"]:
320 | for j, merge in enumerate(data[sen]["merge"]["origin"]):
321 | if merge == ["", []]:
322 | data[sen]["merge"]["origin"] = [data[sen]["merge"]["origin"][j+1]]
323 | datas[i][sen]["merge"]["origin"] = data[sen]["merge"]["origin"]
324 | data[sen]["merge"]["parsing"][0][0][0] = data[sen]["merge"]["parsing"][0][0][1]
325 | data[sen]["merge"]["parsing"][0][2][0] = data[sen]["merge"]["parsing"][0][2][1]
326 | datas[i][sen]["merge"]["parsing"] = data[sen]["merge"]["parsing"]
327 |
328 | for sen in ["premise", "hypothesis"]:
329 | new_koala = []
330 | for k, words in enumerate(data[sen]["merge"][dp]):
331 | origin_idx = [merge[1] for merge in data[sen]["merge"]["origin"]]
332 | if (words[2][0] in origin_idx):
333 | if (words[2][1] in origin_idx):
334 | new_koala.append(words)
335 | datas[i][sen]["merge"][dp] = new_koala
336 |
337 |
338 | with open(outf_dir, 'w', encoding="utf-8") as f:
339 | json.dump(datas, f, ensure_ascii=False, indent=4)
340 | print("\n\nfinish!!")
341 |
342 | if __name__ =='__main__':
343 |
344 | inf_dirs = ["./../../data/parsing/parsing_1_klue_nli_train.json", "./../../data/parsing/parsing_1_klue_nli_dev.json"]
345 | #["./../../data/koala/koala_ver1_klue_nli_train.json", "./../../data/koala/koala_ver1_klue_nli_dev.json"]
346 | outf_dirs = ["./../../data/merge/parsing_1_klue_nli_train.json", "./../../data/merge/parsing_1_klue_nli_dev.json"]
347 | #["./../../data/merge/merge_3_klue_nli_train.json", "./../../data/merge/merge_3_klue_nli_dev.json"]
348 |
349 | for inf_dir, outf_dir in zip(inf_dirs, outf_dirs):
350 | merge_tag(inf_dir, outf_dir, tag_li_type = "modifier_w/_phrase")
351 | # merge_tag(inf_dir, outf_dir, tag_li_type = "modifier")
352 | # merge_tag(inf_dir, outf_dir, tag_li_type = "phrase")
353 |
354 |
--------------------------------------------------------------------------------
/src/functions/processor.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from functools import partial
5 | from multiprocessing import Pool, cpu_count
6 |
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 | import transformers
11 | from transformers.file_utils import is_tf_available, is_torch_available
12 | from transformers.data.processors.utils import DataProcessor
13 |
14 | if is_torch_available():
15 | import torch
16 | from torch.utils.data import TensorDataset
17 |
18 | if is_tf_available():
19 | import tensorflow as tf
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 |
25 | def klue_convert_example_to_features(example, max_seq_length, is_training, prem_max_sentence_length, hypo_max_sentence_length, language):
26 |
27 | # 데이터의 유효성 검사를 위한 부분
28 | # ========================================================
29 | label = None
30 | if is_training:
31 | # Get label
32 | label = example.label
33 |
34 | # label_dictionary에 주어진 label이 존재하지 않으면 None을 feature로 출력
35 | # If the label cannot be found in the text, then skip this example.
36 | ## kind_of_label: label의 종류
37 | kind_of_label = ["entailment", "contradiction", "neutral"]
38 | actual_text = kind_of_label[label] if label<=len(kind_of_label) else label
39 | if actual_text not in kind_of_label:
40 | logger.warning("Could not find label: '%s' \n not in entailment, contradiction, and neutral", actual_text)
41 | return None
42 | # ========================================================
43 |
44 | # 단어와 토큰 간의 위치 정보 확인
45 | tok_to_orig_index = {"premise": [], "hypothesis": []} # token 개수만큼 # token에 대한 word의 위치
46 | orig_to_tok_index = {"premise": [], "hypothesis": []} # origin 개수만큼 # word를 토큰화하여 나온 첫번째 token의 위치
47 | all_doc_tokens = {"premise": [], "hypothesis": []} # origin text를 tokenization
48 | token_to_orig_map = {"premise": {}, "hypothesis": {}}
49 |
50 | for case in example.merge.keys():
51 | new_merge = []
52 | new_word = []
53 | idx = 0
54 | for merge_idx in example.merge[case]:
55 | for m_idx in merge_idx:
56 | new_word.append(example.doc_tokens[case][m_idx])
57 | new_word.append("")
58 | merge_idx = [m_idx+idx for m_idx in range(0,len(merge_idx))]
59 | new_merge.append(merge_idx)
60 | idx = max(merge_idx)+1
61 | new_merge.append([idx])
62 | idx+=1
63 | example.merge[case] = new_merge
64 | example.doc_tokens[case] = new_word
65 |
66 | for case in example.merge.keys():
67 | for merge_idx in example.merge[case]:
68 | for word_idx in merge_idx:
69 | # word를 토큰화하여 나온 첫번째 token의 위치
70 | orig_to_tok_index[case].append(len(tok_to_orig_index[case]))
71 | if (example.doc_tokens[case][word_idx] == ""):
72 | sub_tokens = [""]
73 | else: sub_tokens = tokenizer.tokenize(example.doc_tokens[case][word_idx])
74 | for sub_token in sub_tokens:
75 | # token 저장
76 | all_doc_tokens[case].append(sub_token)
77 | # token에 대한 word의 위치
78 | tok_to_orig_index[case].append(word_idx)
79 | # token_to_orig_map: {token:word}
80 | token_to_orig_map[case][len(tok_to_orig_index[case]) - 1] = len(orig_to_tok_index[case]) - 1
81 |
82 | # print("tok_to_orig_index\n"+str(tok_to_orig_index))
83 | # print("orig_to_tok_index\n"+str(orig_to_tok_index))
84 | # print("all_doc_tokens\n"+str(all_doc_tokens))
85 | # print("token_to_orig_map\n\tindex of token : index of word\n\t"+str(token_to_orig_map))
86 |
87 | # =========================================================
88 |
89 | if int(transformers.__version__[0]) <= 3:
90 | # sequence_added_tokens: [CLS], [SEP]가 추가된 토큰이므로 2
91 | ## roberta or camembert: 3
92 | ## sen1 sen2
93 | sequence_added_tokens = (
94 | tokenizer.max_len - tokenizer.max_len_single_sentence + 1
95 | if "roberta" in language or "camembert" in language
96 | else tokenizer.max_len - tokenizer.max_len_single_sentence
97 | )
98 | #print("sequence_added_tokens(# using special token): "+str(sequence_added_tokens))
99 |
100 | # special token을 제외한 최대 들어갈 수 있는 실제 premise와 hypothesis의 token길이
101 | ## BERT같은 경우 입력으로 '[CLS] P [SEP] H [SEP]'이므로
102 | ### sequence_pair_added_tokens는 special token의 개수인 3
103 | ### tokenizer.max_len = 512 & tokenizer.max_len_sentences_pair = 509
104 | ## RoBERTa는 입력으로 P H 로 구성되므로
105 | ### sequence_pair_added_tokens는 special token의 개수인 4
106 | ### tokenizer.max_len = 512 & tokenizer.max_len_sentences_pair = 508
107 | sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair
108 | #print("sequence_pair_added_tokens(# of special token in text): "+str(sequence_pair_added_tokens))
109 |
110 | # 최대 길이 넘는지 확인
111 | assert len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + sequence_pair_added_tokens <= tokenizer.max_len
112 |
113 | else:
114 | sequence_added_tokens = (
115 | tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
116 | if "roberta" in language or "camembert" in language
117 | else tokenizer.model_max_length - tokenizer.max_len_single_sentence
118 | )
119 | #print("sequence_added_tokens(# using special token): "+str(sequence_added_tokens))
120 |
121 | sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
122 | #print("sequence_pair_added_tokens(# of special token in text): "+str(sequence_pair_added_tokens))
123 |
124 | # 최대 길이 넘는지 확인
125 | assert len(all_doc_tokens["premise"]) + len(
126 | all_doc_tokens["hypothesis"]) + sequence_pair_added_tokens <= tokenizer.model_max_length
127 |
128 | input_ids = [tokenizer.cls_token_id] + [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["premise"]]
129 | prem_word_idxs = [0] + list(filter(lambda x: input_ids[x] == tokenizer.convert_tokens_to_ids(""),range(len(input_ids))))
130 |
131 | input_ids += [tokenizer.sep_token_id]
132 | if "roberta" in language or "camembert" in language: input_ids += [tokenizer.sep_token_id]
133 | token_type_ids = [0] * len(input_ids)
134 |
135 | input_ids += [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]] + [tokenizer.sep_token_id]
136 | hypo_word_idxs = list(filter(lambda x: [tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]][x] == tokenizer.convert_tokens_to_ids(""),range(len([tokenizer.convert_tokens_to_ids(token) for token in all_doc_tokens["hypothesis"]]))))
137 |
138 | hypo_word_idxs = [len(token_type_ids)-1] + [x+len(token_type_ids) for x in hypo_word_idxs]
139 |
140 | token_type_ids = token_type_ids + [1] * (len(input_ids) - len(token_type_ids))
141 | position_ids = list(range(0, len(input_ids)))
142 |
143 | # non_padded_ids: padding을 제외한 토큰의 index 번호
144 | non_padded_ids = [i for i in input_ids]
145 |
146 | # tokens: padding을 제외한 토큰
147 | non_padded_tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
148 |
149 | attention_mask = [1]*len(input_ids)
150 |
151 | paddings = [tokenizer.pad_token_id]*(max_seq_length - len(input_ids))
152 |
153 | if tokenizer.padding_side == "right":
154 | input_ids += paddings
155 | attention_mask += [0]*len(paddings)
156 | token_type_ids += paddings
157 | position_ids += paddings
158 | else:
159 | input_ids = paddings + input_ids
160 | attention_mask = [0]*len(paddings) + attention_mask
161 | token_type_ids = paddings + token_type_ids
162 | position_ids = paddings + position_ids
163 |
164 | prem_word_idxs = [x+len(paddings) for x in prem_word_idxs]
165 | hypo_word_idxs = [x + len(paddings) for x in hypo_word_idxs]
166 |
167 | # """
168 | # mean pooling
169 | prem_not_word_list = []
170 | for k, p_idx in enumerate(prem_word_idxs[1:]):
171 | prem_not_word_idxs = [0] * len(input_ids);
172 | for j in range(prem_word_idxs[k] + 1, p_idx):
173 | prem_not_word_idxs[j] = 1 / (p_idx - prem_word_idxs[k] - 1)
174 | prem_not_word_list.append(prem_not_word_idxs)
175 | prem_not_word_list = prem_not_word_list + [[0] * len(input_ids)] * (
176 | prem_max_sentence_length - len(prem_not_word_list))
177 |
178 |
179 | hypo_not_word_list = [];
180 | for k, h_idx in enumerate(hypo_word_idxs[1:]):
181 | hypo_not_word_idxs = [0] * len(input_ids);
182 | for j in range(hypo_word_idxs[k] + 1, h_idx):
183 | hypo_not_word_idxs[j] = 1 / (h_idx - hypo_word_idxs[k] - 1)
184 | hypo_not_word_list.append(hypo_not_word_idxs)
185 | hypo_not_word_list = hypo_not_word_list + [[0] * len(input_ids)] * (
186 | hypo_max_sentence_length - len(hypo_not_word_list))
187 | """
188 | # (a,b, |a-b|, a*b)
189 | prem_not_word_list = [[], []]
190 | for k, p_idx in enumerate(prem_word_idxs[1:]):
191 | prem_not_word_list[0].append(prem_word_idxs[k] + 1)
192 | prem_not_word_list[1].append(p_idx - 1)
193 | prem_not_word_list[0] = prem_not_word_list[0] + [int(hypo_word_idxs[-1]+i+2) for i in range(0, (prem_max_sentence_length - len(prem_not_word_list)))]
194 | prem_not_word_list[1] = prem_not_word_list[1] + [int(hypo_word_idxs[-1] + i + 2) for i in range(0, (prem_max_sentence_length - len(prem_not_word_list)))]
195 |
196 | hypo_not_word_list = [[],[]];
197 | for k, h_idx in enumerate(hypo_word_idxs[1:]):
198 | hypo_not_word_list[0].append(hypo_word_idxs[k]+1)
199 | hypo_not_word_list[1].append(h_idx+1)
200 | hypo_not_word_list[0] = hypo_not_word_list[0] + [int(hypo_word_idxs[-1]+i+2) for i in range(0,(hypo_max_sentence_length - len(hypo_not_word_list)))]
201 | hypo_not_word_list[1] = hypo_not_word_list[1] + [int(hypo_word_idxs[-1] + i + 2) for i in range(0, (hypo_max_sentence_length - len(hypo_not_word_list)))]
202 | """
203 |
204 | # p_mask: mask with 0 for token which belong premise and hypothesis including CLS TOKEN
205 | # and with 1 otherwise.
206 | # Original TF implem also keep the classification token (set to 0)
207 | p_mask = np.ones_like(token_type_ids)
208 | if tokenizer.padding_side == "right":
209 | # [CLS] P [SEP] H [SEP] PADDING
210 | p_mask[:len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + 1] = 0
211 | else:
212 | p_mask[-(len(all_doc_tokens["premise"]) + len(all_doc_tokens["hypothesis"]) + 1): ] = 0
213 |
214 | # pad_token_indices: input_ids에서 padding된 위치
215 | pad_token_indices = np.array(range(len(non_padded_ids), len(input_ids)))
216 | # special_token_indices: special token의 위치
217 | special_token_indices = np.asarray(
218 | tokenizer.get_special_tokens_mask(input_ids, already_has_special_tokens=True)
219 | ).nonzero()
220 |
221 | p_mask[pad_token_indices] = 1
222 | p_mask[special_token_indices] = 1
223 |
224 | # Set the cls index to 0: the CLS index can be used for impossible answers
225 | # Identify the position of the CLS token
226 | cls_index = input_ids.index(tokenizer.cls_token_id)
227 |
228 | p_mask[cls_index] = 0
229 |
230 | # prem_dependency = [[premise_tail, premise_head, dependency], [], ...]
231 | # hypo_dependency = [[hypothesis_tail, hypothesis_head, dependency], [], ...]]
232 | if example.dependency["premise"] == [[]]:
233 | example.dependency["premise"] = [[prem_max_sentence_length-1,prem_max_sentence_length-1,0] for _ in range(0,prem_max_sentence_length)]
234 | else:
235 | example.dependency["premise"] = example.dependency["premise"] + [[prem_max_sentence_length-1,prem_max_sentence_length-1,0] for i in range(0, abs(prem_max_sentence_length-len(example.dependency["premise"])))]
236 | if example.dependency["hypothesis"] == [[]]:
237 | example.dependency["hypothesis"] = [[hypo_max_sentence_length-1,hypo_max_sentence_length-1,0] for _ in range(0,hypo_max_sentence_length)]
238 | else:
239 | example.dependency["hypothesis"] = example.dependency["hypothesis"]+ [[hypo_max_sentence_length-1,hypo_max_sentence_length-1,0] for i in range(0, abs(hypo_max_sentence_length-len(example.dependency["hypothesis"])))]
240 |
241 | prem_dependency = example.dependency["premise"]
242 | hypo_dependency = example.dependency["hypothesis"]
243 |
244 | return KLUE_NLIFeatures(
245 | input_ids,
246 | attention_mask,
247 | token_type_ids,
248 | #position_ids,
249 | cls_index,
250 | p_mask.tolist(),
251 | example_index=0,
252 | tokens=non_padded_tokens,
253 | token_to_orig_map=token_to_orig_map,
254 | label = label,
255 | guid = example.guid,
256 | language = language,
257 | prem_dependency = prem_dependency,
258 | hypo_dependency=hypo_dependency,
259 | prem_not_word_list = prem_not_word_list,
260 | hypo_not_word_list = hypo_not_word_list,
261 | )
262 |
263 |
264 |
265 | def klue_convert_example_to_features_init(tokenizer_for_convert):
266 | global tokenizer
267 | tokenizer = tokenizer_for_convert
268 |
269 |
270 | def klue_convert_examples_to_features(
271 | examples,
272 | tokenizer,
273 | max_seq_length,
274 | is_training,
275 | return_dataset=False,
276 | threads=1,
277 | prem_max_sentence_length = 0,
278 | hypo_max_sentence_length = 0,
279 | tqdm_enabled=True,
280 | language = None,
281 | ):
282 | """
283 | Converts a list of examples into a list of features that can be directly given as input to a model.
284 | It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
285 |
286 | Args:
287 | examples: list of :class:`~transformers.data.processors.squad.SquadExample`
288 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer`
289 | max_seq_length: The maximum sequence length of the inputs.
290 | doc_stride: The stride used when the context is too large and is split across several features.
291 | max_query_length: The maximum length of the query.
292 | is_training: whether to create features for model evaluation or model training.
293 | return_dataset: Default False. Either 'pt' or 'tf'.
294 | if 'pt': returns a torch.data.TensorDataset,
295 | if 'tf': returns a tf.data.Dataset
296 | threads: multiple processing threadsa-smi
297 |
298 |
299 | Returns:
300 | list of :class:`~transformers.data.processors.squad.SquadFeatures`
301 |
302 | Example::
303 |
304 | processor = SquadV2Processor()
305 | examples = processor.get_dev_examples(data_dir)
306 |
307 | features = squad_convert_examples_to_features(
308 | examples=examples,
309 | tokenizer=tokenizer,
310 | max_seq_length=args.max_seq_length,
311 | doc_stride=args.doc_stride,
312 | max_query_length=args.max_query_length,
313 | is_training=not evaluate,
314 | )
315 | """
316 |
317 | # Defining helper methods
318 | features = []
319 | threads = min(threads, cpu_count())
320 | with Pool(threads, initializer=klue_convert_example_to_features_init, initargs=(tokenizer,)) as p:
321 |
322 | # annotate_ = 하나의 example에 대한 여러 feature를 리스트로 모은 것
323 | # annotate_ = list(feature1, feature2, ...)
324 | annotate_ = partial(
325 | klue_convert_example_to_features,
326 | max_seq_length=max_seq_length,
327 | prem_max_sentence_length=prem_max_sentence_length,
328 | hypo_max_sentence_length=hypo_max_sentence_length,
329 | is_training=is_training,
330 | language = language,
331 | )
332 |
333 | # examples에 대한 annotate_
334 | # features = list( feature1, feature2, feature3, ... )
335 | ## len(features) == len(examples)
336 | features = list(
337 | tqdm(
338 | p.imap(annotate_, examples, chunksize=32),
339 | total=len(examples),
340 | desc="convert klue nli examples to features",
341 | disable=not tqdm_enabled,
342 | )
343 | )
344 | new_features = []
345 | example_index = 0 # example의 id ## len(features) == len(examples)
346 | for example_feature in tqdm(
347 | features, total=len(features), desc="add example index", disable=not tqdm_enabled
348 | ):
349 | if not example_feature:
350 | continue
351 |
352 | example_feature.example_index = example_index
353 | new_features.append(example_feature)
354 | example_index += 1
355 |
356 | features = new_features
357 | del new_features
358 |
359 | if return_dataset == "pt":
360 | if not is_torch_available():
361 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.")
362 |
363 | # Convert to Tensors and build dataset
364 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
365 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
366 |
367 | ## RoBERTa doesn’t have token_type_ids, you don’t need to indicate which token belongs to which segment.
368 | if language == "electra":
369 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
370 | # all_position_ids = torch.tensor([f.#position_ids for f in features], dtype=torch.long)
371 |
372 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
373 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
374 |
375 | all_example_indices = torch.tensor([f.example_index for f in features], dtype=torch.long)
376 | all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) # 전체 feature의 개별 index
377 |
378 | # all_dependency = [[[premise_tail, premise_head, dependency], [], ...],[[hypothesis_tail, hypothesis_head, dependency], [], ...]], [[],[]], ... ]
379 | all_prem_dependency = torch.tensor([f.prem_dependency for f in features], dtype=torch.long)
380 | all_hypo_dependency = torch.tensor([f.hypo_dependency for f in features], dtype=torch.long)
381 |
382 | all_prem_not_word_list = torch.tensor([f.prem_not_word_list for f in features], dtype=torch.float)
383 | all_hypo_not_word_list = torch.tensor([f.hypo_not_word_list for f in features], dtype=torch.float)
384 |
385 | if not is_training:
386 |
387 | if language == "electra":
388 | dataset = TensorDataset(
389 | all_input_ids,
390 | all_attention_masks, all_token_type_ids, #all_position_ids,
391 | all_cls_index, all_p_mask, all_feature_index,
392 | all_prem_dependency, all_hypo_dependency,
393 | all_prem_not_word_list, all_hypo_not_word_list
394 | )
395 | else:
396 | dataset = TensorDataset(
397 | all_input_ids,
398 | all_attention_masks, # all_token_type_ids, all_position_ids,
399 | all_cls_index, all_p_mask, all_feature_index,
400 | all_prem_dependency, all_hypo_dependency,
401 | all_prem_not_word_list, all_hypo_not_word_list
402 | )
403 | else:
404 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
405 | # label_dict = {"entailment": 0, "contradiction": 1, "neutral": 2}
406 | # all_labels = torch.tensor([label_dict[f.label] for f in features], dtype=torch.long)
407 |
408 | if language == "electra":
409 | dataset = TensorDataset(
410 | all_input_ids,
411 | all_attention_masks,
412 | all_token_type_ids,
413 | #all_position_ids,
414 | all_labels,
415 | all_cls_index,
416 | all_p_mask,
417 | all_example_indices,
418 | all_feature_index,
419 | all_prem_dependency, all_hypo_dependency,
420 | all_prem_not_word_list, all_hypo_not_word_list
421 | )
422 | else:
423 | dataset = TensorDataset(
424 | all_input_ids,
425 | all_attention_masks,
426 | # all_token_type_ids,
427 | # all_position_ids,
428 | all_labels,
429 | all_cls_index,
430 | all_p_mask,
431 | all_example_indices,
432 | all_feature_index,
433 | all_prem_dependency, all_hypo_dependency,
434 | all_prem_not_word_list, all_hypo_not_word_list
435 | )
436 |
437 |
438 | return features, dataset
439 | else:
440 | return features
441 |
442 | class KLUE_NLIProcessor(DataProcessor):
443 | train_file = None
444 | dev_file = None
445 |
446 | def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False):
447 | if not evaluate:
448 | gold_label = None
449 | label = tensor_dict["gold_label"].numpy().decode("utf-8")
450 | else:
451 | gold_label = tensor_dict["gold_label"].numpy().decode("utf-8")
452 | label = None
453 |
454 | return KLUE_NLIExample(
455 | guid=tensor_dict["guid"].numpy().decode("utf-8"),
456 | genre=tensor_dict["genre"].numpy().decode("utf-8"),
457 | premise=tensor_dict["premise"]["origin"].numpy().decode("utf-8"),
458 | premise_koala=tensor_dict["premise"]["merge"]["koala"].numpy().decode("utf-8"),
459 | premise_merge=tensor_dict["premise"]["merge"]["origin"].numpy().decode("utf-8"),
460 | hypothesis=tensor_dict["hypothesis"]["origin"].numpy().decode("utf-8"),
461 | hypothesis_koala=tensor_dict["hypothesis"]["merge"]["koala"].numpy().decode("utf-8"),
462 | hypothesis_merge=tensor_dict["hypothesis"]["merge"]["origin"].numpy().decode("utf-8"),
463 | gold_label=gold_label,
464 | label=label,
465 | )
466 |
467 | def get_examples_from_dataset(self, dataset, evaluate=False):
468 | """
469 | Creates a list of :class:`~transformers.data.processors.squad.KLUE_NLIExample` using a TFDS dataset.
470 |
471 | Args:
472 | dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")`
473 | evaluate: boolean specifying if in evaluation mode or in training mode
474 |
475 | Returns:
476 | List of KLUE_NLIExample
477 |
478 | Examples::
479 |
480 | import tensorflow_datasets as tfds
481 | dataset = tfds.load("squad")
482 |
483 | training_examples = get_examples_from_dataset(dataset, evaluate=False)
484 | evaluation_examples = get_examples_from_dataset(dataset, evaluate=True)
485 | """
486 |
487 | if evaluate:
488 | dataset = dataset["validation"]
489 | else:
490 | dataset = dataset["train"]
491 |
492 | examples = []
493 | for tensor_dict in tqdm(dataset):
494 | examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate))
495 |
496 | return examples
497 |
498 | def get_train_examples(self, data_dir, filename=None, depend_embedding = None):
499 | """
500 | Returns the training examples from the data directory.
501 |
502 | Args:
503 | data_dir: Directory containing the data files used for training and evaluating.
504 | filename: None by default.
505 |
506 | """
507 | if data_dir is None:
508 | data_dir = ""
509 |
510 | if self.train_file is None:
511 | raise ValueError("KLUE_NLIProcessor should be instantiated via KLUE_NLIV1Processor.")
512 |
513 | with open(
514 | os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8"
515 | ) as reader:
516 | data_file = self.train_file if filename is None else filename
517 | if data_file.split(".")[-1] == "json":
518 | input_data = json.load(reader)
519 | elif data_file.split(".")[-1] == "jsonl":
520 | input_data = []
521 | for line in reader:
522 | data = json.loads(line)
523 | input_data.append(data)
524 | return self._create_examples(input_data, 'train', self.train_file if filename is None else filename)
525 |
526 | def get_dev_examples(self, data_dir, filename=None, depend_embedding = None):
527 | """
528 | Returns the evaluation example from the data directory.
529 |
530 | Args:
531 | data_dir: Directory containing the data files used for training and evaluating.
532 | filename: None by default.
533 | """
534 | if data_dir is None:
535 | data_dir = ""
536 |
537 | if self.dev_file is None:
538 | raise ValueError("KLUE_NLIProcessor should be instantiated via KLUE_NLIV1Processor.")
539 |
540 | with open(
541 | os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8"
542 | ) as reader:
543 | data_file = self.dev_file if filename is None else filename
544 | if data_file.split(".")[-1] == "json":
545 | input_data = json.load(reader)
546 | elif data_file.split(".")[-1] == "jsonl":
547 | input_data = []
548 | for line in reader:
549 | data = json.loads(line)
550 | input_data.append(data)
551 | return self._create_examples(input_data, "dev", self.dev_file if filename is None else filename)
552 |
553 | def get_example_from_input(self, input_dictionary):
554 | # guid, genre, premise, hypothesis
555 | guid=input_dictionary["guid"]
556 | premise=input_dictionary["premise"]
557 | hypothesis=input_dictionary["hypothesis"]
558 | gold_label=None
559 | label = None
560 |
561 | examples = [KLUE_NLIExample(
562 | guid=guid,
563 | genre="",
564 | premise=premise,
565 | hypothesis=hypothesis,
566 | gold_label=gold_label,
567 | label=label,
568 | )]
569 | return examples
570 |
571 | def _create_examples(self, input_data, set_type, data_file):
572 | is_training = set_type == "train"
573 | num = 0
574 | examples = []
575 | parsing = data_file.split("/")[-1].split("_")[0]
576 | if parsing == "merge": parsing = "koala"
577 | for entry in tqdm(input_data):
578 | guid = entry["guid"]
579 | if "genre" in entry: genre = entry["genre"]
580 | elif "source" in entry: genre = entry["source"]
581 | else: genre = entry["guid"].split("_")[0]
582 | premise = entry["premise"]["origin"]
583 | premise_merge = entry["premise"]["merge"]["origin"]
584 | premise_koala = entry["premise"]["merge"][parsing]
585 |
586 | hypothesis = entry["hypothesis"]["origin"]
587 | hypothesis_merge = entry["hypothesis"]["merge"]["origin"]
588 | hypothesis_koala = entry["hypothesis"]["merge"][parsing]
589 |
590 | gold_label = None
591 | label = None
592 |
593 | if is_training:
594 | label = entry["gold_label"]
595 | else:
596 | gold_label = entry["gold_label"]
597 |
598 | example = KLUE_NLIExample(
599 | guid=guid,
600 | genre=genre,
601 | premise=premise,
602 | premise_koala=premise_koala,
603 | premise_merge=premise_merge,
604 | hypothesis=hypothesis,
605 | hypothesis_koala=hypothesis_koala,
606 | hypothesis_merge=hypothesis_merge,
607 | gold_label=gold_label,
608 | label=label,
609 | )
610 | examples.append(example)
611 | # len(examples) == len(input_data)
612 | return examples
613 |
614 |
615 | class KLUE_NLIV1Processor(KLUE_NLIProcessor):
616 | train_file = "klue-nil-v1_train.json"
617 | dev_file = "klue-nli-v1_dev.json"
618 |
619 |
620 | class KLUE_NLIExample(object):
621 | def __init__(
622 | self,
623 | guid,
624 | genre,
625 | premise,
626 | premise_koala,
627 | premise_merge,
628 | hypothesis,
629 | hypothesis_koala,
630 | hypothesis_merge,
631 | gold_label=None,
632 | label=None,
633 | ):
634 | self.guid = guid
635 | self.genre = genre
636 | self.premise = premise
637 | self.premise_merge = premise_merge
638 | self.premise_koala = premise_koala
639 | self.hypothesis = hypothesis
640 | self.hypothesis_koala = hypothesis_koala
641 | self.hypothesis_merge = hypothesis_merge
642 |
643 | label_dict = {"entailment": 0, "contradiction": 1, "neutral": 2}
644 | if gold_label in label_dict.keys():
645 | gold_label = label_dict[gold_label]
646 | self.gold_label = gold_label
647 |
648 | if label in label_dict.keys():
649 | label = label_dict[label]
650 | self.label = label
651 |
652 | # doct_tokens : 띄어쓰기 기준으로 나누어진 어절(word)로 만들어진 리스트
653 | ## sentence1 sentence2
654 | self.doc_tokens = {"premise":self.premise.strip().split(), "hypothesis":self.hypothesis.strip().split()}
655 |
656 | # merge: 말뭉치의 시작위치를 어절 기준으로 만든 리스트
657 | merge_index = [[],[]]
658 | merge_word = [[],[]]
659 | for merge in self.premise_merge:
660 | if merge[1] != []: merge_index[0].append(merge[1])
661 | for merge in self.hypothesis_merge:
662 | if merge[1] != []: merge_index[1].append(merge[1])
663 |
664 | # 구문구조 종류
665 | depend2idx = {"None":0}; idx2depend ={0:"None"}
666 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']:
667 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]:
668 | depend2idx[depend1 + "-" + depend2] = len(depend2idx)
669 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2
670 |
671 | if ([words for words in self.premise_koala if words[2][0] != words[2][1]] == []): merge_word[0].append([])
672 | else:
673 | for words in self.premise_koala:
674 | if words[2][0] != words[2][1]:
675 | if not [merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1])] in [merge_word[:-1] for merge_word in merge_word[0]]:
676 | merge_word[0].append([merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1]),depend2idx[words[1][0]+"-"+words[1][1]]])
677 | else:
678 | idx = [merge_w[:-1] for merge_w in merge_word[0]].index([merge_index[0].index(words[2][0]), merge_index[0].index(words[2][1])])
679 | tag = idx2depend[[merge_w[-1] for merge_w in merge_word[0]][idx]].split("-")
680 | if (words[1][1] in ['SBJ', 'CNJ', 'OBJ']) and(tag[1] in ['CMP', 'MOD', 'AJT', 'None', "UNDEF"]):
681 | merge_word[0][idx][2] = depend2idx[tag[0]+"-"+words[1][1]]
682 |
683 |
684 | if ([words for words in self.hypothesis_koala if words[2][0] != words[2][1]] == []): merge_word[1].append([])
685 | else:
686 | for words in self.hypothesis_koala:
687 | if words[2][0] != words[2][1]:
688 | if not [merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1])] in [merge_word[:-1] for merge_word in merge_word[1]]:
689 | merge_word[1].append([merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1]),depend2idx[words[1][0]+"-"+words[1][1]]])
690 | else:
691 | idx = [merge_w[:-1] for merge_w in merge_word[1]].index([merge_index[1].index(words[2][0]), merge_index[1].index(words[2][1])])
692 | tag = idx2depend[[merge_w[-1] for merge_w in merge_word[1]][idx]].split("-")
693 | if (words[1][1] in ['SBJ', 'CNJ', 'OBJ']) and(tag[1] in ['CMP', 'MOD', 'AJT', 'None', "UNDEF"]):
694 | merge_word[1][idx][2] = depend2idx[tag[0]+"-"+words[1][1]]
695 |
696 | self.merge = {"premise":merge_index[0], "hypothesis":merge_index[1]}
697 | self.dependency = {"premise":merge_word[0], "hypothesis":merge_word[1]}
698 |
699 |
700 | class KLUE_NLIFeatures(object):
701 | def __init__(
702 | self,
703 | input_ids,
704 | attention_mask,
705 | token_type_ids,
706 | #position_ids,
707 | cls_index,
708 | p_mask,
709 | example_index,
710 | token_to_orig_map,
711 | guid,
712 | tokens,
713 | label,
714 | language,
715 | prem_dependency,
716 | hypo_dependency,
717 | prem_not_word_list,
718 | hypo_not_word_list,
719 | ):
720 | self.input_ids = input_ids
721 | self.attention_mask = attention_mask
722 | if language == "electra": self.token_type_ids = token_type_ids
723 | #self.position_ids = #position_ids
724 | self.cls_index = cls_index
725 | self.p_mask = p_mask
726 |
727 | self.example_index = example_index
728 | self.token_to_orig_map = token_to_orig_map
729 | self.guid = guid
730 | self.tokens = tokens
731 |
732 | self.label = label
733 |
734 | self.prem_dependency = prem_dependency
735 | self.hypo_dependency = hypo_dependency
736 |
737 | self.prem_not_word_list = prem_not_word_list,
738 | self.hypo_not_word_list = hypo_not_word_list,
739 |
740 |
741 | class KLUEResult(object):
742 | def __init__(self, example_index, label_logits, gold_label=None, cls_logits=None):
743 | self.label_logits = label_logits
744 | self.example_index = example_index
745 |
746 | if gold_label:
747 | self.gold_label = gold_label
748 | self.cls_logits = cls_logits
--------------------------------------------------------------------------------
/src/model/model.py:
--------------------------------------------------------------------------------
1 | # model += Parsing Infor Collecting Layer (PIC)
2 |
3 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
4 | import torch.nn as nn
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | #from transformers.modeling_electra import ElectraModel, ElectraPreTrainedModel
9 |
10 | from transformers import ElectraModel, RobertaModel
11 |
12 | import transformers
13 | if int(transformers.__version__[0]) <= 3:
14 | from transformers.modeling_roberta import RobertaPreTrainedModel
15 | from transformers.modeling_bert import BertPreTrainedModel
16 | from transformers.modeling_electra import ElectraModel, ElectraPreTrainedModel
17 | else:
18 | from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
19 | from transformers.models.bert.modeling_bert import BertPreTrainedModel
20 | from transformers.models.electra.modeling_electra import ElectraPreTrainedModel
21 |
22 | from src.functions.biattention import BiAttention, BiLinear
23 |
24 | class RobertaForSequenceClassification(BertPreTrainedModel):
25 |
26 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
27 | super().__init__(config)
28 | self.num_labels = config.num_labels
29 | self.config = config
30 | self.roberta = RobertaModel(config)
31 |
32 | # 입력 토큰에서 token1, token2가 있을 때 (index of token1, index of token2)를 하나의 span으로 보고 이에 대한 정보를 학습
33 | self.span_info_collect = SICModel1(config)
34 | #self.span_info_collect = SICModel2(config)
35 |
36 | # biaffine을 통해 premise와 hypothesis span에 대한 정보를 결합후 정규화
37 | self.parsing_info_collect = PICModel1(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + klue-biaffine attention + bilistm + klue-bilinear classification
38 | #self.parsing_info_collect = PICModel2(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + bilistm + klue-bilinear classification
39 | #self.parsing_info_collect = PICModel3(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + bilistm + klue-bilinear classification
40 | #self.parsing_info_collect = PICModel4(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보(1) + bilistm + bilinear classification
41 | #self.parsing_info_collect = PICModel5(config, prem_max_sentence_length, hypo_max_sentence_length) # 구묶음 + tag 정보 + biaffine attention + bilistm + bilinear classification
42 |
43 | self.init_weights()
44 |
45 | def forward(
46 | self,
47 | input_ids=None,
48 | attention_mask=None,
49 | token_type_ids=None,
50 | position_ids=None,
51 | head_mask=None,
52 | inputs_embeds=None,
53 | labels=None,
54 | prem_span=None,
55 | hypo_span=None,
56 | prem_word_idxs=None,
57 | hypo_word_idxs=None,
58 | ):
59 | batch_size = input_ids.shape[0]
60 | discriminator_hidden_states = self.roberta(
61 | input_ids=input_ids,
62 | attention_mask=attention_mask,
63 | )
64 | # last-layer hidden state
65 | # sequence_output: [batch_size, seq_length, hidden_size]
66 | sequence_output = discriminator_hidden_states[0]
67 |
68 | # span info collecting layer(SIC)
69 | h_ij = self.span_info_collect(sequence_output, prem_word_idxs, hypo_word_idxs)
70 |
71 | # parser info collecting layer(PIC)
72 | logits = self.parsing_info_collect(h_ij,
73 | batch_size= batch_size,
74 | prem_span=prem_span,hypo_span=hypo_span,)
75 |
76 | outputs = (logits, ) + discriminator_hidden_states[2:]
77 |
78 | if labels is not None:
79 | if self.num_labels == 1:
80 | # We are doing regression
81 | loss_fct = MSELoss()
82 | loss = loss_fct(logits.view(-1), labels.view(-1))
83 | else:
84 | loss_fct = CrossEntropyLoss()
85 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
86 | print("loss: "+str(loss))
87 | outputs = (loss,) + outputs
88 |
89 | return outputs # (loss), logits, (hidden_states), (attentions)
90 |
91 |
92 |
93 | class SICModel1(nn.Module):
94 | def __init__(self, config):
95 | super().__init__()
96 |
97 | def forward(self, hidden_states, prem_word_idxs, hypo_word_idxs):
98 | # (batch, max_pre_sen, seq_len) @ (batch, seq_len, hidden) = (batch, max_pre_sen, hidden)
99 | prem_word_idxs = prem_word_idxs.squeeze(1)
100 | hypo_word_idxs = hypo_word_idxs.squeeze(1)
101 |
102 | prem = torch.matmul(prem_word_idxs, hidden_states)
103 | hypo = torch.matmul(hypo_word_idxs, hidden_states)
104 |
105 | return [prem, hypo]
106 |
107 | class SICModel2(nn.Module):
108 | def __init__(self, config):
109 | super().__init__()
110 | self.hidden_size = config.hidden_size
111 |
112 | self.W_p_1 = nn.Linear(self.hidden_size, self.hidden_size)
113 | self.W_p_2 = nn.Linear(self.hidden_size, self.hidden_size)
114 | self.W_p_3 = nn.Linear(self.hidden_size, self.hidden_size)
115 | self.W_p_4 = nn.Linear(self.hidden_size, self.hidden_size)
116 |
117 | self.W_h_1 = nn.Linear(self.hidden_size, self.hidden_size)
118 | self.W_h_2 = nn.Linear(self.hidden_size, self.hidden_size)
119 | self.W_h_3 = nn.Linear(self.hidden_size, self.hidden_size)
120 | self.W_h_4 = nn.Linear(self.hidden_size, self.hidden_size)
121 |
122 | def forward(self, hidden_states, prem_word_idxs, hypo_word_idxs):
123 | prem_word_idxs = prem_word_idxs.squeeze(1).type(torch.LongTensor).to("cuda")
124 | hypo_word_idxs = hypo_word_idxs.squeeze(1).type(torch.LongTensor).to("cuda")
125 |
126 | Wp1_h = self.W_p_1(hidden_states) # (bs, length, hidden_size)
127 | Wp2_h = self.W_p_2(hidden_states)
128 | Wp3_h = self.W_p_3(hidden_states)
129 | Wp4_h = self.W_p_4(hidden_states)
130 |
131 | Wh1_h = self.W_h_1(hidden_states) # (bs, length, hidden_size)
132 | Wh2_h = self.W_h_2(hidden_states)
133 | Wh3_h = self.W_h_3(hidden_states)
134 | Wh4_h = self.W_h_4(hidden_states)
135 |
136 | W1_hi_emb=torch.tensor([], dtype=torch.long).to("cuda")
137 | W2_hi_emb=torch.tensor([], dtype=torch.long).to("cuda")
138 | W3_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda")
139 | W3_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda")
140 | W4_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda")
141 | W4_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda")
142 | for i in range(0, hidden_states.shape[0]):
143 | sub_W1_hi_emb = torch.index_select(Wp1_h[i], 0, prem_word_idxs[i][0]) # (prem_max_seq_length, hidden_size)
144 | sub_W2_hi_emb = torch.index_select(Wp2_h[i], 0, prem_word_idxs[i][1])
145 | sub_W3_hi_start_emb = torch.index_select(Wp3_h[i], 0, prem_word_idxs[i][0])
146 | sub_W3_hi_end_emb = torch.index_select(Wp3_h[i], 0, prem_word_idxs[i][1])
147 | sub_W4_hi_start_emb = torch.index_select(Wp4_h[i], 0, prem_word_idxs[i][0])
148 | sub_W4_hi_end_emb = torch.index_select(Wp4_h[i], 0, prem_word_idxs[i][1])
149 |
150 | W1_hi_emb = torch.cat((W1_hi_emb, sub_W1_hi_emb.unsqueeze(0)))
151 | W2_hi_emb = torch.cat((W2_hi_emb, sub_W2_hi_emb.unsqueeze(0)))
152 | W3_hi_start_emb = torch.cat((W3_hi_start_emb, sub_W3_hi_start_emb.unsqueeze(0)))
153 | W3_hi_end_emb = torch.cat((W3_hi_end_emb, sub_W3_hi_end_emb.unsqueeze(0)))
154 | W4_hi_start_emb = torch.cat((W4_hi_start_emb, sub_W4_hi_start_emb.unsqueeze(0)))
155 | W4_hi_end_emb = torch.cat((W4_hi_end_emb, sub_W4_hi_end_emb.unsqueeze(0)))
156 |
157 | # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)]
158 | prem_span = W1_hi_emb + W2_hi_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hi_start_emb, W4_hi_end_emb) # (batch_size, prem_max_seq_length, hidden_size)
159 | prem_h_ij = torch.tanh(prem_span)
160 |
161 | W1_hi_emb = torch.tensor([], dtype=torch.long).to("cuda")
162 | W2_hi_emb = torch.tensor([], dtype=torch.long).to("cuda")
163 | W3_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda")
164 | W3_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda")
165 | W4_hi_start_emb = torch.tensor([], dtype=torch.long).to("cuda")
166 | W4_hi_end_emb = torch.tensor([], dtype=torch.long).to("cuda")
167 | for i in range(0, hidden_states.shape[0]):
168 | sub_W1_hi_emb = torch.index_select(Wh1_h[i], 0, hypo_word_idxs[i][0]) # (hypo_max_seq_length, hidden_size)
169 | sub_W2_hi_emb = torch.index_select(Wh2_h[i], 0, hypo_word_idxs[i][1])
170 | sub_W3_hi_start_emb = torch.index_select(Wh3_h[i], 0, hypo_word_idxs[i][0])
171 | sub_W3_hi_end_emb = torch.index_select(Wh3_h[i], 0, hypo_word_idxs[i][1])
172 | sub_W4_hi_start_emb = torch.index_select(Wh4_h[i], 0, hypo_word_idxs[i][0])
173 | sub_W4_hi_end_emb = torch.index_select(Wh4_h[i], 0, hypo_word_idxs[i][1])
174 |
175 | W1_hi_emb = torch.cat((W1_hi_emb, sub_W1_hi_emb.unsqueeze(0)))
176 | W2_hi_emb = torch.cat((W2_hi_emb, sub_W2_hi_emb.unsqueeze(0)))
177 | W3_hi_start_emb = torch.cat((W3_hi_start_emb, sub_W3_hi_start_emb.unsqueeze(0)))
178 | W3_hi_end_emb = torch.cat((W3_hi_end_emb, sub_W3_hi_end_emb.unsqueeze(0)))
179 | W4_hi_start_emb = torch.cat((W4_hi_start_emb, sub_W4_hi_start_emb.unsqueeze(0)))
180 | W4_hi_end_emb = torch.cat((W4_hi_end_emb, sub_W4_hi_end_emb.unsqueeze(0)))
181 |
182 | # [w1*hi, w2*hj, w3(hi-hj), w4(hi⊗hj)]
183 | hypo_span = W1_hi_emb + W2_hi_emb + (W3_hi_start_emb - W3_hi_end_emb) + torch.mul(W4_hi_start_emb, W4_hi_end_emb) # (batch_size, hypo_max_seq_length, hidden_size)
184 | hypo_h_ij = torch.tanh(hypo_span)
185 |
186 | h_ij = [prem_h_ij, hypo_h_ij]
187 |
188 | return h_ij
189 |
190 |
191 | class PICModel1(nn.Module):
192 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
193 | super().__init__()
194 | self.hidden_size = config.hidden_size
195 | self.prem_max_sentence_length = prem_max_sentence_length
196 | self.hypo_max_sentence_length = hypo_max_sentence_length
197 | self.num_labels = config.num_labels
198 |
199 | # 구문구조 종류
200 | depend2idx = {"None": 0};
201 | idx2depend = {0: "None"};
202 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']:
203 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]:
204 | depend2idx[depend1 + "-" + depend2] = len(depend2idx)
205 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2
206 | self.depend2idx = depend2idx
207 | self.idx2depend = idx2depend
208 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda")
209 |
210 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
211 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
212 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
213 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
214 |
215 | self.biaffine1 = BiAttention(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
216 | self.biaffine2 = BiAttention(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
217 |
218 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
219 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
220 |
221 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels)
222 |
223 | # self.W_1_bilinear = nn.Bilinear(int(self.hidden_size // 3), int(self.hidden_size // 3), self.hidden_size, bias=False)
224 | # self.W_1_linear = nn.Linear(int(self.hidden_size // 3) + int(self.hidden_size // 3), self.hidden_size)
225 | # self.W_2_bilinear = nn.Bilinear(int(self.hidden_size // 3), int(self.hidden_size // 3), self.hidden_size, bias=False)
226 | # self.W_2_linear = nn.Linear(int(self.hidden_size // 3) + int(self.hidden_size // 3), self.hidden_size)
227 | #
228 | # self.bi_lism_1 = nn.LSTM(input_size=self.hidden_size, hidden_size=self.prem_max_sentence_length//2, num_layers=1, bidirectional=True)
229 | # self.bi_lism_2 = nn.LSTM(input_size=self.hidden_size, hidden_size=self.hypo_max_sentence_length//2, num_layers=1, bidirectional=True)
230 | #
231 | # self.dropout = nn.Dropout(config.hidden_dropout_prob) # 일반화된 정보를 사용
232 | # self.biaffine_W_bilinear = nn.Linear((2*(self.prem_max_sentence_length//2))*(2*(self.hypo_max_sentence_length//2)), self.num_labels, bias=False)
233 | # self.biaffine_W_linear = nn.Linear(2*(self.prem_max_sentence_length//2) + 2*(self.hypo_max_sentence_length//2), self.num_labels)
234 | # #self.biaffine_W_bilinear = nn.Bilinear(self.hidden_size, self.hidden_size, self.num_labels, bias=False)
235 | # #self.biaffine_W_linear = nn.Linear(self.hidden_size*2, self.num_labels)
236 | #
237 | # self.reset_parameters()
238 |
239 | def forward(self, hidden_states, batch_size, prem_span, hypo_span):
240 | # hidden_states: [[batch_size, word_idxs, hidden_size], []]
241 | # span: [batch_size, max_sentence_length, max_sentence_length]
242 | # word_idxs: [batch_size, seq_length]
243 | # -> sequence_outputs: [batch_size, seq_length, hidden_size]
244 |
245 | prem_hidden_states= hidden_states[0]
246 | hypo_hidden_states= hidden_states[1]
247 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape)
248 |
249 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size)
250 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda")
251 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda")
252 |
253 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())):
254 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len)
255 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda")
256 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda")
257 |
258 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size)
259 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail)
260 | p_span_dep = self.depend_embedding(p_span_dep)
261 |
262 | n_p_span = p_span_head + p_span_tail + p_span_dep
263 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0)))
264 |
265 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len)
266 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda")
267 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda")
268 |
269 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size)
270 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail)
271 | h_span_dep = self.depend_embedding(h_span_dep)
272 |
273 | n_h_span = h_span_head + h_span_tail + h_span_dep
274 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0)))
275 |
276 | prem_span = new_prem_span
277 | hypo_span = new_hypo_span
278 |
279 | del new_prem_span
280 | del new_hypo_span
281 |
282 | # biaffine attention
283 | # hidden_states: (batch_size, max_prem_len, hidden_size)
284 | # span: (batch, max_prem_len, hidden_size)
285 | # -> biaffine_outputs: [batch_size, 100, max_prem_len, max_prem_len]
286 | prem_span = self.reduction1(prem_span)
287 | prem_hidden_states = self.reduction2(prem_hidden_states)
288 | hypo_span = self.reduction3(hypo_span)
289 | hypo_hidden_states = self.reduction4(hypo_hidden_states)
290 |
291 | prem_biaffine_outputs= self.biaffine1(prem_hidden_states, prem_span)
292 | hypo_biaffine_outputs = self.biaffine2(hypo_hidden_states, hypo_span)
293 |
294 | # outputs = self.bilinear(prem_biaffine_outputs.view(-1,self.prem_max_sentence_length*self.prem_max_sentence_length),
295 | # hypo_biaffine_outputs.view(-1,self.hypo_max_sentence_length*self.hypo_max_sentence_length))
296 |
297 | # bilstm
298 | # biaffine_outputs: [batch_size, 100, max_prem_len, max_prem_len] -> [batch_size, 100, max_prem_len] -> [max_prem_len, batch_size, 100]
299 | # -> hidden_states: [batch_size, max_sentence_length]
300 | prem_biaffine_outputs = prem_biaffine_outputs.mean(-1)
301 | hypo_biaffine_outputs = hypo_biaffine_outputs.mean(-1)
302 |
303 | prem_biaffine_outputs = prem_biaffine_outputs.transpose(1,2).transpose(0,1)
304 | hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(1,2).transpose(0,1)
305 |
306 | prem_states = None
307 | hypo_states = None
308 |
309 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs)
310 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs)
311 |
312 |
313 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
314 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
315 |
316 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states)
317 |
318 | # new_prem_hidden_states = prem_hidden_states.view(-1, self.prem_max_sentence_length, 1, self.hidden_size) # (batch_size, max_prem_len, 1, hidden_size)
319 | # new_hypo_hidden_states = hypo_hidden_states.view(-1, self.hypo_max_sentence_length, 1, self.hidden_size)
320 | # new_prem_span = prem_span.view(-1, self.prem_max_sentence_length, 3, self.hidden_size)# (batch_size, max_prem_len, 3, hidden_size)
321 | # new_hypo_span = hypo_span.view(-1, self.hypo_max_sentence_length, 3, self.hidden_size)
322 | #
323 | # prem_depend = (new_prem_hidden_states.unsqueeze(-1) * new_prem_span.unsqueeze(-2)).view(-1, self.prem_max_sentence_length, 3*self.hidden_size, self.hidden_size) # (batch_size, max_prem_len, 3*hidden_size, hidden_size)
324 | # prem_depend = self.reduction1(prem_depend.transpose(2,3)).view(-1, self.prem_max_sentence_length, int(3 * self.hidden_size // 64)*self.hidden_size) # (batch_size, max_prem_len, hidden_size, 3*hidden_size) -> (batch_size, max_prem_len, hidden_size, int(3 * self.hidden_size // 64))
325 | #
326 | # hypo_depend = (new_hypo_hidden_states.unsqueeze(-1) * new_hypo_span.unsqueeze(-2)).view(-1, self.hypo_max_sentence_length, 3*self.hidden_size, self.hidden_size)
327 | # hypo_depend = self.reduction2(hypo_depend.transpose(2,3)).view(-1, self.hypo_max_sentence_length, int(3 * self.hidden_size // 64)*self.hidden_size)
328 | #
329 | # prem_biaffine_outputs= self.W_1_bilinear(prem_depend) + self.W_1_linear(torch.cat((prem_span, prem_hidden_states), dim = -1))
330 | # hypo_biaffine_outputs = self.W_2_bilinear(hypo_depend) + self.W_2_linear(torch.cat((hypo_span, hypo_hidden_states), dim = -1))
331 | #
332 | # # bilstm
333 | # # biaffine_outputs: [batch_size, max_sentence_length, hidden_size]
334 | # # -> hidden_states: [batch_size, max_sentence_length]
335 | # prem_biaffine_outputs = prem_biaffine_outputs.transpose(0,1)
336 | # hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(0,1)
337 | #
338 | # prem_states = None
339 | # hypo_states = None
340 | #
341 | # prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs)
342 | # hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs)
343 | #
344 | # prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
345 | # hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
346 | # # biaffine attention
347 | # # prem_hidden_states: (batch_size, max_prem_len)
348 | # # hypo_hidden_states: (batch_size, max_hypo_len)
349 | #
350 | # prem_hypo = (prem_hidden_states.unsqueeze(-1) * hypo_hidden_states.unsqueeze(-2)).view(-1,
351 | # (2*(self.prem_max_sentence_length//2))*(2*(self.hypo_max_sentence_length//2)))
352 | #
353 | # outputs = self.biaffine_W_bilinear(prem_hypo) + self.biaffine_W_linear(torch.cat((prem_hidden_states, hypo_hidden_states), dim=-1))
354 |
355 | return outputs
356 |
357 | def reset_parameters(self):
358 | self.W_1_bilinear.reset_parameters()
359 | self.W_1_linear.reset_parameters()
360 | self.W_2_bilinear.reset_parameters()
361 | self.W_2_linear.reset_parameters()
362 |
363 | self.biaffine_W_bilinear.reset_parameters()
364 | self.biaffine_W_linear.reset_parameters()
365 |
366 |
367 |
368 | class PICModel2(nn.Module):
369 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
370 | super().__init__()
371 | self.hidden_size = config.hidden_size
372 | self.prem_max_sentence_length = prem_max_sentence_length
373 | self.hypo_max_sentence_length = hypo_max_sentence_length
374 | self.num_labels = config.num_labels
375 |
376 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
377 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
378 |
379 | self.bi_lism_1 = nn.LSTM(input_size=int(self.hidden_size // 3), hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
380 | self.bi_lism_2 = nn.LSTM(input_size=int(self.hidden_size // 3), hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
381 |
382 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels)
383 |
384 | def forward(self, hidden_states, batch_size, prem_span, hypo_span):
385 | # hidden_states: [[batch_size, word_idxs, hidden_size], []]
386 | # span: [batch_size, max_sentence_length, max_sentence_length]
387 | # word_idxs: [batch_size, seq_length]
388 | # -> sequence_outputs: [batch_size, seq_length, hidden_size]
389 |
390 | prem_hidden_states= hidden_states[0]
391 | hypo_hidden_states= hidden_states[1]
392 |
393 | # biLSTM
394 | # hidden_states: (batch_size, max_prem_len, hidden_size)
395 | # -> # -> hidden_states: [batch_size, hidden_size]
396 | prem_hidden_states = self.reduction1(prem_hidden_states)
397 | hypo_hidden_states = self.reduction2(hypo_hidden_states)
398 | prem_hidden_states = prem_hidden_states.transpose(0,1)
399 | hypo_hidden_states = hypo_hidden_states.transpose(0,1)
400 |
401 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_hidden_states)
402 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_hidden_states)
403 |
404 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
405 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
406 |
407 | # bilinear classification
408 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states)
409 |
410 | return outputs
411 |
412 |
413 | class PICModel3(nn.Module):
414 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
415 | super().__init__()
416 | self.hidden_size = config.hidden_size
417 | self.prem_max_sentence_length = prem_max_sentence_length
418 | self.hypo_max_sentence_length = hypo_max_sentence_length
419 | self.num_labels = config.num_labels
420 |
421 | # 구문구조 종류
422 | depend2idx = {"None": 0};
423 | idx2depend = {0: "None"};
424 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']:
425 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]:
426 | depend2idx[depend1 + "-" + depend2] = len(depend2idx)
427 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2
428 | self.depend2idx = depend2idx
429 | self.idx2depend = idx2depend
430 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda")
431 |
432 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
433 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
434 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
435 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
436 |
437 | self.tag1 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
438 | self.tag2 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
439 |
440 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
441 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
442 |
443 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels)
444 |
445 | def forward(self, hidden_states, batch_size, prem_span, hypo_span):
446 | # hidden_states: [[batch_size, word_idxs, hidden_size], []]
447 | # span: [batch_size, max_sentence_length, max_sentence_length]
448 | # word_idxs: [batch_size, seq_length]
449 | # -> sequence_outputs: [batch_size, seq_length, hidden_size]
450 |
451 | prem_hidden_states= hidden_states[0]
452 | hypo_hidden_states= hidden_states[1]
453 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape)
454 |
455 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size)
456 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda")
457 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda")
458 |
459 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())):
460 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len)
461 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda")
462 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda")
463 |
464 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size)
465 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail)
466 | p_span_dep = self.depend_embedding(p_span_dep)
467 |
468 | n_p_span = p_span_head + p_span_tail + p_span_dep
469 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0)))
470 |
471 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len)
472 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda")
473 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda")
474 |
475 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size)
476 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail)
477 | h_span_dep = self.depend_embedding(h_span_dep)
478 |
479 | n_h_span = h_span_head + h_span_tail + h_span_dep
480 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0)))
481 |
482 | prem_span = new_prem_span
483 | hypo_span = new_hypo_span
484 |
485 | del new_prem_span
486 | del new_hypo_span
487 |
488 | # bilinear
489 | # hidden_states: (batch_size, max_prem_len, hidden_size)
490 | # span: (batch, max_prem_len, hidden_size)
491 | # -> bilinear_outputs: [batch_size, max_prem_len, 100]
492 | prem_span = self.reduction1(prem_span)
493 | prem_hidden_states = self.reduction2(prem_hidden_states)
494 | hypo_span = self.reduction3(hypo_span)
495 | hypo_hidden_states = self.reduction4(hypo_hidden_states)
496 |
497 | prem_bilinear_outputs= self.tag1(prem_hidden_states, prem_span)
498 | hypo_bilinear_outputs = self.tag2(hypo_hidden_states, hypo_span)
499 |
500 | # bilstm
501 | # biaffine_outputs: [batch_size, max_prem_len, 100]
502 | # -> hidden_states: [batch_size, hidden_size]
503 |
504 | prem_bilinear_outputs = prem_bilinear_outputs. transpose(0,1)
505 | hypo_bilinear_outputs = hypo_bilinear_outputs.transpose(0,1)
506 |
507 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_bilinear_outputs)
508 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_bilinear_outputs)
509 |
510 |
511 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
512 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
513 |
514 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states)
515 |
516 | return outputs
517 |
518 | class PICModel4(nn.Module):
519 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
520 | super().__init__()
521 | self.hidden_size = config.hidden_size
522 | self.prem_max_sentence_length = prem_max_sentence_length
523 | self.hypo_max_sentence_length = hypo_max_sentence_length
524 | self.num_labels = config.num_labels
525 |
526 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
527 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 3))
528 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
529 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 3))
530 |
531 | self.tag1 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
532 | self.tag2 = BiLinear(int(self.hidden_size // 3), int(self.hidden_size // 3), 100)
533 |
534 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
535 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
536 |
537 | self.bilinear = BiLinear(self.hidden_size, self.hidden_size, self.num_labels)
538 |
539 | def forward(self, hidden_states, batch_size, prem_span, hypo_span):
540 | # hidden_states: [[batch_size, word_idxs, hidden_size], []]
541 | # span: [batch_size, max_sentence_length, max_sentence_length]
542 | # word_idxs: [batch_size, seq_length]
543 | # -> sequence_outputs: [batch_size, seq_length, hidden_size]
544 |
545 | prem_hidden_states= hidden_states[0]
546 | hypo_hidden_states= hidden_states[1]
547 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape)
548 |
549 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size)
550 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda")
551 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda")
552 |
553 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())):
554 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len)
555 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda")
556 |
557 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size)
558 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail)
559 |
560 | n_p_span = p_span_head + p_span_tail
561 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0)))
562 |
563 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len)
564 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda")
565 |
566 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size)
567 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail)
568 |
569 | n_h_span = h_span_head + h_span_tail
570 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0)))
571 |
572 | prem_span = new_prem_span
573 | hypo_span = new_hypo_span
574 |
575 | del new_prem_span
576 | del new_hypo_span
577 |
578 | # bilinear
579 | # hidden_states: (batch_size, max_prem_len, hidden_size)
580 | # span: (batch, max_prem_len, hidden_size)
581 | # -> bilinear_outputs: [batch_size, max_prem_len, 100]
582 | prem_span = self.reduction1(prem_span)
583 | prem_hidden_states = self.reduction2(prem_hidden_states)
584 | hypo_span = self.reduction3(hypo_span)
585 | hypo_hidden_states = self.reduction4(hypo_hidden_states)
586 |
587 | prem_bilinear_outputs= self.tag1(prem_hidden_states, prem_span)
588 | hypo_bilinear_outputs = self.tag2(hypo_hidden_states, hypo_span)
589 |
590 | # bilstm
591 | # biaffine_outputs: [batch_size, max_prem_len, 100]
592 | # -> hidden_states: [batch_size, hidden_size]
593 |
594 | prem_bilinear_outputs = prem_bilinear_outputs. transpose(0,1)
595 | hypo_bilinear_outputs = hypo_bilinear_outputs.transpose(0,1)
596 |
597 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_bilinear_outputs)
598 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_bilinear_outputs)
599 |
600 |
601 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
602 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
603 |
604 | outputs = self.bilinear(prem_hidden_states, hypo_hidden_states)
605 |
606 | return outputs
607 |
608 |
609 | class PICModel5(nn.Module):
610 | def __init__(self, config, prem_max_sentence_length, hypo_max_sentence_length):
611 | super().__init__()
612 | self.hidden_size = config.hidden_size
613 | self.prem_max_sentence_length = prem_max_sentence_length
614 | self.hypo_max_sentence_length = hypo_max_sentence_length
615 | self.num_labels = config.num_labels
616 |
617 | # 구문구조 종류
618 | depend2idx = {"None": 0};
619 | idx2depend = {0: "None"};
620 | for depend1 in ['IP', 'AP', 'DP', 'VP', 'VNP', 'S', 'R', 'NP', 'L', 'X']:
621 | for depend2 in ['CMP', 'MOD', 'SBJ', 'AJT', 'CNJ', 'None', 'OBJ', "UNDEF"]:
622 | depend2idx[depend1 + "-" + depend2] = len(depend2idx)
623 | idx2depend[len(idx2depend)] = depend1 + "-" + depend2
624 | self.depend2idx = depend2idx
625 | self.idx2depend = idx2depend
626 | self.depend_embedding = nn.Embedding(len(idx2depend), self.hidden_size, padding_idx=0).to("cuda")
627 |
628 | self.reduction1 = nn.Linear(self.hidden_size , int(self.hidden_size // 6))
629 | self.reduction2 = nn.Linear(self.hidden_size , int(self.hidden_size // 6))
630 | self.reduction3 = nn.Linear(self.hidden_size, int(self.hidden_size // 6))
631 | self.reduction4 = nn.Linear(self.hidden_size, int(self.hidden_size // 6))
632 |
633 | self.W_1_bilinear = nn.Bilinear(int(self.hidden_size // 6), int(self.hidden_size // 6), 100, bias=False)
634 | self.W_1_linear1 = nn.Linear(int(self.hidden_size // 6), 100)
635 | self.W_1_linear2 = nn.Linear(int(self.hidden_size // 6), 100)
636 | self.W_2_bilinear = nn.Bilinear(int(self.hidden_size // 6), int(self.hidden_size // 6), 100, bias=False)
637 | self.W_2_linear1 = nn.Linear(int(self.hidden_size // 6), 100)
638 | self.W_2_linear2 = nn.Linear(int(self.hidden_size // 6), 100)
639 |
640 | self.bi_lism_1 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
641 | self.bi_lism_2 = nn.LSTM(input_size=100, hidden_size=self.hidden_size//2, num_layers=1, bidirectional=True)
642 |
643 | self.dropout = nn.Dropout(config.hidden_dropout_prob) # 일반화된 정보를 사용
644 | self.biaffine_W_bilinear = nn.Bilinear((2*(self.hidden_size//2)),(2*(self.hidden_size//2)), self.num_labels, bias=False)
645 | self.biaffine_W_linear1 = nn.Linear(2 * (self.hidden_size//2), self.num_labels)
646 | self.biaffine_W_linear2 = nn.Linear(2 * (self.hidden_size // 2), self.num_labels)
647 | self.reset_parameters()
648 |
649 | def forward(self, hidden_states, batch_size, prem_span, hypo_span):
650 | # hidden_states: [[batch_size, word_idxs, hidden_size], []]
651 | # span: [batch_size, max_sentence_length, max_sentence_length]
652 | # word_idxs: [batch_size, seq_length]
653 | # -> sequence_outputs: [batch_size, seq_length, hidden_size]
654 |
655 | prem_hidden_states= hidden_states[0]
656 | hypo_hidden_states= hidden_states[1]
657 | #print(prem_hidden_states.shape, hypo_hidden_states.shape, prem_span.shape, hypo_span.shape)
658 |
659 | # span: (batch, max_prem_len, 3) -> (batch, max_prem_len, 3*hidden_size)
660 | new_prem_span = torch.tensor([], dtype=torch.long).to("cuda")
661 | new_hypo_span = torch.tensor([], dtype=torch.long).to("cuda")
662 |
663 | for i, (p_span, h_span) in enumerate(zip(prem_span.tolist(), hypo_span.tolist())):
664 | p_span_head = torch.tensor([span[0] for span in p_span]).to("cuda") #(max_prem_len)
665 | p_span_tail = torch.tensor([span[1] for span in p_span]).to("cuda")
666 | p_span_dep = torch.tensor([span[2] for span in p_span]).to("cuda")
667 |
668 | p_span_head = torch.index_select(prem_hidden_states[i], 0, p_span_head) #(max_prem_len, hidden_size)
669 | p_span_tail = torch.index_select(prem_hidden_states[i], 0, p_span_tail)
670 | p_span_dep = self.depend_embedding(p_span_dep)
671 |
672 | n_p_span = p_span_head + p_span_tail + p_span_dep
673 | new_prem_span = torch.cat((new_prem_span, n_p_span.unsqueeze(0)))
674 |
675 | h_span_head = torch.tensor([span[0] for span in h_span]).to("cuda") # (max_hypo_len)
676 | h_span_tail = torch.tensor([span[1] for span in h_span]).to("cuda")
677 | h_span_dep = torch.tensor([span[2] for span in h_span]).to("cuda")
678 |
679 | h_span_head = torch.index_select(hypo_hidden_states[i], 0, h_span_head) # (max_hypo_len, hidden_size)
680 | h_span_tail = torch.index_select(hypo_hidden_states[i], 0, h_span_tail)
681 | h_span_dep = self.depend_embedding(h_span_dep)
682 |
683 | n_h_span = h_span_head + h_span_tail + h_span_dep
684 | new_hypo_span = torch.cat((new_hypo_span, n_h_span.unsqueeze(0)))
685 |
686 | prem_span = new_prem_span
687 | hypo_span = new_hypo_span
688 |
689 | del new_prem_span
690 | del new_hypo_span
691 |
692 | # biaffine attention
693 | # hidden_states: (batch_size, max_prem_len, hidden_size)
694 | # span: (batch, max_prem_len, hidden_size)
695 | # -> biaffine_outputs: [batch_size, max_prem_len, 100]
696 | prem_span = self.reduction1(prem_span)
697 | prem_hidden_states = self.reduction2(prem_hidden_states)
698 | hypo_span = self.reduction3(hypo_span)
699 | hypo_hidden_states = self.reduction4(hypo_hidden_states)
700 |
701 | prem_biaffine_outputs= self.W_1_bilinear(prem_span, prem_hidden_states) + self.W_1_linear1(prem_span) + self.W_1_linear2(prem_hidden_states)
702 | hypo_biaffine_outputs = self.W_2_bilinear(hypo_span, hypo_hidden_states) + self.W_2_linear1(hypo_span) + self.W_2_linear2(hypo_hidden_states)
703 |
704 | # bilstm
705 | # biaffine_outputs: [batch_size, max_sentence_length, hidden_size]
706 | # -> hidden_states: [batch_size, max_sentence_length]
707 | prem_biaffine_outputs = prem_biaffine_outputs.transpose(0,1)
708 | hypo_biaffine_outputs = hypo_biaffine_outputs.transpose(0,1)
709 |
710 | prem_bilstm_outputs, prem_states = self.bi_lism_1(prem_biaffine_outputs)
711 | hypo_bilstm_outputs, hypo_states = self.bi_lism_2(hypo_biaffine_outputs)
712 |
713 | prem_hidden_states = prem_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
714 | hypo_hidden_states = hypo_states[0].transpose(0, 1).contiguous().view(batch_size, -1)
715 |
716 | # biaffine attention
717 | # prem_hidden_states: (batch_size, max_prem_len)
718 | # hypo_hidden_states: (batch_size, max_hypo_len)
719 | # -> outputs: (batch_size, num_labels)
720 | outputs = self.biaffine_W_bilinear(prem_hidden_states, hypo_hidden_states) + self.biaffine_W_linear1(prem_hidden_states) +self.biaffine_W_linear2(hypo_hidden_states)
721 |
722 | return outputs
723 |
724 | def reset_parameters(self):
725 | self.W_1_bilinear.reset_parameters()
726 | self.W_1_linear1.reset_parameters()
727 | self.W_1_linear2.reset_parameters()
728 | self.W_2_bilinear.reset_parameters()
729 | self.W_2_linear1.reset_parameters()
730 | self.W_2_linear2.reset_parameters()
731 |
732 | self.biaffine_W_bilinear.reset_parameters()
733 | self.biaffine_W_linear1.reset_parameters()
734 | self.biaffine_W_linear2.reset_parameters()
--------------------------------------------------------------------------------