├── etc ├── img │ ├── logo.png │ ├── smoke.mp4 │ ├── hk_figure.png │ ├── concat_data.png │ ├── hk_figure2.png │ ├── after_preprocess.png │ ├── back_translation.png │ ├── max_length_after.png │ ├── paragraph_figure.png │ ├── retrieval_result.PNG │ ├── back_translation2.png │ ├── back_translation3.png │ ├── before_preprocess.png │ ├── before_wiki_split.png │ ├── conv1d_model_idea.png │ ├── max_length_before.png │ ├── wiki_split_tradeoff.png │ ├── QAModel_architecture.png │ ├── post_processing_img1.png │ ├── post_processing_img2.png │ ├── QAConvModel_architecture.png │ ├── conv1d_model_structure.png │ └── retrieval_model_result.png ├── .gitignore └── my_stop_dic.txt ├── 1st solution presentation.pdf ├── code ├── script │ ├── retrieval_prepare_dataset.sh │ ├── retrieval_inference.sh │ ├── retrieval_train.sh │ ├── inference.sh │ ├── pretrain.sh │ ├── train.sh │ └── run_elastic_search.sh ├── elasticsearch_retrieval.py ├── question_labeling │ ├── data_set.py │ ├── question_labeling.py │ └── train.py ├── retrieval_model.py ├── model │ ├── ConvModel.py │ ├── QueryAttentionModel.py │ ├── QAConvModelV1.py │ └── QAConvModelV2.py ├── requirements.txt ├── arguments.py ├── run_elastic_search.py ├── mask.py ├── trainer_qa.py ├── data_processing.py ├── mk_retrieval_dataset.py ├── retrieval_dataset.py ├── retrieval_inference.py ├── prepare_dataset.py ├── retrieval_train.py ├── inference.py ├── train_mrc.py └── utils_qa.py └── README.md /etc/img/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/logo.png -------------------------------------------------------------------------------- /etc/img/smoke.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/smoke.mp4 -------------------------------------------------------------------------------- /etc/img/hk_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/hk_figure.png -------------------------------------------------------------------------------- /etc/.gitignore: -------------------------------------------------------------------------------- 1 | */wandb 2 | */__pycache__ 3 | */.ipynb_checkpoints 4 | data/*.pkl 5 | data/*.arrow 6 | data/*.json 7 | output -------------------------------------------------------------------------------- /etc/img/concat_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/concat_data.png -------------------------------------------------------------------------------- /etc/img/hk_figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/hk_figure2.png -------------------------------------------------------------------------------- /etc/img/after_preprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/after_preprocess.png -------------------------------------------------------------------------------- /etc/img/back_translation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/back_translation.png -------------------------------------------------------------------------------- /etc/img/max_length_after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/max_length_after.png -------------------------------------------------------------------------------- /etc/img/paragraph_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/paragraph_figure.png -------------------------------------------------------------------------------- /etc/img/retrieval_result.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/retrieval_result.PNG -------------------------------------------------------------------------------- /1st solution presentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/1st solution presentation.pdf -------------------------------------------------------------------------------- /etc/img/back_translation2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/back_translation2.png -------------------------------------------------------------------------------- /etc/img/back_translation3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/back_translation3.png -------------------------------------------------------------------------------- /etc/img/before_preprocess.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/before_preprocess.png -------------------------------------------------------------------------------- /etc/img/before_wiki_split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/before_wiki_split.png -------------------------------------------------------------------------------- /etc/img/conv1d_model_idea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/conv1d_model_idea.png -------------------------------------------------------------------------------- /etc/img/max_length_before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/max_length_before.png -------------------------------------------------------------------------------- /etc/img/wiki_split_tradeoff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/wiki_split_tradeoff.png -------------------------------------------------------------------------------- /etc/img/QAModel_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/QAModel_architecture.png -------------------------------------------------------------------------------- /etc/img/post_processing_img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/post_processing_img1.png -------------------------------------------------------------------------------- /etc/img/post_processing_img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/post_processing_img2.png -------------------------------------------------------------------------------- /etc/img/QAConvModel_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/QAConvModel_architecture.png -------------------------------------------------------------------------------- /etc/img/conv1d_model_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/conv1d_model_structure.png -------------------------------------------------------------------------------- /etc/img/retrieval_model_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TEAM-IKYO/Open-Domain-Question-Answering/HEAD/etc/img/retrieval_model_result.png -------------------------------------------------------------------------------- /code/script/retrieval_prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | python3 mk_retrieval_dataset.py --top_k 20 \ 2 | --save_path ../data/retrieval_dataset \ 3 | --index_name nori-index -------------------------------------------------------------------------------- /code/script/retrieval_inference.sh: -------------------------------------------------------------------------------- 1 | python3 retrieval_inference.py --model_checkpoint bert-base-multilingual-cased \ 2 | --run_name best_dense_retrieval \ 3 | --es_top_k 70 \ 4 | --dr_top_k 70 \ 5 | --index_name nori-index -------------------------------------------------------------------------------- /code/script/retrieval_train.sh: -------------------------------------------------------------------------------- 1 | python3 retrieval_train.py --output_dir ../retrieval_output/ \ 2 | --model_checkpoint bert-base-multilingual-cased \ 3 | --seed 2021 \ 4 | --epoch 1 \ 5 | --learning_rate 1e-5 \ 6 | --gradient_accumulation_steps 1 \ 7 | --top_k 20 \ 8 | --run_name best_dense_retrieval -------------------------------------------------------------------------------- /code/script/inference.sh: -------------------------------------------------------------------------------- 1 | python3 inference.py --output_dir ../output/baseline_train/submission \ 2 | --model_name_or_path ../output/baseline_train/baseline_train.pt \ 3 | --tokenizer_name deepset/xlm-roberta-large-squad2 \ 4 | --config_name deepset/xlm-roberta-large-squad2 \ 5 | --retrieval_type elastic_sentence_transformer \ 6 | --retrieval_elastic_index wiki-index-split-800 \ 7 | --retrieval_elastic_num 35 \ 8 | --use_custom_model QAConvModelV2 \ 9 | --do_predict -------------------------------------------------------------------------------- /code/script/pretrain.sh: -------------------------------------------------------------------------------- 1 | python3 train_mrc.py --output_dir ../output \ 2 | --model_name_or_path deepset/xlm-roberta-large-squad2 \ 3 | --tokenizer_name deepset/xlm-roberta-large-squad2 \ 4 | --config_name deepset/xlm-roberta-large-squad2 \ 5 | --learning_rate 0.000005 \ 6 | --num_train_epoch 2 \ 7 | --per_device_train_batch_size 16 \ 8 | --per_device_eval_batch_size 16 \ 9 | --dataset_name ai_hub \ 10 | --use_custom_model QAConvModelV2 \ 11 | --run_name baseline_pretrain -------------------------------------------------------------------------------- /code/script/train.sh: -------------------------------------------------------------------------------- 1 | python3 train_mrc.py --output_dir ../output \ 2 | --model_name_or_path baseline_pretrain \ 3 | --tokenizer_name deepset/xlm-roberta-large-squad2 \ 4 | --config_name deepset/xlm-roberta-large-squad2 \ 5 | --learning_rate 0.000005 \ 6 | --num_train_epoch 2 \ 7 | --per_device_train_batch_size 16 \ 8 | --per_device_eval_batch_size 16 \ 9 | --dataset_name question_type \ 10 | --use_custom_model QAConvModelV2 \ 11 | --use_pretrained_model \ 12 | --run_name baseline_train -------------------------------------------------------------------------------- /code/elasticsearch_retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | from elasticsearch import Elasticsearch 6 | from datasets import load_from_disk 7 | from torch.utils.data import DataLoader, TensorDataset 8 | from subprocess import Popen, PIPE, STDOUT 9 | from tqdm import tqdm 10 | 11 | def elastic_setting(index_name='wiki-index'): 12 | config = {'host':'localhost', 'port':9200} 13 | es = Elasticsearch([config]) 14 | 15 | return es, index_name 16 | 17 | 18 | def search_es(es_obj, index_name, question_text, n_results): 19 | # search query 20 | query = { 21 | 'query': { 22 | 'match': { 23 | 'document_text': question_text 24 | } 25 | } 26 | } 27 | # n_result => 상위 몇개를 선택? 28 | res = es_obj.search(index=index_name, body=query, size=n_results) 29 | 30 | return res 31 | 32 | 33 | def elastic_retrieval(es, index_name, question_text, n_results): 34 | res = search_es(es, index_name, question_text, n_results) 35 | # 매칭된 context만 list형태로 만든다. 36 | context_list = list((hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']) 37 | return context_list 38 | -------------------------------------------------------------------------------- /code/script/run_elastic_search.sh: -------------------------------------------------------------------------------- 1 | # elasticsearch-7.6.2 설치 2 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.6.2-linux-x86_64.tar.gz -q -P ../etc/ 3 | tar -xzf ../etc/elasticsearch-7.6.2-linux-x86_64.tar.gz -C ../etc/ 4 | chown -R daemon:daemon ../etc/elasticsearch-7.6.2 5 | 6 | # Python Library 설치 7 | pip install elasticsearch 8 | pip install tqdm 9 | 10 | # nori Tokenizer 설치 11 | ../etc/elasticsearch-7.6.2/bin/elasticsearch-plugin install analysis-nori 12 | 13 | # elastic search stop word 설정 14 | mkdir ../etc/elasticsearch-7.6.2/config/user_dic 15 | cp ../etc/my_stop_dic.txt ../etc/elasticsearch-7.6.2/config/user_dic/. 16 | 17 | # python script file 실행 18 | python3 run_elastic_search.py --path_to_elastic ../etc/elasticsearch-7.6.2/bin/elasticsearch --index_name wiki-index 19 | python3 run_elastic_search.py --path_to_elastic ../etc/elasticsearch-7.6.2/bin/elasticsearch --index_name wiki-index-split-400 20 | python3 run_elastic_search.py --path_to_elastic ../etc/elasticsearch-7.6.2/bin/elasticsearch --index_name wiki-index-split-800 21 | python3 run_elastic_search.py --path_to_elastic ../etc/elasticsearch-7.6.2/bin/elasticsearch --index_name wiki-index-split-1000 22 | 23 | # elastic search 실행 여부 확인 24 | ps -ef | grep elastic -------------------------------------------------------------------------------- /code/question_labeling/data_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import torch 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | from transformers import AutoTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, BertConfig, ElectraForSequenceClassification, AdamW 7 | 8 | class RE_Dataset(Dataset): 9 | def __init__(self, tokenized_dataset, labels): 10 | self.tokenized_dataset = tokenized_dataset 11 | self.labels = labels 12 | 13 | def __getitem__(self, idx): 14 | item = {key: torch.tensor(val[idx]) for key, val in self.tokenized_dataset.items()} 15 | item['labels'] = torch.tensor(self.labels[idx]) 16 | 17 | 18 | input_ids = item["input_ids"] 19 | attention_mask = item["attention_mask"] 20 | label = item["labels"] 21 | 22 | return input_ids, attention_mask, label 23 | 24 | def __len__(self): 25 | return len(self.labels) 26 | 27 | def tokenized_dataset(dataset, tokenizer): 28 | 29 | label = dataset["question_type"] 30 | tokenized_sentences = tokenizer( 31 | dataset["question"], 32 | return_tensors="pt", 33 | padding=True, 34 | truncation=True, 35 | max_length=50, 36 | add_special_tokens=True, 37 | ) 38 | 39 | return tokenized_sentences, label 40 | 41 | def tokenized_testset(dataset, tokenizer): 42 | label = [0 for i in range(len(dataset))] 43 | tokenized_sentences = tokenizer( 44 | dataset["question"], 45 | return_tensors="pt", 46 | padding=True, 47 | truncation=True, 48 | max_length=50, 49 | add_special_tokens=True, 50 | ) 51 | 52 | return tokenized_sentences, label -------------------------------------------------------------------------------- /code/retrieval_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | 5 | from torch import nn 6 | from transformers import AutoModel, AutoConfig 7 | 8 | class BertPooler(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 12 | self.activation = nn.Tanh() 13 | 14 | def forward(self, hidden_states): 15 | # We "pool" the model by simply taking the hidden state corresponding 16 | # to the first token. 17 | first_token_tensor = hidden_states[:, 0] 18 | pooled_output = self.dense(first_token_tensor) 19 | pooled_output = self.activation(pooled_output) 20 | return pooled_output 21 | 22 | class Encoder(nn.Module): 23 | def __init__(self, model_checkpoint): 24 | super(Encoder, self).__init__() 25 | self.model_checkpoint = model_checkpoint 26 | config = AutoConfig.from_pretrained(self.model_checkpoint) 27 | 28 | if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator': 29 | self.pooler = BertPooler(config) 30 | config = AutoConfig.from_pretrained(self.model_checkpoint) 31 | self.model = AutoModel.from_pretrained(self.model_checkpoint, config=config) 32 | 33 | def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None): 34 | outputs = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids) 35 | if self.model_checkpoint == 'monologg/koelectra-base-v3-discriminator': 36 | sequence_output = outputs[0] 37 | pooled_output = self.pooler(sequence_output) 38 | else: 39 | pooled_output = outputs[1] 40 | return pooled_output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📖 Stage 3 - MRC (Machine Reading Comprehension) ✏️ 2 | > More explanations of our codes are in 🔎[TEAM-IKYO's WIKI](https://github.com/TEAM-IKYO/Open-Domain-Question-Answering/wiki) 3 | 4 | 5 | ![TEAM-IKYO’s ODQA Final Score_v3](https://user-images.githubusercontent.com/45359366/121469986-6ea03e80-c9f8-11eb-9fd1-355e50537450.png) 6 | 7 | ## 💻 Task Description 8 | > More Description can be found at [AI Stages](http://boostcamp.stages.ai/competitions/31/overview/description) 9 | 10 | When you have a question like, "What is the oldest tree in Korea?", you may have asked it at a search engine. And these days it gives you an especially surprisingly accurate answer. How is it possible? 11 | 12 | Question Answering is a field of research that creates artificial intelligence model that answers various kinds of questions. Among them, Open-Domain Question Answering is a more challenging issue because it has to find documents that can answer questions only using pre-built knowledge resources. 13 | 14 | ![image](https://user-images.githubusercontent.com/59340911/119260267-118d4600-bc0d-11eb-95bc-6ea68f7b0df4.png) 15 | 16 | The model we'll create in this competition is made up of 2 stages.The first stage is called "retriever", which is the step to find question-related documents, and the next stage is called "reader", which is the step to read the document we've found at 1st stage and find the answer in the document. If we concatenate these two stages properly, we can make a question answering system that can answer no matter how tough questions are. The team which create a model that makes a more accurate answer will win this stage. 17 | ![image](https://user-images.githubusercontent.com/59340911/119260915-f1ab5180-bc0f-11eb-9ddc-cad4585bc8ce.png) 18 | 19 | --- 20 | 21 | ## 🗂 Directory 22 | ``` 23 | p3-mrc-team-ikyo 24 | ├── code 25 | │ ├── arguments.py 26 | │ ├── data_processing.py 27 | │ ├── elasticsearch_retrieval.py 28 | │ ├── inference.py 29 | │ ├── mask.py 30 | │ ├── mk_retrieval_dataset.py 31 | │ ├── model 32 | │ │ ├── ConvModel.py 33 | │ │ ├── QAConvModelV1.py 34 | │ │ ├── QAConvModelV2.py 35 | │ │ └── QueryAttentionModel.py 36 | │ ├── prepare_dataset.py 37 | │ ├── question_labeling 38 | │ │ ├── data_set.py 39 | │ │ ├── question_labeling.py 40 | │ │ └── train.py 41 | │ ├── requirements.txt 42 | │ ├── retrieval_dataset.py 43 | │ ├── retrieval_inference.py 44 | │ ├── retrieval_model.py 45 | │ ├── retrieval_train.py 46 | │ ├── run_elastic_search.py 47 | │ ├── script 48 | │ │ ├── inference.sh 49 | │ │ ├── pretrain.sh 50 | │ │ ├── retrieval_inference.sh 51 | │ │ ├── retrieval_prepare_dataset.sh 52 | │ │ ├── retrieval_train.sh 53 | │ │ ├── run_elastic_search.sh 54 | │ │ └── train.sh 55 | │ ├── train_mrc.py 56 | │ ├── trainer_qa.py 57 | │ └── utils_qa.py 58 | └── etc 59 | └── my_stop_dic.txt 60 | ``` 61 | -------------------------------------------------------------------------------- /code/model/ConvModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | 7 | class ConvModel(nn.Module): 8 | def __init__(self, model_name, model_config, tokenizer_name): 9 | super().__init__() 10 | self.model_name = model_name 11 | self.tokenizer_name = tokenizer_name 12 | self.backbone_model = AutoModel.from_pretrained(model_name, config=model_config) 13 | self.conv1d_layer1 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=1) 14 | self.conv1d_layer3 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=3, padding=1) 15 | self.conv1d_layer5 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=5, padding=2) 16 | self.dropout = nn.Dropout(0.3) 17 | self.dense_layer = nn.Linear(1024 * 3, 2, bias=True) 18 | 19 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None): 20 | if "xlm" in self.tokenizer_name: 21 | outputs = self.backbone_model( 22 | input_ids, 23 | attention_mask=attention_mask, 24 | token_type_ids=token_type_ids, 25 | head_mask=head_mask, 26 | inputs_embeds=inputs_embeds, 27 | output_attentions=output_attentions, 28 | output_hidden_states=output_hidden_states, 29 | return_dict=return_dict, 30 | ) 31 | 32 | else: 33 | outputs = self.backbone_model( 34 | input_ids, 35 | attention_mask=attention_mask, 36 | token_type_ids=token_type_ids, 37 | position_ids=position_ids, 38 | head_mask=head_mask, 39 | inputs_embeds=inputs_embeds, 40 | output_attentions=output_attentions, 41 | output_hidden_states=output_hidden_states, 42 | return_dict=return_dict, 43 | ) 44 | 45 | sequence_output = outputs[0] # Convolution 연산을 위해 Transpose (B * hidden_size * max_seq_legth) 46 | conv_input = sequence_output.transpose(1, 2) # Conv 연산을 위한 Transpose (B * hidden_size * max_seq_length) 47 | conv_output1 = F.relu(self.conv1d_layer1(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 48 | conv_output3 = F.relu(self.conv1d_layer3(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 49 | conv_output5 = F.relu(self.conv1d_layer5(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 50 | concat_output = torch.cat((conv_output1, conv_output3, conv_output5), dim=1) # Concatenation (B * num_conv_filter x 3 * max_seq_legth) 51 | 52 | start_logits, end_logits = logits.split(1, dim=-1) 53 | start_logits = start_logits.squeeze(-1) 54 | end_logits = end_logits.squeeze(-1) 55 | 56 | return {"start_logits" : start_logits, "end_logits" : end_logits, "hidden_states" : outputs.hidden_states, "attentions" : outputs.attentions} -------------------------------------------------------------------------------- /code/requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==2.2.0 2 | argon2-cffi==20.1.0 3 | async-generator==1.10 4 | attrs==20.3.0 5 | Babel==2.9.0 6 | backcall==0.2.0 7 | bcrypt==3.2.0 8 | beautifulsoup4==4.6.0 9 | bleach==3.3.0 10 | boto3==1.17.57 11 | botocore==1.20.57 12 | certifi==2020.12.5 13 | cffi==1.14.0 14 | chardet==3.0.4 15 | click==7.1.2 16 | colorama==0.4.4 17 | conda==4.10.1 18 | conda-build==3.18.11 19 | conda-package-handling==1.7.0 20 | configparser==5.0.2 21 | cryptography==2.9.2 22 | cycler==0.10.0 23 | datasets==1.5.0 24 | decorator==4.4.2 25 | defusedxml==0.7.1 26 | deprecation==2.1.0 27 | dill==0.3.3 28 | docker-pycreds==0.4.0 29 | elasticsearch==7.12.1 30 | entrypoints==0.3 31 | faiss==1.7.0 32 | faiss-cpu==1.7.0 33 | filelock==3.0.12 34 | fsspec==2021.4.0 35 | future==0.18.2 36 | gensim==4.0.1 37 | gevent==21.1.2 38 | gitdb==4.0.7 39 | GitPython==3.1.17 40 | glob2==0.7 41 | greenlet==1.0.0 42 | huggingface-hub==0.0.8 43 | idna==2.9 44 | importlib-metadata==4.0.1 45 | inotify-simple==1.2.1 46 | ipykernel==5.5.3 47 | ipython==7.16.1 48 | ipython-genutils==0.2.0 49 | ipywidgets==7.6.3 50 | jedi==0.17.1 51 | Jinja2==2.11.2 52 | jmespath==0.10.0 53 | joblib==1.0.1 54 | JPype1==1.2.1 55 | json5==0.9.5 56 | jsonschema==3.2.0 57 | jupyter-client==6.1.12 58 | jupyter-core==4.7.1 59 | jupyter-packaging==0.9.2 60 | jupyter-server==1.6.4 61 | jupyterlab==3.0.14 62 | jupyterlab-pygments==0.1.2 63 | jupyterlab-server==2.4.0 64 | jupyterlab-widgets==1.0.0 65 | kiwisolver==1.3.1 66 | konlpy==0.5.2 67 | kss==2.5.0 68 | libarchive-c==2.9 69 | lxml==4.6.3 70 | MarkupSafe==1.1.1 71 | matplotlib==3.4.2 72 | mecab-python===0.996-ko-0.9.2 73 | mistune==0.8.4 74 | mkl-fft==1.1.0 75 | mkl-random==1.1.1 76 | mkl-service==2.3.0 77 | multiprocess==0.70.11.1 78 | nbclassic==0.2.7 79 | nbclient==0.5.3 80 | nbconvert==6.0.7 81 | nbformat==5.1.3 82 | nest-asyncio==1.5.1 83 | network==0.1 84 | networkx==2.5.1 85 | nltk==3.6.2 86 | node2vec==0.4.3 87 | notebook==6.3.0 88 | numpy==1.18.5 89 | oauthlib==3.1.0 90 | olefile==0.46 91 | packaging==20.9 92 | pandas==1.1.4 93 | pandocfilters==1.4.3 94 | paramiko==2.7.2 95 | parso==0.7.0 96 | pathtools==0.1.2 97 | pexpect==4.8.0 98 | pickleshare==0.7.5 99 | Pillow==7.2.0 100 | pkginfo==1.5.0.1 101 | prometheus-client==0.10.1 102 | promise==2.3 103 | prompt-toolkit==3.0.5 104 | protobuf==3.15.8 105 | psutil==5.7.0 106 | ptyprocess==0.6.0 107 | pyarrow==3.0.0 108 | pycosat==0.6.3 109 | pycparser==2.20 110 | Pygments==2.6.1 111 | PyNaCl==1.4.0 112 | pyOpenSSL==19.1.0 113 | pyparsing==2.4.7 114 | pyrsistent==0.17.3 115 | PySocks==1.7.1 116 | python-dateutil==2.8.1 117 | pytz==2020.1 118 | PyYAML==5.3.1 119 | pyzmq==22.0.3 120 | regex==2021.4.4 121 | requests==2.23.0 122 | requests-oauthlib==1.3.0 123 | retrying==1.3.3 124 | ruamel-yaml==0.15.87 125 | s3transfer==0.4.2 126 | sacremoses==0.0.45 127 | sagemaker-training==3.9.1 128 | scikit-learn==0.24.1 129 | scipy==1.6.3 130 | Send2Trash==1.5.0 131 | sentence-transformers==1.1.1 132 | sentencepiece==0.1.95 133 | sentry-sdk==1.1.0 134 | shortuuid==1.0.1 135 | six==1.14.0 136 | smart-open==5.0.0 137 | smmap==4.0.0 138 | sniffio==1.2.0 139 | soupsieve==2.0.1 140 | subprocess32==3.5.4 141 | terminado==0.9.4 142 | testpath==0.4.4 143 | threadpoolctl==2.1.0 144 | tokenizers==0.10.2 145 | tomlkit==0.7.0 146 | torch==1.6.0 147 | torchvision==0.7.0 148 | tornado==6.1 149 | tqdm==4.41.1 150 | traitlets==4.3.3 151 | transformers==4.6.0 152 | tweepy==3.10.0 153 | typing-extensions==3.7.4.3 154 | urllib3==1.25.8 155 | wandb==0.10.30 156 | wcwidth==0.2.5 157 | webencodings==0.5.1 158 | Werkzeug==1.0.1 159 | widgetsnbextension==3.5.1 160 | xxhash==2.0.2 161 | zipp==3.4.1 162 | zope.event==4.5.0 163 | zope.interface==5.4.0 164 | -------------------------------------------------------------------------------- /code/question_labeling/question_labeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from data_set import * 6 | from datasets import load_metric, load_from_disk, load_dataset, Features, Value, Sequence, DatasetDict, Dataset 7 | 8 | def get_pickle(pickle_path): 9 | '''Custom Dataset을 Load하기 위한 함수''' 10 | f = open(pickle_path, "rb") 11 | dataset = pickle.load(f) 12 | f.close() 13 | 14 | return dataset 15 | 16 | 17 | def get_data(): 18 | tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large") 19 | data = get_pickle("../../data/concat_train.pkl") 20 | 21 | train_token, train_labels = tokenized_testset(data["train"], tokenizer) 22 | val_token, val_labels = tokenized_testset(data["validation"], tokenizer) 23 | 24 | train_set = RE_Dataset(train_token, train_labels) 25 | val_set = RE_Dataset(val_token, val_labels) 26 | 27 | train_iter = DataLoader(train_set, batch_size=1) 28 | val_iter = DataLoader(val_set, batch_size=1) 29 | 30 | return train_iter, val_iter 31 | 32 | 33 | def question_labeling(model, train_iter, val_iter): 34 | train_file = get_pickle("../../data/concat_train.pkl")["train"] 35 | validation_file = get_pickle("../../data/concat_train.pkl")["validation"] 36 | 37 | train_qa = [{"id" : train_file[i]["id"], "question" : train_file[i]["question"], "answers" : train_file[i]["answers"], "context" : train_file[i]["context"]} for i in range(len(train_file))] 38 | validation_qa = [{"id" : validation_file[i]["id"], "question" : validation_file[i]["question"], "answers" : validation_file[i]["answers"], "context" : validation_file[i]["context"]} for i in range(len(validation_file))] 39 | 40 | device = "cuda:0" 41 | for step, (input_ids, attention_mask, labels) in tqdm(enumerate(train_iter), total=len(train_iter), position=0, leave=True): 42 | score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0] 43 | pred = torch.argmax(score, 1).detach().cpu().numpy() 44 | train_qa[step]["question_type"] = pred 45 | 46 | for step, (input_ids, attention_mask, labels) in tqdm(enumerate(val_iter), total=len(val_iter), position=0, leave=True): 47 | score = model(input_ids.to(device), attention_mask=attention_mask.to(device))[0] 48 | pred = torch.argmax(score, 1).detach().cpu().numpy() 49 | validation_qa[step]["question_type"] = pred 50 | 51 | train_df = pd.DataFrame(train_qa) 52 | val_df = pd.DataFrame(validation_qa) 53 | 54 | return train_df, val_df 55 | 56 | 57 | def save_data(train_df, val_df): 58 | train_f = Features({'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None), 59 | 'context': Value(dtype='string', id=None), 60 | 'id': Value(dtype='string', id=None), 61 | 'question': Value(dtype='string', id=None), 62 | 'question_type' : Value(dtype='int32', id=None)}) 63 | 64 | train_datasets = DatasetDict({'train': Dataset.from_pandas(train_df, features=train_f), 'validation': Dataset.from_pandas(val_df, features=train_f)}) 65 | file = open("../../data/question_type.pkl", "wb") 66 | pickle.dump(train_datasets, file) 67 | file.close() 68 | 69 | 70 | def main(): 71 | model = torch.load("../../output/question_model.pt") 72 | train_iter, val_iter = get_data() 73 | train_df, val_df = question_labeling(model, train_iter, val_iter) 74 | save_data(train_df, val_df) 75 | 76 | if __name__ == "__main__": 77 | main() 78 | data_set = get_pickle("../../data/question_type.pkl") 79 | print(data_set) 80 | -------------------------------------------------------------------------------- /code/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from typing import Any, Dict, List, Optional 3 | 4 | @dataclass 5 | class ModelArguments: 6 | """ 7 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 8 | """ 9 | model_name_or_path: str = field( 10 | default="xlm-roberta-large", 11 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 12 | ) 13 | use_custom_model: str = field( 14 | default='QAConvModelV2', 15 | metadata={"help": "Choose one of ['ConvModel', 'QueryAttentionModel', 'QAConvModelV1', 'QAConvModelV2']"} 16 | ) 17 | use_pretrained_model: bool = field( 18 | default=False, 19 | metadata={"help": "use_pretrained_koquard_model"} 20 | ) 21 | config_name: Optional[str] = field( 22 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 23 | ) 24 | tokenizer_name: Optional[str] = field( 25 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 26 | ) 27 | retrieval_type: Optional[str] = field( 28 | default="elastic", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 29 | ) 30 | retrieval_elastic_index: Optional[str] = field( 31 | default="wiki-index-split-800", metadata={"help": "Elastic search index name[wiki-index, wiki-index-split-400, wiki-index-split-800(best), wiki-index-split-1000]"} 32 | ) 33 | retrieval_elastic_num: Optional[int] = field( 34 | default=35, 35 | metadata={"help": "The number of context or passage from Elastic search"}, 36 | ) 37 | 38 | @dataclass 39 | class DataTrainingArguments: 40 | """ 41 | Arguments pertaining to what data we are going to input our model for training and eval. 42 | """ 43 | dataset_name: Optional[str] = field( 44 | default="question_type", metadata={"help": "Choose one of ['basic', 'preprocessed', 'concat', 'korquad', 'only_korquad', 'quetion_type', 'ai_hub', 'random_masking', 'token_masking']"} 45 | ) 46 | overwrite_cache: bool = field( 47 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 48 | ) 49 | preprocessing_num_workers: Optional[int] = field( 50 | default=None, 51 | metadata={"help": "The number of processes to use for the preprocessing."}, 52 | ) 53 | max_seq_length: int = field( 54 | default=384, 55 | metadata={ 56 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 57 | "than this will be truncated, sequences shorter will be padded." 58 | }, 59 | ) 60 | pad_to_max_length: bool = field( 61 | default=False, 62 | metadata={ 63 | "help": "Whether to pad all samples to `max_seq_length`. " 64 | "If False, will pad the samples dynamically when batching to the maximum length in the batch (which can " 65 | "be faster on GPU but will be slower on TPU)." 66 | }, 67 | ) 68 | doc_stride: int = field( 69 | default=128, 70 | metadata={"help": "When splitting up a long document into chunks, how much stride to take between chunks."}, 71 | ) 72 | max_answer_length: int = field( 73 | default=30, 74 | metadata={ 75 | "help": "The maximum length of an answer that can be generated. This is needed because the start " 76 | "and end predictions are not conditioned on one another." 77 | }, 78 | ) 79 | train_retrieval: bool = field( 80 | default=True, 81 | metadata={"help": "Whether to train sparse/dense embedding (prepare for retrieval)."}, 82 | ) 83 | eval_retrieval: bool = field( 84 | default=True, 85 | metadata={"help":"Whether to run passage retrieval using sparse/dense embedding )."}, 86 | ) 87 | 88 | -------------------------------------------------------------------------------- /code/model/QueryAttentionModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | 7 | class QueryAttentionModel(nn.Module): 8 | def __init__(self, model_name, model_config, tokenizer_name): 9 | super().__init__() 10 | self.model_name = model_name 11 | self.model_config = model_config 12 | self.tokenizer_name = tokenizer_name 13 | self.backbone = AutoModel.from_pretrained(model_name, config=model_config) 14 | self.query_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 15 | self.query_calssify_layer = nn.Linear(model_config.hidden_size, 6, bias=True) 16 | self.key_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 17 | self.value_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 18 | self.gelu = nn.GELU() 19 | self.drop_out = nn.Dropout(0.7) 20 | self.classify_layer = nn.Linear(model_config.hidden_size, 2, bias=True) 21 | 22 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, question_type=None): 23 | if "xlm" in self.tokenizer_name: 24 | outputs = self.backbone( 25 | input_ids, 26 | attention_mask=attention_mask, 27 | token_type_ids=token_type_ids, 28 | head_mask=head_mask, 29 | inputs_embeds=inputs_embeds, 30 | output_attentions=output_attentions, 31 | output_hidden_states=output_hidden_states, 32 | return_dict=return_dict, 33 | ) 34 | 35 | else: 36 | outputs = self.backbone( 37 | input_ids, 38 | attention_mask=attention_mask, 39 | token_type_ids=token_type_ids, 40 | position_ids=position_ids, 41 | head_mask=head_mask, 42 | inputs_embeds=inputs_embeds, 43 | output_attentions=output_attentions, 44 | output_hidden_states=output_hidden_states, 45 | return_dict=return_dict, 46 | ) 47 | 48 | sequence_output = outputs[0] # (B * 384 * 1024) 49 | 50 | if not token_type_ids : 51 | token_type_ids = self.make_token_type_ids(input_ids) 52 | 53 | embedded_query = sequence_output * (token_type_ids==0) # 전체 Text 중 query에 해당하는 Embedded Vector만 남김. 54 | embedded_query = self.query_layer(embedded_query) # Dense Layer를 통과 시킴. (B * max_seq_length * hidden_size) 55 | embedded_query = torch.mean(embedded_query, 1, keepdim=True) # Query에 해당하는 Token Embedding을 평균냄. (B * 1 * hidden_size) 56 | query_logits = self.query_calssify_layer(embedded_query.squeeze(1)) # Query의 종류를 예측하는 Branch (B * 6) 57 | 58 | embedded_key = self.key_layer(sequence_output) # (B * max_seq_length * hidden_size) 59 | embedded_value = self.value_layer(sequence_output) # (B * max_seq_length * hidden_size) 60 | 61 | attention_rate = torch.matmul(embedded_key, torch.transpose(embedded_query, 1, 2)) # Context의 Value Vector와 Quetion의 Query Vector를 사용 62 | attention_rate = F.softmax(attention_rate, 1) # Question과 Context의 Attention Rate를 구함. (B * max_seq_length * 1) 63 | 64 | logits = embedded_value * attention_rate # Attention Rate를 활용해서 Output 값을 변경함. 65 | logits = self.gelu(logits) # Activation Function 통과 66 | logits = self.drop_out(logits) # dropout 통과 67 | logits = self.classify_layer(logits) # Classifier Layer를 통해 최종 Logit을 얻음. 68 | 69 | start_logits, end_logits = logits.split(1, dim=-1) 70 | start_logits = start_logits.squeeze(-1) 71 | end_logits = end_logits.squeeze(-1) 72 | 73 | return {"start_logits" : start_logits, "end_logits" : end_logits, "hidden_states" : outputs.hidden_states, "attentions" : outputs.attentions, "query_logits" : query_logits} 74 | 75 | def make_token_type_ids(self, input_ids) : 76 | token_type_ids = [] 77 | for i, input_id in enumerate(input_ids): 78 | sep_idx = np.where(input_id.cpu().numpy() == self.sep_token_id) 79 | token_type_id = [0]*sep_idx[0][0] + [1]*(len(input_id)-sep_idx[0][0]) 80 | token_type_ids.append(token_type_id) 81 | return torch.tensor(token_type_ids).cuda() -------------------------------------------------------------------------------- /code/run_elastic_search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | from subprocess import Popen, PIPE, STDOUT 6 | 7 | from datasets import load_from_disk 8 | from elasticsearch import Elasticsearch 9 | from prepare_dataset import make_custom_dataset 10 | from torch.utils.data import DataLoader, TensorDataset 11 | from tqdm import tqdm 12 | 13 | 14 | def populate_index(es_obj, index_name, evidence_corpus): 15 | 16 | for i, rec in enumerate(tqdm(evidence_corpus)): 17 | try: 18 | index_status = es_obj.index(index=index_name, id=i, body=rec) 19 | except: 20 | print(f'Unable to load document {i}.') 21 | 22 | n_records = es_obj.count(index=index_name)['count'] 23 | print(f'Succesfully loaded {n_records} into {index_name}') 24 | return 25 | 26 | 27 | def set_datas(args) : 28 | if not os.path.isfile("../data/preprocess_train.pkl") : 29 | make_custom_dataset("../data/preprocess_train.pkl") 30 | # train_file = load_from_disk("../data/train_dataset")["train"] 31 | # validation_file = load_from_disk("../data/train_dataset")["validation"] 32 | train_file = load_from_disk("../data/train_dataset")["train"] 33 | validation_file = load_from_disk("../data/train_dataset")["validation"] 34 | 35 | #[wiki-index, wiki-index-split-400, wiki-index-split-800, wiki-index-split-1000] 36 | if args.index_name == 'wiki-index': 37 | dataset_path = "../data/preprocess_wiki.json" 38 | elif args.index_name == 'wiki-index-split-400': 39 | dataset_path = "../data/split_wiki_400.json" 40 | elif args.index_name == 'wiki-index-split-800': 41 | dataset_path = "../data/split_wiki_800.json" 42 | elif args.index_name == 'wiki-index-split-1000': 43 | dataset_path = "../data/split_wiki_1000.json" 44 | 45 | if not os.path.isfile(dataset_path) : 46 | print(dataset_path) 47 | make_custom_dataset(dataset_path) 48 | 49 | with open(dataset_path, "r") as f: 50 | wiki = json.load(f) 51 | wiki_contexts = list(dict.fromkeys([v['text'] for v in wiki.values()])) 52 | 53 | qa_records = [{"example_id" : train_file[i]["id"], "document_title" : train_file[i]["title"], "question_text" : train_file[i]["question"], "answer" : train_file[i]["answers"]} for i in range(len(train_file))] 54 | wiki_articles = [{"document_text" : wiki_contexts[i]} for i in range(len(wiki_contexts))] 55 | return qa_records, wiki_articles 56 | 57 | 58 | def set_index_and_server(args) : 59 | es_server = Popen([args.path_to_elastic], 60 | stdout=PIPE, stderr=STDOUT, 61 | preexec_fn=lambda: os.setuid(1) # as daemon 62 | ) 63 | time.sleep(30) 64 | 65 | config = {'host':'localhost', 'port':9200} 66 | es = Elasticsearch([config]) 67 | 68 | index_config = { 69 | "settings": { 70 | "analysis": { 71 | "filter":{ 72 | "my_stop_filter": { 73 | "type" : "stop", 74 | "stopwords_path" : "user_dic/my_stop_dic.txt" 75 | } 76 | }, 77 | "analyzer": { 78 | "nori_analyzer": { 79 | "type": "custom", 80 | "tokenizer": "nori_tokenizer", 81 | "decompound_mode": "mixed", 82 | "filter" : ["my_stop_filter"] 83 | } 84 | } 85 | } 86 | }, 87 | "mappings": { 88 | "dynamic": "strict", 89 | "properties": { 90 | "document_text": {"type": "text", "analyzer": "nori_analyzer"} 91 | } 92 | } 93 | } 94 | 95 | print('elastic serach ping :', es.ping()) 96 | print(es.indices.create(index=args.index_name, body=index_config, ignore=400)) 97 | 98 | return es 99 | 100 | 101 | def main(args) : 102 | print('Start to Set Elastic Search') 103 | _, wiki_articles = set_datas(args) 104 | es = set_index_and_server(args) 105 | populate_index(es_obj=es, index_name=args.index_name, evidence_corpus=wiki_articles) 106 | print('Finish') 107 | 108 | 109 | if __name__ == '__main__' : 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--path_to_elastic', type=str, default='elasticsearch-7.6.2/bin/elasticsearch', help='Path to Elastic search') 112 | parser.add_argument('--index_name', type=str, default='wiki-index', help='Elastic search index name[wiki-index, wiki-index-split-400, wiki-index-split-800, wiki-index-split-1000]') 113 | 114 | args = parser.parse_args() 115 | main(args) -------------------------------------------------------------------------------- /code/mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import re 4 | from numpy import dot 5 | from numpy.linalg import norm 6 | from scipy.spatial import distance 7 | from scipy.stats import pearsonr 8 | from sentence_transformers import SentenceTransformer, util 9 | from datasets import Dataset 10 | 11 | def cos_sim(A, B): 12 | return dot(A, B) / (norm(A) * norm(B)) 13 | 14 | 15 | def make_word_index_dict(tokens, cls_token, sep_token): 16 | word_start = False 17 | word_index = {} 18 | word = '' 19 | index = [] 20 | 21 | for i, t in enumerate(tokens): 22 | if t == cls_token: 23 | continue 24 | elif t == sep_token: 25 | break 26 | if t.startswith('▁') and not word_start: 27 | word_start = True 28 | word += t 29 | index.append(i) 30 | if tokens[i+1].startswith('▁'): 31 | word_start = False 32 | word_index[word.replace('▁', '')] = index 33 | word = '' 34 | index = [] 35 | if not t.startswith('▁') and word_start: 36 | word += t 37 | index.append(i) 38 | if i < 383 and (tokens[i+1].startswith('▁') or tokens[i+1] == sep_token): 39 | word_start = False 40 | word_index[word.replace('▁', '')] = index 41 | word = '' 42 | index = [] 43 | 44 | return word_index 45 | 46 | 47 | def mask_to_tokens(batch, tokenizer, top_k, model): 48 | ''' 49 | Span 단위로 Random Masking을 적용하는 함수 50 | ''' 51 | mask_token = tokenizer.mask_token_id 52 | 53 | for i, input_id in enumerate(batch["input_ids"]): 54 | sep_idx = np.where(input_id.numpy() == tokenizer.sep_token_id)[0][0] 55 | pad_idx = 0 56 | if tokenizer.pad_token_id in input_id.numpy(): 57 | pad_idx = np.where(input_id.numpy() == tokenizer.pad_token_id)[0][0] 58 | tokenizer.pad_token_id 59 | question = tokenizer.decode(input_id[1:sep_idx]) # sep_idx[0][0]: 첫 번째 sep 토큰 위치 60 | answer = tokenizer.decode(input_id[batch['start_positions'][i]:batch['end_positions'][i]+1]) 61 | context = None 62 | if pad_idx == 0: 63 | context = tokenizer.decode(input_id[sep_idx+2:-1]) 64 | else: 65 | context = tokenizer.decode(input_id[sep_idx+2:pad_idx-1]) 66 | q_emb = model.encode(question) 67 | tokens = tokenizer.convert_ids_to_tokens(input_id) 68 | 69 | word_dict = make_word_index_dict(tokens, tokenizer, answer) 70 | 71 | sim_dict = {} 72 | for word in word_dict.keys(): 73 | sim = cos_sim(q_emb, model.encode(word)) 74 | if sim > 0.35: 75 | sim_dict[sim] = word_dict[word] 76 | 77 | ordered_sim_dict = sorted(sim_dict.items(), reverse=True) 78 | tokens_to_mask = [] 79 | if len(ordered_sim_dict) < top_k: 80 | for val in ordered_sim_dict: 81 | tokens_to_mask.extend(val[1]) 82 | else: 83 | for val in ordered_sim_dict[:top_k]: 84 | tokens_to_mask.extend(val[1]) 85 | 86 | for token_idx in list(tokens_to_mask): 87 | input_id[token_idx] = mask_token 88 | 89 | batch["input_ids"][i] = input_id 90 | 91 | return batch 92 | 93 | 94 | def mask_to_random(dataset): 95 | context_list = [] 96 | question_list = [] 97 | id_list = [] 98 | answer_list = [] 99 | train_dataset = dataset["train"] 100 | 101 | for i in tqdm(range(train_dataset.num_rows)): 102 | question = train_dataset["question"] 103 | 104 | for word, pos in mecab.pos(text): 105 | # first_word = True 106 | # 첫번째 단어는 무조건 Masking(질문 중 가장 중요한 의미를 가지고 있다고 생각) 107 | # 두번째 단어부터는 20% 확률로 Masking 108 | # 하나의 단어만 Masking 109 | if pos in {"NNG", "NNP"} and (random.random() > 0.8): 110 | context_list.append(train_dataset["context"]) 111 | question_list.append(re.sub(word, "MASK", question)) # tokenizer.mask_token 112 | id_list.append(train_dataset[i]["id"]) 113 | answer_list.append(train_dataset[i]["answers"]) 114 | 115 | random.Random(2021).shuffle(context_list) 116 | random.Random(2021).shuffle(question_list) 117 | random.Random(2021).shuffle(id_list) 118 | random.Random(2021).shuffle(answer_list) 119 | 120 | dataset["train"] = Dataset.from_dict({"id" : id_list, 121 | "context": context_list, 122 | "question": question_list, 123 | "answers": answer_list}) 124 | 125 | return dataset["train"] -------------------------------------------------------------------------------- /code/trainer_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | A subclass of `Trainer` specific to Question-Answering tasks 17 | """ 18 | 19 | from transformers import Trainer, is_datasets_available, is_torch_tpu_available 20 | from transformers.trainer_utils import PredictionOutput 21 | 22 | 23 | if is_datasets_available(): 24 | import datasets 25 | 26 | if is_torch_tpu_available(): 27 | import torch_xla.core.xla_model as xm 28 | import torch_xla.debug.metrics as met 29 | 30 | 31 | class QuestionAnsweringTrainer(Trainer): 32 | def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): 33 | super().__init__(*args, **kwargs) 34 | self.eval_examples = eval_examples 35 | self.post_process_function = post_process_function 36 | 37 | def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None): 38 | eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset 39 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 40 | eval_examples = self.eval_examples if eval_examples is None else eval_examples 41 | 42 | # Temporarily disable metric computation, we will do it in the loop here. 43 | compute_metrics = self.compute_metrics 44 | self.compute_metrics = None 45 | try: 46 | output = self.prediction_loop( 47 | eval_dataloader, 48 | description="Evaluation", 49 | # No point gathering the predictions if there are no metrics, otherwise we defer to 50 | # self.args.prediction_loss_only 51 | prediction_loss_only=True if compute_metrics is None else None, 52 | ignore_keys=ignore_keys, 53 | ) 54 | finally: 55 | self.compute_metrics = compute_metrics 56 | 57 | # We might have removed columns from the dataset so we put them back. 58 | if isinstance(eval_dataset, datasets.Dataset): 59 | eval_dataset.set_format( 60 | type=eval_dataset.format["type"], 61 | columns=list(eval_dataset.features.keys()), 62 | ) 63 | 64 | if self.post_process_function is not None and self.compute_metrics is not None: 65 | eval_preds = self.post_process_function( 66 | eval_examples, eval_dataset, output.predictions, self.args 67 | ) 68 | metrics = self.compute_metrics(eval_preds) 69 | 70 | self.log(metrics) 71 | else: 72 | metrics = {} 73 | 74 | if self.args.tpu_metrics_debug or self.args.debug: 75 | # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 76 | xm.master_print(met.metrics_report()) 77 | 78 | self.control = self.callback_handler.on_evaluate( 79 | self.args, self.state, self.control, metrics 80 | ) 81 | return metrics 82 | 83 | def predict(self, test_dataset, test_examples, ignore_keys=None): 84 | test_dataloader = self.get_test_dataloader(test_dataset) 85 | 86 | # Temporarily disable metric computation, we will do it in the loop here. 87 | compute_metrics = self.compute_metrics 88 | self.compute_metrics = None 89 | try: 90 | output = self.prediction_loop( 91 | test_dataloader, 92 | description="Evaluation", 93 | # No point gathering the predictions if there are no metrics, otherwise we defer to 94 | # self.args.prediction_loss_only 95 | prediction_loss_only=True if compute_metrics is None else None, 96 | ignore_keys=ignore_keys, 97 | ) 98 | finally: 99 | self.compute_metrics = compute_metrics 100 | 101 | if self.post_process_function is None or self.compute_metrics is None: 102 | return output 103 | 104 | # We might have removed columns from the dataset so we put them back. 105 | if isinstance(test_dataset, datasets.Dataset): 106 | test_dataset.set_format( 107 | type=test_dataset.format["type"], 108 | columns=list(test_dataset.features.keys()), 109 | ) 110 | 111 | predictions = self.post_process_function( 112 | test_examples, test_dataset, output.predictions, self.args 113 | ) 114 | return predictions 115 | -------------------------------------------------------------------------------- /code/data_processing.py: -------------------------------------------------------------------------------- 1 | from datasets import load_metric, load_from_disk 2 | from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer 3 | 4 | class DataProcessor(): 5 | def __init__(self, tokenizer, max_length = 384, doc_stride = 128): 6 | self.tokenizer = tokenizer 7 | self.max_length = max_length 8 | self.doc_stride = doc_stride 9 | 10 | def prepare_train_features(self, examples): 11 | tokenized_examples = self.tokenizer( 12 | examples["question"], 13 | examples["context"], 14 | truncation="only_second", 15 | max_length=self.max_length, 16 | stride=self.doc_stride, 17 | return_overflowing_tokens=True, 18 | return_offsets_mapping=True, 19 | padding="max_length", 20 | ) 21 | 22 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 23 | offset_mapping = tokenized_examples.pop("offset_mapping") 24 | 25 | tokenized_examples["start_positions"] = [] 26 | tokenized_examples["end_positions"] = [] 27 | if 'question_type' in examples.keys() : 28 | tokenized_examples['question_type'] = [] 29 | 30 | for i, offsets in enumerate(offset_mapping): 31 | input_ids = tokenized_examples["input_ids"][i] 32 | sequence_ids = tokenized_examples.sequence_ids(i) 33 | cls_index = input_ids.index(self.tokenizer.cls_token_id) 34 | 35 | sample_index = sample_mapping[i] 36 | answers = examples["answers"][sample_index] 37 | 38 | if 'question_type' in examples.keys() : 39 | tokenized_examples['question_type'].append(examples['question_type'][sample_index]) 40 | 41 | if len(answers["answer_start"]) == 0: 42 | tokenized_examples["start_positions"].append(cls_index) 43 | tokenized_examples["end_positions"].append(cls_index) 44 | else: 45 | start_char = answers["answer_start"][0] 46 | end_char = start_char + len(answers["text"][0]) 47 | 48 | token_start_index = 0 49 | while sequence_ids[token_start_index] != (1): 50 | token_start_index += 1 51 | 52 | token_end_index = len(input_ids) - 1 53 | while sequence_ids[token_end_index] != (1): 54 | token_end_index -= 1 55 | 56 | if not ( 57 | offsets[token_start_index][0] <= start_char 58 | and offsets[token_end_index][1] >= end_char 59 | ): 60 | tokenized_examples["start_positions"].append(cls_index) 61 | tokenized_examples["end_positions"].append(cls_index) 62 | else: 63 | while ( 64 | token_start_index < len(offsets) 65 | and offsets[token_start_index][0] <= start_char 66 | ): 67 | token_start_index += 1 68 | tokenized_examples["start_positions"].append(token_start_index - 1) 69 | while offsets[token_end_index][1] >= end_char: 70 | token_end_index -= 1 71 | tokenized_examples["end_positions"].append(token_end_index + 1) 72 | 73 | return tokenized_examples 74 | 75 | def prepare_validation_features(self, examples): 76 | tokenized_examples = self.tokenizer( 77 | examples["question"], 78 | examples["context"], 79 | truncation="only_second", 80 | max_length=self.max_length, 81 | stride=self.doc_stride, 82 | return_overflowing_tokens=True, 83 | return_offsets_mapping=True, 84 | padding="max_length", 85 | ) 86 | sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") 87 | tokenized_examples["example_id"] = [] 88 | 89 | if 'question_type' in examples.keys() : 90 | tokenized_examples['question_type'] = [] 91 | 92 | for i in range(len(tokenized_examples['input_ids'])): 93 | sequence_ids = tokenized_examples.sequence_ids(i) 94 | context_index = 1 95 | sample_index = sample_mapping[i] 96 | tokenized_examples["example_id"].append(examples["id"][sample_index]) 97 | 98 | if 'question_type' in examples.keys() : 99 | tokenized_examples['question_type'].append(examples['question_type'][sample_index]) 100 | 101 | tokenized_examples["offset_mapping"][i] = [ 102 | (o if sequence_ids[k] == context_index else None) 103 | for k, o in enumerate(tokenized_examples["offset_mapping"][i]) 104 | ] 105 | 106 | return tokenized_examples 107 | 108 | def train_tokenizer(self, train_dataset, column_names): 109 | train_dataset = train_dataset.map( 110 | self.prepare_train_features, 111 | batched=True, 112 | num_proc=4, 113 | remove_columns=column_names, 114 | ) 115 | 116 | return train_dataset 117 | 118 | def val_tokenzier(self, val_dataset, column_names): 119 | val_dataset = val_dataset.map( 120 | self.prepare_validation_features, 121 | batched=True, 122 | num_proc=4, 123 | remove_columns=column_names, 124 | ) 125 | 126 | return val_dataset -------------------------------------------------------------------------------- /code/question_labeling/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import warnings 4 | 5 | import pickle 6 | import torch 7 | import random 8 | import numpy as np 9 | import pandas as pd 10 | import torch.nn.functional as F 11 | 12 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, ElectraForSequenceClassification, AdamW 13 | from torch import nn, optim 14 | from torch.utils.data import DataLoader 15 | from torch.cuda.amp import autocast, GradScaler 16 | from tqdm import tqdm 17 | 18 | from data_set import * 19 | 20 | def seed_everything(seed): 21 | random.seed(seed) 22 | os.environ['PYTHONHASHSEED'] = str(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = True 28 | 29 | def get_pickle(pickle_path): 30 | '''Custom Dataset을 Load하기 위한 함수''' 31 | f = open(pickle_path, "rb") 32 | dataset = pickle.load(f) 33 | f.close() 34 | 35 | return dataset 36 | 37 | def get_data(): 38 | tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large") 39 | ai_hub = get_pickle("../../data/ai_hub_dataset.pkl") 40 | train_token, train_label = tokenized_dataset(ai_hub["train"], tokenizer) 41 | val_token, val_label = tokenized_dataset(ai_hub["validation"], tokenizer) 42 | 43 | train_set = RE_Dataset(train_token, train_label) 44 | val_set = RE_Dataset(val_token, val_label) 45 | 46 | train_iter = DataLoader(train_set, batch_size=16, shuffle=True) 47 | val_iter = DataLoader(val_set, batch_size=16, shuffle=True) 48 | 49 | return train_iter, val_iter 50 | 51 | def get_model(): 52 | network = AutoModelForSequenceClassification.from_pretrained("xlm-roberta-large", num_labels=6, hidden_dropout_prob=0.0).to("cuda:0") 53 | optimizer = AdamW(network.parameters(), lr=5e-6) 54 | scaler = GradScaler() 55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=10, eta_min=1e-6) 56 | criterion = nn.CrossEntropyLoss().to("cuda:0") 57 | 58 | return network, optimizer, scaler, scheduler, criterion 59 | 60 | def training_per_step(model, loss_fn, optimizer, scaler, input_ids, attention_mask, labels, device): 61 | '''매 step마다 학습을 하는 함수''' 62 | model.train() 63 | with autocast(): 64 | labels = labels.to(device) 65 | 66 | preds = model(input_ids.to(device), attention_mask = attention_mask.to(device))[0] 67 | loss = loss_fn(preds, labels) 68 | 69 | scaler.scale(loss).backward() 70 | scaler.step(optimizer) 71 | scaler.update() 72 | optimizer.zero_grad() 73 | 74 | return loss 75 | 76 | def validating_per_steps(epoch, model, loss_fn, test_loader, device): 77 | '''특정 step마다 검증을 하는 함수''' 78 | model.eval() 79 | 80 | loss_sum = 0 81 | sample_num = 0 82 | preds_all = [] 83 | targets_all = [] 84 | 85 | pbar = tqdm(test_loader, total=len(test_loader), position=0, leave=True) 86 | for input_ids, attention_mask, labels in pbar : 87 | labels = labels.to(device) 88 | 89 | preds = model(input_ids.to(device), attention_mask = attention_mask.to(device))[0] 90 | 91 | preds_all += [torch.argmax(preds, 1).detach().cpu().numpy()] 92 | targets_all += [labels.detach().cpu().numpy()] 93 | 94 | loss = loss_fn(preds, labels) 95 | 96 | loss_sum += loss.item()*labels.shape[0] 97 | sample_num += labels.shape[0] 98 | 99 | description = f"epoch {epoch + 1} loss: {loss_sum/sample_num:.4f}" 100 | pbar.set_description(description) 101 | 102 | preds_all = np.concatenate(preds_all) 103 | targets_all = np.concatenate(targets_all) 104 | accuracy = (preds_all == targets_all).mean() 105 | 106 | print(" test accuracy = {:.4f}".format(accuracy)) 107 | 108 | return accuracy 109 | 110 | def train(model, loss_fn, optimizer, scaler, train_loader, test_loader, scheduler, device): 111 | '''training과 validating을 진행하는 함수''' 112 | prev_acc = 0 113 | global_steps = 0 114 | for epoch in range(1): 115 | running_loss = 0 116 | sample_num = 0 117 | preds_all = [] 118 | targets_all = [] 119 | 120 | pbar = tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True) 121 | for step, (input_ids, attention_mask, labels) in pbar: 122 | # training phase 123 | loss = training_per_step(model, loss_fn, optimizer, scaler, input_ids, attention_mask, labels, device) 124 | running_loss += loss.item()*labels.shape[0] 125 | sample_num += labels.shape[0] 126 | 127 | global_steps += 1 128 | description = f"{epoch+1}epoch {global_steps: >4d}step | loss: {running_loss/sample_num: .4f} " 129 | pbar.set_description(description) 130 | 131 | # validating phase 132 | if global_steps % 500 == 0 : 133 | with torch.no_grad(): 134 | acc = validating_per_steps(epoch, model, loss_fn, test_loader, device) 135 | if acc > prev_acc: 136 | torch.save(model, "../../output/question_model.pt") 137 | prev_acc = acc 138 | 139 | if scheduler is not None : 140 | scheduler.step() 141 | 142 | def main(): 143 | seed_everything(2021) 144 | train_iter, val_iter = get_data() 145 | network, optimizer, scaler, scheduler, criterion = get_model() 146 | train(network, criterion, optimizer, scaler, train_iter, val_iter, scheduler, "cuda:0") 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /code/model/QAConvModelV1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | import math 7 | 8 | class QAConvModelV1(nn.Module): 9 | def __init__(self, model_name, model_config, tokenizer_name): 10 | super().__init__() 11 | self.model_name = model_name 12 | self.model_config = model_config 13 | self.sep_token_id = AutoTokenizer.from_pretrained(tokenizer_name).sep_token_id 14 | self.backbone = AutoModel.from_pretrained(model_name, config=model_config) 15 | self.query_drop_out = nn.Dropout(0.1) 16 | self.query_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 17 | self.query_calssify_layer = nn.Linear(model_config.hidden_size, 6, bias=True) 18 | self.key_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 19 | self.value_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 20 | self.conv1d_layer1 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=1) 21 | self.conv1d_layer3 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=3, padding=1) 22 | self.conv1d_layer5 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=5, padding=2) 23 | self.drop_out = nn.Dropout(0.3) 24 | self.classify_layer = nn.Linear(1024*3, 2, bias=True) 25 | 26 | 27 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, question_type=None): 28 | if "xlm" in self.model_name: 29 | outputs = self.backbone( 30 | input_ids, 31 | attention_mask=attention_mask, 32 | token_type_ids=token_type_ids, 33 | head_mask=head_mask, 34 | inputs_embeds=inputs_embeds, 35 | output_attentions=output_attentions, 36 | output_hidden_states=output_hidden_states, 37 | return_dict=return_dict, 38 | ) 39 | 40 | else: 41 | outputs = self.backbone( 42 | input_ids, 43 | attention_mask=attention_mask, 44 | token_type_ids=token_type_ids, 45 | position_ids=position_ids, 46 | head_mask=head_mask, 47 | inputs_embeds=inputs_embeds, 48 | output_attentions=output_attentions, 49 | output_hidden_states=output_hidden_states, 50 | return_dict=return_dict, 51 | ) 52 | 53 | sequence_output = outputs[0] # (B * 384 * 1024) 54 | 55 | if not token_type_ids : 56 | token_type_ids = self.make_token_type_ids(input_ids) 57 | 58 | embedded_query = sequence_output * (token_type_ids.unsqueeze(dim=-1)==0) # 전체 Text 중 query에 해당하는 Embedded Vector만 남김. 59 | embedded_query = self.query_drop_out(F.relu(embedded_query)) # Activation Function 및 Dropout Layer 통과 60 | embedded_query = self.query_layer(embedded_query) # Dense Layer를 통과 시킴. (B * max_seq_length * hidden_size) 61 | embedded_query = torch.mean(embedded_query, 1, keepdim=True) # Query에 해당하는 Token Embedding을 평균냄. (B * 1 * hidden_size) 62 | query_logits = self.query_calssify_layer(embedded_query.squeeze(1)) # Query의 종류를 예측하는 Branch (B * 6) 63 | 64 | embedded_key = sequence_output * (token_type_ids.unsqueeze(dim=-1)==1) # 전체 Text 중 context에 해당하는 Embedded Vector만 남김. 65 | embedded_key = self.key_layer(embedded_key) # (B * max_seq_length * hidden_size) 66 | embedded_value = self.value_layer(sequence_output) # (B * max_seq_length * hidden_size) 67 | attention_rate = torch.matmul(embedded_key, torch.transpose(embedded_query, 1, 2)) # Context의 Value Vector와 Quetion의 Query Vector를 사용 (B * max_seq_legth * 1) 68 | attention_rate = attention_rate / math.sqrt(embedded_key.shape[-1]) # hidden size의 표준편차로 나눠줌. (B * max_seq_legth * 1) 69 | attention_rate = attention_rate / 10 # Temperature로 나눠줌. (B * max_seq_legth * 1) 70 | attention_rate = F.softmax(attention_rate, 1) # softmax를 통과시켜서 확률값으로 변경해, Question과 Context의 Attention Rate를 구함. (B * max_seq_legth * 1) 71 | embedded_value = embedded_value * attention_rate # Attention Rate를 활용해서 Output 값을 변경함. (B * max_seq_legth * hidden_size) 72 | 73 | conv_input = embedded_value.transpose(1, 2) # Convolution 연산을 위해 Transpose (B * hidden_size * max_seq_legth) 74 | conv_output1 = F.relu(self.conv1d_layer1(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 75 | conv_output3 = F.relu(self.conv1d_layer3(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 76 | conv_output5 = F.relu(self.conv1d_layer5(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 77 | concat_output = torch.cat((conv_output1, conv_output3, conv_output5), dim=1) # Concatenation (B * num_conv_filter x 3 * max_seq_legth) 78 | 79 | concat_output = concat_output.transpose(1, 2) # Dense Layer에 입력을 위해 Transpose (B * max_seq_legth * num_conv_filter x 3) 80 | concat_output = self.drop_out(concat_output) # dropout 통과 81 | logits = self.classify_layer(concat_output) # Classifier Layer를 통해 최종 Logit을 얻음. (B * max_seq_legth * 2) 82 | 83 | start_logits, end_logits = logits.split(1, dim=-1) 84 | start_logits = start_logits.squeeze(-1) 85 | end_logits = end_logits.squeeze(-1) 86 | 87 | return {"start_logits" : start_logits, "end_logits" : end_logits, "hidden_states" : outputs.hidden_states, "attentions" : outputs.attentions, "query_logits" : query_logits} 88 | 89 | 90 | def make_token_type_ids(self, input_ids) : 91 | token_type_ids = [] 92 | for i, input_id in enumerate(input_ids): 93 | sep_idx = np.where(input_id.cpu().numpy() == self.sep_token_id) 94 | token_type_id = [0]*sep_idx[0][0] + [1]*(len(input_id)-sep_idx[0][0]) 95 | token_type_ids.append(token_type_id) 96 | return torch.tensor(token_type_ids).cuda() 97 | -------------------------------------------------------------------------------- /code/mk_retrieval_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import os 5 | import re 6 | import time 7 | import json 8 | import torch 9 | import pickle 10 | import argparse 11 | 12 | from konlpy.tag import Mecab 13 | from transformers import AutoTokenizer 14 | from tqdm import tqdm 15 | from tqdm.notebook import tqdm, trange 16 | from elasticsearch import Elasticsearch 17 | from subprocess import Popen, PIPE, STDOUT 18 | from torch.utils.data import DataLoader, TensorDataset 19 | from datasets import load_metric, load_from_disk, load_dataset, Features, Value, Sequence, DatasetDict, Dataset 20 | 21 | # 엘라스틱 서치 노트북 파일 (es_retrieval.ipynb 를 먼저 실행하여 index 등록후 사용해야합니다. ) 22 | def elastic_setting(index_name): 23 | config = {'host':'localhost', 'port':9200} 24 | es = Elasticsearch([config]) 25 | return es 26 | 27 | 28 | def search_es(es_obj, index_name, question_text, n_results): 29 | # search query 30 | query = { 31 | 'query': { 32 | 'match': { 33 | 'document_text': question_text 34 | } 35 | } 36 | } 37 | # n_result => 상위 몇개를 선택? 38 | res = es_obj.search(index=index_name, body=query, size=n_results) 39 | 40 | return res 41 | 42 | 43 | def elastic_retrieval(es, index_name, question_text, n_results): 44 | res = search_es(es, index_name, question_text, n_results) 45 | # 매칭된 context만 list형태로 만든다. 46 | context_list = list((hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']) 47 | return context_list 48 | 49 | def preprocess(text): 50 | text = re.sub(r'\n', ' ', text) 51 | text = re.sub(r"\\n", " ", text) 52 | text = re.sub(r"\s+", " ", text) 53 | text = re.sub(r'#', ' ', text) 54 | text = re.sub(r"[^a-zA-Z0-9가-힣ㄱ-ㅎㅏ-ㅣぁ-ゔァ-ヴー々〆〤一-龥<>()\s\.\?!》《≪≫\'<>〈〉:‘’%,『』「」<>・\"-“”∧]", "", text) 55 | return text 56 | 57 | def mk_new_file(mode, files, top_k, es, index_name): 58 | if mode == 'test': 59 | new_files = {'id':[], 'question':[], 'top_k':[]} 60 | for file in files: 61 | question_text = file['question'] 62 | 63 | top_list = elastic_retrieval(es, index_name, question_text, top_k) 64 | top_list = [text for text, score in top_list] 65 | 66 | new_files['id'].append(file['id']) 67 | new_files['question'].append(question_text) 68 | new_files['top_k'].append(top_list) 69 | return new_files 70 | 71 | else: 72 | new_files = {'context':[], 'id':[], 'question':[], 'top_k':[], 'answer_idx':[], 'answer':[], 'start_idx':[]} 73 | for file in files: 74 | start_ids = file["answers"]["answer_start"][0] 75 | 76 | before = file["context"][:start_ids] 77 | after = file["context"][start_ids:] 78 | 79 | process_before = preprocess(before) 80 | process_after = preprocess(after) 81 | new_context = process_before + process_after 82 | 83 | start_idx = start_ids - len(before) + len(process_before) 84 | 85 | question_text = file['question'] 86 | top_list = elastic_retrieval(es, index_name, question_text, top_k) 87 | top_list = [text for text, score in top_list] 88 | 89 | if not new_context in top_list: 90 | top_list = top_list[:-1] + [new_context] 91 | answer_idx = top_k-1 92 | else: 93 | answer_idx = top_list.index(new_context) 94 | 95 | answer = file['answers']['text'][0] 96 | 97 | new_files['context'].append(new_context) 98 | new_files['id'].append(file['id']) 99 | new_files['question'].append(question_text) 100 | new_files['top_k'].append(top_list) 101 | new_files['answer_idx'].append(answer_idx) 102 | new_files['answer'].append(answer) 103 | new_files['start_idx'].append(start_idx) 104 | return new_files 105 | 106 | def save_pickle(save_path, data_set): 107 | file = open(save_path, "wb") 108 | pickle.dump(data_set, file) 109 | file.close() 110 | 111 | def get_pickle(pickle_path): 112 | f = open(pickle_path, "rb") 113 | dataset = pickle.load(f) 114 | f.close() 115 | return dataset 116 | 117 | def main(args): 118 | train_file = load_from_disk("../data/train_dataset")["train"] 119 | validation_file = load_from_disk("../data/train_dataset")["validation"] 120 | test_file = load_from_disk("../data/test_dataset")["validation"] 121 | 122 | es = elastic_setting(args.index_name) 123 | 124 | print('wait...', end='\r') 125 | new_train_file = mk_new_file('train', train_file, args.top_k, es, args.index_name) 126 | print('make train dataset!!') 127 | save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_train.pkl'), new_train_file) 128 | 129 | print('wait...', end='\r') 130 | new_valid_file = mk_new_file('valid', validation_file, args.top_k, es, args.index_name) 131 | print('make validation dataset!!') 132 | save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_valid.pkl'), new_valid_file) 133 | 134 | print('wait...', end='\r') 135 | new_test_file = mk_new_file('test', test_file, args.top_k, es, args.index_name) 136 | print('make test dataset!!') 137 | save_pickle(os.path.join(args.save_path, f'Top{args.top_k}_preprocess_test.pkl'), new_test_file) 138 | 139 | print('complete!!') 140 | 141 | if __name__ == '__main__': 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--top_k', type=int, default=20) 144 | parser.add_argument('--save_path', type=str, default='../data/retrieval_dataset') 145 | parser.add_argument('--index_name', type=str, default="nori-index") 146 | 147 | args = parser.parse_args() 148 | 149 | if not os.path.exists(args.save_path): 150 | os.mkdir(args.save_path) 151 | 152 | print(f'TOP K ::: {args.top_k}') 153 | print(f'SAVE PATH ::: {args.save_path}') 154 | print(f'INDEX NAME ::: {args.index_name}') 155 | 156 | main(args) 157 | -------------------------------------------------------------------------------- /code/model/QAConvModelV2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | from transformers import AutoTokenizer, AutoModel 6 | import math 7 | 8 | class QAConvModelV2(nn.Module): 9 | def __init__(self, model_name, model_config, tokenizer_name): 10 | super().__init__() 11 | self.model_name = model_name 12 | self.model_config = model_config 13 | self.sep_token_id = AutoTokenizer.from_pretrained(tokenizer_name).sep_token_id 14 | self.backbone = AutoModel.from_pretrained(model_name, config=model_config) 15 | self.query_layer = nn.Linear(50*model_config.hidden_size, model_config.hidden_size, bias=True) 16 | self.query_calssify_layer = nn.Linear(model_config.hidden_size, 6, bias=True) 17 | self.key_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 18 | self.value_layer = nn.Linear(model_config.hidden_size, model_config.hidden_size, bias=True) 19 | self.conv1d_layer1 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=1) 20 | self.conv1d_layer3 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=3, padding=1) 21 | self.conv1d_layer5 = nn.Conv1d(model_config.hidden_size, 1024, kernel_size=5, padding=2) 22 | self.drop_out = nn.Dropout(0.3) 23 | self.classify_layer = nn.Linear(1024*3, 2, bias=True) 24 | 25 | 26 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, start_positions=None, end_positions=None, output_attentions=None, output_hidden_states=None, return_dict=None, question_type=None): 27 | if "xlm" in self.model_name: 28 | outputs = self.backbone( 29 | input_ids, 30 | attention_mask=attention_mask, 31 | token_type_ids=token_type_ids, 32 | head_mask=head_mask, 33 | inputs_embeds=inputs_embeds, 34 | output_attentions=output_attentions, 35 | output_hidden_states=output_hidden_states, 36 | return_dict=return_dict, 37 | ) 38 | 39 | else: 40 | outputs = self.backbone( 41 | input_ids, 42 | attention_mask=attention_mask, 43 | token_type_ids=token_type_ids, 44 | position_ids=position_ids, 45 | head_mask=head_mask, 46 | inputs_embeds=inputs_embeds, 47 | output_attentions=output_attentions, 48 | output_hidden_states=output_hidden_states, 49 | return_dict=return_dict, 50 | ) 51 | 52 | sequence_output = outputs[0] # (B * 384 * 1024) 53 | 54 | if not token_type_ids : 55 | token_type_ids = self.make_token_type_ids(input_ids) 56 | 57 | embedded_query = sequence_output * (token_type_ids.unsqueeze(dim=-1)==0) # 전체 Text 중 query에 해당하는 Embedded Vector만 남김. 58 | embedded_query = nn.Dropout(0.1)(F.relu(embedded_query)) # Activation Function 및 Dropout Layer 통과 59 | embedded_query = embedded_query[:, :50, :] # 질문에 해당하는 Embedding만 남김. (B * 50 * hidden_size) 60 | embedded_query = embedded_query.reshape((sequence_output.shape[0], 1, -1)) # Token의 Embedding을 Hidden Dim축으로 Concat함. (B * 1 * 50 x hidden_size) 61 | embedded_query = self.query_layer(embedded_query) # Dense Layer를 통과 시킴. (B * 1 * hidden_size) 62 | query_logits = F.softmax(self.query_calssify_layer(embedded_query.squeeze(1)), -1) # Query의 종류를 예측하는 Branch (B * 6) 63 | 64 | embedded_key = sequence_output * (token_type_ids.unsqueeze(dim=-1)==1) # 전체 Text 중 context에 해당하는 Embedded Vector만 남김. 65 | embedded_key = self.key_layer(embedded_key) # (B * max_seq_length * hidden_size) 66 | attention_rate = torch.matmul(embedded_key, torch.transpose(embedded_query, 1, 2)) # Context의 Value Vector와 Quetion의 Query Vector를 사용 (B * max_seq_legth * 1) 67 | attention_rate = attention_rate / math.sqrt(embedded_key.shape[-1]) # hidden size의 표준편차로 나눠줌. (B * max_seq_legth * 1) 68 | attention_rate = attention_rate / 10 # Temperature로 나눠줌. (B * max_seq_legth * 1) 69 | attention_rate = F.softmax(attention_rate, 1) # softmax를 통과시켜서 확률값으로 변경해, Question과 Context의 Attention Rate를 구함. (B * max_seq_legth * 1) 70 | 71 | embedded_value = self.value_layer(sequence_output) # (B * max_seq_length * hidden_size) 72 | embedded_value = embedded_value * attention_rate # Attention Rate를 활용해서 Output 값을 변경함. (B * max_seq_legth * hidden_size) 73 | 74 | conv_input = embedded_value.transpose(1, 2) # Convolution 연산을 위해 Transpose (B * hidden_size * max_seq_legth) 75 | conv_output1 = F.relu(self.conv1d_layer1(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 76 | conv_output3 = F.relu(self.conv1d_layer3(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 77 | conv_output5 = F.relu(self.conv1d_layer5(conv_input)) # Conv연산의 결과 (B * num_conv_filter * max_seq_legth) 78 | concat_output = torch.cat((conv_output1, conv_output3, conv_output5), dim=1) # Concatenation (B * num_conv_filter x 3 * max_seq_legth) 79 | 80 | concat_output = concat_output.transpose(1, 2) # Dense Layer에 입력을 위해 Transpose (B * max_seq_legth * num_conv_filter x 3) 81 | concat_output = self.drop_out(concat_output) # dropout 통과 82 | logits = self.classify_layer(concat_output) # Classifier Layer를 통해 최종 Logit을 얻음. (B * max_seq_legth * 2) 83 | 84 | start_logits, end_logits = logits.split(1, dim=-1) 85 | start_logits = start_logits.squeeze(-1) 86 | end_logits = end_logits.squeeze(-1) 87 | 88 | return {"start_logits" : start_logits, "end_logits" : end_logits, "hidden_states" : outputs.hidden_states, "attentions" : outputs.attentions, "query_logits" : query_logits} 89 | 90 | 91 | def make_token_type_ids(self, input_ids) : 92 | token_type_ids = [] 93 | for i, input_id in enumerate(input_ids): 94 | sep_idx = np.where(input_id.cpu().numpy() == self.sep_token_id) 95 | token_type_id = [0]*sep_idx[0][0] + [1]*(len(input_id)-sep_idx[0][0]) 96 | token_type_ids.append(token_type_id) 97 | return torch.tensor(token_type_ids).cuda() 98 | -------------------------------------------------------------------------------- /code/retrieval_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import kss 5 | import torch 6 | import random 7 | from tqdm.notebook import tqdm 8 | 9 | class TrainRetrievalDataset(torch.utils.data.Dataset): 10 | def __init__(self, dataset, p_tokenizer, q_tokenizer): 11 | self.dataset = dataset 12 | self.p_tokenizer = p_tokenizer 13 | self.q_tokenizer = q_tokenizer 14 | 15 | def __getitem__(self, idx): 16 | question = self.dataset['question'][idx] 17 | top_context = self.dataset['top_k'][idx] 18 | target = self.dataset['answer_idx'][idx] 19 | 20 | p_seqs = self.p_tokenizer(top_context, 21 | padding='max_length', 22 | truncation=True, 23 | return_tensors='pt') 24 | q_seqs = self.q_tokenizer(question, 25 | padding='max_length', 26 | truncation=True, 27 | return_tensors='pt') 28 | 29 | p_input_ids = p_seqs['input_ids'] 30 | p_attention_mask = p_seqs['attention_mask'] 31 | p_token_type_ids = p_seqs['token_type_ids'] 32 | 33 | q_input_ids = q_seqs['input_ids'] 34 | q_attention_mask = q_seqs['attention_mask'] 35 | q_token_type_ids = q_seqs['token_type_ids'] 36 | 37 | p_input_ids_list = torch.Tensor([]) 38 | p_attention_mask_list = torch.Tensor([]) 39 | p_token_type_ids_list = torch.Tensor([]) 40 | for i in range(len(p_attention_mask)): 41 | str_idx, end_idx = self._select_range(p_attention_mask[i]) 42 | 43 | p_input_ids_tmp = torch.cat([torch.Tensor([101]), p_input_ids[i][str_idx:end_idx], torch.Tensor([102])]).int().long() 44 | p_attention_mask_tmp = p_attention_mask[i][str_idx-1:end_idx+1].int().long() 45 | p_token_type_ids_tmp = p_token_type_ids[i][str_idx-1:end_idx+1].int().long() 46 | 47 | p_input_ids_list = torch.cat([p_input_ids_list, p_input_ids_tmp.unsqueeze(0)]).int().long() 48 | p_attention_mask_list = torch.cat([p_attention_mask_list, p_attention_mask_tmp.unsqueeze(0)]).int().long() 49 | p_token_type_ids_list = torch.cat([p_token_type_ids_list, p_token_type_ids_tmp.unsqueeze(0)]).int().long() 50 | 51 | return p_input_ids_list, p_attention_mask_list, p_token_type_ids_list, q_input_ids, q_attention_mask, q_token_type_ids, target 52 | 53 | def __len__(self): 54 | return len(self.dataset['question']) 55 | 56 | def _select_range(self, attention_mask): 57 | sent_len = len([i for i in attention_mask if i != 0]) 58 | if sent_len <= 512: 59 | return 1, 511 60 | else: 61 | start_idx = random.randint(1, sent_len-511) 62 | end_idx = start_idx + 510 63 | return start_idx, end_idx 64 | 65 | class ValidRetrievalDataset(torch.utils.data.Dataset): 66 | def __init__(self, dataset, p_tokenizer, q_tokenizer): 67 | self.dataset = dataset 68 | self.p_tokenizer = p_tokenizer 69 | self.q_tokenizer = q_tokenizer 70 | 71 | def __getitem__(self, idx): 72 | question = self.dataset['question'][idx] 73 | top_context = self.dataset['top_k'][idx] 74 | target = self.dataset['answer_idx'][idx] 75 | 76 | p_seqs = self.p_tokenizer(top_context, 77 | padding='max_length', 78 | truncation=True, 79 | return_tensors='pt') 80 | q_seqs = self.q_tokenizer(question, 81 | padding='max_length', 82 | truncation=True, 83 | return_tensors='pt') 84 | 85 | p_input_ids = p_seqs['input_ids'] 86 | p_attention_mask = p_seqs['attention_mask'] 87 | p_token_type_ids = p_seqs['token_type_ids'] 88 | 89 | q_input_ids = q_seqs['input_ids'] 90 | q_attention_mask = q_seqs['attention_mask'] 91 | q_token_type_ids = q_seqs['token_type_ids'] 92 | 93 | p_input_ids_list = torch.Tensor([]) 94 | p_attention_mask_list = torch.Tensor([]) 95 | p_token_type_ids_list = torch.Tensor([]) 96 | 97 | top_k_id = [] 98 | for i in range(len(p_attention_mask)): 99 | ids_list = self._select_range(p_attention_mask[i]) 100 | if i == target: 101 | target = list(range(len(p_input_ids_list), len(p_input_ids_list)+len(ids_list))) 102 | for str_idx, end_idx in ids_list: 103 | p_input_ids_tmp = torch.cat([torch.Tensor([101]), p_input_ids[i][str_idx:end_idx], torch.Tensor([102])]).int().long() 104 | p_attention_mask_tmp = p_attention_mask[i][str_idx-1:end_idx+1].int().long() 105 | p_token_type_ids_tmp = p_token_type_ids[i][str_idx-1:end_idx+1].int().long() 106 | 107 | p_input_ids_list = torch.cat([p_input_ids_list, p_input_ids_tmp.unsqueeze(0)]).int().long() 108 | p_attention_mask_list = torch.cat([p_attention_mask_list, p_attention_mask_tmp.unsqueeze(0)]).int().long() 109 | p_token_type_ids_list = torch.cat([p_token_type_ids_list, p_token_type_ids_tmp.unsqueeze(0)]).int().long() 110 | top_k_id.append(i) 111 | 112 | return p_input_ids_list, p_attention_mask_list, p_token_type_ids_list, q_input_ids, q_attention_mask, q_token_type_ids, target, top_k_id 113 | 114 | def __len__(self): 115 | return len(self.dataset['question']) 116 | 117 | def _select_range(self, attention_mask): 118 | sent_len = len([i for i in attention_mask if i != 0]) 119 | if sent_len <= 512: 120 | return [(1,511)] 121 | else: 122 | num = sent_len // 255 123 | res = sent_len % 255 124 | if res == 0: 125 | num -= 1 126 | ids_list = [] 127 | for n in range(num): 128 | if res > 0 and n == num-1: 129 | end_idx = sent_len-1 130 | start_idx = end_idx - 510 131 | else: 132 | start_idx = n*255+1 133 | end_idx = start_idx + 510 134 | ids_list.append((start_idx, end_idx)) 135 | return ids_list -------------------------------------------------------------------------------- /code/retrieval_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import logging 5 | import os 6 | import sys 7 | import time 8 | import json 9 | 10 | import torch 11 | import random 12 | import numpy as np 13 | import pandas as pd 14 | import os 15 | import pickle 16 | 17 | from scipy.special import log_softmax 18 | 19 | from tqdm import tqdm 20 | from datasets import load_metric, load_from_disk, Sequence, Value, Features, Dataset, DatasetDict 21 | from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, AdamW 22 | from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer 23 | from torch.utils.data import DataLoader, TensorDataset 24 | from konlpy.tag import Mecab 25 | from konlpy.tag import Kkma 26 | from konlpy.tag import Hannanum 27 | from elasticsearch import Elasticsearch 28 | from subprocess import Popen, PIPE, STDOUT 29 | 30 | from transformers import ( 31 | DataCollatorWithPadding, 32 | EvalPrediction, 33 | HfArgumentParser, 34 | TrainingArguments, 35 | set_seed, 36 | ) 37 | from retrieval_model import Encoder 38 | 39 | # 엘라스틱 서치 노트북 파일 (es_retrieval.ipynb 를 먼저 실행하여 index 등록후 사용해야합니다. ) 40 | def elastic_setting(index_name): 41 | config = {'host':'localhost', 'port':9200} 42 | es = Elasticsearch([config]) 43 | return es 44 | 45 | 46 | def search_es(es_obj, index_name, question_text, n_results): 47 | # search query 48 | query = { 49 | 'query': { 50 | 'match': { 51 | 'document_text': question_text 52 | } 53 | } 54 | } 55 | # n_result => 상위 몇개를 선택? 56 | res = es_obj.search(index=index_name, body=query, size=n_results) 57 | 58 | return res 59 | 60 | def elastic_retrieval(es, index_name, question_text, n_results): 61 | res = search_es(es, index_name, question_text, n_results) 62 | # 매칭된 context만 list형태로 만든다. 63 | context_list = list((hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']) 64 | return context_list 65 | 66 | 67 | def get_pickle(pickle_path): 68 | '''Custom Dataset을 Load하기 위한 함수''' 69 | f = open(pickle_path, "rb") 70 | dataset = pickle.load(f) 71 | f.close() 72 | return dataset 73 | 74 | 75 | def save_pickle(save_path, data_set): 76 | file = open(save_path, "wb") 77 | pickle.dump(data_set, file) 78 | file.close() 79 | return None 80 | 81 | 82 | def select_range(attention_mask): 83 | sent_len = len([i for i in attention_mask if i != 0]) 84 | if sent_len <= 512: 85 | return [(1,511)] 86 | else: 87 | num = sent_len // 255 88 | res = sent_len % 255 89 | if res == 0: 90 | num -= 1 91 | ids_list = [] 92 | for n in range(num): 93 | if res > 0 and n == num-1: 94 | end_idx = sent_len-1 95 | start_idx = end_idx - 510 96 | else: 97 | start_idx = n*255+1 98 | end_idx = start_idx + 510 99 | ids_list.append((start_idx, end_idx)) 100 | return ids_list 101 | 102 | def inference(args, p_encoder, q_encoder, question_texts, p_tokenizer, q_tokenizer): 103 | es = elastic_setting(args.index_name) 104 | 105 | p_encoder.eval() 106 | q_encoder.eval() 107 | 108 | dense_retrieval_result = {} 109 | for question_text in tqdm(question_texts): 110 | es_context_list = elastic_retrieval(es, args.index_name, question_text, args.es_top_k = 70) 111 | es_context_list = [context for context, score in es_context_list] 112 | 113 | p_seqs = p_tokenizer(es_context_list, 114 | padding='max_length', 115 | truncation=True, 116 | return_tensors='pt') 117 | 118 | q_seqs = q_tokenizer(question_text, 119 | padding='max_length', 120 | truncation=True, 121 | return_tensors='pt') 122 | 123 | p_input_ids = p_seqs['input_ids'] 124 | p_attention_mask = p_seqs['attention_mask'] 125 | p_token_type_ids = p_seqs['token_type_ids'] 126 | 127 | q_input_ids = q_seqs['input_ids'] 128 | q_attention_mask = q_seqs['attention_mask'] 129 | q_token_type_ids = q_seqs['token_type_ids'] 130 | 131 | p_input_ids_list = torch.Tensor([]) 132 | p_attention_mask_list = torch.Tensor([]) 133 | p_token_type_ids_list = torch.Tensor([]) 134 | 135 | top_k_id = [] 136 | for i in range(len(p_attention_mask)): 137 | ids_list = select_range(p_attention_mask[i]) 138 | for str_idx, end_idx in ids_list: 139 | p_input_ids_tmp = torch.cat([torch.Tensor([101]), p_input_ids[i][str_idx:end_idx], torch.Tensor([102])]).int().long() 140 | p_attention_mask_tmp = p_attention_mask[i][str_idx-1:end_idx+1].int().long() 141 | p_token_type_ids_tmp = p_token_type_ids[i][str_idx-1:end_idx+1].int().long() 142 | 143 | p_input_ids_list = torch.cat([p_input_ids_list, p_input_ids_tmp.unsqueeze(0)]).int().long() 144 | p_attention_mask_list = torch.cat([p_attention_mask_list, p_attention_mask_tmp.unsqueeze(0)]).int().long() 145 | p_token_type_ids_list = torch.cat([p_token_type_ids_list, p_token_type_ids_tmp.unsqueeze(0)]).int().long() 146 | top_k_id.append(i) 147 | 148 | batch_num = 20 149 | if len(p_input_ids_list) % batch_num == 0: 150 | num = len(p_input_ids_list) // batch_num 151 | else: 152 | num = len(p_input_ids_list) // batch_num + 1 153 | 154 | p_output_list = [] 155 | for i in range(num): 156 | p_input_ids = p_input_ids_list[i*batch_num:(i+1)*batch_num] 157 | p_attention_mask = p_attention_mask_list[i*batch_num:(i+1)*batch_num] 158 | p_token_type_ids =p_token_type_ids_list[i*batch_num:(i+1)*batch_num] 159 | 160 | batch = (p_input_ids, p_attention_mask, p_token_type_ids) 161 | p_inputs = {'input_ids' : batch[0].to('cuda'), 162 | 'attention_mask' : batch[1].to('cuda'), 163 | 'token_type_ids': batch[2].to('cuda')} 164 | p_outputs = p_encoder(**p_inputs).cpu() 165 | p_output_list.extend(p_outputs.cpu().tolist()) 166 | p_output_list = np.array(p_output_list) 167 | 168 | batch = (q_input_ids, q_attention_mask, q_token_type_ids) 169 | q_inputs = {'input_ids' : batch[0].to('cuda'), 170 | 'attention_mask' : batch[1].to('cuda'), 171 | 'token_type_ids': batch[2].to('cuda')} 172 | q_outputs = q_encoder(**q_inputs).cpu() # (N, E) 173 | q_outputs = np.array(q_outputs.cpu().tolist()) 174 | 175 | sim_scores = np.matmul(q_outputs, np.transpose(p_output_list, [1, 0])) # (1, E) x (E, N) = (1, N) 176 | sim_scores = log_softmax(sim_scores, axis=1) 177 | 178 | class_0 = np.array([1 if i == 0 else 0 for idx, i in enumerate(top_k_id)]) 179 | w = np.sum(sim_scores, axis=1) * 1/np.shape(sim_scores)[1] 180 | sim_scores = sim_scores[0] - w[0]*class_0 181 | 182 | preds_idx = np.argsort(-1*sim_scores, axis=0) 183 | 184 | top_idx_list = [] 185 | top_k_list = [] 186 | for idx in preds_idx: 187 | top_idx = top_k_id[idx] 188 | if top_idx in top_idx_list: 189 | continue 190 | top_idx_list.append(top_idx) 191 | top_k_list.append((es_context_list[top_idx], sim_scores[idx])) 192 | dense_retrieval_result[question_text] = top_k_list[:args.dr_top_k] 193 | return dense_retrieval_result 194 | 195 | 196 | def main(args): 197 | 198 | text_data = load_from_disk('../../data/test_dataset') 199 | question_texts = text_data["validation"]["question"] 200 | 201 | p_tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint) 202 | p_tokenizer.model_max_length = 1536 203 | q_tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint) 204 | 205 | p_encoder = Encoder(args.model_checkpoint) 206 | q_encoder = Encoder(args.model_checkpoint) 207 | 208 | p_encoder.load_state_dict(torch.load(f'../retrieval_output/{args.run_name}/model/p_{args.run_name}.pt')) 209 | q_encoder.load_state_dict(torch.load(f'../retrieval_output/{args.run_name}/model/q_{args.run_name}.pt')) 210 | 211 | if torch.cuda.is_available(): 212 | p_encoder.to('cuda') 213 | q_encoder.to('cuda') 214 | print('GPU enabled') 215 | 216 | dense_retrieval_result = inference(args, p_encoder, q_encoder, question_texts, p_tokenizer, q_tokenizer) 217 | 218 | save_path = f'../data/test_ex{args.es_top_k}_dr{args.dr_top_k}_dense.pkl' 219 | save_pickle(save_path, dense_retrieval_result) 220 | print('complete !!') 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | 225 | parser.add_argument('--model_checkpoint', type=str, default='bert-base-multilingual-cased') 226 | parser.add_argument('--run_name', type=str, default='best_dense_retrieval') 227 | parser.add_argument('--es_top_k', type=int, default=70) 228 | parser.add_argument('--dr_top_k', type=int, default=70) 229 | parser.add_argument('--index_name', type=str, default="nori-index") 230 | 231 | args = parser.parse_args() 232 | 233 | if not os.path.exists(args.output_dir): 234 | os.mkdir(args.output_dir) 235 | args.output_dir = os.path.join(args.output_dir, args.run_name) 236 | 237 | print(f'Model Checkpoint ::: {args.model_checkpoint}') 238 | print(f'Run Name ::: {args.run_name}') 239 | print(f'Top k Number of Elastic Retrieval ::: {args.es_top_k}') 240 | print(f'Top k Number of Dense Retrieval ::: {args.dr_top_k}') 241 | print(f'Index Name ::: {args.index_name}') 242 | 243 | if args.es_top_k < args.dr_top_k: 244 | raise ValueError(f' Top k number of elastic retrieval must be greater than Top k number of dense retrieval >>> [ Top k number of elastic retrieval : {args.es_top_k} / Top k number of dense retrieval : {args.dr_top_k} ]') 245 | 246 | main(args) 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /etc/my_stop_dic.txt: -------------------------------------------------------------------------------- 1 | 가 2 | 같이 3 | 같이나 4 | 같이는 5 | 같이는야 6 | 같이는커녕 7 | 같이도 8 | 같이만 9 | 같인 10 | 고 11 | 과 12 | 과는 13 | 과는커녕 14 | 과도 15 | 과를 16 | 과만 17 | 과만은 18 | 과의 19 | 까지 20 | 까지가 21 | 까지나 22 | 까지나마 23 | 까지는 24 | 까지는야 25 | 까지는커녕 26 | 까지도 27 | 까지든지 28 | 까지라고 29 | 까지라고는 30 | 까지라고만은 31 | 까지라도 32 | 까지로 33 | 까지로나 34 | 까지로나마 35 | 까지로는 36 | 까지로는야 37 | 까지로는커녕 38 | 까지로도 39 | 까지로든 40 | 까지로든지 41 | 까지로라서 42 | 까지로라야 43 | 까지로만 44 | 까지로만은 45 | 까지로서 46 | 까지로써 47 | 까지를 48 | 까지만 49 | 까지만은 50 | 까지만이라도 51 | 까지야 52 | 까지야말로 53 | 까지에 54 | 까지와 55 | 까지의 56 | 까지조차 57 | 까지조차도 58 | 까진 59 | 께옵서 60 | 께옵서는 61 | 께옵서는야 62 | 께옵서는커녕 63 | 께옵서도 64 | 께옵서만 65 | 께옵서만은 66 | 께옵서만이 67 | 께옵선 68 | 나 69 | 나마 70 | 는 71 | 는야 72 | 는커녕 73 | 니 74 | 다 75 | 다가 76 | 다가는 77 | 다가도 78 | 다간 79 | 대로 80 | 대로가 81 | 대로는 82 | 대로의 83 | 더러 84 | 더러는 85 | 더러만은 86 | 도 87 | 든 88 | 든지 89 | 라 90 | 라고 91 | 라고까지 92 | 라고까지는 93 | 라고는 94 | 라고만은 95 | 라곤 96 | 라도 97 | 라든지 98 | 라서 99 | 라야 100 | 라야만 101 | 라오 102 | 라지 103 | 라지요 104 | 랑 105 | 랑은 106 | 로고 107 | 로구나 108 | 로구려 109 | 로구먼 110 | 로군 111 | 로군요 112 | 로다 113 | 로되 114 | 로세 115 | 를 116 | 마다 117 | 마다라도 118 | 마다를 119 | 마다에게 120 | 마다의 121 | 마따나 122 | 마저 123 | 마저나마라도 124 | 마저도 125 | 마저라도 126 | 마저야 127 | 만 128 | 만도 129 | 만에 130 | 만으로 131 | 만으로는 132 | 만으로도 133 | 만으로라도 134 | 만으로써 135 | 만으론 136 | 만은 137 | 만을 138 | 만의 139 | 만이 140 | 만이라도 141 | 만치 142 | 만큼 143 | 만큼도 144 | 만큼만 145 | 만큼씩 146 | 만큼은 147 | 만큼의 148 | 만큼이나 149 | 만큼이라도 150 | 만큼이야 151 | 말고 152 | 말고는 153 | 말고도 154 | 며 155 | 밖에 156 | 밖에는 157 | 밖에도 158 | 밖엔 159 | 보고 160 | 보고는 161 | 보고도 162 | 보고만 163 | 보고만은 164 | 보고만이라도 165 | 보곤 166 | 보다 167 | 보다는 168 | 보다는야 169 | 보다도 170 | 보다만 171 | 보다야 172 | 보단 173 | 부터 174 | 부터가 175 | 부터나마 176 | 부터는 177 | 부터도 178 | 부터라도 179 | 부터를 180 | 부터만 181 | 부터만은 182 | 부터서는 183 | 부터야말로 184 | 부터의 185 | 부턴 186 | 아 187 | 야 188 | 야말로 189 | 에 190 | 에게 191 | 에게가 192 | 에게까지 193 | 에게까지는 194 | 에게까지는커녕 195 | 에게까지도 196 | 에게까지만 197 | 에게까지만은 198 | 에게나 199 | 에게는 200 | 에게는커녕 201 | 에게다 202 | 에게도 203 | 에게든 204 | 에게든지 205 | 에게라도 206 | 에게로 207 | 에게로는 208 | 에게마다 209 | 에게만 210 | 에게며 211 | 에게보다 212 | 에게보다는 213 | 에게부터 214 | 에게서 215 | 에게서가 216 | 에게서까지 217 | 에게서나 218 | 에게서는 219 | 에게서도 220 | 에게서든지 221 | 에게서라도 222 | 에게서만 223 | 에게서보다 224 | 에게서부터 225 | 에게서야 226 | 에게서와 227 | 에게서의 228 | 에게서처럼 229 | 에게선 230 | 에게야 231 | 에게와 232 | 에게의 233 | 에게처럼 234 | 에게하고 235 | 에게하며 236 | 에겐 237 | 에까지 238 | 에까지는 239 | 에까지도 240 | 에까지든지 241 | 에까지라도 242 | 에까지만 243 | 에까지만은 244 | 에까진 245 | 에나 246 | 에는 247 | 에다 248 | 에다가 249 | 에다가는 250 | 에다간 251 | 에도 252 | 에든 253 | 에든지 254 | 에라도 255 | 에로 256 | 에로의 257 | 에를 258 | 에만 259 | 에만은 260 | 에부터 261 | 에서 262 | 에서가 263 | 에서까지 264 | 에서까지도 265 | 에서나 266 | 에서나마 267 | 에서는 268 | 에서도 269 | 에서든지 270 | 에서라도 271 | 에서만 272 | 에서만도 273 | 에서만이 274 | 에서만큼 275 | 에서만큼은 276 | 에서보다 277 | 에서부터 278 | 에서부터는 279 | 에서부터도 280 | 에서부터라도 281 | 에서부터만 282 | 에서부터만은 283 | 에서야 284 | 에서와 285 | 에서와는 286 | 에서와의 287 | 에서의 288 | 에서조차 289 | 에서처럼 290 | 에선 291 | 에야 292 | 에의 293 | 에조차도 294 | 에하며 295 | 엔 296 | 엔들 297 | 엘 298 | 엘랑 299 | 여 300 | 와 301 | 와는 302 | 와도 303 | 와라도 304 | 와를 305 | 와만 306 | 와만은 307 | 와에만 308 | 와의 309 | 와처럼 310 | 와한테 311 | 요 312 | 으로 313 | 으로가 314 | 으로까지 315 | 으로까지만은 316 | 으로나 317 | 으로나든지 318 | 으로는 319 | 으로도 320 | 으로든지 321 | 으로라도 322 | 으로랑 323 | 으로만 324 | 으로만은 325 | 으로부터 326 | 으로부터는 327 | 으로부터는커녕 328 | 으로부터도 329 | 으로부터만 330 | 으로부터만은 331 | 으로부터서는 332 | 으로부터서도 333 | 으로부터서만 334 | 으로부터의 335 | 으로서 336 | 으로서가 337 | 으로서나 338 | 으로서는 339 | 으로서도 340 | 으로서든지 341 | 으로서라도 342 | 으로서만 343 | 으로서만도 344 | 으로서만은 345 | 으로서야 346 | 으로서의 347 | 으로선 348 | 으로써 349 | 으로써나 350 | 으로써는 351 | 으로써라도 352 | 으로써만 353 | 으로써야 354 | 으로야 355 | 으로의 356 | 으론 357 | 은 358 | 은커녕 359 | 을 360 | 의 361 | 이 362 | 이고 363 | 이나 364 | 이나마 365 | 이니 366 | 이다 367 | 이든 368 | 이든지 369 | 이라 370 | 이라고 371 | 이라고는 372 | 이라고도 373 | 이라고만은 374 | 이라곤 375 | 이라는 376 | 이라도 377 | 이라든지 378 | 이라서 379 | 이라야 380 | 이라야만 381 | 이랑 382 | 이랑은 383 | 이며 384 | 이며에게 385 | 이며조차도 386 | 이야 387 | 이야말로 388 | 이여 389 | 인들 390 | 인즉 391 | 인즉슨 392 | 일랑 393 | 일랑은 394 | 조차 395 | 조차가 396 | 조차도 397 | 조차를 398 | 조차의 399 | 처럼 400 | 처럼과 401 | 처럼도 402 | 처럼만 403 | 처럼만은 404 | 처럼은 405 | 처럼이라도 406 | 처럼이야 407 | 치고 408 | 치고는 409 | 커녕 410 | 커녕은 411 | 커니와 412 | 토록 413 | 하고 414 | 하고가 415 | 하고는 416 | 하고는커녕 417 | 하고도 418 | 하고라도 419 | 하고마저 420 | 하고만 421 | 하고만은 422 | 하고야 423 | 하고에게 424 | 하고의 425 | 하고조차 426 | 하고조차도 427 | 하곤 428 | 거나 429 | 거늘 430 | 거니 431 | 거니와 432 | 거드면 433 | 거드면은 434 | 거든 435 | 거들랑 436 | 거들랑은 437 | 건 438 | 건대 439 | 건댄 440 | 건마는 441 | 건만 442 | 것다 443 | 게 444 | 게끔 445 | 게나 446 | 게나마 447 | 게는 448 | 게도 449 | 게라도 450 | 게만 451 | 게만은 452 | 게시리 453 | 게요 454 | 고 455 | 고는 456 | 고도 457 | 고만 458 | 고말고 459 | 고서 460 | 고서는 461 | 고서도 462 | 고선 463 | 고야 464 | 고요 465 | 고자 466 | 곤 467 | 관데 468 | 구나 469 | 구려 470 | 구료 471 | 구먼 472 | 군 473 | 군요 474 | 기 475 | 기까지 476 | 기까지는 477 | 기까지도 478 | 기까지만 479 | 기까지만은 480 | 기로 481 | 기로서 482 | 기로서니 483 | 기로선들 484 | 기에 485 | 긴 486 | 길 487 | 나 488 | 나니 489 | 나마 490 | 나요 491 | 나이까 492 | 나이다 493 | 냐 494 | 냐고 495 | 냐는 496 | 냐라고 497 | 냐라고도 498 | 냐라고만 499 | 냐에 500 | 네 501 | 네만 502 | 네요 503 | 노 504 | 노라 505 | 노라고 506 | 노라니 507 | 노라면 508 | 느냐 509 | 느냐고 510 | 느냐는 511 | 느냐라고 512 | 느냐라고는 513 | 느냐라고도 514 | 느냐라고만 515 | 느냐라고만은 516 | 느냐에 517 | 느뇨 518 | 느니 519 | 느니라 520 | 느니만 521 | 느라 522 | 느라고 523 | 는 524 | 는가 525 | 는가라고 526 | 는가라는 527 | 는가를 528 | 는가에 529 | 는걸 530 | 는고 531 | 는구나 532 | 는구려 533 | 는구료 534 | 는구먼 535 | 는군 536 | 는다 537 | 는다거나 538 | 는다고 539 | 는다고는 540 | 는다는 541 | 는다는데 542 | 는다니 543 | 는다니까 544 | 는다든지 545 | 는다마는 546 | 는다만 547 | 는다만은 548 | 는다며 549 | 는다며는 550 | 는다면 551 | 는다면서 552 | 는다면은 553 | 는단다 554 | 는담 555 | 는답니까 556 | 는답니다 557 | 는답디까 558 | 는답디다 559 | 는답시고 560 | 는대 561 | 는대로 562 | 는대서 563 | 는대서야 564 | 는대야 565 | 는대요 566 | 는데 567 | 는데는 568 | 는데다 569 | 는데도 570 | 는데서 571 | 는만큼 572 | 는만큼만 573 | 는바 574 | 는지 575 | 는지가 576 | 는지고 577 | 는지는 578 | 는지도 579 | 는지라 580 | 는지를 581 | 는지만 582 | 는지에 583 | 는지요 584 | 는지의 585 | 니 586 | 니까 587 | 니까는 588 | 니깐 589 | 니라 590 | 니만치 591 | 니만큼 592 | 다 593 | 다가 594 | 다가는 595 | 다가도 596 | 다간 597 | 다거나 598 | 다고 599 | 다고까지 600 | 다고까지는 601 | 다고까지도 602 | 다고까지라도 603 | 다고까지만 604 | 다고까지만은 605 | 다고는 606 | 다고도 607 | 다고만 608 | 다고만은 609 | 다고요 610 | 다곤 611 | 다느냐 612 | 다느니 613 | 다는 614 | 다는데 615 | 다니 616 | 다마는 617 | 다마다 618 | 다만 619 | 다만은 620 | 다며 621 | 다며는 622 | 다면 623 | 다면서 624 | 다면서도 625 | 다면야 626 | 다면은 627 | 다시피 628 | 다오 629 | 단 630 | 단다 631 | 담 632 | 답시고 633 | 더구나 634 | 더구려 635 | 더구먼 636 | 더군 637 | 더군요 638 | 더냐 639 | 더니 640 | 더니라 641 | 더니마는 642 | 더니만 643 | 더라 644 | 더라도 645 | 더라며는 646 | 더라면 647 | 더란 648 | 더면 649 | 던 650 | 던가 651 | 던가요 652 | 던걸 653 | 던걸요 654 | 던고 655 | 던데 656 | 던데다 657 | 던데요 658 | 던들 659 | 던지 660 | 데 661 | 데도 662 | 데요 663 | 도록 664 | 도록까지 665 | 도록까지도 666 | 도록까지만 667 | 도록까지만요 668 | 도록까지만은 669 | 되 670 | 든 671 | 든지 672 | 듯 673 | 듯이 674 | 디 675 | 라 676 | 라고 677 | 라고까지 678 | 라고까지는 679 | 라고까지도 680 | 라고까지만 681 | 라고까지만은 682 | 라고는 683 | 라고도 684 | 라고만 685 | 라고만은 686 | 라곤 687 | 라느니 688 | 라는 689 | 라는데 690 | 라는데도 691 | 라는데요 692 | 라니 693 | 라니까 694 | 라니까요 695 | 라도 696 | 라든지 697 | 라며 698 | 라면 699 | 라면서 700 | 라면서까지 701 | 라면서까지도 702 | 라면서도 703 | 라면서요 704 | 란 705 | 란다 706 | 란다고 707 | 람 708 | 랍니까 709 | 랍니다 710 | 랍디까 711 | 랍디다 712 | 랍시고 713 | 래 714 | 래도 715 | 랴 716 | 랴마는 717 | 러 718 | 러니 719 | 러니라 720 | 러니이까 721 | 러니이다 722 | 러만 723 | 러만은 724 | 러이까 725 | 러이다 726 | 런가 727 | 런들 728 | 려 729 | 려거든 730 | 려고 731 | 려고까지 732 | 려고까지도 733 | 려고까지만 734 | 려고까지만은 735 | 려고는 736 | 려고도 737 | 려고만 738 | 려고만은 739 | 려고요 740 | 려기에 741 | 려나 742 | 려네 743 | 려느냐 744 | 려는 745 | 려는가 746 | 려는데 747 | 려는데요 748 | 려는지 749 | 려니 750 | 려니까 751 | 려니와 752 | 려다 753 | 려다가 754 | 려다가는 755 | 려다가도 756 | 려다가요 757 | 려더니 758 | 려더니만 759 | 려던 760 | 려면 761 | 려면요 762 | 려면은 763 | 려무나 764 | 련 765 | 련마는 766 | 련만 767 | 렴 768 | 렷다 769 | 리 770 | 리까 771 | 리니 772 | 리니라 773 | 리다 774 | 리라 775 | 리라는 776 | 리란 777 | 리로다 778 | 리만치 779 | 리만큼 780 | 리요 781 | 리요마는 782 | 마 783 | 매 784 | 며 785 | 며는 786 | 면 787 | 면서 788 | 면서까지 789 | 면서까지도 790 | 면서까지만은 791 | 면서도 792 | 면서부터 793 | 면서부터는 794 | 면요 795 | 면은 796 | 므로 797 | 사 798 | 사오이다 799 | 사옵니까 800 | 사옵니다 801 | 사옵디까 802 | 사옵디다 803 | 사외다 804 | 세 805 | 세요 806 | 소 807 | 소서 808 | 소이다 809 | 쇠다 810 | 습니까 811 | 습니다 812 | 습니다마는 813 | 습니다만 814 | 습디까 815 | 습디다 816 | 습디다마는 817 | 습디다만 818 | 아 819 | 아다 820 | 아다가 821 | 아도 822 | 아라 823 | 아서 824 | 아서까지 825 | 아서는 826 | 아서도 827 | 아서만 828 | 아서요 829 | 아선 830 | 아야 831 | 아야만 832 | 아요 833 | 어 834 | 어다 835 | 어다가 836 | 어도 837 | 어라 838 | 어서 839 | 어서까지 840 | 어서는 841 | 어서도 842 | 어서만 843 | 어서만은 844 | 어선 845 | 어야 846 | 어야만 847 | 어야지 848 | 어야지만 849 | 어요 850 | 어지이다 851 | 언정 852 | 엇다 853 | 오 854 | 오리까 855 | 오리까마는 856 | 오리까만 857 | 오리다 858 | 오이다 859 | 올습니다 860 | 올습니다마는 861 | 올습니다만 862 | 올시다 863 | 옵나이까 864 | 옵나이다 865 | 옵니까 866 | 옵니다 867 | 옵니다만 868 | 옵디까 869 | 옵디다 870 | 외다 871 | 요 872 | 으나 873 | 으나마 874 | 으냐 875 | 으냐고 876 | 으니 877 | 으니까 878 | 으니까는 879 | 으니깐 880 | 으니라 881 | 으니만치 882 | 으니만큼 883 | 으라 884 | 으라고 885 | 으라고까지 886 | 으라고까지는 887 | 으라고까지도 888 | 으라고까지만은 889 | 으라고는 890 | 으라고도 891 | 으라고만 892 | 으라고만은 893 | 으라고요 894 | 으라느니 895 | 으라는 896 | 으라니 897 | 으라니까 898 | 으라든지 899 | 으라며 900 | 으라면 901 | 으라면서 902 | 으라면은 903 | 으란 904 | 으람 905 | 으랍니까 906 | 으랍니다 907 | 으래 908 | 으래서 909 | 으래서야 910 | 으래야 911 | 으래요 912 | 으랴 913 | 으랴마는 914 | 으러 915 | 으러까지 916 | 으러까지도 917 | 으려 918 | 으려거든 919 | 으려고 920 | 으려고까지 921 | 으려고까지는 922 | 으려고까지도 923 | 으려고까지만 924 | 으려고까지만은 925 | 으려고는 926 | 으려고도 927 | 으려고만 928 | 으려고만은 929 | 으려고요 930 | 으려기에 931 | 으려나 932 | 으려느냐 933 | 으려느냐는 934 | 으려는 935 | 으려는가 936 | 으려는데 937 | 으려는데도 938 | 으려는데요 939 | 으려는지 940 | 으려니 941 | 으려니까 942 | 으려니와 943 | 으려다 944 | 으려다가 945 | 으려다가는 946 | 으려다가요 947 | 으려다간 948 | 으려더니 949 | 으려면 950 | 으려면야 951 | 으려면은 952 | 으려무나 953 | 으려서야 954 | 으려오 955 | 으련 956 | 으련다 957 | 으련마는 958 | 으련만 959 | 으련만은 960 | 으렴 961 | 으렵니까 962 | 으렵니다 963 | 으렷다 964 | 으리 965 | 으리까 966 | 으리니 967 | 으리니라 968 | 으리다 969 | 으리라 970 | 으리로다 971 | 으리만치 972 | 으리만큼 973 | 으리요 974 | 으마 975 | 으매 976 | 으며 977 | 으면 978 | 으면서 979 | 으면서까지 980 | 으면서까지도 981 | 으면서까지만 982 | 으면서까지만은 983 | 으면서는 984 | 으면서도 985 | 으면서부터 986 | 으면서부터까지 987 | 으면서부터까지도 988 | 으면서부터는 989 | 으면서요 990 | 으면요 991 | 으면은 992 | 으므로 993 | 으세요 994 | 으셔요 995 | 으소서 996 | 으시어요 997 | 으오 998 | 으오리까 999 | 으오리다 1000 | 으오이다 1001 | 으옵니까 1002 | 으옵니다 1003 | 으옵니다만 1004 | 으옵디까 1005 | 으옵디다 1006 | 으외다 1007 | 으이 1008 | 은 1009 | 은가 1010 | 은가를 1011 | 은가에 1012 | 은가에도 1013 | 은가에만 1014 | 은가요 1015 | 은걸 1016 | 은걸요 1017 | 은고 1018 | 은다고 1019 | 은다고까지 1020 | 은다고까지도 1021 | 은다고는 1022 | 은다는 1023 | 은다는데 1024 | 은다니 1025 | 은다니까 1026 | 은다든지 1027 | 은다마는 1028 | 은다면 1029 | 은다면서 1030 | 은다면서도 1031 | 은다면요 1032 | 은다면은 1033 | 은단다 1034 | 은담 1035 | 은답니까 1036 | 은답니다 1037 | 은답디까 1038 | 은답디다 1039 | 은답시고 1040 | 은대 1041 | 은대서 1042 | 은대서야 1043 | 은대야 1044 | 은대요 1045 | 은데 1046 | 은데는 1047 | 은데다 1048 | 은데도 1049 | 은데도요 1050 | 은데서 1051 | 은들 1052 | 은만큼 1053 | 은만큼도 1054 | 은만큼만은 1055 | 은만큼은 1056 | 은바 1057 | 은즉 1058 | 은즉슨 1059 | 은지 1060 | 은지가 1061 | 은지고 1062 | 은지는 1063 | 은지도 1064 | 은지라 1065 | 은지라도 1066 | 은지를 1067 | 은지만 1068 | 은지만은 1069 | 은지요 1070 | 을 1071 | 을거나 1072 | 을거냐 1073 | 을거다 1074 | 을거야 1075 | 을거지요 1076 | 을걸 1077 | 을까 1078 | 을까마는 1079 | 을까봐 1080 | 을까요 1081 | 을께 1082 | 을께요 1083 | 을꼬 1084 | 을는지 1085 | 을는지요 1086 | 을라 1087 | 을라고 1088 | 을라고까지 1089 | 을라고까지도 1090 | 을라고까지만 1091 | 을라고는 1092 | 을라고도 1093 | 을라고만 1094 | 을라고만은 1095 | 을라고요 1096 | 을라요 1097 | 을라치면 1098 | 을락 1099 | 을래 1100 | 을래도 1101 | 을래요 1102 | 을러니 1103 | 을러라 1104 | 을런가 1105 | 을런고 1106 | 을레 1107 | 을레라 1108 | 을만한 1109 | 을망정 1110 | 을밖에 1111 | 을밖에요 1112 | 을뿐더러 1113 | 을새 1114 | 을세라 1115 | 을세말이지 1116 | 을소냐 1117 | 을수록 1118 | 을쏘냐 1119 | 을이만큼 1120 | 을작이면 1121 | 을지 1122 | 을지가 1123 | 을지나 1124 | 을지니 1125 | 을지니라 1126 | 을지도 1127 | 을지라 1128 | 을지라도 1129 | 을지어다 1130 | 을지언정 1131 | 을지요 1132 | 을진대 1133 | 을진댄 1134 | 을진저 1135 | 을테다 1136 | 을텐데 1137 | 음 1138 | 음세 1139 | 음에도 1140 | 음에랴 1141 | 읍쇼 1142 | 읍시다 1143 | 읍시다요 1144 | 읍시오 1145 | 자 1146 | 자고 1147 | 자고까지 1148 | 자고까지는 1149 | 자고까지라도 1150 | 자고는 1151 | 자고도 1152 | 자고만 1153 | 자고만은 1154 | 자꾸나 1155 | 자는 1156 | 자마자 1157 | 자면 1158 | 자면요 1159 | 잔 1160 | 잘 1161 | 지 1162 | 지는 1163 | 지도 1164 | 지를 1165 | 지마는 1166 | 지만 1167 | 지요 1168 | 진 1169 | 질 1170 |  -------------------------------------------------------------------------------- /code/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import pickle 5 | import kss 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from elasticsearch import Elasticsearch 9 | from torch.utils.data import DataLoader, TensorDataset 10 | from datasets import load_metric, load_from_disk, load_dataset, Features, Value, Sequence, DatasetDict, Dataset 11 | from sentence_transformers import SentenceTransformer, util 12 | from data_processing import * 13 | from mask import mask_to_tokens 14 | 15 | 16 | def save_pickle(save_path, data_set): 17 | file = open(save_path, "wb") 18 | pickle.dump(data_set, file) 19 | file.close() 20 | return None 21 | 22 | 23 | def get_pickle(pickle_path): 24 | f = open(pickle_path, "rb") 25 | dataset = pickle.load(f) 26 | f.close() 27 | return dataset 28 | 29 | 30 | def save_data(data_path, new_wiki): 31 | with open(data_path, 'w', encoding='utf-8') as make_file: 32 | json.dump(new_wiki, make_file, indent="\t", ensure_ascii=False) 33 | 34 | 35 | def passage_split_400(text): 36 | num = len(text) // 400 37 | count = 1 38 | split_datas = kss.split_sentences(text) 39 | data_list = [] 40 | data = "" 41 | for split_data in split_datas: 42 | if abs(len(data) - 400) > abs(len(data) + len(split_data) - 400) and count < num: 43 | if len(data) == 0: 44 | data += split_data 45 | else: 46 | data += (" " + split_data) 47 | elif count < num: 48 | data_list.append(data) 49 | count += 1 50 | data = "" 51 | data += split_data 52 | else: 53 | data += split_data 54 | 55 | data_list.append(data) 56 | return data_list, len(data_list) 57 | 58 | 59 | def passage_split(text): 60 | length = len(text) // 2 61 | split_datas = kss.split_sentences(text) 62 | data_1 = "" 63 | data_2 = "" 64 | for split_data in split_datas: 65 | if abs(len(data_1) - length) > abs(len(data_1) + len(split_data) - length): 66 | if len(data_1) == 0: 67 | data_1 += split_data 68 | else: 69 | data_1 += (" " + split_data) 70 | else: 71 | if len(data_2) == 0: 72 | data_2 += split_data 73 | else: 74 | data_2 += (" " + split_data) 75 | 76 | return data_1, data_2 77 | 78 | 79 | def preprocess(text): 80 | text = re.sub(r'\n', ' ', text) 81 | text = re.sub(r"\\n", " ", text) 82 | text = re.sub(r"\s+", " ", text) 83 | text = re.sub(r'#', ' ', text) 84 | text = re.sub(r"[^a-zA-Z0-9가-힣ㄱ-ㅎㅏ-ㅣぁ-ゔァ-ヴー々〆〤一-龥<>()\s\.\?!》《≪≫\'<>〈〉:‘’%,『』「」<>・\"-“”∧]", "", text) 85 | return text 86 | 87 | 88 | def run_preprocess(data_dict): 89 | context = data_dict["context"] 90 | start_ids = data_dict["answers"]["answer_start"][0] 91 | before = data_dict["context"][:start_ids] 92 | after = data_dict["context"][start_ids:] 93 | process_before = preprocess(before) 94 | process_after = preprocess(after) 95 | process_data = process_before + process_after 96 | ids_move = len(before) - len(process_before) 97 | data_dict["context"] = process_data 98 | data_dict["answers"]["answer_start"][0] = start_ids - ids_move 99 | return data_dict 100 | 101 | 102 | def run_preprocess_to_wiki(data_dict): 103 | context = data_dict["text"] 104 | process_data = preprocess(context) 105 | data_dict["text"] = process_data 106 | return data_dict 107 | 108 | 109 | def search_es(es_obj, index_name, question_text, n_results): 110 | query = { 111 | 'query': { 112 | 'match': { 113 | 'document_text': question_text 114 | } 115 | } 116 | } 117 | res = es_obj.search(index=index_name, body=query, size=n_results) 118 | return res 119 | 120 | 121 | def make_custom_dataset(dataset_path) : 122 | if not (os.path.isdir("../data/train_dataset") or 123 | os.path.isdir("../data/wikipedia_documents.json")) : 124 | raise Exception ("Set the original data path to '../data'") 125 | 126 | train_f = Features({'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None), 127 | 'context': Value(dtype='string', id=None), 128 | 'id': Value(dtype='string', id=None), 129 | 'question': Value(dtype='string', id=None)}) 130 | 131 | if not os.path.isfile("../data/preprocess_wiki.json") : 132 | with open("../data/wikipedia_documents.json", "r") as f: 133 | wiki = json.load(f) 134 | new_wiki = dict() 135 | for ids in range(len(wiki)): 136 | new_wiki[str(ids)] = run_preprocess_to(wiki[str(ids)]) 137 | with open('../data/preprocess_wiki.json', 'w', encoding='utf-8') as make_file: 138 | json.dump(new_wiki, make_file, indent="\t", ensure_ascii=False) 139 | 140 | if not os.path.isfile("/opt/ml/input/data/preprocess_train.pkl"): 141 | train_dataset = load_from_disk("../data/train_dataset")['train'] 142 | val_dataset = load_from_disk("../data/train_dataset")['validation'] 143 | 144 | new_train_data, new_val_data = [], [] 145 | for data in train_dataset: 146 | new_data = run_preprocess(data) 147 | new_train_data.append(new_data) 148 | for data in val_dataset: 149 | new_data = run_preprocess(data) 150 | new_val_data.append(new_data) 151 | 152 | train_df = pd.DataFrame(new_train_data) 153 | val_df = pd.DataFrame(new_val_data) 154 | dataset = DatasetDict({'train': Dataset.from_pandas(train_df, features=train_f), 155 | 'validation': Dataset.from_pandas(val_df, features=train_f)}) 156 | save_pickle(dataset_path, dataset) 157 | 158 | if 'preprocess' in dataset_path: 159 | return dataset 160 | 161 | if 'squad' in dataset_path : 162 | train_data = get_pickle("../data/preprocess_train.pkl")["train"] 163 | val_data = get_pickle("../data/preprocess_train.pkl")["validation"] 164 | korquad_data = load_dataset("squad_kor_v1")["train"] 165 | 166 | df_train_data = pd.DataFrame(train_data) 167 | df_val_data = pd.DataFrame(val_data) 168 | df_korquad_data = pd.DataFrame(korquad_data, columns=['answers', 'context', 'id', 'question']) 169 | df_total_train = pd.concat([df_train_data, df_korquad_data]) 170 | 171 | dataset = DatasetDict({'train': Dataset.from_pandas(df_total_train, features=train_f), 172 | 'validation': Dataset.from_pandas(df_val_data, features=train_f)}) 173 | save_pickle("../data/korquad_train.pkl", dataset) 174 | return train_dataset 175 | 176 | if 'concat' in dataset_path : 177 | base_dataset = get_pickle("../data/preprocess_train.pkl") 178 | train_dataset, val_dataset = base_dataset["train"], base_dataset["validation"] 179 | 180 | train_data = [{"id" : train_dataset[i]["id"], "question" : train_dataset[i]["question"], 181 | "answers" : train_dataset[i]["answers"], "context" : train_dataset[i]["context"]} 182 | for i in range(len(train_dataset))] 183 | val_data = [{"id" : val_dataset[i]["id"], "question" : val_dataset[i]["question"], 184 | "answers" : val_dataset[i]["answers"], "context" : val_dataset[i]["context"]} 185 | for i in range(len(val_dataset))] 186 | 187 | config = {'host':'localhost', 'port':9200} 188 | es = Elasticsearch([config]) 189 | 190 | k = 5 # k : how many contexts to concatenate 191 | for idx, train in enumerate(train_data): 192 | res = search_es(es, "wiki-index", question["question"], k) 193 | context_list = [(hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']] 194 | contexts = train["context"] 195 | count = 0 196 | for context in context_list: 197 | # if same context already exists, don't concatenate 198 | if train["context"] == context[0]: 199 | continue 200 | contexts += " " + context[0] 201 | count += 1 202 | if count == (k-1): 203 | break 204 | train_data[idx]["context"] = contexts 205 | 206 | for idx, val in enumerate(val_data): 207 | res = search_es(es, "wiki-index", question["question"], k) 208 | context_list = [(hit['_source']['document_text'], hit['_score']) for hit in res['hits']['hits']] 209 | contexts = val["context"] 210 | count = 0 211 | for context in context_list: 212 | if val["context"] == context[0]: 213 | continue 214 | contexts += " " + context[0] 215 | count += 1 216 | if count == (k-1): 217 | break 218 | val_data[idx]["context"] = contexts 219 | 220 | train_df = pd.DataFrame(train_data) 221 | val_df = pd.DataFrame(val_data) 222 | dataset = DatasetDict({'train': Dataset.from_pandas(train_df, features=train_f), 223 | 'validation': Dataset.from_pandas(val_df, features=train_f)}) 224 | save_pickle(dataset_path, dataset) 225 | return dataset 226 | 227 | if "split_wiki_400" in dataset_path: 228 | with open("/opt/ml/input/data/preprocess_wiki.json", "r") as f: 229 | wiki = json.load(f) 230 | new_wiki = dict() 231 | for i in tqdm(range(len(wiki))): 232 | if len(wiki[str(i)]["text"]) < 800: 233 | new_wiki[str(i)] = wiki[str(i)] 234 | continue 235 | data_list, count = passage_split_400(wiki[str(i)]["text"]) 236 | for j in range(count): 237 | new_wiki[str(i) + f"_{j}"] = {"text" : data_list[j], "corpus_source" : wiki[str(i)]["corpus_source"], 238 | "url" : wiki[str(i)]["url"], "domain" : wiki[str(i)]["domain"], 239 | "title" : wiki[str(i)]["title"], "author" : wiki[str(i)]["author"], 240 | "html" : wiki[str(i)]["html"],"document_id" : wiki[str(i)]["document_id"]} 241 | 242 | save_data("../data/wiki-index-split-400.json", new_wiki) 243 | 244 | if "split_wiki" in dataset_path and dataset_path != "split_wiki_400": 245 | with open("/opt/ml/input/data/preprocess_wiki.json", "r") as f: 246 | wiki = json.load(f) 247 | 248 | limit = 0 249 | if "800" in dataset_path: 250 | limit = 800 251 | if "1000" in dataset_path: 252 | limit = 1000 253 | 254 | new_wiki = dict() 255 | for i in tqdm(range(len(wiki))): 256 | if len(wiki[str(i)]["text"]) < limit: 257 | new_wiki[str(i)] = wiki[str(i)] 258 | continue 259 | data_1, data_2 = passage_split(wiki[str(i)]["text"]) 260 | new_wiki[str(i) + f"_1"] = {"text" : data_1, "corpus_source" : wiki[str(i)]["corpus_source"], "url" : wiki[str(i)]["url"], 261 | "domain" : wiki[str(i)]["domain"], "title" : wiki[str(i)]["title"], "author" : wiki[str(i)]["author"], 262 | "html" : wiki[str(i)]["html"], "document_id" : wiki[str(i)]["document_id"]} 263 | new_wiki[str(i) + f"_2"] = {"text" : data_2, "corpus_source" : wiki[str(i)]["corpus_source"], "url" : wiki[str(i)]["url"], 264 | "domain" : wiki[str(i)]["domain"], "title" : wiki[str(i)]["title"], 265 | "author" : wiki[str(i)]["author"], "html" : wiki[str(i)]["html"], "document_id" : wiki[str(i)]["document_id"]} 266 | 267 | save_data(f"../data/split_wiki_{limit}.json") 268 | 269 | 270 | def make_mask_dataset(dataset_path, k, tokenizer): 271 | base_dataset, opt = None, None 272 | if 'default' in dataset_path: 273 | base_dataset = get_pickle("../data/preprocess_train.pkl") 274 | if 'concat' in dataset_path: 275 | base_dataset = get_pickle("../data/concat_train.pkl") 276 | k = int(re.findall("\d", dataset_path)[0]) 277 | 278 | data_processor = DataProcessor(tokenizer) 279 | train_dataset, val_dataset = base_dataset['train'], base_dataset['val'] 280 | column_names = train_dataset.column_names 281 | train_dataset = data_processor.train_tokenizer(train_dataset, column_names) 282 | val_dataset = data_processor.val_tokenizer(val_dataset, column_names) 283 | 284 | model = SentenceTransformer('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens') 285 | 286 | mask_dataset = mask_to_tokens(train_dataset, tokenizer, k, model) 287 | 288 | dataset = DatasetDict({'train': mask_dataset, 289 | 'validation': val_dataset}) 290 | 291 | save_pickle(dataset_path, dataset) 292 | return dataset 293 | 294 | 295 | 296 | 297 | -------------------------------------------------------------------------------- /code/retrieval_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import os 8 | import copy 9 | import time 10 | import json 11 | import random 12 | import pickle 13 | import argparse 14 | import numpy as np 15 | 16 | from konlpy.tag import Mecab 17 | from tqdm import tqdm, trange 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | 22 | from torch import nn, optim 23 | from torch.cuda.amp import autocast, GradScaler 24 | from torch.utils.tensorboard import SummaryWriter 25 | from torch.utils.data import DataLoader, RandomSampler, TensorDataset 26 | 27 | from datasets import load_dataset, load_from_disk 28 | from transformers import (AutoTokenizer, 29 | AdamW, 30 | TrainingArguments, 31 | get_linear_schedule_with_warmup, 32 | set_seed) 33 | 34 | from retrieval_model import Encoder 35 | from retrieval_dataset import TrainRetrievalDataset, ValidRetrievalDataset 36 | 37 | 38 | # In[2]: 39 | 40 | 41 | def seed_everything(seed): 42 | random.seed(seed) 43 | os.environ['PYTHONHASHSEED'] = str(seed) 44 | np.random.seed(seed) 45 | torch.manual_seed(seed) 46 | torch.cuda.manual_seed(seed) 47 | torch.backends.cudnn.deterministic = True 48 | torch.backends.cudnn.benchmark = True 49 | set_seed(seed) 50 | 51 | def get_pickle(pickle_path): 52 | f = open(pickle_path, "rb") 53 | dataset = pickle.load(f) 54 | f.close() 55 | return dataset 56 | 57 | def one_step_train(args, batch_list, p_encoder, q_encoder, criterion, scaler): 58 | p_input_ids = batch_list[0] 59 | p_attention_mask = batch_list[1] 60 | p_token_type_ids = batch_list[2] 61 | q_input_ids = batch_list[3] 62 | q_attention_mask = batch_list[4] 63 | q_token_type_ids = batch_list[5] 64 | targets_batch = batch_list[6] 65 | 66 | batch_loss, batch_acc = 0, 0 67 | for i in range(args.per_device_train_batch_size): 68 | batch = (p_input_ids[i], 69 | p_attention_mask[i], 70 | p_token_type_ids[i], 71 | q_input_ids[i], 72 | q_attention_mask[i], 73 | q_token_type_ids[i]) 74 | 75 | targets = torch.tensor([targets_batch[i]]).long() 76 | batch = tuple(t.to('cuda') for t in batch) 77 | p_inputs = {'input_ids' : batch[0], 78 | 'attention_mask' : batch[1], 79 | 'token_type_ids': batch[2]} 80 | 81 | q_inputs = {'input_ids' : batch[3], 82 | 'attention_mask' : batch[4], 83 | 'token_type_ids': batch[5]} 84 | 85 | p_outputs = p_encoder(**p_inputs) # (20, E) 86 | q_outputs = q_encoder(**q_inputs) # (1, E) 87 | 88 | # Calculate similarity score & loss 89 | sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1)) # (1, E) x (E, N) = (1, 20) 90 | # target : position of positive samples = diagonal element 91 | if torch.cuda.is_available(): 92 | targets = targets.to('cuda') 93 | sim_scores = F.log_softmax(sim_scores, dim=1) 94 | _, preds = torch.max(sim_scores, 1) 95 | 96 | loss = criterion(sim_scores, targets) 97 | scaler.scale(loss).backward() 98 | 99 | batch_loss += loss.cpu().item() 100 | batch_acc += torch.sum(preds.cpu() == targets.cpu()) 101 | return p_encoder, q_encoder, batch_loss, batch_acc 102 | 103 | def training(args, epoch, train_dataloader, p_encoder, q_encoder, criterion, scaler, optimizer, scheduler, logger): 104 | ## train 105 | epoch_iterator = tqdm(train_dataloader, desc="train Iteration") 106 | p_encoder.to('cuda').train() 107 | q_encoder.to('cuda').train() 108 | 109 | running_loss, running_acc, num_cnt = 0, 0, 0 110 | with torch.set_grad_enabled(True): 111 | for step, batch_list in enumerate(epoch_iterator): 112 | p_encoder, q_encoder, batch_loss, batch_acc = one_step_train(args, 113 | batch_list, 114 | p_encoder, 115 | q_encoder, 116 | criterion, 117 | scaler) 118 | running_loss += batch_loss/args.per_device_train_batch_size 119 | running_acc += batch_acc/args.per_device_train_batch_size 120 | num_cnt += 1 121 | 122 | if (step+1) % args.gradient_accumulation_steps == 0: 123 | log_step = epoch*len(epoch_iterator) + step 124 | scaler.step(optimizer) 125 | scaler.update() 126 | if scheduler is not None: 127 | scheduler.step() 128 | optimizer.zero_grad() 129 | p_encoder.zero_grad() 130 | q_encoder.zero_grad() 131 | 132 | logger.add_scalar(f"Train/loss", batch_loss/args.per_device_train_batch_size, log_step) 133 | logger.add_scalar(f"Train/accuracy", batch_acc/args.per_device_train_batch_size*100, log_step) 134 | 135 | epoch_loss = float(running_loss / num_cnt) 136 | epoch_acc = float((running_acc.double() / num_cnt).cpu()*100) 137 | print(f'global step-{log_step} | Loss: {epoch_loss:.4f} Accuracy: {epoch_acc:.2f}') 138 | return p_encoder, q_encoder, scaler, optimizer, scheduler 139 | 140 | def validation(args, epoch, valid_dataloader, p_encoder, q_encoder, logger, best_acc, run_name): 141 | ## valid 142 | epoch_iterator = tqdm(valid_dataloader, desc="valid Iteration") 143 | p_encoder.to('cuda').eval() 144 | q_encoder.to('cuda').eval() 145 | 146 | running_loss, running_acc, num_cnt = 0, 0, 0 147 | for step, batch in enumerate(epoch_iterator): 148 | with torch.set_grad_enabled(False): 149 | batch = tuple(t.squeeze(0) if i < 6 else t for i, t in enumerate(batch)) 150 | 151 | targets, top_k_id = batch[-2], batch[-1] 152 | if torch.cuda.is_available(): 153 | batch = tuple(t.to('cuda') for t in batch[:-2]) 154 | 155 | p_inputs = {'input_ids' : batch[0], 156 | 'attention_mask' : batch[1], 157 | 'token_type_ids': batch[2]} 158 | 159 | q_inputs = {'input_ids' : batch[3], 160 | 'attention_mask' : batch[4], 161 | 'token_type_ids': batch[5]} 162 | 163 | p_outputs = p_encoder(**p_inputs) # (N, E) 164 | q_outputs = q_encoder(**q_inputs) # (1, E) 165 | 166 | # Calculate similarity score & loss 167 | sim_scores = torch.matmul(q_outputs, torch.transpose(p_outputs, 0, 1)) # (1, E) x (E, N) = (1, N) 168 | sim_scores = F.log_softmax(sim_scores, dim=1) 169 | 170 | class_0 = torch.Tensor([1 if i.item() == 0 else 0 for idx, i in enumerate(top_k_id)]) 171 | w = (torch.sum(sim_scores, dim=1)*1/sim_scores.size()[1]).item() 172 | sim_scores -= w*class_0.unsqueeze(0).cuda() 173 | 174 | _, preds = torch.max(sim_scores, 1) 175 | if preds.item() in targets: 176 | running_acc += 1 177 | num_cnt += 1 178 | 179 | epoch_acc = float((running_acc / num_cnt)*100) 180 | logger.add_scalar(f"Val/accuracy", epoch_acc, epoch) 181 | print(f'Epoch-{epoch} | Accuracy: {epoch_acc:.2f}') 182 | 183 | if epoch_acc > best_acc: 184 | best_idx = epoch 185 | best_acc = epoch_acc 186 | 187 | save_path = os.path.join(args.output_dir, 'model') 188 | if not os.path.exists(save_path): 189 | os.mkdir(save_path) 190 | torch.save(p_encoder.cpu().state_dict(), os.path.join(save_path, f'p_{run_name}.pt')) 191 | torch.save(q_encoder.cpu().state_dict(), os.path.join(save_path, f'q_{run_name}.pt')) 192 | print(f'\t==> best model saved - {best_idx} / Accuracy: {best_acc:.2f}') 193 | return best_acc 194 | 195 | def train(args, p_encoder, q_encoder, train_dataloader, valid_dataloader, criterion, scaler, optimizer, scheduler, logger, run_name): 196 | # Start training! 197 | best_acc = 0.0 198 | 199 | train_iterator = trange(int(args.num_train_epochs), desc='Epoch') 200 | for epoch in train_iterator: 201 | optimizer.zero_grad() 202 | p_encoder.zero_grad() 203 | q_encoder.zero_grad() 204 | torch.cuda.empty_cache() 205 | 206 | p_encoder, q_encoder, scaler, optimizer, scheduler = training(args, epoch, train_dataloader, p_encoder, q_encoder, criterion, scaler, optimizer, scheduler, logger) 207 | best_acc = validation(args, epoch, valid_dataloader, p_encoder, q_encoder, logger, best_acc, run_name) 208 | return p_encoder, q_encoder 209 | 210 | def main(args): 211 | seed_everything(seed=args.seed) 212 | 213 | p_tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint) 214 | p_tokenizer.model_max_length = 1536 215 | q_tokenizer = AutoTokenizer.from_pretrained(args.model_checkpoint) 216 | 217 | training_dataset = get_pickle(f"../data/retrieval_dataset/Top{args.top_k}_preprocess_train.pkl") 218 | validation_dataset = get_pickle(f"../data/retrieval_dataset/Top{args.top_k}_preprocess_valid.pkl") 219 | 220 | train_dataset = TrainRetrievalDataset(training_dataset, p_tokenizer, q_tokenizer) 221 | valid_dataset = ValidRetrievalDataset(validation_dataset, p_tokenizer, q_tokenizer) 222 | 223 | p_encoder = Encoder(args.model_checkpoint) 224 | q_encoder = Encoder(args.model_checkpoint) 225 | 226 | if torch.cuda.is_available(): 227 | p_encoder.to('cuda') 228 | q_encoder.to('cuda') 229 | print('GPU enabled') 230 | 231 | training_args = TrainingArguments(output_dir=args.output_dir, 232 | evaluation_strategy='epoch', 233 | learning_rate=args.learning_rate, 234 | per_device_train_batch_size=16, 235 | per_device_eval_batch_size=1, 236 | gradient_accumulation_steps=args.gradient_accumulation_steps, 237 | num_train_epochs=args.epoch, 238 | weight_decay=0.01) 239 | 240 | # Dataloader 241 | train_sampler = RandomSampler(train_dataset) 242 | train_dataloader = DataLoader(train_dataset, 243 | sampler=train_sampler, 244 | batch_size=training_args.per_device_train_batch_size) 245 | 246 | valid_sampler = RandomSampler(valid_dataset) 247 | valid_dataloader = DataLoader(valid_dataset, 248 | sampler=valid_sampler, 249 | batch_size=training_args.per_device_eval_batch_size) 250 | 251 | # Optimizer 252 | no_decay = ['bias', 'LayerNorm.weight'] 253 | optimizer_grouped_parameters = [{'params': [p for n, p in p_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': training_args.weight_decay}, 254 | {'params': [p for n, p in p_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 255 | {'params': [p for n, p in q_encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': training_args.weight_decay}, 256 | {'params': [p for n, p in q_encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 257 | ] 258 | optimizer = AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate) 259 | scaler = GradScaler() 260 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=10, eta_min=1e-6) 261 | criterion = nn.NLLLoss() 262 | 263 | # -- logging 264 | log_dir = os.path.join(training_args.output_dir) 265 | if not os.path.exists(log_dir): 266 | os.mkdir(log_dir) 267 | else: 268 | raise NameError(f'Already Exists Directory >>> [ Path : {log_dir} ]') 269 | logger = SummaryWriter(log_dir=log_dir) 270 | 271 | p_encoder, q_encoder = train(training_args, p_encoder, q_encoder, train_dataloader, valid_dataloader, criterion, scaler, optimizer, scheduler, logger, args.run_name) 272 | print('complete !!') 273 | 274 | if __name__ == '__main__': 275 | parser = argparse.ArgumentParser() 276 | 277 | parser.add_argument('--output_dir', type=str, default='../retrieval_output/') 278 | parser.add_argument('--model_checkpoint', type=str, default='bert-base-multilingual-cased') 279 | parser.add_argument('--seed', type=int, default=2021) 280 | parser.add_argument('--epoch', type=int, default=10) 281 | parser.add_argument('--learning_rate', type=float, default=1e-5) 282 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1) 283 | parser.add_argument('--top_k', type=int, default=20) 284 | parser.add_argument('--run_name', type=str, default='best_dense_retrieval') 285 | 286 | args = parser.parse_args() 287 | 288 | if not os.path.exists(args.output_dir): 289 | os.mkdir(args.output_dir) 290 | args.output_dir = os.path.join(args.output_dir, args.run_name) 291 | 292 | print(f'Output Dir ::: {args.output_dir}') 293 | print(f'Model Checkpoint ::: {args.model_checkpoint}') 294 | print(f'Seed ::: {args.seed}') 295 | print(f'Epoch ::: {args.epoch}') 296 | print(f'Learning rate ::: {args.learning_rate}') 297 | print(f'Gradient Accumulation Steps ::: {args.gradient_accumulation_steps}') 298 | print(f'Dataset K Number ::: {args.top_k}') 299 | print(f'Run Name ::: {args.run_name}') 300 | 301 | main(args) -------------------------------------------------------------------------------- /code/inference.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import time 5 | import json 6 | 7 | import torch 8 | import random 9 | import numpy as np 10 | import pandas as pd 11 | import os 12 | import pickle 13 | 14 | from tqdm import tqdm 15 | from datasets import load_metric, load_from_disk, Sequence, Value, Features, Dataset, DatasetDict 16 | from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoder, DPRQuestionEncoderTokenizer, AdamW 17 | from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer 18 | from torch.utils.data import DataLoader, TensorDataset 19 | from konlpy.tag import Mecab 20 | from konlpy.tag import Kkma 21 | from konlpy.tag import Hannanum 22 | from sentence_transformers import SentenceTransformer 23 | import kss 24 | 25 | from transformers import ( 26 | DataCollatorWithPadding, 27 | EvalPrediction, 28 | HfArgumentParser, 29 | TrainingArguments, 30 | set_seed, 31 | ) 32 | 33 | from elasticsearch_retrieval import * 34 | from data_processing import DataProcessor 35 | from utils_qa import postprocess_qa_predictions, check_no_error, tokenize, cos_sim 36 | from trainer_qa import QuestionAnsweringTrainer 37 | from arguments import ( 38 | ModelArguments, 39 | DataTrainingArguments, 40 | ) 41 | 42 | def get_pickle(pickle_path): 43 | '''Custom Dataset을 Load하기 위한 함수''' 44 | f = open(pickle_path, "rb") 45 | dataset = pickle.load(f) 46 | f.close() 47 | return dataset 48 | 49 | def get_config(): 50 | """ 51 | get config 52 | 53 | Returns: 54 | model_args: model arguments 55 | data_args: data arguments 56 | training_args: training arguments 57 | """ 58 | parser = HfArgumentParser( 59 | (ModelArguments, DataTrainingArguments, TrainingArguments) 60 | ) 61 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 62 | 63 | return model_args, data_args, training_args 64 | 65 | def fix_seed(seed): 66 | """ 67 | fix_seed 68 | 69 | Args: 70 | seed (int): seed number 71 | """ 72 | random.seed(seed) 73 | os.environ['PYTHONHASHSEED'] = str(seed) 74 | np.random.seed(seed) 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed(seed) 77 | torch.backends.cudnn.deterministic = True 78 | torch.backends.cudnn.benchmark = True 79 | set_seed(seed) 80 | 81 | 82 | def get_model(model_args, training_args): 83 | """ 84 | get model 85 | 86 | Args: 87 | model_args : model arguments 88 | training_args : training arguments 89 | 90 | Returns: 91 | tokenizer, model 92 | """ 93 | tokenizer = AutoTokenizer.from_pretrained( 94 | model_args.tokenizer_name, 95 | use_fast=True 96 | ) 97 | model = torch.load(model_args.model_name_or_path) 98 | 99 | return tokenizer, model 100 | 101 | 102 | def run_elasticsearch(text_data, concat_num, model_args, is_sentence_trainformer): 103 | """ 104 | run elasticsearch and filter sentences 105 | 106 | Args: 107 | text_data 108 | concat_num: number of texts to import from elasticsearch 109 | is_sentence_trainformer: whether sentence trainformer is used or not 110 | 111 | Returns: 112 | datasets: test data 113 | scores: elasticsearch scores 114 | """ 115 | # elastic setting & load index 116 | es, index_name = elastic_setting(model_args.retrieval_elastic_index) 117 | # load sentence transformer model 118 | if is_sentence_trainformer: 119 | model = SentenceTransformer('sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens') 120 | question_texts = text_data["validation"]["question"] 121 | total = [] 122 | scores = [] 123 | 124 | pbar = tqdm(enumerate(question_texts), total=len(question_texts), position=0, leave=True) 125 | for step, question_text in pbar: 126 | # concat_num만큼 context 검색 127 | context_list = elastic_retrieval(es, index_name, question_text, concat_num) 128 | score = [] 129 | concat_context = [] 130 | 131 | if is_sentence_trainformer: 132 | # question embedding 133 | question_embedding = model.encode(question_text) 134 | # use sentence transformer 135 | for i in range(len(context_list)): 136 | temp_context = [] 137 | # separate context by sentence 138 | for sent in kss.split_sentences(context_list[i][0]): 139 | # question embedding과 sentence embedding의 cosine similarity 계산 140 | # -0.2 보다 높은 sentence만 append 141 | if cos_sim(question_embedding, model.encode(sent)) > -0.2: 142 | temp_context.append(sent) 143 | 144 | concat_context.append(" ".join(temp_context)) 145 | else: 146 | # not use sentence transformer 147 | for i in range(len(context_list)): 148 | concat_context.append(context_list[i][0]) 149 | 150 | tmp = { 151 | "question" : question_text, 152 | "id" : text_data["validation"]["id"][step], 153 | "context" : " ".join(concat_context) if is_sentence_trainformer else " ".join(concat_context) 154 | } 155 | 156 | score.append(context_list[0][1]) 157 | total.append(tmp) 158 | scores.append(score) 159 | 160 | df = pd.DataFrame(total) 161 | f = Features({'context': Value(dtype='string', id=None), 162 | 'id': Value(dtype='string', id=None), 163 | 'question': Value(dtype='string', id=None)}) 164 | datasets = DatasetDict({'validation': Dataset.from_pandas(df, features=f)}) 165 | 166 | return datasets, scores 167 | 168 | def run_concat_dense_retrival(text_data, concat_num): 169 | test_data = get_pickle("../data/test_ex70_dr70_dense.pkl") 170 | question_texts = text_data["validation"]["question"] 171 | total = [] 172 | scores = [] 173 | 174 | pbar = tqdm(enumerate(question_texts), total=len(question_texts), position=0, leave=True) 175 | for step, question_text in pbar: 176 | context_list = test_data[question_text][:concat_num] 177 | score = [] 178 | concat_context = "" 179 | # 유일하게 다른 부분 : context list를 concat 시켜주는 부분 180 | for i in range(len(context_list)): 181 | if i == 0 : 182 | concat_context += context_list[i][0] 183 | else: 184 | concat_context += " " + context_list[i][0] 185 | 186 | tmp = { 187 | "question" : question_text, 188 | "id" : text_data["validation"]["id"][step], 189 | "context" : concat_context 190 | } 191 | 192 | score.append(context_list[0][1]) 193 | total.append(tmp) 194 | scores.append(score) 195 | 196 | df = pd.DataFrame(total) 197 | f = Features({'context': Value(dtype='string', id=None), 198 | 'id': Value(dtype='string', id=None), 199 | 'question': Value(dtype='string', id=None)}) 200 | datasets = DatasetDict({'validation': Dataset.from_pandas(df, features=f)}) 201 | 202 | return datasets, scores 203 | 204 | def get_data(model_args, training_args, tokenizer, text_data_path = "../data/test_dataset"): # 경로 변경 ../data/test_dataset 205 | """ 206 | get data 207 | 208 | Args: 209 | model_args: model arguments 210 | training_args: training arguments 211 | tokenizer: tokenizer 212 | text_data_path: Defaults to "../data/test_dataset" 213 | 214 | Returns: 215 | text_data, val_iter, val_dataset, scores 216 | """ 217 | text_data = load_from_disk(text_data_path) 218 | 219 | # run_ lasticsearch 220 | if "elastic" in model_args.retrieval_type: 221 | is_sentence_trainformer = False 222 | if "sentence_trainformer" in model_args.retrieval_type: 223 | is_sentence_trainformer = True 224 | # number of text to concat 225 | concat_num = model_args.retrieval_elastic_num 226 | text_data, scores = run_elasticsearch(text_data, concat_num, model_args, is_sentence_trainformer) 227 | elif model_args.retrieval_type == "dense": 228 | concat_num = model_args.retrieval_elastic_num 229 | text_data, scores = run_concat_dense_retrival(text_data, concat_num) 230 | 231 | column_names = text_data["validation"].column_names 232 | 233 | data_collator = ( 234 | DataCollatorWithPadding( 235 | tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None 236 | ) 237 | ) 238 | # 데이터 tokenize(mrc 모델안에 들어 갈 수 있도록) 239 | data_processor = DataProcessor(tokenizer) 240 | val_text = text_data["validation"] 241 | val_dataset = data_processor.val_tokenzier(val_text, column_names) 242 | val_iter = DataLoader(val_dataset, collate_fn = data_collator, batch_size=1) 243 | 244 | return text_data, val_iter, val_dataset, scores 245 | 246 | 247 | def post_processing_function(features, predictions, text_data, data_args, training_args): 248 | """ 249 | post processing 250 | 251 | Args: 252 | features, predictions, text_data, data_args, training_args 253 | 254 | Returns: 255 | inference or evaluation results 256 | """ 257 | predictions = postprocess_qa_predictions( 258 | examples=text_data["validation"], 259 | features=features, 260 | predictions=predictions, 261 | max_answer_length=data_args.max_answer_length, 262 | output_dir=training_args.output_dir, 263 | ) 264 | 265 | formatted_predictions = [ 266 | {"id": k, "prediction_text": v} for k, v in predictions.items() 267 | ] 268 | if training_args.do_predict: 269 | return formatted_predictions 270 | 271 | elif training_args.do_eval: 272 | references = [ 273 | {"id": ex["id"], "answers": ex["answers"].strip()} 274 | for ex in text_data["validation"] 275 | ] 276 | return EvalPrediction(predictions=formatted_predictions, label_ids=references) 277 | 278 | 279 | def create_and_fill_np_array(start_or_end_logits, dataset, max_len): 280 | step = 0 281 | 282 | logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) 283 | 284 | for i, output_logit in enumerate(start_or_end_logits): 285 | batch_size = output_logit.shape[0] 286 | cols = output_logit.shape[1] 287 | 288 | if step + batch_size < len(dataset): 289 | logits_concat[step : step + batch_size, :cols] = output_logit 290 | else: 291 | logits_concat[step:, :cols] = output_logit[: len(dataset) - step] 292 | 293 | step += batch_size 294 | 295 | return logits_concat 296 | 297 | 298 | def predict(model, text_data, test_loader, test_dataset, model_args, data_args, training_args, device): 299 | """ 300 | Create prediction json using MRC model 301 | 302 | Args: 303 | model, text_data, test_loader, test_dataset, model_args, data_args, training_args, device 304 | """ 305 | 306 | metric = load_metric("squad") 307 | # xlm의 input 예외처리 308 | if "xlm" in model_args.tokenizer_name: 309 | test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids"]) 310 | else: 311 | test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) 312 | 313 | model.eval() 314 | 315 | all_start_logits = [] 316 | all_end_logits = [] 317 | 318 | t = time.time() 319 | # start predic 320 | pbar = tqdm(enumerate(test_loader), total=len(test_loader), position=0, leave=True) 321 | for step, batch in pbar: 322 | batch = batch.to(device) 323 | outputs = model(**batch) 324 | 325 | if model_args.use_custom_model: 326 | start_logits = outputs["start_logits"] 327 | end_logits = outputs["end_logits"] 328 | else: 329 | start_logits = outputs.start_logits 330 | end_logits = outputs.end_logits 331 | 332 | 333 | all_start_logits.append(start_logits.detach().cpu().numpy()) 334 | all_end_logits.append(end_logits.detach().cpu().numpy()) 335 | 336 | max_len = max(x.shape[1] for x in all_start_logits) 337 | 338 | start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len) 339 | end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len) 340 | 341 | del all_start_logits 342 | del all_end_logits 343 | 344 | test_dataset.set_format(type=None, columns=list(test_dataset.features.keys())) 345 | output_numpy = (start_logits_concat, end_logits_concat) 346 | prediction = post_processing_function(test_dataset, output_numpy, text_data, data_args, training_args) 347 | 348 | 349 | def remove_particle(training_args): 350 | """ 351 | remove particle 352 | 353 | Args: 354 | training_args 355 | """ 356 | # load tokenizer 357 | mecab = Mecab() 358 | kkma = Kkma() 359 | hannanum = Hannanum() 360 | # load prediction file 361 | with open(os.path.join(training_args.output_dir, "predictions.json"), "r") as f: 362 | prediction_json = json.load(f) 363 | 364 | prediction_dict = dict() 365 | for mrc_id in prediction_json.keys(): 366 | final_predictions = prediction_json[mrc_id] 367 | pos_tag = mecab.pos(final_predictions) 368 | 369 | # 조사가 있는 경우 삭제 370 | if final_predictions[-1] == "의": 371 | min_len = min(len(kkma.pos(final_predictions)[-1][0]), len(mecab.pos(final_predictions)[-1][0]), len(hannanum.pos(final_predictions)[-1][0])) 372 | if min_len == 1: 373 | final_predictions = final_predictions[:-1] 374 | elif pos_tag[-1][-1] in {"JX", "JKB", "JKO", "JKS", "ETM", "VCP", "JC"}: 375 | final_predictions = final_predictions[:-len(pos_tag[-1][0])] 376 | 377 | prediction_dict[str(mrc_id)] = final_predictions 378 | 379 | # save final results 380 | with open(os.path.join(training_args.output_dir, "final_predictions.json"), 'w', encoding='utf-8') as make_file: 381 | json.dump(prediction_dict, make_file, indent="\t", ensure_ascii=False) 382 | print(prediction_dict) 383 | 384 | 385 | def main(): 386 | # get arguments 387 | model_args, data_args, training_args = get_config() 388 | # fix seed 389 | fix_seed(training_args.seed) 390 | # set device 391 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 392 | 393 | # get tokenizer, model 394 | tokenizer, model = get_model(model_args, training_args) 395 | model.cuda() 396 | 397 | if not os.path.isdir(training_args.output_dir) : 398 | os.mkdir(training_args.output_dir) 399 | 400 | # load data 401 | text_data, test_loader, test_dataset, scores = get_data(model_args, training_args, tokenizer) 402 | # prediction 403 | predict(model, text_data, test_loader, test_dataset, model_args, data_args, training_args, device) 404 | # remove particle 405 | remove_particle(training_args) 406 | 407 | if __name__ == "__main__": 408 | main() -------------------------------------------------------------------------------- /code/train_mrc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pickle 5 | import random 6 | import logging 7 | 8 | import wandb 9 | import torch 10 | import numpy as np 11 | import pandas as pd 12 | from tqdm import tqdm 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from torch.cuda.amp import autocast, GradScaler 16 | from datasets import load_metric, load_from_disk, load_dataset 17 | from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, AdamW, get_cosine_with_hard_restarts_schedule_with_warmup 18 | from transformers import ( 19 | DataCollatorWithPadding, 20 | EvalPrediction, 21 | HfArgumentParser, 22 | TrainingArguments, 23 | set_seed, 24 | ) 25 | 26 | from model.ConvModel import ConvModel 27 | from model.QueryAttentionModel import QueryAttentionModel 28 | from model.QAConvModelV1 import QAConvModelV1 29 | from model.QAConvModelV2 import QAConvModelV2 30 | from utils_qa import postprocess_qa_predictions, check_no_error, tokenize, AverageMeter, last_processing 31 | from trainer_qa import QuestionAnsweringTrainer 32 | from arguments import ModelArguments, DataTrainingArguments 33 | from data_processing import DataProcessor 34 | # from prepare_dataset import make_custom_dataset 35 | 36 | 37 | def get_args() : 38 | '''훈련 시 입력한 각종 Argument를 반환하는 함수''' 39 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 40 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 41 | 42 | return model_args, data_args, training_args 43 | 44 | 45 | def set_seed_everything(seed): 46 | '''Random Seed를 고정하는 함수''' 47 | random.seed(seed) 48 | os.environ['PYTHONHASHSEED'] = str(seed) 49 | np.random.seed(seed) 50 | torch.manual_seed(seed) 51 | torch.cuda.manual_seed(seed) 52 | torch.backends.cudnn.deterministic = True 53 | torch.backends.cudnn.benchmark = True 54 | set_seed(seed) 55 | 56 | return None 57 | 58 | 59 | def get_model(model_args, training_args) : 60 | '''tokenizer, model_config, model, optimizer, scaler, shceduler를 반환하는 함수''' 61 | # Load pretrained model and tokenizer 62 | model_config = AutoConfig.from_pretrained( 63 | model_args.config_name 64 | if model_args.config_name 65 | else model_args.model_name_or_path, 66 | ) 67 | tokenizer = AutoTokenizer.from_pretrained( 68 | model_args.tokenizer_name 69 | if model_args.tokenizer_name 70 | else model_args.model_name_or_path, 71 | use_fast=True, 72 | ) 73 | 74 | if model_args.use_custom_model == 'ConvModel' : 75 | model = ConvModel(model_args.config_name, model_config, model_args.tokenizer_name) 76 | elif model_args.use_custom_model == 'QueryAttentionModel' : 77 | model = QueryAttentionModel(model_args.config_name, model_config, model_args.tokenizer_name) 78 | elif model_args.use_custom_model == 'QAConvModelV1' : 79 | model = QAConvModelV1(model_args.config_name, model_config, model_args.tokenizer_name) 80 | elif model_args.use_custom_model == 'QAConvModelV2' : 81 | model = QAConvModelV2(model_args.config_name, model_config, model_args.tokenizer_name) 82 | else: 83 | model = AutoModelForQuestionAnswering.from_pretrained( 84 | model_args.model_name_or_path, 85 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 86 | config=model_config, 87 | ) 88 | 89 | if model_args.use_pretrained_model: 90 | pretrained_model = torch.load(f'/opt/ml/output/{model_args.model_name_or_path}/{model_args.model_name_or_path}.pt') 91 | pretrained_model_state = deepcopy(pretrained_model.state_dict()) 92 | model.load_state_dict(pretrained_model_state) 93 | del pretrained_model 94 | 95 | optimizer = AdamW(model.parameters(), lr=training_args.learning_rate) 96 | scaler = GradScaler() 97 | scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=1000, num_training_steps=12820, num_cycles=2) 98 | 99 | return tokenizer, model_config, model, optimizer, scaler, scheduler 100 | 101 | 102 | def get_pickle(pickle_path): 103 | '''Custom Dataset을 Load하기 위한 함수''' 104 | f = open(pickle_path, "rb") 105 | dataset = pickle.load(f) 106 | f.close() 107 | 108 | return dataset 109 | 110 | def get_data(data_args, training_args, tokenizer) : 111 | '''train과 validation의 dataloader와 dataset를 반환하는 함수''' 112 | if data_args.dataset_name == 'basic' : 113 | if os.path.isdir("../data/train_dataset") : 114 | dataset = load_from_disk("../data/train_dataset") 115 | else : 116 | raise Exception ("Set the data path to 'p3-mrc-team-ikyo/data/.'") 117 | elif data_args.dataset_name == 'preprocessed' : 118 | if os.path.isfile("../data/preprocess_train.pkl") : 119 | dataset = get_pickle("../data/preprocess_train.pkl") 120 | else : 121 | dataset = make_custom_dataset("../data/preprocess_train.pkl") 122 | elif data_args.dataset_name == 'concat' : 123 | if os.path.isfile("../data/concat_train.pkl") : 124 | dataset = get_pickle("../data/concat_train.pkl") 125 | else : 126 | dataset = make_custom_dataset("../data/concat_train.pkl") 127 | elif data_args.dataset_name == 'korquad' : 128 | if os.path.isfile("../data/korquad_train.pkl") : 129 | dataset = get_pickle("../data/korquad_train.pkl") 130 | else : 131 | dataset = make_custom_dataset("../data/korquad_train.pkl") 132 | elif data_args.dataset_name == "question_type": 133 | if os.path.isfile("../data/question_type.pkl") : 134 | dataset = get_pickle("../data/question_type.pkl") 135 | else : 136 | dataset = make_custom_dataset("../data/question_type.pkl") 137 | elif data_args.dataset_name == "ai_hub": 138 | if os.path.isfile("../data/ai_hub_dataset.pkl") : 139 | dataset = get_pickle("../data/ai_hub_dataset.pkl") 140 | else : 141 | dataset = make_custom_dataset("../data/ai_hub_dataset.pkl") 142 | elif data_args.dataset_name == "only_korquad": 143 | dataset = load_dataset("squad_kor_v1") 144 | elif data_args.dataset_name == "random_masking": 145 | if os.path.isfile("../data/random_mask_train.pkl") : 146 | dataset = get_pickle("../data/random_mask_train.pkl") 147 | else : 148 | dataset = make_custom_dataset("../data/random_mask_train.pkl") 149 | elif data_args.dataset_name == "token_masking": 150 | if os.path.isfile("../data/concat_token_mask_top_3.pkl") : 151 | dataset = get_pickle("../data/concat_token_mask_top_3.pkl") 152 | else : 153 | dataset = make_mask_dataset("../data/concat_token_mask_top_3.pkl", tokenizer) 154 | train_dataset = dataset['train'] 155 | val_dataset = dataset['validation'] 156 | else : 157 | raise Exception ("dataset_name have to be one of ['basic', 'preprocessed', 'concat', 'korquad', 'only_korquad', 'question_type', 'ai_hub', 'random_masking', 'token_masking']") 158 | 159 | if data_args.dataset_name != "token_masking": 160 | train_dataset = dataset['train'] 161 | val_dataset = dataset['validation'] 162 | train_column_names = train_dataset.column_names 163 | val_column_names = val_dataset.column_names 164 | 165 | data_processor = DataProcessor(tokenizer, data_args.max_seq_length, data_args.doc_stride) 166 | train_dataset = data_processor.train_tokenizer(train_dataset, train_column_names) 167 | val_dataset = data_processor.val_tokenzier(val_dataset, val_column_names) 168 | 169 | data_collator = (DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)) 170 | train_iter = DataLoader(train_dataset, collate_fn = data_collator, batch_size=training_args.per_device_train_batch_size) 171 | val_iter = DataLoader(val_dataset, collate_fn = data_collator, batch_size=training_args.per_device_eval_batch_size) 172 | 173 | return dataset, train_iter, val_iter, train_dataset, val_dataset 174 | 175 | 176 | def post_processing_function(examples, features, predictions, text_data, data_args, training_args): 177 | '''Model의 Prediction을 Text 형태로 변환하는 함수''' 178 | predictions = postprocess_qa_predictions( 179 | examples=examples, 180 | features=features, 181 | predictions=predictions, 182 | max_answer_length=data_args.max_answer_length, 183 | output_dir=training_args.output_dir, 184 | ) 185 | 186 | formatted_predictions = [ 187 | {"id": k, "prediction_text": last_processing(v)} for k, v in predictions.items() 188 | ] 189 | if training_args.do_predict: 190 | return formatted_predictions 191 | 192 | references = [ 193 | {"id": ex["id"], "answers": ex["answers"]} 194 | for ex in text_data["validation"] 195 | ] 196 | return EvalPrediction(predictions=formatted_predictions, label_ids=references) 197 | 198 | 199 | def create_and_fill_np_array(start_or_end_logits, dataset, max_len): 200 | '''Model의 Logit을 Context 단위로 연결하기 위한 함수''' 201 | step = 0 202 | logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) 203 | 204 | for i, output_logit in enumerate(start_or_end_logits): 205 | batch_size = output_logit.shape[0] 206 | cols = output_logit.shape[1] 207 | if step + batch_size < len(dataset): 208 | logits_concat[step : step + batch_size, :cols] = output_logit 209 | else: 210 | logits_concat[step:, :cols] = output_logit[: len(dataset) - step] 211 | step += batch_size 212 | 213 | return logits_concat 214 | 215 | 216 | def custom_to_mask(batch, tokenizer): 217 | '''Question 부분에 Random Masking을 적용하는 함수''' 218 | mask_token = tokenizer.mask_token_id 219 | 220 | for i in range(len(batch["input_ids"])): 221 | # sep 토큰으로 question과 context가 나뉘어져 있다. 222 | sep_idx = np.where(batch["input_ids"][i].numpy() == tokenizer.sep_token_id) 223 | # q_ids = > 첫번째 sep 토큰위치 224 | q_ids = sep_idx[0][0] 225 | mask_idxs = set() 226 | while len(mask_idxs) < 1: 227 | # 1 ~ q_ids까지가 Question 위치 228 | ids = random.randrange(1, q_ids) 229 | mask_idxs.add(ids) 230 | 231 | for mask_idx in list(mask_idxs): 232 | batch["input_ids"][i][mask_idx] = mask_token 233 | 234 | return batch 235 | 236 | 237 | def cal_loss(start_positions, end_positions, start_logits, end_logits): 238 | '''MRC Task에서 Loss를 계산하는 함수''' 239 | total_loss =None 240 | if start_positions is not None and end_positions is not None: 241 | # If we are on multi-GPU, split add a dimension 242 | if len(start_positions.size()) > 1: 243 | start_positions = start_positions.squeeze(-1) 244 | if len(end_positions.size()) > 1: 245 | end_positions = end_positions.squeeze(-1) 246 | 247 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 248 | ignored_index = start_logits.size(1) 249 | start_positions.clamp_(0, ignored_index) 250 | end_positions.clamp_(0, ignored_index) 251 | 252 | loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) 253 | start_loss = loss_fct(start_logits, start_positions) 254 | end_loss = loss_fct(end_logits, end_positions) 255 | total_loss = (start_loss + end_loss) / 2 256 | return total_loss 257 | 258 | def cal_query_loss(question_type, query_logits) : 259 | '''Sub Task에서 Loss를 계산하는 함수''' 260 | return nn.CrossEntropyLoss()(query_logits, question_type)/5 261 | 262 | def training_per_step(model, optimizer, scaler, batch, model_args, data_args, training_args, tokenizer, device): 263 | '''매 step마다 학습을 하는 함수''' 264 | model.train() 265 | with autocast(): 266 | mask_props = 0.8 267 | mask_p = random.random() 268 | if mask_p < mask_props: 269 | # 확률 안에 들면 mask 적용 270 | batch = custom_to_mask(batch, tokenizer) 271 | 272 | batch = batch.to(device) 273 | outputs = model(**batch) 274 | 275 | # output안에 loss가 들어있는 형태 276 | if model_args.use_custom_model: 277 | loss = cal_loss(batch["start_positions"], batch["end_positions"], outputs["start_logits"], outputs["end_logits"]) 278 | if 'query_logits' in outputs.keys() and 'question_type' in batch.keys() : 279 | loss += cal_query_loss(batch['question_type'], outputs['query_logits']) 280 | else: 281 | loss = outputs.loss 282 | scaler.scale(loss).backward() 283 | scaler.step(optimizer) 284 | scaler.update() 285 | optimizer.zero_grad() 286 | 287 | return loss.item() 288 | 289 | 290 | def validating_per_steps(epoch, model, text_data, test_loader, test_dataset, model_args, data_args, training_args, device): 291 | '''특정 step마다 검증을 하는 함수''' 292 | metric = load_metric("squad") 293 | if "xlm" in model_args.tokenizer_name: 294 | test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids"]) 295 | else: 296 | test_dataset.set_format(type="torch", columns=["attention_mask", "input_ids", "token_type_ids"]) 297 | 298 | model.eval() 299 | all_start_logits = [] 300 | all_end_logits = [] 301 | 302 | for batch in test_loader : 303 | batch = batch.to(device) 304 | outputs = model(**batch) 305 | if model_args.use_custom_model: 306 | start_logits = outputs["start_logits"] 307 | end_logits = outputs["end_logits"] 308 | else: 309 | start_logits = outputs.start_logits 310 | end_logits = outputs.end_logits 311 | 312 | all_start_logits.append(start_logits.detach().cpu().numpy()) 313 | all_end_logits.append(end_logits.detach().cpu().numpy()) 314 | 315 | max_len = max(x.shape[1] for x in all_start_logits) 316 | 317 | start_logits_concat = create_and_fill_np_array(all_start_logits, test_dataset, max_len) 318 | end_logits_concat = create_and_fill_np_array(all_end_logits, test_dataset, max_len) 319 | 320 | del all_start_logits 321 | del all_end_logits 322 | 323 | test_dataset.set_format(type=None, columns=list(test_dataset.features.keys())) 324 | output_numpy = (start_logits_concat, end_logits_concat) 325 | prediction = post_processing_function(text_data["validation"], test_dataset, output_numpy, text_data, data_args, training_args) 326 | val_metric = metric.compute(predictions=prediction.predictions, references=prediction.label_ids) 327 | 328 | return val_metric 329 | 330 | 331 | def train_mrc(model, optimizer, scaler, text_data, train_loader, test_loader, train_dataset, test_dataset, scheduler, model_args, data_args, training_args, tokenizer, device): 332 | '''training과 validating을 진행하는 함수''' 333 | prev_f1 = 0 334 | prev_em = 0 335 | global_steps = 0 336 | train_loss = AverageMeter() 337 | for epoch in range(int(training_args.num_train_epochs)): 338 | pbar = tqdm(enumerate(train_loader), total=len(train_loader), position=0, leave=True) 339 | for step, batch in pbar: 340 | # training phase 341 | loss = training_per_step(model, optimizer, scaler, batch, model_args, data_args, training_args, tokenizer, device) 342 | train_loss.update(loss, len(batch['input_ids'])) 343 | global_steps += 1 344 | description = f"{epoch+1}epoch {global_steps: >5d}step | loss: {train_loss.avg: .4f} | best_f1: {prev_f1: .4f} | em : {prev_em: .4f}" 345 | pbar.set_description(description) 346 | if scheduler is not None : 347 | scheduler.step() 348 | 349 | # validating phase 350 | if global_steps % training_args.logging_steps == 0 : 351 | with torch.no_grad(): 352 | val_metric = validating_per_steps(epoch, model, text_data, test_loader, test_dataset, model_args, data_args, training_args, device) 353 | if val_metric["f1"] > prev_f1: 354 | torch.save(model, training_args.output_dir + f"/{training_args.run_name}.pt") 355 | prev_f1 = val_metric["f1"] 356 | prev_em = val_metric["exact_match"] 357 | wandb.log({ 358 | 'train/loss' : train_loss.avg, 359 | 'train/learning_rate' : scheduler.get_last_lr()[0] if scheduler is not None else training_args.learning_rate, 360 | 'eval/exact_match' : val_metric['exact_match'], 361 | 'eval/f1_score' : val_metric['f1'], 362 | 'global_steps': global_steps 363 | }) 364 | train_loss.reset() 365 | else : 366 | wandb.log({'global_steps':global_steps}) 367 | 368 | 369 | def main(): 370 | '''각종 설정 이후 train_mrc를 실행하는 함수''' 371 | model_args, data_args, training_args = get_args() 372 | training_args.output_dir = os.path.join(training_args.output_dir, training_args.run_name) 373 | set_seed_everything(training_args.seed) 374 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 375 | 376 | tokenizer, model_config, model, optimizer, scaler, scheduler = get_model(model_args, training_args) 377 | text_data, train_loader, val_loader, train_dataset, val_dataset = get_data(data_args, training_args, tokenizer) 378 | model.cuda() 379 | 380 | if not os.path.isdir(training_args.output_dir) : 381 | os.mkdir(training_args.output_dir) 382 | 383 | # set wandb 384 | os.environ['WANDB_LOG_MODEL'] = 'true' 385 | os.environ['WANDB_WATCH'] = 'all' 386 | os.environ['WANDB_SILENT'] = 'true' 387 | wandb.login() 388 | wandb.init(project='P3-MRC', entity='team-ikyo', name=training_args.run_name) 389 | 390 | train_mrc(model, optimizer, scaler, text_data, train_loader, val_loader, train_dataset, val_dataset, scheduler, model_args, data_args, training_args, tokenizer, device) 391 | 392 | 393 | if __name__ == "__main__": 394 | main() 395 | -------------------------------------------------------------------------------- /code/utils_qa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The HuggingFace Team All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Pre-processing 17 | Post-processing utilities for question answering. 18 | """ 19 | import collections 20 | import json 21 | import logging 22 | import os 23 | from numpy import dot 24 | from numpy.linalg import norm 25 | from typing import Optional, Tuple 26 | import math 27 | import re 28 | from datasets import Dataset 29 | 30 | import numpy as np 31 | from tqdm.auto import tqdm 32 | from konlpy.tag import Mecab, Kkma, Hannanum 33 | 34 | import torch 35 | from torch.optim.lr_scheduler import _LRScheduler 36 | import random 37 | from transformers import is_torch_available, PreTrainedTokenizerFast 38 | from transformers.trainer_utils import get_last_checkpoint 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | mecab = Mecab() 43 | def tokenize(text): 44 | """ 45 | """ 46 | return mecab.morphs(text) 47 | 48 | def cos_sim(A, B): 49 | ''' 50 | A,B의 cosine similarity 51 | :return: A,B의 cosine similarity 52 | ''' 53 | return dot(A, B) / (norm(A) * norm(B)) 54 | 55 | def random_masking(datasets): 56 | context_list = [] 57 | question_list = [] 58 | id_list = [] 59 | answer_list = [] 60 | question_type_list = [] 61 | 62 | # train 갯수만큼 iteration 63 | for i in tqdm(range(datasets["train"].num_rows)): 64 | text = datasets["train"][i]["question"] 65 | 66 | # 단어 기준 Masking 67 | for word, pos in mecab.pos(text): 68 | # 하나의 단어만 30% 확률로 Masking 69 | if pos in {"NNG", "NNP"} and (random.random() > 0.7): 70 | context_list.append(datasets["train"][i]["context"]) 71 | question_list.append(re.sub(word, "MASK", text)) # tokenizer.mask_token 72 | id_list.append(datasets["train"][i]["id"]) 73 | answer_list.append(datasets["train"][i]["answers"]) 74 | question_type_list.append(datasets["train"][i]["question_type"]) 75 | 76 | random.Random(2021).shuffle(context_list) 77 | random.Random(2021).shuffle(question_list) 78 | random.Random(2021).shuffle(id_list) 79 | random.Random(2021).shuffle(answer_list) 80 | random.Random(2021).shuffle(question_type_list) 81 | 82 | 83 | # list를 Dataset 형태로 변환 84 | datasets["train"] = Dataset.from_dict({"id" : id_list, 85 | "context": context_list, 86 | "question": question_list, 87 | "answers": answer_list, 88 | "question_type" : question_type_list}) 89 | 90 | return datasets["train"] # 3000 => 20000 91 | 92 | def set_seed(seed: int): 93 | """ 94 | Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if 95 | installed). 96 | 97 | Args: 98 | seed (:obj:`int`): The seed to set. 99 | """ 100 | random.seed(seed) 101 | np.random.seed(seed) 102 | os.environ['PYTHONHASHSEED'] = str(seed) 103 | if is_torch_available(): 104 | torch.manual_seed(seed) 105 | torch.cuda.manual_seed(seed) 106 | torch.cuda.manual_seed_all(seed) # if use multi-GPU 107 | torch.backends.cudnn.deterministic = True 108 | torch.backends.cudnn.benchmark = False 109 | 110 | 111 | def postprocess_qa_predictions( 112 | examples, 113 | features, 114 | predictions: Tuple[np.ndarray, np.ndarray], 115 | version_2_with_negative: bool = False, 116 | n_best_size: int = 20, 117 | max_answer_length: int = 30, 118 | null_score_diff_threshold: float = 0.0, 119 | output_dir: Optional[str] = None, 120 | prefix: Optional[str] = None, 121 | is_world_process_zero: bool = True, 122 | ): 123 | """ 124 | Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the 125 | original contexts. This is the base postprocessing functions for models that only return start and end logits. 126 | 127 | Args: 128 | examples: The non-preprocessed dataset (see the main script for more information). 129 | features: The processed dataset (see the main script for more information). 130 | predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): 131 | The predictions of the model: two arrays containing the start logits and the end logits respectively. Its 132 | first dimension must match the number of elements of :obj:`features`. 133 | version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): 134 | Whether or not the underlying dataset contains examples with no answers. 135 | n_best_size (:obj:`int`, `optional`, defaults to 20): 136 | The total number of n-best predictions to generate when looking for an answer. 137 | max_answer_length (:obj:`int`, `optional`, defaults to 30): 138 | The maximum length of an answer that can be generated. This is needed because the start and end predictions 139 | are not conditioned on one another. 140 | null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): 141 | The threshold used to select the null answer: if the best answer has a score that is less than the score of 142 | the null answer minus this threshold, the null answer is selected for this example (note that the score of 143 | the null answer for an example giving several features is the minimum of the scores for the null answer on 144 | each feature: all features must be aligned on the fact they `want` to predict a null answer). 145 | 146 | Only useful when :obj:`version_2_with_negative` is :obj:`True`. 147 | output_dir (:obj:`str`, `optional`): 148 | If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if 149 | :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null 150 | answers, are saved in `output_dir`. 151 | prefix (:obj:`str`, `optional`): 152 | If provided, the dictionaries mentioned above are saved with `prefix` added to their names. 153 | is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): 154 | Whether this process is the main process or not (used to determine if logging/saves should be done). 155 | """ 156 | assert ( 157 | len(predictions) == 2 158 | ), "`predictions` should be a tuple with two elements (start_logits, end_logits)." 159 | all_start_logits, all_end_logits = predictions 160 | 161 | assert len(predictions[0]) == len( 162 | features 163 | ), f"Got {len(predictions[0])} predictions and {len(features)} features." 164 | 165 | # Build a map example to its corresponding features. 166 | example_id_to_index = {k: i for i, k in enumerate(examples["id"])} 167 | features_per_example = collections.defaultdict(list) 168 | for i, feature in enumerate(features): 169 | features_per_example[example_id_to_index[feature["example_id"]]].append(i) 170 | 171 | # The dictionaries we have to fill. 172 | all_predictions = collections.OrderedDict() 173 | all_nbest_json = collections.OrderedDict() 174 | if version_2_with_negative: 175 | scores_diff_json = collections.OrderedDict() 176 | 177 | # Logging. 178 | logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) 179 | logger.info( 180 | f"Post-processing {len(examples)} example predictions split into {len(features)} features." 181 | ) 182 | 183 | # Let's loop over all the examples! 184 | for example_index, example in enumerate(examples): 185 | # Those are the indices of the features associated to the current example. 186 | feature_indices = features_per_example[example_index] 187 | 188 | min_null_prediction = None 189 | prelim_predictions = [] 190 | 191 | # Looping through all the features associated to the current example. 192 | for feature_index in feature_indices: 193 | # We grab the predictions of the model for this feature. 194 | start_logits = all_start_logits[feature_index] 195 | end_logits = all_end_logits[feature_index] 196 | # This is what will allow us to map some the positions in our logits to span of texts in the original 197 | # context. 198 | offset_mapping = features[feature_index]["offset_mapping"] 199 | # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context 200 | # available in the current feature. 201 | token_is_max_context = features[feature_index].get( 202 | "token_is_max_context", None 203 | ) 204 | 205 | # Update minimum null prediction. 206 | feature_null_score = start_logits[0] + end_logits[0] 207 | if ( 208 | min_null_prediction is None 209 | or min_null_prediction["score"] > feature_null_score 210 | ): 211 | min_null_prediction = { 212 | "offsets": (0, 0), 213 | "score": feature_null_score, 214 | "start_logit": start_logits[0], 215 | "end_logit": end_logits[0], 216 | } 217 | 218 | # Go through all possibilities for the `n_best_size` greater start and end logits. 219 | start_indexes = np.argsort(start_logits)[ 220 | -1 : -n_best_size - 1 : -1 221 | ].tolist() 222 | end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() 223 | for start_index in start_indexes: 224 | for end_index in end_indexes: 225 | # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond 226 | # to part of the input_ids that are not in the context. 227 | if ( 228 | start_index >= len(offset_mapping) 229 | or end_index >= len(offset_mapping) 230 | or offset_mapping[start_index] is None 231 | or offset_mapping[end_index] is None 232 | ): 233 | continue 234 | # Don't consider answers with a length that is either < 0 or > max_answer_length. 235 | if ( 236 | end_index < start_index 237 | or end_index - start_index + 1 > max_answer_length 238 | ): 239 | continue 240 | # Don't consider answer that don't have the maximum context available (if such information is 241 | # provided). 242 | if ( 243 | token_is_max_context is not None 244 | and not token_is_max_context.get(str(start_index), False) 245 | ): 246 | continue 247 | prelim_predictions.append( 248 | { 249 | "offsets": ( 250 | offset_mapping[start_index][0], 251 | offset_mapping[end_index][1], 252 | ), 253 | "score": start_logits[start_index] + end_logits[end_index], 254 | "start_logit": start_logits[start_index], 255 | "end_logit": end_logits[end_index], 256 | } 257 | ) 258 | if version_2_with_negative: 259 | # Add the minimum null prediction 260 | prelim_predictions.append(min_null_prediction) 261 | null_score = min_null_prediction["score"] 262 | 263 | # Only keep the best `n_best_size` predictions. 264 | predictions = sorted( 265 | prelim_predictions, key=lambda x: x["score"], reverse=True 266 | )[:n_best_size] 267 | 268 | # Add back the minimum null prediction if it was removed because of its low score. 269 | if version_2_with_negative and not any( 270 | p["offsets"] == (0, 0) for p in predictions 271 | ): 272 | predictions.append(min_null_prediction) 273 | 274 | # Use the offsets to gather the answer text in the original context. 275 | context = example["context"] 276 | for pred in predictions: 277 | offsets = pred.pop("offsets") 278 | pred["text"] = context[offsets[0] : offsets[1]] 279 | 280 | # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid 281 | # failure. 282 | if len(predictions) == 0 or ( 283 | len(predictions) == 1 and predictions[0]["text"] == "" 284 | ): 285 | predictions.insert( 286 | 0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0} 287 | ) 288 | 289 | # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using 290 | # the LogSumExp trick). 291 | scores = np.array([pred.pop("score") for pred in predictions]) 292 | exp_scores = np.exp(scores - np.max(scores)) 293 | probs = exp_scores / exp_scores.sum() 294 | 295 | # Include the probabilities in our predictions. 296 | for prob, pred in zip(probs, predictions): 297 | pred["probability"] = prob 298 | 299 | # Pick the best prediction. If the null answer is not possible, this is easy. 300 | if not version_2_with_negative: 301 | all_predictions[example["id"]] = predictions[0]["text"] 302 | else: 303 | # Otherwise we first need to find the best non-empty prediction. 304 | i = 0 305 | while predictions[i]["text"] == "": 306 | i += 1 307 | best_non_null_pred = predictions[i] 308 | 309 | # Then we compare to the null prediction using the threshold. 310 | score_diff = ( 311 | null_score 312 | - best_non_null_pred["start_logit"] 313 | - best_non_null_pred["end_logit"] 314 | ) 315 | scores_diff_json[example["id"]] = float( 316 | score_diff 317 | ) # To be JSON-serializable. 318 | if score_diff > null_score_diff_threshold: 319 | all_predictions[example["id"]] = "" 320 | else: 321 | all_predictions[example["id"]] = best_non_null_pred["text"] 322 | 323 | # Make `predictions` JSON-serializable by casting np.float back to float. 324 | all_nbest_json[example["id"]] = [ 325 | { 326 | k: ( 327 | float(v) 328 | if isinstance(v, (np.float16, np.float32, np.float64)) 329 | else v 330 | ) 331 | for k, v in pred.items() 332 | } 333 | for pred in predictions 334 | ] 335 | 336 | # If we have an output_dir, let's save all those dicts. 337 | if output_dir is not None: 338 | assert os.path.isdir(output_dir), f"{output_dir} is not a directory." 339 | 340 | prediction_file = os.path.join( 341 | output_dir, 342 | "predictions.json" if prefix is None else f"predictions_{prefix}".json, 343 | ) 344 | nbest_file = os.path.join( 345 | output_dir, 346 | "nbest_predictions.json" 347 | if prefix is None 348 | else f"nbest_predictions_{prefix}".json, 349 | ) 350 | if version_2_with_negative: 351 | null_odds_file = os.path.join( 352 | output_dir, 353 | "null_odds.json" if prefix is None else f"null_odds_{prefix}".json, 354 | ) 355 | 356 | logger.info(f"Saving predictions to {prediction_file}.") 357 | with open(prediction_file, "w") as writer: 358 | writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n") 359 | logger.info(f"Saving nbest_preds to {nbest_file}.") 360 | with open(nbest_file, "w") as writer: 361 | writer.write(json.dumps(all_nbest_json, indent=4, ensure_ascii=False) + "\n") 362 | if version_2_with_negative: 363 | logger.info(f"Saving null_odds to {null_odds_file}.") 364 | with open(null_odds_file, "w") as writer: 365 | writer.write(json.dumps(scores_diff_json, indent=4, ensure_ascii=False) + "\n") 366 | 367 | return all_predictions 368 | 369 | 370 | def check_no_error(training_args, data_args, tokenizer, datasets): 371 | # Detecting last checkpoint. 372 | last_checkpoint = None 373 | if ( 374 | os.path.isdir(training_args.output_dir) 375 | and training_args.do_train 376 | and not training_args.overwrite_output_dir 377 | ): 378 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 379 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 380 | raise ValueError( 381 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 382 | "Use --overwrite_output_dir to overcome." 383 | ) 384 | elif last_checkpoint is not None: 385 | logger.info( 386 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 387 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 388 | ) 389 | 390 | # Tokenizer check: this script requires a fast tokenizer. 391 | if not isinstance(tokenizer, PreTrainedTokenizerFast): 392 | raise ValueError( 393 | "This example script only works for models that have a fast tokenizer. Checkout the big table of models " 394 | "at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this " 395 | "requirement" 396 | ) 397 | 398 | if data_args.max_seq_length > tokenizer.model_max_length: 399 | logger.warn( 400 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 401 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 402 | ) 403 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 404 | 405 | if "validation" not in datasets: 406 | raise ValueError("--do_eval requires a validation dataset") 407 | return last_checkpoint, max_seq_length 408 | 409 | 410 | class AverageMeter(object): 411 | """Computes and stores the average and current value""" 412 | def __init__(self): 413 | self.reset() 414 | def reset(self): 415 | self.val = 0 416 | self.avg = 0 417 | self.sum = 0 418 | self.count = 0 419 | def update(self, val, n=1): 420 | self.val = val 421 | self.sum += val * n 422 | self.count += n 423 | self.avg = self.sum / self.count 424 | 425 | 426 | def last_processing(text): 427 | """ 428 | 조사 버리기 429 | Args: 430 | text (str): 조사가 있는 text 431 | Returns: 432 | [str]: 필요 없는 조사 제거 433 | """ 434 | mecab = Mecab() 435 | kkma = Kkma() 436 | hannanum = Hannanum() 437 | 438 | pos_tag = mecab.pos(text) 439 | 440 | # last word(조사)에 있는 단어고 형태소 분석 결과가 j일경우 삭제 441 | if not pos_tag : 442 | pass 443 | elif text[-1] == "의": 444 | min_len = min(len(kkma.pos(text)[-1][0]), len(mecab.pos(text)[-1][0]), len(hannanum.pos(text)[-1][0])) 445 | if min_len == 1: 446 | text = text[:-1] 447 | elif pos_tag[-1][-1] in {"JX", "JKB", "JKO", "JKS", "ETM", "VCP", "JC"}: 448 | text = text[:-len(pos_tag[-1][0])] 449 | return text --------------------------------------------------------------------------------