├── requirements.txt ├── download_original_data.sh ├── tokenizer-v1 ├── special_tokens_map.json └── tokenizer_config.json ├── tokenizer-v3 ├── special_tokens_map.json └── tokenizer_config.json ├── data └── label.txt ├── test_pipeline.py ├── config ├── koelectra-base-v1.json ├── koelectra-base-v3.json ├── koelectra-small-v1.json └── koelectra-small-v3.json ├── model.py ├── multilabel_pipeline.py ├── utils.py ├── .gitignore ├── translate_data.py ├── README.md ├── data_loader.py ├── LICENSE └── run_goemotions.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | transformers==3.5.1 3 | git+git://github.com/ssut/py-googletrans 4 | attrdict==2.0.1 -------------------------------------------------------------------------------- /download_original_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | svn export https://github.com/google-research/google-research/trunk/goemotions -------------------------------------------------------------------------------- /tokenizer-v1/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": "[UNK]", 3 | "sep_token": "[SEP]", 4 | "pad_token": "[PAD]", 5 | "cls_token": "[CLS]", 6 | "mask_token": "[MASK]", 7 | "additional_special_tokens": [ 8 | "[NAME]", 9 | "[RELIGION]" 10 | ] 11 | } -------------------------------------------------------------------------------- /tokenizer-v3/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": "[UNK]", 3 | "sep_token": "[SEP]", 4 | "pad_token": "[PAD]", 5 | "cls_token": "[CLS]", 6 | "mask_token": "[MASK]", 7 | "additional_special_tokens": [ 8 | "[NAME]", 9 | "[RELIGION]" 10 | ] 11 | } -------------------------------------------------------------------------------- /tokenizer-v1/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": false, 3 | "max_len": 512, 4 | "unk_token": "[UNK]", 5 | "sep_token": "[SEP]", 6 | "pad_token": "[PAD]", 7 | "cls_token": "[CLS]", 8 | "mask_token": "[MASK]", 9 | "additional_special_tokens": [ 10 | "[NAME]", 11 | "[RELIGION]" 12 | ] 13 | } -------------------------------------------------------------------------------- /tokenizer-v3/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_lower_case": false, 3 | "max_len": 512, 4 | "unk_token": "[UNK]", 5 | "sep_token": "[SEP]", 6 | "pad_token": "[PAD]", 7 | "cls_token": "[CLS]", 8 | "mask_token": "[MASK]", 9 | "additional_special_tokens": [ 10 | "[NAME]", 11 | "[RELIGION]" 12 | ] 13 | } -------------------------------------------------------------------------------- /data/label.txt: -------------------------------------------------------------------------------- 1 | admiration 2 | amusement 3 | anger 4 | annoyance 5 | approval 6 | caring 7 | confusion 8 | curiosity 9 | desire 10 | disappointment 11 | disapproval 12 | disgust 13 | embarrassment 14 | excitement 15 | fear 16 | gratitude 17 | grief 18 | joy 19 | love 20 | nervousness 21 | optimism 22 | pride 23 | realization 24 | relief 25 | remorse 26 | sadness 27 | surprise 28 | neutral -------------------------------------------------------------------------------- /test_pipeline.py: -------------------------------------------------------------------------------- 1 | from multilabel_pipeline import MultiLabelPipeline 2 | from transformers import ElectraTokenizer 3 | from model import ElectraForMultiLabelClassification 4 | from pprint import pprint 5 | 6 | 7 | tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-goemotions") 8 | model = ElectraForMultiLabelClassification.from_pretrained("monologg/koelectra-base-v3-goemotions") 9 | 10 | goemotions = MultiLabelPipeline( 11 | model=model, 12 | tokenizer=tokenizer, 13 | threshold=0.3 14 | ) 15 | 16 | texts = [ 17 | "전혀 재미 있지 않습니다 ...", 18 | "나는 “지금 가장 큰 두려움은 내 상자 안에 사는 것” 이라고 말했다.", 19 | "곱창... 한시간반 기다릴 맛은 아님!", 20 | "애정하는 공간을 애정하는 사람들로 채울때", 21 | "너무 좋아", 22 | "딥러닝을 짝사랑중인 학생입니다!", 23 | "마음이 급해진다.", 24 | "아니 진짜 다들 미쳤나봨ㅋㅋㅋ", 25 | "개노잼" 26 | ] 27 | 28 | pprint(goemotions(texts)) -------------------------------------------------------------------------------- /config/koelectra-base-v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data", 4 | "ckpt_dir": "ckpt", 5 | "train_file": "train.tsv", 6 | "dev_file": "dev.tsv", 7 | "test_file": "test.tsv", 8 | "label_file": "label.txt", 9 | "evaluate_test_during_training": false, 10 | "eval_all_checkpoints": true, 11 | "save_optimizer": false, 12 | "do_lower_case": false, 13 | "do_train": true, 14 | "do_eval": true, 15 | "max_seq_len": 50, 16 | "num_train_epochs": 15, 17 | "weight_decay": 0.0, 18 | "gradient_accumulation_steps": 1, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "max_steps": -1, 22 | "max_grad_norm": 1.0, 23 | "no_cuda": false, 24 | "model_type": "koelectra-base-v1", 25 | "model_name_or_path": "monologg/koelectra-base-discriminator", 26 | "output_dir": "koelectra-base-v1-goemotions-ckpt", 27 | "seed": 42, 28 | "train_batch_size": 32, 29 | "eval_batch_size": 64, 30 | "logging_steps": 500, 31 | "save_steps": 500, 32 | "learning_rate": 5e-5, 33 | "threshold": 0.3, 34 | "tokenizer_dir": "tokenizer-v1" 35 | } -------------------------------------------------------------------------------- /config/koelectra-base-v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data", 4 | "ckpt_dir": "ckpt", 5 | "train_file": "train.tsv", 6 | "dev_file": "dev.tsv", 7 | "test_file": "test.tsv", 8 | "label_file": "label.txt", 9 | "evaluate_test_during_training": false, 10 | "eval_all_checkpoints": true, 11 | "save_optimizer": false, 12 | "do_lower_case": false, 13 | "do_train": true, 14 | "do_eval": true, 15 | "max_seq_len": 50, 16 | "num_train_epochs": 15, 17 | "weight_decay": 0.0, 18 | "gradient_accumulation_steps": 1, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "max_steps": -1, 22 | "max_grad_norm": 1.0, 23 | "no_cuda": false, 24 | "model_type": "koelectra-base-v3", 25 | "model_name_or_path": "monologg/koelectra-base-v3-discriminator", 26 | "output_dir": "koelectra-base-v3-goemotions-ckpt", 27 | "seed": 42, 28 | "train_batch_size": 32, 29 | "eval_batch_size": 64, 30 | "logging_steps": 500, 31 | "save_steps": 500, 32 | "learning_rate": 5e-5, 33 | "threshold": 0.3, 34 | "tokenizer_dir": "tokenizer-v3" 35 | } -------------------------------------------------------------------------------- /config/koelectra-small-v1.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data", 4 | "ckpt_dir": "ckpt", 5 | "train_file": "train.tsv", 6 | "dev_file": "dev.tsv", 7 | "test_file": "test.tsv", 8 | "label_file": "label.txt", 9 | "evaluate_test_during_training": false, 10 | "eval_all_checkpoints": true, 11 | "save_optimizer": false, 12 | "do_lower_case": false, 13 | "do_train": true, 14 | "do_eval": true, 15 | "max_seq_len": 50, 16 | "num_train_epochs": 15, 17 | "weight_decay": 0.0, 18 | "gradient_accumulation_steps": 1, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "max_steps": -1, 22 | "max_grad_norm": 1.0, 23 | "no_cuda": false, 24 | "model_type": "koelectra-small-v1", 25 | "model_name_or_path": "monologg/koelectra-small-discriminator", 26 | "output_dir": "koelectra-small-v1-goemotions-ckpt", 27 | "seed": 42, 28 | "train_batch_size": 32, 29 | "eval_batch_size": 64, 30 | "logging_steps": 500, 31 | "save_steps": 500, 32 | "learning_rate": 5e-5, 33 | "threshold": 0.3, 34 | "tokenizer_dir": "tokenizer-v1" 35 | } -------------------------------------------------------------------------------- /config/koelectra-small-v3.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "goemotions", 3 | "data_dir": "data", 4 | "ckpt_dir": "ckpt", 5 | "train_file": "train.tsv", 6 | "dev_file": "dev.tsv", 7 | "test_file": "test.tsv", 8 | "label_file": "label.txt", 9 | "evaluate_test_during_training": false, 10 | "eval_all_checkpoints": true, 11 | "save_optimizer": false, 12 | "do_lower_case": false, 13 | "do_train": true, 14 | "do_eval": true, 15 | "max_seq_len": 50, 16 | "num_train_epochs": 15, 17 | "weight_decay": 0.0, 18 | "gradient_accumulation_steps": 1, 19 | "adam_epsilon": 1e-8, 20 | "warmup_steps": 0, 21 | "max_steps": -1, 22 | "max_grad_norm": 1.0, 23 | "no_cuda": false, 24 | "model_type": "koelectra-small-v3", 25 | "model_name_or_path": "monologg/koelectra-small-v3-discriminator", 26 | "output_dir": "koelectra-small-v3-goemotions-ckpt", 27 | "seed": 42, 28 | "train_batch_size": 32, 29 | "eval_batch_size": 64, 30 | "logging_steps": 500, 31 | "save_steps": 500, 32 | "learning_rate": 5e-5, 33 | "threshold": 0.3, 34 | "tokenizer_dir": "tokenizer-v3" 35 | } 36 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn import BCEWithLogitsLoss 3 | from transformers.modeling_electra import ElectraModel, ElectraPreTrainedModel 4 | 5 | 6 | class ElectraForMultiLabelClassification(ElectraPreTrainedModel): 7 | def __init__(self, config): 8 | super().__init__(config) 9 | self.num_labels = config.num_labels 10 | 11 | self.electra = ElectraModel(config) 12 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 13 | self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) 14 | self.loss_fct = BCEWithLogitsLoss() 15 | 16 | self.init_weights() 17 | 18 | def forward( 19 | self, 20 | input_ids=None, 21 | attention_mask=None, 22 | token_type_ids=None, 23 | position_ids=None, 24 | head_mask=None, 25 | inputs_embeds=None, 26 | labels=None, 27 | ): 28 | discriminator_hidden_states = self.electra( 29 | input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds 30 | ) 31 | pooled_output = discriminator_hidden_states[0][:, 0] 32 | 33 | pooled_output = self.dropout(pooled_output) 34 | logits = self.classifier(pooled_output) 35 | 36 | outputs = (logits,) + discriminator_hidden_states[1:] # add hidden states and attention if they are here 37 | 38 | if labels is not None: 39 | loss = self.loss_fct(logits, labels) 40 | outputs = (loss,) + outputs 41 | 42 | return outputs # (loss), logits, (hidden_states), (attentions) 43 | -------------------------------------------------------------------------------- /multilabel_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | import numpy as np 4 | from transformers.pipelines import ArgumentHandler 5 | from transformers import ( 6 | Pipeline, 7 | PreTrainedTokenizer, 8 | ModelCard 9 | ) 10 | 11 | 12 | class MultiLabelPipeline(Pipeline): 13 | def __init__( 14 | self, 15 | model: Union["PreTrainedModel", "TFPreTrainedModel"], 16 | tokenizer: PreTrainedTokenizer, 17 | modelcard: Optional[ModelCard] = None, 18 | framework: Optional[str] = None, 19 | task: str = "", 20 | args_parser: ArgumentHandler = None, 21 | device: int = -1, 22 | binary_output: bool = False, 23 | threshold: float = 0.3 24 | ): 25 | super().__init__( 26 | model=model, 27 | tokenizer=tokenizer, 28 | modelcard=modelcard, 29 | framework=framework, 30 | args_parser=args_parser, 31 | device=device, 32 | binary_output=binary_output, 33 | task=task 34 | ) 35 | 36 | self.threshold = threshold 37 | 38 | def __call__(self, *args, **kwargs): 39 | outputs = super().__call__(*args, **kwargs) 40 | scores = 1 / (1 + np.exp(-outputs)) # Sigmoid 41 | results = [] 42 | for item in scores: 43 | labels = [] 44 | scores = [] 45 | for idx, s in enumerate(item): 46 | if s > self.threshold: 47 | labels.append(self.model.config.id2label[idx]) 48 | scores.append(s) 49 | results.append({"labels": labels, "scores": scores}) 50 | return results 51 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import logging 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from sklearn.metrics import precision_recall_fscore_support, accuracy_score 8 | 9 | from model import ElectraForMultiLabelClassification 10 | 11 | from transformers import ( 12 | ElectraConfig, 13 | ElectraTokenizer, 14 | ) 15 | 16 | CONFIG_CLASSES = { 17 | "koelectra-small-v1": ElectraConfig, 18 | "koelectra-base-v1": ElectraConfig, 19 | "koelectra-small-v3": ElectraConfig, 20 | "koelectra-base-v3": ElectraConfig 21 | } 22 | 23 | TOKENIZER_CLASSES = { 24 | "koelectra-small-v1": ElectraTokenizer, 25 | "koelectra-base-v1": ElectraTokenizer, 26 | "koelectra-small-v3": ElectraTokenizer, 27 | "koelectra-base-v3": ElectraTokenizer 28 | } 29 | 30 | MODEL_CLASSES = { 31 | "koelectra-small-v1": ElectraForMultiLabelClassification, 32 | "koelectra-base-v1": ElectraForMultiLabelClassification, 33 | "koelectra-small-v3": ElectraForMultiLabelClassification, 34 | "koelectra-base-v3": ElectraForMultiLabelClassification 35 | } 36 | 37 | 38 | def init_logger(): 39 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 40 | datefmt='%m/%d/%Y %H:%M:%S', 41 | level=logging.INFO) 42 | 43 | 44 | def set_seed(args): 45 | random.seed(args.seed) 46 | np.random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | if not args.no_cuda and torch.cuda.is_available(): 49 | torch.cuda.manual_seed_all(args.seed) 50 | 51 | 52 | def compute_metrics(labels, preds): 53 | assert len(preds) == len(labels) 54 | results = dict() 55 | 56 | results["accuracy"] = accuracy_score(labels, preds) 57 | results["macro_precision"], results["macro_recall"], results[ 58 | "macro_f1"], _ = precision_recall_fscore_support( 59 | labels, preds, average="macro") 60 | results["micro_precision"], results["micro_recall"], results[ 61 | "micro_f1"], _ = precision_recall_fscore_support( 62 | labels, preds, average="micro") 63 | results["weighted_precision"], results["weighted_recall"], results[ 64 | "weighted_f1"], _ = precision_recall_fscore_support( 65 | labels, preds, average="weighted") 66 | 67 | return results 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | ########### 132 | .vscode/ 133 | .idea/ 134 | 135 | goemotions/ 136 | 137 | cached* 138 | ckpt/ 139 | 140 | koelectra-base-finetuned-goemotions/ 141 | koelectra-small-finetuned-goemotions/ -------------------------------------------------------------------------------- /translate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | from tqdm import tqdm 5 | from googletrans import Translator as GoogleTranslator 6 | 7 | ORIG_DATA_DIR = os.path.join("goemotions", "data") 8 | DATA_DIR = "data" 9 | 10 | TRAIN_FILE = "train.tsv" 11 | DEV_FILE = "dev.tsv" 12 | TEST_FILE = "test.tsv" 13 | 14 | TEXT_MAX_LENGTH = 5000 # google translate allows maximum size of 5000 for one request 15 | GOOGLE_TIME_TO_SLEEP = 1.5 16 | 17 | 18 | def make_chunks(sentence_lst): 19 | """ 20 | Chunk is a sentence that is not longer than TEXT_MAX_LENGTH 21 | By looping the list of sentences, we will make a new chunk which is not longer than TEXT_MAX_LENGTH, while as long as possible 22 | """ 23 | input_chunk_lst = [] 24 | chunk = "" 25 | for sentence in sentence_lst: 26 | sentence = sentence.strip() 27 | # https://www.reddit.com/r/OutOfTheLoop/comments/9abjhm/what_does_x200b_mean/ 28 | sentence = sentence.replace("​", "") # This one makes error 29 | sentence = sentence + "\r\n" 30 | if len((chunk.rstrip() + sentence).encode('utf-8')) > TEXT_MAX_LENGTH: 31 | input_chunk_lst.append(chunk.rstrip()) 32 | chunk = sentence 33 | else: 34 | chunk = chunk + sentence 35 | input_chunk_lst.append(chunk.rstrip()) 36 | return input_chunk_lst 37 | 38 | 39 | def get_sentence_lst(file_path): 40 | sentence_lst = [] 41 | label_lst = [] 42 | with open(file_path, "r", encoding="utf-8") as f: 43 | for line in f: 44 | line = line.strip() 45 | items = line.split("\t") 46 | sentence = items[0].strip() 47 | label = items[1] 48 | sentence_lst.append(sentence) 49 | label_lst.append(label) 50 | return sentence_lst, label_lst 51 | 52 | 53 | def google_translate(sentence_lst): 54 | input_chunk_lst = make_chunks(sentence_lst) 55 | trans = GoogleTranslator() 56 | translated_sentence_lst = [] 57 | 58 | for en_chunk in tqdm(input_chunk_lst): 59 | kr_chunk = trans.translate(en_chunk, src='en', dest='ko') 60 | kr_chunk = kr_chunk.text 61 | kr_sentences = kr_chunk.split("\r\n") 62 | if kr_sentences[-1] == "": 63 | kr_sentences = kr_sentences[:-1] 64 | time.sleep(GOOGLE_TIME_TO_SLEEP) 65 | 66 | translated_sentence_lst.extend(kr_sentences) 67 | 68 | return translated_sentence_lst 69 | 70 | 71 | def make_translate_data(orig_file_path, translated_file_path): 72 | sentence_lst, label_lst = get_sentence_lst(orig_file_path) 73 | translate_sentence_lst = google_translate(sentence_lst) 74 | 75 | assert len(translate_sentence_lst) == len(label_lst) 76 | 77 | with open(translated_file_path, "w", encoding="utf-8") as f: 78 | for (translated_sent, label) in zip(translate_sentence_lst, label_lst): 79 | f.write("{}\t{}\n".format(translated_sent, label)) 80 | 81 | print("Translating {} done".format(orig_file_path)) 82 | 83 | 84 | if __name__ == "__main__": 85 | if not os.path.exists(DATA_DIR): 86 | os.mkdir(DATA_DIR) 87 | 88 | make_translate_data( 89 | os.path.join(ORIG_DATA_DIR, TRAIN_FILE), 90 | os.path.join(DATA_DIR, TRAIN_FILE) 91 | ) 92 | 93 | make_translate_data( 94 | os.path.join(ORIG_DATA_DIR, DEV_FILE), 95 | os.path.join(DATA_DIR, DEV_FILE) 96 | ) 97 | 98 | make_translate_data( 99 | os.path.join(ORIG_DATA_DIR, TEST_FILE), 100 | os.path.join(DATA_DIR, TEST_FILE) 101 | ) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GoEmotions-Korean 2 | 3 | [GoEmotions](https://github.com/google-research/google-research/tree/master/goemotions) 데이터셋을 한국어로 번역한 후, [KoELECTRA](https://github.com/monologg/KoELECTRA)로 학습 4 | 5 | ## Updates 6 | 7 | **June 19, 2020** - Transformers v2.9.1 기준으로 모델 학습 시 `[NAME]`, `[RELIGION]`과 같은 Special token을 추가하였음에도 pipeline에서 다시 사용할 때 적용이 되지 않는 이슈가 있었으나, Transformers v2.11.0에서 해당 이슈가 해결되었습니다. 8 | 9 | **Feb 9, 2021** - Transformers v3.5.1 기준으로 `KoELECTRA-v1`, `KoELECTRA-v3`를 가지고 학습하여 새로 모델을 업로드 하였습니다. 10 | 11 | ## GoEmotions 12 | 13 | **58000개의 Reddit comments**를 **28개의 emotion**으로 라벨링한 데이터셋 14 | 15 | - admiration, amusement, anger, annoyance, approval, caring, confusion, curiosity, desire, disappointment, disapproval, disgust, embarrassment, excitement, fear, gratitude, grief, joy, love, nervousness, optimism, pride, realization, relief, remorse, sadness, surprise, neutral 16 | 17 | ## Requirements 18 | 19 | - torch==1.7.1 20 | - transformers=3.5.1 21 | - googletrans==2.4.1 22 | - attrdict==2.0.1 23 | 24 | ```bash 25 | $ pip3 install -r requirements.txt 26 | ``` 27 | 28 | ## Translated Data 29 | 30 | 🚨 **Reddit 댓글로 만든 데이터여서 번역된 결과물의 품질이 좋지 않습니다.** 🚨 31 | 32 | - [pygoogletrans](https://github.com/ssut/py-googletrans)를 사용하여 한국어 데이터 생성 33 | - `pygoogletrans v2.4.1`이 pypi에 업데이트되지 않은 관계로 repository에서 곧바로 라이브러리를 설치하는 것을 권장 (`requirements.txt`에 명시되어 있음) 34 | - API 호출 간에 1.5초의 간격을 주었습니다. 35 | - 한 번의 request에 최대 5000자를 넣을 수 있는 점을 고려하여 문장들을 `\r\n`으로 이어 붙여 input으로 넣었습니다. 36 | - `​​​`(Zero-width space)가 번역 문장 안에 있으면 번역이 되지 않는 오류가 있어서 이는 제거하였습니다. 37 | - **번역을 완료한 데이터는 `data` 디렉토리에 이미 있습니다.** 혹여나 직접 번역을 돌리고 싶다면 아래의 명령어를 실행하면 됩니다. 38 | 39 | ```bash 40 | $ bash download_original_data.sh 41 | $ pip3 install git+git://github.com/ssut/py-googletrans 42 | $ python3 tranlate_data.py 43 | ``` 44 | 45 | ## Tokenizer 46 | 47 | - 데이터셋에 `[NAME]`, `[RELIGION]`의 Special Token이 존재하여, 이를 `vocab.txt`의 `[unused0]`와 `[unused1]`에 각각 할당하였습니다. 48 | 49 | ## Train & Evaluation 50 | 51 | - Sigmoid를 적용한 Multi-label classification (**threshold는 0.3으로 지정**) 52 | - `model.py`의 `ElectraForMultiLabelClassification` 참고 53 | - config의 경우 `config` 디렉토리의 json 파일에서 변경하면 됩니다. 54 | 55 | ```bash 56 | $ python3 run_goemotions.py --config_file koelectra-base.json 57 | $ python3 run_goemotions.py --config_file koelectra-small.json 58 | ``` 59 | 60 | ## Results 61 | 62 | `Macro F1`을 기준으로 결과 측정 (Best result) 63 | 64 | | Macro F1 (%) | Dev | Test | 65 | | ---------------------- | :---: | :-------: | 66 | | **KoELECTRA-small-v1** | 39.99 | **41.02** | 67 | | **KoELECTRA-base-v1** | 42.18 | **44.03** | 68 | | **KoELECTRA-small-v3** | 40.27 | **40.85** | 69 | | **KoELECTRA-base-v3** | 42.85 | **42.28** | 70 | 71 | ## Pipeline 72 | 73 | - `MultiLabelPipeline` 클래스를 새로 만들어 Multi-label classification에 대한 inference가 가능하게 하였습니다. 74 | - Huggingface s3에 모델을 업로드하였습니다. 75 | - `monologg/koelectra-small-v1-goemotions` 76 | - `monologg/koelectra-base-v1-goemotions` 77 | - `monologg/koelectra-small-v3-goemotions` 78 | - `monologg/koelectra-base-v3-goemotions` 79 | 80 | ```python 81 | from multilabel_pipeline import MultiLabelPipeline 82 | from transformers import ElectraTokenizer 83 | from model import ElectraForMultiLabelClassification 84 | from pprint import pprint 85 | 86 | 87 | tokenizer = ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-goemotions") 88 | model = ElectraForMultiLabelClassification.from_pretrained("monologg/koelectra-base-v3-goemotions") 89 | 90 | goemotions = MultiLabelPipeline( 91 | model=model, 92 | tokenizer=tokenizer, 93 | threshold=0.3 94 | ) 95 | 96 | texts = [ 97 | "전혀 재미 있지 않습니다 ...", 98 | "나는 “지금 가장 큰 두려움은 내 상자 안에 사는 것” 이라고 말했다.", 99 | "곱창... 한시간반 기다릴 맛은 아님!", 100 | "애정하는 공간을 애정하는 사람들로 채울때", 101 | "너무 좋아", 102 | "딥러닝을 짝사랑중인 학생입니다!", 103 | "마음이 급해진다.", 104 | "아니 진짜 다들 미쳤나봨ㅋㅋㅋ", 105 | "개노잼" 106 | ] 107 | 108 | pprint(goemotions(texts)) 109 | 110 | # Output 111 | [{'labels': ['disapproval'], 'scores': [0.97151965]}, 112 | {'labels': ['fear'], 'scores': [0.9519822]}, 113 | {'labels': ['disapproval', 'neutral'], 'scores': [0.452921, 0.5345312]}, 114 | {'labels': ['love'], 'scores': [0.8750478]}, 115 | {'labels': ['admiration'], 'scores': [0.93127275]}, 116 | {'labels': ['love'], 'scores': [0.9093589]}, 117 | {'labels': ['nervousness', 'neutral'], 'scores': [0.76960915, 0.33462417]}, 118 | {'labels': ['disapproval'], 'scores': [0.95657086]}, 119 | {'labels': ['annoyance', 'disgust'], 'scores': [0.39240348, 0.7896941]}] 120 | ``` 121 | 122 | ## Reference 123 | 124 | - [GoEmotions](https://github.com/google-research/google-research/tree/master/goemotions) 125 | - [KoELECTRA](https://github.com/monologg/KoELECTRA) 126 | - [googletrans](https://github.com/ssut/py-googletrans) 127 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import json 4 | import logging 5 | 6 | import torch 7 | from torch.utils.data import TensorDataset 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class InputExample(object): 13 | """ 14 | A single training/test example for simple sequence classification. 15 | """ 16 | 17 | def __init__(self, guid, text_a, text_b, label): 18 | self.guid = guid 19 | self.text_a = text_a 20 | self.text_b = text_b 21 | self.label = label 22 | 23 | def __repr__(self): 24 | return str(self.to_json_string()) 25 | 26 | def to_dict(self): 27 | """Serializes this instance to a Python dictionary.""" 28 | output = copy.deepcopy(self.__dict__) 29 | return output 30 | 31 | def to_json_string(self): 32 | """Serializes this instance to a JSON string.""" 33 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 34 | 35 | 36 | class InputFeatures(object): 37 | """A single set of features of data.""" 38 | 39 | def __init__(self, input_ids, attention_mask, token_type_ids, label): 40 | self.input_ids = input_ids 41 | self.attention_mask = attention_mask 42 | self.token_type_ids = token_type_ids 43 | self.label = label 44 | 45 | def __repr__(self): 46 | return str(self.to_json_string()) 47 | 48 | def to_dict(self): 49 | """Serializes this instance to a Python dictionary.""" 50 | output = copy.deepcopy(self.__dict__) 51 | return output 52 | 53 | def to_json_string(self): 54 | """Serializes this instance to a JSON string.""" 55 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 56 | 57 | 58 | def convert_examples_to_features( 59 | args, 60 | examples, 61 | tokenizer, 62 | max_length, 63 | ): 64 | processor = GoEmotionsProcessor(args) 65 | label_list_len = len(processor.get_labels()) 66 | 67 | def convert_to_one_hot_label(label): 68 | one_hot_label = [0] * label_list_len 69 | for l in label: 70 | one_hot_label[l] = 1 71 | return one_hot_label 72 | 73 | labels = [convert_to_one_hot_label(example.label) for example in examples] 74 | 75 | batch_encoding = tokenizer.batch_encode_plus( 76 | [(example.text_a, example.text_b) for example in examples], max_length=max_length, pad_to_max_length=True 77 | ) 78 | 79 | features = [] 80 | for i in range(len(examples)): 81 | inputs = {k: batch_encoding[k][i] for k in batch_encoding} 82 | 83 | feature = InputFeatures(**inputs, label=labels[i]) 84 | features.append(feature) 85 | 86 | for i, example in enumerate(examples[:10]): 87 | logger.info("*** Example ***") 88 | logger.info("guid: {}".format(example.guid)) 89 | logger.info("sentence: {}".format(example.text_a)) 90 | logger.info("tokens: {}".format(" ".join([str(x) for x in tokenizer.tokenize(example.text_a)]))) 91 | logger.info("input_ids: {}".format(" ".join([str(x) for x in features[i].input_ids]))) 92 | logger.info("attention_mask: {}".format(" ".join([str(x) for x in features[i].attention_mask]))) 93 | logger.info("token_type_ids: {}".format(" ".join([str(x) for x in features[i].token_type_ids]))) 94 | logger.info("label: {}".format(" ".join([str(x) for x in features[i].label]))) 95 | 96 | return features 97 | 98 | 99 | class GoEmotionsProcessor(object): 100 | """Processor for the GoEmotions data set """ 101 | 102 | def __init__(self, args): 103 | self.args = args 104 | 105 | def get_labels(self): 106 | labels = [] 107 | with open(os.path.join(self.args.data_dir, self.args.label_file), "r", encoding="utf-8") as f: 108 | for line in f: 109 | labels.append(line.rstrip()) 110 | return labels 111 | 112 | @classmethod 113 | def _read_file(cls, input_file): 114 | """Reads a tab separated value file.""" 115 | with open(input_file, "r", encoding="utf-8") as f: 116 | return f.readlines() 117 | 118 | def _create_examples(self, lines, set_type): 119 | """Creates examples for the training and dev sets.""" 120 | examples = [] 121 | for (i, line) in enumerate(lines): 122 | guid = "%s-%s" % (set_type, i) 123 | line = line.strip() 124 | items = line.split("\t") 125 | text_a = items[0] 126 | label = list(map(int, items[1].split(","))) 127 | if i % 5000 == 0: 128 | logger.info(line) 129 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 130 | return examples 131 | 132 | def get_examples(self, mode): 133 | """ 134 | Args: 135 | mode: train, dev, test 136 | """ 137 | file_to_read = None 138 | if mode == 'train': 139 | file_to_read = self.args.train_file 140 | elif mode == 'dev': 141 | file_to_read = self.args.dev_file 142 | elif mode == 'test': 143 | file_to_read = self.args.test_file 144 | 145 | logger.info("LOOKING AT {}".format(os.path.join(self.args.data_dir, file_to_read))) 146 | return self._create_examples(self._read_file(os.path.join(self.args.data_dir, 147 | file_to_read)), mode) 148 | 149 | 150 | def load_and_cache_examples(args, tokenizer, mode): 151 | processor = GoEmotionsProcessor(args) 152 | # Load data features from cache or dataset file 153 | cached_features_file = os.path.join( 154 | args.data_dir, 155 | "cached_{}_{}_{}_{}".format( 156 | str(args.task), 157 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 158 | str(args.max_seq_len), 159 | mode 160 | ) 161 | ) 162 | if os.path.exists(cached_features_file): 163 | logger.info("Loading features from cached file %s", cached_features_file) 164 | features = torch.load(cached_features_file) 165 | else: 166 | logger.info("Creating features from dataset file at %s", args.data_dir) 167 | if mode == "train": 168 | examples = processor.get_examples("train") 169 | elif mode == "dev": 170 | examples = processor.get_examples("dev") 171 | elif mode == "test": 172 | examples = processor.get_examples("test") 173 | else: 174 | raise ValueError("For mode, only train, dev, test is available") 175 | features = convert_examples_to_features( 176 | args, examples, tokenizer, max_length=args.max_seq_len 177 | ) 178 | logger.info("Saving features into cached file %s", cached_features_file) 179 | torch.save(features, cached_features_file) 180 | 181 | # Convert to Tensors and build dataset 182 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 183 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 184 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 185 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 186 | 187 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 188 | return dataset 189 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /run_goemotions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import glob 6 | 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 10 | from tqdm import tqdm, trange 11 | from attrdict import AttrDict 12 | 13 | from transformers import ( 14 | AdamW, 15 | get_linear_schedule_with_warmup 16 | ) 17 | 18 | from utils import ( 19 | CONFIG_CLASSES, 20 | TOKENIZER_CLASSES, 21 | MODEL_CLASSES, 22 | init_logger, 23 | set_seed, 24 | compute_metrics 25 | ) 26 | 27 | from data_loader import ( 28 | load_and_cache_examples, 29 | GoEmotionsProcessor 30 | ) 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | def train(args, 36 | model, 37 | tokenizer, 38 | train_dataset, 39 | dev_dataset=None, 40 | test_dataset=None): 41 | train_sampler = RandomSampler(train_dataset) 42 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 43 | if args.max_steps > 0: 44 | t_total = args.max_steps 45 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 46 | else: 47 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 48 | 49 | # Prepare optimizer and schedule (linear warmup and decay) 50 | no_decay = ['bias', 'LayerNorm.weight'] 51 | optimizer_grouped_parameters = [ 52 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 53 | 'weight_decay': args.weight_decay}, 54 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 55 | ] 56 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 57 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 58 | 59 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 60 | os.path.join(args.model_name_or_path, "scheduler.pt") 61 | ): 62 | # Load optimizer and scheduler states 63 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 64 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 65 | 66 | # Train! 67 | logger.info("***** Running training *****") 68 | logger.info(" Num examples = %d", len(train_dataset)) 69 | logger.info(" Num Epochs = %d", args.num_train_epochs) 70 | logger.info(" Total train batch size = %d", args.train_batch_size) 71 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 72 | logger.info(" Total optimization steps = %d", t_total) 73 | logger.info(" Logging steps = %d", args.logging_steps) 74 | logger.info(" Save steps = %d", args.save_steps) 75 | 76 | global_step = 0 77 | tr_loss = 0.0 78 | 79 | model.zero_grad() 80 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 81 | for _ in train_iterator: 82 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 83 | for step, batch in enumerate(epoch_iterator): 84 | model.train() 85 | batch = tuple(t.to(args.device) for t in batch) 86 | inputs = { 87 | "input_ids": batch[0], 88 | "attention_mask": batch[1], 89 | "labels": batch[3] 90 | } 91 | if args.model_type not in ["distilkobert", "xlm-roberta"]: 92 | inputs["token_type_ids"] = batch[2] 93 | outputs = model(**inputs) 94 | 95 | loss = outputs[0] 96 | 97 | if args.gradient_accumulation_steps > 1: 98 | loss = loss / args.gradient_accumulation_steps 99 | 100 | loss.backward() 101 | tr_loss += loss.item() 102 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 103 | len(train_dataloader) <= args.gradient_accumulation_steps 104 | and (step + 1) == len(train_dataloader) 105 | ): 106 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 107 | 108 | optimizer.step() 109 | scheduler.step() 110 | model.zero_grad() 111 | global_step += 1 112 | 113 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 114 | if args.evaluate_test_during_training: 115 | evaluate(args, model, test_dataset, "test", global_step) 116 | else: 117 | evaluate(args, model, dev_dataset, "dev", global_step) 118 | 119 | if args.save_steps > 0 and global_step % args.save_steps == 0: 120 | # Save model checkpoint 121 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 122 | if not os.path.exists(output_dir): 123 | os.makedirs(output_dir) 124 | model_to_save = ( 125 | model.module if hasattr(model, "module") else model 126 | ) 127 | model_to_save.save_pretrained(output_dir) 128 | tokenizer.save_pretrained(output_dir) 129 | 130 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 131 | logger.info("Saving model checkpoint to {}".format(output_dir)) 132 | 133 | if args.save_optimizer: 134 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 135 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 136 | logger.info("Saving optimizer and scheduler states to {}".format(output_dir)) 137 | 138 | if args.max_steps > 0 and global_step > args.max_steps: 139 | break 140 | 141 | if args.max_steps > 0 and global_step > args.max_steps: 142 | break 143 | 144 | return global_step, tr_loss / global_step 145 | 146 | 147 | def evaluate(args, model, eval_dataset, mode, global_step=None): 148 | results = {} 149 | eval_sampler = SequentialSampler(eval_dataset) 150 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 151 | 152 | # Eval! 153 | if global_step != None: 154 | logger.info("***** Running evaluation on {} dataset ({} step) *****".format(mode, global_step)) 155 | else: 156 | logger.info("***** Running evaluation on {} dataset *****".format(mode)) 157 | logger.info(" Num examples = {}".format(len(eval_dataset))) 158 | logger.info(" Eval Batch size = {}".format(args.eval_batch_size)) 159 | eval_loss = 0.0 160 | nb_eval_steps = 0 161 | preds = None 162 | out_label_ids = None 163 | 164 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 165 | model.eval() 166 | batch = tuple(t.to(args.device) for t in batch) 167 | 168 | with torch.no_grad(): 169 | inputs = { 170 | "input_ids": batch[0], 171 | "attention_mask": batch[1], 172 | "labels": batch[3] 173 | } 174 | if args.model_type not in ["distilkobert", "xlm-roberta"]: 175 | inputs["token_type_ids"] = batch[2] 176 | outputs = model(**inputs) 177 | tmp_eval_loss, logits = outputs[:2] 178 | 179 | eval_loss += tmp_eval_loss.mean().item() 180 | nb_eval_steps += 1 181 | if preds is None: 182 | preds = 1 / (1 + np.exp(-logits.detach().cpu().numpy())) # Sigmoid 183 | out_label_ids = inputs["labels"].detach().cpu().numpy() 184 | else: 185 | preds = np.append(preds, 1 / (1 + np.exp(-logits.detach().cpu().numpy())), axis=0) # Sigmoid 186 | out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0) 187 | 188 | eval_loss = eval_loss / nb_eval_steps 189 | results = { 190 | "loss": eval_loss 191 | } 192 | preds[preds > args.threshold] = 1 193 | preds[preds <= args.threshold] = 0 194 | result = compute_metrics(out_label_ids, preds) 195 | results.update(result) 196 | 197 | output_dir = os.path.join(args.output_dir, mode) 198 | if not os.path.exists(output_dir): 199 | os.makedirs(output_dir) 200 | 201 | output_eval_file = os.path.join(output_dir, "{}-{}.txt".format(mode, global_step) if global_step else "{}.txt".format(mode)) 202 | with open(output_eval_file, "w") as f_w: 203 | logger.info("***** Eval results on {} dataset *****".format(mode)) 204 | for key in sorted(results.keys()): 205 | logger.info(" {} = {}".format(key, str(results[key]))) 206 | f_w.write(" {} = {}\n".format(key, str(results[key]))) 207 | 208 | return results 209 | 210 | 211 | def main(cli_args): 212 | # Read from config file and make args 213 | with open(os.path.join(cli_args.config_dir, cli_args.config_file)) as f: 214 | args = AttrDict(json.load(f)) 215 | logger.info("Training/evaluation parameters {}".format(args)) 216 | 217 | args.output_dir = os.path.join(args.ckpt_dir, args.output_dir) 218 | 219 | init_logger() 220 | set_seed(args) 221 | 222 | processor = GoEmotionsProcessor(args) 223 | label_list = processor.get_labels() 224 | 225 | config = CONFIG_CLASSES[args.model_type].from_pretrained( 226 | args.model_name_or_path, 227 | num_labels=len(label_list), 228 | finetuning_task=args.task, 229 | id2label={str(i): label for i, label in enumerate(label_list)}, 230 | label2id={label: i for i, label in enumerate(label_list)} 231 | ) 232 | tokenizer = TOKENIZER_CLASSES[args.model_type].from_pretrained( 233 | args.tokenizer_dir, 234 | ) 235 | tokenizer.add_special_tokens({"additional_special_tokens": ["[NAME]", "[RELIGION]"]}) 236 | model = MODEL_CLASSES[args.model_type].from_pretrained( 237 | args.model_name_or_path, 238 | config=config 239 | ) 240 | 241 | # GPU or CPU 242 | args.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" 243 | model.to(args.device) 244 | 245 | # Load dataset 246 | train_dataset = load_and_cache_examples(args, tokenizer, mode="train") if args.train_file else None 247 | dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev") if args.dev_file else None 248 | test_dataset = load_and_cache_examples(args, tokenizer, mode="test") if args.test_file else None 249 | 250 | if dev_dataset == None: 251 | args.evaluate_test_during_training = True # If there is no dev dataset, only use testset 252 | 253 | if args.do_train: 254 | global_step, tr_loss = train(args, model, tokenizer, train_dataset, dev_dataset, test_dataset) 255 | logger.info(" global_step = {}, average loss = {}".format(global_step, tr_loss)) 256 | 257 | results = {} 258 | if args.do_eval: 259 | checkpoints = list( 260 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + "pytorch_model.bin", recursive=True)) 261 | ) 262 | if not args.eval_all_checkpoints: 263 | checkpoints = checkpoints[-1:] 264 | else: 265 | logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN) # Reduce logging 266 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 267 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 268 | for checkpoint in checkpoints: 269 | global_step = checkpoint.split("-")[-1] 270 | model = MODEL_CLASSES[args.model_type].from_pretrained(checkpoint) 271 | model.to(args.device) 272 | result = evaluate(args, model, test_dataset, mode="test", global_step=global_step) 273 | result = dict((k + "_{}".format(global_step), v) for k, v in result.items()) 274 | results.update(result) 275 | 276 | output_eval_file = os.path.join(args.output_dir, "eval_results.txt") 277 | with open(output_eval_file, "w") as f_w: 278 | for key in sorted(results.keys()): 279 | f_w.write("{} = {}\n".format(key, str(results[key]))) 280 | 281 | 282 | if __name__ == '__main__': 283 | cli_parser = argparse.ArgumentParser() 284 | 285 | cli_parser.add_argument("--config_dir", type=str, default="config") 286 | cli_parser.add_argument("--config_file", type=str, required=True) 287 | 288 | cli_args = cli_parser.parse_args() 289 | 290 | main(cli_args) 291 | --------------------------------------------------------------------------------