├── tests ├── __init__.py └── test_first.py ├── requirements.txt ├── data ├── dev_example.txt └── train_example.txt ├── chosung_translator ├── __init__.py ├── utils.py ├── config.py └── data.py ├── pyproject.toml ├── .vscode └── extensions.json ├── requirements-dev.txt ├── .github ├── ISSUE_TEMPLATE │ ├── okr.md │ ├── proposal.md │ ├── feature_request.md │ └── bug_report.md ├── pull_request_template.md └── workflows │ └── lint-and-format.yml ├── setup.cfg ├── tox.ini ├── setup.py ├── pyrightconfig.json ├── README.md ├── .gitignore ├── run_inference.py └── run_train.py /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | torch 3 | tqdm 4 | -------------------------------------------------------------------------------- /data/dev_example.txt: -------------------------------------------------------------------------------- 1 | 안녕하세여 2 | 감사해여 3 | 잘 있어여 4 | 다시 만나여! 5 | -------------------------------------------------------------------------------- /data/train_example.txt: -------------------------------------------------------------------------------- 1 | 안녕하세요 2 | 감사해요 3 | 잘 있어요 4 | 다시 만나요! 5 | -------------------------------------------------------------------------------- /tests/test_first.py: -------------------------------------------------------------------------------- 1 | def test_first(): 2 | assert 1 == 1 3 | -------------------------------------------------------------------------------- /chosung_translator/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | __author__ = "ScatterLab" 3 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | include = '\.py$' 5 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-pyright.pyright", 4 | "ms-python.python" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # dev dependency 2 | isort 3 | black 4 | flake8 5 | 6 | pytest 7 | pytest-cov 8 | 9 | coveralls 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/okr.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: ✍️ OKR 이슈 3 | about: OKR과 연동되는 이슈는 이 템플릿을 이용해주세요! 4 | title: "" 5 | --- 6 | 7 | ## 담당자 & Due Date 8 | 9 | * TO: @ 10 | * CC: @ 11 | * DUE: 00/00(N요일) 12 | 13 | ## 추가 메모 (선택사항) 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, W503 4 | 5 | [tool:isort] 6 | line_length = 120 7 | multi_line_output = 3 8 | include_trailing_comma = True 9 | 10 | [tool:pytest] 11 | addopts = -ra -v -l 12 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37 3 | 4 | [testenv] 5 | deps = 6 | -r requirements.txt 7 | -r requirements-dev.txt 8 | commands = 9 | black --check chosung_translator tests 10 | flake8 chosung_translator tests 11 | isort -c chosung_translator tests 12 | pytest 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/proposal.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🙌 제안 사항 3 | about: 이것을 고쳤으면 좋겠다! 4 | title: "" 5 | --- 6 | 7 | ## 어떻게 바꾸면 좋겠나요? 🤩 8 | 9 | 최대한 간결하고 명확하게 원하는 방법을 작성해주세요! 10 | 11 | ## 왜 바꾸면 좋겠나요? 🤔 (선택사항) 12 | 13 | 최대한 간결하고 명확하게 현재의 문제점 또는 부족한 점을 작성해주세요. 14 | 15 | ## 추가로 알아야 할 것을 알려주세요! 🥺 (선택사항) 16 | 17 | 다른 사람이 안다면 좋을 정보를 여기에 적어주세요! 18 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## 무엇이 변경되었나요? 🎉 2 | 3 | - 화면 XX에서 YY 추가 4 | - `kk` 파일에서 `zz` 추가/제거/수정 5 | 6 | ## 관련된 이슈 혹은 PR은 무엇인가요? 🔍 7 | 8 | [GitHub 문서의 "Closing Issues Using Keywords"](https://help.github.com/en/articles/closing-issues-using-keywords)를 참고해주세요! 9 | 10 | ## 추가로 알아야 할 것을 알려주세요! 🥺 (선택사항) 11 | 12 | 다른 사람이 안다면 좋을 정보를 여기에 적어주세요! 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="chosung-translator", 5 | version="0.0.1", 6 | description="초성해석기", 7 | install_requires=[], 8 | url="https://github.com/noowad93/chosung-translator", 9 | author="Dawoon Jung", 10 | author_email="dawoon@scatterlab.co.kr", 11 | packages=find_packages(exclude=["tests"]), 12 | ) 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🙏 기능 추가 요청 3 | about: 이것 좀 만들어 주실래요? 4 | title: "" 5 | --- 6 | 7 | ## 현재의 문제점을 적어주세요! 🤩 8 | 9 | 최대한 간결하고 명확하게 현재의 문제점을 작성해주세요. 10 | 11 | ## 어떻게 해결할 수 있을까요? 🤔 12 | 13 | 최대한 간결하고 명확하게 원하는 해결 방법을 작성해주세요! 14 | 15 | ## 만약 다르게 해결할 수 있다면? 🤭 (선택사항) 16 | 17 | 만약 다른 해결 방법이 있다면 무엇인지 명확하게 작성해주세요! 18 | 19 | ## 추가로 알아야 할 것을 알려주세요! 🥺 (선택사항) 20 | 21 | 다른 사람이 안다면 좋을 정보를 여기에 적어주세요! 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐞 버그 리포트 3 | about: 버그 리포트는 여기에 적어주세요. 4 | title: "" 5 | --- 6 | 7 | ## 어떤 버그인가요? 🤯 8 | 9 | 버그를 간단하고 최대한 명확하게 요약해주세요! 10 | 11 | ## 어떻게 버그를 재현하나요? 🤔 12 | 13 | 버그를 재현하는 방법: 14 | 15 | 1. '...'으로 가서 16 | 2. '....'을 누르고 17 | 3. '....'으로 이동하면 18 | 4. 쨘! 버그다! 🤗 19 | 20 | ## 원래는 어떻게 동작해야 했을까요? 😢 21 | 22 | 정상적으로 작동한다면 어떻게 되어야 할지, 최대한 간결하고 명확한 설명을 작성해주세요! 23 | 24 | ## 추가로 알아야 할 것을 알려주세요! 🥺 (선택사항) 25 | 26 | 다른 사람이 안다면 좋을 정보를 여기에 적어주세요! 27 | -------------------------------------------------------------------------------- /chosung_translator/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | CHOSUNG_LIST = ["ㄱ", "ㄲ", "ㄴ", "ㄷ", "ㄸ", "ㄹ", "ㅁ", "ㅂ", "ㅃ", "ㅅ", "ㅆ", "ㅇ", "ㅈ", "ㅉ", "ㅊ", "ㅋ", "ㅌ", "ㅍ", "ㅎ"] 4 | 5 | 6 | def load_data(file_path: str) -> List[str]: 7 | texts = [] 8 | with open(file_path, "r") as f: 9 | for line in f: 10 | texts.append(line.strip()) 11 | return texts 12 | 13 | 14 | def convert_text_to_chosung(text: str) -> str: 15 | chosung_text = "" 16 | for c in text: 17 | if ord("가") <= ord(c) <= ord("힣"): 18 | chosung_text += CHOSUNG_LIST[(ord(c) - ord("가")) // 588] 19 | else: 20 | chosung_text += c 21 | return chosung_text 22 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": ["chosung_translator"], 3 | "venvPath": "env", 4 | 5 | "reportUnknownVariableType": true, 6 | "reportUnknownMemberType": true, 7 | "reportUnusedImport": true, 8 | "reportUnusedVariable": true, 9 | "reportUnusedClass": true, 10 | "reportUnusedFunction": true, 11 | "reportImportCycles": true, 12 | "reportTypeshedErrors": true, 13 | "reportOptionalMemberAccess": true, 14 | "reportUntypedBaseClass": true, 15 | "reportPrivateUsage": true, 16 | "reportConstantRedefinition": true, 17 | "reportInvalidStringEscapeSequence": true, 18 | "reportUnnecessaryIsInstance": true, 19 | "reportUnnecessaryCast": true, 20 | "reportAssertAlwaysTrue": true, 21 | "reportSelfClsParameterName": true, 22 | 23 | "pythonVersion": "3.7", 24 | "pythonPlatform": "Linux" 25 | } 26 | -------------------------------------------------------------------------------- /.github/workflows/lint-and-format.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Format Python 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt -r requirements-dev.txt 23 | - name: Lint with flake8 24 | run: | 25 | flake8 chosung_translator tests run_train.py run_inference.py 26 | - name: Check an order of import statements with isort 27 | run: | 28 | isort -c chosung_translator tests run_train.py run_inference.py 29 | - name: Check the code formatting with black 30 | run: | 31 | black --check chosung_translator tests run_train.py run_inference.py 32 | -------------------------------------------------------------------------------- /chosung_translator/config.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class TrainConfig(NamedTuple): 5 | """ 6 | Training Hyperparameters 7 | """ 8 | 9 | #: random seed 10 | seed: int = 42 11 | #: 사용할 gpu 갯수 12 | gpus: int = 1 13 | #: epoch 도는 횟수 14 | num_epochs: int = 5 15 | #: 훈련 시의 batch size 16 | batch_size: int = 64 17 | #: learning rate 18 | learning_rate: float = 5e-5 19 | #: warm up 20 | warmup_ratio: float = 0.0 21 | #: num workers 22 | num_workers: int = 20 23 | #: max seq len 24 | max_seq_len: int = 48 25 | 26 | train_log_interval: int = 100 27 | dev_log_interval: int = 1000 28 | save_interval: int = 1000 29 | """ 30 | Data Hyperparameters 31 | """ 32 | #: training data 파일 경로 33 | train_file_path: str = "./data/train_example.txt" 34 | #: dev data 파일 경로 35 | dev_file_path: str = "./data/dev_example.txt" 36 | pretrained_model_name: str = "hyunwoongko/kobart" 37 | #: 모델이 저장될 경로 38 | save_model_file_prefix: str = "./checkpoints/chosung_translator" 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 초성 해석기 2 | ## 개요 3 | 한국어 초성만으로 이루어진 문장을 입력하면, 완성된 문장을 예측하는 초성 해석기입니다. 4 | ```text 5 | 초성: ㄴㄴ ㄴㄹ ㅈㅇㅎ 6 | 예측 문장: 나는 너를 좋아해 7 | ``` 8 | ## 모델 9 | 모델은 SKT-AI에서 공개한 [Ko-BART](https://github.com/SKT-AI/KoBART)를 이용합니다. 10 | ## 데이터 11 | 문장 단위로 이루어진 아무 코퍼스나 사용가능합니다. 단, 모델의 추론 성능은 데이터의 도메인이나 데이터의 양에 크게 의존하기 때문에 원하는 모델 성능에 맞는 코퍼스를 사용해주세요. 12 | `./data` 디렉토리에 더미 데이터셋을 추가해두었으니, 더미 데이터셋과 동일한 형식의 코퍼스를 준비해두시면 됩니다. 13 | ## 학습 14 | 15 | ```sh 16 | python run_train.py 17 | ``` 18 | 19 | ## 추론 20 | ```sh 21 | python run_inference.py --finetuned-model-path $FINETUNED_MODEL_PATH 22 | ``` 23 | ## 예시 24 | - 공개된 코퍼스로 학습한 모델의 추론 결과입니다. 25 | ```text 26 | 초성: ㅂㄱㅍㄷ 예측 문장: 배고픈데 27 | 초성: ㅂㄱㅍㄷ 예측 문장: 배고프다 28 | 초성: ㅂㄱㅍㄷ 예측 문장: 배고프대 29 | 30 | 초성: ㄴㅁㄴㅁ ㅅㄹㅎㅇ 예측 문장: 너무너무 사랑해요 31 | 초성: ㄴㅁㄴㅁ ㅅㄹㅎㅇ 예측 문장: 너무너무 사랑했어 32 | 초성: ㄴㅁㄴㅁ ㅅㄹㅎㅇ 예측 문장: 나만너무 사랑해요 33 | 34 | 초성: ㄴㄴ ㄴㄹ ㅈㅇㅎ 예측 문장: 나는 너를 좋아해 35 | 초성: ㄴㄴ ㄴㄹ ㅈㅇㅎ 예측 문장: 누나 나랑 좋아해 36 | 초성: ㄴㄴ ㄴㄹ ㅈㅇㅎ 예측 문장: 너는 나를 좋아해 37 | ``` 38 | 39 | ## Notes 40 | - 본 레포는 별도의 학습 데이터를 포함하고 있지 않습니다. 41 | - 본 레포의 라이센스는 [Ko-BART](https://github.com/SKT-AI/KoBART)의 `modified-MIT` 라이센스를 따릅니다. 42 | 43 | ## Todo 44 | - 테스트 코드 추가 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # editor specific 107 | .vscode/settings.json 108 | .vscode/sftp.json -------------------------------------------------------------------------------- /chosung_translator/data.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from transformers import PreTrainedTokenizerFast 6 | 7 | from chosung_translator.utils import convert_text_to_chosung 8 | 9 | 10 | class ChosungTranslatorDataset(Dataset): 11 | def __init__(self, texts: List[str], tokenizer: PreTrainedTokenizerFast, max_seq_len: int = 48): 12 | self.texts = texts 13 | self.max_seq_len = max_seq_len 14 | self.tokenizer = tokenizer 15 | 16 | def __getitem__(self, index: int) -> Tuple[torch.Tensor, ...]: 17 | chosung_text = convert_text_to_chosung(self.texts[index]) 18 | 19 | tokenized_chosung_text = self.tokenizer.tokenize(chosung_text) 20 | encoder_input_ids = self.tokenizer.convert_tokens_to_ids(tokenized_chosung_text) 21 | encoder_input_ids = encoder_input_ids[: self.max_seq_len - 1] 22 | 23 | tokenized_text = self.tokenizer.tokenize(self.texts[index]) 24 | token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) 25 | token_ids = token_ids[: self.max_seq_len - 1] 26 | 27 | decoder_input_ids = [self.tokenizer.bos_token_id] + token_ids 28 | decoder_output_ids = token_ids + [self.tokenizer.eos_token_id] 29 | 30 | padded_encoder_input_ids = torch.tensor( 31 | encoder_input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - len(encoder_input_ids)), 32 | dtype=torch.long, 33 | ) 34 | padded_decoder_input_ids = torch.tensor( 35 | decoder_input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - len(decoder_input_ids)), 36 | dtype=torch.long, 37 | ) 38 | padded_decoder_output_ids = torch.tensor( 39 | decoder_output_ids + [-100] * (self.max_seq_len - len(decoder_output_ids)), 40 | dtype=torch.long, 41 | ) 42 | encoder_attention_mask = torch.tensor( 43 | [1] * len(encoder_input_ids) + [0] * (self.max_seq_len - len(encoder_input_ids)), dtype=torch.long 44 | ) 45 | decoder_attention_mask = torch.tensor( 46 | [1] * len(decoder_input_ids) + [0] * (self.max_seq_len - len(decoder_input_ids)), dtype=torch.long 47 | ) 48 | return tuple( 49 | ( 50 | padded_encoder_input_ids, 51 | encoder_attention_mask, 52 | padded_decoder_input_ids, 53 | padded_decoder_output_ids, 54 | decoder_attention_mask, 55 | ) 56 | ) 57 | 58 | def __len__(self): 59 | return len(self.texts) 60 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast 5 | 6 | from chosung_translator.utils import convert_text_to_chosung 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--finetuned-model-path", type=str, help="Path to fine-tuned model", required=True) 10 | parser.add_argument( 11 | "--decoding-method", 12 | default="beam_search", 13 | type=str, 14 | help="Decoding method (beam_search or top_p)", 15 | ) 16 | 17 | 18 | def main(args): 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/kobart") 21 | 22 | model = BartForConditionalGeneration.from_pretrained(args.finetuned_model_path) 23 | model.eval() 24 | model.to(device) 25 | 26 | examples = ["배고프다", "너무너무 사랑해요", "나는 너를 좋아해", "저의 취미는 축구입니다", "어제 무슨 영화 봤어?", "짜장면 짬뽕 탕수육 먹었어"] 27 | 28 | for example in examples: 29 | chosung_example = convert_text_to_chosung(example) 30 | 31 | input_ids = ( 32 | torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(chosung_example))).unsqueeze(0).to(device) 33 | ) 34 | 35 | if args.decoding_method == "top_p": 36 | outputs = model.generate( 37 | input_ids=input_ids, 38 | max_length=48, 39 | temperature=1.0, 40 | do_sample=True, 41 | top_p=0.8, 42 | pad_token_id=tokenizer.pad_token_id, 43 | bos_token_id=tokenizer.bos_token_id, 44 | eos_token_id=tokenizer.eos_token_id, 45 | decoder_start_token_id=tokenizer.bos_token_id, 46 | num_return_sequences=5, 47 | ) 48 | elif args.decoding_method == "beam_search": 49 | outputs = model.generate( 50 | input_ids=input_ids, 51 | max_length=48, 52 | num_beams=10, 53 | pad_token_id=tokenizer.pad_token_id, 54 | bos_token_id=tokenizer.bos_token_id, 55 | eos_token_id=tokenizer.eos_token_id, 56 | decoder_start_token_id=tokenizer.bos_token_id, 57 | num_return_sequences=5, 58 | ) 59 | else: 60 | raise ValueError("Enter the right decoding method (top_p or beam_search)") 61 | 62 | for output in outputs.tolist(): 63 | answer = tokenizer.decode(output) 64 | print(f"초성: {chosung_example} \t 예측 문장: {answer}") 65 | 66 | 67 | if __name__ == "__main__": 68 | args = parser.parse_args() 69 | main(args) 70 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import sys 4 | from typing import Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.optim.adam import Adam 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast 12 | 13 | from chosung_translator.config import TrainConfig 14 | from chosung_translator.data import ChosungTranslatorDataset 15 | from chosung_translator.utils import load_data 16 | 17 | 18 | def train( 19 | config: TrainConfig, 20 | model: BartForConditionalGeneration, 21 | train_dataloader: DataLoader, 22 | dev_dataloader: DataLoader, 23 | optimizer: Adam, 24 | logger: logging.Logger, 25 | device=torch.device, 26 | ): 27 | """ 지정된 Epoch만큼 모델을 학습시키는 함수입니다. """ 28 | model.to(device) 29 | global_step = 0 30 | for epoch in range(1, config.num_epochs + 1): 31 | model.train() 32 | loss_sum = 0.0 33 | for data in train_dataloader: 34 | global_step += 1 35 | data = _change_device(data, device) 36 | optimizer.zero_grad() 37 | output = model.forward( 38 | input_ids=data[0], 39 | attention_mask=data[1], 40 | decoder_input_ids=data[2], 41 | labels=data[3], 42 | decoder_attention_mask=data[4], 43 | return_dict=True, 44 | ) 45 | loss = output["loss"] 46 | loss.backward() 47 | loss_sum += loss.item() 48 | 49 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 50 | optimizer.step() 51 | 52 | if global_step % config.train_log_interval == 0: 53 | mean_loss = loss_sum / config.train_log_interval 54 | logger.info( 55 | f"Epoch {epoch} Step {global_step} " f"Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}" 56 | ) 57 | loss_sum = 0.0 58 | if global_step % config.dev_log_interval == 0: 59 | _validate(model, dev_dataloader, logger, device) 60 | if global_step % config.save_interval == 0: 61 | model.save_pretrained(f"{config.save_model_file_prefix}_{global_step}") 62 | 63 | 64 | def _validate( 65 | model: BartForConditionalGeneration, 66 | dev_dataloader: DataLoader, 67 | logger: logging.Logger, 68 | device: torch.device, 69 | ): 70 | model.eval() 71 | loss_sum = 0.0 72 | with torch.no_grad(): 73 | for data in tqdm(dev_dataloader): 74 | data = _change_device(data, device) 75 | output = model.forward( 76 | input_ids=data[0], 77 | attention_mask=data[1], 78 | decoder_input_ids=data[2], 79 | labels=data[3], 80 | decoder_attention_mask=data[4], 81 | return_dict=True, 82 | ) 83 | loss = output["loss"] 84 | loss_sum += loss.item() 85 | mean_loss = loss_sum / len(dev_dataloader) 86 | logger.info(f"[Validation] Loss {mean_loss:.4f} Perplexity {math.exp(mean_loss):8.2f}") 87 | model.train() 88 | 89 | 90 | def _change_device(data: Tuple[torch.Tensor, ...], device: torch.device): 91 | return tuple((data[0].to(device), data[1].to(device), data[2].to(device), data[3].to(device), data[4].to(device))) 92 | 93 | 94 | def main(): 95 | # Config 96 | config = TrainConfig() 97 | 98 | # Logger 99 | logger = logging.getLogger() 100 | logger.setLevel(logging.INFO) 101 | handler = logging.StreamHandler(sys.stdout) 102 | formatter = logging.Formatter("[%(asctime)s] %(message)s") 103 | handler.setFormatter(formatter) 104 | logger.addHandler(handler) 105 | 106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | # Data Loading... 108 | raw_train_instances = load_data(config.train_file_path) 109 | raw_dev_instances = load_data(config.dev_file_path) 110 | logger.info(f"훈련용 예시 개수:{len(raw_train_instances)}\t 검증용 예시 개수:{len(raw_dev_instances)}") 111 | 112 | tokenizer = PreTrainedTokenizerFast.from_pretrained(config.pretrained_model_name) 113 | 114 | train_dataset = ChosungTranslatorDataset(raw_train_instances, tokenizer, config.max_seq_len) 115 | dev_dataset = ChosungTranslatorDataset(raw_dev_instances, tokenizer, config.max_seq_len) 116 | 117 | train_dataloader = DataLoader( 118 | train_dataset, 119 | batch_size=config.batch_size, 120 | shuffle=True, 121 | num_workers=config.num_workers, 122 | ) 123 | dev_dataloader = DataLoader( 124 | dev_dataset, 125 | batch_size=config.batch_size, 126 | num_workers=config.num_workers, 127 | ) 128 | 129 | model = BartForConditionalGeneration.from_pretrained(config.pretrained_model_name) 130 | 131 | # Train 132 | optimizer = Adam(model.parameters(), lr=config.learning_rate) 133 | train(config, model, train_dataloader, dev_dataloader, optimizer, logger, device) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | --------------------------------------------------------------------------------