├── data ├── .txt └── .gitignore ├── img ├── captum.png ├── demo_view.png ├── entangled.png ├── seperated.png ├── captum_example.png ├── demo_example.png └── confusion_matrix.png ├── voc_classifier ├── __pycache__ │ ├── config.cpython-38.pyc │ └── preprocess.cpython-38.pyc ├── demo.py ├── config.py ├── .ipynb_checkpoints │ ├── config-checkpoint.py │ ├── preprocess-checkpoint.py │ └── kobert_multilabel_text_classifier-checkpoint.ipynb ├── captum_tools_vocvis.py ├── inference.py ├── bert_model.py ├── preprocess.py ├── metrics_for_multilabel.py └── kobert_multilabel_text_classifier.ipynb ├── requirements.txt ├── README.md └── LICENSE /data/.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /img/captum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/captum.png -------------------------------------------------------------------------------- /img/demo_view.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/demo_view.png -------------------------------------------------------------------------------- /img/entangled.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/entangled.png -------------------------------------------------------------------------------- /img/seperated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/seperated.png -------------------------------------------------------------------------------- /img/captum_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/captum_example.png -------------------------------------------------------------------------------- /img/demo_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/demo_example.png -------------------------------------------------------------------------------- /img/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/confusion_matrix.png -------------------------------------------------------------------------------- /voc_classifier/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/voc_classifier/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /voc_classifier/__pycache__/preprocess.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/voc_classifier/__pycache__/preprocess.cpython-38.pyc -------------------------------------------------------------------------------- /voc_classifier/demo.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from inference import load_model 4 | 5 | model=load_model() 6 | 7 | bert=model.get_model() 8 | 9 | bert.eval() 10 | 11 | 12 | import pickle 13 | 14 | raw=pickle.load(open("../data/raw_text.pkl","rb")) 15 | 16 | data=pickle.load(open("../data/preprocessed_text.pkl","rb")) 17 | 18 | def process_for_inference(num_of_text): 19 | raw_to_show=raw.text.iloc[num_of_text] 20 | text_input=data.text.iloc[num_of_text] 21 | input_text_label=data.iloc[num_of_text,1:].tolist() 22 | return raw_to_show, text_input, input_text_label 23 | 24 | 25 | number_for_demo = st.text_input('태그를 생성할 텍스트의 번호를 입력해 주세요.') 26 | if number_for_demo: 27 | num_of_text=int(number_for_demo) 28 | Raw,Input,Label=process_for_inference(num_of_text) 29 | for i in range(62,len(Raw),62): 30 | Raw=Raw[:i]+"\n"+Raw[i:] 31 | st.text(Raw) 32 | st.text("\n") 33 | st.text("==================예측 결과==================") 34 | st.text("\n") 35 | 36 | a,b=model.get_prediction_from_txt(Input,Label) 37 | st.text(a) 38 | st.text(b) -------------------------------------------------------------------------------- /voc_classifier/config.py: -------------------------------------------------------------------------------- 1 | # config file for DeepVOC model 2 | 3 | import os 4 | 5 | 6 | DATA_PATH="../data/voc_data.xlsx" 7 | 8 | 9 | 10 | def expand_pandas(max_rows=100, max_cols=500, width=None, max_info_cols=None): 11 | import pandas as pd 12 | if max_rows: 13 | pd.set_option("display.max_rows", max_rows) # 출력할 최대 행 갯수를 100으로 설정 14 | if max_cols: 15 | pd.set_option("display.max_columns", max_cols) # 출력할 최대 열 갯수를 500개로 설정 16 | if width: 17 | pd.set_option("display.width", width) # 글자 수 기준 출력할 넓이 설정 18 | if max_info_cols: 19 | pd.set_option("max_info_columns", max_info_cols) # 열 기반 info가 주어질 경우, 최대 넓이 20 | pd.set_option("display.float_format", lambda x : "%.3f" %x) # 출력할 float의 소숫점 자릿수 제한 21 | print("done") 22 | 23 | 24 | 25 | model_config={"max_len" :512, 26 | "batch_size":5, 27 | "warmup_ratio": 0.1, 28 | "num_epochs": 200, 29 | "max_grad_norm": 1, 30 | "learning_rate": 5e-6, 31 | "dr_rate":0.45} 32 | 33 | 34 | label_cols=['국내선', 35 | '스케줄/기종변경', 36 | '항공권규정', 37 | '사전좌석배정', 38 | '환불', 39 | '홈페이지', 40 | '유상변경/취소', 41 | '부가서비스', 42 | '일정표/영수증', 43 | '무상변경/취소', 44 | '무상환불', 45 | '대기예약', 46 | '유상환불', 47 | '운임', 48 | '무상신규예약', 49 | '무응답', 50 | '재발행'] -------------------------------------------------------------------------------- /voc_classifier/.ipynb_checkpoints/config-checkpoint.py: -------------------------------------------------------------------------------- 1 | # config file for DeepVOC model 2 | 3 | import os 4 | 5 | 6 | DATA_PATH="../data/voc_data.xlsx" 7 | 8 | 9 | 10 | def expand_pandas(max_rows=100, max_cols=500, width=None, max_info_cols=None): 11 | import pandas as pd 12 | if max_rows: 13 | pd.set_option("display.max_rows", max_rows) # 출력할 최대 행 갯수를 100으로 설정 14 | if max_cols: 15 | pd.set_option("display.max_columns", max_cols) # 출력할 최대 열 갯수를 500개로 설정 16 | if width: 17 | pd.set_option("display.width", width) # 글자 수 기준 출력할 넓이 설정 18 | if max_info_cols: 19 | pd.set_option("max_info_columns", max_info_cols) # 열 기반 info가 주어질 경우, 최대 넓이 20 | pd.set_option("display.float_format", lambda x : "%.3f" %x) # 출력할 float의 소숫점 자릿수 제한 21 | print("done") 22 | 23 | 24 | 25 | model_config={"max_len" :512, 26 | "batch_size":5, 27 | "warmup_ratio": 0.1, 28 | "num_epochs": 200, 29 | "max_grad_norm": 1, 30 | "learning_rate": 5e-6, 31 | "dr_rate":0.45} 32 | 33 | 34 | label_cols=['국내선', 35 | '스케줄/기종변경', 36 | '항공권규정', 37 | '사전좌석배정', 38 | '환불', 39 | '홈페이지', 40 | '유상변경/취소', 41 | '부가서비스', 42 | '일정표/영수증', 43 | '무상변경/취소', 44 | '무상환불', 45 | '대기예약', 46 | '유상환불', 47 | '운임', 48 | '무상신규예약', 49 | '무응답', 50 | '재발행'] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | altair==4.1.0 2 | anyio==3.3.0 3 | argcomplete==1.12.3 4 | argon2-cffi==20.1.0 5 | astor==0.8.1 6 | async-generator==1.10 7 | attrs==21.2.0 8 | Babel==2.9.1 9 | backcall==0.2.0 10 | backports.zoneinfo==0.2.1 11 | base58==2.1.0 12 | bleach==4.0.0 13 | blinker==1.4 14 | Bottleneck==1.3.2 15 | cachetools==4.2.2 16 | captum==0.4.0 17 | certifi==2021.5.30 18 | cffi==1.14.6 19 | chardet==3.0.4 20 | charset-normalizer==2.0.4 21 | click==7.1.2 22 | colorama==0.4.4 23 | cycler==0.10.0 24 | Cython==0.29.24 25 | debugpy==1.4.1 26 | decorator==5.0.9 27 | defusedxml==0.7.1 28 | entrypoints==0.3 29 | et-xmlfile==1.1.0 30 | filelock==3.0.12 31 | gitdb==4.0.7 32 | GitPython==3.1.18 33 | gluonnlp==0.10.0 34 | graphviz==0.8.4 35 | idna==2.6 36 | importlib-metadata==4.6.3 37 | ipykernel==6.1.0 38 | ipython==7.26.0 39 | ipython-genutils==0.2.0 40 | ipywidgets==7.6.3 41 | jedi==0.18.0 42 | Jinja2==3.0.1 43 | joblib==1.0.1 44 | json5==0.9.6 45 | jsonschema==3.2.0 46 | jupyter-client==6.1.12 47 | jupyter-core==4.7.1 48 | jupyter-server==1.10.2 49 | jupyterlab==3.1.6 50 | jupyterlab-pygments==0.1.2 51 | jupyterlab-server==2.7.0 52 | jupyterlab-widgets==1.0.0 53 | kiwisolver==1.3.1 54 | MarkupSafe==2.0.1 55 | matplotlib==3.4.3 56 | matplotlib-inline==0.1.2 57 | mistune==0.8.4 58 | mkl-fft==1.3.0 59 | mkl-random @ file:///C:/ci/mkl_random_1626186163140/work 60 | mkl-service==2.4.0 61 | mxnet==1.7.0.post2 62 | nbclassic==0.3.1 63 | nbclient==0.5.3 64 | nbconvert==6.1.0 65 | nbformat==5.1.3 66 | nest-asyncio==1.5.1 67 | notebook==6.4.3 68 | numexpr @ file:///C:/ci/numexpr_1618856761305/work 69 | numpy==1.17.3 70 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work 71 | openpyxl==3.0.7 72 | packaging==21.0 73 | pandas @ file:///C:/ci/pandas_1627570311072/work 74 | pandocfilters==1.4.3 75 | parso==0.8.2 76 | pickleshare==0.7.5 77 | Pillow @ file:///C:/ci/pillow_1625663286921/work 78 | prometheus-client==0.11.0 79 | prompt-toolkit==3.0.19 80 | protobuf==3.17.3 81 | pyarrow==5.0.0 82 | pycparser==2.20 83 | pydeck==0.6.2 84 | Pygments==2.9.0 85 | pyparsing==2.4.7 86 | pyrsistent==0.18.0 87 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work 88 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work 89 | pywin32==301 90 | pywinpty==1.1.3 91 | pyzmq==22.2.1 92 | regex==2021.8.3 93 | requests==2.18.4 94 | requests-unixsocket==0.2.0 95 | sacremoses==0.0.45 96 | scikit-learn==0.24.2 97 | scipy==1.7.1 98 | Send2Trash==1.8.0 99 | sentencepiece==0.1.96 100 | six @ file:///tmp/build/80754af9/six_1623709665295/work 101 | smmap==4.0.0 102 | sniffio==1.2.0 103 | streamlit==0.86.0 104 | terminado==0.11.0 105 | testpath==0.5.0 106 | threadpoolctl==2.2.0 107 | tokenizers==0.8.0rc4 108 | toml==0.10.2 109 | toolz==0.11.1 110 | torch==1.9.0 111 | torchaudio==0.9.0 112 | torchtext==0.10.0 113 | torchvision==0.10.0 114 | tornado==6.1 115 | tqdm==4.62.0 116 | traitlets==5.0.5 117 | transformers==3.0.1 118 | typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1622748266870/work 119 | tzdata==2021.1 120 | tzlocal==3.0 121 | urllib3==1.22 122 | validators==0.18.2 123 | watchdog==2.1.3 124 | wcwidth==0.2.5 125 | webencodings==0.5.1 126 | websocket-client==1.2.0 127 | widgetsnbextension==3.5.1 128 | wincertstore==0.2 129 | zipp==3.5.0 130 | -------------------------------------------------------------------------------- /voc_classifier/captum_tools_vocvis.py: -------------------------------------------------------------------------------- 1 | from IPython.core.display import HTML, display 2 | HAS_IPYTHON = True 3 | 4 | def format_special_tokens(token): 5 | if token.startswith("<") and token.endswith(">"): 6 | return "#" + token.strip("<>") 7 | return token 8 | 9 | 10 | def format_tooltip(item, text): 11 | return '
{item}\ 12 | {text}\ 13 |
'.format( 14 | item=item, text=text 15 | ) 16 | def format_classname(classname): 17 | return '{}'.format(classname) 18 | def _get_color(attr): 19 | # clip values to prevent CSS errors (Values should be from [-1,1]) 20 | attr = max(-1, min(1, attr)) 21 | if attr > 0: 22 | hue = 120 23 | sat = 75 24 | lig = 100 - int(50 * attr) 25 | else: 26 | hue = 0 27 | sat = 75 28 | lig = 100 - int(-40 * attr) 29 | return "hsl({}, {}%, {}%)".format(hue, sat, lig) 30 | min_len=64 31 | def format_word_importances(words, importances): 32 | try: 33 | if importances is None or len(importances) == 0: 34 | return "" 35 | 36 | if len(words) > len(importances): 37 | words=words[:min_len] 38 | 39 | 40 | tags = [""] 41 | for word, importance in zip(words, importances[: len(words)]): 42 | word = format_special_tokens(word) 43 | color = _get_color(importance) 44 | unwrapped_tag = ' {word}\ 46 | '.format( 47 | color=color, word=word 48 | ) 49 | tags.append(unwrapped_tag) 50 | tags.append("") 51 | return "".join(tags) 52 | except Exception as e: 53 | print("skip it", e) 54 | 55 | 56 | def visualize_text(datarecords, legend=True): 57 | assert HAS_IPYTHON, ( 58 | "IPython must be available to visualize text. " 59 | "Please run 'pip install ipython'." 60 | ) 61 | dom = [""] 62 | rows = [ 63 | "" 64 | "" 65 | "" 66 | "" 67 | "" 68 | ] 69 | cnt=0 70 | for datarecord in datarecords: 71 | cnt+=1 72 | try: 73 | rows.append( 74 | "".join( 75 | [ 76 | "", 77 | format_classname(datarecord.true_class), 78 | format_classname( 79 | "{0} ({1:.2f})".format( 80 | datarecord.pred_class, datarecord.pred_prob 81 | ) 82 | ), 83 | format_classname(datarecord.attr_class), 84 | format_classname("{0:.2f}".format(datarecord.attr_score)), 85 | format_word_importances( 86 | datarecord.raw_input, datarecord.word_attributions 87 | ), 88 | "", 89 | ] 90 | ) 91 | ) 92 | except Exception as e: 93 | print(f"Error in {cnt}",e) 94 | 95 | if legend: 96 | dom.append( 97 | '
' 99 | ) 100 | dom.append("Legend: ") 101 | 102 | for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]): 103 | dom.append( 104 | ' {label} '.format( 107 | value=_get_color(value), label=label 108 | ) 109 | ) 110 | dom.append("
") 111 | 112 | dom.append("".join(rows)) 113 | dom.append("
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
") 114 | display(HTML("".join(dom))) 115 | -------------------------------------------------------------------------------- /voc_classifier/inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | # Input Data 가공 파트 5 | 6 | import torchtext 7 | import pandas as pd 8 | import numpy as np 9 | 10 | import os 11 | import re 12 | 13 | import config 14 | from config import expand_pandas 15 | from preprocess import preprocess 16 | 17 | DATA_PATH=config.DATA_PATH 18 | 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | num_class=17 25 | ver_num=2 26 | 27 | except_labels=["변경/취소","예약기타"] 28 | 29 | version_info="{:02d}".format(ver_num) 30 | # weight_path=f"../weights/Deep_Voice_{version_info}.pt" 31 | 32 | 33 | weight_path="../weights/weight_01.pt" 34 | 35 | 36 | 37 | 38 | # KoBERT 모델 39 | import torch 40 | from torch import nn 41 | import torch.nn.functional as F 42 | import torch.optim as optim 43 | from torch.utils.data import Dataset, DataLoader 44 | import gluonnlp as nlp 45 | from tqdm import tqdm, tqdm_notebook 46 | 47 | from KoBERT.kobert.utils import get_tokenizer 48 | from KoBERT.kobert.pytorch_kobert import get_pytorch_kobert_model 49 | 50 | from transformers import AdamW 51 | 52 | class BERTClassifier(nn.Module): 53 | def __init__(self, bert, hidden_size = 768, num_classes = 8, dr_rate = None, params = None): 54 | # BERTClassifier의 자식 노드에 클래스 속성을 상속시켜준다? 55 | # 인풋으로 넣는 bert 모델을 클래스 내부의 bert 메서드로 넣어주고 56 | # dr_rate(drop-out rate)를 dr_rate로 넣어줌 57 | 58 | super(BERTClassifier, self).__init__() 59 | self.bert = bert 60 | self.dr_rate = dr_rate 61 | 62 | # 여기서 nn.Linear는 keras의 Dense 층과 같은 fully connected layer 63 | # hidden layer의 사이즈를 입력해주고(여기서는 768) 64 | # out-put layer의 사이즈를 num_classes 인자의 수만큼 잡아줌. 65 | # 아마 대/중/소분류 사이즈로 분리 가능할 듯. 66 | 67 | 68 | # self.lstm_layer = nn.LSTM(512, 128, 2) 69 | self.classifier = nn.Linear(hidden_size, num_classes) 70 | 71 | # self.classifier=Net(hidden_size=hidden_size, num_classes=num_classes) 72 | 73 | # dr_rate가 정의되어 있을 경우, 넣어준 비율에 맞게 weight를 drop-out 시켜줌 74 | if dr_rate: 75 | self.dropout = nn.Dropout(p=dr_rate) 76 | 77 | def generate_attention_mask(self, token_ids, valid_length): 78 | 79 | # 버트 모델에 사용할 attention_mask를 만들어 줌. 80 | # token_id를 인풋으로 받아, attention mask를 만들어 냄 81 | 82 | # torch.zeros_like()는 토치 텐서를 인풋으로 받아, 스칼라 값 0으로 채워진 같은 사이즈의 텐서를 뱉어냄 83 | attention_mask = torch.zeros_like(token_ids) 84 | 85 | for i,v in enumerate(valid_length): 86 | attention_mask[i][:v] = 1 87 | return attention_mask.float() 88 | 89 | def forward(self, token_ids, valid_length, segment_ids): 90 | # attention mask 를 만들어 내고, 버트 모델을 넣어줌. 91 | attention_mask = self.generate_attention_mask(token_ids, valid_length) 92 | 93 | # .long() pytorch는 .to()와 같은 기능을 수행함. 이는 장치(GPU)에 모델을 넣어주는 역할을 수행 94 | # 출력값으로 classifier() 95 | 96 | _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device)) 97 | if self.dr_rate: 98 | out=self.dropout(pooler) 99 | 100 | # output=self.lstm_layer(out) 101 | 102 | return self.classifier(out) 103 | 104 | 105 | 106 | 107 | class load_model: 108 | 109 | def __init__(self): 110 | 111 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 112 | self.bertmodel, self.vocab = get_pytorch_kobert_model() 113 | self.tokenizer = get_tokenizer() 114 | self.tok = nlp.data.BERTSPTokenizer(self.tokenizer, self.vocab, lower = False) 115 | self.transform = nlp.data.BERTSentenceTransform(self.tok, max_seq_length = config.model_config["max_len"], pad=True, pair=False) 116 | 117 | 118 | def get_model(self): 119 | 120 | 121 | # KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌. 122 | self.model = BERTClassifier(self.bertmodel, num_classes=num_class, dr_rate = config.model_config["dr_rate"]).to(self.device) 123 | 124 | 125 | self.model.load_state_dict(torch.load(weight_path)) 126 | self.model.eval() 127 | 128 | return self.model 129 | 130 | def get_prediction_from_txt(self, input_text, input_text_label): 131 | 132 | device=self.device 133 | 134 | sentences = self.transform([input_text]) 135 | true_values=np.nonzero(input_text_label)[0].tolist() 136 | num_of_true=round(len(true_values)*1.5) 137 | 138 | get_pred=self.model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device)) 139 | get_pred=get_pred.topk(k=num_of_true)[1] 140 | 141 | pred=np.array(get_pred.to("cpu").detach().numpy()[0], dtype=float) 142 | pred=list(map(int,pred)) 143 | result=f"분석 결과, 대화의 예상 태그는 {[config.label_cols[i] for i in pred]} 입니다." 144 | true_label=f"실제 태그는 {[config.label_cols[i] for i in true_values]} 입니다." 145 | return result, true_label 146 | 147 | 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | KoBERT multi-label VOC classifier 2 | ============================================== 3 | 4 |
5 | 6 | 7 | **[TL-DR]** 8 | KoBERT를 사용한 PyTorch text multi-label classification 모델입니다. SKT에서 공개한 한국어 pretrain embedding model인 [KoBERT](https://github.com/SKTBrain/KoBERT)를 사용했습니다. 9 | 10 | ---- 11 | 12 |
13 | 14 | Contents 15 | -------- 16 | 17 | 1. [Intro](#intro) 18 | 2. [Structure](#structure) 19 | 3. [Embedding Visualization](#embedding) 20 | 4. [XAI using PyTorch Captum](#captum) 21 | 5. [Streamlit Demo](#demo) 22 | 23 |
24 | 25 | 26 | ## Intro 27 | 28 | 큰 기업에서는 매일 수 천건, 통화로 생성되는 STT 결과 데이터를 포함하면 수 만건에 달하는 VoC(Voice of Customers)가 발생합니다. 이 때 고객의 불만 사항을 실시간으로 분류/관리하여 트렌드를 추적해주는 시스템이 있다면, 운영 부서에서 미처 예상치 못한 서비스 장애를 진단하여 조기에 대응할 수도 있고, 나아가 고객 불만을 보상하는 프로모션을 제공한다면 고객 이탈을 방지함과 동시에 세일즈 KPI에 직접적인 효과를 얻을 수 있겠죠. 이러한 맥락에서 고안된 VOC 자동 분류 모델입니다. 29 | 30 | 사용된 데이터는 모 항공사의 VOC 데이터이고, 약 2,000건의 raw 데이터를 가지고 있습니다. 전체 레이블의 수는 70여개이며, 모델링에 사용된 수는 17개 입니다. 극히 소수의 데이터에서 실제 현업에 적용 가능한 수준의 성능을 검증하는 것이 해당 모델의 구현 목적이었습니다. 라벨 데이터는 현업 전문가에 의해 태깅되었고, 상담사와 고객의 대화를 보고 관련되었다고 판단되는 태그를 달아 주었습니다. 데이터와 학습된 weight file은 보안상의 이유로 공개할 수 없음을 양해바랍니다. 31 | 32 | 33 |
34 | 35 | 36 | 37 | ## Structure 38 | 39 | 40 | ```bash 41 | 42 | ├── data # 모델 학습 및 데모에 사용되는 데이터셋 43 | ├── KoBERT # 과거 버전의 KoBERT 레포지터리를 클론한 폴더 44 | ├── model 45 | │ ├── bert_model.py # dataloader, bert 모델 및 학습 관련 util 46 | │ ├── captum_tools_vocvis.py # PyTorch XAI를 위한 Captum 관련 util 47 | │ ├── config.py 48 | │ ├── demo.py # 텍스트 입력을 넣으면 모델 예측 결과를 반환하는 streamlit web app 49 | │ ├── inference.py # 추론 모듈 50 | │ ├── kobert_multilabel_text_classifier.ipynb 51 | │ ├── metrics_for_multilabel.py # multilabel 모델 평가를 위한 metrics 52 | │ └── preprocess.py # 전처리 모듈 53 | └── weights # 학습 모델 가중치 54 | 55 | ``` 56 | 57 | Python 3.7.11 버전에서 구현되었습니다. conda 가상환경을 권장합니다. 사용 패키지는 requirements.txt를 참조해주세요. 58 | 59 | 60 | 61 | ### Model Performance 62 | 63 |
64 | 65 | 해당 모델은 Multi-Label classification 모델로, 전체 2,000여 건의 샘플 데이터를 train 85%, test 15%로 분할하여 테스트했습니다. 66 | 67 | 68 | | methods | NDCG@17| Micro f1 score| Macro f1 score | 69 | |-------------------------|:---------:|:---:|:---:| 70 | | KoBERT | **0.841** | **0.615** | **0.534**| 71 | 72 | 73 |
74 | 75 | 76 | 77 | ### Embedding visualization 78 | 79 | 80 | 81 | 프로젝트의 초기에는 multi-class task로 접근하여 모델링을 진행했습니다. 그런데, 500여개의 데이터 셋으로 나왔던 1차 성능에 비해 샘플이 더 추가된 데이터 셋으로 만든 2차 모델의 성능이 더 떨어지는 현상이 발생했고 (8개 클래스 77% -> 73%), 원인 파악을 위해 오분류 샘플을 조사했습니다. 82 | 83 | | | 카테고리 명| 84 | |-------------------------|:---------:| 85 | | 예측 정확도 70% 이상 | 무상 변경/취소, 유상 변경/취소, 기내 서비스 등 | 86 | | 예측 정확도 40% 이하 | **예약 기타**, **변경/취소** | 87 | 88 | 위와 같이, 기타 클래스의 특징을 모호하게 포함하고 있는 클래스의 성능이 매우 낮은 것을 확인할 수 있었습니다. 사람이 직접 의미적으로 판단해도 모호한 경우가 많았습니다. 이 클래스에 포함된 샘플들은 모델의 최적화 과정에서 모호한 시그널을 제공함으로써 파라미터 최적화에 악영향을 미칠 것이라고 직관적으로 생각했고, 이와 같은 내용이 버트 분류 모델의 예측에 사용되는 마지막 CLS 토큰의 representation을 low dimension에 mapping 했을 때 확인 가능할 것이라고 가정했습니다. 89 | 90 | 아래는 실제 CLS 토큰의 임베딩에 t-SNE를 적용한 결과입니다. 모호한 라벨을 가진 샘플들에 의해 임베딩 스페이스가 다소 entangled된 형태를 보이는 것을 알 수 있습니다. 91 | 92 | 93 |
94 |
drawing
95 |
96 | 97 | 그렇다면 이 라벨들을 제거해 준다면, 버트 representation 이후의 레이어가 결정 경계를 손쉽게 그을 수 있도록 임베딩이 학습되지 않을까요? 그러한 질문에 답한 것이 다음과 같은 이미지였습니다. 98 | 99 | 100 |
101 |
drawing
102 |
103 | 104 | 예쁘게 잘 정리 됐네요. 이와 같은 결과가 말해주듯이, 데이터셋을 수작업으로 레이블링할 때 모델이 혼동하지 않는 기준을 세우는 것이 중요하다는 결론을 내릴 수 있었습니다. 아래는 수정 후 모델의 confusion matrix입니다. 85:15로 stratified sampling을 해 주었습니다. 105 | 106 | 107 |
108 |
drawing
109 |
110 | 111 | 112 | 113 | 114 | ### XAI using pytorch Captum 115 | 116 | 117 | [Captum](https://captum.ai/)은 PyTorch 모델의 interpretability를 위한 라이브러리입니다. 이 중 자연어 분류 모델의 판단에 긍정적, 부정적으로 영향을 미친 토큰을 시각화해주는 [예제](https://github.com/pytorch/captum/blob/master/tutorials/IMDB_TorchText_Interpret.ipynb)가 있어 본 문제에 적용해 보았습니다. 아래는 시각화 결과입니다. 118 | 119 | 120 | 121 |
122 |
drawing
123 |
124 | 125 | "기내 서비스" 라는 레이블을 예측하는 데 positive한 영향을 준 토큰은 녹색으로, negative한 영향을 준 (즉 라벨 예측에 혼동을 준) 토큰은 붉은 색으로 시각화해 줍니다. 우리의 경우에서는 토큰 시각화가 직관에 다소 부합하지 않는 결과를 보이기도 했으나, 이는 소수 샘플로 인한 특정 토큰의 영향에 의한 것일수도, 한글 토큰의 인코딩의 문제일 수도 있습니다. 126 | 127 | 128 | 129 | 130 | 131 | ### Streamlit Demo 132 | 133 | [streamlit](https://streamlit.io/)은 웹/앱 개발에 익숙치 않은 데이터 사이언티스트들이 손쉽게 웹앱 데모를 구현할 수 있도록 도와주는 high-level data app 라이브러리입니다. 입출력을 현업에게 빠르게 보여주기 위해 다음과 같은 데모를 만들었습니다. 불과 몇 분의 투자로 모델의 I/O를 보여줄 수 있는 매우 간편한 기능을 제공합니다. 134 | 135 |
136 |
drawing
137 |
138 | 139 |
140 |
drawing
141 |
142 | 143 | 144 | 다음 커맨드로 간단하게 실행할 수 있습니다. 145 | 146 | ``` 147 | streamlit run demo.py 148 | ``` 149 | -------------------------------------------------------------------------------- /voc_classifier/bert_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("../") 3 | 4 | # KoBERT 모델 5 | 6 | import config 7 | 8 | import pandas as pd 9 | import numpy as np 10 | from sklearn.preprocessing import OneHotEncoder 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.utils.data import Dataset, DataLoader 17 | import gluonnlp as nlp 18 | from tqdm import tqdm, tqdm_notebook 19 | 20 | 21 | 22 | from KoBERT.kobert.utils import get_tokenizer 23 | from KoBERT.kobert.pytorch_kobert import get_pytorch_kobert_model 24 | 25 | from transformers import AdamW 26 | # from transformers.optimization import WarmupLinearSchedule 27 | 28 | from transformers import get_linear_schedule_with_warmup 29 | 30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 31 | 32 | bertmodel, vocab = get_pytorch_kobert_model() 33 | 34 | # 토크나이저 메서드를 tokenizer에 호출 35 | # 코퍼스를 토큰으로 만드는 과정을 수행, 이 때 토크나이저는 kobert 패키지에 있는 get_tokenizer()를 사용하고, 36 | # 토큰화를 위해 필요한 단어 사전은 kobert의 vocab을 사용함. 37 | # uncased로 투입해야 하므로 lower = False 38 | 39 | tokenizer = get_tokenizer() 40 | tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower = False) 41 | print(f'device using: {device}') 42 | 43 | 44 | model_config=config.model_config 45 | 46 | 47 | class Data_for_BERT(Dataset): 48 | def __init__(self, dataset, max_len, pad, pair, label_cols): 49 | 50 | # gluon nlp 패키지의 data.BERTSentenceTransform 메서드를 사용, 51 | # 버트 활용을 위한 토크나이저를 bert_tokenizer로 주고, 52 | # 문장 내 시퀀스의 최대 수를 max_len 인자로 제공. 이 말은 max_len개의 (단어를 쪼갠) 덩어리만 활용한다는 의미 53 | # pad 인자는 max_len보다 짧은 문장을 패딩해주겠냐는 것을 묻는 것, 54 | # pair 인자는 문장으로 변환할지, 문장 쌍으로 변환할지. 55 | 56 | transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=pad, pair=pair) 57 | self.sentences = [transform([txt]) for txt in dataset.text] 58 | # self.sentences_Customer = [transform([txt]) for txt in dataset.Customer] 59 | # self.labels = [np.int32(i) for i in dataset.label] 60 | self.labels=dataset[label_cols].values 61 | 62 | # ohe = OneHotEncoder().fit(pd.Series(self.labels).values.reshape(-1,1)) 63 | # self.labels = ohe.transform(pd.Series(self.labels).values.reshape(-1,1)).toarray() 64 | 65 | # target.bcat 66 | # self.labels = b_ohe.fit_transform(pd.Series(self.labels).values.reshape(-1,1)) 67 | 68 | def __getitem__(self,i): 69 | return (self.sentences[i] + (self.labels[i],)) 70 | 71 | def __len__(self): 72 | return(len(self.labels)) 73 | 74 | 75 | class BERTClassifier(nn.Module): 76 | def __init__(self, hidden_size = 768, num_classes = 8, dr_rate = None, params = None): 77 | # BERTClassifier의 자식 노드에 클래스 속성을 상속시켜준다? 78 | # 인풋으로 넣는 bert 모델을 클래스 내부의 bert 메서드로 넣어주고 79 | # dr_rate(drop-out rate)를 dr_rate로 넣어줌 80 | 81 | super(BERTClassifier, self).__init__() 82 | self.bert = bertmodel 83 | self.dr_rate = dr_rate 84 | 85 | # 여기서 nn.Linear는 keras의 Dense 층과 같은 fully connected layer 86 | # hidden layer의 사이즈를 입력해주고(여기서는 768) 87 | # out-put layer의 사이즈를 num_classes 인자의 수만큼 잡아줌. 88 | 89 | # self.lstm_layer = nn.LSTM(512, 128, 2) 90 | self.classifier = nn.Linear(hidden_size, num_classes) 91 | 92 | # dr_rate가 정의되어 있을 경우, 넣어준 비율에 맞게 weight를 drop-out 시켜줌 93 | if dr_rate: 94 | self.dropout = nn.Dropout(p=dr_rate) 95 | 96 | def generate_attention_mask(self, token_ids, valid_length): 97 | 98 | # 버트 모델에 사용할 attention_mask를 만들어 줌. 99 | # token_id를 인풋으로 받아, attention mask를 만들어 냄 100 | 101 | # torch.zeros_like()는 토치 텐서를 인풋으로 받아, 스칼라 값 0으로 채워진 같은 사이즈의 텐서를 뱉어냄 102 | attention_mask = torch.zeros_like(token_ids) 103 | 104 | for i,v in enumerate(valid_length): 105 | attention_mask[i][:v] = 1 106 | return attention_mask.float() 107 | 108 | def forward(self, token_ids, valid_length, segment_ids): 109 | # attention mask 를 만들어 내고, 버트 모델을 넣어줌. 110 | attention_mask = self.generate_attention_mask(token_ids, valid_length) 111 | 112 | # .long() pytorch는 .to()와 같은 기능을 수행함. 이는 장치(GPU)에 모델을 넣어주는 역할을 수행 113 | # 출력값으로 classifier() 114 | 115 | _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device)) 116 | if self.dr_rate: 117 | out=self.dropout(pooler) 118 | 119 | # output=self.lstm_layer(out) 120 | 121 | return self.classifier(out) 122 | 123 | 124 | class EarlyStopping: 125 | """Early stops the training if validation loss doesn't improve after a given patience.""" 126 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 127 | """ 128 | Args: 129 | patience (int): How long to wait after last time validation loss improved. 130 | Default: 7 131 | verbose (bool): If True, prints a message for each validation loss improvement. 132 | Default: False 133 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 134 | Default: 0 135 | path (str): Path for the checkpoint to be saved to. 136 | Default: 'checkpoint.pt' 137 | trace_func (function): trace print function. 138 | Default: print 139 | """ 140 | self.patience = patience 141 | self.verbose = verbose 142 | self.counter = 0 143 | self.best_score = None 144 | self.early_stop = False 145 | self.val_loss_min = np.Inf 146 | self.delta = delta 147 | self.path = path 148 | self.trace_func = trace_func 149 | 150 | def __call__(self, val_loss, model): 151 | 152 | score = -val_loss 153 | 154 | if self.best_score is None: 155 | self.best_score = score 156 | self.save_checkpoint(val_loss, model) 157 | elif score < self.best_score + self.delta: 158 | self.counter += 1 159 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 160 | if self.counter >= self.patience: 161 | self.early_stop = True 162 | else: 163 | self.best_score = score 164 | self.save_checkpoint(val_loss, model) 165 | self.counter = 0 166 | 167 | def save_checkpoint(self, val_loss, model): 168 | '''Saves model when validation loss decrease.''' 169 | if self.verbose: 170 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 171 | torch.save(model.state_dict(), self.path) 172 | self.val_loss_min = val_loss 173 | -------------------------------------------------------------------------------- /voc_classifier/preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import os 5 | import re 6 | 7 | import config 8 | 9 | DATA_PATH=config.DATA_PATH 10 | 11 | 12 | 13 | def remove_linebreak(string): 14 | return string.replace('\r',' ').replace('\n',' ') 15 | 16 | def split_table_string(string): 17 | trimmedTableString=string[string.rfind("PNR :"):] 18 | string=string[:string.rfind("PNR :")] 19 | return (string, trimmedTableString) 20 | 21 | def remove_multispace(x): 22 | x = str(x).strip() 23 | x = re.sub(' +', ' ',x) 24 | return x 25 | 26 | 27 | 28 | 29 | 30 | class preprocess: 31 | 32 | def __init__(self): 33 | self.voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl') 34 | 35 | 36 | def make_table(self): 37 | 38 | voc_data=self.voc_total["종합본"] 39 | 40 | DATA_PATH=config.DATA_PATH 41 | 42 | voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl') 43 | 44 | voc_data=voc_total["종합본"] 45 | 46 | voc_data=voc_data.drop(["Unnamed: 11","Unnamed: 12"],axis=1) 47 | 48 | voc_data.columns=["case_no","sender_type","message","cat1","cat2","cat3","cat4","cat5","cat6","cat7","cat8"] 49 | 50 | case_start_idx=voc_data.loc[:,"case_no"][voc_data.loc[:,"case_no"].notnull()].index.tolist() 51 | 52 | voc_data.loc[case_start_idx,"seq"]=1 53 | 54 | nan_val=voc_data.iloc[2][0] 55 | 56 | result=[] 57 | case_no=False 58 | cnt_case=0 59 | for _,txt in voc_data.iterrows(): 60 | if (cnt_case==0)&(txt.seq==1): 61 | new_chat_corpus=[] 62 | 63 | cat1=txt.cat1 64 | cat2=txt.cat2 65 | cat3=txt.cat3 66 | cat4=txt.cat4 67 | cat5=txt.cat5 68 | cat6=txt.cat6 69 | cat7=txt.cat7 70 | cat8=txt.cat8 71 | 72 | new_chat_corpus.append([txt.sender_type,txt.message]) 73 | cnt_case+=1 74 | 75 | elif txt.seq==1: 76 | 77 | result.append([new_chat_corpus,cat1,cat2,cat3,cat4,cat5,cat6,cat7,cat8]) 78 | 79 | cnt_case+=1 80 | # 기존의 말뭉치 append 81 | # 새로운 말뭉치 구분 시작 82 | new_chat_corpus=[] 83 | 84 | cat1=txt.cat1 85 | cat2=txt.cat2 86 | cat3=txt.cat3 87 | cat4=txt.cat4 88 | cat5=txt.cat5 89 | cat6=txt.cat6 90 | cat7=txt.cat7 91 | cat8=txt.cat8 92 | 93 | new_chat_corpus.append([txt.sender_type,txt.message]) 94 | 95 | else: 96 | new_chat_corpus.append([txt.sender_type,txt.message]) 97 | 98 | 99 | total_result=[] 100 | 101 | for i in range(len(result)): 102 | talk=" ".join([str(txt[1]) for txt in result[i][0]]) 103 | label=result[i][-8:] 104 | total_result.append([talk,label]) 105 | 106 | 107 | 108 | result_data=pd.DataFrame(total_result) 109 | 110 | label_total=result_data[1] 111 | 112 | result_data[["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]]=pd.DataFrame(result_data[1].tolist(), index=result_data.index) 113 | 114 | result_data=result_data.drop(1,axis=1) 115 | 116 | result_data.columns=["text","label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"] 117 | 118 | result_data=result_data[result_data["label_cat1"].isna()==False] 119 | 120 | result_data.label_cat1=result_data.label_cat1.fillna("결측") 121 | result_data.label_cat1=result_data.label_cat1.apply(str.strip) 122 | result_data.label_cat2=result_data.label_cat2.fillna("결측") 123 | result_data.label_cat2=result_data.label_cat2.apply(str.strip) 124 | result_data.label_cat3=result_data.label_cat3.fillna("결측") 125 | result_data.label_cat3=result_data.label_cat3.apply(str.strip) 126 | result_data.label_cat4=result_data.label_cat4.fillna("결측") 127 | result_data.label_cat4=result_data.label_cat4.apply(str.strip) 128 | result_data.label_cat5=result_data.label_cat5.fillna("결측") 129 | result_data.label_cat5=result_data.label_cat5.apply(str.strip) 130 | result_data.label_cat6=result_data.label_cat6.fillna("결측") 131 | result_data.label_cat6=result_data.label_cat6.apply(str.strip) 132 | result_data.label_cat7=result_data.label_cat7.fillna("결측") 133 | result_data.label_cat7=result_data.label_cat7.apply(str.strip) 134 | result_data.label_cat8=result_data.label_cat8.fillna("결측") 135 | result_data.label_cat8=result_data.label_cat8.apply(str.strip) 136 | 137 | result_data.label_cat1=result_data.label_cat1.apply(str.replace, args=(" ","")) 138 | result_data.label_cat2=result_data.label_cat2.apply(str.replace, args=(" ","")) 139 | result_data.label_cat3=result_data.label_cat3.apply(str.replace, args=(" ","")) 140 | result_data.label_cat4=result_data.label_cat4.apply(str.replace, args=(" ","")) 141 | result_data.label_cat5=result_data.label_cat5.apply(str.replace, args=(" ","")) 142 | result_data.label_cat6=result_data.label_cat6.apply(str.replace, args=(" ","")) 143 | result_data.label_cat7=result_data.label_cat7.apply(str.replace, args=(" ","")) 144 | result_data.label_cat8=result_data.label_cat8.apply(str.replace, args=(" ","")) 145 | 146 | self.table=result_data 147 | return True 148 | 149 | def label_process(self, num_labels=17 ,except_labels=None): 150 | result_data=self.table 151 | 152 | total_label_cases=set(result_data.label_cat1.unique().tolist())|set(result_data.label_cat2.unique().tolist())|set(result_data.label_cat3.unique().tolist())|\ 153 | set(result_data.label_cat4.unique().tolist())|set(result_data.label_cat5.unique().tolist())|set(result_data.label_cat6.unique().tolist())|\ 154 | set(result_data.label_cat7.unique().tolist())|set(result_data.label_cat8.unique().tolist()) 155 | 156 | total_label_cases=list(total_label_cases) 157 | 158 | final_label=[] 159 | 160 | for _,txt in result_data.iterrows(): 161 | label_sum="|"+txt.label_cat1+"|"+txt.label_cat2+"|"+txt.label_cat3+"|"+txt.label_cat4+"|"+txt.label_cat5+"|"+txt.label_cat6+"|"+txt.label_cat7+"|"+txt.label_cat8 162 | final_label.append([_,label_sum]) 163 | 164 | total_label_cases_dict={} 165 | 166 | for col in total_label_cases: 167 | total_label_cases_dict[col]=len([i[0] for i in final_label if f"|{col}|" in i[1]]) 168 | 169 | label_cases_dict_cnt=[case for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()] 170 | 171 | label_cases_sorted=[case[0] for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()] 172 | 173 | if except_labels: 174 | for label in except_labels: 175 | label_cases_sorted.remove(label) 176 | 177 | label_cases_sorted_target=label_cases_sorted[1:num_labels+1] 178 | 179 | label_cols=label_cases_sorted_target 180 | self.label_cols=label_cases_sorted_target 181 | 182 | for col in label_cases_sorted_target: 183 | result_data.loc[[i[0] for i in final_label if f"|{col}|" in i[1]], col]=1 184 | result_data[col]=result_data[col].fillna(0) 185 | result_data[col]=result_data[col].astype(int) 186 | 187 | 188 | result_data=result_data.drop(["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"],axis=1) 189 | # 단 하나의 답도 없는 경우에는 일단 제거 190 | 191 | result_data=result_data[(result_data.iloc[:,1:].eq(0).sum(axis=1)!=num_labels)] 192 | 193 | self.data=result_data 194 | 195 | return True 196 | 197 | -------------------------------------------------------------------------------- /voc_classifier/.ipynb_checkpoints/preprocess-checkpoint.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import os 5 | import re 6 | 7 | import config 8 | 9 | DATA_PATH=config.DATA_PATH 10 | 11 | 12 | 13 | def remove_linebreak(string): 14 | return string.replace('\r',' ').replace('\n',' ') 15 | 16 | def split_table_string(string): 17 | trimmedTableString=string[string.rfind("PNR :"):] 18 | string=string[:string.rfind("PNR :")] 19 | return (string, trimmedTableString) 20 | 21 | def remove_multispace(x): 22 | x = str(x).strip() 23 | x = re.sub(' +', ' ',x) 24 | return x 25 | 26 | 27 | 28 | 29 | 30 | class preprocess: 31 | 32 | def __init__(self): 33 | self.voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl') 34 | 35 | 36 | def make_table(self): 37 | 38 | voc_data=self.voc_total["종합본"] 39 | 40 | DATA_PATH=config.DATA_PATH 41 | 42 | voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl') 43 | 44 | voc_data=voc_total["종합본"] 45 | 46 | voc_data=voc_data.drop(["Unnamed: 11","Unnamed: 12"],axis=1) 47 | 48 | voc_data.columns=["case_no","sender_type","message","cat1","cat2","cat3","cat4","cat5","cat6","cat7","cat8"] 49 | 50 | case_start_idx=voc_data.loc[:,"case_no"][voc_data.loc[:,"case_no"].notnull()].index.tolist() 51 | 52 | voc_data.loc[case_start_idx,"seq"]=1 53 | 54 | nan_val=voc_data.iloc[2][0] 55 | 56 | result=[] 57 | case_no=False 58 | cnt_case=0 59 | for _,txt in voc_data.iterrows(): 60 | if (cnt_case==0)&(txt.seq==1): 61 | new_chat_corpus=[] 62 | 63 | cat1=txt.cat1 64 | cat2=txt.cat2 65 | cat3=txt.cat3 66 | cat4=txt.cat4 67 | cat5=txt.cat5 68 | cat6=txt.cat6 69 | cat7=txt.cat7 70 | cat8=txt.cat8 71 | 72 | new_chat_corpus.append([txt.sender_type,txt.message]) 73 | cnt_case+=1 74 | 75 | elif txt.seq==1: 76 | 77 | result.append([new_chat_corpus,cat1,cat2,cat3,cat4,cat5,cat6,cat7,cat8]) 78 | 79 | cnt_case+=1 80 | # 기존의 말뭉치 append 81 | # 새로운 말뭉치 구분 시작 82 | new_chat_corpus=[] 83 | 84 | cat1=txt.cat1 85 | cat2=txt.cat2 86 | cat3=txt.cat3 87 | cat4=txt.cat4 88 | cat5=txt.cat5 89 | cat6=txt.cat6 90 | cat7=txt.cat7 91 | cat8=txt.cat8 92 | 93 | new_chat_corpus.append([txt.sender_type,txt.message]) 94 | 95 | else: 96 | new_chat_corpus.append([txt.sender_type,txt.message]) 97 | 98 | 99 | total_result=[] 100 | 101 | for i in range(len(result)): 102 | talk=" ".join([str(txt[1]) for txt in result[i][0]]) 103 | label=result[i][-8:] 104 | total_result.append([talk,label]) 105 | 106 | 107 | 108 | result_data=pd.DataFrame(total_result) 109 | 110 | label_total=result_data[1] 111 | 112 | result_data[["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]]=pd.DataFrame(result_data[1].tolist(), index=result_data.index) 113 | 114 | result_data=result_data.drop(1,axis=1) 115 | 116 | result_data.columns=["text","label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"] 117 | 118 | result_data=result_data[result_data["label_cat1"].isna()==False] 119 | 120 | result_data.label_cat1=result_data.label_cat1.fillna("결측") 121 | result_data.label_cat1=result_data.label_cat1.apply(str.strip) 122 | result_data.label_cat2=result_data.label_cat2.fillna("결측") 123 | result_data.label_cat2=result_data.label_cat2.apply(str.strip) 124 | result_data.label_cat3=result_data.label_cat3.fillna("결측") 125 | result_data.label_cat3=result_data.label_cat3.apply(str.strip) 126 | result_data.label_cat4=result_data.label_cat4.fillna("결측") 127 | result_data.label_cat4=result_data.label_cat4.apply(str.strip) 128 | result_data.label_cat5=result_data.label_cat5.fillna("결측") 129 | result_data.label_cat5=result_data.label_cat5.apply(str.strip) 130 | result_data.label_cat6=result_data.label_cat6.fillna("결측") 131 | result_data.label_cat6=result_data.label_cat6.apply(str.strip) 132 | result_data.label_cat7=result_data.label_cat7.fillna("결측") 133 | result_data.label_cat7=result_data.label_cat7.apply(str.strip) 134 | result_data.label_cat8=result_data.label_cat8.fillna("결측") 135 | result_data.label_cat8=result_data.label_cat8.apply(str.strip) 136 | 137 | result_data.label_cat1=result_data.label_cat1.apply(str.replace, args=(" ","")) 138 | result_data.label_cat2=result_data.label_cat2.apply(str.replace, args=(" ","")) 139 | result_data.label_cat3=result_data.label_cat3.apply(str.replace, args=(" ","")) 140 | result_data.label_cat4=result_data.label_cat4.apply(str.replace, args=(" ","")) 141 | result_data.label_cat5=result_data.label_cat5.apply(str.replace, args=(" ","")) 142 | result_data.label_cat6=result_data.label_cat6.apply(str.replace, args=(" ","")) 143 | result_data.label_cat7=result_data.label_cat7.apply(str.replace, args=(" ","")) 144 | result_data.label_cat8=result_data.label_cat8.apply(str.replace, args=(" ","")) 145 | 146 | self.table=result_data 147 | return True 148 | 149 | def label_process(self, num_labels=17 ,except_labels=None): 150 | result_data=self.table 151 | 152 | total_label_cases=set(result_data.label_cat1.unique().tolist())|set(result_data.label_cat2.unique().tolist())|set(result_data.label_cat3.unique().tolist())|\ 153 | set(result_data.label_cat4.unique().tolist())|set(result_data.label_cat5.unique().tolist())|set(result_data.label_cat6.unique().tolist())|\ 154 | set(result_data.label_cat7.unique().tolist())|set(result_data.label_cat8.unique().tolist()) 155 | 156 | total_label_cases=list(total_label_cases) 157 | 158 | final_label=[] 159 | 160 | for _,txt in result_data.iterrows(): 161 | label_sum="|"+txt.label_cat1+"|"+txt.label_cat2+"|"+txt.label_cat3+"|"+txt.label_cat4+"|"+txt.label_cat5+"|"+txt.label_cat6+"|"+txt.label_cat7+"|"+txt.label_cat8 162 | final_label.append([_,label_sum]) 163 | 164 | total_label_cases_dict={} 165 | 166 | for col in total_label_cases: 167 | total_label_cases_dict[col]=len([i[0] for i in final_label if f"|{col}|" in i[1]]) 168 | 169 | label_cases_dict_cnt=[case for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()] 170 | 171 | label_cases_sorted=[case[0] for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()] 172 | 173 | if except_labels: 174 | for label in except_labels: 175 | label_cases_sorted.remove(label) 176 | 177 | label_cases_sorted_target=label_cases_sorted[1:num_labels+1] 178 | 179 | label_cols=label_cases_sorted_target 180 | self.label_cols=label_cases_sorted_target 181 | 182 | for col in label_cases_sorted_target: 183 | result_data.loc[[i[0] for i in final_label if f"|{col}|" in i[1]], col]=1 184 | result_data[col]=result_data[col].fillna(0) 185 | result_data[col]=result_data[col].astype(int) 186 | 187 | 188 | result_data=result_data.drop(["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"],axis=1) 189 | # 단 하나의 답도 없는 경우에는 일단 제거 190 | 191 | result_data=result_data[(result_data.iloc[:,1:].eq(0).sum(axis=1)!=num_labels)] 192 | 193 | self.data=result_data 194 | 195 | return True 196 | 197 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /voc_classifier/metrics_for_multilabel.py: -------------------------------------------------------------------------------- 1 | # https://github.com/iliaschalkidis/lmtc-eurlex57k/blob/master/metrics.py 2 | 3 | from sklearn.metrics import accuracy_score 4 | from sklearn.metrics import precision_score 5 | from sklearn.metrics import recall_score 6 | from sklearn.metrics import f1_score 7 | 8 | 9 | import numpy as np 10 | 11 | 12 | def mean_precision_k(y_true, y_score, k=10): 13 | """Mean precision at rank k 14 | Parameters 15 | ---------- 16 | y_true : array-like, shape = [n_samples] 17 | Ground truth (true relevance labels). 18 | y_score : array-like, shape = [n_samples] 19 | Predicted scores. 20 | k : int 21 | Rank. 22 | Returns 23 | ------- 24 | mean precision @k : float 25 | """ 26 | 27 | p_ks = [] 28 | for y_t, y_s in zip(y_true, y_score): 29 | if np.sum(y_t == 1): 30 | p_ks.append(ranking_precision_score(y_t, y_s, k=k)) 31 | 32 | return np.mean(p_ks) 33 | 34 | 35 | def mean_recall_k(y_true, y_score, k=10): 36 | """Mean recall at rank k 37 | Parameters 38 | ---------- 39 | y_true : array-like, shape = [n_samples] 40 | Ground truth (true relevance labels). 41 | y_score : array-like, shape = [n_samples] 42 | Predicted scores. 43 | k : int 44 | Rank. 45 | Returns 46 | ------- 47 | mean recall @k : float 48 | """ 49 | 50 | r_ks = [] 51 | for y_t, y_s in zip(y_true, y_score): 52 | if np.sum(y_t == 1): 53 | r_ks.append(ranking_recall_score(y_t, y_s, k=k)) 54 | 55 | return np.mean(r_ks) 56 | 57 | 58 | def mean_ndcg_score(y_true, y_score, k=10, gains="exponential"): 59 | """Normalized discounted cumulative gain (NDCG) at rank k 60 | Parameters 61 | ---------- 62 | y_true : array-like, shape = [n_samples] 63 | Ground truth (true relevance labels). 64 | y_score : array-like, shape = [n_samples] 65 | Predicted scores. 66 | k : int 67 | Rank. 68 | gains : str 69 | Whether gains should be "exponential" (default) or "linear". 70 | Returns 71 | ------- 72 | Mean NDCG @k : float 73 | """ 74 | 75 | ndcg_s = [] 76 | for y_t, y_s in zip(y_true, y_score): 77 | if np.sum(y_t == 1): 78 | ndcg_s.append(ndcg_score(y_t, y_s, k=k, gains=gains)) 79 | 80 | return np.mean(ndcg_s) 81 | 82 | 83 | def mean_rprecision_k(y_true, y_score, k=10): 84 | """Mean precision at rank k 85 | Parameters 86 | ---------- 87 | y_true : array-like, shape = [n_samples] 88 | Ground truth (true relevance labels). 89 | y_score : array-like, shape = [n_samples] 90 | Predicted scores. 91 | k : int 92 | Rank. 93 | Returns 94 | ------- 95 | mean precision @k : float 96 | """ 97 | 98 | p_ks = [] 99 | for y_t, y_s in zip(y_true, y_score): 100 | if np.sum(y_t == 1): 101 | p_ks.append(ranking_rprecision_score(y_t, y_s, k=k)) 102 | 103 | return np.mean(p_ks) 104 | 105 | 106 | def ranking_recall_score(y_true, y_score, k=10): 107 | # https://ils.unc.edu/courses/2013_spring/inls509_001/lectures/10-EvaluationMetrics.pdf 108 | """Recall at rank k 109 | Parameters 110 | ---------- 111 | y_true : array-like, shape = [n_samples] 112 | Ground truth (true relevance labels). 113 | y_score : array-like, shape = [n_samples] 114 | Predicted scores. 115 | k : int 116 | Rank. 117 | Returns 118 | ------- 119 | precision @k : float 120 | """ 121 | unique_y = np.unique(y_true) 122 | 123 | if len(unique_y) == 1: 124 | return ValueError("The score cannot be approximated.") 125 | elif len(unique_y) > 2: 126 | raise ValueError("Only supported for two relevance levels.") 127 | 128 | pos_label = unique_y[1] 129 | n_pos = np.sum(y_true == pos_label) 130 | 131 | order = np.argsort(y_score)[::-1] 132 | y_true = np.take(y_true, order[:k]) 133 | n_relevant = np.sum(y_true == pos_label) 134 | 135 | return float(n_relevant) / n_pos 136 | 137 | 138 | def ranking_precision_score(y_true, y_score, k=10): 139 | """Precision at rank k 140 | Parameters 141 | ---------- 142 | y_true : array-like, shape = [n_samples] 143 | Ground truth (true relevance labels). 144 | y_score : array-like, shape = [n_samples] 145 | Predicted scores. 146 | k : int 147 | Rank. 148 | Returns 149 | ------- 150 | precision @k : float 151 | """ 152 | unique_y = np.unique(y_true) 153 | 154 | if len(unique_y) == 1: 155 | return ValueError("The score cannot be approximated.") 156 | elif len(unique_y) > 2: 157 | raise ValueError("Only supported for two relevance levels.") 158 | 159 | pos_label = unique_y[1] 160 | 161 | order = np.argsort(y_score)[::-1] 162 | y_true = np.take(y_true, order[:k]) 163 | n_relevant = np.sum(y_true == pos_label) 164 | 165 | return float(n_relevant) / k 166 | 167 | 168 | def ranking_rprecision_score(y_true, y_score, k=10): 169 | """Precision at rank k 170 | Parameters 171 | ---------- 172 | y_true : array-like, shape = [n_samples] 173 | Ground truth (true relevance labels). 174 | y_score : array-like, shape = [n_samples] 175 | Predicted scores. 176 | k : int 177 | Rank. 178 | Returns 179 | ------- 180 | precision @k : float 181 | """ 182 | unique_y = np.unique(y_true) 183 | 184 | if len(unique_y) == 1: 185 | return ValueError("The score cannot be approximated.") 186 | elif len(unique_y) > 2: 187 | raise ValueError("Only supported for two relevance levels.") 188 | 189 | pos_label = unique_y[1] 190 | n_pos = np.sum(y_true == pos_label) 191 | 192 | order = np.argsort(y_score)[::-1] 193 | y_true = np.take(y_true, order[:k]) 194 | n_relevant = np.sum(y_true == pos_label) 195 | 196 | # Divide by min(n_pos, k) such that the best achievable score is always 1.0. 197 | return float(n_relevant) / min(k, n_pos) 198 | 199 | 200 | def average_precision_score(y_true, y_score, k=10): 201 | """Average precision at rank k 202 | Parameters 203 | ---------- 204 | y_true : array-like, shape = [n_samples] 205 | Ground truth (true relevance labels). 206 | y_score : array-like, shape = [n_samples] 207 | Predicted scores. 208 | k : int 209 | Rank. 210 | Returns 211 | ------- 212 | average precision @k : float 213 | """ 214 | unique_y = np.unique(y_true) 215 | 216 | if len(unique_y) == 1: 217 | return ValueError("The score cannot be approximated.") 218 | elif len(unique_y) > 2: 219 | raise ValueError("Only supported for two relevance levels.") 220 | 221 | pos_label = unique_y[1] 222 | n_pos = np.sum(y_true == pos_label) 223 | 224 | order = np.argsort(y_score)[::-1][:min(n_pos, k)] 225 | y_true = np.asarray(y_true)[order] 226 | 227 | score = 0 228 | for i in range(len(y_true)): 229 | if y_true[i] == pos_label: 230 | # Compute precision up to document i 231 | # i.e, percentage of relevant documents up to document i. 232 | prec = 0 233 | for j in range(0, i + 1): 234 | if y_true[j] == pos_label: 235 | prec += 1.0 236 | prec /= (i + 1.0) 237 | score += prec 238 | 239 | if n_pos == 0: 240 | return 0 241 | 242 | return score / n_pos 243 | 244 | 245 | def dcg_score(y_true, y_score, k=10, gains="exponential"): 246 | """Discounted cumulative gain (DCG) at rank k 247 | Parameters 248 | ---------- 249 | y_true : array-like, shape = [n_samples] 250 | Ground truth (true relevance labels). 251 | y_score : array-like, shape = [n_samples] 252 | Predicted scores. 253 | k : int 254 | Rank. 255 | gains : str 256 | Whether gains should be "exponential" (default) or "linear". 257 | Returns 258 | ------- 259 | DCG @k : float 260 | """ 261 | order = np.argsort(y_score)[::-1] 262 | y_true = np.take(y_true, order[:k]) 263 | 264 | if gains == "exponential": 265 | gains = 2 ** y_true - 1 266 | elif gains == "linear": 267 | gains = y_true 268 | else: 269 | raise ValueError("Invalid gains option.") 270 | 271 | # highest rank is 1 so +2 instead of +1 272 | discounts = np.log2(np.arange(len(y_true)) + 2) 273 | return np.sum(gains / discounts) 274 | 275 | 276 | def ndcg_score(y_true, y_score, k=10, gains="exponential"): 277 | """Normalized discounted cumulative gain (NDCG) at rank k 278 | Parameters 279 | ---------- 280 | y_true : array-like, shape = [n_samples] 281 | Ground truth (true relevance labels). 282 | y_score : array-like, shape = [n_samples] 283 | Predicted scores. 284 | k : int 285 | Rank. 286 | gains : str 287 | Whether gains should be "exponential" (default) or "linear". 288 | Returns 289 | ------- 290 | NDCG @k : float 291 | """ 292 | best = dcg_score(y_true, y_true, k, gains) 293 | actual = dcg_score(y_true, y_score, k, gains) 294 | return actual / best 295 | 296 | 297 | # Alternative API. 298 | 299 | def dcg_from_ranking(y_true, ranking): 300 | """Discounted cumulative gain (DCG) at rank k 301 | Parameters 302 | ---------- 303 | y_true : array-like, shape = [n_samples] 304 | Ground truth (true relevance labels). 305 | ranking : array-like, shape = [k] 306 | Document indices, i.e., 307 | ranking[0] is the index of top-ranked document, 308 | ranking[1] is the index of second-ranked document, 309 | ... 310 | k : int 311 | Rank. 312 | Returns 313 | ------- 314 | DCG @k : float 315 | """ 316 | y_true = np.asarray(y_true) 317 | ranking = np.asarray(ranking) 318 | rel = y_true[ranking] 319 | gains = 2 ** rel - 1 320 | discounts = np.log2(np.arange(len(ranking)) + 2) 321 | return np.sum(gains / discounts) 322 | 323 | 324 | def ndcg_from_ranking(y_true, ranking): 325 | """Normalized discounted cumulative gain (NDCG) at rank k 326 | Parameters 327 | ---------- 328 | y_true : array-like, shape = [n_samples] 329 | Ground truth (true relevance labels). 330 | ranking : array-like, shape = [k] 331 | Document indices, i.e., 332 | ranking[0] is the index of top-ranked document, 333 | ranking[1] is the index of second-ranked document, 334 | ... 335 | k : int 336 | Rank. 337 | Returns 338 | ------- 339 | NDCG @k : float 340 | """ 341 | k = len(ranking) 342 | best_ranking = np.argsort(y_true)[::-1] 343 | best = dcg_from_ranking(y_true, best_ranking[:k]) 344 | return dcg_from_ranking(y_true, ranking) / best 345 | 346 | def colwise_accuracy(y_true,y_pred): 347 | y_pred=y_pred.T 348 | y_true=y_true.T 349 | acc_list=[] 350 | for cate in range(0,y_pred.shape[0]): 351 | acc_list.append(accuracy_score(y_pred[cate],y_true[cate])) 352 | return sum(acc_list)/len(acc_list) 353 | 354 | def calculate_metrics(pred, target, threshold=0.5): 355 | 356 | pred = np.array(pred > threshold, dtype=float) 357 | 358 | return {'Accuracy': accuracy_score(y_true=target, y_pred=pred), 359 | 'Column-wise Accuracy': colwise_accuracy(y_true=target, y_pred=pred), 360 | 'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'), 361 | 'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'), 362 | 'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'), 363 | 'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'), 364 | 'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'), 365 | 'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'), 366 | 'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'), 367 | 'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'), 368 | 'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'), 369 | } -------------------------------------------------------------------------------- /voc_classifier/.ipynb_checkpoints/kobert_multilabel_text_classifier-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "
\n", 8 | "
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "
\n", 16 | "

KoBERT Multi-label text classifier

\n", 17 | "

By: Myeonghak Lee

\n", 18 | "
" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "
\n", 26 | "
" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "tags": [] 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "# Input Data 가공 파트\n", 38 | "\n", 39 | "# import torchtext\n", 40 | "import pandas as pd\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "import os\n", 44 | "import re\n", 45 | "\n", 46 | "import config\n", 47 | "from config import expand_pandas\n", 48 | "from preprocess import preprocess\n", 49 | "\n", 50 | "DATA_PATH=config.DATA_PATH\n", 51 | "\n", 52 | "model_config=config.model_config" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import warnings\n", 62 | "warnings.filterwarnings(\"ignore\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "done\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "config.expand_pandas(max_rows=100, max_cols=100,width=1000,max_info_cols=500)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "### **configs**" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "num_class=17\n", 96 | "ver_num=1\n", 97 | "\n", 98 | "except_labels=[\"변경/취소\",\"예약기타\"]\n", 99 | "\n", 100 | "version_info=\"{:02d}\".format(ver_num)\n", 101 | "weight_path=f\"../weights/weight_{version_info}.pt\"" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "### **preprocess**" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 12, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "data=preprocess()" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 13, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "True" 129 | ] 130 | }, 131 | "execution_count": 13, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | } 135 | ], 136 | "source": [ 137 | "# data_orig=data.voc_total[\"종합본\"]\n", 138 | "data.make_table()\n", 139 | "\n", 140 | "# put labels\n", 141 | "data.label_process(num_labels=num_class, except_labels=except_labels)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 14, 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "orig=data.voc_total[\"종합본\"]" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 15, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "label_cols=data.label_cols" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 16, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "df=data.data.copy()" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 17, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "voc_dataset=df.reset_index(drop=True)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "# Modeling part" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "import torch\n", 194 | "from torch import nn\n", 195 | "\n", 196 | "from metrics_for_multilabel import calculate_metrics, colwise_accuracy\n", 197 | "\n", 198 | "from bert_model import Data_for_BERT, BERTClassifier, EarlyStopping\n", 199 | "\n", 200 | "from transformers import get_linear_schedule_with_warmup, AdamW" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 15, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [ 216 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "metadata": {}, 223 | "outputs": [], 224 | "source": [] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 24, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "ename": "NameError", 233 | "evalue": "name 'Data_for_BERT' is not defined", 234 | "output_type": "error", 235 | "traceback": [ 236 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 237 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 238 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mdata_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mdata_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 239 | "\u001b[0;31mNameError\u001b[0m: name 'Data_for_BERT' is not defined" 240 | ] 241 | } 242 | ], 243 | "source": [ 244 | "from sklearn.model_selection import train_test_split\n", 245 | "train_input, test_input, train_target, test_target = train_test_split(voc_dataset, voc_dataset[\"국내선\"], test_size = 0.25, random_state = 42)\n", 246 | "\n", 247 | "# train=pd.concat([train_input,train_target],axis=1)\n", 248 | "# test=pd.concat([test_input,test_target],axis=1)\n", 249 | "\n", 250 | "train=train_input.copy()\n", 251 | "test=test_input.copy()\n", 252 | "\n", 253 | "train=train.reset_index(drop=True)\n", 254 | "test=test.reset_index(drop=True)\n", 255 | "\n", 256 | "data_train = Data_for_BERT(train, model_config[\"max_len\"], True, False, label_cols=label_cols)\n", 257 | "data_test = Data_for_BERT(test, model_config[\"max_len\"], True, False, label_cols=label_cols)\n", 258 | "\n", 259 | "# 파이토치 모델에 넣을 수 있도록 데이터를 처리함. \n", 260 | "# data_train을 넣어주고, 이 테이터를 batch_size에 맞게 잘라줌. num_workers는 사용할 subprocess의 개수를 의미함(병렬 프로그래밍)\n", 261 | "\n", 262 | "train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=model_config[\"batch_size\"], num_workers=0)\n", 263 | "test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=model_config[\"batch_size\"], num_workers=0)" 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "execution_count": null, 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 18, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [ 279 | "# KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌.\n", 280 | "model = BERTClassifier(num_classes=num_class, dr_rate = model_config[\"dr_rate\"]).to(device)\n", 281 | "\n", 282 | "# 옵티마이저와 스케쥴 준비 (linear warmup과 decay)\n", 283 | "no_decay = ['bias', 'LayerNorm.weight']\n", 284 | "\n", 285 | "# no_decay에 해당하는 파라미터명을 가진 레이어들은 decay에서 배제하기 위해 weight_decay를 0으로 셋팅, 그 외에는 0.01로 decay\n", 286 | "# weight decay란 l2 norm으로 파라미터 값을 정규화해주는 기법을 의미함\n", 287 | "optimizer_grouped_parameters = [\n", 288 | " {\"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay' : 0.01},\n", 289 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 290 | "]\n", 291 | "\n", 292 | "\n", 293 | "# 옵티마이저는 AdamW, 손실함수는 BCE\n", 294 | "# optimizer_grouped_parameters는 최적화할 파라미터의 그룹을 의미함\n", 295 | "optimizer = AdamW(optimizer_grouped_parameters, lr= model_config[\"learning_rate\"])\n", 296 | "# loss_fn = nn.CrossEntropyLoss()\n", 297 | "loss_fn=nn.BCEWithLogitsLoss()\n", 298 | "\n", 299 | "\n", 300 | "# t_total = train_dataloader.dataset.labels.shape[0] * num_epochs\n", 301 | "# linear warmup을 사용해 학습 초기 단계(배치 초기)의 learning rate를 조금씩 증가시켜 나가다, 어느 지점에 이르면 constant하게 유지\n", 302 | "# 초기 학습 단계에서의 변동성을 줄여줌.\n", 303 | "\n", 304 | "t_total = len(train_dataloader) * model_config[\"num_epochs\"]\n", 305 | "warmup_step = int(t_total * model_config[\"warmup_ratio\"])\n", 306 | "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)\n", 307 | "\n", 308 | "\n", 309 | "# model_save_name = 'classifier'\n", 310 | "# model_file='.pt'\n", 311 | "# path = f\"./bert_weights/{model_save_name}_{model_file}\" " 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 23, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "def train_model(model, batch_size, patience, n_epochs,path):\n", 328 | " \n", 329 | " # to track the training loss as the model trains\n", 330 | " train_losses = []\n", 331 | " # to track the validation loss as the model trains\n", 332 | " valid_losses = []\n", 333 | " # to track the average training loss per epoch as the model trains\n", 334 | " avg_train_losses = []\n", 335 | " # to track the average validation loss per epoch as the model trains\n", 336 | " avg_valid_losses = [] \n", 337 | "\n", 338 | " early_stopping = EarlyStopping(patience=patience, verbose=True, path=path)\n", 339 | "\n", 340 | " for epoch in range(1, n_epochs + 1):\n", 341 | " \n", 342 | " # initialize the early_stopping object\n", 343 | " model.train()\n", 344 | " train_epoch_pred=[]\n", 345 | " train_loss_record=[]\n", 346 | "\n", 347 | " for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n", 348 | " optimizer.zero_grad()\n", 349 | "\n", 350 | " token_ids = token_ids.long().to(device)\n", 351 | " segment_ids = segment_ids.long().to(device)\n", 352 | " valid_length= valid_length\n", 353 | " \n", 354 | " # label = label.long().to(device)\n", 355 | " label = label.float().to(device)\n", 356 | "\n", 357 | " out= model(token_ids, valid_length, segment_ids)#.squeeze(1)\n", 358 | " \n", 359 | " loss = loss_fn(out, label)\n", 360 | "\n", 361 | " train_loss_record.append(loss)\n", 362 | "\n", 363 | " train_pred=out.detach().cpu().numpy()\n", 364 | " train_real=label.detach().cpu().numpy()\n", 365 | "\n", 366 | " train_batch_result = calculate_metrics(np.array(train_pred), np.array(train_real))\n", 367 | " \n", 368 | " if batch_id%50==0:\n", 369 | " print(f\"batch number {batch_id}, train col-wise accuracy is : {train_batch_result['Column-wise Accuracy']}\")\n", 370 | " \n", 371 | "\n", 372 | " # save prediction result for calculation of accuracy per batch\n", 373 | " train_epoch_pred.append(train_pred)\n", 374 | "\n", 375 | " \n", 376 | " loss.backward()\n", 377 | " torch.nn.utils.clip_grad_norm_(model.parameters(), model_config[\"max_grad_norm\"])\n", 378 | " optimizer.step()\n", 379 | " scheduler.step() # Update learning rate schedule\n", 380 | "\n", 381 | " train_losses.append(loss.item())\n", 382 | "\n", 383 | " train_epoch_pred=np.concatenate(train_epoch_pred)\n", 384 | " train_epoch_target=train_dataloader.dataset.labels\n", 385 | " train_epoch_result=calculate_metrics(target=train_epoch_target, pred=train_epoch_pred)\n", 386 | " \n", 387 | " print(f\"=====Training Report: mean loss is {sum(train_loss_record)/len(train_loss_record)}=====\")\n", 388 | " print(train_epoch_result)\n", 389 | " \n", 390 | " print(\"=====train done!=====\")\n", 391 | "\n", 392 | " # if e % log_interval == 0:\n", 393 | " # print(\"epoch {} batch id {} loss {} train acc {}\".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))\n", 394 | "\n", 395 | " # print(\"epoch {} train acc {}\".format(e+1, train_acc / (batch_id+1)))\n", 396 | " test_epoch_pred=[]\n", 397 | " test_loss_record=[]\n", 398 | "\n", 399 | " model.eval()\n", 400 | " with torch.no_grad():\n", 401 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n", 402 | " \n", 403 | " token_ids = token_ids.long().to(device)\n", 404 | " segment_ids = segment_ids.long().to(device)\n", 405 | " valid_length = valid_length\n", 406 | " \n", 407 | " # test_label = test_label.long().to(device)\n", 408 | " test_label = test_label.float().to(device)\n", 409 | "\n", 410 | " test_out = model(token_ids, valid_length, segment_ids)\n", 411 | "\n", 412 | " test_loss = loss_fn(test_out, test_label)\n", 413 | "\n", 414 | " test_loss_record.append(test_loss)\n", 415 | " \n", 416 | " valid_losses.append(test_loss.item())\n", 417 | "\n", 418 | " test_pred=test_out.detach().cpu().numpy()\n", 419 | " test_real=test_label.detach().cpu().numpy()\n", 420 | "\n", 421 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n", 422 | "\n", 423 | " if batch_id%50==0:\n", 424 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n", 425 | "\n", 426 | " # save prediction result for calculation of accuracy per epoch\n", 427 | " test_epoch_pred.append(test_pred)\n", 428 | "\n", 429 | " test_epoch_pred=np.concatenate(test_epoch_pred)\n", 430 | " test_epoch_target=test_dataloader.dataset.labels\n", 431 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n", 432 | "\n", 433 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n", 434 | " print(test_epoch_result)\n", 435 | "\n", 436 | " train_loss = np.average(train_losses)\n", 437 | " valid_loss = np.average(valid_losses)\n", 438 | " avg_train_losses.append(train_loss)\n", 439 | " avg_valid_losses.append(valid_loss)\n", 440 | "\n", 441 | " # clear lists to track next epoch\n", 442 | " train_losses = []\n", 443 | " valid_losses = []\n", 444 | "\n", 445 | " # early_stopping needs the validation loss to check if it has decresed, \n", 446 | " # and if it has, it will make a checkpoint of the current model\n", 447 | " early_stopping(valid_loss, model)\n", 448 | "\n", 449 | " if early_stopping.early_stop:\n", 450 | " print(\"Early stopping\")\n", 451 | " break\n", 452 | "\n", 453 | " # load the last checkpoint with the best model\n", 454 | " model.load_state_dict(torch.load(path))\n", 455 | "\n", 456 | " return model, avg_train_losses, avg_valid_losses\n", 457 | " " 458 | ] 459 | }, 460 | { 461 | "cell_type": "code", 462 | "execution_count": null, 463 | "metadata": { 464 | "tags": [ 465 | "outputPrepend" 466 | ] 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "# early stopping patience; how long to wait after last time validation loss improved.\n", 471 | "patience = 10\n", 472 | "model, train_loss, valid_loss = train_model(model, \n", 473 | " model_config[\"batch_size\"],\n", 474 | " patience, \n", 475 | " model_config[\"num_epochs\"], \n", 476 | " path=weight_path)\n" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [] 492 | }, 493 | { 494 | "cell_type": "markdown", 495 | "metadata": {}, 496 | "source": [ 497 | "# test performance" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 54, 503 | "metadata": {}, 504 | "outputs": [], 505 | "source": [ 506 | "weight_path=\"../weights/weight_01.pt\"" 507 | ] 508 | }, 509 | { 510 | "cell_type": "code", 511 | "execution_count": 55, 512 | "metadata": {}, 513 | "outputs": [ 514 | { 515 | "data": { 516 | "text/plain": [ 517 | "" 518 | ] 519 | }, 520 | "execution_count": 55, 521 | "metadata": {}, 522 | "output_type": "execute_result" 523 | } 524 | ], 525 | "source": [ 526 | "model.load_state_dict(torch.load(weight_path))" 527 | ] 528 | }, 529 | { 530 | "cell_type": "code", 531 | "execution_count": 56, 532 | "metadata": {}, 533 | "outputs": [ 534 | { 535 | "name": "stdout", 536 | "output_type": "stream", 537 | "text": [ 538 | "batch number 0, test col-wise accuracy is : 0.9294117647058825\n", 539 | "batch number 50, test col-wise accuracy is : 0.9411764705882353\n", 540 | "batch number 100, test col-wise accuracy is : 0.9058823529411765\n", 541 | "=====Testing Report: mean loss is 0.20872297883033752=====\n", 542 | "{'Accuracy': 0.22437137330754353, 'Column-wise Accuracy': 0.9210376607122539, 'micro/precision': 0.7973273942093542, 'micro/recall': 0.372528616024974, 'micro/f1': 0.5078014184397163, 'macro/precision': 0.6019785504362263, 'macro/recall': 0.28350515905377016, 'macro/f1': 0.34598051554393844, 'samples/precision': 0.563023855577047, 'samples/recall': 0.3934235976789168, 'samples/f1': 0.4393478861563968}\n" 543 | ] 544 | } 545 | ], 546 | "source": [ 547 | "test_epoch_pred=[] \n", 548 | "test_loss_record=[] \n", 549 | "valid_losses=[]\n", 550 | "\n", 551 | "model.eval() \n", 552 | "with torch.no_grad(): \n", 553 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n", 554 | "\n", 555 | " token_ids = token_ids.long().to(device)\n", 556 | " segment_ids = segment_ids.long().to(device)\n", 557 | " valid_length = valid_length\n", 558 | " \n", 559 | " # test_label = test_label.long().to(device)\n", 560 | " test_label = test_label.float().to(device)\n", 561 | "\n", 562 | " test_out = model(token_ids, valid_length, segment_ids)\n", 563 | "\n", 564 | " test_loss = loss_fn(test_out, test_label)\n", 565 | "\n", 566 | " test_loss_record.append(test_loss)\n", 567 | " \n", 568 | " valid_losses.append(test_loss.item())\n", 569 | "\n", 570 | " test_pred=test_out.detach().cpu().numpy()\n", 571 | " test_real=test_label.detach().cpu().numpy()\n", 572 | "\n", 573 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n", 574 | "\n", 575 | " if batch_id%50==0:\n", 576 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n", 577 | "\n", 578 | " # save prediction result for calculation of accuracy per epoch\n", 579 | " test_epoch_pred.append(test_pred)\n", 580 | "\n", 581 | " # if batch_id%10==0:\n", 582 | " # print(test_batch_result[\"Accuracy\"])\n", 583 | " test_epoch_pred=np.concatenate(test_epoch_pred) \n", 584 | " test_epoch_target=test_dataloader.dataset.labels \n", 585 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n", 586 | "\n", 587 | " # print(test_epoch_pred)\n", 588 | " # print(test_epoch_target)\n", 589 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n", 590 | " print(test_epoch_result)" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "metadata": {}, 597 | "outputs": [], 598 | "source": [] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": 58, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [ 606 | "import metrics_for_multilabel as metrics" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 59, 612 | "metadata": {}, 613 | "outputs": [ 614 | { 615 | "data": { 616 | "text/plain": [ 617 | "0.840774483390301" 618 | ] 619 | }, 620 | "execution_count": 59, 621 | "metadata": {}, 622 | "output_type": "execute_result" 623 | } 624 | ], 625 | "source": [ 626 | "metrics.mean_ndcg_score(test_epoch_target,test_epoch_pred, k=17)" 627 | ] 628 | }, 629 | { 630 | "cell_type": "code", 631 | "execution_count": null, 632 | "metadata": {}, 633 | "outputs": [], 634 | "source": [] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": 60, 639 | "metadata": {}, 640 | "outputs": [ 641 | { 642 | "name": "stdout", 643 | "output_type": "stream", 644 | "text": [ 645 | "accuracy: 0.6367182462927145\n" 646 | ] 647 | } 648 | ], 649 | "source": [ 650 | "acc_cnt=0\n", 651 | "for n in range(test_epoch_pred.shape[0]):\n", 652 | " tar_cnt=np.count_nonzero(test_epoch_target[n])\n", 653 | " pred_=test_epoch_pred[n].argsort()[-tar_cnt:]\n", 654 | " tar_=test_epoch_target[n].argsort()[-tar_cnt:]\n", 655 | " acc_cnt+=len(set(pred_)&set(tar_))/len(pred_)\n", 656 | "print(f\"accuracy: {acc_cnt/test_epoch_pred.shape[0]}\")" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": null, 662 | "metadata": {}, 663 | "outputs": [], 664 | "source": [ 665 | "calculate_metrics(target=test_epoch_target, pred=test_epoch_pred, threshold=-1)" 666 | ] 667 | }, 668 | { 669 | "cell_type": "code", 670 | "execution_count": 62, 671 | "metadata": {}, 672 | "outputs": [], 673 | "source": [ 674 | "label_cases_sorted_target=data.label_cols" 675 | ] 676 | }, 677 | { 678 | "cell_type": "code", 679 | "execution_count": 63, 680 | "metadata": {}, 681 | "outputs": [], 682 | "source": [ 683 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=True, pair=False)\n", 684 | "\n", 685 | "def get_prediction_from_txt(input_text, threshold=0.0):\n", 686 | " sentences = transform([input_text])\n", 687 | " get_pred=model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device))\n", 688 | " pred=np.array(get_pred.to(\"cpu\").detach().numpy()[0] > threshold, dtype=float)\n", 689 | " pred=np.nonzero(pred)[0].tolist()\n", 690 | " print(f\"분석 결과, 대화의 예상 태그는 {[label_cases_sorted_target[i] for i in pred]} 입니다.\")\n", 691 | " true=np.nonzero(input_text_label)[0].tolist()\n", 692 | " print(f\"실제 태그는 {[label_cases_sorted_target[i] for i in true]} 입니다.\")\n", 693 | "\n" 694 | ] 695 | }, 696 | { 697 | "cell_type": "code", 698 | "execution_count": 64, 699 | "metadata": {}, 700 | "outputs": [], 701 | "source": [ 702 | "input_text_num=17\n", 703 | "input_text=voc_dataset.iloc[input_text_num,0]\n", 704 | "# input_text=test.iloc[input_text_num,0]\n", 705 | "input_text_label=voc_dataset.iloc[input_text_num,1:].tolist()" 706 | ] 707 | }, 708 | { 709 | "cell_type": "code", 710 | "execution_count": 65, 711 | "metadata": {}, 712 | "outputs": [ 713 | { 714 | "name": "stdout", 715 | "output_type": "stream", 716 | "text": [ 717 | "분석 결과, 대화의 예상 태그는 ['대기예약', '무상신규예약'] 입니다.\n", 718 | "실제 태그는 ['무상신규예약'] 입니다.\n" 719 | ] 720 | } 721 | ], 722 | "source": [ 723 | "get_prediction_from_txt(input_text, -1)" 724 | ] 725 | }, 726 | { 727 | "cell_type": "code", 728 | "execution_count": null, 729 | "metadata": {}, 730 | "outputs": [], 731 | "source": [] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "metadata": {}, 736 | "source": [ 737 | "# XAI" 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": 69, 743 | "metadata": {}, 744 | "outputs": [], 745 | "source": [ 746 | "from captum_tools_vocvis import *" 747 | ] 748 | }, 749 | { 750 | "cell_type": "code", 751 | "execution_count": null, 752 | "metadata": {}, 753 | "outputs": [], 754 | "source": [ 755 | "from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization\n", 756 | "\n", 757 | "# model = BERTClassifier(bertmodel, dr_rate = 0.4).to(device)\n", 758 | "# model.load_state_dict(torch.load(os.getcwd()+\"/chat_voc_model.pt\", map_location=device))\n", 759 | "model.eval()" 760 | ] 761 | }, 762 | { 763 | "cell_type": "code", 764 | "execution_count": 94, 765 | "metadata": {}, 766 | "outputs": [], 767 | "source": [ 768 | "PAD_IND = tok.vocab.padding_token\n", 769 | "PAD_IND = tok.convert_tokens_to_ids(PAD_IND)\n", 770 | "token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)\n", 771 | "lig = LayerIntegratedGradients(model,model.bert.embeddings)" 772 | ] 773 | }, 774 | { 775 | "cell_type": "code", 776 | "execution_count": 95, 777 | "metadata": {}, 778 | "outputs": [], 779 | "source": [ 780 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = 64, pad=True, pair=False)\n", 781 | "\n", 782 | "voc_label_dict_inverse={ele:label_cols.index(ele) for ele in label_cols}\n", 783 | "\n", 784 | "voc_label_dict={label_cols.index(ele):ele for ele in label_cols}" 785 | ] 786 | }, 787 | { 788 | "cell_type": "code", 789 | "execution_count": null, 790 | "metadata": {}, 791 | "outputs": [], 792 | "source": [] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": 96, 797 | "metadata": {}, 798 | "outputs": [], 799 | "source": [ 800 | "def forward_with_sigmoid_for_bert(input,valid_length,segment_ids):\n", 801 | " return torch.sigmoid(model(input,valid_length,segment_ids))\n" 802 | ] 803 | }, 804 | { 805 | "cell_type": "code", 806 | "execution_count": 97, 807 | "metadata": {}, 808 | "outputs": [], 809 | "source": [ 810 | "def forward_for_bert(input,valid_length,segment_ids):\n", 811 | " return torch.nn.functional.softmax(model(input,valid_length,segment_ids),dim=1)" 812 | ] 813 | }, 814 | { 815 | "cell_type": "code", 816 | "execution_count": null, 817 | "metadata": {}, 818 | "outputs": [], 819 | "source": [] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": 109, 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [ 827 | "# accumalate couple samples in this array for visualization purposes\n", 828 | "vis_data_records_ig = []\n", 829 | "\n", 830 | "def interpret_sentence(model, sentence, min_len = 64, label = 0, n_steps=10):\n", 831 | " # text = [token for token in tok.sentencepiece(sentence)]\n", 832 | " # if len(text) < min_len:\n", 833 | " # text += ['pad'] * (min_len - len(text))\n", 834 | " # indexed = tok.convert_tokens_to_ids(text)\n", 835 | " # print(text)\n", 836 | " \n", 837 | " # 토크나이징, 시퀀스 생성\n", 838 | " seq_tokens=transform([sentence])\n", 839 | " indexed=torch.tensor(seq_tokens[0]).long()#.to(device)\n", 840 | " valid_length=torch.tensor(seq_tokens[1]).long().unsqueeze(0)\n", 841 | " segment_ids=torch.tensor(seq_tokens[2]).long().unsqueeze(0).to(device)\n", 842 | " sentence=[token for token in tok.sentencepiece(sentence)]\n", 843 | " \n", 844 | "\n", 845 | " with torch.no_grad():\n", 846 | " model.zero_grad()\n", 847 | "\n", 848 | " input_indices = torch.tensor(indexed, device=device)\n", 849 | " input_indices = input_indices.unsqueeze(0)\n", 850 | " \n", 851 | " seq_length = min_len\n", 852 | "\n", 853 | " # predict\n", 854 | " pred = forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids).detach().cpu().numpy().argmax().item()\n", 855 | " print(forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids))\n", 856 | " pred_ind = round(pred)\n", 857 | " \n", 858 | " # generate reference indices for each sample\n", 859 | " reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)\n", 860 | "\n", 861 | " # compute attributions and approximation delta using layer integrated gradients\n", 862 | " attributions_ig, delta = lig.attribute(input_indices, reference_indices,\\\n", 863 | " n_steps=n_steps, return_convergence_delta=True,target=label,\\\n", 864 | " additional_forward_args=(valid_length,segment_ids))\n", 865 | "\n", 866 | " print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))\n", 867 | "\n", 868 | " add_attributions_to_visualizer(attributions_ig, sentence, pred, pred_ind, label, delta, vis_data_records_ig)" 869 | ] 870 | }, 871 | { 872 | "cell_type": "code", 873 | "execution_count": 110, 874 | "metadata": {}, 875 | "outputs": [], 876 | "source": [ 877 | "def add_attributions_to_visualizer(attributions, input_text, pred, pred_ind, label, delta, vis_data_records):\n", 878 | " attributions = attributions.sum(dim=2).squeeze(0)\n", 879 | " attributions = attributions / torch.norm(attributions)\n", 880 | " attributions = attributions.cpu().detach().numpy()\n", 881 | "\n", 882 | " # storing couple samples in an array for visualization purposes\n", 883 | " vis_data_records.append(visualization.VisualizationDataRecord(\n", 884 | " attributions,\n", 885 | " pred,\n", 886 | " voc_label_dict[pred_ind], #Label.vocab.itos[pred_ind],\n", 887 | " voc_label_dict[label], # Label.vocab.itos[label],\n", 888 | " 100, # Label.vocab.itos[1],\n", 889 | " attributions.sum(), \n", 890 | " input_text,\n", 891 | " delta))" 892 | ] 893 | }, 894 | { 895 | "cell_type": "code", 896 | "execution_count": 126, 897 | "metadata": {}, 898 | "outputs": [ 899 | { 900 | "data": { 901 | "text/html": [ 902 | "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
" 903 | ], 904 | "text/plain": [ 905 | "" 906 | ] 907 | }, 908 | "metadata": {}, 909 | "output_type": "display_data" 910 | } 911 | ], 912 | "source": [ 913 | "sentence=voc_dataset.iloc[22].text\n", 914 | "\n", 915 | "visualize_text(vis_data_records_ig)" 916 | ] 917 | } 918 | ], 919 | "metadata": { 920 | "accelerator": "GPU", 921 | "colab": { 922 | "collapsed_sections": [], 923 | "machine_shape": "hm", 924 | "name": "KoBERT_PoC_1211.ipynb", 925 | "provenance": [] 926 | }, 927 | "kernelspec": { 928 | "display_name": "Python 3", 929 | "language": "python", 930 | "name": "python3" 931 | }, 932 | "language_info": { 933 | "codemirror_mode": { 934 | "name": "ipython", 935 | "version": 3 936 | }, 937 | "file_extension": ".py", 938 | "mimetype": "text/x-python", 939 | "name": "python", 940 | "nbconvert_exporter": "python", 941 | "pygments_lexer": "ipython3", 942 | "version": "3.8.8" 943 | } 944 | }, 945 | "nbformat": 4, 946 | "nbformat_minor": 4 947 | } 948 | -------------------------------------------------------------------------------- /voc_classifier/kobert_multilabel_text_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "
\n", 8 | "
" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "
\n", 16 | "

KoBERT Multi-label text classifier

\n", 17 | "

By: Myeonghak Lee

\n", 18 | "
" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "
\n", 26 | "
" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 1, 32 | "metadata": { 33 | "tags": [] 34 | }, 35 | "outputs": [], 36 | "source": [ 37 | "# Input Data 가공 파트\n", 38 | "\n", 39 | "# import torchtext\n", 40 | "import pandas as pd\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "import os\n", 44 | "import re\n", 45 | "\n", 46 | "import config\n", 47 | "from config import expand_pandas\n", 48 | "from preprocess import preprocess\n", 49 | "\n", 50 | "DATA_PATH=config.DATA_PATH\n", 51 | "\n", 52 | "model_config=config.model_config" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import warnings\n", 62 | "warnings.filterwarnings(\"ignore\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "done\n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "config.expand_pandas(max_rows=100, max_cols=100,width=1000,max_info_cols=500)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "### **configs**" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 4, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "num_class=17\n", 96 | "ver_num=1\n", 97 | "\n", 98 | "except_labels=[\"변경/취소\",\"예약기타\"]\n", 99 | "\n", 100 | "version_info=\"{:02d}\".format(ver_num)\n", 101 | "weight_path=f\"../weights/weight_{version_info}.pt\"" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "### **preprocess**" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 12, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "data=preprocess()" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 13, 130 | "metadata": {}, 131 | "outputs": [ 132 | { 133 | "data": { 134 | "text/plain": [ 135 | "True" 136 | ] 137 | }, 138 | "execution_count": 13, 139 | "metadata": {}, 140 | "output_type": "execute_result" 141 | } 142 | ], 143 | "source": [ 144 | "# data_orig=data.voc_total[\"종합본\"]\n", 145 | "data.make_table()\n", 146 | "\n", 147 | "# put labels\n", 148 | "data.label_process(num_labels=num_class, except_labels=except_labels)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 14, 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "orig=data.voc_total[\"종합본\"]" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 15, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "label_cols=data.label_cols" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 16, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "df=data.data.copy()" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 17, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "voc_dataset=df.reset_index(drop=True)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "# Modeling part" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "import torch\n", 201 | "from torch import nn\n", 202 | "\n", 203 | "from metrics_for_multilabel import calculate_metrics, colwise_accuracy\n", 204 | "\n", 205 | "from bert_model import Data_for_BERT, BERTClassifier, EarlyStopping\n", 206 | "\n", 207 | "from transformers import get_linear_schedule_with_warmup, AdamW" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 15, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": null, 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 24, 236 | "metadata": {}, 237 | "outputs": [ 238 | { 239 | "ename": "NameError", 240 | "evalue": "name 'Data_for_BERT' is not defined", 241 | "output_type": "error", 242 | "traceback": [ 243 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 244 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 245 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mdata_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mdata_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 246 | "\u001b[0;31mNameError\u001b[0m: name 'Data_for_BERT' is not defined" 247 | ] 248 | } 249 | ], 250 | "source": [ 251 | "from sklearn.model_selection import train_test_split\n", 252 | "train_input, test_input, train_target, test_target = train_test_split(voc_dataset, voc_dataset[\"국내선\"], test_size = 0.25, random_state = 42)\n", 253 | "\n", 254 | "# train=pd.concat([train_input,train_target],axis=1)\n", 255 | "# test=pd.concat([test_input,test_target],axis=1)\n", 256 | "\n", 257 | "train=train_input.copy()\n", 258 | "test=test_input.copy()\n", 259 | "\n", 260 | "train=train.reset_index(drop=True)\n", 261 | "test=test.reset_index(drop=True)\n", 262 | "\n", 263 | "data_train = Data_for_BERT(train, model_config[\"max_len\"], True, False, label_cols=label_cols)\n", 264 | "data_test = Data_for_BERT(test, model_config[\"max_len\"], True, False, label_cols=label_cols)\n", 265 | "\n", 266 | "# 파이토치 모델에 넣을 수 있도록 데이터를 처리함. \n", 267 | "# data_train을 넣어주고, 이 테이터를 batch_size에 맞게 잘라줌. num_workers는 사용할 subprocess의 개수를 의미함(병렬 프로그래밍)\n", 268 | "\n", 269 | "train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=model_config[\"batch_size\"], num_workers=0)\n", 270 | "test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=model_config[\"batch_size\"], num_workers=0)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 18, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "# KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌.\n", 287 | "model = BERTClassifier(num_classes=num_class, dr_rate = model_config[\"dr_rate\"]).to(device)\n", 288 | "\n", 289 | "# 옵티마이저와 스케쥴 준비 (linear warmup과 decay)\n", 290 | "no_decay = ['bias', 'LayerNorm.weight']\n", 291 | "\n", 292 | "# no_decay에 해당하는 파라미터명을 가진 레이어들은 decay에서 배제하기 위해 weight_decay를 0으로 셋팅, 그 외에는 0.01로 decay\n", 293 | "# weight decay란 l2 norm으로 파라미터 값을 정규화해주는 기법을 의미함\n", 294 | "optimizer_grouped_parameters = [\n", 295 | " {\"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay' : 0.01},\n", 296 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n", 297 | "]\n", 298 | "\n", 299 | "\n", 300 | "# 옵티마이저는 AdamW, 손실함수는 BCE\n", 301 | "# optimizer_grouped_parameters는 최적화할 파라미터의 그룹을 의미함\n", 302 | "optimizer = AdamW(optimizer_grouped_parameters, lr= model_config[\"learning_rate\"])\n", 303 | "# loss_fn = nn.CrossEntropyLoss()\n", 304 | "loss_fn=nn.BCEWithLogitsLoss()\n", 305 | "\n", 306 | "\n", 307 | "# t_total = train_dataloader.dataset.labels.shape[0] * num_epochs\n", 308 | "# linear warmup을 사용해 학습 초기 단계(배치 초기)의 learning rate를 조금씩 증가시켜 나가다, 어느 지점에 이르면 constant하게 유지\n", 309 | "# 초기 학습 단계에서의 변동성을 줄여줌.\n", 310 | "\n", 311 | "t_total = len(train_dataloader) * model_config[\"num_epochs\"]\n", 312 | "warmup_step = int(t_total * model_config[\"warmup_ratio\"])\n", 313 | "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)\n", 314 | "\n", 315 | "\n", 316 | "# model_save_name = 'classifier'\n", 317 | "# model_file='.pt'\n", 318 | "# path = f\"./bert_weights/{model_save_name}_{model_file}\" " 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 23, 331 | "metadata": {}, 332 | "outputs": [], 333 | "source": [ 334 | "def train_model(model, batch_size, patience, n_epochs,path):\n", 335 | " \n", 336 | " # to track the training loss as the model trains\n", 337 | " train_losses = []\n", 338 | " # to track the validation loss as the model trains\n", 339 | " valid_losses = []\n", 340 | " # to track the average training loss per epoch as the model trains\n", 341 | " avg_train_losses = []\n", 342 | " # to track the average validation loss per epoch as the model trains\n", 343 | " avg_valid_losses = [] \n", 344 | "\n", 345 | " early_stopping = EarlyStopping(patience=patience, verbose=True, path=path)\n", 346 | "\n", 347 | " for epoch in range(1, n_epochs + 1):\n", 348 | " \n", 349 | " # initialize the early_stopping object\n", 350 | " model.train()\n", 351 | " train_epoch_pred=[]\n", 352 | " train_loss_record=[]\n", 353 | "\n", 354 | " for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n", 355 | " optimizer.zero_grad()\n", 356 | "\n", 357 | " token_ids = token_ids.long().to(device)\n", 358 | " segment_ids = segment_ids.long().to(device)\n", 359 | " valid_length= valid_length\n", 360 | " \n", 361 | " # label = label.long().to(device)\n", 362 | " label = label.float().to(device)\n", 363 | "\n", 364 | " out= model(token_ids, valid_length, segment_ids)#.squeeze(1)\n", 365 | " \n", 366 | " loss = loss_fn(out, label)\n", 367 | "\n", 368 | " train_loss_record.append(loss)\n", 369 | "\n", 370 | " train_pred=out.detach().cpu().numpy()\n", 371 | " train_real=label.detach().cpu().numpy()\n", 372 | "\n", 373 | " train_batch_result = calculate_metrics(np.array(train_pred), np.array(train_real))\n", 374 | " \n", 375 | " if batch_id%50==0:\n", 376 | " print(f\"batch number {batch_id}, train col-wise accuracy is : {train_batch_result['Column-wise Accuracy']}\")\n", 377 | " \n", 378 | "\n", 379 | " # save prediction result for calculation of accuracy per batch\n", 380 | " train_epoch_pred.append(train_pred)\n", 381 | "\n", 382 | " \n", 383 | " loss.backward()\n", 384 | " torch.nn.utils.clip_grad_norm_(model.parameters(), model_config[\"max_grad_norm\"])\n", 385 | " optimizer.step()\n", 386 | " scheduler.step() # Update learning rate schedule\n", 387 | "\n", 388 | " train_losses.append(loss.item())\n", 389 | "\n", 390 | " train_epoch_pred=np.concatenate(train_epoch_pred)\n", 391 | " train_epoch_target=train_dataloader.dataset.labels\n", 392 | " train_epoch_result=calculate_metrics(target=train_epoch_target, pred=train_epoch_pred)\n", 393 | " \n", 394 | " print(f\"=====Training Report: mean loss is {sum(train_loss_record)/len(train_loss_record)}=====\")\n", 395 | " print(train_epoch_result)\n", 396 | " \n", 397 | " print(\"=====train done!=====\")\n", 398 | "\n", 399 | " # if e % log_interval == 0:\n", 400 | " # print(\"epoch {} batch id {} loss {} train acc {}\".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))\n", 401 | "\n", 402 | " # print(\"epoch {} train acc {}\".format(e+1, train_acc / (batch_id+1)))\n", 403 | " test_epoch_pred=[]\n", 404 | " test_loss_record=[]\n", 405 | "\n", 406 | " model.eval()\n", 407 | " with torch.no_grad():\n", 408 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n", 409 | " \n", 410 | " token_ids = token_ids.long().to(device)\n", 411 | " segment_ids = segment_ids.long().to(device)\n", 412 | " valid_length = valid_length\n", 413 | " \n", 414 | " # test_label = test_label.long().to(device)\n", 415 | " test_label = test_label.float().to(device)\n", 416 | "\n", 417 | " test_out = model(token_ids, valid_length, segment_ids)\n", 418 | "\n", 419 | " test_loss = loss_fn(test_out, test_label)\n", 420 | "\n", 421 | " test_loss_record.append(test_loss)\n", 422 | " \n", 423 | " valid_losses.append(test_loss.item())\n", 424 | "\n", 425 | " test_pred=test_out.detach().cpu().numpy()\n", 426 | " test_real=test_label.detach().cpu().numpy()\n", 427 | "\n", 428 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n", 429 | "\n", 430 | " if batch_id%50==0:\n", 431 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n", 432 | "\n", 433 | " # save prediction result for calculation of accuracy per epoch\n", 434 | " test_epoch_pred.append(test_pred)\n", 435 | "\n", 436 | " test_epoch_pred=np.concatenate(test_epoch_pred)\n", 437 | " test_epoch_target=test_dataloader.dataset.labels\n", 438 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n", 439 | "\n", 440 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n", 441 | " print(test_epoch_result)\n", 442 | "\n", 443 | " train_loss = np.average(train_losses)\n", 444 | " valid_loss = np.average(valid_losses)\n", 445 | " avg_train_losses.append(train_loss)\n", 446 | " avg_valid_losses.append(valid_loss)\n", 447 | "\n", 448 | " # clear lists to track next epoch\n", 449 | " train_losses = []\n", 450 | " valid_losses = []\n", 451 | "\n", 452 | " # early_stopping needs the validation loss to check if it has decresed, \n", 453 | " # and if it has, it will make a checkpoint of the current model\n", 454 | " early_stopping(valid_loss, model)\n", 455 | "\n", 456 | " if early_stopping.early_stop:\n", 457 | " print(\"Early stopping\")\n", 458 | " break\n", 459 | "\n", 460 | " # load the last checkpoint with the best model\n", 461 | " model.load_state_dict(torch.load(path))\n", 462 | "\n", 463 | " return model, avg_train_losses, avg_valid_losses\n", 464 | " " 465 | ] 466 | }, 467 | { 468 | "cell_type": "code", 469 | "execution_count": null, 470 | "metadata": { 471 | "tags": [ 472 | "outputPrepend" 473 | ] 474 | }, 475 | "outputs": [], 476 | "source": [ 477 | "# early stopping patience; how long to wait after last time validation loss improved.\n", 478 | "patience = 10\n", 479 | "model, train_loss, valid_loss = train_model(model, \n", 480 | " model_config[\"batch_size\"],\n", 481 | " patience, \n", 482 | " model_config[\"num_epochs\"], \n", 483 | " path=weight_path)\n" 484 | ] 485 | }, 486 | { 487 | "cell_type": "code", 488 | "execution_count": null, 489 | "metadata": {}, 490 | "outputs": [], 491 | "source": [] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [] 499 | }, 500 | { 501 | "cell_type": "markdown", 502 | "metadata": {}, 503 | "source": [ 504 | "# test performance" 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "execution_count": 54, 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "weight_path=\"../weights/weight_01.pt\"" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": 55, 519 | "metadata": {}, 520 | "outputs": [ 521 | { 522 | "data": { 523 | "text/plain": [ 524 | "" 525 | ] 526 | }, 527 | "execution_count": 55, 528 | "metadata": {}, 529 | "output_type": "execute_result" 530 | } 531 | ], 532 | "source": [ 533 | "model.load_state_dict(torch.load(weight_path))" 534 | ] 535 | }, 536 | { 537 | "cell_type": "code", 538 | "execution_count": 56, 539 | "metadata": {}, 540 | "outputs": [ 541 | { 542 | "name": "stdout", 543 | "output_type": "stream", 544 | "text": [ 545 | "batch number 0, test col-wise accuracy is : 0.9294117647058825\n", 546 | "batch number 50, test col-wise accuracy is : 0.9411764705882353\n", 547 | "batch number 100, test col-wise accuracy is : 0.9058823529411765\n", 548 | "=====Testing Report: mean loss is 0.20872297883033752=====\n", 549 | "{'Accuracy': 0.22437137330754353, 'Column-wise Accuracy': 0.9210376607122539, 'micro/precision': 0.7973273942093542, 'micro/recall': 0.372528616024974, 'micro/f1': 0.5078014184397163, 'macro/precision': 0.6019785504362263, 'macro/recall': 0.28350515905377016, 'macro/f1': 0.34598051554393844, 'samples/precision': 0.563023855577047, 'samples/recall': 0.3934235976789168, 'samples/f1': 0.4393478861563968}\n" 550 | ] 551 | } 552 | ], 553 | "source": [ 554 | "test_epoch_pred=[] \n", 555 | "test_loss_record=[] \n", 556 | "valid_losses=[]\n", 557 | "\n", 558 | "model.eval() \n", 559 | "with torch.no_grad(): \n", 560 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n", 561 | "\n", 562 | " token_ids = token_ids.long().to(device)\n", 563 | " segment_ids = segment_ids.long().to(device)\n", 564 | " valid_length = valid_length\n", 565 | " \n", 566 | " # test_label = test_label.long().to(device)\n", 567 | " test_label = test_label.float().to(device)\n", 568 | "\n", 569 | " test_out = model(token_ids, valid_length, segment_ids)\n", 570 | "\n", 571 | " test_loss = loss_fn(test_out, test_label)\n", 572 | "\n", 573 | " test_loss_record.append(test_loss)\n", 574 | " \n", 575 | " valid_losses.append(test_loss.item())\n", 576 | "\n", 577 | " test_pred=test_out.detach().cpu().numpy()\n", 578 | " test_real=test_label.detach().cpu().numpy()\n", 579 | "\n", 580 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n", 581 | "\n", 582 | " if batch_id%50==0:\n", 583 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n", 584 | "\n", 585 | " # save prediction result for calculation of accuracy per epoch\n", 586 | " test_epoch_pred.append(test_pred)\n", 587 | "\n", 588 | " # if batch_id%10==0:\n", 589 | " # print(test_batch_result[\"Accuracy\"])\n", 590 | " test_epoch_pred=np.concatenate(test_epoch_pred) \n", 591 | " test_epoch_target=test_dataloader.dataset.labels \n", 592 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n", 593 | "\n", 594 | " # print(test_epoch_pred)\n", 595 | " # print(test_epoch_target)\n", 596 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n", 597 | " print(test_epoch_result)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "code", 602 | "execution_count": null, 603 | "metadata": {}, 604 | "outputs": [], 605 | "source": [] 606 | }, 607 | { 608 | "cell_type": "code", 609 | "execution_count": 58, 610 | "metadata": {}, 611 | "outputs": [], 612 | "source": [ 613 | "import metrics_for_multilabel as metrics" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": 59, 619 | "metadata": {}, 620 | "outputs": [ 621 | { 622 | "data": { 623 | "text/plain": [ 624 | "0.840774483390301" 625 | ] 626 | }, 627 | "execution_count": 59, 628 | "metadata": {}, 629 | "output_type": "execute_result" 630 | } 631 | ], 632 | "source": [ 633 | "metrics.mean_ndcg_score(test_epoch_target,test_epoch_pred, k=17)" 634 | ] 635 | }, 636 | { 637 | "cell_type": "code", 638 | "execution_count": null, 639 | "metadata": {}, 640 | "outputs": [], 641 | "source": [] 642 | }, 643 | { 644 | "cell_type": "code", 645 | "execution_count": 60, 646 | "metadata": {}, 647 | "outputs": [ 648 | { 649 | "name": "stdout", 650 | "output_type": "stream", 651 | "text": [ 652 | "accuracy: 0.6367182462927145\n" 653 | ] 654 | } 655 | ], 656 | "source": [ 657 | "acc_cnt=0\n", 658 | "for n in range(test_epoch_pred.shape[0]):\n", 659 | " tar_cnt=np.count_nonzero(test_epoch_target[n])\n", 660 | " pred_=test_epoch_pred[n].argsort()[-tar_cnt:]\n", 661 | " tar_=test_epoch_target[n].argsort()[-tar_cnt:]\n", 662 | " acc_cnt+=len(set(pred_)&set(tar_))/len(pred_)\n", 663 | "print(f\"accuracy: {acc_cnt/test_epoch_pred.shape[0]}\")" 664 | ] 665 | }, 666 | { 667 | "cell_type": "code", 668 | "execution_count": null, 669 | "metadata": {}, 670 | "outputs": [], 671 | "source": [ 672 | "calculate_metrics(target=test_epoch_target, pred=test_epoch_pred, threshold=-1)" 673 | ] 674 | }, 675 | { 676 | "cell_type": "code", 677 | "execution_count": 62, 678 | "metadata": {}, 679 | "outputs": [], 680 | "source": [ 681 | "label_cases_sorted_target=data.label_cols" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": 63, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=True, pair=False)\n", 691 | "\n", 692 | "def get_prediction_from_txt(input_text, threshold=0.0):\n", 693 | " sentences = transform([input_text])\n", 694 | " get_pred=model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device))\n", 695 | " pred=np.array(get_pred.to(\"cpu\").detach().numpy()[0] > threshold, dtype=float)\n", 696 | " pred=np.nonzero(pred)[0].tolist()\n", 697 | " print(f\"분석 결과, 대화의 예상 태그는 {[label_cases_sorted_target[i] for i in pred]} 입니다.\")\n", 698 | " true=np.nonzero(input_text_label)[0].tolist()\n", 699 | " print(f\"실제 태그는 {[label_cases_sorted_target[i] for i in true]} 입니다.\")\n", 700 | "\n" 701 | ] 702 | }, 703 | { 704 | "cell_type": "code", 705 | "execution_count": 64, 706 | "metadata": {}, 707 | "outputs": [], 708 | "source": [ 709 | "input_text_num=17\n", 710 | "input_text=voc_dataset.iloc[input_text_num,0]\n", 711 | "# input_text=test.iloc[input_text_num,0]\n", 712 | "input_text_label=voc_dataset.iloc[input_text_num,1:].tolist()" 713 | ] 714 | }, 715 | { 716 | "cell_type": "code", 717 | "execution_count": 65, 718 | "metadata": {}, 719 | "outputs": [ 720 | { 721 | "name": "stdout", 722 | "output_type": "stream", 723 | "text": [ 724 | "분석 결과, 대화의 예상 태그는 ['대기예약', '무상신규예약'] 입니다.\n", 725 | "실제 태그는 ['무상신규예약'] 입니다.\n" 726 | ] 727 | } 728 | ], 729 | "source": [ 730 | "get_prediction_from_txt(input_text, -1)" 731 | ] 732 | }, 733 | { 734 | "cell_type": "code", 735 | "execution_count": null, 736 | "metadata": {}, 737 | "outputs": [], 738 | "source": [] 739 | }, 740 | { 741 | "cell_type": "markdown", 742 | "metadata": {}, 743 | "source": [ 744 | "# XAI" 745 | ] 746 | }, 747 | { 748 | "cell_type": "code", 749 | "execution_count": 69, 750 | "metadata": {}, 751 | "outputs": [], 752 | "source": [ 753 | "from captum_tools_vocvis import *" 754 | ] 755 | }, 756 | { 757 | "cell_type": "code", 758 | "execution_count": null, 759 | "metadata": {}, 760 | "outputs": [], 761 | "source": [ 762 | "from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization\n", 763 | "\n", 764 | "# model = BERTClassifier(bertmodel, dr_rate = 0.4).to(device)\n", 765 | "# model.load_state_dict(torch.load(os.getcwd()+\"/chat_voc_model.pt\", map_location=device))\n", 766 | "model.eval()" 767 | ] 768 | }, 769 | { 770 | "cell_type": "code", 771 | "execution_count": 94, 772 | "metadata": {}, 773 | "outputs": [], 774 | "source": [ 775 | "PAD_IND = tok.vocab.padding_token\n", 776 | "PAD_IND = tok.convert_tokens_to_ids(PAD_IND)\n", 777 | "token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)\n", 778 | "lig = LayerIntegratedGradients(model,model.bert.embeddings)" 779 | ] 780 | }, 781 | { 782 | "cell_type": "code", 783 | "execution_count": 95, 784 | "metadata": {}, 785 | "outputs": [], 786 | "source": [ 787 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = 64, pad=True, pair=False)\n", 788 | "\n", 789 | "voc_label_dict_inverse={ele:label_cols.index(ele) for ele in label_cols}\n", 790 | "\n", 791 | "voc_label_dict={label_cols.index(ele):ele for ele in label_cols}" 792 | ] 793 | }, 794 | { 795 | "cell_type": "code", 796 | "execution_count": null, 797 | "metadata": {}, 798 | "outputs": [], 799 | "source": [] 800 | }, 801 | { 802 | "cell_type": "code", 803 | "execution_count": 96, 804 | "metadata": {}, 805 | "outputs": [], 806 | "source": [ 807 | "def forward_with_sigmoid_for_bert(input,valid_length,segment_ids):\n", 808 | " return torch.sigmoid(model(input,valid_length,segment_ids))\n" 809 | ] 810 | }, 811 | { 812 | "cell_type": "code", 813 | "execution_count": 97, 814 | "metadata": {}, 815 | "outputs": [], 816 | "source": [ 817 | "def forward_for_bert(input,valid_length,segment_ids):\n", 818 | " return torch.nn.functional.softmax(model(input,valid_length,segment_ids),dim=1)" 819 | ] 820 | }, 821 | { 822 | "cell_type": "code", 823 | "execution_count": null, 824 | "metadata": {}, 825 | "outputs": [], 826 | "source": [] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": 109, 831 | "metadata": {}, 832 | "outputs": [], 833 | "source": [ 834 | "# accumalate couple samples in this array for visualization purposes\n", 835 | "vis_data_records_ig = []\n", 836 | "\n", 837 | "def interpret_sentence(model, sentence, min_len = 64, label = 0, n_steps=10):\n", 838 | " # text = [token for token in tok.sentencepiece(sentence)]\n", 839 | " # if len(text) < min_len:\n", 840 | " # text += ['pad'] * (min_len - len(text))\n", 841 | " # indexed = tok.convert_tokens_to_ids(text)\n", 842 | " # print(text)\n", 843 | " \n", 844 | " # 토크나이징, 시퀀스 생성\n", 845 | " seq_tokens=transform([sentence])\n", 846 | " indexed=torch.tensor(seq_tokens[0]).long()#.to(device)\n", 847 | " valid_length=torch.tensor(seq_tokens[1]).long().unsqueeze(0)\n", 848 | " segment_ids=torch.tensor(seq_tokens[2]).long().unsqueeze(0).to(device)\n", 849 | " sentence=[token for token in tok.sentencepiece(sentence)]\n", 850 | " \n", 851 | "\n", 852 | " with torch.no_grad():\n", 853 | " model.zero_grad()\n", 854 | "\n", 855 | " input_indices = torch.tensor(indexed, device=device)\n", 856 | " input_indices = input_indices.unsqueeze(0)\n", 857 | " \n", 858 | " seq_length = min_len\n", 859 | "\n", 860 | " # predict\n", 861 | " pred = forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids).detach().cpu().numpy().argmax().item()\n", 862 | " print(forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids))\n", 863 | " pred_ind = round(pred)\n", 864 | " \n", 865 | " # generate reference indices for each sample\n", 866 | " reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)\n", 867 | "\n", 868 | " # compute attributions and approximation delta using layer integrated gradients\n", 869 | " attributions_ig, delta = lig.attribute(input_indices, reference_indices,\\\n", 870 | " n_steps=n_steps, return_convergence_delta=True,target=label,\\\n", 871 | " additional_forward_args=(valid_length,segment_ids))\n", 872 | "\n", 873 | " print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))\n", 874 | "\n", 875 | " add_attributions_to_visualizer(attributions_ig, sentence, pred, pred_ind, label, delta, vis_data_records_ig)" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": 110, 881 | "metadata": {}, 882 | "outputs": [], 883 | "source": [ 884 | "def add_attributions_to_visualizer(attributions, input_text, pred, pred_ind, label, delta, vis_data_records):\n", 885 | " attributions = attributions.sum(dim=2).squeeze(0)\n", 886 | " attributions = attributions / torch.norm(attributions)\n", 887 | " attributions = attributions.cpu().detach().numpy()\n", 888 | "\n", 889 | " # storing couple samples in an array for visualization purposes\n", 890 | " vis_data_records.append(visualization.VisualizationDataRecord(\n", 891 | " attributions,\n", 892 | " pred,\n", 893 | " voc_label_dict[pred_ind], #Label.vocab.itos[pred_ind],\n", 894 | " voc_label_dict[label], # Label.vocab.itos[label],\n", 895 | " 100, # Label.vocab.itos[1],\n", 896 | " attributions.sum(), \n", 897 | " input_text,\n", 898 | " delta))" 899 | ] 900 | }, 901 | { 902 | "cell_type": "code", 903 | "execution_count": 126, 904 | "metadata": {}, 905 | "outputs": [ 906 | { 907 | "data": { 908 | "text/html": [ 909 | "
Legend: Negative Neutral Positive
True LabelPredicted LabelAttribution LabelAttribution ScoreWord Importance
" 910 | ], 911 | "text/plain": [ 912 | "" 913 | ] 914 | }, 915 | "metadata": {}, 916 | "output_type": "display_data" 917 | } 918 | ], 919 | "source": [ 920 | "sentence=voc_dataset.iloc[22].text\n", 921 | "\n", 922 | "visualize_text(vis_data_records_ig)" 923 | ] 924 | } 925 | ], 926 | "metadata": { 927 | "accelerator": "GPU", 928 | "colab": { 929 | "collapsed_sections": [], 930 | "machine_shape": "hm", 931 | "name": "KoBERT_PoC_1211.ipynb", 932 | "provenance": [] 933 | }, 934 | "kernelspec": { 935 | "display_name": "Python 3", 936 | "language": "python", 937 | "name": "python3" 938 | }, 939 | "language_info": { 940 | "codemirror_mode": { 941 | "name": "ipython", 942 | "version": 3 943 | }, 944 | "file_extension": ".py", 945 | "mimetype": "text/x-python", 946 | "name": "python", 947 | "nbconvert_exporter": "python", 948 | "pygments_lexer": "ipython3", 949 | "version": "3.8.8" 950 | } 951 | }, 952 | "nbformat": 4, 953 | "nbformat_minor": 4 954 | } 955 | --------------------------------------------------------------------------------