├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ └── feature.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── LICENSE ├── README.md ├── docker └── Dockerfile ├── examples ├── README.md └── nsmc.py ├── imgs ├── bart.png └── kobart_summ.png ├── kobart ├── __init__.py ├── pytorch_kobart.py └── utils │ ├── __init__.py │ ├── aws_s3_downloader.py │ └── utils.py ├── requirements.txt └── setup.py /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: 버그 관련 리포팅을 합니다. 4 | title: "[BUG] " 5 | labels: bug 6 | assignees: "" 7 | --- 8 | 9 | ## 🐛 Bug 10 | 11 | 12 | ## To Reproduce 13 | 14 | 15 | 버그를 재현하기 위한 재현절차를 작성해주세요. 16 | 17 | 1. - 18 | 2. - 19 | 3. - 20 | 21 | ## Expected behavior 22 | 23 | 24 | ## Environment 25 | 26 | 27 | ## Additional context 28 | 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature 3 | about: 개발할 기능에 대해 서술합니다. 4 | title: "[FEATURE] " 5 | labels: enhancement 6 | assignees: "" 7 | --- 8 | 9 | ## 🚀 Feature 10 | 11 | 12 | ## Motivation 13 | 14 | 15 | ## Pitch 16 | 17 | 18 | ## Additional context 19 | 20 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Pull Request 2 | 레파지토리에 기여해주셔서 감사드립니다. 3 | 4 | 해당 PR을 제출하기 전에 아래 사항이 완료되었는지 확인 부탁드립니다: 5 | - [ ] 작성한 코드가 어떤 에러나 경고없이 빌드가 되었나요? 6 | - [ ] 충분한 테스트를 수행하셨나요? 7 | 8 | ## 1. 해당 PR은 어떤 내용인가요? 9 | 10 | 11 | ## 2. PR과 관련된 이슈가 있나요? 12 | 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified MIT License 2 | 3 | Software Copyright (c) 2020 SK telecom 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 6 | associated documentation files (the "Software"), to deal in the Software without restriction, 7 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 9 | subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included 12 | in all copies or substantial portions of the Software. 13 | The above copyright notice and this permission notice need not be included 14 | with content created by the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 17 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 19 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 20 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 21 | OR OTHER DEALINGS IN THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤣 KoBART 2 | 3 | * [🤣 KoBART](#-kobart) 4 | * [How to install](#how-to-install) 5 | * [Data](#data) 6 | * [Tokenizer](#tokenizer) 7 | * [Model](#model) 8 | * [Performances](#performances) 9 | * [Classification or Regression](#classification-or-regression) 10 | * [Summarization](#summarization) 11 | * [Demos](#demos) 12 | * [Examples](#examples) 13 | * [Release](#release) 14 | * [Contacts](#contacts) 15 | * [License](#license) 16 | 17 | [**BART**](https://arxiv.org/pdf/1910.13461.pdf)(**B**idirectional and **A**uto-**R**egressive **T**ransformers)는 입력 텍스트 일부에 노이즈를 추가하여 이를 다시 원문으로 복구하는 `autoencoder`의 형태로 학습이 됩니다. 한국어 BART(이하 **KoBART**) 는 논문에서 사용된 `Text Infilling` 노이즈 함수를 사용하여 **40GB** 이상의 한국어 텍스트에 대해서 학습한 한국어 `encoder-decoder` 언어 모델입니다. 이를 통해 도출된 `KoBART-base`를 배포합니다. 18 | 19 | ![bart](imgs/bart.png) 20 | 21 | ## How to install 22 | 23 | ```bash 24 | pip install git+https://github.com/SKT-AI/KoBART#egg=kobart 25 | ``` 26 | 27 | ## Data 28 | 29 | | Data | # of Sentences | 30 | | ------------ | -------------: | 31 | | Korean Wiki | 5M | 32 | | Other corpus | 0.27B | 33 | 34 | 한국어 위키 백과 이외, 뉴스, 책, [모두의 말뭉치 v1.0(대화, 뉴스, ...)](https://corpus.korean.go.kr/) 등의 다양한 데이터가 모델 학습에 사용되었습니다. 35 | 36 | ## Tokenizer 37 | 38 | [`tokenizers`](https://github.com/huggingface/tokenizers) 패키지의 `Character BPE tokenizer`로 학습되었습니다. 39 | 40 | `vocab` 사이즈는 30,000 이며 대화에 자주 쓰이는 아래와 같은 이모티콘, 이모지 등을 추가하여 해당 토큰의 인식 능력을 올렸습니다. 41 | > 😀, 😁, 😆, 😅, 🤣, .. , `:-)`, `:)`, `-)`, `(-:`... 42 | 43 | 또한 `` ~ ``등의 미사용 토큰을 정의해, 필요한 `subtasks`에 따라 자유롭게 정의해 사용할 수 있게 했습니다. 44 | 45 | ```python 46 | >>> from kobart import get_kobart_tokenizer 47 | >>> kobart_tokenizer = get_kobart_tokenizer() 48 | >>> kobart_tokenizer.tokenize("안녕하세요. 한국어 BART 입니다.🤣:)l^o") 49 | ['▁안녕하', '세요.', '▁한국어', '▁B', 'A', 'R', 'T', '▁입', '니다.', '🤣', ':)', 'l^o'] 50 | ``` 51 | 52 | ## Model 53 | 54 | | Model | # of params | Type | # of layers | # of heads | ffn_dim | hidden_dims | 55 | | ------------- | :---------: | :-----: | ----------: | ---------: | ------: | ----------: | 56 | | `KoBART-base` | 124M | Encoder | 6 | 16 | 3072 | 768 | 57 | | | | Decoder | 6 | 16 | 3072 | 768 | 58 | 59 | ```python 60 | >>> from transformers import BartModel 61 | >>> from kobart import get_pytorch_kobart_model, get_kobart_tokenizer 62 | >>> kobart_tokenizer = get_kobart_tokenizer() 63 | >>> model = BartModel.from_pretrained(get_pytorch_kobart_model()) 64 | >>> inputs = kobart_tokenizer(['안녕하세요.'], return_tensors='pt') 65 | >>> model(inputs['input_ids']) 66 | Seq2SeqModelOutput(last_hidden_state=tensor([[[-0.4418, -4.3673, 3.2404, ..., 5.8832, 4.0629, 3.5540], 67 | [-0.1316, -4.6446, 2.5955, ..., 6.0093, 2.7467, 3.0007]]], 68 | grad_fn=), past_key_values=((tensor([[[[-9.7980e-02, -6.6584e-01, -1.8089e+00, ..., 9.6023e-01, -1.8818e-01, -1.3252e+00], 69 | ``` 70 | 71 | ### Performances 72 | 73 | #### Classification or Regression 74 | 75 | | | [NSMC](https://github.com/e9t/nsmc)(acc) | [KorSTS](https://github.com/kakaobrain/KorNLUDatasets)(spearman) | [Question Pair](https://github.com/aisolab/nlp_classification/tree/master/BERT_pairwise_text_classification/qpair)(acc) | 76 | | --------------- | ---------------------------------------- | ---------------------------------------------------------------- | ------------------------------------------------------------------------------------ 77 | ----------------------------------- | 78 | | **KoBART-base** | 90.24 | 81.66 | 94.34 | 79 | 80 | #### Summarization 81 | 82 | * 업데이트 예정 * 83 | 84 | ## Demos 85 | 86 | * [요약 데모](https://huggingface.co/spaces/gogamza/kobart-summarization) 87 | 88 | 89 | 90 | *위 예시는 [ZDNET 기사](https://zdnet.co.kr/view/?no=20201125093328)를 요약한 결과임* 91 | 92 | ## Examples 93 | 94 | * [NSMC Classification](https://github.com/SKT-AI/KoBART/tree/main/examples) 95 | * [KoBART ChitChatBot](https://github.com/haven-jeon/KoBART-chatbot) 96 | * [KoBART Summarization](https://github.com/seujung/KoBART-summarization) 97 | * [KoBART Translation](https://github.com/seujung/KoBART-translation) 98 | * [LegalQA using Sentence**KoBART**](https://github.com/haven-jeon/LegalQA) 99 | * [KoBART Question Generation](https://github.com/Seoneun/KoBART-Question-Generation) 100 | 101 | *KoBART를 사용한 흥미로운 예제가 있다면 PR주세요!* 102 | 103 | ## Release 104 | 105 | * v0.5.1 106 | * guide default 'import statements' 107 | * v0.5 108 | * download large files from `aws s3` 109 | * v0.4 110 | * Update model binary 111 | * v0.3 112 | * 토크나이저 버그로 인해 `` 토큰이 사라지는 이슈 해결 113 | * v0.2 114 | * `KoBART` 모델 업데이트(서브테스트 sample efficient가 좋아짐) 115 | * `모두의 말뭉치` 사용 버전 명시 116 | * downloder 버그 수정 117 | * `pip` 설치 지원 118 | 119 | ## Contacts 120 | 121 | `KoBART` 관련 이슈는 [이곳](https://github.com/SKT-AI/KoBART/issues)에 올려주세요. 122 | 123 | ## License 124 | 125 | `KoBART`는 `modified MIT` 라이선스 하에 공개되어 있습니다. 모델 및 코드를 사용할 경우 라이선스 내용을 준수해주세요. 라이선스 전문은 `LICENSE` 파일에서 확인하실 수 있습니다. 126 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.05-py3 2 | 3 | WORKDIR $HOME/KoBART/examples 4 | 5 | RUN pip install pytorch-lightning==1.2.1 transformers==4.3.3 boto3 6 | 7 | ENTRYPOINT [ "/bin/sh", "-c", "/bin/bash" ] 8 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # NSMC classification example 2 | 3 | ## Run on a GPU machine 4 | 5 | - build a docker image 6 | 7 | ```bash 8 | docker build -t kobart -f docker/Dockerfile . 9 | ``` 10 | 11 | - run a docker container 12 | 13 | ```bash 14 | cd ~/KoBART # root directory of this repository 15 | docker run --gpus '"device=0"' --rm -it \ 16 | -v $HOME/KoBART:$HOME/KoBART \ 17 | -e PYTHONPATH="$HOME/KoBART" \ 18 | -w "$HOME/KoBART/examples" \ 19 | --name "kobart" \ 20 | kobart /bin/sh 21 | ``` 22 | 23 | - finetune KoBART model with NSMC 24 | 25 | - :warning: run on the docker container 26 | 27 | - finetune 28 | 29 | ```bash 30 | python nsmc.py --gpus 1 --max_epochs 3 --default_root_dir .cache --gradient_clip_val 1.0 31 | ``` 32 | -------------------------------------------------------------------------------- /examples/nsmc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified MIT License 3 | 4 | # Software Copyright (c) 2020 SK telecom 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 7 | # associated documentation files (the "Software"), to deal in the Software without restriction, 8 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # The above copyright notice and this permission notice need not be included 15 | # with content created by the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 22 | # OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | import argparse 25 | import logging 26 | import os 27 | import pandas as pd 28 | import numpy as np 29 | import torch 30 | from torch.utils.data import DataLoader, Dataset 31 | import pytorch_lightning as pl 32 | from pytorch_lightning import loggers as pl_loggers 33 | from transformers.optimization import AdamW, get_cosine_schedule_with_warmup 34 | from transformers import BartForSequenceClassification 35 | 36 | from kobart import get_kobart_tokenizer, get_pytorch_kobart_model 37 | from kobart import download 38 | 39 | 40 | logger = logging.getLogger() 41 | logger.setLevel(logging.INFO) 42 | 43 | 44 | class ArgsBase: 45 | @staticmethod 46 | def add_model_specific_args(parent_parser): 47 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 48 | parser.add_argument("--batch_size", type=int, default=128, help="") 49 | parser.add_argument("--max_seq_len", type=int, default=128, help="") 50 | return parser 51 | 52 | 53 | class NSMCDataset(Dataset): 54 | def __init__(self, filepath, max_seq_len=128): 55 | self.filepath = filepath 56 | self.data = pd.read_csv(self.filepath, sep="\t") 57 | self.max_seq_len = max_seq_len 58 | self.tokenizer = get_kobart_tokenizer() 59 | 60 | def __len__(self): 61 | return len(self.data) 62 | 63 | def __getitem__(self, index): 64 | record = self.data.iloc[index] 65 | document, label = str(record["document"]), int(record["label"]) 66 | tokens = ( 67 | [self.tokenizer.bos_token] 68 | + self.tokenizer.tokenize(document) 69 | + [self.tokenizer.eos_token] 70 | ) 71 | encoder_input_id = self.tokenizer.convert_tokens_to_ids(tokens) 72 | attention_mask = [1] * len(encoder_input_id) 73 | if len(encoder_input_id) < self.max_seq_len: 74 | while len(encoder_input_id) < self.max_seq_len: 75 | encoder_input_id += [self.tokenizer.pad_token_id] 76 | attention_mask += [0] 77 | else: 78 | encoder_input_id = encoder_input_id[: self.max_seq_len - 1] + [ 79 | self.tokenizer.eos_token_id 80 | ] 81 | attention_mask = attention_mask[: self.max_seq_len] 82 | return { 83 | "input_ids": np.array(encoder_input_id, dtype=np.int_), 84 | "attention_mask": np.array(attention_mask, dtype=float), 85 | "labels": np.array(label, dtype=np.int_), 86 | } 87 | 88 | 89 | class NSMCDataModule(pl.LightningDataModule): 90 | def __init__(self, max_seq_len=128, batch_size=32): 91 | super().__init__() 92 | self.batch_size = batch_size 93 | self.max_seq_len = max_seq_len 94 | 95 | s3_train_file = { 96 | "url": "s3://skt-lsl-nlp-model/KoBART/datasets/nsmc/ratings_train.txt", 97 | "chksum": None, 98 | } 99 | s3_test_file = { 100 | "url": "s3://skt-lsl-nlp-model/KoBART/datasets/nsmc/ratings_test.txt", 101 | "chksum": None, 102 | } 103 | 104 | os.makedirs(os.path.dirname(args.cachedir), exist_ok=True) 105 | self.train_file_path, is_cached = download( 106 | s3_train_file["url"], s3_train_file["chksum"], cachedir=args.cachedir 107 | ) 108 | self.test_file_path, is_cached = download( 109 | s3_test_file["url"], s3_test_file["chksum"], cachedir=args.cachedir 110 | ) 111 | 112 | @staticmethod 113 | def add_model_specific_args(parent_parser): 114 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 115 | return parser 116 | 117 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 118 | def setup(self, stage): 119 | # split dataset 120 | self.nsmc_train = NSMCDataset(self.train_file_path, self.max_seq_len) 121 | self.nsmc_test = NSMCDataset(self.test_file_path, self.max_seq_len) 122 | 123 | # return the dataloader for each split 124 | def train_dataloader(self): 125 | nsmc_train = DataLoader( 126 | self.nsmc_train, batch_size=self.batch_size, num_workers=5, shuffle=True 127 | ) 128 | return nsmc_train 129 | 130 | def val_dataloader(self): 131 | nsmc_val = DataLoader( 132 | self.nsmc_test, batch_size=self.batch_size, num_workers=5, shuffle=False 133 | ) 134 | return nsmc_val 135 | 136 | def test_dataloader(self): 137 | nsmc_test = DataLoader( 138 | self.nsmc_test, batch_size=self.batch_size, num_workers=5, shuffle=False 139 | ) 140 | return nsmc_test 141 | 142 | 143 | class Classification(pl.LightningModule): 144 | def __init__(self, hparams, **kwargs) -> None: 145 | super(Classification, self).__init__() 146 | self.hparams = hparams 147 | 148 | @staticmethod 149 | def add_model_specific_args(parent_parser): 150 | # add model specific args 151 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 152 | 153 | parser.add_argument( 154 | "--batch-size", 155 | type=int, 156 | default=32, 157 | help="batch size for training (default: 96)", 158 | ) 159 | 160 | parser.add_argument( 161 | "--lr", type=float, default=5e-5, help="The initial learning rate" 162 | ) 163 | 164 | parser.add_argument( 165 | "--warmup_ratio", type=float, default=0.1, help="warmup ratio" 166 | ) 167 | 168 | return parser 169 | 170 | def configure_optimizers(self): 171 | # Prepare optimizer 172 | param_optimizer = list(self.model.named_parameters()) 173 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 174 | optimizer_grouped_parameters = [ 175 | { 176 | "params": [ 177 | p for n, p in param_optimizer if not any(nd in n for nd in no_decay) 178 | ], 179 | "weight_decay": 0.01, 180 | }, 181 | { 182 | "params": [ 183 | p for n, p in param_optimizer if any(nd in n for nd in no_decay) 184 | ], 185 | "weight_decay": 0.0, 186 | }, 187 | ] 188 | optimizer = AdamW( 189 | optimizer_grouped_parameters, lr=self.hparams.lr, correct_bias=False 190 | ) 191 | # warm up lr 192 | num_workers = (self.hparams.gpus if self.hparams.gpus is not None else 1) * ( 193 | self.hparams.num_nodes if self.hparams.num_nodes is not None else 1 194 | ) 195 | data_len = len(self.train_dataloader().dataset) 196 | logging.info(f"number of workers {num_workers}, data length {data_len}") 197 | num_train_steps = int( 198 | data_len 199 | / ( 200 | self.hparams.batch_size 201 | * num_workers 202 | * self.hparams.accumulate_grad_batches 203 | ) 204 | * self.hparams.max_epochs 205 | ) 206 | logging.info(f"num_train_steps : {num_train_steps}") 207 | num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio) 208 | logging.info(f"num_warmup_steps : {num_warmup_steps}") 209 | scheduler = get_cosine_schedule_with_warmup( 210 | optimizer, 211 | num_warmup_steps=num_warmup_steps, 212 | num_training_steps=num_train_steps, 213 | ) 214 | lr_scheduler = { 215 | "scheduler": scheduler, 216 | "monitor": "loss", 217 | "interval": "step", 218 | "frequency": 1, 219 | } 220 | return [optimizer], [lr_scheduler] 221 | 222 | 223 | class KoBARTClassification(Classification): 224 | def __init__(self, hparams, **kwargs): 225 | super(KoBARTClassification, self).__init__(hparams, **kwargs) 226 | self.model = BartForSequenceClassification.from_pretrained( 227 | get_pytorch_kobart_model() 228 | ) 229 | self.model.train() 230 | self.metric_acc = pl.metrics.classification.Accuracy() 231 | 232 | def forward(self, input_ids, attention_mask, labels=None): 233 | return self.model( 234 | input_ids=input_ids, 235 | attention_mask=attention_mask, 236 | labels=labels, 237 | return_dict=True, 238 | ) 239 | 240 | def training_step(self, batch, batch_idx): 241 | outs = self(batch["input_ids"], batch["attention_mask"], batch["labels"]) 242 | loss = outs.loss 243 | self.log("train_loss", loss, prog_bar=True) 244 | return loss 245 | 246 | def validation_step(self, batch, batch_idx): 247 | pred = self(batch["input_ids"], batch["attention_mask"]) 248 | labels = batch["labels"] 249 | accuracy = self.metric_acc( 250 | torch.nn.functional.softmax(pred.logits, dim=1), labels 251 | ) 252 | self.log("accuracy", accuracy) 253 | result = {"accuracy": accuracy} 254 | # Checkpoint model based on validation loss 255 | return result 256 | 257 | def validation_epoch_end(self, outputs): 258 | val_acc = torch.stack([i["accuracy"] for i in outputs]).mean() 259 | self.log("val_acc", val_acc, prog_bar=True) 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser(description="subtask for KoBART") 264 | parser.add_argument( 265 | "--cachedir", type=str, default=os.path.join(os.getcwd(), ".cache") 266 | ) 267 | parser.add_argument("--subtask", type=str, default="NSMC", help="NSMC") 268 | parser = Classification.add_model_specific_args(parser) 269 | parser = ArgsBase.add_model_specific_args(parser) 270 | parser = NSMCDataModule.add_model_specific_args(parser) 271 | parser = pl.Trainer.add_argparse_args(parser) 272 | args = parser.parse_args() 273 | logging.info(args) 274 | 275 | if args.default_root_dir is None: 276 | args.default_root_dir = args.cachedir 277 | 278 | # init model 279 | model = KoBARTClassification(args) 280 | 281 | if args.subtask == "NSMC": 282 | # init data 283 | dm = NSMCDataModule( 284 | batch_size=args.batch_size, 285 | max_seq_len=args.max_seq_len, 286 | ) 287 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 288 | monitor="val_acc", 289 | dirpath=args.default_root_dir, 290 | filename="model_chp/{epoch:02d}-{val_acc:.3f}", 291 | verbose=True, 292 | save_last=True, 293 | mode="max", 294 | save_top_k=-1, 295 | prefix=f"{args.subtask}", 296 | ) 297 | else: 298 | # add more subtasks 299 | assert False 300 | tb_logger = pl_loggers.TensorBoardLogger( 301 | os.path.join(args.default_root_dir, "tb_logs") 302 | ) 303 | # train 304 | lr_logger = pl.callbacks.LearningRateMonitor() 305 | trainer = pl.Trainer.from_argparse_args( 306 | args, logger=tb_logger, callbacks=[checkpoint_callback, lr_logger] 307 | ) 308 | trainer.fit(model, dm) 309 | -------------------------------------------------------------------------------- /imgs/bart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SKT-AI/KoBART/eec563bfccf723cae8fd0fff02d5b2b09e847516/imgs/bart.png -------------------------------------------------------------------------------- /imgs/kobart_summ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SKT-AI/KoBART/eec563bfccf723cae8fd0fff02d5b2b09e847516/imgs/kobart_summ.png -------------------------------------------------------------------------------- /kobart/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified MIT License 3 | 4 | # Software Copyright (c) 2020 SK telecom 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 7 | # associated documentation files (the "Software"), to deal in the Software without restriction, 8 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # The above copyright notice and this permission notice need not be included 15 | # with content created by the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 22 | # OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | from kobart.utils.utils import download 25 | from kobart.pytorch_kobart import get_pytorch_kobart_model, get_kobart_tokenizer 26 | 27 | __all__ = ("download", "get_kobart_tokenizer", "get_pytorch_kobart_model") 28 | -------------------------------------------------------------------------------- /kobart/pytorch_kobart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified MIT License 3 | 4 | # Software Copyright (c) 2020 SK telecom 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 7 | # associated documentation files (the "Software"), to deal in the Software without restriction, 8 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # The above copyright notice and this permission notice need not be included 15 | # with content created by the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 22 | # OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | import os 25 | import shutil 26 | from zipfile import ZipFile 27 | from transformers import PreTrainedTokenizerFast 28 | 29 | from kobart import download 30 | 31 | 32 | def get_pytorch_kobart_model(ctx="cpu", cachedir=".cache"): 33 | pytorch_kobart = { 34 | "url": "s3://skt-lsl-nlp-model/KoBART/models/kobart_base_cased_ff4bda5738.zip", 35 | "chksum": "ff4bda5738", 36 | } 37 | model_zip, is_cached = download( 38 | pytorch_kobart["url"], pytorch_kobart["chksum"], cachedir=cachedir 39 | ) 40 | cachedir_full = os.path.join(os.getcwd(), cachedir) 41 | model_path = os.path.join(cachedir_full, "kobart_from_pretrained") 42 | if not os.path.exists(model_path) or not is_cached: 43 | if not is_cached: 44 | shutil.rmtree(model_path, ignore_errors=True) 45 | zipf = ZipFile(os.path.expanduser(model_zip)) 46 | zipf.extractall(path=cachedir_full) 47 | return model_path 48 | 49 | 50 | def get_kobart_tokenizer(cachedir=".cache"): 51 | """Get KoGPT2 Tokenizer file path after downloading""" 52 | tokenizer = { 53 | "url": "s3://skt-lsl-nlp-model/KoBART/tokenizers/kobart_base_tokenizer_cased_cf74400bce.zip", 54 | "chksum": "cf74400bce", 55 | } 56 | file_path, is_cached = download( 57 | tokenizer["url"], tokenizer["chksum"], cachedir=cachedir 58 | ) 59 | cachedir_full = os.path.expanduser(cachedir) 60 | if ( 61 | not os.path.exists(os.path.join(cachedir_full, "emji_tokenizer")) 62 | or not is_cached 63 | ): 64 | if not is_cached: 65 | shutil.rmtree( 66 | os.path.join(cachedir_full, "emji_tokenizer"), ignore_errors=True 67 | ) 68 | zipf = ZipFile(os.path.expanduser(file_path)) 69 | zipf.extractall(path=cachedir_full) 70 | tok_path = os.path.join(cachedir_full, "emji_tokenizer/model.json") 71 | tokenizer_obj = PreTrainedTokenizerFast( 72 | tokenizer_file=tok_path, 73 | bos_token="", 74 | eos_token="", 75 | unk_token="", 76 | pad_token="", 77 | mask_token="", 78 | ) 79 | return tokenizer_obj 80 | 81 | 82 | if __name__ == "__main__": 83 | # pip install git+https://github.com/SKT-AI/KoBART#egg=kobart 84 | from transformers import BartModel 85 | from kobart import get_pytorch_kobart_model, get_kobart_tokenizer 86 | 87 | kobart_tokenizer = get_kobart_tokenizer() 88 | print(kobart_tokenizer.tokenize("안녕하세요. 한국어 BART 입니다.🤣:)l^o")) 89 | 90 | model = BartModel.from_pretrained(get_pytorch_kobart_model()) 91 | inputs = kobart_tokenizer(["안녕하세요."], return_tensors="pt") 92 | print(model(inputs["input_ids"])) 93 | -------------------------------------------------------------------------------- /kobart/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from kobart.utils.utils import download 2 | -------------------------------------------------------------------------------- /kobart/utils/aws_s3_downloader.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import os 3 | import sys 4 | from botocore import UNSIGNED 5 | from botocore.client import Config 6 | 7 | 8 | class AwsS3Downloader(object): 9 | def __init__( 10 | self, 11 | aws_access_key_id=None, 12 | aws_secret_access_key=None, 13 | ): 14 | self.resource = boto3.Session( 15 | aws_access_key_id=aws_access_key_id, 16 | aws_secret_access_key=aws_secret_access_key, 17 | ).resource("s3") 18 | self.client = boto3.client( 19 | "s3", 20 | aws_access_key_id=aws_access_key_id, 21 | aws_secret_access_key=aws_secret_access_key, 22 | config=Config(signature_version=UNSIGNED), 23 | ) 24 | 25 | def __split_url(self, url: str): 26 | if url.startswith("s3://"): 27 | url = url.replace("s3://", "") 28 | bucket, key = url.split("/", maxsplit=1) 29 | return bucket, key 30 | 31 | def download(self, url: str, local_dir: str): 32 | bucket, key = self.__split_url(url) 33 | filename = os.path.basename(key) 34 | file_path = os.path.join(local_dir, filename) 35 | 36 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 37 | meta_data = self.client.head_object(Bucket=bucket, Key=key) 38 | total_length = int(meta_data.get("ContentLength", 0)) 39 | 40 | downloaded = 0 41 | 42 | def progress(chunk): 43 | nonlocal downloaded 44 | downloaded += chunk 45 | done = int(50 * downloaded / total_length) 46 | sys.stdout.write( 47 | "\r{}[{}{}]".format(file_path, "█" * done, "." * (50 - done)) 48 | ) 49 | sys.stdout.flush() 50 | 51 | try: 52 | with open(file_path, "wb") as f: 53 | self.client.download_fileobj(bucket, key, f, Callback=progress) 54 | sys.stdout.write("\n") 55 | sys.stdout.flush() 56 | except: 57 | raise Exception(f"downloading file is failed. {url}") 58 | return file_path 59 | 60 | 61 | if __name__ == "__main__": 62 | s3 = AwsS3Downloader() 63 | 64 | s3.download( 65 | url="s3://skt-lsl-nlp-model/KoBART/models/kobart_base_cased_ff4bda5738.zip", 66 | local_dir=".cache", 67 | ) 68 | -------------------------------------------------------------------------------- /kobart/utils/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified MIT License 3 | 4 | # Software Copyright (c) 2020 SK telecom 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 7 | # associated documentation files (the "Software"), to deal in the Software without restriction, 8 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # The above copyright notice and this permission notice need not be included 15 | # with content created by the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 22 | # OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | import hashlib 25 | import os 26 | 27 | from kobart.utils.aws_s3_downloader import AwsS3Downloader 28 | 29 | 30 | def download(url, chksum=None, cachedir=".cache"): 31 | cachedir_full = os.path.join(os.getcwd(), cachedir) 32 | os.makedirs(cachedir_full, exist_ok=True) 33 | filename = os.path.basename(url) 34 | file_path = os.path.join(cachedir_full, filename) 35 | if os.path.isfile(file_path): 36 | if hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10] == chksum: 37 | print(f"using cached model. {file_path}") 38 | return file_path, True 39 | 40 | s3 = AwsS3Downloader() 41 | file_path = s3.download(url, cachedir_full) 42 | if chksum: 43 | assert ( 44 | chksum == hashlib.md5(open(file_path, "rb").read()).hexdigest()[:10] 45 | ), "corrupted file!" 46 | return file_path, False 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | boto3 2 | pandas 3 | pytorch-lightning == 1.2.1 4 | torch == 1.7.1 5 | transformers == 4.3.3 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified MIT License 3 | 4 | # Software Copyright (c) 2020 SK telecom 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 7 | # associated documentation files (the "Software"), to deal in the Software without restriction, 8 | # including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 10 | # subject to the following conditions: 11 | 12 | # The above copyright notice and this permission notice shall be included 13 | # in all copies or substantial portions of the Software. 14 | # The above copyright notice and this permission notice need not be included 15 | # with content created by the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 18 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 20 | # BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 21 | # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 22 | # OR OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | from setuptools import find_packages, setup 25 | 26 | 27 | def install_requires(): 28 | with open("requirements.txt") as f: 29 | lines = f.read().splitlines() 30 | install_requires = [line for line in lines] 31 | return install_requires 32 | 33 | 34 | setup( 35 | name="kobart", 36 | version="0.5.1", 37 | url="https://github.com/SKT-AI/KoBART.git", 38 | license="midified MIT", 39 | author="Heewon Jeon", 40 | author_email="madjakarta@gmail.com", 41 | description="KoBART (Korean BART)", 42 | packages=find_packages(where=".", exclude=("tests", "scripts", "examples")), 43 | long_description=open("README.md", encoding="utf-8").read(), 44 | zip_safe=False, 45 | include_package_data=True, 46 | python_requires=">=3.6", 47 | install_requires=install_requires(), 48 | ) 49 | --------------------------------------------------------------------------------