├── data
├── .txt
└── .gitignore
├── img
├── captum.png
├── demo_view.png
├── entangled.png
├── seperated.png
├── captum_example.png
├── demo_example.png
└── confusion_matrix.png
├── voc_classifier
├── __pycache__
│ ├── config.cpython-38.pyc
│ └── preprocess.cpython-38.pyc
├── demo.py
├── config.py
├── .ipynb_checkpoints
│ ├── config-checkpoint.py
│ ├── preprocess-checkpoint.py
│ └── kobert_multilabel_text_classifier-checkpoint.ipynb
├── captum_tools_vocvis.py
├── inference.py
├── bert_model.py
├── preprocess.py
├── metrics_for_multilabel.py
└── kobert_multilabel_text_classifier.ipynb
├── requirements.txt
├── README.md
└── LICENSE
/data/.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/img/captum.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/captum.png
--------------------------------------------------------------------------------
/img/demo_view.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/demo_view.png
--------------------------------------------------------------------------------
/img/entangled.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/entangled.png
--------------------------------------------------------------------------------
/img/seperated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/seperated.png
--------------------------------------------------------------------------------
/img/captum_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/captum_example.png
--------------------------------------------------------------------------------
/img/demo_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/demo_example.png
--------------------------------------------------------------------------------
/img/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/img/confusion_matrix.png
--------------------------------------------------------------------------------
/voc_classifier/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/voc_classifier/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/voc_classifier/__pycache__/preprocess.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myeonghak/kobert-multi-label-VOC-classifier/HEAD/voc_classifier/__pycache__/preprocess.cpython-38.pyc
--------------------------------------------------------------------------------
/voc_classifier/demo.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from inference import load_model
4 |
5 | model=load_model()
6 |
7 | bert=model.get_model()
8 |
9 | bert.eval()
10 |
11 |
12 | import pickle
13 |
14 | raw=pickle.load(open("../data/raw_text.pkl","rb"))
15 |
16 | data=pickle.load(open("../data/preprocessed_text.pkl","rb"))
17 |
18 | def process_for_inference(num_of_text):
19 | raw_to_show=raw.text.iloc[num_of_text]
20 | text_input=data.text.iloc[num_of_text]
21 | input_text_label=data.iloc[num_of_text,1:].tolist()
22 | return raw_to_show, text_input, input_text_label
23 |
24 |
25 | number_for_demo = st.text_input('태그를 생성할 텍스트의 번호를 입력해 주세요.')
26 | if number_for_demo:
27 | num_of_text=int(number_for_demo)
28 | Raw,Input,Label=process_for_inference(num_of_text)
29 | for i in range(62,len(Raw),62):
30 | Raw=Raw[:i]+"\n"+Raw[i:]
31 | st.text(Raw)
32 | st.text("\n")
33 | st.text("==================예측 결과==================")
34 | st.text("\n")
35 |
36 | a,b=model.get_prediction_from_txt(Input,Label)
37 | st.text(a)
38 | st.text(b)
--------------------------------------------------------------------------------
/voc_classifier/config.py:
--------------------------------------------------------------------------------
1 | # config file for DeepVOC model
2 |
3 | import os
4 |
5 |
6 | DATA_PATH="../data/voc_data.xlsx"
7 |
8 |
9 |
10 | def expand_pandas(max_rows=100, max_cols=500, width=None, max_info_cols=None):
11 | import pandas as pd
12 | if max_rows:
13 | pd.set_option("display.max_rows", max_rows) # 출력할 최대 행 갯수를 100으로 설정
14 | if max_cols:
15 | pd.set_option("display.max_columns", max_cols) # 출력할 최대 열 갯수를 500개로 설정
16 | if width:
17 | pd.set_option("display.width", width) # 글자 수 기준 출력할 넓이 설정
18 | if max_info_cols:
19 | pd.set_option("max_info_columns", max_info_cols) # 열 기반 info가 주어질 경우, 최대 넓이
20 | pd.set_option("display.float_format", lambda x : "%.3f" %x) # 출력할 float의 소숫점 자릿수 제한
21 | print("done")
22 |
23 |
24 |
25 | model_config={"max_len" :512,
26 | "batch_size":5,
27 | "warmup_ratio": 0.1,
28 | "num_epochs": 200,
29 | "max_grad_norm": 1,
30 | "learning_rate": 5e-6,
31 | "dr_rate":0.45}
32 |
33 |
34 | label_cols=['국내선',
35 | '스케줄/기종변경',
36 | '항공권규정',
37 | '사전좌석배정',
38 | '환불',
39 | '홈페이지',
40 | '유상변경/취소',
41 | '부가서비스',
42 | '일정표/영수증',
43 | '무상변경/취소',
44 | '무상환불',
45 | '대기예약',
46 | '유상환불',
47 | '운임',
48 | '무상신규예약',
49 | '무응답',
50 | '재발행']
--------------------------------------------------------------------------------
/voc_classifier/.ipynb_checkpoints/config-checkpoint.py:
--------------------------------------------------------------------------------
1 | # config file for DeepVOC model
2 |
3 | import os
4 |
5 |
6 | DATA_PATH="../data/voc_data.xlsx"
7 |
8 |
9 |
10 | def expand_pandas(max_rows=100, max_cols=500, width=None, max_info_cols=None):
11 | import pandas as pd
12 | if max_rows:
13 | pd.set_option("display.max_rows", max_rows) # 출력할 최대 행 갯수를 100으로 설정
14 | if max_cols:
15 | pd.set_option("display.max_columns", max_cols) # 출력할 최대 열 갯수를 500개로 설정
16 | if width:
17 | pd.set_option("display.width", width) # 글자 수 기준 출력할 넓이 설정
18 | if max_info_cols:
19 | pd.set_option("max_info_columns", max_info_cols) # 열 기반 info가 주어질 경우, 최대 넓이
20 | pd.set_option("display.float_format", lambda x : "%.3f" %x) # 출력할 float의 소숫점 자릿수 제한
21 | print("done")
22 |
23 |
24 |
25 | model_config={"max_len" :512,
26 | "batch_size":5,
27 | "warmup_ratio": 0.1,
28 | "num_epochs": 200,
29 | "max_grad_norm": 1,
30 | "learning_rate": 5e-6,
31 | "dr_rate":0.45}
32 |
33 |
34 | label_cols=['국내선',
35 | '스케줄/기종변경',
36 | '항공권규정',
37 | '사전좌석배정',
38 | '환불',
39 | '홈페이지',
40 | '유상변경/취소',
41 | '부가서비스',
42 | '일정표/영수증',
43 | '무상변경/취소',
44 | '무상환불',
45 | '대기예약',
46 | '유상환불',
47 | '운임',
48 | '무상신규예약',
49 | '무응답',
50 | '재발행']
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | altair==4.1.0
2 | anyio==3.3.0
3 | argcomplete==1.12.3
4 | argon2-cffi==20.1.0
5 | astor==0.8.1
6 | async-generator==1.10
7 | attrs==21.2.0
8 | Babel==2.9.1
9 | backcall==0.2.0
10 | backports.zoneinfo==0.2.1
11 | base58==2.1.0
12 | bleach==4.0.0
13 | blinker==1.4
14 | Bottleneck==1.3.2
15 | cachetools==4.2.2
16 | captum==0.4.0
17 | certifi==2021.5.30
18 | cffi==1.14.6
19 | chardet==3.0.4
20 | charset-normalizer==2.0.4
21 | click==7.1.2
22 | colorama==0.4.4
23 | cycler==0.10.0
24 | Cython==0.29.24
25 | debugpy==1.4.1
26 | decorator==5.0.9
27 | defusedxml==0.7.1
28 | entrypoints==0.3
29 | et-xmlfile==1.1.0
30 | filelock==3.0.12
31 | gitdb==4.0.7
32 | GitPython==3.1.18
33 | gluonnlp==0.10.0
34 | graphviz==0.8.4
35 | idna==2.6
36 | importlib-metadata==4.6.3
37 | ipykernel==6.1.0
38 | ipython==7.26.0
39 | ipython-genutils==0.2.0
40 | ipywidgets==7.6.3
41 | jedi==0.18.0
42 | Jinja2==3.0.1
43 | joblib==1.0.1
44 | json5==0.9.6
45 | jsonschema==3.2.0
46 | jupyter-client==6.1.12
47 | jupyter-core==4.7.1
48 | jupyter-server==1.10.2
49 | jupyterlab==3.1.6
50 | jupyterlab-pygments==0.1.2
51 | jupyterlab-server==2.7.0
52 | jupyterlab-widgets==1.0.0
53 | kiwisolver==1.3.1
54 | MarkupSafe==2.0.1
55 | matplotlib==3.4.3
56 | matplotlib-inline==0.1.2
57 | mistune==0.8.4
58 | mkl-fft==1.3.0
59 | mkl-random @ file:///C:/ci/mkl_random_1626186163140/work
60 | mkl-service==2.4.0
61 | mxnet==1.7.0.post2
62 | nbclassic==0.3.1
63 | nbclient==0.5.3
64 | nbconvert==6.1.0
65 | nbformat==5.1.3
66 | nest-asyncio==1.5.1
67 | notebook==6.4.3
68 | numexpr @ file:///C:/ci/numexpr_1618856761305/work
69 | numpy==1.17.3
70 | olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work
71 | openpyxl==3.0.7
72 | packaging==21.0
73 | pandas @ file:///C:/ci/pandas_1627570311072/work
74 | pandocfilters==1.4.3
75 | parso==0.8.2
76 | pickleshare==0.7.5
77 | Pillow @ file:///C:/ci/pillow_1625663286921/work
78 | prometheus-client==0.11.0
79 | prompt-toolkit==3.0.19
80 | protobuf==3.17.3
81 | pyarrow==5.0.0
82 | pycparser==2.20
83 | pydeck==0.6.2
84 | Pygments==2.9.0
85 | pyparsing==2.4.7
86 | pyrsistent==0.18.0
87 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
88 | pytz @ file:///tmp/build/80754af9/pytz_1612215392582/work
89 | pywin32==301
90 | pywinpty==1.1.3
91 | pyzmq==22.2.1
92 | regex==2021.8.3
93 | requests==2.18.4
94 | requests-unixsocket==0.2.0
95 | sacremoses==0.0.45
96 | scikit-learn==0.24.2
97 | scipy==1.7.1
98 | Send2Trash==1.8.0
99 | sentencepiece==0.1.96
100 | six @ file:///tmp/build/80754af9/six_1623709665295/work
101 | smmap==4.0.0
102 | sniffio==1.2.0
103 | streamlit==0.86.0
104 | terminado==0.11.0
105 | testpath==0.5.0
106 | threadpoolctl==2.2.0
107 | tokenizers==0.8.0rc4
108 | toml==0.10.2
109 | toolz==0.11.1
110 | torch==1.9.0
111 | torchaudio==0.9.0
112 | torchtext==0.10.0
113 | torchvision==0.10.0
114 | tornado==6.1
115 | tqdm==4.62.0
116 | traitlets==5.0.5
117 | transformers==3.0.1
118 | typing-extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1622748266870/work
119 | tzdata==2021.1
120 | tzlocal==3.0
121 | urllib3==1.22
122 | validators==0.18.2
123 | watchdog==2.1.3
124 | wcwidth==0.2.5
125 | webencodings==0.5.1
126 | websocket-client==1.2.0
127 | widgetsnbextension==3.5.1
128 | wincertstore==0.2
129 | zipp==3.5.0
130 |
--------------------------------------------------------------------------------
/voc_classifier/captum_tools_vocvis.py:
--------------------------------------------------------------------------------
1 | from IPython.core.display import HTML, display
2 | HAS_IPYTHON = True
3 |
4 | def format_special_tokens(token):
5 | if token.startswith("<") and token.endswith(">"):
6 | return "#" + token.strip("<>")
7 | return token
8 |
9 |
10 | def format_tooltip(item, text):
11 | return '
{item}\
12 | {text}\
13 |
'.format(
14 | item=item, text=text
15 | )
16 | def format_classname(classname):
17 | return '{} | '.format(classname)
18 | def _get_color(attr):
19 | # clip values to prevent CSS errors (Values should be from [-1,1])
20 | attr = max(-1, min(1, attr))
21 | if attr > 0:
22 | hue = 120
23 | sat = 75
24 | lig = 100 - int(50 * attr)
25 | else:
26 | hue = 0
27 | sat = 75
28 | lig = 100 - int(-40 * attr)
29 | return "hsl({}, {}%, {}%)".format(hue, sat, lig)
30 | min_len=64
31 | def format_word_importances(words, importances):
32 | try:
33 | if importances is None or len(importances) == 0:
34 | return " | "
35 |
36 | if len(words) > len(importances):
37 | words=words[:min_len]
38 |
39 |
40 | tags = [""]
41 | for word, importance in zip(words, importances[: len(words)]):
42 | word = format_special_tokens(word)
43 | color = _get_color(importance)
44 | unwrapped_tag = ' {word}\
46 | '.format(
47 | color=color, word=word
48 | )
49 | tags.append(unwrapped_tag)
50 | tags.append(" | ")
51 | return "".join(tags)
52 | except Exception as e:
53 | print("skip it", e)
54 |
55 |
56 | def visualize_text(datarecords, legend=True):
57 | assert HAS_IPYTHON, (
58 | "IPython must be available to visualize text. "
59 | "Please run 'pip install ipython'."
60 | )
61 | dom = [""]
62 | rows = [
63 | "| True Label | "
64 | "Predicted Label | "
65 | "Attribution Label | "
66 | "Attribution Score | "
67 | "Word Importance | "
68 | ]
69 | cnt=0
70 | for datarecord in datarecords:
71 | cnt+=1
72 | try:
73 | rows.append(
74 | "".join(
75 | [
76 | "
|---|
",
77 | format_classname(datarecord.true_class),
78 | format_classname(
79 | "{0} ({1:.2f})".format(
80 | datarecord.pred_class, datarecord.pred_prob
81 | )
82 | ),
83 | format_classname(datarecord.attr_class),
84 | format_classname("{0:.2f}".format(datarecord.attr_score)),
85 | format_word_importances(
86 | datarecord.raw_input, datarecord.word_attributions
87 | ),
88 | "
",
89 | ]
90 | )
91 | )
92 | except Exception as e:
93 | print(f"Error in {cnt}",e)
94 |
95 | if legend:
96 | dom.append(
97 | ''
99 | )
100 | dom.append("Legend: ")
101 |
102 | for value, label in zip([-1, 0, 1], ["Negative", "Neutral", "Positive"]):
103 | dom.append(
104 | ' {label} '.format(
107 | value=_get_color(value), label=label
108 | )
109 | )
110 | dom.append("
")
111 |
112 | dom.append("".join(rows))
113 | dom.append("
")
114 | display(HTML("".join(dom)))
115 |
--------------------------------------------------------------------------------
/voc_classifier/inference.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../")
3 |
4 | # Input Data 가공 파트
5 |
6 | import torchtext
7 | import pandas as pd
8 | import numpy as np
9 |
10 | import os
11 | import re
12 |
13 | import config
14 | from config import expand_pandas
15 | from preprocess import preprocess
16 |
17 | DATA_PATH=config.DATA_PATH
18 |
19 |
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 |
24 | num_class=17
25 | ver_num=2
26 |
27 | except_labels=["변경/취소","예약기타"]
28 |
29 | version_info="{:02d}".format(ver_num)
30 | # weight_path=f"../weights/Deep_Voice_{version_info}.pt"
31 |
32 |
33 | weight_path="../weights/weight_01.pt"
34 |
35 |
36 |
37 |
38 | # KoBERT 모델
39 | import torch
40 | from torch import nn
41 | import torch.nn.functional as F
42 | import torch.optim as optim
43 | from torch.utils.data import Dataset, DataLoader
44 | import gluonnlp as nlp
45 | from tqdm import tqdm, tqdm_notebook
46 |
47 | from KoBERT.kobert.utils import get_tokenizer
48 | from KoBERT.kobert.pytorch_kobert import get_pytorch_kobert_model
49 |
50 | from transformers import AdamW
51 |
52 | class BERTClassifier(nn.Module):
53 | def __init__(self, bert, hidden_size = 768, num_classes = 8, dr_rate = None, params = None):
54 | # BERTClassifier의 자식 노드에 클래스 속성을 상속시켜준다?
55 | # 인풋으로 넣는 bert 모델을 클래스 내부의 bert 메서드로 넣어주고
56 | # dr_rate(drop-out rate)를 dr_rate로 넣어줌
57 |
58 | super(BERTClassifier, self).__init__()
59 | self.bert = bert
60 | self.dr_rate = dr_rate
61 |
62 | # 여기서 nn.Linear는 keras의 Dense 층과 같은 fully connected layer
63 | # hidden layer의 사이즈를 입력해주고(여기서는 768)
64 | # out-put layer의 사이즈를 num_classes 인자의 수만큼 잡아줌.
65 | # 아마 대/중/소분류 사이즈로 분리 가능할 듯.
66 |
67 |
68 | # self.lstm_layer = nn.LSTM(512, 128, 2)
69 | self.classifier = nn.Linear(hidden_size, num_classes)
70 |
71 | # self.classifier=Net(hidden_size=hidden_size, num_classes=num_classes)
72 |
73 | # dr_rate가 정의되어 있을 경우, 넣어준 비율에 맞게 weight를 drop-out 시켜줌
74 | if dr_rate:
75 | self.dropout = nn.Dropout(p=dr_rate)
76 |
77 | def generate_attention_mask(self, token_ids, valid_length):
78 |
79 | # 버트 모델에 사용할 attention_mask를 만들어 줌.
80 | # token_id를 인풋으로 받아, attention mask를 만들어 냄
81 |
82 | # torch.zeros_like()는 토치 텐서를 인풋으로 받아, 스칼라 값 0으로 채워진 같은 사이즈의 텐서를 뱉어냄
83 | attention_mask = torch.zeros_like(token_ids)
84 |
85 | for i,v in enumerate(valid_length):
86 | attention_mask[i][:v] = 1
87 | return attention_mask.float()
88 |
89 | def forward(self, token_ids, valid_length, segment_ids):
90 | # attention mask 를 만들어 내고, 버트 모델을 넣어줌.
91 | attention_mask = self.generate_attention_mask(token_ids, valid_length)
92 |
93 | # .long() pytorch는 .to()와 같은 기능을 수행함. 이는 장치(GPU)에 모델을 넣어주는 역할을 수행
94 | # 출력값으로 classifier()
95 |
96 | _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
97 | if self.dr_rate:
98 | out=self.dropout(pooler)
99 |
100 | # output=self.lstm_layer(out)
101 |
102 | return self.classifier(out)
103 |
104 |
105 |
106 |
107 | class load_model:
108 |
109 | def __init__(self):
110 |
111 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
112 | self.bertmodel, self.vocab = get_pytorch_kobert_model()
113 | self.tokenizer = get_tokenizer()
114 | self.tok = nlp.data.BERTSPTokenizer(self.tokenizer, self.vocab, lower = False)
115 | self.transform = nlp.data.BERTSentenceTransform(self.tok, max_seq_length = config.model_config["max_len"], pad=True, pair=False)
116 |
117 |
118 | def get_model(self):
119 |
120 |
121 | # KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌.
122 | self.model = BERTClassifier(self.bertmodel, num_classes=num_class, dr_rate = config.model_config["dr_rate"]).to(self.device)
123 |
124 |
125 | self.model.load_state_dict(torch.load(weight_path))
126 | self.model.eval()
127 |
128 | return self.model
129 |
130 | def get_prediction_from_txt(self, input_text, input_text_label):
131 |
132 | device=self.device
133 |
134 | sentences = self.transform([input_text])
135 | true_values=np.nonzero(input_text_label)[0].tolist()
136 | num_of_true=round(len(true_values)*1.5)
137 |
138 | get_pred=self.model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device))
139 | get_pred=get_pred.topk(k=num_of_true)[1]
140 |
141 | pred=np.array(get_pred.to("cpu").detach().numpy()[0], dtype=float)
142 | pred=list(map(int,pred))
143 | result=f"분석 결과, 대화의 예상 태그는 {[config.label_cols[i] for i in pred]} 입니다."
144 | true_label=f"실제 태그는 {[config.label_cols[i] for i in true_values]} 입니다."
145 | return result, true_label
146 |
147 |
148 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | KoBERT multi-label VOC classifier
2 | ==============================================
3 |
4 |
5 |
6 |
7 | **[TL-DR]**
8 | KoBERT를 사용한 PyTorch text multi-label classification 모델입니다. SKT에서 공개한 한국어 pretrain embedding model인 [KoBERT](https://github.com/SKTBrain/KoBERT)를 사용했습니다.
9 |
10 | ----
11 |
12 |
13 |
14 | Contents
15 | --------
16 |
17 | 1. [Intro](#intro)
18 | 2. [Structure](#structure)
19 | 3. [Embedding Visualization](#embedding)
20 | 4. [XAI using PyTorch Captum](#captum)
21 | 5. [Streamlit Demo](#demo)
22 |
23 |
24 |
25 |
26 | ## Intro
27 |
28 | 큰 기업에서는 매일 수 천건, 통화로 생성되는 STT 결과 데이터를 포함하면 수 만건에 달하는 VoC(Voice of Customers)가 발생합니다. 이 때 고객의 불만 사항을 실시간으로 분류/관리하여 트렌드를 추적해주는 시스템이 있다면, 운영 부서에서 미처 예상치 못한 서비스 장애를 진단하여 조기에 대응할 수도 있고, 나아가 고객 불만을 보상하는 프로모션을 제공한다면 고객 이탈을 방지함과 동시에 세일즈 KPI에 직접적인 효과를 얻을 수 있겠죠. 이러한 맥락에서 고안된 VOC 자동 분류 모델입니다.
29 |
30 | 사용된 데이터는 모 항공사의 VOC 데이터이고, 약 2,000건의 raw 데이터를 가지고 있습니다. 전체 레이블의 수는 70여개이며, 모델링에 사용된 수는 17개 입니다. 극히 소수의 데이터에서 실제 현업에 적용 가능한 수준의 성능을 검증하는 것이 해당 모델의 구현 목적이었습니다. 라벨 데이터는 현업 전문가에 의해 태깅되었고, 상담사와 고객의 대화를 보고 관련되었다고 판단되는 태그를 달아 주었습니다. 데이터와 학습된 weight file은 보안상의 이유로 공개할 수 없음을 양해바랍니다.
31 |
32 |
33 |
34 |
35 |
36 |
37 | ## Structure
38 |
39 |
40 | ```bash
41 |
42 | ├── data # 모델 학습 및 데모에 사용되는 데이터셋
43 | ├── KoBERT # 과거 버전의 KoBERT 레포지터리를 클론한 폴더
44 | ├── model
45 | │ ├── bert_model.py # dataloader, bert 모델 및 학습 관련 util
46 | │ ├── captum_tools_vocvis.py # PyTorch XAI를 위한 Captum 관련 util
47 | │ ├── config.py
48 | │ ├── demo.py # 텍스트 입력을 넣으면 모델 예측 결과를 반환하는 streamlit web app
49 | │ ├── inference.py # 추론 모듈
50 | │ ├── kobert_multilabel_text_classifier.ipynb
51 | │ ├── metrics_for_multilabel.py # multilabel 모델 평가를 위한 metrics
52 | │ └── preprocess.py # 전처리 모듈
53 | └── weights # 학습 모델 가중치
54 |
55 | ```
56 |
57 | Python 3.7.11 버전에서 구현되었습니다. conda 가상환경을 권장합니다. 사용 패키지는 requirements.txt를 참조해주세요.
58 |
59 |
60 |
61 | ### Model Performance
62 |
63 |
64 |
65 | 해당 모델은 Multi-Label classification 모델로, 전체 2,000여 건의 샘플 데이터를 train 85%, test 15%로 분할하여 테스트했습니다.
66 |
67 |
68 | | methods | NDCG@17| Micro f1 score| Macro f1 score |
69 | |-------------------------|:---------:|:---:|:---:|
70 | | KoBERT | **0.841** | **0.615** | **0.534**|
71 |
72 |
73 |
74 |
75 |
76 |
77 | ### Embedding visualization
78 |
79 |
80 |
81 | 프로젝트의 초기에는 multi-class task로 접근하여 모델링을 진행했습니다. 그런데, 500여개의 데이터 셋으로 나왔던 1차 성능에 비해 샘플이 더 추가된 데이터 셋으로 만든 2차 모델의 성능이 더 떨어지는 현상이 발생했고 (8개 클래스 77% -> 73%), 원인 파악을 위해 오분류 샘플을 조사했습니다.
82 |
83 | | | 카테고리 명|
84 | |-------------------------|:---------:|
85 | | 예측 정확도 70% 이상 | 무상 변경/취소, 유상 변경/취소, 기내 서비스 등 |
86 | | 예측 정확도 40% 이하 | **예약 기타**, **변경/취소** |
87 |
88 | 위와 같이, 기타 클래스의 특징을 모호하게 포함하고 있는 클래스의 성능이 매우 낮은 것을 확인할 수 있었습니다. 사람이 직접 의미적으로 판단해도 모호한 경우가 많았습니다. 이 클래스에 포함된 샘플들은 모델의 최적화 과정에서 모호한 시그널을 제공함으로써 파라미터 최적화에 악영향을 미칠 것이라고 직관적으로 생각했고, 이와 같은 내용이 버트 분류 모델의 예측에 사용되는 마지막 CLS 토큰의 representation을 low dimension에 mapping 했을 때 확인 가능할 것이라고 가정했습니다.
89 |
90 | 아래는 실제 CLS 토큰의 임베딩에 t-SNE를 적용한 결과입니다. 모호한 라벨을 가진 샘플들에 의해 임베딩 스페이스가 다소 entangled된 형태를 보이는 것을 알 수 있습니다.
91 |
92 |
93 |
94 |
95 |
96 |
97 | 그렇다면 이 라벨들을 제거해 준다면, 버트 representation 이후의 레이어가 결정 경계를 손쉽게 그을 수 있도록 임베딩이 학습되지 않을까요? 그러한 질문에 답한 것이 다음과 같은 이미지였습니다.
98 |
99 |
100 |
101 |
102 |
103 |
104 | 예쁘게 잘 정리 됐네요. 이와 같은 결과가 말해주듯이, 데이터셋을 수작업으로 레이블링할 때 모델이 혼동하지 않는 기준을 세우는 것이 중요하다는 결론을 내릴 수 있었습니다. 아래는 수정 후 모델의 confusion matrix입니다. 85:15로 stratified sampling을 해 주었습니다.
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 | ### XAI using pytorch Captum
115 |
116 |
117 | [Captum](https://captum.ai/)은 PyTorch 모델의 interpretability를 위한 라이브러리입니다. 이 중 자연어 분류 모델의 판단에 긍정적, 부정적으로 영향을 미친 토큰을 시각화해주는 [예제](https://github.com/pytorch/captum/blob/master/tutorials/IMDB_TorchText_Interpret.ipynb)가 있어 본 문제에 적용해 보았습니다. 아래는 시각화 결과입니다.
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 | "기내 서비스" 라는 레이블을 예측하는 데 positive한 영향을 준 토큰은 녹색으로, negative한 영향을 준 (즉 라벨 예측에 혼동을 준) 토큰은 붉은 색으로 시각화해 줍니다. 우리의 경우에서는 토큰 시각화가 직관에 다소 부합하지 않는 결과를 보이기도 했으나, 이는 소수 샘플로 인한 특정 토큰의 영향에 의한 것일수도, 한글 토큰의 인코딩의 문제일 수도 있습니다.
126 |
127 |
128 |
129 |
130 |
131 | ### Streamlit Demo
132 |
133 | [streamlit](https://streamlit.io/)은 웹/앱 개발에 익숙치 않은 데이터 사이언티스트들이 손쉽게 웹앱 데모를 구현할 수 있도록 도와주는 high-level data app 라이브러리입니다. 입출력을 현업에게 빠르게 보여주기 위해 다음과 같은 데모를 만들었습니다. 불과 몇 분의 투자로 모델의 I/O를 보여줄 수 있는 매우 간편한 기능을 제공합니다.
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 | 다음 커맨드로 간단하게 실행할 수 있습니다.
145 |
146 | ```
147 | streamlit run demo.py
148 | ```
149 |
--------------------------------------------------------------------------------
/voc_classifier/bert_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("../")
3 |
4 | # KoBERT 모델
5 |
6 | import config
7 |
8 | import pandas as pd
9 | import numpy as np
10 | from sklearn.preprocessing import OneHotEncoder
11 |
12 | import torch
13 | from torch import nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | from torch.utils.data import Dataset, DataLoader
17 | import gluonnlp as nlp
18 | from tqdm import tqdm, tqdm_notebook
19 |
20 |
21 |
22 | from KoBERT.kobert.utils import get_tokenizer
23 | from KoBERT.kobert.pytorch_kobert import get_pytorch_kobert_model
24 |
25 | from transformers import AdamW
26 | # from transformers.optimization import WarmupLinearSchedule
27 |
28 | from transformers import get_linear_schedule_with_warmup
29 |
30 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31 |
32 | bertmodel, vocab = get_pytorch_kobert_model()
33 |
34 | # 토크나이저 메서드를 tokenizer에 호출
35 | # 코퍼스를 토큰으로 만드는 과정을 수행, 이 때 토크나이저는 kobert 패키지에 있는 get_tokenizer()를 사용하고,
36 | # 토큰화를 위해 필요한 단어 사전은 kobert의 vocab을 사용함.
37 | # uncased로 투입해야 하므로 lower = False
38 |
39 | tokenizer = get_tokenizer()
40 | tok = nlp.data.BERTSPTokenizer(tokenizer, vocab, lower = False)
41 | print(f'device using: {device}')
42 |
43 |
44 | model_config=config.model_config
45 |
46 |
47 | class Data_for_BERT(Dataset):
48 | def __init__(self, dataset, max_len, pad, pair, label_cols):
49 |
50 | # gluon nlp 패키지의 data.BERTSentenceTransform 메서드를 사용,
51 | # 버트 활용을 위한 토크나이저를 bert_tokenizer로 주고,
52 | # 문장 내 시퀀스의 최대 수를 max_len 인자로 제공. 이 말은 max_len개의 (단어를 쪼갠) 덩어리만 활용한다는 의미
53 | # pad 인자는 max_len보다 짧은 문장을 패딩해주겠냐는 것을 묻는 것,
54 | # pair 인자는 문장으로 변환할지, 문장 쌍으로 변환할지.
55 |
56 | transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=pad, pair=pair)
57 | self.sentences = [transform([txt]) for txt in dataset.text]
58 | # self.sentences_Customer = [transform([txt]) for txt in dataset.Customer]
59 | # self.labels = [np.int32(i) for i in dataset.label]
60 | self.labels=dataset[label_cols].values
61 |
62 | # ohe = OneHotEncoder().fit(pd.Series(self.labels).values.reshape(-1,1))
63 | # self.labels = ohe.transform(pd.Series(self.labels).values.reshape(-1,1)).toarray()
64 |
65 | # target.bcat
66 | # self.labels = b_ohe.fit_transform(pd.Series(self.labels).values.reshape(-1,1))
67 |
68 | def __getitem__(self,i):
69 | return (self.sentences[i] + (self.labels[i],))
70 |
71 | def __len__(self):
72 | return(len(self.labels))
73 |
74 |
75 | class BERTClassifier(nn.Module):
76 | def __init__(self, hidden_size = 768, num_classes = 8, dr_rate = None, params = None):
77 | # BERTClassifier의 자식 노드에 클래스 속성을 상속시켜준다?
78 | # 인풋으로 넣는 bert 모델을 클래스 내부의 bert 메서드로 넣어주고
79 | # dr_rate(drop-out rate)를 dr_rate로 넣어줌
80 |
81 | super(BERTClassifier, self).__init__()
82 | self.bert = bertmodel
83 | self.dr_rate = dr_rate
84 |
85 | # 여기서 nn.Linear는 keras의 Dense 층과 같은 fully connected layer
86 | # hidden layer의 사이즈를 입력해주고(여기서는 768)
87 | # out-put layer의 사이즈를 num_classes 인자의 수만큼 잡아줌.
88 |
89 | # self.lstm_layer = nn.LSTM(512, 128, 2)
90 | self.classifier = nn.Linear(hidden_size, num_classes)
91 |
92 | # dr_rate가 정의되어 있을 경우, 넣어준 비율에 맞게 weight를 drop-out 시켜줌
93 | if dr_rate:
94 | self.dropout = nn.Dropout(p=dr_rate)
95 |
96 | def generate_attention_mask(self, token_ids, valid_length):
97 |
98 | # 버트 모델에 사용할 attention_mask를 만들어 줌.
99 | # token_id를 인풋으로 받아, attention mask를 만들어 냄
100 |
101 | # torch.zeros_like()는 토치 텐서를 인풋으로 받아, 스칼라 값 0으로 채워진 같은 사이즈의 텐서를 뱉어냄
102 | attention_mask = torch.zeros_like(token_ids)
103 |
104 | for i,v in enumerate(valid_length):
105 | attention_mask[i][:v] = 1
106 | return attention_mask.float()
107 |
108 | def forward(self, token_ids, valid_length, segment_ids):
109 | # attention mask 를 만들어 내고, 버트 모델을 넣어줌.
110 | attention_mask = self.generate_attention_mask(token_ids, valid_length)
111 |
112 | # .long() pytorch는 .to()와 같은 기능을 수행함. 이는 장치(GPU)에 모델을 넣어주는 역할을 수행
113 | # 출력값으로 classifier()
114 |
115 | _, pooler = self.bert(input_ids = token_ids, token_type_ids = segment_ids.long(), attention_mask = attention_mask.float().to(token_ids.device))
116 | if self.dr_rate:
117 | out=self.dropout(pooler)
118 |
119 | # output=self.lstm_layer(out)
120 |
121 | return self.classifier(out)
122 |
123 |
124 | class EarlyStopping:
125 | """Early stops the training if validation loss doesn't improve after a given patience."""
126 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
127 | """
128 | Args:
129 | patience (int): How long to wait after last time validation loss improved.
130 | Default: 7
131 | verbose (bool): If True, prints a message for each validation loss improvement.
132 | Default: False
133 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
134 | Default: 0
135 | path (str): Path for the checkpoint to be saved to.
136 | Default: 'checkpoint.pt'
137 | trace_func (function): trace print function.
138 | Default: print
139 | """
140 | self.patience = patience
141 | self.verbose = verbose
142 | self.counter = 0
143 | self.best_score = None
144 | self.early_stop = False
145 | self.val_loss_min = np.Inf
146 | self.delta = delta
147 | self.path = path
148 | self.trace_func = trace_func
149 |
150 | def __call__(self, val_loss, model):
151 |
152 | score = -val_loss
153 |
154 | if self.best_score is None:
155 | self.best_score = score
156 | self.save_checkpoint(val_loss, model)
157 | elif score < self.best_score + self.delta:
158 | self.counter += 1
159 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
160 | if self.counter >= self.patience:
161 | self.early_stop = True
162 | else:
163 | self.best_score = score
164 | self.save_checkpoint(val_loss, model)
165 | self.counter = 0
166 |
167 | def save_checkpoint(self, val_loss, model):
168 | '''Saves model when validation loss decrease.'''
169 | if self.verbose:
170 | self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
171 | torch.save(model.state_dict(), self.path)
172 | self.val_loss_min = val_loss
173 |
--------------------------------------------------------------------------------
/voc_classifier/preprocess.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | import os
5 | import re
6 |
7 | import config
8 |
9 | DATA_PATH=config.DATA_PATH
10 |
11 |
12 |
13 | def remove_linebreak(string):
14 | return string.replace('\r',' ').replace('\n',' ')
15 |
16 | def split_table_string(string):
17 | trimmedTableString=string[string.rfind("PNR :"):]
18 | string=string[:string.rfind("PNR :")]
19 | return (string, trimmedTableString)
20 |
21 | def remove_multispace(x):
22 | x = str(x).strip()
23 | x = re.sub(' +', ' ',x)
24 | return x
25 |
26 |
27 |
28 |
29 |
30 | class preprocess:
31 |
32 | def __init__(self):
33 | self.voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl')
34 |
35 |
36 | def make_table(self):
37 |
38 | voc_data=self.voc_total["종합본"]
39 |
40 | DATA_PATH=config.DATA_PATH
41 |
42 | voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl')
43 |
44 | voc_data=voc_total["종합본"]
45 |
46 | voc_data=voc_data.drop(["Unnamed: 11","Unnamed: 12"],axis=1)
47 |
48 | voc_data.columns=["case_no","sender_type","message","cat1","cat2","cat3","cat4","cat5","cat6","cat7","cat8"]
49 |
50 | case_start_idx=voc_data.loc[:,"case_no"][voc_data.loc[:,"case_no"].notnull()].index.tolist()
51 |
52 | voc_data.loc[case_start_idx,"seq"]=1
53 |
54 | nan_val=voc_data.iloc[2][0]
55 |
56 | result=[]
57 | case_no=False
58 | cnt_case=0
59 | for _,txt in voc_data.iterrows():
60 | if (cnt_case==0)&(txt.seq==1):
61 | new_chat_corpus=[]
62 |
63 | cat1=txt.cat1
64 | cat2=txt.cat2
65 | cat3=txt.cat3
66 | cat4=txt.cat4
67 | cat5=txt.cat5
68 | cat6=txt.cat6
69 | cat7=txt.cat7
70 | cat8=txt.cat8
71 |
72 | new_chat_corpus.append([txt.sender_type,txt.message])
73 | cnt_case+=1
74 |
75 | elif txt.seq==1:
76 |
77 | result.append([new_chat_corpus,cat1,cat2,cat3,cat4,cat5,cat6,cat7,cat8])
78 |
79 | cnt_case+=1
80 | # 기존의 말뭉치 append
81 | # 새로운 말뭉치 구분 시작
82 | new_chat_corpus=[]
83 |
84 | cat1=txt.cat1
85 | cat2=txt.cat2
86 | cat3=txt.cat3
87 | cat4=txt.cat4
88 | cat5=txt.cat5
89 | cat6=txt.cat6
90 | cat7=txt.cat7
91 | cat8=txt.cat8
92 |
93 | new_chat_corpus.append([txt.sender_type,txt.message])
94 |
95 | else:
96 | new_chat_corpus.append([txt.sender_type,txt.message])
97 |
98 |
99 | total_result=[]
100 |
101 | for i in range(len(result)):
102 | talk=" ".join([str(txt[1]) for txt in result[i][0]])
103 | label=result[i][-8:]
104 | total_result.append([talk,label])
105 |
106 |
107 |
108 | result_data=pd.DataFrame(total_result)
109 |
110 | label_total=result_data[1]
111 |
112 | result_data[["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]]=pd.DataFrame(result_data[1].tolist(), index=result_data.index)
113 |
114 | result_data=result_data.drop(1,axis=1)
115 |
116 | result_data.columns=["text","label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]
117 |
118 | result_data=result_data[result_data["label_cat1"].isna()==False]
119 |
120 | result_data.label_cat1=result_data.label_cat1.fillna("결측")
121 | result_data.label_cat1=result_data.label_cat1.apply(str.strip)
122 | result_data.label_cat2=result_data.label_cat2.fillna("결측")
123 | result_data.label_cat2=result_data.label_cat2.apply(str.strip)
124 | result_data.label_cat3=result_data.label_cat3.fillna("결측")
125 | result_data.label_cat3=result_data.label_cat3.apply(str.strip)
126 | result_data.label_cat4=result_data.label_cat4.fillna("결측")
127 | result_data.label_cat4=result_data.label_cat4.apply(str.strip)
128 | result_data.label_cat5=result_data.label_cat5.fillna("결측")
129 | result_data.label_cat5=result_data.label_cat5.apply(str.strip)
130 | result_data.label_cat6=result_data.label_cat6.fillna("결측")
131 | result_data.label_cat6=result_data.label_cat6.apply(str.strip)
132 | result_data.label_cat7=result_data.label_cat7.fillna("결측")
133 | result_data.label_cat7=result_data.label_cat7.apply(str.strip)
134 | result_data.label_cat8=result_data.label_cat8.fillna("결측")
135 | result_data.label_cat8=result_data.label_cat8.apply(str.strip)
136 |
137 | result_data.label_cat1=result_data.label_cat1.apply(str.replace, args=(" ",""))
138 | result_data.label_cat2=result_data.label_cat2.apply(str.replace, args=(" ",""))
139 | result_data.label_cat3=result_data.label_cat3.apply(str.replace, args=(" ",""))
140 | result_data.label_cat4=result_data.label_cat4.apply(str.replace, args=(" ",""))
141 | result_data.label_cat5=result_data.label_cat5.apply(str.replace, args=(" ",""))
142 | result_data.label_cat6=result_data.label_cat6.apply(str.replace, args=(" ",""))
143 | result_data.label_cat7=result_data.label_cat7.apply(str.replace, args=(" ",""))
144 | result_data.label_cat8=result_data.label_cat8.apply(str.replace, args=(" ",""))
145 |
146 | self.table=result_data
147 | return True
148 |
149 | def label_process(self, num_labels=17 ,except_labels=None):
150 | result_data=self.table
151 |
152 | total_label_cases=set(result_data.label_cat1.unique().tolist())|set(result_data.label_cat2.unique().tolist())|set(result_data.label_cat3.unique().tolist())|\
153 | set(result_data.label_cat4.unique().tolist())|set(result_data.label_cat5.unique().tolist())|set(result_data.label_cat6.unique().tolist())|\
154 | set(result_data.label_cat7.unique().tolist())|set(result_data.label_cat8.unique().tolist())
155 |
156 | total_label_cases=list(total_label_cases)
157 |
158 | final_label=[]
159 |
160 | for _,txt in result_data.iterrows():
161 | label_sum="|"+txt.label_cat1+"|"+txt.label_cat2+"|"+txt.label_cat3+"|"+txt.label_cat4+"|"+txt.label_cat5+"|"+txt.label_cat6+"|"+txt.label_cat7+"|"+txt.label_cat8
162 | final_label.append([_,label_sum])
163 |
164 | total_label_cases_dict={}
165 |
166 | for col in total_label_cases:
167 | total_label_cases_dict[col]=len([i[0] for i in final_label if f"|{col}|" in i[1]])
168 |
169 | label_cases_dict_cnt=[case for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()]
170 |
171 | label_cases_sorted=[case[0] for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()]
172 |
173 | if except_labels:
174 | for label in except_labels:
175 | label_cases_sorted.remove(label)
176 |
177 | label_cases_sorted_target=label_cases_sorted[1:num_labels+1]
178 |
179 | label_cols=label_cases_sorted_target
180 | self.label_cols=label_cases_sorted_target
181 |
182 | for col in label_cases_sorted_target:
183 | result_data.loc[[i[0] for i in final_label if f"|{col}|" in i[1]], col]=1
184 | result_data[col]=result_data[col].fillna(0)
185 | result_data[col]=result_data[col].astype(int)
186 |
187 |
188 | result_data=result_data.drop(["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"],axis=1)
189 | # 단 하나의 답도 없는 경우에는 일단 제거
190 |
191 | result_data=result_data[(result_data.iloc[:,1:].eq(0).sum(axis=1)!=num_labels)]
192 |
193 | self.data=result_data
194 |
195 | return True
196 |
197 |
--------------------------------------------------------------------------------
/voc_classifier/.ipynb_checkpoints/preprocess-checkpoint.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | import os
5 | import re
6 |
7 | import config
8 |
9 | DATA_PATH=config.DATA_PATH
10 |
11 |
12 |
13 | def remove_linebreak(string):
14 | return string.replace('\r',' ').replace('\n',' ')
15 |
16 | def split_table_string(string):
17 | trimmedTableString=string[string.rfind("PNR :"):]
18 | string=string[:string.rfind("PNR :")]
19 | return (string, trimmedTableString)
20 |
21 | def remove_multispace(x):
22 | x = str(x).strip()
23 | x = re.sub(' +', ' ',x)
24 | return x
25 |
26 |
27 |
28 |
29 |
30 | class preprocess:
31 |
32 | def __init__(self):
33 | self.voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl')
34 |
35 |
36 | def make_table(self):
37 |
38 | voc_data=self.voc_total["종합본"]
39 |
40 | DATA_PATH=config.DATA_PATH
41 |
42 | voc_total=pd.read_excel(DATA_PATH,sheet_name=None, engine='openpyxl')
43 |
44 | voc_data=voc_total["종합본"]
45 |
46 | voc_data=voc_data.drop(["Unnamed: 11","Unnamed: 12"],axis=1)
47 |
48 | voc_data.columns=["case_no","sender_type","message","cat1","cat2","cat3","cat4","cat5","cat6","cat7","cat8"]
49 |
50 | case_start_idx=voc_data.loc[:,"case_no"][voc_data.loc[:,"case_no"].notnull()].index.tolist()
51 |
52 | voc_data.loc[case_start_idx,"seq"]=1
53 |
54 | nan_val=voc_data.iloc[2][0]
55 |
56 | result=[]
57 | case_no=False
58 | cnt_case=0
59 | for _,txt in voc_data.iterrows():
60 | if (cnt_case==0)&(txt.seq==1):
61 | new_chat_corpus=[]
62 |
63 | cat1=txt.cat1
64 | cat2=txt.cat2
65 | cat3=txt.cat3
66 | cat4=txt.cat4
67 | cat5=txt.cat5
68 | cat6=txt.cat6
69 | cat7=txt.cat7
70 | cat8=txt.cat8
71 |
72 | new_chat_corpus.append([txt.sender_type,txt.message])
73 | cnt_case+=1
74 |
75 | elif txt.seq==1:
76 |
77 | result.append([new_chat_corpus,cat1,cat2,cat3,cat4,cat5,cat6,cat7,cat8])
78 |
79 | cnt_case+=1
80 | # 기존의 말뭉치 append
81 | # 새로운 말뭉치 구분 시작
82 | new_chat_corpus=[]
83 |
84 | cat1=txt.cat1
85 | cat2=txt.cat2
86 | cat3=txt.cat3
87 | cat4=txt.cat4
88 | cat5=txt.cat5
89 | cat6=txt.cat6
90 | cat7=txt.cat7
91 | cat8=txt.cat8
92 |
93 | new_chat_corpus.append([txt.sender_type,txt.message])
94 |
95 | else:
96 | new_chat_corpus.append([txt.sender_type,txt.message])
97 |
98 |
99 | total_result=[]
100 |
101 | for i in range(len(result)):
102 | talk=" ".join([str(txt[1]) for txt in result[i][0]])
103 | label=result[i][-8:]
104 | total_result.append([talk,label])
105 |
106 |
107 |
108 | result_data=pd.DataFrame(total_result)
109 |
110 | label_total=result_data[1]
111 |
112 | result_data[["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]]=pd.DataFrame(result_data[1].tolist(), index=result_data.index)
113 |
114 | result_data=result_data.drop(1,axis=1)
115 |
116 | result_data.columns=["text","label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"]
117 |
118 | result_data=result_data[result_data["label_cat1"].isna()==False]
119 |
120 | result_data.label_cat1=result_data.label_cat1.fillna("결측")
121 | result_data.label_cat1=result_data.label_cat1.apply(str.strip)
122 | result_data.label_cat2=result_data.label_cat2.fillna("결측")
123 | result_data.label_cat2=result_data.label_cat2.apply(str.strip)
124 | result_data.label_cat3=result_data.label_cat3.fillna("결측")
125 | result_data.label_cat3=result_data.label_cat3.apply(str.strip)
126 | result_data.label_cat4=result_data.label_cat4.fillna("결측")
127 | result_data.label_cat4=result_data.label_cat4.apply(str.strip)
128 | result_data.label_cat5=result_data.label_cat5.fillna("결측")
129 | result_data.label_cat5=result_data.label_cat5.apply(str.strip)
130 | result_data.label_cat6=result_data.label_cat6.fillna("결측")
131 | result_data.label_cat6=result_data.label_cat6.apply(str.strip)
132 | result_data.label_cat7=result_data.label_cat7.fillna("결측")
133 | result_data.label_cat7=result_data.label_cat7.apply(str.strip)
134 | result_data.label_cat8=result_data.label_cat8.fillna("결측")
135 | result_data.label_cat8=result_data.label_cat8.apply(str.strip)
136 |
137 | result_data.label_cat1=result_data.label_cat1.apply(str.replace, args=(" ",""))
138 | result_data.label_cat2=result_data.label_cat2.apply(str.replace, args=(" ",""))
139 | result_data.label_cat3=result_data.label_cat3.apply(str.replace, args=(" ",""))
140 | result_data.label_cat4=result_data.label_cat4.apply(str.replace, args=(" ",""))
141 | result_data.label_cat5=result_data.label_cat5.apply(str.replace, args=(" ",""))
142 | result_data.label_cat6=result_data.label_cat6.apply(str.replace, args=(" ",""))
143 | result_data.label_cat7=result_data.label_cat7.apply(str.replace, args=(" ",""))
144 | result_data.label_cat8=result_data.label_cat8.apply(str.replace, args=(" ",""))
145 |
146 | self.table=result_data
147 | return True
148 |
149 | def label_process(self, num_labels=17 ,except_labels=None):
150 | result_data=self.table
151 |
152 | total_label_cases=set(result_data.label_cat1.unique().tolist())|set(result_data.label_cat2.unique().tolist())|set(result_data.label_cat3.unique().tolist())|\
153 | set(result_data.label_cat4.unique().tolist())|set(result_data.label_cat5.unique().tolist())|set(result_data.label_cat6.unique().tolist())|\
154 | set(result_data.label_cat7.unique().tolist())|set(result_data.label_cat8.unique().tolist())
155 |
156 | total_label_cases=list(total_label_cases)
157 |
158 | final_label=[]
159 |
160 | for _,txt in result_data.iterrows():
161 | label_sum="|"+txt.label_cat1+"|"+txt.label_cat2+"|"+txt.label_cat3+"|"+txt.label_cat4+"|"+txt.label_cat5+"|"+txt.label_cat6+"|"+txt.label_cat7+"|"+txt.label_cat8
162 | final_label.append([_,label_sum])
163 |
164 | total_label_cases_dict={}
165 |
166 | for col in total_label_cases:
167 | total_label_cases_dict[col]=len([i[0] for i in final_label if f"|{col}|" in i[1]])
168 |
169 | label_cases_dict_cnt=[case for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()]
170 |
171 | label_cases_sorted=[case[0] for case in {k: v for k, v in sorted(total_label_cases_dict.items(), key=lambda item: item[1], reverse=True)}.items()]
172 |
173 | if except_labels:
174 | for label in except_labels:
175 | label_cases_sorted.remove(label)
176 |
177 | label_cases_sorted_target=label_cases_sorted[1:num_labels+1]
178 |
179 | label_cols=label_cases_sorted_target
180 | self.label_cols=label_cases_sorted_target
181 |
182 | for col in label_cases_sorted_target:
183 | result_data.loc[[i[0] for i in final_label if f"|{col}|" in i[1]], col]=1
184 | result_data[col]=result_data[col].fillna(0)
185 | result_data[col]=result_data[col].astype(int)
186 |
187 |
188 | result_data=result_data.drop(["label_cat1","label_cat2","label_cat3","label_cat4","label_cat5","label_cat6","label_cat7","label_cat8"],axis=1)
189 | # 단 하나의 답도 없는 경우에는 일단 제거
190 |
191 | result_data=result_data[(result_data.iloc[:,1:].eq(0).sum(axis=1)!=num_labels)]
192 |
193 | self.data=result_data
194 |
195 | return True
196 |
197 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/voc_classifier/metrics_for_multilabel.py:
--------------------------------------------------------------------------------
1 | # https://github.com/iliaschalkidis/lmtc-eurlex57k/blob/master/metrics.py
2 |
3 | from sklearn.metrics import accuracy_score
4 | from sklearn.metrics import precision_score
5 | from sklearn.metrics import recall_score
6 | from sklearn.metrics import f1_score
7 |
8 |
9 | import numpy as np
10 |
11 |
12 | def mean_precision_k(y_true, y_score, k=10):
13 | """Mean precision at rank k
14 | Parameters
15 | ----------
16 | y_true : array-like, shape = [n_samples]
17 | Ground truth (true relevance labels).
18 | y_score : array-like, shape = [n_samples]
19 | Predicted scores.
20 | k : int
21 | Rank.
22 | Returns
23 | -------
24 | mean precision @k : float
25 | """
26 |
27 | p_ks = []
28 | for y_t, y_s in zip(y_true, y_score):
29 | if np.sum(y_t == 1):
30 | p_ks.append(ranking_precision_score(y_t, y_s, k=k))
31 |
32 | return np.mean(p_ks)
33 |
34 |
35 | def mean_recall_k(y_true, y_score, k=10):
36 | """Mean recall at rank k
37 | Parameters
38 | ----------
39 | y_true : array-like, shape = [n_samples]
40 | Ground truth (true relevance labels).
41 | y_score : array-like, shape = [n_samples]
42 | Predicted scores.
43 | k : int
44 | Rank.
45 | Returns
46 | -------
47 | mean recall @k : float
48 | """
49 |
50 | r_ks = []
51 | for y_t, y_s in zip(y_true, y_score):
52 | if np.sum(y_t == 1):
53 | r_ks.append(ranking_recall_score(y_t, y_s, k=k))
54 |
55 | return np.mean(r_ks)
56 |
57 |
58 | def mean_ndcg_score(y_true, y_score, k=10, gains="exponential"):
59 | """Normalized discounted cumulative gain (NDCG) at rank k
60 | Parameters
61 | ----------
62 | y_true : array-like, shape = [n_samples]
63 | Ground truth (true relevance labels).
64 | y_score : array-like, shape = [n_samples]
65 | Predicted scores.
66 | k : int
67 | Rank.
68 | gains : str
69 | Whether gains should be "exponential" (default) or "linear".
70 | Returns
71 | -------
72 | Mean NDCG @k : float
73 | """
74 |
75 | ndcg_s = []
76 | for y_t, y_s in zip(y_true, y_score):
77 | if np.sum(y_t == 1):
78 | ndcg_s.append(ndcg_score(y_t, y_s, k=k, gains=gains))
79 |
80 | return np.mean(ndcg_s)
81 |
82 |
83 | def mean_rprecision_k(y_true, y_score, k=10):
84 | """Mean precision at rank k
85 | Parameters
86 | ----------
87 | y_true : array-like, shape = [n_samples]
88 | Ground truth (true relevance labels).
89 | y_score : array-like, shape = [n_samples]
90 | Predicted scores.
91 | k : int
92 | Rank.
93 | Returns
94 | -------
95 | mean precision @k : float
96 | """
97 |
98 | p_ks = []
99 | for y_t, y_s in zip(y_true, y_score):
100 | if np.sum(y_t == 1):
101 | p_ks.append(ranking_rprecision_score(y_t, y_s, k=k))
102 |
103 | return np.mean(p_ks)
104 |
105 |
106 | def ranking_recall_score(y_true, y_score, k=10):
107 | # https://ils.unc.edu/courses/2013_spring/inls509_001/lectures/10-EvaluationMetrics.pdf
108 | """Recall at rank k
109 | Parameters
110 | ----------
111 | y_true : array-like, shape = [n_samples]
112 | Ground truth (true relevance labels).
113 | y_score : array-like, shape = [n_samples]
114 | Predicted scores.
115 | k : int
116 | Rank.
117 | Returns
118 | -------
119 | precision @k : float
120 | """
121 | unique_y = np.unique(y_true)
122 |
123 | if len(unique_y) == 1:
124 | return ValueError("The score cannot be approximated.")
125 | elif len(unique_y) > 2:
126 | raise ValueError("Only supported for two relevance levels.")
127 |
128 | pos_label = unique_y[1]
129 | n_pos = np.sum(y_true == pos_label)
130 |
131 | order = np.argsort(y_score)[::-1]
132 | y_true = np.take(y_true, order[:k])
133 | n_relevant = np.sum(y_true == pos_label)
134 |
135 | return float(n_relevant) / n_pos
136 |
137 |
138 | def ranking_precision_score(y_true, y_score, k=10):
139 | """Precision at rank k
140 | Parameters
141 | ----------
142 | y_true : array-like, shape = [n_samples]
143 | Ground truth (true relevance labels).
144 | y_score : array-like, shape = [n_samples]
145 | Predicted scores.
146 | k : int
147 | Rank.
148 | Returns
149 | -------
150 | precision @k : float
151 | """
152 | unique_y = np.unique(y_true)
153 |
154 | if len(unique_y) == 1:
155 | return ValueError("The score cannot be approximated.")
156 | elif len(unique_y) > 2:
157 | raise ValueError("Only supported for two relevance levels.")
158 |
159 | pos_label = unique_y[1]
160 |
161 | order = np.argsort(y_score)[::-1]
162 | y_true = np.take(y_true, order[:k])
163 | n_relevant = np.sum(y_true == pos_label)
164 |
165 | return float(n_relevant) / k
166 |
167 |
168 | def ranking_rprecision_score(y_true, y_score, k=10):
169 | """Precision at rank k
170 | Parameters
171 | ----------
172 | y_true : array-like, shape = [n_samples]
173 | Ground truth (true relevance labels).
174 | y_score : array-like, shape = [n_samples]
175 | Predicted scores.
176 | k : int
177 | Rank.
178 | Returns
179 | -------
180 | precision @k : float
181 | """
182 | unique_y = np.unique(y_true)
183 |
184 | if len(unique_y) == 1:
185 | return ValueError("The score cannot be approximated.")
186 | elif len(unique_y) > 2:
187 | raise ValueError("Only supported for two relevance levels.")
188 |
189 | pos_label = unique_y[1]
190 | n_pos = np.sum(y_true == pos_label)
191 |
192 | order = np.argsort(y_score)[::-1]
193 | y_true = np.take(y_true, order[:k])
194 | n_relevant = np.sum(y_true == pos_label)
195 |
196 | # Divide by min(n_pos, k) such that the best achievable score is always 1.0.
197 | return float(n_relevant) / min(k, n_pos)
198 |
199 |
200 | def average_precision_score(y_true, y_score, k=10):
201 | """Average precision at rank k
202 | Parameters
203 | ----------
204 | y_true : array-like, shape = [n_samples]
205 | Ground truth (true relevance labels).
206 | y_score : array-like, shape = [n_samples]
207 | Predicted scores.
208 | k : int
209 | Rank.
210 | Returns
211 | -------
212 | average precision @k : float
213 | """
214 | unique_y = np.unique(y_true)
215 |
216 | if len(unique_y) == 1:
217 | return ValueError("The score cannot be approximated.")
218 | elif len(unique_y) > 2:
219 | raise ValueError("Only supported for two relevance levels.")
220 |
221 | pos_label = unique_y[1]
222 | n_pos = np.sum(y_true == pos_label)
223 |
224 | order = np.argsort(y_score)[::-1][:min(n_pos, k)]
225 | y_true = np.asarray(y_true)[order]
226 |
227 | score = 0
228 | for i in range(len(y_true)):
229 | if y_true[i] == pos_label:
230 | # Compute precision up to document i
231 | # i.e, percentage of relevant documents up to document i.
232 | prec = 0
233 | for j in range(0, i + 1):
234 | if y_true[j] == pos_label:
235 | prec += 1.0
236 | prec /= (i + 1.0)
237 | score += prec
238 |
239 | if n_pos == 0:
240 | return 0
241 |
242 | return score / n_pos
243 |
244 |
245 | def dcg_score(y_true, y_score, k=10, gains="exponential"):
246 | """Discounted cumulative gain (DCG) at rank k
247 | Parameters
248 | ----------
249 | y_true : array-like, shape = [n_samples]
250 | Ground truth (true relevance labels).
251 | y_score : array-like, shape = [n_samples]
252 | Predicted scores.
253 | k : int
254 | Rank.
255 | gains : str
256 | Whether gains should be "exponential" (default) or "linear".
257 | Returns
258 | -------
259 | DCG @k : float
260 | """
261 | order = np.argsort(y_score)[::-1]
262 | y_true = np.take(y_true, order[:k])
263 |
264 | if gains == "exponential":
265 | gains = 2 ** y_true - 1
266 | elif gains == "linear":
267 | gains = y_true
268 | else:
269 | raise ValueError("Invalid gains option.")
270 |
271 | # highest rank is 1 so +2 instead of +1
272 | discounts = np.log2(np.arange(len(y_true)) + 2)
273 | return np.sum(gains / discounts)
274 |
275 |
276 | def ndcg_score(y_true, y_score, k=10, gains="exponential"):
277 | """Normalized discounted cumulative gain (NDCG) at rank k
278 | Parameters
279 | ----------
280 | y_true : array-like, shape = [n_samples]
281 | Ground truth (true relevance labels).
282 | y_score : array-like, shape = [n_samples]
283 | Predicted scores.
284 | k : int
285 | Rank.
286 | gains : str
287 | Whether gains should be "exponential" (default) or "linear".
288 | Returns
289 | -------
290 | NDCG @k : float
291 | """
292 | best = dcg_score(y_true, y_true, k, gains)
293 | actual = dcg_score(y_true, y_score, k, gains)
294 | return actual / best
295 |
296 |
297 | # Alternative API.
298 |
299 | def dcg_from_ranking(y_true, ranking):
300 | """Discounted cumulative gain (DCG) at rank k
301 | Parameters
302 | ----------
303 | y_true : array-like, shape = [n_samples]
304 | Ground truth (true relevance labels).
305 | ranking : array-like, shape = [k]
306 | Document indices, i.e.,
307 | ranking[0] is the index of top-ranked document,
308 | ranking[1] is the index of second-ranked document,
309 | ...
310 | k : int
311 | Rank.
312 | Returns
313 | -------
314 | DCG @k : float
315 | """
316 | y_true = np.asarray(y_true)
317 | ranking = np.asarray(ranking)
318 | rel = y_true[ranking]
319 | gains = 2 ** rel - 1
320 | discounts = np.log2(np.arange(len(ranking)) + 2)
321 | return np.sum(gains / discounts)
322 |
323 |
324 | def ndcg_from_ranking(y_true, ranking):
325 | """Normalized discounted cumulative gain (NDCG) at rank k
326 | Parameters
327 | ----------
328 | y_true : array-like, shape = [n_samples]
329 | Ground truth (true relevance labels).
330 | ranking : array-like, shape = [k]
331 | Document indices, i.e.,
332 | ranking[0] is the index of top-ranked document,
333 | ranking[1] is the index of second-ranked document,
334 | ...
335 | k : int
336 | Rank.
337 | Returns
338 | -------
339 | NDCG @k : float
340 | """
341 | k = len(ranking)
342 | best_ranking = np.argsort(y_true)[::-1]
343 | best = dcg_from_ranking(y_true, best_ranking[:k])
344 | return dcg_from_ranking(y_true, ranking) / best
345 |
346 | def colwise_accuracy(y_true,y_pred):
347 | y_pred=y_pred.T
348 | y_true=y_true.T
349 | acc_list=[]
350 | for cate in range(0,y_pred.shape[0]):
351 | acc_list.append(accuracy_score(y_pred[cate],y_true[cate]))
352 | return sum(acc_list)/len(acc_list)
353 |
354 | def calculate_metrics(pred, target, threshold=0.5):
355 |
356 | pred = np.array(pred > threshold, dtype=float)
357 |
358 | return {'Accuracy': accuracy_score(y_true=target, y_pred=pred),
359 | 'Column-wise Accuracy': colwise_accuracy(y_true=target, y_pred=pred),
360 | 'micro/precision': precision_score(y_true=target, y_pred=pred, average='micro'),
361 | 'micro/recall': recall_score(y_true=target, y_pred=pred, average='micro'),
362 | 'micro/f1': f1_score(y_true=target, y_pred=pred, average='micro'),
363 | 'macro/precision': precision_score(y_true=target, y_pred=pred, average='macro'),
364 | 'macro/recall': recall_score(y_true=target, y_pred=pred, average='macro'),
365 | 'macro/f1': f1_score(y_true=target, y_pred=pred, average='macro'),
366 | 'samples/precision': precision_score(y_true=target, y_pred=pred, average='samples'),
367 | 'samples/recall': recall_score(y_true=target, y_pred=pred, average='samples'),
368 | 'samples/f1': f1_score(y_true=target, y_pred=pred, average='samples'),
369 | }
--------------------------------------------------------------------------------
/voc_classifier/.ipynb_checkpoints/kobert_multilabel_text_classifier-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | " \n",
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "\n",
16 | "
KoBERT Multi-label text classifier
\n",
17 | " By: Myeonghak Lee
\n",
18 | ""
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | " \n",
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "tags": []
34 | },
35 | "outputs": [],
36 | "source": [
37 | "# Input Data 가공 파트\n",
38 | "\n",
39 | "# import torchtext\n",
40 | "import pandas as pd\n",
41 | "import numpy as np\n",
42 | "\n",
43 | "import os\n",
44 | "import re\n",
45 | "\n",
46 | "import config\n",
47 | "from config import expand_pandas\n",
48 | "from preprocess import preprocess\n",
49 | "\n",
50 | "DATA_PATH=config.DATA_PATH\n",
51 | "\n",
52 | "model_config=config.model_config"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 2,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "import warnings\n",
62 | "warnings.filterwarnings(\"ignore\")"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "done\n"
75 | ]
76 | }
77 | ],
78 | "source": [
79 | "config.expand_pandas(max_rows=100, max_cols=100,width=1000,max_info_cols=500)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "### **configs**"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 4,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "num_class=17\n",
96 | "ver_num=1\n",
97 | "\n",
98 | "except_labels=[\"변경/취소\",\"예약기타\"]\n",
99 | "\n",
100 | "version_info=\"{:02d}\".format(ver_num)\n",
101 | "weight_path=f\"../weights/weight_{version_info}.pt\""
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "### **preprocess**"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 12,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "data=preprocess()"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 13,
123 | "metadata": {},
124 | "outputs": [
125 | {
126 | "data": {
127 | "text/plain": [
128 | "True"
129 | ]
130 | },
131 | "execution_count": 13,
132 | "metadata": {},
133 | "output_type": "execute_result"
134 | }
135 | ],
136 | "source": [
137 | "# data_orig=data.voc_total[\"종합본\"]\n",
138 | "data.make_table()\n",
139 | "\n",
140 | "# put labels\n",
141 | "data.label_process(num_labels=num_class, except_labels=except_labels)"
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": 14,
147 | "metadata": {},
148 | "outputs": [],
149 | "source": [
150 | "orig=data.voc_total[\"종합본\"]"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": 15,
156 | "metadata": {},
157 | "outputs": [],
158 | "source": [
159 | "label_cols=data.label_cols"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 16,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "df=data.data.copy()"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 17,
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "voc_dataset=df.reset_index(drop=True)"
178 | ]
179 | },
180 | {
181 | "cell_type": "markdown",
182 | "metadata": {},
183 | "source": [
184 | "# Modeling part"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {},
191 | "outputs": [],
192 | "source": [
193 | "import torch\n",
194 | "from torch import nn\n",
195 | "\n",
196 | "from metrics_for_multilabel import calculate_metrics, colwise_accuracy\n",
197 | "\n",
198 | "from bert_model import Data_for_BERT, BERTClassifier, EarlyStopping\n",
199 | "\n",
200 | "from transformers import get_linear_schedule_with_warmup, AdamW"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "metadata": {},
207 | "outputs": [],
208 | "source": []
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 15,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": [
216 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": null,
222 | "metadata": {},
223 | "outputs": [],
224 | "source": []
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 24,
229 | "metadata": {},
230 | "outputs": [
231 | {
232 | "ename": "NameError",
233 | "evalue": "name 'Data_for_BERT' is not defined",
234 | "output_type": "error",
235 | "traceback": [
236 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
237 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
238 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mdata_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mdata_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
239 | "\u001b[0;31mNameError\u001b[0m: name 'Data_for_BERT' is not defined"
240 | ]
241 | }
242 | ],
243 | "source": [
244 | "from sklearn.model_selection import train_test_split\n",
245 | "train_input, test_input, train_target, test_target = train_test_split(voc_dataset, voc_dataset[\"국내선\"], test_size = 0.25, random_state = 42)\n",
246 | "\n",
247 | "# train=pd.concat([train_input,train_target],axis=1)\n",
248 | "# test=pd.concat([test_input,test_target],axis=1)\n",
249 | "\n",
250 | "train=train_input.copy()\n",
251 | "test=test_input.copy()\n",
252 | "\n",
253 | "train=train.reset_index(drop=True)\n",
254 | "test=test.reset_index(drop=True)\n",
255 | "\n",
256 | "data_train = Data_for_BERT(train, model_config[\"max_len\"], True, False, label_cols=label_cols)\n",
257 | "data_test = Data_for_BERT(test, model_config[\"max_len\"], True, False, label_cols=label_cols)\n",
258 | "\n",
259 | "# 파이토치 모델에 넣을 수 있도록 데이터를 처리함. \n",
260 | "# data_train을 넣어주고, 이 테이터를 batch_size에 맞게 잘라줌. num_workers는 사용할 subprocess의 개수를 의미함(병렬 프로그래밍)\n",
261 | "\n",
262 | "train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=model_config[\"batch_size\"], num_workers=0)\n",
263 | "test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=model_config[\"batch_size\"], num_workers=0)"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": null,
269 | "metadata": {},
270 | "outputs": [],
271 | "source": []
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 18,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "# KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌.\n",
280 | "model = BERTClassifier(num_classes=num_class, dr_rate = model_config[\"dr_rate\"]).to(device)\n",
281 | "\n",
282 | "# 옵티마이저와 스케쥴 준비 (linear warmup과 decay)\n",
283 | "no_decay = ['bias', 'LayerNorm.weight']\n",
284 | "\n",
285 | "# no_decay에 해당하는 파라미터명을 가진 레이어들은 decay에서 배제하기 위해 weight_decay를 0으로 셋팅, 그 외에는 0.01로 decay\n",
286 | "# weight decay란 l2 norm으로 파라미터 값을 정규화해주는 기법을 의미함\n",
287 | "optimizer_grouped_parameters = [\n",
288 | " {\"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay' : 0.01},\n",
289 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
290 | "]\n",
291 | "\n",
292 | "\n",
293 | "# 옵티마이저는 AdamW, 손실함수는 BCE\n",
294 | "# optimizer_grouped_parameters는 최적화할 파라미터의 그룹을 의미함\n",
295 | "optimizer = AdamW(optimizer_grouped_parameters, lr= model_config[\"learning_rate\"])\n",
296 | "# loss_fn = nn.CrossEntropyLoss()\n",
297 | "loss_fn=nn.BCEWithLogitsLoss()\n",
298 | "\n",
299 | "\n",
300 | "# t_total = train_dataloader.dataset.labels.shape[0] * num_epochs\n",
301 | "# linear warmup을 사용해 학습 초기 단계(배치 초기)의 learning rate를 조금씩 증가시켜 나가다, 어느 지점에 이르면 constant하게 유지\n",
302 | "# 초기 학습 단계에서의 변동성을 줄여줌.\n",
303 | "\n",
304 | "t_total = len(train_dataloader) * model_config[\"num_epochs\"]\n",
305 | "warmup_step = int(t_total * model_config[\"warmup_ratio\"])\n",
306 | "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)\n",
307 | "\n",
308 | "\n",
309 | "# model_save_name = 'classifier'\n",
310 | "# model_file='.pt'\n",
311 | "# path = f\"./bert_weights/{model_save_name}_{model_file}\" "
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": null,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": []
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 23,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "def train_model(model, batch_size, patience, n_epochs,path):\n",
328 | " \n",
329 | " # to track the training loss as the model trains\n",
330 | " train_losses = []\n",
331 | " # to track the validation loss as the model trains\n",
332 | " valid_losses = []\n",
333 | " # to track the average training loss per epoch as the model trains\n",
334 | " avg_train_losses = []\n",
335 | " # to track the average validation loss per epoch as the model trains\n",
336 | " avg_valid_losses = [] \n",
337 | "\n",
338 | " early_stopping = EarlyStopping(patience=patience, verbose=True, path=path)\n",
339 | "\n",
340 | " for epoch in range(1, n_epochs + 1):\n",
341 | " \n",
342 | " # initialize the early_stopping object\n",
343 | " model.train()\n",
344 | " train_epoch_pred=[]\n",
345 | " train_loss_record=[]\n",
346 | "\n",
347 | " for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n",
348 | " optimizer.zero_grad()\n",
349 | "\n",
350 | " token_ids = token_ids.long().to(device)\n",
351 | " segment_ids = segment_ids.long().to(device)\n",
352 | " valid_length= valid_length\n",
353 | " \n",
354 | " # label = label.long().to(device)\n",
355 | " label = label.float().to(device)\n",
356 | "\n",
357 | " out= model(token_ids, valid_length, segment_ids)#.squeeze(1)\n",
358 | " \n",
359 | " loss = loss_fn(out, label)\n",
360 | "\n",
361 | " train_loss_record.append(loss)\n",
362 | "\n",
363 | " train_pred=out.detach().cpu().numpy()\n",
364 | " train_real=label.detach().cpu().numpy()\n",
365 | "\n",
366 | " train_batch_result = calculate_metrics(np.array(train_pred), np.array(train_real))\n",
367 | " \n",
368 | " if batch_id%50==0:\n",
369 | " print(f\"batch number {batch_id}, train col-wise accuracy is : {train_batch_result['Column-wise Accuracy']}\")\n",
370 | " \n",
371 | "\n",
372 | " # save prediction result for calculation of accuracy per batch\n",
373 | " train_epoch_pred.append(train_pred)\n",
374 | "\n",
375 | " \n",
376 | " loss.backward()\n",
377 | " torch.nn.utils.clip_grad_norm_(model.parameters(), model_config[\"max_grad_norm\"])\n",
378 | " optimizer.step()\n",
379 | " scheduler.step() # Update learning rate schedule\n",
380 | "\n",
381 | " train_losses.append(loss.item())\n",
382 | "\n",
383 | " train_epoch_pred=np.concatenate(train_epoch_pred)\n",
384 | " train_epoch_target=train_dataloader.dataset.labels\n",
385 | " train_epoch_result=calculate_metrics(target=train_epoch_target, pred=train_epoch_pred)\n",
386 | " \n",
387 | " print(f\"=====Training Report: mean loss is {sum(train_loss_record)/len(train_loss_record)}=====\")\n",
388 | " print(train_epoch_result)\n",
389 | " \n",
390 | " print(\"=====train done!=====\")\n",
391 | "\n",
392 | " # if e % log_interval == 0:\n",
393 | " # print(\"epoch {} batch id {} loss {} train acc {}\".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))\n",
394 | "\n",
395 | " # print(\"epoch {} train acc {}\".format(e+1, train_acc / (batch_id+1)))\n",
396 | " test_epoch_pred=[]\n",
397 | " test_loss_record=[]\n",
398 | "\n",
399 | " model.eval()\n",
400 | " with torch.no_grad():\n",
401 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n",
402 | " \n",
403 | " token_ids = token_ids.long().to(device)\n",
404 | " segment_ids = segment_ids.long().to(device)\n",
405 | " valid_length = valid_length\n",
406 | " \n",
407 | " # test_label = test_label.long().to(device)\n",
408 | " test_label = test_label.float().to(device)\n",
409 | "\n",
410 | " test_out = model(token_ids, valid_length, segment_ids)\n",
411 | "\n",
412 | " test_loss = loss_fn(test_out, test_label)\n",
413 | "\n",
414 | " test_loss_record.append(test_loss)\n",
415 | " \n",
416 | " valid_losses.append(test_loss.item())\n",
417 | "\n",
418 | " test_pred=test_out.detach().cpu().numpy()\n",
419 | " test_real=test_label.detach().cpu().numpy()\n",
420 | "\n",
421 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n",
422 | "\n",
423 | " if batch_id%50==0:\n",
424 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n",
425 | "\n",
426 | " # save prediction result for calculation of accuracy per epoch\n",
427 | " test_epoch_pred.append(test_pred)\n",
428 | "\n",
429 | " test_epoch_pred=np.concatenate(test_epoch_pred)\n",
430 | " test_epoch_target=test_dataloader.dataset.labels\n",
431 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n",
432 | "\n",
433 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n",
434 | " print(test_epoch_result)\n",
435 | "\n",
436 | " train_loss = np.average(train_losses)\n",
437 | " valid_loss = np.average(valid_losses)\n",
438 | " avg_train_losses.append(train_loss)\n",
439 | " avg_valid_losses.append(valid_loss)\n",
440 | "\n",
441 | " # clear lists to track next epoch\n",
442 | " train_losses = []\n",
443 | " valid_losses = []\n",
444 | "\n",
445 | " # early_stopping needs the validation loss to check if it has decresed, \n",
446 | " # and if it has, it will make a checkpoint of the current model\n",
447 | " early_stopping(valid_loss, model)\n",
448 | "\n",
449 | " if early_stopping.early_stop:\n",
450 | " print(\"Early stopping\")\n",
451 | " break\n",
452 | "\n",
453 | " # load the last checkpoint with the best model\n",
454 | " model.load_state_dict(torch.load(path))\n",
455 | "\n",
456 | " return model, avg_train_losses, avg_valid_losses\n",
457 | " "
458 | ]
459 | },
460 | {
461 | "cell_type": "code",
462 | "execution_count": null,
463 | "metadata": {
464 | "tags": [
465 | "outputPrepend"
466 | ]
467 | },
468 | "outputs": [],
469 | "source": [
470 | "# early stopping patience; how long to wait after last time validation loss improved.\n",
471 | "patience = 10\n",
472 | "model, train_loss, valid_loss = train_model(model, \n",
473 | " model_config[\"batch_size\"],\n",
474 | " patience, \n",
475 | " model_config[\"num_epochs\"], \n",
476 | " path=weight_path)\n"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": []
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": null,
489 | "metadata": {},
490 | "outputs": [],
491 | "source": []
492 | },
493 | {
494 | "cell_type": "markdown",
495 | "metadata": {},
496 | "source": [
497 | "# test performance"
498 | ]
499 | },
500 | {
501 | "cell_type": "code",
502 | "execution_count": 54,
503 | "metadata": {},
504 | "outputs": [],
505 | "source": [
506 | "weight_path=\"../weights/weight_01.pt\""
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": 55,
512 | "metadata": {},
513 | "outputs": [
514 | {
515 | "data": {
516 | "text/plain": [
517 | ""
518 | ]
519 | },
520 | "execution_count": 55,
521 | "metadata": {},
522 | "output_type": "execute_result"
523 | }
524 | ],
525 | "source": [
526 | "model.load_state_dict(torch.load(weight_path))"
527 | ]
528 | },
529 | {
530 | "cell_type": "code",
531 | "execution_count": 56,
532 | "metadata": {},
533 | "outputs": [
534 | {
535 | "name": "stdout",
536 | "output_type": "stream",
537 | "text": [
538 | "batch number 0, test col-wise accuracy is : 0.9294117647058825\n",
539 | "batch number 50, test col-wise accuracy is : 0.9411764705882353\n",
540 | "batch number 100, test col-wise accuracy is : 0.9058823529411765\n",
541 | "=====Testing Report: mean loss is 0.20872297883033752=====\n",
542 | "{'Accuracy': 0.22437137330754353, 'Column-wise Accuracy': 0.9210376607122539, 'micro/precision': 0.7973273942093542, 'micro/recall': 0.372528616024974, 'micro/f1': 0.5078014184397163, 'macro/precision': 0.6019785504362263, 'macro/recall': 0.28350515905377016, 'macro/f1': 0.34598051554393844, 'samples/precision': 0.563023855577047, 'samples/recall': 0.3934235976789168, 'samples/f1': 0.4393478861563968}\n"
543 | ]
544 | }
545 | ],
546 | "source": [
547 | "test_epoch_pred=[] \n",
548 | "test_loss_record=[] \n",
549 | "valid_losses=[]\n",
550 | "\n",
551 | "model.eval() \n",
552 | "with torch.no_grad(): \n",
553 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n",
554 | "\n",
555 | " token_ids = token_ids.long().to(device)\n",
556 | " segment_ids = segment_ids.long().to(device)\n",
557 | " valid_length = valid_length\n",
558 | " \n",
559 | " # test_label = test_label.long().to(device)\n",
560 | " test_label = test_label.float().to(device)\n",
561 | "\n",
562 | " test_out = model(token_ids, valid_length, segment_ids)\n",
563 | "\n",
564 | " test_loss = loss_fn(test_out, test_label)\n",
565 | "\n",
566 | " test_loss_record.append(test_loss)\n",
567 | " \n",
568 | " valid_losses.append(test_loss.item())\n",
569 | "\n",
570 | " test_pred=test_out.detach().cpu().numpy()\n",
571 | " test_real=test_label.detach().cpu().numpy()\n",
572 | "\n",
573 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n",
574 | "\n",
575 | " if batch_id%50==0:\n",
576 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n",
577 | "\n",
578 | " # save prediction result for calculation of accuracy per epoch\n",
579 | " test_epoch_pred.append(test_pred)\n",
580 | "\n",
581 | " # if batch_id%10==0:\n",
582 | " # print(test_batch_result[\"Accuracy\"])\n",
583 | " test_epoch_pred=np.concatenate(test_epoch_pred) \n",
584 | " test_epoch_target=test_dataloader.dataset.labels \n",
585 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n",
586 | "\n",
587 | " # print(test_epoch_pred)\n",
588 | " # print(test_epoch_target)\n",
589 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n",
590 | " print(test_epoch_result)"
591 | ]
592 | },
593 | {
594 | "cell_type": "code",
595 | "execution_count": null,
596 | "metadata": {},
597 | "outputs": [],
598 | "source": []
599 | },
600 | {
601 | "cell_type": "code",
602 | "execution_count": 58,
603 | "metadata": {},
604 | "outputs": [],
605 | "source": [
606 | "import metrics_for_multilabel as metrics"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "execution_count": 59,
612 | "metadata": {},
613 | "outputs": [
614 | {
615 | "data": {
616 | "text/plain": [
617 | "0.840774483390301"
618 | ]
619 | },
620 | "execution_count": 59,
621 | "metadata": {},
622 | "output_type": "execute_result"
623 | }
624 | ],
625 | "source": [
626 | "metrics.mean_ndcg_score(test_epoch_target,test_epoch_pred, k=17)"
627 | ]
628 | },
629 | {
630 | "cell_type": "code",
631 | "execution_count": null,
632 | "metadata": {},
633 | "outputs": [],
634 | "source": []
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": 60,
639 | "metadata": {},
640 | "outputs": [
641 | {
642 | "name": "stdout",
643 | "output_type": "stream",
644 | "text": [
645 | "accuracy: 0.6367182462927145\n"
646 | ]
647 | }
648 | ],
649 | "source": [
650 | "acc_cnt=0\n",
651 | "for n in range(test_epoch_pred.shape[0]):\n",
652 | " tar_cnt=np.count_nonzero(test_epoch_target[n])\n",
653 | " pred_=test_epoch_pred[n].argsort()[-tar_cnt:]\n",
654 | " tar_=test_epoch_target[n].argsort()[-tar_cnt:]\n",
655 | " acc_cnt+=len(set(pred_)&set(tar_))/len(pred_)\n",
656 | "print(f\"accuracy: {acc_cnt/test_epoch_pred.shape[0]}\")"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": null,
662 | "metadata": {},
663 | "outputs": [],
664 | "source": [
665 | "calculate_metrics(target=test_epoch_target, pred=test_epoch_pred, threshold=-1)"
666 | ]
667 | },
668 | {
669 | "cell_type": "code",
670 | "execution_count": 62,
671 | "metadata": {},
672 | "outputs": [],
673 | "source": [
674 | "label_cases_sorted_target=data.label_cols"
675 | ]
676 | },
677 | {
678 | "cell_type": "code",
679 | "execution_count": 63,
680 | "metadata": {},
681 | "outputs": [],
682 | "source": [
683 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=True, pair=False)\n",
684 | "\n",
685 | "def get_prediction_from_txt(input_text, threshold=0.0):\n",
686 | " sentences = transform([input_text])\n",
687 | " get_pred=model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device))\n",
688 | " pred=np.array(get_pred.to(\"cpu\").detach().numpy()[0] > threshold, dtype=float)\n",
689 | " pred=np.nonzero(pred)[0].tolist()\n",
690 | " print(f\"분석 결과, 대화의 예상 태그는 {[label_cases_sorted_target[i] for i in pred]} 입니다.\")\n",
691 | " true=np.nonzero(input_text_label)[0].tolist()\n",
692 | " print(f\"실제 태그는 {[label_cases_sorted_target[i] for i in true]} 입니다.\")\n",
693 | "\n"
694 | ]
695 | },
696 | {
697 | "cell_type": "code",
698 | "execution_count": 64,
699 | "metadata": {},
700 | "outputs": [],
701 | "source": [
702 | "input_text_num=17\n",
703 | "input_text=voc_dataset.iloc[input_text_num,0]\n",
704 | "# input_text=test.iloc[input_text_num,0]\n",
705 | "input_text_label=voc_dataset.iloc[input_text_num,1:].tolist()"
706 | ]
707 | },
708 | {
709 | "cell_type": "code",
710 | "execution_count": 65,
711 | "metadata": {},
712 | "outputs": [
713 | {
714 | "name": "stdout",
715 | "output_type": "stream",
716 | "text": [
717 | "분석 결과, 대화의 예상 태그는 ['대기예약', '무상신규예약'] 입니다.\n",
718 | "실제 태그는 ['무상신규예약'] 입니다.\n"
719 | ]
720 | }
721 | ],
722 | "source": [
723 | "get_prediction_from_txt(input_text, -1)"
724 | ]
725 | },
726 | {
727 | "cell_type": "code",
728 | "execution_count": null,
729 | "metadata": {},
730 | "outputs": [],
731 | "source": []
732 | },
733 | {
734 | "cell_type": "markdown",
735 | "metadata": {},
736 | "source": [
737 | "# XAI"
738 | ]
739 | },
740 | {
741 | "cell_type": "code",
742 | "execution_count": 69,
743 | "metadata": {},
744 | "outputs": [],
745 | "source": [
746 | "from captum_tools_vocvis import *"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": null,
752 | "metadata": {},
753 | "outputs": [],
754 | "source": [
755 | "from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization\n",
756 | "\n",
757 | "# model = BERTClassifier(bertmodel, dr_rate = 0.4).to(device)\n",
758 | "# model.load_state_dict(torch.load(os.getcwd()+\"/chat_voc_model.pt\", map_location=device))\n",
759 | "model.eval()"
760 | ]
761 | },
762 | {
763 | "cell_type": "code",
764 | "execution_count": 94,
765 | "metadata": {},
766 | "outputs": [],
767 | "source": [
768 | "PAD_IND = tok.vocab.padding_token\n",
769 | "PAD_IND = tok.convert_tokens_to_ids(PAD_IND)\n",
770 | "token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)\n",
771 | "lig = LayerIntegratedGradients(model,model.bert.embeddings)"
772 | ]
773 | },
774 | {
775 | "cell_type": "code",
776 | "execution_count": 95,
777 | "metadata": {},
778 | "outputs": [],
779 | "source": [
780 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = 64, pad=True, pair=False)\n",
781 | "\n",
782 | "voc_label_dict_inverse={ele:label_cols.index(ele) for ele in label_cols}\n",
783 | "\n",
784 | "voc_label_dict={label_cols.index(ele):ele for ele in label_cols}"
785 | ]
786 | },
787 | {
788 | "cell_type": "code",
789 | "execution_count": null,
790 | "metadata": {},
791 | "outputs": [],
792 | "source": []
793 | },
794 | {
795 | "cell_type": "code",
796 | "execution_count": 96,
797 | "metadata": {},
798 | "outputs": [],
799 | "source": [
800 | "def forward_with_sigmoid_for_bert(input,valid_length,segment_ids):\n",
801 | " return torch.sigmoid(model(input,valid_length,segment_ids))\n"
802 | ]
803 | },
804 | {
805 | "cell_type": "code",
806 | "execution_count": 97,
807 | "metadata": {},
808 | "outputs": [],
809 | "source": [
810 | "def forward_for_bert(input,valid_length,segment_ids):\n",
811 | " return torch.nn.functional.softmax(model(input,valid_length,segment_ids),dim=1)"
812 | ]
813 | },
814 | {
815 | "cell_type": "code",
816 | "execution_count": null,
817 | "metadata": {},
818 | "outputs": [],
819 | "source": []
820 | },
821 | {
822 | "cell_type": "code",
823 | "execution_count": 109,
824 | "metadata": {},
825 | "outputs": [],
826 | "source": [
827 | "# accumalate couple samples in this array for visualization purposes\n",
828 | "vis_data_records_ig = []\n",
829 | "\n",
830 | "def interpret_sentence(model, sentence, min_len = 64, label = 0, n_steps=10):\n",
831 | " # text = [token for token in tok.sentencepiece(sentence)]\n",
832 | " # if len(text) < min_len:\n",
833 | " # text += ['pad'] * (min_len - len(text))\n",
834 | " # indexed = tok.convert_tokens_to_ids(text)\n",
835 | " # print(text)\n",
836 | " \n",
837 | " # 토크나이징, 시퀀스 생성\n",
838 | " seq_tokens=transform([sentence])\n",
839 | " indexed=torch.tensor(seq_tokens[0]).long()#.to(device)\n",
840 | " valid_length=torch.tensor(seq_tokens[1]).long().unsqueeze(0)\n",
841 | " segment_ids=torch.tensor(seq_tokens[2]).long().unsqueeze(0).to(device)\n",
842 | " sentence=[token for token in tok.sentencepiece(sentence)]\n",
843 | " \n",
844 | "\n",
845 | " with torch.no_grad():\n",
846 | " model.zero_grad()\n",
847 | "\n",
848 | " input_indices = torch.tensor(indexed, device=device)\n",
849 | " input_indices = input_indices.unsqueeze(0)\n",
850 | " \n",
851 | " seq_length = min_len\n",
852 | "\n",
853 | " # predict\n",
854 | " pred = forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids).detach().cpu().numpy().argmax().item()\n",
855 | " print(forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids))\n",
856 | " pred_ind = round(pred)\n",
857 | " \n",
858 | " # generate reference indices for each sample\n",
859 | " reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)\n",
860 | "\n",
861 | " # compute attributions and approximation delta using layer integrated gradients\n",
862 | " attributions_ig, delta = lig.attribute(input_indices, reference_indices,\\\n",
863 | " n_steps=n_steps, return_convergence_delta=True,target=label,\\\n",
864 | " additional_forward_args=(valid_length,segment_ids))\n",
865 | "\n",
866 | " print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))\n",
867 | "\n",
868 | " add_attributions_to_visualizer(attributions_ig, sentence, pred, pred_ind, label, delta, vis_data_records_ig)"
869 | ]
870 | },
871 | {
872 | "cell_type": "code",
873 | "execution_count": 110,
874 | "metadata": {},
875 | "outputs": [],
876 | "source": [
877 | "def add_attributions_to_visualizer(attributions, input_text, pred, pred_ind, label, delta, vis_data_records):\n",
878 | " attributions = attributions.sum(dim=2).squeeze(0)\n",
879 | " attributions = attributions / torch.norm(attributions)\n",
880 | " attributions = attributions.cpu().detach().numpy()\n",
881 | "\n",
882 | " # storing couple samples in an array for visualization purposes\n",
883 | " vis_data_records.append(visualization.VisualizationDataRecord(\n",
884 | " attributions,\n",
885 | " pred,\n",
886 | " voc_label_dict[pred_ind], #Label.vocab.itos[pred_ind],\n",
887 | " voc_label_dict[label], # Label.vocab.itos[label],\n",
888 | " 100, # Label.vocab.itos[1],\n",
889 | " attributions.sum(), \n",
890 | " input_text,\n",
891 | " delta))"
892 | ]
893 | },
894 | {
895 | "cell_type": "code",
896 | "execution_count": 126,
897 | "metadata": {},
898 | "outputs": [
899 | {
900 | "data": {
901 | "text/html": [
902 | "Legend: Negative Neutral Positive
| True Label | Predicted Label | Attribution Label | Attribution Score | Word Importance |
|---|
"
903 | ],
904 | "text/plain": [
905 | ""
906 | ]
907 | },
908 | "metadata": {},
909 | "output_type": "display_data"
910 | }
911 | ],
912 | "source": [
913 | "sentence=voc_dataset.iloc[22].text\n",
914 | "\n",
915 | "visualize_text(vis_data_records_ig)"
916 | ]
917 | }
918 | ],
919 | "metadata": {
920 | "accelerator": "GPU",
921 | "colab": {
922 | "collapsed_sections": [],
923 | "machine_shape": "hm",
924 | "name": "KoBERT_PoC_1211.ipynb",
925 | "provenance": []
926 | },
927 | "kernelspec": {
928 | "display_name": "Python 3",
929 | "language": "python",
930 | "name": "python3"
931 | },
932 | "language_info": {
933 | "codemirror_mode": {
934 | "name": "ipython",
935 | "version": 3
936 | },
937 | "file_extension": ".py",
938 | "mimetype": "text/x-python",
939 | "name": "python",
940 | "nbconvert_exporter": "python",
941 | "pygments_lexer": "ipython3",
942 | "version": "3.8.8"
943 | }
944 | },
945 | "nbformat": 4,
946 | "nbformat_minor": 4
947 | }
948 |
--------------------------------------------------------------------------------
/voc_classifier/kobert_multilabel_text_classifier.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | " \n",
8 | "
"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "metadata": {},
14 | "source": [
15 | "\n",
16 | "
KoBERT Multi-label text classifier
\n",
17 | " By: Myeonghak Lee
\n",
18 | ""
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | " \n",
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "execution_count": 1,
32 | "metadata": {
33 | "tags": []
34 | },
35 | "outputs": [],
36 | "source": [
37 | "# Input Data 가공 파트\n",
38 | "\n",
39 | "# import torchtext\n",
40 | "import pandas as pd\n",
41 | "import numpy as np\n",
42 | "\n",
43 | "import os\n",
44 | "import re\n",
45 | "\n",
46 | "import config\n",
47 | "from config import expand_pandas\n",
48 | "from preprocess import preprocess\n",
49 | "\n",
50 | "DATA_PATH=config.DATA_PATH\n",
51 | "\n",
52 | "model_config=config.model_config"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 2,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "import warnings\n",
62 | "warnings.filterwarnings(\"ignore\")"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "done\n"
75 | ]
76 | }
77 | ],
78 | "source": [
79 | "config.expand_pandas(max_rows=100, max_cols=100,width=1000,max_info_cols=500)"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "metadata": {},
85 | "source": [
86 | "### **configs**"
87 | ]
88 | },
89 | {
90 | "cell_type": "code",
91 | "execution_count": 4,
92 | "metadata": {},
93 | "outputs": [],
94 | "source": [
95 | "num_class=17\n",
96 | "ver_num=1\n",
97 | "\n",
98 | "except_labels=[\"변경/취소\",\"예약기타\"]\n",
99 | "\n",
100 | "version_info=\"{:02d}\".format(ver_num)\n",
101 | "weight_path=f\"../weights/weight_{version_info}.pt\""
102 | ]
103 | },
104 | {
105 | "cell_type": "markdown",
106 | "metadata": {},
107 | "source": [
108 | "### **preprocess**"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": []
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": 12,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "data=preprocess()"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 13,
130 | "metadata": {},
131 | "outputs": [
132 | {
133 | "data": {
134 | "text/plain": [
135 | "True"
136 | ]
137 | },
138 | "execution_count": 13,
139 | "metadata": {},
140 | "output_type": "execute_result"
141 | }
142 | ],
143 | "source": [
144 | "# data_orig=data.voc_total[\"종합본\"]\n",
145 | "data.make_table()\n",
146 | "\n",
147 | "# put labels\n",
148 | "data.label_process(num_labels=num_class, except_labels=except_labels)"
149 | ]
150 | },
151 | {
152 | "cell_type": "code",
153 | "execution_count": 14,
154 | "metadata": {},
155 | "outputs": [],
156 | "source": [
157 | "orig=data.voc_total[\"종합본\"]"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 15,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "label_cols=data.label_cols"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 16,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "df=data.data.copy()"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": 17,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "voc_dataset=df.reset_index(drop=True)"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "# Modeling part"
192 | ]
193 | },
194 | {
195 | "cell_type": "code",
196 | "execution_count": null,
197 | "metadata": {},
198 | "outputs": [],
199 | "source": [
200 | "import torch\n",
201 | "from torch import nn\n",
202 | "\n",
203 | "from metrics_for_multilabel import calculate_metrics, colwise_accuracy\n",
204 | "\n",
205 | "from bert_model import Data_for_BERT, BERTClassifier, EarlyStopping\n",
206 | "\n",
207 | "from transformers import get_linear_schedule_with_warmup, AdamW"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": null,
213 | "metadata": {},
214 | "outputs": [],
215 | "source": []
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 15,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": null,
229 | "metadata": {},
230 | "outputs": [],
231 | "source": []
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 24,
236 | "metadata": {},
237 | "outputs": [
238 | {
239 | "ename": "NameError",
240 | "evalue": "name 'Data_for_BERT' is not defined",
241 | "output_type": "error",
242 | "traceback": [
243 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
244 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
245 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtest\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset_index\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdrop\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mdata_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0mdata_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mData_for_BERT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_config\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"max_len\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabel_cols\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabel_cols\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
246 | "\u001b[0;31mNameError\u001b[0m: name 'Data_for_BERT' is not defined"
247 | ]
248 | }
249 | ],
250 | "source": [
251 | "from sklearn.model_selection import train_test_split\n",
252 | "train_input, test_input, train_target, test_target = train_test_split(voc_dataset, voc_dataset[\"국내선\"], test_size = 0.25, random_state = 42)\n",
253 | "\n",
254 | "# train=pd.concat([train_input,train_target],axis=1)\n",
255 | "# test=pd.concat([test_input,test_target],axis=1)\n",
256 | "\n",
257 | "train=train_input.copy()\n",
258 | "test=test_input.copy()\n",
259 | "\n",
260 | "train=train.reset_index(drop=True)\n",
261 | "test=test.reset_index(drop=True)\n",
262 | "\n",
263 | "data_train = Data_for_BERT(train, model_config[\"max_len\"], True, False, label_cols=label_cols)\n",
264 | "data_test = Data_for_BERT(test, model_config[\"max_len\"], True, False, label_cols=label_cols)\n",
265 | "\n",
266 | "# 파이토치 모델에 넣을 수 있도록 데이터를 처리함. \n",
267 | "# data_train을 넣어주고, 이 테이터를 batch_size에 맞게 잘라줌. num_workers는 사용할 subprocess의 개수를 의미함(병렬 프로그래밍)\n",
268 | "\n",
269 | "train_dataloader = torch.utils.data.DataLoader(data_train, batch_size=model_config[\"batch_size\"], num_workers=0)\n",
270 | "test_dataloader = torch.utils.data.DataLoader(data_test, batch_size=model_config[\"batch_size\"], num_workers=0)"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": null,
276 | "metadata": {},
277 | "outputs": [],
278 | "source": []
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 18,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": [
286 | "# KoBERT 라이브러리에서 bertmodel을 호출함. .to() 메서드는 모델 전체를 GPU 디바이스에 옮겨 줌.\n",
287 | "model = BERTClassifier(num_classes=num_class, dr_rate = model_config[\"dr_rate\"]).to(device)\n",
288 | "\n",
289 | "# 옵티마이저와 스케쥴 준비 (linear warmup과 decay)\n",
290 | "no_decay = ['bias', 'LayerNorm.weight']\n",
291 | "\n",
292 | "# no_decay에 해당하는 파라미터명을 가진 레이어들은 decay에서 배제하기 위해 weight_decay를 0으로 셋팅, 그 외에는 0.01로 decay\n",
293 | "# weight decay란 l2 norm으로 파라미터 값을 정규화해주는 기법을 의미함\n",
294 | "optimizer_grouped_parameters = [\n",
295 | " {\"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay' : 0.01},\n",
296 | " {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
297 | "]\n",
298 | "\n",
299 | "\n",
300 | "# 옵티마이저는 AdamW, 손실함수는 BCE\n",
301 | "# optimizer_grouped_parameters는 최적화할 파라미터의 그룹을 의미함\n",
302 | "optimizer = AdamW(optimizer_grouped_parameters, lr= model_config[\"learning_rate\"])\n",
303 | "# loss_fn = nn.CrossEntropyLoss()\n",
304 | "loss_fn=nn.BCEWithLogitsLoss()\n",
305 | "\n",
306 | "\n",
307 | "# t_total = train_dataloader.dataset.labels.shape[0] * num_epochs\n",
308 | "# linear warmup을 사용해 학습 초기 단계(배치 초기)의 learning rate를 조금씩 증가시켜 나가다, 어느 지점에 이르면 constant하게 유지\n",
309 | "# 초기 학습 단계에서의 변동성을 줄여줌.\n",
310 | "\n",
311 | "t_total = len(train_dataloader) * model_config[\"num_epochs\"]\n",
312 | "warmup_step = int(t_total * model_config[\"warmup_ratio\"])\n",
313 | "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_step, num_training_steps=t_total)\n",
314 | "\n",
315 | "\n",
316 | "# model_save_name = 'classifier'\n",
317 | "# model_file='.pt'\n",
318 | "# path = f\"./bert_weights/{model_save_name}_{model_file}\" "
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": null,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": []
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 23,
331 | "metadata": {},
332 | "outputs": [],
333 | "source": [
334 | "def train_model(model, batch_size, patience, n_epochs,path):\n",
335 | " \n",
336 | " # to track the training loss as the model trains\n",
337 | " train_losses = []\n",
338 | " # to track the validation loss as the model trains\n",
339 | " valid_losses = []\n",
340 | " # to track the average training loss per epoch as the model trains\n",
341 | " avg_train_losses = []\n",
342 | " # to track the average validation loss per epoch as the model trains\n",
343 | " avg_valid_losses = [] \n",
344 | "\n",
345 | " early_stopping = EarlyStopping(patience=patience, verbose=True, path=path)\n",
346 | "\n",
347 | " for epoch in range(1, n_epochs + 1):\n",
348 | " \n",
349 | " # initialize the early_stopping object\n",
350 | " model.train()\n",
351 | " train_epoch_pred=[]\n",
352 | " train_loss_record=[]\n",
353 | "\n",
354 | " for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(train_dataloader):\n",
355 | " optimizer.zero_grad()\n",
356 | "\n",
357 | " token_ids = token_ids.long().to(device)\n",
358 | " segment_ids = segment_ids.long().to(device)\n",
359 | " valid_length= valid_length\n",
360 | " \n",
361 | " # label = label.long().to(device)\n",
362 | " label = label.float().to(device)\n",
363 | "\n",
364 | " out= model(token_ids, valid_length, segment_ids)#.squeeze(1)\n",
365 | " \n",
366 | " loss = loss_fn(out, label)\n",
367 | "\n",
368 | " train_loss_record.append(loss)\n",
369 | "\n",
370 | " train_pred=out.detach().cpu().numpy()\n",
371 | " train_real=label.detach().cpu().numpy()\n",
372 | "\n",
373 | " train_batch_result = calculate_metrics(np.array(train_pred), np.array(train_real))\n",
374 | " \n",
375 | " if batch_id%50==0:\n",
376 | " print(f\"batch number {batch_id}, train col-wise accuracy is : {train_batch_result['Column-wise Accuracy']}\")\n",
377 | " \n",
378 | "\n",
379 | " # save prediction result for calculation of accuracy per batch\n",
380 | " train_epoch_pred.append(train_pred)\n",
381 | "\n",
382 | " \n",
383 | " loss.backward()\n",
384 | " torch.nn.utils.clip_grad_norm_(model.parameters(), model_config[\"max_grad_norm\"])\n",
385 | " optimizer.step()\n",
386 | " scheduler.step() # Update learning rate schedule\n",
387 | "\n",
388 | " train_losses.append(loss.item())\n",
389 | "\n",
390 | " train_epoch_pred=np.concatenate(train_epoch_pred)\n",
391 | " train_epoch_target=train_dataloader.dataset.labels\n",
392 | " train_epoch_result=calculate_metrics(target=train_epoch_target, pred=train_epoch_pred)\n",
393 | " \n",
394 | " print(f\"=====Training Report: mean loss is {sum(train_loss_record)/len(train_loss_record)}=====\")\n",
395 | " print(train_epoch_result)\n",
396 | " \n",
397 | " print(\"=====train done!=====\")\n",
398 | "\n",
399 | " # if e % log_interval == 0:\n",
400 | " # print(\"epoch {} batch id {} loss {} train acc {}\".format(e+1, batch_id+1, loss.data.cpu().numpy(), train_acc / (batch_id+1)))\n",
401 | "\n",
402 | " # print(\"epoch {} train acc {}\".format(e+1, train_acc / (batch_id+1)))\n",
403 | " test_epoch_pred=[]\n",
404 | " test_loss_record=[]\n",
405 | "\n",
406 | " model.eval()\n",
407 | " with torch.no_grad():\n",
408 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n",
409 | " \n",
410 | " token_ids = token_ids.long().to(device)\n",
411 | " segment_ids = segment_ids.long().to(device)\n",
412 | " valid_length = valid_length\n",
413 | " \n",
414 | " # test_label = test_label.long().to(device)\n",
415 | " test_label = test_label.float().to(device)\n",
416 | "\n",
417 | " test_out = model(token_ids, valid_length, segment_ids)\n",
418 | "\n",
419 | " test_loss = loss_fn(test_out, test_label)\n",
420 | "\n",
421 | " test_loss_record.append(test_loss)\n",
422 | " \n",
423 | " valid_losses.append(test_loss.item())\n",
424 | "\n",
425 | " test_pred=test_out.detach().cpu().numpy()\n",
426 | " test_real=test_label.detach().cpu().numpy()\n",
427 | "\n",
428 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n",
429 | "\n",
430 | " if batch_id%50==0:\n",
431 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n",
432 | "\n",
433 | " # save prediction result for calculation of accuracy per epoch\n",
434 | " test_epoch_pred.append(test_pred)\n",
435 | "\n",
436 | " test_epoch_pred=np.concatenate(test_epoch_pred)\n",
437 | " test_epoch_target=test_dataloader.dataset.labels\n",
438 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n",
439 | "\n",
440 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n",
441 | " print(test_epoch_result)\n",
442 | "\n",
443 | " train_loss = np.average(train_losses)\n",
444 | " valid_loss = np.average(valid_losses)\n",
445 | " avg_train_losses.append(train_loss)\n",
446 | " avg_valid_losses.append(valid_loss)\n",
447 | "\n",
448 | " # clear lists to track next epoch\n",
449 | " train_losses = []\n",
450 | " valid_losses = []\n",
451 | "\n",
452 | " # early_stopping needs the validation loss to check if it has decresed, \n",
453 | " # and if it has, it will make a checkpoint of the current model\n",
454 | " early_stopping(valid_loss, model)\n",
455 | "\n",
456 | " if early_stopping.early_stop:\n",
457 | " print(\"Early stopping\")\n",
458 | " break\n",
459 | "\n",
460 | " # load the last checkpoint with the best model\n",
461 | " model.load_state_dict(torch.load(path))\n",
462 | "\n",
463 | " return model, avg_train_losses, avg_valid_losses\n",
464 | " "
465 | ]
466 | },
467 | {
468 | "cell_type": "code",
469 | "execution_count": null,
470 | "metadata": {
471 | "tags": [
472 | "outputPrepend"
473 | ]
474 | },
475 | "outputs": [],
476 | "source": [
477 | "# early stopping patience; how long to wait after last time validation loss improved.\n",
478 | "patience = 10\n",
479 | "model, train_loss, valid_loss = train_model(model, \n",
480 | " model_config[\"batch_size\"],\n",
481 | " patience, \n",
482 | " model_config[\"num_epochs\"], \n",
483 | " path=weight_path)\n"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": null,
489 | "metadata": {},
490 | "outputs": [],
491 | "source": []
492 | },
493 | {
494 | "cell_type": "code",
495 | "execution_count": null,
496 | "metadata": {},
497 | "outputs": [],
498 | "source": []
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "metadata": {},
503 | "source": [
504 | "# test performance"
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "execution_count": 54,
510 | "metadata": {},
511 | "outputs": [],
512 | "source": [
513 | "weight_path=\"../weights/weight_01.pt\""
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": 55,
519 | "metadata": {},
520 | "outputs": [
521 | {
522 | "data": {
523 | "text/plain": [
524 | ""
525 | ]
526 | },
527 | "execution_count": 55,
528 | "metadata": {},
529 | "output_type": "execute_result"
530 | }
531 | ],
532 | "source": [
533 | "model.load_state_dict(torch.load(weight_path))"
534 | ]
535 | },
536 | {
537 | "cell_type": "code",
538 | "execution_count": 56,
539 | "metadata": {},
540 | "outputs": [
541 | {
542 | "name": "stdout",
543 | "output_type": "stream",
544 | "text": [
545 | "batch number 0, test col-wise accuracy is : 0.9294117647058825\n",
546 | "batch number 50, test col-wise accuracy is : 0.9411764705882353\n",
547 | "batch number 100, test col-wise accuracy is : 0.9058823529411765\n",
548 | "=====Testing Report: mean loss is 0.20872297883033752=====\n",
549 | "{'Accuracy': 0.22437137330754353, 'Column-wise Accuracy': 0.9210376607122539, 'micro/precision': 0.7973273942093542, 'micro/recall': 0.372528616024974, 'micro/f1': 0.5078014184397163, 'macro/precision': 0.6019785504362263, 'macro/recall': 0.28350515905377016, 'macro/f1': 0.34598051554393844, 'samples/precision': 0.563023855577047, 'samples/recall': 0.3934235976789168, 'samples/f1': 0.4393478861563968}\n"
550 | ]
551 | }
552 | ],
553 | "source": [
554 | "test_epoch_pred=[] \n",
555 | "test_loss_record=[] \n",
556 | "valid_losses=[]\n",
557 | "\n",
558 | "model.eval() \n",
559 | "with torch.no_grad(): \n",
560 | " for batch_id, (token_ids, valid_length, segment_ids, test_label) in enumerate(test_dataloader):\n",
561 | "\n",
562 | " token_ids = token_ids.long().to(device)\n",
563 | " segment_ids = segment_ids.long().to(device)\n",
564 | " valid_length = valid_length\n",
565 | " \n",
566 | " # test_label = test_label.long().to(device)\n",
567 | " test_label = test_label.float().to(device)\n",
568 | "\n",
569 | " test_out = model(token_ids, valid_length, segment_ids)\n",
570 | "\n",
571 | " test_loss = loss_fn(test_out, test_label)\n",
572 | "\n",
573 | " test_loss_record.append(test_loss)\n",
574 | " \n",
575 | " valid_losses.append(test_loss.item())\n",
576 | "\n",
577 | " test_pred=test_out.detach().cpu().numpy()\n",
578 | " test_real=test_label.detach().cpu().numpy()\n",
579 | "\n",
580 | " test_batch_result = calculate_metrics(np.array(test_pred), np.array(test_real))\n",
581 | "\n",
582 | " if batch_id%50==0:\n",
583 | " print(f\"batch number {batch_id}, test col-wise accuracy is : {test_batch_result['Column-wise Accuracy']}\")\n",
584 | "\n",
585 | " # save prediction result for calculation of accuracy per epoch\n",
586 | " test_epoch_pred.append(test_pred)\n",
587 | "\n",
588 | " # if batch_id%10==0:\n",
589 | " # print(test_batch_result[\"Accuracy\"])\n",
590 | " test_epoch_pred=np.concatenate(test_epoch_pred) \n",
591 | " test_epoch_target=test_dataloader.dataset.labels \n",
592 | " test_epoch_result=calculate_metrics(target=test_epoch_target, pred=test_epoch_pred)\n",
593 | "\n",
594 | " # print(test_epoch_pred)\n",
595 | " # print(test_epoch_target)\n",
596 | " print(f\"=====Testing Report: mean loss is {sum(test_loss_record)/len(test_loss_record)}=====\")\n",
597 | " print(test_epoch_result)"
598 | ]
599 | },
600 | {
601 | "cell_type": "code",
602 | "execution_count": null,
603 | "metadata": {},
604 | "outputs": [],
605 | "source": []
606 | },
607 | {
608 | "cell_type": "code",
609 | "execution_count": 58,
610 | "metadata": {},
611 | "outputs": [],
612 | "source": [
613 | "import metrics_for_multilabel as metrics"
614 | ]
615 | },
616 | {
617 | "cell_type": "code",
618 | "execution_count": 59,
619 | "metadata": {},
620 | "outputs": [
621 | {
622 | "data": {
623 | "text/plain": [
624 | "0.840774483390301"
625 | ]
626 | },
627 | "execution_count": 59,
628 | "metadata": {},
629 | "output_type": "execute_result"
630 | }
631 | ],
632 | "source": [
633 | "metrics.mean_ndcg_score(test_epoch_target,test_epoch_pred, k=17)"
634 | ]
635 | },
636 | {
637 | "cell_type": "code",
638 | "execution_count": null,
639 | "metadata": {},
640 | "outputs": [],
641 | "source": []
642 | },
643 | {
644 | "cell_type": "code",
645 | "execution_count": 60,
646 | "metadata": {},
647 | "outputs": [
648 | {
649 | "name": "stdout",
650 | "output_type": "stream",
651 | "text": [
652 | "accuracy: 0.6367182462927145\n"
653 | ]
654 | }
655 | ],
656 | "source": [
657 | "acc_cnt=0\n",
658 | "for n in range(test_epoch_pred.shape[0]):\n",
659 | " tar_cnt=np.count_nonzero(test_epoch_target[n])\n",
660 | " pred_=test_epoch_pred[n].argsort()[-tar_cnt:]\n",
661 | " tar_=test_epoch_target[n].argsort()[-tar_cnt:]\n",
662 | " acc_cnt+=len(set(pred_)&set(tar_))/len(pred_)\n",
663 | "print(f\"accuracy: {acc_cnt/test_epoch_pred.shape[0]}\")"
664 | ]
665 | },
666 | {
667 | "cell_type": "code",
668 | "execution_count": null,
669 | "metadata": {},
670 | "outputs": [],
671 | "source": [
672 | "calculate_metrics(target=test_epoch_target, pred=test_epoch_pred, threshold=-1)"
673 | ]
674 | },
675 | {
676 | "cell_type": "code",
677 | "execution_count": 62,
678 | "metadata": {},
679 | "outputs": [],
680 | "source": [
681 | "label_cases_sorted_target=data.label_cols"
682 | ]
683 | },
684 | {
685 | "cell_type": "code",
686 | "execution_count": 63,
687 | "metadata": {},
688 | "outputs": [],
689 | "source": [
690 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = max_len, pad=True, pair=False)\n",
691 | "\n",
692 | "def get_prediction_from_txt(input_text, threshold=0.0):\n",
693 | " sentences = transform([input_text])\n",
694 | " get_pred=model(torch.tensor(sentences[0]).long().unsqueeze(0).to(device),torch.tensor(sentences[1]).unsqueeze(0),torch.tensor(sentences[2]).to(device))\n",
695 | " pred=np.array(get_pred.to(\"cpu\").detach().numpy()[0] > threshold, dtype=float)\n",
696 | " pred=np.nonzero(pred)[0].tolist()\n",
697 | " print(f\"분석 결과, 대화의 예상 태그는 {[label_cases_sorted_target[i] for i in pred]} 입니다.\")\n",
698 | " true=np.nonzero(input_text_label)[0].tolist()\n",
699 | " print(f\"실제 태그는 {[label_cases_sorted_target[i] for i in true]} 입니다.\")\n",
700 | "\n"
701 | ]
702 | },
703 | {
704 | "cell_type": "code",
705 | "execution_count": 64,
706 | "metadata": {},
707 | "outputs": [],
708 | "source": [
709 | "input_text_num=17\n",
710 | "input_text=voc_dataset.iloc[input_text_num,0]\n",
711 | "# input_text=test.iloc[input_text_num,0]\n",
712 | "input_text_label=voc_dataset.iloc[input_text_num,1:].tolist()"
713 | ]
714 | },
715 | {
716 | "cell_type": "code",
717 | "execution_count": 65,
718 | "metadata": {},
719 | "outputs": [
720 | {
721 | "name": "stdout",
722 | "output_type": "stream",
723 | "text": [
724 | "분석 결과, 대화의 예상 태그는 ['대기예약', '무상신규예약'] 입니다.\n",
725 | "실제 태그는 ['무상신규예약'] 입니다.\n"
726 | ]
727 | }
728 | ],
729 | "source": [
730 | "get_prediction_from_txt(input_text, -1)"
731 | ]
732 | },
733 | {
734 | "cell_type": "code",
735 | "execution_count": null,
736 | "metadata": {},
737 | "outputs": [],
738 | "source": []
739 | },
740 | {
741 | "cell_type": "markdown",
742 | "metadata": {},
743 | "source": [
744 | "# XAI"
745 | ]
746 | },
747 | {
748 | "cell_type": "code",
749 | "execution_count": 69,
750 | "metadata": {},
751 | "outputs": [],
752 | "source": [
753 | "from captum_tools_vocvis import *"
754 | ]
755 | },
756 | {
757 | "cell_type": "code",
758 | "execution_count": null,
759 | "metadata": {},
760 | "outputs": [],
761 | "source": [
762 | "from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization\n",
763 | "\n",
764 | "# model = BERTClassifier(bertmodel, dr_rate = 0.4).to(device)\n",
765 | "# model.load_state_dict(torch.load(os.getcwd()+\"/chat_voc_model.pt\", map_location=device))\n",
766 | "model.eval()"
767 | ]
768 | },
769 | {
770 | "cell_type": "code",
771 | "execution_count": 94,
772 | "metadata": {},
773 | "outputs": [],
774 | "source": [
775 | "PAD_IND = tok.vocab.padding_token\n",
776 | "PAD_IND = tok.convert_tokens_to_ids(PAD_IND)\n",
777 | "token_reference = TokenReferenceBase(reference_token_idx=PAD_IND)\n",
778 | "lig = LayerIntegratedGradients(model,model.bert.embeddings)"
779 | ]
780 | },
781 | {
782 | "cell_type": "code",
783 | "execution_count": 95,
784 | "metadata": {},
785 | "outputs": [],
786 | "source": [
787 | "transform = nlp.data.BERTSentenceTransform(tok, max_seq_length = 64, pad=True, pair=False)\n",
788 | "\n",
789 | "voc_label_dict_inverse={ele:label_cols.index(ele) for ele in label_cols}\n",
790 | "\n",
791 | "voc_label_dict={label_cols.index(ele):ele for ele in label_cols}"
792 | ]
793 | },
794 | {
795 | "cell_type": "code",
796 | "execution_count": null,
797 | "metadata": {},
798 | "outputs": [],
799 | "source": []
800 | },
801 | {
802 | "cell_type": "code",
803 | "execution_count": 96,
804 | "metadata": {},
805 | "outputs": [],
806 | "source": [
807 | "def forward_with_sigmoid_for_bert(input,valid_length,segment_ids):\n",
808 | " return torch.sigmoid(model(input,valid_length,segment_ids))\n"
809 | ]
810 | },
811 | {
812 | "cell_type": "code",
813 | "execution_count": 97,
814 | "metadata": {},
815 | "outputs": [],
816 | "source": [
817 | "def forward_for_bert(input,valid_length,segment_ids):\n",
818 | " return torch.nn.functional.softmax(model(input,valid_length,segment_ids),dim=1)"
819 | ]
820 | },
821 | {
822 | "cell_type": "code",
823 | "execution_count": null,
824 | "metadata": {},
825 | "outputs": [],
826 | "source": []
827 | },
828 | {
829 | "cell_type": "code",
830 | "execution_count": 109,
831 | "metadata": {},
832 | "outputs": [],
833 | "source": [
834 | "# accumalate couple samples in this array for visualization purposes\n",
835 | "vis_data_records_ig = []\n",
836 | "\n",
837 | "def interpret_sentence(model, sentence, min_len = 64, label = 0, n_steps=10):\n",
838 | " # text = [token for token in tok.sentencepiece(sentence)]\n",
839 | " # if len(text) < min_len:\n",
840 | " # text += ['pad'] * (min_len - len(text))\n",
841 | " # indexed = tok.convert_tokens_to_ids(text)\n",
842 | " # print(text)\n",
843 | " \n",
844 | " # 토크나이징, 시퀀스 생성\n",
845 | " seq_tokens=transform([sentence])\n",
846 | " indexed=torch.tensor(seq_tokens[0]).long()#.to(device)\n",
847 | " valid_length=torch.tensor(seq_tokens[1]).long().unsqueeze(0)\n",
848 | " segment_ids=torch.tensor(seq_tokens[2]).long().unsqueeze(0).to(device)\n",
849 | " sentence=[token for token in tok.sentencepiece(sentence)]\n",
850 | " \n",
851 | "\n",
852 | " with torch.no_grad():\n",
853 | " model.zero_grad()\n",
854 | "\n",
855 | " input_indices = torch.tensor(indexed, device=device)\n",
856 | " input_indices = input_indices.unsqueeze(0)\n",
857 | " \n",
858 | " seq_length = min_len\n",
859 | "\n",
860 | " # predict\n",
861 | " pred = forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids).detach().cpu().numpy().argmax().item()\n",
862 | " print(forward_with_sigmoid_for_bert(input_indices,valid_length,segment_ids))\n",
863 | " pred_ind = round(pred)\n",
864 | " \n",
865 | " # generate reference indices for each sample\n",
866 | " reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)\n",
867 | "\n",
868 | " # compute attributions and approximation delta using layer integrated gradients\n",
869 | " attributions_ig, delta = lig.attribute(input_indices, reference_indices,\\\n",
870 | " n_steps=n_steps, return_convergence_delta=True,target=label,\\\n",
871 | " additional_forward_args=(valid_length,segment_ids))\n",
872 | "\n",
873 | " print('pred: ', Label.vocab.itos[pred_ind], '(', '%.2f'%pred, ')', ', delta: ', abs(delta))\n",
874 | "\n",
875 | " add_attributions_to_visualizer(attributions_ig, sentence, pred, pred_ind, label, delta, vis_data_records_ig)"
876 | ]
877 | },
878 | {
879 | "cell_type": "code",
880 | "execution_count": 110,
881 | "metadata": {},
882 | "outputs": [],
883 | "source": [
884 | "def add_attributions_to_visualizer(attributions, input_text, pred, pred_ind, label, delta, vis_data_records):\n",
885 | " attributions = attributions.sum(dim=2).squeeze(0)\n",
886 | " attributions = attributions / torch.norm(attributions)\n",
887 | " attributions = attributions.cpu().detach().numpy()\n",
888 | "\n",
889 | " # storing couple samples in an array for visualization purposes\n",
890 | " vis_data_records.append(visualization.VisualizationDataRecord(\n",
891 | " attributions,\n",
892 | " pred,\n",
893 | " voc_label_dict[pred_ind], #Label.vocab.itos[pred_ind],\n",
894 | " voc_label_dict[label], # Label.vocab.itos[label],\n",
895 | " 100, # Label.vocab.itos[1],\n",
896 | " attributions.sum(), \n",
897 | " input_text,\n",
898 | " delta))"
899 | ]
900 | },
901 | {
902 | "cell_type": "code",
903 | "execution_count": 126,
904 | "metadata": {},
905 | "outputs": [
906 | {
907 | "data": {
908 | "text/html": [
909 | "Legend: Negative Neutral Positive
| True Label | Predicted Label | Attribution Label | Attribution Score | Word Importance |
|---|
"
910 | ],
911 | "text/plain": [
912 | ""
913 | ]
914 | },
915 | "metadata": {},
916 | "output_type": "display_data"
917 | }
918 | ],
919 | "source": [
920 | "sentence=voc_dataset.iloc[22].text\n",
921 | "\n",
922 | "visualize_text(vis_data_records_ig)"
923 | ]
924 | }
925 | ],
926 | "metadata": {
927 | "accelerator": "GPU",
928 | "colab": {
929 | "collapsed_sections": [],
930 | "machine_shape": "hm",
931 | "name": "KoBERT_PoC_1211.ipynb",
932 | "provenance": []
933 | },
934 | "kernelspec": {
935 | "display_name": "Python 3",
936 | "language": "python",
937 | "name": "python3"
938 | },
939 | "language_info": {
940 | "codemirror_mode": {
941 | "name": "ipython",
942 | "version": 3
943 | },
944 | "file_extension": ".py",
945 | "mimetype": "text/x-python",
946 | "name": "python",
947 | "nbconvert_exporter": "python",
948 | "pygments_lexer": "ipython3",
949 | "version": "3.8.8"
950 | }
951 | },
952 | "nbformat": 4,
953 | "nbformat_minor": 4
954 | }
955 |
--------------------------------------------------------------------------------