├── ratsnlp
├── __init__.py
└── nlpbook
│ ├── paircls
│ ├── __init__.py
│ ├── deploy.py
│ ├── corpus.py
│ └── index.html
│ ├── __init__.py
│ ├── qa
│ ├── __init__.py
│ ├── deploy.py
│ ├── task.py
│ ├── index.html
│ ├── arguments.py
│ └── corpus.py
│ ├── ner
│ ├── __init__.py
│ ├── deploy.py
│ ├── task.py
│ ├── index.html
│ ├── arguments.py
│ └── corpus.py
│ ├── generation
│ ├── __init__.py
│ ├── deploy.py
│ ├── task.py
│ ├── arguments.py
│ ├── index.html
│ └── corpus.py
│ ├── classification
│ ├── __init__.py
│ ├── deploy.py
│ ├── task.py
│ ├── index.html
│ ├── arguments.py
│ └── corpus.py
│ ├── metrics.py
│ ├── trainer.py
│ ├── data_utils.py
│ └── utils.py
├── apply.sh
├── README.md
├── dist.sh
├── requirements.txt
├── .github
├── PULL_REQUEST_TEMPLATE.md
├── ISSUE_TEMPLATE
│ ├── feature.md
│ └── bug-report.md
└── workflows
│ └── lint.yaml
├── .gitmessage.txt
├── LICENSE
├── setup.py
└── .gitignore
/ratsnlp/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/apply.sh:
--------------------------------------------------------------------------------
1 | git config --local commit.template .gitmessage.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ratsnlp
2 |
3 | 자연어 처리 실습을 위한 패키지입니다. 구글 코랩(colab) 환경에서 동작할 수 있도록 작성하였습니다.
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/paircls/__init__.py:
--------------------------------------------------------------------------------
1 | from .corpus import *
2 | from .deploy import get_web_service_app
3 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .trainer import *
3 | from .data_utils import *
4 |
--------------------------------------------------------------------------------
/dist.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | VERSION=$1
4 |
5 | python setup.py sdist bdist_wheel
6 | python -m twine upload dist/*$VERSION*
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.6.1
2 | transformers==4.28.1
3 | Korpora>=0.2.0
4 | flask>=1.1.4
5 | flask_ngrok>=0.0.25
6 | flask_cors>=3.0.10
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/__init__.py:
--------------------------------------------------------------------------------
1 | from .arguments import QATrainArguments, QADeployArguments
2 | from .task import QATask
3 | from .corpus import *
4 | from .deploy import get_web_service_app
5 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/__init__.py:
--------------------------------------------------------------------------------
1 | from .arguments import NERTrainArguments, NERDeployArguments
2 | from .corpus import *
3 | from .task import NERTask
4 | from .deploy import get_web_service_app
5 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/__init__.py:
--------------------------------------------------------------------------------
1 | from .corpus import *
2 | from .task import GenerationTask
3 | from .arguments import GenerationTrainArguments, GenerationDeployArguments
4 | from .deploy import get_web_service_app
5 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/__init__.py:
--------------------------------------------------------------------------------
1 | from .task import ClassificationTask
2 | from .deploy import get_web_service_app
3 | from .corpus import *
4 | from .arguments import ClassificationTrainArguments, ClassificationDeployArguments
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Pull Request
2 | 레파지토리에 기여해주셔서 감사드립니다.
3 |
4 | 해당 PR을 제출하기 전에 아래 사항이 완료되었는지 확인 부탁드립니다:
5 | - [ ] 작성한 코드가 어떤 에러나 경고없이 빌드가 되었나요?
6 | - [ ] 충분한 테스트를 수행하셨나요?
7 |
8 | ## 1. 해당 PR은 어떤 내용인가요?
9 |
10 |
11 | ## 2. PR과 관련된 이슈가 있나요?
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature
3 | about: 개발할 기능에 대해 서술합니다.
4 | title: "[FEATURE] "
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## 🚀 Feature
11 |
12 |
13 | ## Motivation
14 |
15 |
16 | ## Pitch
17 |
18 |
19 | ## Alternatives
20 |
21 |
22 | ## Additional context
23 |
24 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: 버그 관련 리포팅을 합니다.
4 | title: "[BUG] "
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | ## 🐛 Bug
11 |
12 |
13 | ## To Reproduce
14 |
15 | 버그를 재현하기 위한 재현절차를 작성해주세요.
16 |
17 | 1.
18 | 2.
19 | 3.
20 |
21 | ## Expected behavior
22 |
23 |
24 |
25 | ## Environment
26 |
27 | 실행한 환경을 확인하기 위해 아래 체크리스트를 작성해주세요.
28 |
29 | - 운영체제 (e.g., Linux):
30 | - GPU 모델:
31 | - 기타 다른 정보:
32 |
33 | ## Additional context
34 |
35 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def accuracy(preds, labels, ignore_index=None):
5 | with torch.no_grad():
6 | assert preds.shape[0] == len(labels)
7 | correct = torch.sum(preds == labels)
8 | total = torch.sum(torch.ones_like(labels))
9 | if ignore_index is not None:
10 | # 모델이 맞춘 것 가운데 ignore index에 해당하는 것 제외
11 | correct -= torch.sum(torch.logical_and(preds == ignore_index, preds == labels))
12 | # accuracy의 분모 가운데 ignore index에 해당하는 것 제외
13 | total -= torch.sum(labels == ignore_index)
14 | return correct.to(dtype=torch.float) / total.to(dtype=torch.float)
15 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/deploy.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify, render_template
2 |
3 |
4 | def get_web_service_app(inference_fn, is_colab=True):
5 |
6 | app = Flask(__name__, template_folder='')
7 | if is_colab:
8 | from flask_ngrok import run_with_ngrok
9 | run_with_ngrok(app)
10 | else:
11 | from flask_cors import CORS
12 | CORS(app)
13 |
14 | @app.route('/')
15 | def index():
16 | return render_template('index.html')
17 |
18 | @app.route('/api', methods=['POST'])
19 | def api():
20 | query_sentence = request.json
21 | output_data = inference_fn(query_sentence)
22 | response = jsonify(output_data)
23 | return response
24 |
25 | return app
26 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/deploy.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify, render_template
2 |
3 |
4 | def get_web_service_app(inference_fn, is_colab=True):
5 |
6 | app = Flask(__name__, template_folder='')
7 | if is_colab:
8 | from flask_ngrok import run_with_ngrok
9 | run_with_ngrok(app)
10 | else:
11 | from flask_cors import CORS
12 | CORS(app)
13 |
14 | @app.route('/')
15 | def index():
16 | return render_template('index.html')
17 |
18 | @app.route('/api', methods=['POST'])
19 | def api():
20 | query_sentence = request.json
21 | output_data = inference_fn(query_sentence)
22 | response = jsonify(output_data)
23 | return response
24 |
25 | return app
26 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/deploy.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify, render_template
2 |
3 |
4 | def get_web_service_app(inference_fn, is_colab=True):
5 |
6 | app = Flask(__name__, template_folder='')
7 | if is_colab:
8 | from flask_ngrok import run_with_ngrok
9 | run_with_ngrok(app)
10 | else:
11 | from flask_cors import CORS
12 | CORS(app)
13 |
14 | @app.route('/')
15 | def index():
16 | return render_template('index.html')
17 |
18 | @app.route('/api', methods=['POST'])
19 | def api():
20 | query = request.json
21 | output_data = inference_fn(query["question"], query["context"])
22 | response = jsonify(output_data)
23 | return response
24 |
25 | return app
26 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/paircls/deploy.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify, render_template
2 |
3 |
4 | def get_web_service_app(inference_fn, is_colab=True):
5 |
6 | app = Flask(__name__, template_folder='')
7 | if is_colab:
8 | from flask_ngrok import run_with_ngrok
9 | run_with_ngrok(app)
10 | else:
11 | from flask_cors import CORS
12 | CORS(app)
13 |
14 | @app.route('/')
15 | def index():
16 | return render_template('index.html')
17 |
18 | @app.route('/api', methods=['POST'])
19 | def api():
20 | query = request.json
21 | output_data = inference_fn(query["premise"], query["hypothesis"])
22 | response = jsonify(output_data)
23 | return response
24 |
25 | return app
26 |
--------------------------------------------------------------------------------
/.gitmessage.txt:
--------------------------------------------------------------------------------
1 |
2 | <타입>: <제목>
3 |
4 | <꼬릿말>
5 |
6 | ##### 제목은 최대 50 글자까지만 입력 ############## -> |
7 |
8 |
9 | # 본문은 위에 작성
10 | ######## 본문은 한 줄에 최대 72 글자까지만 입력 ########################### -> |
11 |
12 | # 꼬릿말은 아래에 작성: ex) Refs: #이슈 번호
13 |
14 | # --- COMMIT END ---
15 | # <타입> 리스트
16 | # feat : 기능 (새로운 기능)
17 | # fix : 버그 (버그 수정)
18 | # refactor: 리팩토링
19 | # style : 스타일 (코드 형식, 세미콜론 추가: 비즈니스 로직에 변경 없음)
20 | # docs : 문서 (문서 추가, 수정, 삭제)
21 | # test : 테스트 (테스트 코드 추가, 수정, 삭제: 비즈니스 로직에 변경 없음)
22 | # chore : 기타 변경사항 (빌드 스크립트 수정 등)
23 | # ------------------
24 | # 제목 첫 글자를 대문자로
25 | # 제목은 명령문으로
26 | # 제목 끝에 마침표(.) 금지
27 | # 제목과 본문을 한 줄 띄워 분리하기
28 | # 본문은 "어떻게" 보다 "무엇을", "왜"를 설명한다.
29 | # 본문에 여러줄의 메시지를 작성할 땐 "-"로 구분
30 | # ------------------
31 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/deploy.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, request, jsonify, render_template
2 |
3 |
4 | def get_web_service_app(inference_fn, is_colab=True):
5 |
6 | app = Flask(__name__, template_folder='')
7 | if is_colab:
8 | from flask_ngrok import run_with_ngrok
9 | run_with_ngrok(app)
10 | else:
11 | from flask_cors import CORS
12 | CORS(app)
13 |
14 | @app.route('/')
15 | def index():
16 | return render_template('index.html')
17 |
18 | @app.route('/api', methods=['POST'])
19 | def api():
20 | query = request.json
21 | output_data = inference_fn(
22 | query["prompt"],
23 | query["min_length"],
24 | query["max_length"],
25 | query["top_p"],
26 | query["top_k"],
27 | query["repetition_penalty"],
28 | query["no_repeat_ngram_size"],
29 | query["temperature"],
30 | )
31 | response = jsonify(output_data)
32 | return response
33 |
34 | return app
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 gichang.lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import setuptools
3 |
4 |
5 | def requirements():
6 | with open(os.path.join(os.path.dirname(__file__), 'requirements.txt'), encoding='utf-8') as f:
7 | return f.read().splitlines()
8 |
9 |
10 | setuptools.setup(
11 | name="ratsnlp",
12 | version="1.0.53",
13 | license='MIT',
14 | author="ratsgo",
15 | author_email="ratsgo@naver.com",
16 | description="tools for Natural Language Processing",
17 | long_description=open('README.md').read(),
18 | url="https://github.com/ratsgo/ratsnlp",
19 | packages=setuptools.find_packages(),
20 | include_package_data=True,
21 | package_data={
22 | 'ratsnlp.nlpbook.classification': ['*.html'],
23 | 'ratsnlp.nlpbook.ner': ['*.html'],
24 | 'ratsnlp.nlpbook.qa': ['*.html'],
25 | 'ratsnlp.nlpbook.paircls': ['*.html'],
26 | 'ratsnlp.nlpbook.generation': ['*.html'],
27 | },
28 | install_requires=requirements(),
29 | classifiers=[
30 | "Programming Language :: Python :: 3.7",
31 | "License :: OSI Approved :: MIT License",
32 | "Operating System :: OS Independent"
33 | ],
34 | )
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from pytorch_lightning import Trainer
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 |
6 |
7 | def get_trainer(args, return_trainer_only=True):
8 | ckpt_path = os.path.abspath(args.downstream_model_dir)
9 | os.makedirs(ckpt_path, exist_ok=True)
10 | checkpoint_callback = ModelCheckpoint(
11 | dirpath=ckpt_path,
12 | save_top_k=args.save_top_k,
13 | monitor=args.monitor.split()[1],
14 | mode=args.monitor.split()[0],
15 | filename='{epoch}-{val_loss:.2f}',
16 | )
17 | trainer = Trainer(
18 | max_epochs=args.epochs,
19 | fast_dev_run=args.test_mode,
20 | num_sanity_val_steps=None if args.test_mode else 0,
21 | callbacks=[checkpoint_callback],
22 | default_root_dir=ckpt_path,
23 | # For GPU Setup
24 | deterministic=torch.cuda.is_available() and args.seed is not None,
25 | gpus=torch.cuda.device_count() if torch.cuda.is_available() else None,
26 | precision=16 if args.fp16 else 32,
27 | # For TPU Setup
28 | tpu_cores=args.tpu_cores if args.tpu_cores else None,
29 | )
30 | if return_trainer_only:
31 | return trainer
32 | else:
33 | return checkpoint_callback, trainer
34 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/task.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedModel
2 | from transformers.optimization import AdamW
3 | from pytorch_lightning import LightningModule
4 | from ratsnlp.nlpbook.generation.arguments import GenerationTrainArguments
5 | from torch.optim.lr_scheduler import ExponentialLR
6 |
7 |
8 | class GenerationTask(LightningModule):
9 |
10 | def __init__(self,
11 | model: PreTrainedModel,
12 | args: GenerationTrainArguments,
13 | ):
14 | super().__init__()
15 | self.model = model
16 | self.args = args
17 |
18 | def configure_optimizers(self):
19 | optimizer = AdamW(self.parameters(), lr=self.args.learning_rate)
20 | scheduler = ExponentialLR(optimizer, gamma=0.9)
21 | return {
22 | 'optimizer': optimizer,
23 | 'scheduler': scheduler,
24 | }
25 |
26 | def training_step(self, inputs, batch_idx):
27 | # outputs: CausalLMOutputWithCrossAttentions
28 | outputs = self.model(**inputs)
29 | self.log("loss", outputs.loss, prog_bar=False, logger=True, on_step=True, on_epoch=False)
30 | return outputs.loss
31 |
32 | def validation_step(self, inputs, batch_idx):
33 | # outputs: CausalLMOutputWithCrossAttentions
34 | outputs = self.model(**inputs)
35 | self.log("val_loss", outputs.loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)
36 | return outputs.loss
37 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/task.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedModel
2 | from transformers.optimization import AdamW
3 | from ratsnlp.nlpbook.metrics import accuracy
4 | from pytorch_lightning import LightningModule
5 | from torch.optim.lr_scheduler import ExponentialLR
6 | from ratsnlp.nlpbook.classification.arguments import ClassificationTrainArguments
7 |
8 |
9 | class ClassificationTask(LightningModule):
10 |
11 | def __init__(self,
12 | model: PreTrainedModel,
13 | args: ClassificationTrainArguments,
14 | ):
15 | super().__init__()
16 | self.model = model
17 | self.args = args
18 |
19 | def configure_optimizers(self):
20 | optimizer = AdamW(self.parameters(), lr=self.args.learning_rate)
21 | scheduler = ExponentialLR(optimizer, gamma=0.9)
22 | return {
23 | 'optimizer': optimizer,
24 | 'scheduler': scheduler,
25 | }
26 |
27 | def training_step(self, inputs, batch_idx):
28 | # outputs: SequenceClassifierOutput
29 | outputs = self.model(**inputs)
30 | preds = outputs.logits.argmax(dim=-1)
31 | labels = inputs["labels"]
32 | acc = accuracy(preds, labels)
33 | self.log("loss", outputs.loss, prog_bar=False, logger=True, on_step=True, on_epoch=False)
34 | self.log("acc", acc, prog_bar=True, logger=True, on_step=True, on_epoch=False)
35 | return outputs.loss
36 |
37 | def validation_step(self, inputs, batch_idx):
38 | # outputs: SequenceClassifierOutput
39 | outputs = self.model(**inputs)
40 | preds = outputs.logits.argmax(dim=-1)
41 | labels = inputs["labels"]
42 | acc = accuracy(preds, labels)
43 | self.log("val_loss", outputs.loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)
44 | self.log("val_acc", acc, prog_bar=True, logger=True, on_step=False, on_epoch=True)
45 | return outputs.loss
46 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/task.py:
--------------------------------------------------------------------------------
1 | from transformers.optimization import AdamW
2 | from ratsnlp.nlpbook.metrics import accuracy
3 | from pytorch_lightning import LightningModule
4 | from transformers import BertPreTrainedModel
5 | from ratsnlp.nlpbook.ner import NERTrainArguments, NER_PAD_ID
6 | from torch.optim.lr_scheduler import ExponentialLR
7 |
8 |
9 | class NERTask(LightningModule):
10 |
11 | def __init__(self,
12 | model: BertPreTrainedModel,
13 | args: NERTrainArguments,
14 | ):
15 | super().__init__()
16 | self.model = model
17 | self.args = args
18 |
19 | def configure_optimizers(self):
20 | optimizer = AdamW(self.parameters(), lr=self.args.learning_rate)
21 | scheduler = ExponentialLR(optimizer, gamma=0.9)
22 | return {
23 | 'optimizer': optimizer,
24 | 'scheduler': scheduler,
25 | }
26 |
27 | def training_step(self, inputs, batch_idx):
28 | # outputs: TokenClassifierOutput
29 | outputs = self.model(**inputs)
30 | preds = outputs.logits.argmax(dim=-1)
31 | labels = inputs["labels"]
32 | acc = accuracy(preds, labels, ignore_index=NER_PAD_ID)
33 | self.log("loss", outputs.loss, prog_bar=False, logger=True, on_step=True, on_epoch=False)
34 | self.log("acc", acc, prog_bar=True, logger=True, on_step=True, on_epoch=False)
35 | return outputs.loss
36 |
37 | def validation_step(self, inputs, batch_idx):
38 | # outputs: TokenClassifierOutput
39 | outputs = self.model(**inputs)
40 | preds = outputs.logits.argmax(dim=-1)
41 | labels = inputs["labels"]
42 | acc = accuracy(preds, labels, ignore_index=NER_PAD_ID)
43 | self.log("val_loss", outputs.loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)
44 | self.log("val_acc", acc, prog_bar=True, logger=True, on_step=False, on_epoch=True)
45 | return outputs.loss
46 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: Lint
4 |
5 | # Controls when the action will run. Triggers the workflow on push or pull request
6 | # events but only for the master branch
7 | on:
8 | pull_request:
9 | branches:
10 | - master
11 |
12 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
13 | jobs:
14 | black-and-isort:
15 | # The type of runner that the job will run on
16 | runs-on: ubuntu-latest
17 | strategy:
18 | matrix:
19 | python-version:
20 | - 3.x
21 | timeout-minutes: 60
22 |
23 | # Steps represent a sequence of tasks that will be executed as part of the job
24 | steps:
25 | # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it
26 | - uses: actions/checkout@v2
27 | with:
28 | ref: ${{ github.event.pull_request.head.sha }}
29 |
30 | - name: Prerequeist
31 | run: |
32 | pip3 install black --no-cache --user
33 | pip3 install isort --no-cache --user
34 | pip3 install -r requirements.txt --user
35 |
36 | - name: apply isort
37 | run: python3 -m isort .
38 |
39 | - name: apply black
40 | run: python3 -m black .
41 |
42 | - name: commit
43 | run: |
44 | git config --local user.email "action@github.com"
45 | git config --local user.name "GitHub Action"
46 | git add -A && git diff-index --cached --quiet HEAD || git commit -m 'style: isort/black'
47 |
48 | - name: push
49 | uses: ad-m/github-push-action@master
50 | with:
51 | github_token: ${{ secrets.GITHUB_TOKEN }}
52 | branch: ${{ github.head_ref }}
53 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/task.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedModel
2 | from transformers.optimization import AdamW
3 | from ratsnlp.nlpbook.metrics import accuracy
4 | from pytorch_lightning import LightningModule
5 | from ratsnlp.nlpbook.qa import QATrainArguments
6 | from torch.optim.lr_scheduler import ExponentialLR
7 |
8 |
9 | class QATask(LightningModule):
10 |
11 | def __init__(self,
12 | model: PreTrainedModel,
13 | args: QATrainArguments,
14 | ):
15 | super().__init__()
16 | self.model = model
17 | self.args = args
18 |
19 | def configure_optimizers(self):
20 | optimizer = AdamW(self.parameters(), lr=self.args.learning_rate)
21 | scheduler = ExponentialLR(optimizer, gamma=0.9)
22 | return {
23 | 'optimizer': optimizer,
24 | 'scheduler': scheduler,
25 | }
26 |
27 | def training_step(self, inputs, batch_idx):
28 | # outputs: QuestionAnsweringModelOutput
29 | outputs = self.model(**inputs)
30 | start_preds = outputs.start_logits.argmax(dim=-1)
31 | start_positions = inputs["start_positions"]
32 | end_preds = outputs.end_logits.argmax(dim=-1)
33 | end_positions = inputs["end_positions"]
34 | acc = (accuracy(start_preds, start_positions) + accuracy(end_preds, end_positions)) / 2
35 | self.log("loss", outputs.loss, prog_bar=False, logger=True, on_step=True, on_epoch=False)
36 | self.log("acc", acc, prog_bar=True, logger=True, on_step=True, on_epoch=False)
37 | return outputs.loss
38 |
39 | def validation_step(self, inputs, batch_idx):
40 | # outputs: QuestionAnsweringModelOutput
41 | outputs = self.model(**inputs)
42 | start_preds = outputs.start_logits.argmax(dim=-1)
43 | start_positions = inputs["start_positions"]
44 | end_preds = outputs.end_logits.argmax(dim=-1)
45 | end_positions = inputs["end_positions"]
46 | acc = (accuracy(start_preds, start_positions) + accuracy(end_preds, end_positions)) / 2
47 | self.log("val_loss", outputs.loss, prog_bar=True, logger=True, on_step=False, on_epoch=True)
48 | self.log("val_acc", acc, prog_bar=True, logger=True, on_step=False, on_epoch=True)
49 | return outputs.loss
50 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def data_collator(features):
5 | """
6 | Very simple data collator that:
7 | - simply collates batches of dict-like objects
8 | - Performs special handling for potential keys named:
9 | - `label`: handles a single value (int or float) per object
10 | - `label_ids`: handles a list of values per object
11 | - does not do any additional preprocessing
12 |
13 | i.e., Property names of the input object will be used as corresponding inputs to the model.
14 | See glue and ner for example of how it's useful.
15 | """
16 |
17 | # In this function we'll make the assumption that all `features` in the batch
18 | # have the same attributes.
19 | # So we will look at the first element as a proxy for what attributes exist
20 | # on the whole batch.
21 | if not isinstance(features[0], dict):
22 | features = [vars(f) for f in features]
23 |
24 | first = features[0]
25 | batch = {}
26 |
27 | # Special handling for labels.
28 | # Ensure that tensor is created with the correct type
29 | # (it should be automatically the case, but let's make sure of it.)
30 | if "label" in first and first["label"] is not None:
31 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
32 | dtype = torch.long if isinstance(label, int) else torch.float
33 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
34 | elif "label_ids" in first and first["label_ids"] is not None:
35 | if isinstance(first["label_ids"], torch.Tensor):
36 | batch["labels"] = torch.stack([f["label_ids"] for f in features])
37 | else:
38 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
39 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
40 |
41 | # Handling of all other possible keys.
42 | # Again, we will use the first element to figure out which key/values are not None for this model.
43 | for k, v in first.items():
44 | if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
45 | if isinstance(v, torch.Tensor):
46 | batch[k] = torch.stack([f[k] for f in features])
47 | else:
48 | batch[k] = torch.tensor([f[k] for f in features], dtype=torch.long)
49 |
50 | return batch
51 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | cache/
3 | .idea/
4 | .DS_Store
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
개체명 인식 (Named Entity Recognition)
11 |
주어진 문장 내에서 개체명을 판별합니다.
12 |
13 |
14 |
18 |
19 |
20 |
21 |
22 | | # |
23 | 토큰 |
24 | 태그 |
25 | 확률 |
26 |
27 |
28 |
29 |
30 |
31 |
이 웹 데모가 어떻게 동작하는지 자세히 확인하고 싶으신 분은 ratsgo's nlpbook을 참고하세요. Copyright © 2020 Gichang LEE. Distributed by an CC BY-NC-SA 3.0 license.
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/paircls/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | from ratsnlp.nlpbook.classification.corpus import ClassificationExample
5 |
6 |
7 | logger = logging.getLogger("ratsnlp")
8 |
9 |
10 | class KlueNLICorpus:
11 |
12 | def __init__(self):
13 | pass
14 |
15 | def _create_examples(self, data_path):
16 | examples = []
17 | data = json.load(open(data_path, "r"))
18 | for el in data:
19 | example = ClassificationExample(
20 | text_a=el["premise"],
21 | text_b=el["hypothesis"],
22 | label=el["gold_label"],
23 | )
24 | examples.append(example)
25 | return examples
26 |
27 | def get_examples(self, data_path, mode):
28 | if mode == "train":
29 | data_fpath = os.path.join(data_path, "klue_nli_train.json")
30 | else:
31 | data_fpath = os.path.join(data_path, "klue_nli_dev.json")
32 | logger.info(f"loading {mode} data... LOOKING AT {data_fpath}")
33 | examples = self._create_examples(data_fpath)
34 | return examples
35 |
36 | def get_labels(self):
37 | return ["entailment", "contradiction", "neutral"]
38 |
39 | @property
40 | def num_labels(self):
41 | return len(self.get_labels())
42 |
43 |
44 |
45 | class KorNLICorpus:
46 |
47 | def __init__(self):
48 | pass
49 |
50 | def _create_examples(self, data_path):
51 | examples = []
52 | corpus = open(data_path, "r", encoding="utf-8").readlines()
53 | lines = [line.strip().split("\t") for line in corpus]
54 | for (i, line) in enumerate(lines):
55 | if i == 0:
56 | continue
57 | text_a, text_b, label = line
58 | examples.append(ClassificationExample(text_a=text_a, text_b=text_b, label=label))
59 | return examples
60 |
61 | def get_examples(self, data_path, mode):
62 | logger.info(f"loading {mode} data... LOOKING AT {data_path}")
63 | if mode == "train":
64 | multinli_train_data_fpath = os.path.join(data_path, "multinli.train.ko.tsv")
65 | multinli_train_data = self._create_examples(multinli_train_data_fpath)
66 | snli_train_data_fpath = os.path.join(data_path, "snli_1.0_train.ko.tsv")
67 | snli_train_data = self._create_examples(snli_train_data_fpath)
68 | examples = multinli_train_data + snli_train_data
69 | elif mode == "val":
70 | valid_data_fpath = os.path.join(data_path, "xnli.dev.ko.tsv")
71 | examples = self._create_examples(valid_data_fpath)
72 | else:
73 | test_data_fpath = os.path.join(data_path, "xnli.test.ko.tsv")
74 | examples = self._create_examples(test_data_fpath)
75 | return examples
76 |
77 | def get_labels(self):
78 | return ["entailment", "contradiction", "neutral"]
79 |
80 | @property
81 | def num_labels(self):
82 | return len(self.get_labels())
83 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
질의/응답 (Question Answering)
11 |
지문(context)과 질문(question)이 주어졌을 때 질문에 적절한 답을 지문에서 찾습니다.
12 |
13 | context
14 |
15 |
16 |
17 | question
18 |
19 |
20 |
21 |
22 |
23 |
24 | - context
25 |
26 | - question
27 |
28 | - answer
29 |
30 |
31 |
32 |
33 |
34 | 이 웹 데모가 어떻게 동작하는지 자세히 확인하고 싶으신 분은 ratsgo's nlpbook을 참고하세요.
35 |
36 | Copyright © 2020 Gichang LEE. Distributed by an CC BY-NC-SA 3.0 license.
37 |
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/paircls/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
자연어 추론 (Natural Language Inference)
11 |
전제(premise)에 대한 가설(hypothesis)이 참(entailment), 거짓(contradiction), 중립(neutral)인지 판단합니다.
12 |
13 | 전제
14 |
15 |
16 |
17 | 가설
18 |
19 |
20 |
21 |
22 |
26 |
30 |
31 |
32 |
참 1
33 |
거짓 0
34 |
중립 0
35 |
36 |
37 |
이 웹 데모가 어떻게 동작하는지 자세히 확인하고 싶으신 분은 ratsgo's nlpbook을 참고하세요. Copyright © 2020 Gichang LEE. Distributed by an CC BY-NC-SA 3.0 license.
38 |
39 |
40 |
41 |
42 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class GenerationTrainArguments:
8 |
9 | pretrained_model_name: str = field(
10 | default="kogpt2",
11 | metadata={"help": "pretrained model name"}
12 | )
13 | downstream_task_name: str = field(
14 | default="sentence-generation",
15 | metadata={"help": "The name of the downstream data."}
16 | )
17 | downstream_corpus_name: str = field(
18 | default="nsmc",
19 | metadata={"help": "The name of the downstream data."}
20 | )
21 | downstream_corpus_root_dir: str = field(
22 | default="/content/Korpora",
23 | metadata={"help": "The root directory of the downstream data."}
24 | )
25 | downstream_model_dir: str = field(
26 | default="/gdrive/My Drive/nlpbook/checkpoint-generation",
27 | metadata={"help": "The output model dir."}
28 | )
29 | max_seq_length: int = field(
30 | default=32,
31 | metadata={
32 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
33 | "than this will be truncated, sequences shorter will be padded."
34 | }
35 | )
36 | save_top_k: int = field(
37 | default=1,
38 | metadata={"help": "save top k model checkpoints."}
39 | )
40 | monitor: str = field(
41 | default="min val_loss",
42 | metadata={"help": "monitor condition (save top k)"}
43 | )
44 | seed: int = field(
45 | default=None,
46 | metadata={"help": "random seed."}
47 | )
48 | overwrite_cache: bool = field(
49 | default=False,
50 | metadata={"help": "Overwrite the cached training and evaluation sets"}
51 | )
52 | force_download: bool = field(
53 | default=False,
54 | metadata={"help": "force to download downstream data and pretrained models."}
55 | )
56 | test_mode: bool = field(
57 | default=False,
58 | metadata={"help": "Test Mode enables `fast_dev_run`"}
59 | )
60 | learning_rate: float = field(
61 | default=5e-5,
62 | metadata={"help": "learning rate"}
63 | )
64 | epochs: int = field(
65 | default=3,
66 | metadata={"help": "max epochs"}
67 | )
68 | batch_size: int = field(
69 | default=96,
70 | metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"}
71 | )
72 | cpu_workers: int = field(
73 | default=os.cpu_count(),
74 | metadata={"help": "number of CPU workers"}
75 | )
76 | fp16: bool = field(
77 | default=False,
78 | metadata={"help": "Enable train on FP16"}
79 | )
80 | tpu_cores: int = field(
81 | default=0,
82 | metadata={"help": "Enable TPU with 1 core or 8 cores"}
83 | )
84 |
85 |
86 | @dataclass
87 | class GenerationDeployArguments:
88 |
89 | def __init__(
90 | self,
91 | pretrained_model_name=None,
92 | downstream_model_dir=None,
93 | downstream_model_checkpoint_fpath=None,
94 | ):
95 | self.pretrained_model_name = pretrained_model_name
96 | if downstream_model_checkpoint_fpath is not None:
97 | self.downstream_model_checkpoint_fpath = downstream_model_checkpoint_fpath
98 | elif downstream_model_dir is not None:
99 | ckpt_file_names = glob(os.path.join(downstream_model_dir, "*.ckpt"))
100 | ckpt_file_names = [el for el in ckpt_file_names if "temp" not in el and "tmp" not in el]
101 | if len(ckpt_file_names) == 0:
102 | raise Exception(f"downstream_model_dir \"{downstream_model_dir}\" is not valid")
103 | selected_fname = ckpt_file_names[-1]
104 | min_val_loss = os.path.split(selected_fname)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
105 | try:
106 | for ckpt_file_name in ckpt_file_names:
107 | val_loss = os.path.split(ckpt_file_name)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
108 | if float(val_loss) < float(min_val_loss):
109 | selected_fname = ckpt_file_name
110 | min_val_loss = val_loss
111 | except:
112 | raise Exception(f"the ckpt file name of downstream_model_directory \"{downstream_model_dir}\" is not valid")
113 | self.downstream_model_checkpoint_fpath = selected_fname
114 | else:
115 | raise Exception("Either downstream_model_dir or downstream_model_checkpoint_fpath must be entered.")
116 | print(f"downstream_model_checkpoint_fpath: {self.downstream_model_checkpoint_fpath}")
117 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class ClassificationTrainArguments:
8 |
9 | pretrained_model_name: str = field(
10 | default="beomi/kcbert-base",
11 | metadata={"help": "pretrained model name"}
12 | )
13 | downstream_task_name: str = field(
14 | default="document-classification",
15 | metadata={"help": "The name of the downstream data."}
16 | )
17 | downstream_corpus_name: str = field(
18 | default=None,
19 | metadata={"help": "The name of the downstream data."}
20 | )
21 | downstream_corpus_root_dir: str = field(
22 | default="/content/Korpora",
23 | metadata={"help": "The root directory of the downstream data."}
24 | )
25 | downstream_model_dir: str = field(
26 | default=None,
27 | metadata={"help": "The output model dir."}
28 | )
29 | max_seq_length: int = field(
30 | default=128,
31 | metadata={
32 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
33 | "than this will be truncated, sequences shorter will be padded."
34 | }
35 | )
36 | save_top_k: int = field(
37 | default=1,
38 | metadata={"help": "save top k model checkpoints."}
39 | )
40 | monitor: str = field(
41 | default="min val_loss",
42 | metadata={"help": "monitor condition (save top k)"}
43 | )
44 | seed: int = field(
45 | default=None,
46 | metadata={"help": "random seed."}
47 | )
48 | overwrite_cache: bool = field(
49 | default=False,
50 | metadata={"help": "Overwrite the cached training and evaluation sets"}
51 | )
52 | force_download: bool = field(
53 | default=False,
54 | metadata={"help": "force to download downstream data and pretrained models."}
55 | )
56 | test_mode: bool = field(
57 | default=False,
58 | metadata={"help": "Test Mode enables `fast_dev_run`"}
59 | )
60 | learning_rate: float = field(
61 | default=5e-5,
62 | metadata={"help": "learning rate"}
63 | )
64 | epochs: int = field(
65 | default=3,
66 | metadata={"help": "max epochs"}
67 | )
68 | batch_size: int = field(
69 | default=32,
70 | metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"}
71 | )
72 | cpu_workers: int = field(
73 | default=os.cpu_count(),
74 | metadata={"help": "number of CPU workers"}
75 | )
76 | fp16: bool = field(
77 | default=False,
78 | metadata={"help": "Enable train on FP16"}
79 | )
80 | tpu_cores: int = field(
81 | default=0,
82 | metadata={"help": "Enable TPU with 1 core or 8 cores"}
83 | )
84 |
85 |
86 | @dataclass
87 | class ClassificationDeployArguments:
88 |
89 | def __init__(
90 | self,
91 | pretrained_model_name=None,
92 | downstream_model_dir=None,
93 | downstream_model_checkpoint_fpath=None,
94 | max_seq_length=128,
95 | ):
96 | self.pretrained_model_name = pretrained_model_name
97 | self.max_seq_length = max_seq_length
98 | if downstream_model_checkpoint_fpath is not None:
99 | self.downstream_model_checkpoint_fpath = downstream_model_checkpoint_fpath
100 | elif downstream_model_dir is not None:
101 | ckpt_file_names = glob(os.path.join(downstream_model_dir, "*.ckpt"))
102 | ckpt_file_names = [el for el in ckpt_file_names if "temp" not in el and "tmp" not in el]
103 | if len(ckpt_file_names) == 0:
104 | raise Exception(f"downstream_model_dir \"{downstream_model_dir}\" is not valid")
105 | selected_fname = ckpt_file_names[-1]
106 | min_val_loss = os.path.split(selected_fname)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
107 | try:
108 | for ckpt_file_name in ckpt_file_names:
109 | val_loss = os.path.split(ckpt_file_name)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
110 | if float(val_loss) < float(min_val_loss):
111 | selected_fname = ckpt_file_name
112 | min_val_loss = val_loss
113 | except:
114 | raise Exception(f"the ckpt file name of downstream_model_directory \"{downstream_model_dir}\" is not valid")
115 | self.downstream_model_checkpoint_fpath = selected_fname
116 | else:
117 | raise Exception("Either downstream_model_dir or downstream_model_checkpoint_fpath must be entered.")
118 | print(f"downstream_model_checkpoint_fpath: {self.downstream_model_checkpoint_fpath}")
119 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class NERTrainArguments:
8 |
9 | pretrained_model_name: str = field(
10 | default="beomi/kcbert-base",
11 | metadata={"help": "pretrained model name"}
12 | )
13 | downstream_task_name: str = field(
14 | default="named-entity-recognition",
15 | metadata={"help": "The name of the downstream data."}
16 | )
17 | downstream_corpus_name: str = field(
18 | default="ner",
19 | metadata={"help": "The name of the downstream data."}
20 | )
21 | downstream_corpus_root_dir: str = field(
22 | default="/content/Korpora",
23 | metadata={"help": "The root directory of the downstream data."}
24 | )
25 | downstream_model_dir: str = field(
26 | default="/gdrive/My Drive/nlpbook/checkpoint-ner",
27 | metadata={"help": "The output model dir."}
28 | )
29 | max_seq_length: int = field(
30 | default=128,
31 | metadata={
32 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
33 | "than this will be truncated, sequences shorter will be padded."
34 | }
35 | )
36 | save_top_k: int = field(
37 | default=1,
38 | metadata={"help": "save top k model checkpoints."}
39 | )
40 | monitor: str = field(
41 | default="min val_loss",
42 | metadata={"help": "monitor condition (save top k)"}
43 | )
44 | seed: int = field(
45 | default=None,
46 | metadata={"help": "random seed."}
47 | )
48 | overwrite_cache: bool = field(
49 | default=False,
50 | metadata={"help": "Overwrite the cached training and evaluation sets"}
51 | )
52 | force_download: bool = field(
53 | default=False,
54 | metadata={"help": "force to download downstream data and pretrained models."}
55 | )
56 | test_mode: bool = field(
57 | default=False,
58 | metadata={"help": "Test Mode enables `fast_dev_run`"}
59 | )
60 | learning_rate: float = field(
61 | default=5e-5,
62 | metadata={"help": "learning rate"}
63 | )
64 | epochs: int = field(
65 | default=3,
66 | metadata={"help": "max epochs"}
67 | )
68 | batch_size: int = field(
69 | default=32,
70 | metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"}
71 | )
72 | cpu_workers: int = field(
73 | default=os.cpu_count(),
74 | metadata={"help": "number of CPU workers"}
75 | )
76 | fp16: bool = field(
77 | default=False,
78 | metadata={"help": "Enable train on FP16"}
79 | )
80 | tpu_cores: int = field(
81 | default=0,
82 | metadata={"help": "Enable TPU with 1 core or 8 cores"}
83 | )
84 |
85 |
86 | @dataclass
87 | class NERDeployArguments:
88 |
89 | def __init__(
90 | self,
91 | pretrained_model_name=None,
92 | downstream_model_dir=None,
93 | downstream_model_checkpoint_fpath=None,
94 | downstream_model_labelmap_fpath=None,
95 | max_seq_length=128,
96 | ):
97 | self.pretrained_model_name = pretrained_model_name
98 | self.max_seq_length = max_seq_length
99 | if downstream_model_checkpoint_fpath is not None and downstream_model_labelmap_fpath is not None:
100 | self.downstream_model_checkpoint_fpath = downstream_model_checkpoint_fpath
101 | self.downstream_model_labelmap_fpath = downstream_model_labelmap_fpath
102 | elif downstream_model_dir is not None:
103 | ckpt_file_names = glob(os.path.join(downstream_model_dir, "*.ckpt"))
104 | ckpt_file_names = [el for el in ckpt_file_names if "temp" not in el and "tmp" not in el]
105 | if len(ckpt_file_names) == 0:
106 | raise Exception(f"downstream_model_dir \"{downstream_model_dir}\" is not valid")
107 | selected_fname = ckpt_file_names[-1]
108 | min_val_loss = os.path.split(selected_fname)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
109 | try:
110 | for ckpt_file_name in ckpt_file_names:
111 | val_loss = os.path.split(ckpt_file_name)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
112 | if float(val_loss) < float(min_val_loss):
113 | selected_fname = ckpt_file_name
114 | min_val_loss = val_loss
115 | except:
116 | raise Exception(f"the ckpt file name of downstream_model_directory \"{downstream_model_dir}\" is not valid")
117 | self.downstream_model_checkpoint_fpath = selected_fname
118 | self.downstream_model_labelmap_fpath = os.path.join(downstream_model_dir, "label_map.txt")
119 | else:
120 | raise Exception("Either downstream_model_dir or downstream_model_checkpoint_fpath must be entered.")
121 | print(f"downstream_model_checkpoint_fpath: {self.downstream_model_checkpoint_fpath}")
122 | print(f"downstream_model_labelmap_fpath: {self.downstream_model_labelmap_fpath}")
123 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
문장 생성 (Sentence Generation)
11 |
프롬프트에 이어진 문장을 생성합니다.
아래 입력 란의 괄호는 기본값입니다.
12 |
40 |
62 |
68 |
69 |
73 |
74 |
75 |
이 웹 데모가 어떻게 동작하는지 자세히 확인하고 싶으신 분은 ratsgo's nlpbook을 참고하세요. Copyright © 2020 Gichang LEE. Distributed by an CC BY-NC-SA 3.0 license.
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/arguments.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class QATrainArguments:
8 |
9 | pretrained_model_name: str = field(
10 | default="beomi/kcbert-base",
11 | metadata={"help": "pretrained model name"}
12 | )
13 | downstream_corpus_name: str = field(
14 | default="korquad-v1",
15 | metadata={"help": "The name of the downstream data."}
16 | )
17 | downstream_corpus_root_dir: str = field(
18 | default="/content/Korpora",
19 | metadata={"help": "The root directory of the downstream data."}
20 | )
21 | downstream_model_dir: str = field(
22 | default="/gdrive/My Drive/nlpbook/checkpoint-qa",
23 | metadata={"help": "The output model dir."}
24 | )
25 | max_seq_length: int = field(
26 | default=128,
27 | metadata={
28 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
29 | "than this will be truncated, sequences shorter will be padded."
30 | }
31 | )
32 | doc_stride: int = field(
33 | default=64,
34 | metadata={
35 | "help": "When splitting up a long document into chunks, how much stride to take between chunks."
36 | }
37 | )
38 | max_query_length: int = field(
39 | default=32,
40 | metadata={
41 | "help": "The maximum number of tokens for the question. Questions longer than this will "
42 | "be truncated to this length."
43 | }
44 | )
45 | threads: int = field(
46 | default=4,
47 | metadata={
48 | "help": "the number of threads, using for preprocessing"
49 | }
50 | )
51 | cpu_workers: int = field(
52 | default=os.cpu_count(),
53 | metadata={"help": "number of CPU workers"}
54 | )
55 | save_top_k: int = field(
56 | default=1,
57 | metadata={"help": "save top k model checkpoints."}
58 | )
59 | monitor: str = field(
60 | default="min val_loss",
61 | metadata={"help": "monitor condition (save top k)"}
62 | )
63 | seed: int = field(
64 | default=None,
65 | metadata={"help": "random seed."}
66 | )
67 | overwrite_cache: bool = field(
68 | default=False,
69 | metadata={"help": "Overwrite the cached training and evaluation sets"}
70 | )
71 | force_download: bool = field(
72 | default=False,
73 | metadata={"help": "force to download downstream data and pretrained models."}
74 | )
75 | test_mode: bool = field(
76 | default=False,
77 | metadata={"help": "Test Mode enables `fast_dev_run`"}
78 | )
79 | learning_rate: float = field(
80 | default=5e-5,
81 | metadata={"help": "learning rate"}
82 | )
83 | epochs: int = field(
84 | default=3,
85 | metadata={"help": "max epochs"}
86 | )
87 | batch_size: int = field(
88 | default=32,
89 | metadata={"help": "batch size. if 0, Let PyTorch Lightening find the best batch size"}
90 | )
91 | fp16: bool = field(
92 | default=False,
93 | metadata={"help": "Enable train on FP16"}
94 | )
95 | tpu_cores: int = field(
96 | default=0,
97 | metadata={"help": "Enable TPU with 1 core or 8 cores"}
98 | )
99 | tqdm_enabled: bool = field(
100 | default=True,
101 | metadata={"help": "do tqdn enabled or not"}
102 | )
103 |
104 |
105 | @dataclass
106 | class QADeployArguments:
107 |
108 | def __init__(
109 | self,
110 | pretrained_model_name=None,
111 | downstream_model_dir=None,
112 | downstream_model_checkpoint_fpath=None,
113 | max_seq_length=128,
114 | max_query_length=32,
115 | ):
116 | self.pretrained_model_name = pretrained_model_name
117 | self.max_seq_length = max_seq_length
118 | self.max_query_length = max_query_length
119 | if downstream_model_checkpoint_fpath is not None:
120 | self.downstream_model_checkpoint_fpath = downstream_model_checkpoint_fpath
121 | elif downstream_model_dir is not None:
122 | ckpt_file_names = glob(os.path.join(downstream_model_dir, "*.ckpt"))
123 | ckpt_file_names = [el for el in ckpt_file_names if "temp" not in el and "tmp" not in el]
124 | if len(ckpt_file_names) == 0:
125 | raise Exception(f"downstream_model_dir \"{downstream_model_dir}\" is not valid")
126 | selected_fname = ckpt_file_names[-1]
127 | min_val_loss = os.path.split(selected_fname)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
128 | try:
129 | for ckpt_file_name in ckpt_file_names:
130 | val_loss = os.path.split(ckpt_file_name)[-1].replace(".ckpt", "").split("=")[-1].split("-")[0]
131 | if float(val_loss) < float(min_val_loss):
132 | selected_fname = ckpt_file_name
133 | min_val_loss = val_loss
134 | except:
135 | raise Exception(f"the ckpt file name of downstream_model_directory \"{downstream_model_dir}\" is not valid")
136 | self.downstream_model_checkpoint_fpath = selected_fname
137 | else:
138 | raise Exception("Either downstream_model_dir or downstream_model_checkpoint_fpath must be entered.")
139 | print(f"downstream_model_checkpoint_fpath: {self.downstream_model_checkpoint_fpath}")
140 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/generation/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import time
4 | import torch
5 | import logging
6 | from filelock import FileLock
7 | from dataclasses import dataclass
8 | from typing import List, Optional
9 | from torch.utils.data.dataset import Dataset
10 | from transformers import PreTrainedTokenizerFast
11 | from ratsnlp.nlpbook.generation.arguments import GenerationTrainArguments
12 |
13 |
14 | logger = logging.getLogger("ratsnlp")
15 |
16 |
17 | @dataclass
18 | class GenerationExample:
19 | text: str
20 |
21 |
22 | @dataclass
23 | class GenerationFeatures:
24 | input_ids: List[int]
25 | attention_mask: Optional[List[int]] = None
26 | token_type_ids: Optional[List[int]] = None
27 | labels: Optional[List[int]] = None
28 |
29 |
30 | class NsmcCorpus:
31 |
32 | def __init__(self):
33 | pass
34 |
35 | def _read_corpus(cls, input_file, quotechar='"'):
36 | with open(input_file, "r", encoding="utf-8") as f:
37 | return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
38 |
39 | def _create_examples(self, lines):
40 | examples = []
41 | for (i, line) in enumerate(lines):
42 | if i == 0:
43 | continue
44 | _, review_sentence, sentiment = line
45 | sentiment = "긍정" if sentiment == "1" else "부정"
46 | text = sentiment + " " + review_sentence
47 | examples.append(GenerationExample(text=text))
48 | return examples
49 |
50 | def get_examples(self, data_root_path, mode):
51 | data_fpath = os.path.join(data_root_path, f"ratings_{mode}.txt")
52 | logger.info(f"loading {mode} data... LOOKING AT {data_fpath}")
53 | return self._create_examples(self._read_corpus(data_fpath))
54 |
55 |
56 | def _convert_examples_to_generation_features(
57 | examples: List[GenerationExample],
58 | tokenizer: PreTrainedTokenizerFast,
59 | args: GenerationTrainArguments,
60 | ):
61 |
62 | logger.info(
63 | "tokenize sentences, it could take a lot of time..."
64 | )
65 | start = time.time()
66 | batch_encoding = tokenizer(
67 | [example.text for example in examples],
68 | max_length=args.max_seq_length,
69 | padding="max_length",
70 | truncation=True,
71 | )
72 | logger.info(
73 | "tokenize sentences [took %.3f s]", time.time() - start
74 | )
75 |
76 | features = []
77 | for i in range(len(examples)):
78 | inputs = {k: batch_encoding[k][i] for k in batch_encoding}
79 | feature = GenerationFeatures(**inputs, labels=batch_encoding["input_ids"][i])
80 | features.append(feature)
81 |
82 | for i, example in enumerate(examples[:5]):
83 | logger.info("*** Example ***")
84 | logger.info("sentence: %s" % (example.text))
85 | logger.info("tokens: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids))))
86 | logger.info("features: %s" % features[i])
87 |
88 | return features
89 |
90 |
91 | class GenerationDataset(Dataset):
92 |
93 | def __init__(
94 | self,
95 | args: GenerationTrainArguments,
96 | tokenizer: PreTrainedTokenizerFast,
97 | corpus,
98 | mode: Optional[str] = "train",
99 | convert_examples_to_features_fn=_convert_examples_to_generation_features,
100 | ):
101 | if corpus is not None:
102 | self.corpus = corpus
103 | else:
104 | raise KeyError("corpus is not valid")
105 | if not mode in ["train", "val", "test"]:
106 | raise KeyError(f"mode({mode}) is not a valid split name")
107 | # Load data features from cache or dataset file
108 | cached_features_file = os.path.join(
109 | args.downstream_corpus_root_dir,
110 | args.downstream_corpus_name,
111 | "cached_{}_{}_{}_{}_{}".format(
112 | mode,
113 | tokenizer.__class__.__name__,
114 | str(args.max_seq_length),
115 | args.downstream_corpus_name,
116 | args.downstream_task_name,
117 | ),
118 | )
119 |
120 | # Make sure only the first process in distributed training processes the dataset,
121 | # and the others will use the cache.
122 | lock_path = cached_features_file + ".lock"
123 | with FileLock(lock_path):
124 |
125 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
126 | start = time.time()
127 | self.features = torch.load(cached_features_file)
128 | logger.info(
129 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
130 | )
131 | else:
132 | corpus_path = os.path.join(
133 | args.downstream_corpus_root_dir,
134 | args.downstream_corpus_name,
135 | )
136 | logger.info(f"Creating features from dataset file at {corpus_path}")
137 | examples = self.corpus.get_examples(corpus_path, mode)
138 | tokenizer.pad_token = tokenizer.eos_token
139 | self.features = convert_examples_to_features_fn(
140 | examples,
141 | tokenizer,
142 | args,
143 | )
144 | start = time.time()
145 | logger.info(
146 | "Saving features into cached file, it could take a lot of time..."
147 | )
148 | torch.save(self.features, cached_features_file)
149 | logger.info(
150 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
151 | )
152 |
153 | def __len__(self):
154 | return len(self.features)
155 |
156 | def __getitem__(self, i):
157 | return self.features[i]
158 |
159 | def get_labels(self):
160 | return self.corpus.get_labels()
161 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/classification/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 | import time
4 | import torch
5 | import logging
6 | from filelock import FileLock
7 | from dataclasses import dataclass
8 | from typing import List, Optional
9 | from torch.utils.data.dataset import Dataset
10 | from transformers import PreTrainedTokenizer
11 | from ratsnlp.nlpbook.classification.arguments import ClassificationTrainArguments
12 |
13 |
14 | logger = logging.getLogger("ratsnlp")
15 |
16 |
17 | @dataclass
18 | class ClassificationExample:
19 | text_a: str
20 | text_b: Optional[str] = None
21 | label: Optional[str] = None
22 |
23 |
24 | @dataclass
25 | class ClassificationFeatures:
26 | input_ids: List[int]
27 | attention_mask: Optional[List[int]] = None
28 | token_type_ids: Optional[List[int]] = None
29 | label: Optional[int] = None
30 |
31 |
32 | class NsmcCorpus:
33 |
34 | def __init__(self):
35 | pass
36 |
37 | def get_examples(self, data_root_path, mode):
38 | data_fpath = os.path.join(data_root_path, f"ratings_{mode}.txt")
39 | logger.info(f"loading {mode} data... LOOKING AT {data_fpath}")
40 | lines = list(csv.reader(open(data_fpath, "r", encoding="utf-8"), delimiter="\t", quotechar='"'))
41 | examples = []
42 | for (i, line) in enumerate(lines):
43 | if i == 0:
44 | continue
45 | _, text_a, label = line
46 | examples.append(ClassificationExample(text_a=text_a, text_b=None, label=label))
47 | return examples
48 |
49 | def get_labels(self):
50 | return ["0", "1"]
51 |
52 | @property
53 | def num_labels(self):
54 | return len(self.get_labels())
55 |
56 |
57 | def _convert_examples_to_classification_features(
58 | examples: List[ClassificationExample],
59 | tokenizer: PreTrainedTokenizer,
60 | args: ClassificationTrainArguments,
61 | label_list: List[str],
62 | ):
63 | label_map = {label: i for i, label in enumerate(label_list)}
64 | labels = [label_map[example.label] for example in examples]
65 |
66 | logger.info(
67 | "tokenize sentences, it could take a lot of time..."
68 | )
69 | start = time.time()
70 | batch_encoding = tokenizer(
71 | [(example.text_a, example.text_b) for example in examples],
72 | max_length=args.max_seq_length,
73 | padding="max_length",
74 | truncation=True,
75 | )
76 | logger.info(
77 | "tokenize sentences [took %.3f s]", time.time() - start
78 | )
79 |
80 | features = []
81 | for i in range(len(examples)):
82 | inputs = {k: batch_encoding[k][i] for k in batch_encoding}
83 | feature = ClassificationFeatures(**inputs, label=labels[i])
84 | features.append(feature)
85 |
86 | for i, example in enumerate(examples[:5]):
87 | logger.info("*** Example ***")
88 | if example.text_b is None:
89 | logger.info("sentence: %s" % (example.text_a))
90 | else:
91 | sentence = example.text_a + " + " + example.text_b
92 | logger.info("sentence A, B: %s" % (sentence))
93 | logger.info("tokens: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids))))
94 | logger.info("label: %s" % (example.label))
95 | logger.info("features: %s" % features[i])
96 |
97 | return features
98 |
99 |
100 | class ClassificationDataset(Dataset):
101 |
102 | def __init__(
103 | self,
104 | args: ClassificationTrainArguments,
105 | tokenizer: PreTrainedTokenizer,
106 | corpus,
107 | mode: Optional[str] = "train",
108 | convert_examples_to_features_fn=_convert_examples_to_classification_features,
109 | ):
110 | if corpus is not None:
111 | self.corpus = corpus
112 | else:
113 | raise KeyError("corpus is not valid")
114 | if not mode in ["train", "val", "test"]:
115 | raise KeyError(f"mode({mode}) is not a valid split name")
116 | # Load data features from cache or dataset file
117 | cached_features_file = os.path.join(
118 | args.downstream_corpus_root_dir,
119 | args.downstream_corpus_name,
120 | "cached_{}_{}_{}_{}_{}".format(
121 | mode,
122 | tokenizer.__class__.__name__,
123 | str(args.max_seq_length),
124 | args.downstream_corpus_name,
125 | args.downstream_task_name,
126 | ),
127 | )
128 |
129 | # Make sure only the first process in distributed training processes the dataset,
130 | # and the others will use the cache.
131 | lock_path = cached_features_file + ".lock"
132 | with FileLock(lock_path):
133 |
134 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
135 | start = time.time()
136 | self.features = torch.load(cached_features_file)
137 | logger.info(
138 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
139 | )
140 | else:
141 | corpus_path = os.path.join(
142 | args.downstream_corpus_root_dir,
143 | args.downstream_corpus_name,
144 | )
145 | logger.info(f"Creating features from dataset file at {corpus_path}")
146 | examples = self.corpus.get_examples(corpus_path, mode)
147 | self.features = convert_examples_to_features_fn(
148 | examples,
149 | tokenizer,
150 | args,
151 | label_list=self.corpus.get_labels(),
152 | )
153 | start = time.time()
154 | logger.info(
155 | "Saving features into cached file, it could take a lot of time..."
156 | )
157 | torch.save(self.features, cached_features_file)
158 | logger.info(
159 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
160 | )
161 |
162 | def __len__(self):
163 | return len(self.features)
164 |
165 | def __getitem__(self, i):
166 | return self.features[i]
167 |
168 | def get_labels(self):
169 | return self.corpus.get_labels()
170 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tqdm
4 | import logging
5 | import requests
6 | from transformers import HfArgumentParser
7 |
8 |
9 | REMOTE_DATA_MAP = {
10 | "nsmc": {
11 | "train": {
12 | "web_url": "https://github.com/e9t/nsmc/raw/master/ratings_train.txt",
13 | "fname": "train.txt",
14 | },
15 | "val": {
16 | "web_url": "https://github.com/e9t/nsmc/raw/master/ratings_test.txt",
17 | "fname": "val.txt",
18 | },
19 | },
20 | "klue-nli": {
21 | "train": {
22 | "googledrive_file_id": "18LhrHaPEW0VITMPfnwKXJ6bNuklBdi4U",
23 | "fname": "klue_nli_train.json",
24 | },
25 | "val": {
26 | "googledrive_file_id": "1UKIDAFOFuDSah7A66FZXSA8XUWUHhBAd",
27 | "fname": "klue_nli_dev.json",
28 | }
29 | },
30 | "ner": {
31 | "train": {
32 | "googledrive_file_id": "1RP764owqs1kZeHcjFnCX7zXt2EcjGY1i",
33 | "fname": "train.txt",
34 | },
35 | "val": {
36 | "googledrive_file_id": "1bEPNWT5952rD3xjg0LfJBy3hLHry3yUL",
37 | "fname": "val.txt",
38 | },
39 | },
40 | "korquad-v1": {
41 | "train": {
42 | "web_url": "https://korquad.github.io/dataset/KorQuAD_v1.0_train.json",
43 | "fname": "KorQuAD_v1.0_train.json",
44 | },
45 | "val": {
46 | "web_url": "https://korquad.github.io/dataset/KorQuAD_v1.0_dev.json",
47 | "fname": "KorQuAD_v1.0_dev.json",
48 | }
49 | }
50 | }
51 |
52 | REMOTE_MODEL_MAP = {
53 | "kogpt2": {
54 | "merges": {
55 | "googledrive_file_id": "19-vpk-RAPhmIM1pPJ66F2Kbj4dW5V5sV",
56 | "fname": "merges.txt",
57 | },
58 | "vocab": {
59 | "googledrive_file_id": "19vjuxYOmlNTfg8kYKOPOUlZERm-QoTnj",
60 | "fname": "vocab.json",
61 | },
62 | "model": {
63 | "googledrive_file_id": "1dDGtsMy1NsfpuvgX8XobBsCYyctn5Xex",
64 | "fname": "pytorch_model.bin",
65 | },
66 | "config": {
67 | "googledrive_file_id": "1z6obNRWPHoVrMzT9THElblebdovuDLUZ",
68 | "fname": "config.json",
69 | },
70 | },
71 | }
72 | GOOGLE_DRIVE_URL = "https://docs.google.com/uc?export=download"
73 | logger = logging.getLogger("ratsnlp") # pylint: disable=invalid-name
74 |
75 |
76 | def save_response_content(response, save_path):
77 | with open(save_path, "wb") as f:
78 | content_length = response.headers.get("Content-Length")
79 | total = int(content_length) if content_length is not None else None
80 | progress = tqdm.tqdm(
81 | unit="B",
82 | unit_scale=True,
83 | total=total,
84 | initial=0,
85 | desc="Downloading",
86 | disable=bool(logger.getEffectiveLevel() == logging.NOTSET),
87 | )
88 | for chunk in response.iter_content(chunk_size=1024):
89 | if chunk: # filter out keep-alive new chunks
90 | progress.update(len(chunk))
91 | f.write(chunk)
92 | progress.close()
93 |
94 |
95 | def get_valid_path(cache_dir, save_fname, make_dir=True):
96 | # 캐시 디렉토리 절대 주소 확인
97 | if cache_dir.startswith("~"):
98 | cache_dir = os.path.expanduser(cache_dir)
99 | else:
100 | cache_dir = os.path.abspath(cache_dir)
101 | if make_dir:
102 | os.makedirs(cache_dir, exist_ok=True)
103 | valid_save_path = os.path.join(cache_dir, save_fname)
104 | return valid_save_path
105 |
106 |
107 | def google_download(file_id,
108 | save_fname,
109 | cache_dir="~/cache",
110 | force_download=False):
111 | def get_confirm_token(response):
112 | for key, value in response.cookies.items():
113 | if key.startswith('download_warning'):
114 | return value
115 | return None
116 | valid_save_path = get_valid_path(cache_dir, save_fname)
117 | # 캐시 파일이 있으면 캐시 사용
118 | if os.path.exists(valid_save_path) and not force_download:
119 | logger.info(f"cache file({valid_save_path}) exists, using cache!")
120 | return valid_save_path
121 | # init a HTTP session
122 | session = requests.Session()
123 | # make a request
124 | response = session.get(GOOGLE_DRIVE_URL, params={'id': file_id}, stream=True)
125 | # get confirmation token
126 | token = get_confirm_token(response)
127 | if token:
128 | params = {'id': file_id, 'confirm': token}
129 | response = session.get(GOOGLE_DRIVE_URL, params=params, stream=True)
130 | # download to disk
131 | save_response_content(response, valid_save_path)
132 | return valid_save_path
133 |
134 |
135 | def web_download(url,
136 | save_fname,
137 | cache_dir="~/cache",
138 | proxies=None,
139 | etag_timeout=10,
140 | force_download=False):
141 | """
142 | download function. 허깅페이스와 SK T-BRAIN 다운로드 함수 참고.
143 | https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
144 | https://github.com/SKTBrain/KoBERT/blob/master/kobert/utils.py
145 | """
146 | valid_save_path = get_valid_path(cache_dir, save_fname)
147 | # 캐시 파일이 있으면 캐시 사용
148 | if os.path.exists(valid_save_path) and not force_download:
149 | logger.info(f"cache file({valid_save_path}) exists, using cache!")
150 | return valid_save_path
151 | # url 유효성 체크
152 | # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
153 | etag = None
154 | try:
155 | response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
156 | if response.status_code == 200:
157 | etag = response.headers.get("ETag")
158 | except (EnvironmentError, requests.exceptions.Timeout):
159 | pass
160 | if etag is None:
161 | raise ValueError(f"not valid URL({url}), cannot download resources")
162 | response = requests.get(url, stream=True)
163 | save_response_content(response, valid_save_path)
164 | return valid_save_path
165 |
166 |
167 | def download_downstream_dataset(args):
168 | data_name = args.downstream_corpus_name.lower()
169 | if data_name in REMOTE_DATA_MAP.keys():
170 | cache_dir = os.path.join(args.downstream_corpus_root_dir, data_name)
171 | for value in REMOTE_DATA_MAP[data_name].values():
172 | if "web_url" in value.keys():
173 | web_download(
174 | url=value["web_url"],
175 | save_fname=value["fname"],
176 | cache_dir=cache_dir,
177 | force_download=args.force_download,
178 | )
179 | else:
180 | google_download(
181 | file_id=value["googledrive_file_id"],
182 | save_fname=value["fname"],
183 | cache_dir=cache_dir,
184 | force_download=args.force_download
185 | )
186 | else:
187 | raise ValueError(f"not valid data name({data_name}), cannot download resources")
188 |
189 |
190 | def download_pretrained_model(args, config_only=False):
191 | pretrained_model_name = args.pretrained_model_name.lower()
192 | if pretrained_model_name in REMOTE_MODEL_MAP.keys():
193 | for key, value in REMOTE_MODEL_MAP[pretrained_model_name].items():
194 | if not config_only or (config_only and key == "config"):
195 | if "web_url" in value.keys():
196 | web_download(
197 | url=value["web_url"],
198 | save_fname=value["fname"],
199 | cache_dir=args.pretrained_model_cache_dir,
200 | force_download=args.force_download,
201 | )
202 | else:
203 | google_download(
204 | file_id=value["googledrive_file_id"],
205 | save_fname=value["fname"],
206 | cache_dir=args.pretrained_model_cache_dir,
207 | force_download=args.force_download,
208 | )
209 | else:
210 | raise ValueError(f"not valid model name({pretrained_model_name}), cannot download resources")
211 |
212 |
213 | def set_logger(args):
214 | import torch
215 | if torch.cuda.is_available():
216 | stream_handler = logging.StreamHandler()
217 | formatter = logging.Formatter(
218 | fmt="%(levelname)s:%(name)s:%(message)s",
219 | )
220 | stream_handler.setFormatter(formatter)
221 | logger.addHandler(stream_handler)
222 | logger.setLevel(logging.INFO)
223 | logger.info("Training/evaluation parameters %s", args)
224 |
225 |
226 | def set_seed(args):
227 | if args.seed is not None:
228 | # 향후 pytorch-lightning의 seed_everything까지 확장
229 | from transformers import set_seed
230 | set_seed(args.seed)
231 | print(f"set seed: {args.seed}")
232 | else:
233 | print("not fixed seed")
234 |
235 |
236 | def load_arguments(argument_class, json_file_path=None):
237 | parser = HfArgumentParser(argument_class)
238 | if json_file_path is not None:
239 | args, = parser.parse_json_file(json_file=json_file_path)
240 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
241 | args, = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
242 | else:
243 | args, = parser.parse_args_into_dataclasses()
244 | return args
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/ner/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | import torch
5 | import logging
6 | from filelock import FileLock
7 | from typing import List, Optional
8 | from dataclasses import dataclass
9 | from transformers import BertTokenizer
10 | from torch.utils.data.dataset import Dataset
11 | from ratsnlp.nlpbook.ner import NERTrainArguments
12 | from transformers.tokenization_utils_base import PaddingStrategy, TruncationStrategy
13 |
14 |
15 | logger = logging.getLogger("ratsnlp")
16 |
17 |
18 | # 자체 제작 NER 코퍼스 기준의 레이블 시퀀스를 만들기 위한 ID 체계
19 | # 나 는 삼성 에 입사 했다
20 | # O O 기관 O O O > [CLS] O O 기관 O O O [SEP] [PAD] [PAD] ...
21 | NER_CLS_TOKEN = "[CLS]"
22 | NER_SEP_TOKEN = "[SEP]"
23 | NER_PAD_TOKEN = "[PAD]"
24 | NER_MASK_TOKEN = "[MASK]"
25 | NER_PAD_ID = 2
26 |
27 |
28 | @dataclass
29 | class NERExample:
30 | text: str
31 | label: Optional[str] = None
32 |
33 |
34 | @dataclass
35 | class NERFeatures:
36 | input_ids: List[int]
37 | attention_mask: Optional[List[int]] = None
38 | token_type_ids: Optional[List[int]] = None
39 | label_ids: Optional[List[int]] = None
40 |
41 |
42 | class NERCorpus:
43 |
44 | def __init__(
45 | self,
46 | args: NERTrainArguments
47 | ):
48 | self.args = args
49 |
50 | def get_examples(self, data_root_path, mode):
51 | data_fpath = os.path.join(data_root_path, f"{mode}.txt")
52 | logger.info(f"loading {mode} data... LOOKING AT {data_fpath}")
53 | examples = []
54 | for line in open(data_fpath, "r", encoding="utf-8").readlines():
55 | text, label = line.split("\u241E")
56 | examples.append(NERExample(text=text, label=label))
57 | return examples
58 |
59 | def get_labels(self):
60 | label_map_path = os.path.join(
61 | self.args.downstream_model_dir,
62 | "label_map.txt",
63 | )
64 | if not os.path.exists(label_map_path):
65 | logger.info("processing NER tag dictionary...")
66 | os.makedirs(self.args.downstream_model_dir, exist_ok=True)
67 | ner_tags = []
68 | regex_ner = re.compile('<(.+?):[A-Z]{3}>')
69 | train_corpus_path = os.path.join(
70 | self.args.downstream_corpus_root_dir,
71 | self.args.downstream_corpus_name,
72 | "train.txt",
73 | )
74 | target_sentences = [line.split("\u241E")[1].strip()
75 | for line in open(train_corpus_path, "r", encoding="utf-8").readlines()]
76 | for target_sentence in target_sentences:
77 | regex_filter_res = regex_ner.finditer(target_sentence)
78 | for match_item in regex_filter_res:
79 | ner_tag = match_item[0][-4:-1]
80 | if ner_tag not in ner_tags:
81 | ner_tags.append(ner_tag)
82 | b_tags = [f"B-{ner_tag}" for ner_tag in ner_tags]
83 | i_tags = [f"I-{ner_tag}" for ner_tag in ner_tags]
84 | labels = [NER_CLS_TOKEN, NER_SEP_TOKEN, NER_PAD_TOKEN, NER_MASK_TOKEN, "O"] + b_tags + i_tags
85 | with open(label_map_path, "w", encoding="utf-8") as f:
86 | for tag in labels:
87 | f.writelines(tag + "\n")
88 | else:
89 | labels = [tag.strip() for tag in open(label_map_path, "r", encoding="utf-8").readlines()]
90 | return labels
91 |
92 | @property
93 | def num_labels(self):
94 | return len(self.get_labels())
95 |
96 |
97 | def _process_target_sentence(
98 | tokens: List[str],
99 | origin_sentence: str,
100 | target_sentence: str,
101 | max_length: int,
102 | label_map: dict,
103 | tokenizer: BertTokenizer,
104 | cls_token_at_end: Optional[bool] = False,
105 | ):
106 | """
107 | target_sentence = "―<효진:PER> 역의 <김환희:PER>(<14:NOH>)가 특히 인상적이었다."
108 | tokens = ["―", "효", "##진", "역", "##의", "김", "##환", "##희",
109 | "(", "14", ")", "가", "특히", "인상", "##적이", "##었다", "."]
110 | label_sequence = ['O', 'B-PER', 'I-PER', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O',
111 | 'B-NOH', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
112 | """
113 | if "[UNK]" in tokens:
114 | processed_tokens = []
115 | basic_tokens = tokenizer.basic_tokenizer.tokenize(origin_sentence)
116 | for basic_token in basic_tokens:
117 | current_tokens = tokenizer.tokenize(basic_token)
118 | if "[UNK]" in current_tokens:
119 | # [UNK] 복원
120 | processed_tokens.append(basic_token)
121 | else:
122 | processed_tokens.extend(current_tokens)
123 | else:
124 | processed_tokens = tokens
125 |
126 | prefix_sum_of_token_start_index, sum = [0], 0
127 | for i, token in enumerate(processed_tokens):
128 | if token.startswith("##"):
129 | sum += len(token) - 2
130 | else:
131 | sum += len(token)
132 | prefix_sum_of_token_start_index.append(sum)
133 |
134 | regex_ner = re.compile('<(.+?):[A-Z]{3}>') # NER Tag가 2자리 문자면 {3} -> {2}로 변경 (e.g. LOC -> LC) 인경우
135 | regex_filter_res = regex_ner.finditer(target_sentence.replace(" ", ""))
136 |
137 | list_of_ner_tag = []
138 | list_of_ner_text = []
139 | list_of_tuple_ner_start_end = []
140 |
141 | count_of_match = 0
142 | for match_item in regex_filter_res:
143 | ner_tag = match_item[0][-4:-1] # <4일간:DUR> -> DUR
144 | ner_text = match_item[1] # <4일간:DUR> -> 4일간
145 | start_index = match_item.start() - 6 * count_of_match # delete previous '<, :, 3 words tag name, >'
146 | end_index = match_item.end() - 6 - 6 * count_of_match
147 |
148 | list_of_ner_tag.append(ner_tag)
149 | list_of_ner_text.append(ner_text)
150 | list_of_tuple_ner_start_end.append((start_index, end_index))
151 | count_of_match += 1
152 |
153 | label_sequence = []
154 | entity_index = 0
155 | is_entity_still_B = True
156 |
157 | for tup in zip(processed_tokens, prefix_sum_of_token_start_index):
158 | token, index = tup
159 |
160 | if entity_index < len(list_of_tuple_ner_start_end):
161 | start, end = list_of_tuple_ner_start_end[entity_index]
162 |
163 | if end < index: # 엔티티 범위보다 현재 seq pos가 더 크면 다음 엔티티를 꺼내서 체크
164 | is_entity_still_B = True
165 | entity_index = entity_index + 1 if entity_index + 1 < len(list_of_tuple_ner_start_end) else entity_index
166 | start, end = list_of_tuple_ner_start_end[entity_index]
167 |
168 | if start <= index and index < end: # <13일:DAT>까지 -> ('▁13', 10, 'B-DAT') ('일까지', 12, 'I-DAT') 이런 경우가 포함됨, 포함 안시키려면 토큰의 length도 계산해서 제어해야함
169 | entity_tag = list_of_ner_tag[entity_index]
170 | if is_entity_still_B is True:
171 | entity_tag = 'B-' + entity_tag
172 | label_sequence.append(entity_tag)
173 | is_entity_still_B = False
174 | else:
175 | entity_tag = 'I-' + entity_tag
176 | label_sequence.append(entity_tag)
177 | else:
178 | is_entity_still_B = True
179 | entity_tag = 'O'
180 | label_sequence.append(entity_tag)
181 | else:
182 | entity_tag = 'O'
183 | label_sequence.append(entity_tag)
184 |
185 | # truncation
186 | label_sequence = label_sequence[:max_length - 2]
187 |
188 | # add special tokens
189 | if cls_token_at_end:
190 | label_sequence = label_sequence + [NER_CLS_TOKEN, NER_SEP_TOKEN]
191 | else:
192 | label_sequence = [NER_CLS_TOKEN] + label_sequence + [NER_SEP_TOKEN]
193 |
194 | # padding
195 | pad_length = max(max_length - len(label_sequence), 0)
196 | pad_sequence = [NER_PAD_TOKEN] * pad_length
197 | label_sequence += pad_sequence
198 |
199 | # encoding
200 | label_ids = [label_map[label] for label in label_sequence]
201 | return label_ids
202 |
203 |
204 | def _convert_examples_to_ner_features(
205 | examples: List[NERExample],
206 | tokenizer: BertTokenizer,
207 | args: NERTrainArguments,
208 | label_list: List[str],
209 | cls_token_at_end: Optional[bool] = False,
210 | ):
211 | """
212 | `cls_token_at_end` define the location of the CLS token:
213 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
214 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
215 | """
216 | label_map = {label: i for i, label in enumerate(label_list)}
217 | id_to_label = {i: label for i, label in enumerate(label_list)}
218 |
219 | features = []
220 | for example in examples:
221 | tokens = tokenizer.tokenize(example.text)
222 | inputs = tokenizer._encode_plus(
223 | tokens,
224 | max_length=args.max_seq_length,
225 | truncation_strategy=TruncationStrategy.LONGEST_FIRST,
226 | padding_strategy=PaddingStrategy.MAX_LENGTH,
227 | )
228 | label_ids = _process_target_sentence(
229 | tokens=tokens,
230 | origin_sentence=example.text,
231 | target_sentence=example.label,
232 | max_length=args.max_seq_length,
233 | label_map=label_map,
234 | tokenizer=tokenizer,
235 | cls_token_at_end=cls_token_at_end,
236 | )
237 | features.append(NERFeatures(**inputs, label_ids=label_ids))
238 |
239 | for i, example in enumerate(examples[:5]):
240 | logger.info("*** Example ***")
241 | logger.info("sentence: %s" % (example.text))
242 | logger.info("target: %s" % (example.label))
243 | logger.info("tokens: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids))))
244 | logger.info("label: %s" % (" ".join([id_to_label[label_id] for label_id in features[i].label_ids])))
245 | logger.info("features: %s" % features[i])
246 |
247 | return features
248 |
249 |
250 | class NERDataset(Dataset):
251 |
252 | def __init__(
253 | self,
254 | args: NERTrainArguments,
255 | tokenizer: BertTokenizer,
256 | corpus: NERCorpus,
257 | mode: Optional[str] = "train",
258 | convert_examples_to_features_fn=_convert_examples_to_ner_features,
259 | ):
260 | if corpus is not None:
261 | self.corpus = corpus
262 | else:
263 | raise KeyError("corpus is not valid")
264 | if not mode in ["train", "val", "test"]:
265 | raise KeyError(f"mode({mode}) is not a valid split name")
266 | # Load data features from cache or dataset file
267 | cached_features_file = os.path.join(
268 | args.downstream_corpus_root_dir,
269 | args.downstream_corpus_name,
270 | "cached_{}_{}_{}_{}_{}".format(
271 | mode,
272 | tokenizer.__class__.__name__,
273 | str(args.max_seq_length),
274 | args.downstream_corpus_name,
275 | args.downstream_task_name,
276 | ),
277 | )
278 |
279 | # Make sure only the first process in distributed training processes the dataset,
280 | # and the others will use the cache.
281 | lock_path = cached_features_file + ".lock"
282 | with FileLock(lock_path):
283 |
284 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
285 | start = time.time()
286 | self.features = torch.load(cached_features_file)
287 | logger.info(
288 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
289 | )
290 | else:
291 | corpus_path = os.path.join(
292 | args.downstream_corpus_root_dir,
293 | args.downstream_corpus_name,
294 | )
295 | logger.info(f"Creating features from dataset file at {corpus_path}")
296 | examples = self.corpus.get_examples(corpus_path, mode)
297 | self.features = convert_examples_to_features_fn(
298 | examples,
299 | tokenizer,
300 | args,
301 | label_list=self.corpus.get_labels(),
302 | )
303 | start = time.time()
304 | logger.info(
305 | "Saving features into cached file, it could take a lot of time..."
306 | )
307 | torch.save(self.features, cached_features_file)
308 | logger.info(
309 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
310 | )
311 |
312 | def __len__(self):
313 | return len(self.features)
314 |
315 | def __getitem__(self, i):
316 | return self.features[i]
317 |
318 | def get_labels(self):
319 | return self.corpus.get_labels()
320 |
--------------------------------------------------------------------------------
/ratsnlp/nlpbook/qa/corpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import torch
5 | import logging
6 | from tqdm import tqdm
7 | from functools import partial
8 | from filelock import FileLock
9 | from dataclasses import dataclass
10 | from typing import List, Optional
11 | from multiprocessing import Pool, cpu_count
12 | from transformers import PreTrainedTokenizer
13 | from torch.utils.data.dataset import Dataset
14 | from ratsnlp.nlpbook.qa import QATrainArguments
15 |
16 |
17 | logger = logging.getLogger("ratsnlp")
18 |
19 |
20 | @dataclass
21 | class QAExample:
22 | # 질문 : 임종석이 여의도 농민 폭력 시위를 주도한 혐의로 지명수배 된 날은?
23 | question_text: str
24 | # (답 찾는 대상인)지문 : 1989년 2월 15일 여의도 농민 폭력 시위를 주도한 혐의 ... 서울지방경찰청 공안분실로 인계되었다.
25 | context_text: str
26 | # 답변 : 1989년 2월 15일
27 | answer_text: str
28 | # 답변의 시작 위치(음절 수 기준) : 0
29 | start_position_character: Optional[int] = None
30 |
31 |
32 | class QACorpus:
33 |
34 | def __init__(self):
35 | pass
36 |
37 | def get_examples(self, corpus_dir, mode):
38 | """
39 | :return: List[QAExample]
40 | """
41 | raise NotImplementedError
42 |
43 |
44 | class KorQuADV1Corpus(QACorpus):
45 |
46 | def __init__(self):
47 | super().__init__()
48 | self.train_file = "KorQuAD_v1.0_train.json"
49 | self.val_file = "KorQuAD_v1.0_dev.json"
50 |
51 | def get_examples(self, corpus_dir, mode):
52 | examples = []
53 | if mode == "train":
54 | corpus_fpath = os.path.join(corpus_dir, self.train_file)
55 | elif mode == "val":
56 | corpus_fpath = os.path.join(corpus_dir, self.val_file)
57 | else:
58 | raise KeyError(f"mode({mode}) is not a valid split name")
59 | json_data = json.load(open(corpus_fpath, "r", encoding="utf-8"))["data"]
60 | for entry in tqdm(json_data):
61 | for paragraph in entry["paragraphs"]:
62 | context_text = paragraph["context"]
63 | for qa in paragraph["qas"]:
64 | question_text = qa["question"]
65 | for answer in qa["answers"]:
66 | answer_text = answer["text"]
67 | start_position_character = answer["answer_start"]
68 | if question_text and answer_text and context_text and start_position_character:
69 | example = QAExample(
70 | question_text=question_text,
71 | context_text=context_text,
72 | answer_text=answer_text,
73 | start_position_character=start_position_character,
74 | )
75 | examples.append(example)
76 | return examples
77 |
78 |
79 | @dataclass
80 | class QAFeatures:
81 | input_ids: List[int]
82 | attention_mask: List[int]
83 | token_type_ids: List[int]
84 | # start_positions : 지문상 시작 토큰 위치 (wordpiece 토큰 기준)
85 | start_positions: int
86 | # end_position : 지문상 끝 토큰 위치 (wordpiece 토큰 기준)
87 | end_positions: int
88 |
89 |
90 | def _squad_convert_example_to_features_init(tokenizer_for_convert):
91 | global tokenizer
92 | tokenizer = tokenizer_for_convert
93 |
94 |
95 | def _is_whitespace(c):
96 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
97 | return True
98 | return False
99 |
100 |
101 | def _whitespace_tokenize(text):
102 | """Runs basic whitespace cleaning and splitting on a piece of text."""
103 | text = text.strip()
104 | if not text:
105 | return []
106 | tokens = text.split()
107 | return tokens
108 |
109 |
110 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
111 | """Returns tokenized answer spans that better match the annotated answer."""
112 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
113 | for new_start in range(input_start, input_end + 1):
114 | for new_end in range(input_end, new_start - 1, -1):
115 | text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
116 | if text_span == tok_answer_text:
117 | return new_start, new_end
118 | return input_start, input_end
119 |
120 |
121 | def _squad_convert_example_to_features(example, max_seq_length, doc_stride, max_query_length):
122 | features = []
123 |
124 | doc_tokens, char_to_word_offset = [], []
125 | prev_is_whitespace = True
126 | # Split on whitespace so that different tokens may be attributed to their original position.
127 | for c in example.context_text:
128 | if _is_whitespace(c):
129 | prev_is_whitespace = True
130 | else:
131 | if prev_is_whitespace:
132 | doc_tokens.append(c)
133 | else:
134 | doc_tokens[-1] += c
135 | prev_is_whitespace = False
136 | char_to_word_offset.append(len(doc_tokens) - 1)
137 |
138 | # Get start and end position
139 | # 정답의 시작/끝 위치 : 어절 기준
140 | start_position = char_to_word_offset[example.start_position_character]
141 | end_position = char_to_word_offset[
142 | min(example.start_position_character + len(example.answer_text) - 1, len(char_to_word_offset) - 1)
143 | ]
144 |
145 | # If the answer cannot be found in the text, then skip this example.
146 | # actual_text : 어절 단위 정답 스팬(대개 cleaned_answer_text을 포함한다), 예: 베토벤의 교향곡 9번을
147 | actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
148 | # cleaned_answer_text : 사람이 레이블한 정답 스팬, 베토벤의 교향곡 9번
149 | cleaned_answer_text = " ".join(_whitespace_tokenize(example.answer_text))
150 | # actual_text가 cleaned_answer_text를 포함할 경우 0
151 | # 그렇지 않을 경우 -1 (actual_text이 "베토벤 교향곡 9번" 등일 경우 이 케이스)
152 | if actual_text.find(cleaned_answer_text) == -1:
153 | logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
154 | return []
155 |
156 | # doc_tokens : context_text의 각 어절
157 | # all_doc_tokens는 doc_tokens의 각 어절별로 wordpiece를 수행한 토큰 리스트
158 | # tok_to_orig_index는 all_doc_tokens의 각 토큰이 context_text에서 몇 번째 어절에 위치하는지 나타내는 리스트
159 | # orig_to_tok_index는 context_text의 각 어절의 시작 토큰이 all_doc_tokens에서 몇 번째 토큰에 위치하는지 나타내는 리스트
160 | # context_text가 "아이스크림케이크 좋아하는 사람 있나요?"고
161 | # doc_tokens가 ["아이스크림케이크", "좋아하는", "사람", "있나요?"]라면
162 | # all_doc_tokens = ['아이', '##스크', '##림', '##케이', '##크', '좋아하는', '사람', '있나요', '?']
163 | # tok_to_orig_index = [0, 0, 0, 0, 0, 1, 2, 3, 3]라면
164 | # all_doc_tokens의 0~4번째 토큰('아이', '##스크', '##림', '##케이', '##크')은 context_text상 0번째 어절에 위치함을 나타냄
165 | # orig_to_tok_index = [0, 5, 6, 7]라면
166 | # context_text의 0번째 어절(아이스크림케이크)의 시작은 all_doc_tokens상 0번째 토큰
167 | # context_text의 1번째 어절(좋아하는)의 시작은 all_doc_tokens상 5번째 토큰
168 | # ...
169 | tok_to_orig_index = []
170 | orig_to_tok_index = []
171 | all_doc_tokens = []
172 | for (i, token) in enumerate(doc_tokens):
173 | orig_to_tok_index.append(len(all_doc_tokens))
174 | sub_tokens = tokenizer.tokenize(token)
175 | for sub_token in sub_tokens:
176 | tok_to_orig_index.append(i)
177 | all_doc_tokens.append(sub_token)
178 |
179 | # 학습은 어절 단위가 아니라 wordpiece 토큰 단위로 이뤄진다
180 | # 하지만 annotation된 레이블은 wordpiece 토큰 단위가 아니라 사람이 특정 범위를 지정한 것
181 | # 따라서 아래 if문 안에서 처리를 해서 wordpiece상 정답 범위를 정한다
182 | # all_doc_tokens[tok_start_position:tok_end_position]
183 | # > ['베', '##토', '##벤', '##의', '교', '##향', '##곡', '9', '##번']
184 | # example.start_position : 정답 토큰의 시작이 context_text에서 몇 번째 어절에 있는지 정보
185 | # example.end_position : 정답 토큰의 끝이 context_text에서 몇 번째 어절에 있는지 정보
186 | # tok_start_position = context_text상 example.start_position번째 어절이 all_doc_tokens에서 몇 번째 토큰인지 나타냄
187 | # tok_end_position = context_text상 example.end_position번째 어절이 all_doc_tokens에서 몇 번째 토큰인지 나타냄
188 | tok_start_position = orig_to_tok_index[start_position]
189 | if end_position < len(doc_tokens) - 1:
190 | tok_end_position = orig_to_tok_index[end_position + 1] - 1
191 | else:
192 | tok_end_position = len(doc_tokens) - 1
193 |
194 | (tok_start_position, tok_end_position) = _improve_answer_span(
195 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text
196 | )
197 |
198 | spans = []
199 |
200 | truncated_query = tokenizer.encode(
201 | example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length
202 | )
203 | sequence_added_tokens = (
204 | tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1
205 | if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer))
206 | else tokenizer.model_max_length - tokenizer.max_len_single_sentence
207 | )
208 |
209 | # [CLS] question [SEP] context [SEP] > 따라서 총 3개
210 | sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair
211 |
212 | span_doc_tokens = all_doc_tokens
213 | while len(spans) * doc_stride < len(all_doc_tokens):
214 | # padding_side = "right"라면 question + [SEP] + context으로 인코딩
215 | # padding_size = "left"라면 context + [SEP] + question으로 인코딩
216 | # truncated_query : token id sequence, List[int]
217 | # span_doc_tokens : token sequence, List[str]
218 | # encode_plus의 arg인 stride는 max_seq_length보다 길 경우
219 | # truncated 실시한 토큰화 결과(input_ids)와 넘치는 토큰 시퀀스(overflowing_tokens)가
220 | # 몇 개 토큰이 겹치게 만들 것인지를 정한다
221 | # stride = 0이라면 이 둘 사이에 겹치는 토큰 = 0
222 | # stride = max_seq_length라면 이 둘을 완전히 겹치게 만든다
223 | # 다만 이 값을 정할 때 max_seq_length에서 TrainArguments의 doc_stride만큼을 빼주고 있으므로
224 | # 다음 청크를 만들 때 doc_stride만큼 건너뛰는 효과가 있다
225 | encoded_dict = tokenizer.encode_plus(
226 | truncated_query if tokenizer.padding_side == "right" else span_doc_tokens,
227 | span_doc_tokens if tokenizer.padding_side == "right" else truncated_query,
228 | truncation="only_second" if tokenizer.padding_side == "right" else "only_first",
229 | padding="max_length",
230 | max_length=max_seq_length,
231 | return_overflowing_tokens=True,
232 | stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens,
233 | return_token_type_ids=True,
234 | )
235 |
236 | paragraph_len = min(
237 | len(all_doc_tokens) - len(spans) * doc_stride,
238 | max_seq_length - len(truncated_query) - sequence_pair_added_tokens,
239 | )
240 |
241 | encoded_dict["start"] = len(spans) * doc_stride
242 | encoded_dict["length"] = paragraph_len
243 |
244 | spans.append(encoded_dict)
245 |
246 | if "overflowing_tokens" not in encoded_dict or (
247 | "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0
248 | ):
249 | break
250 | # tokenizer.encode_plus에서 return_overflowing_tokens=True로 켜면
251 | # truncate하고 남은 토큰들을 리턴한다, 이를 span_doc_tokens에 다시 넣어 재처리한다
252 | # 이렇게 하는 이유는 max_seq_length보다 보통 context_text가 길기 때문에
253 | # 동일한 question-context pair로부터 학습 인스턴스를 stride해 가며 여러 개를 복제
254 | span_doc_tokens = encoded_dict["overflowing_tokens"]
255 |
256 | for span in spans:
257 | # Identify the position of the CLS token
258 | cls_index = span["input_ids"].index(tokenizer.cls_token_id)
259 | # For training, if our document chunk does not contain an annotation
260 | # we throw it out, since there is nothing to predict.
261 | doc_start = span["start"]
262 | doc_end = span["start"] + span["length"] - 1
263 | out_of_span = False
264 |
265 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
266 | out_of_span = True
267 |
268 | if out_of_span:
269 | start_position = cls_index
270 | end_position = cls_index
271 | else:
272 | if tokenizer.padding_side == "left":
273 | doc_offset = 0
274 | else:
275 | doc_offset = len(truncated_query) + sequence_added_tokens
276 |
277 | start_position = tok_start_position - doc_start + doc_offset
278 | end_position = tok_end_position - doc_start + doc_offset
279 |
280 | feature = QAFeatures(
281 | input_ids=span["input_ids"],
282 | attention_mask=span["attention_mask"],
283 | token_type_ids=span["token_type_ids"],
284 | start_positions=start_position,
285 | end_positions=end_position,
286 | )
287 |
288 | features.append(feature)
289 |
290 | return features
291 |
292 |
293 | def _squad_convert_examples_to_features(
294 | examples: List[QAExample],
295 | tokenizer: PreTrainedTokenizer,
296 | args: QATrainArguments,
297 | ):
298 | threads = min(args.threads, cpu_count())
299 | with Pool(threads, initializer=_squad_convert_example_to_features_init, initargs=(tokenizer,)) as p:
300 | annotate_ = partial(
301 | _squad_convert_example_to_features,
302 | max_seq_length=args.max_seq_length,
303 | doc_stride=args.doc_stride,
304 | max_query_length=args.max_query_length,
305 | )
306 | features = list(
307 | tqdm(
308 | p.imap(annotate_, examples, chunksize=32),
309 | total=len(examples),
310 | desc="convert squad examples to features",
311 | disable=not args.tqdm_enabled,
312 | )
313 | )
314 | new_features = []
315 | for feature in features:
316 | if not feature:
317 | continue
318 | for f in feature:
319 | new_features.append(f)
320 | features = new_features
321 | del new_features
322 |
323 | for i, example in enumerate(examples[:10]):
324 | logger.info("*** Example ***")
325 | logger.info("question & context: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids))))
326 | logger.info("answer: %s" % (" ".join(tokenizer.convert_ids_to_tokens(features[i].input_ids[features[i].start_positions:features[i].end_positions + 1]))))
327 | logger.info("features: %s" % features[i])
328 |
329 | return features
330 |
331 |
332 | class QADataset(Dataset):
333 |
334 | def __init__(
335 | self,
336 | args: QATrainArguments,
337 | tokenizer: PreTrainedTokenizer,
338 | corpus: QACorpus,
339 | mode: Optional[str] = "train",
340 | convert_examples_to_features_fn=_squad_convert_examples_to_features,
341 | ):
342 | if corpus is not None:
343 | self.corpus = corpus
344 | else:
345 | raise KeyError("corpus is not valid")
346 | if not mode in ["train", "val", "test"]:
347 | raise KeyError(f"mode({mode}) is not a valid split name")
348 | # Load data features from cache or dataset file
349 | cached_features_file = os.path.join(
350 | args.downstream_corpus_root_dir,
351 | args.downstream_corpus_name,
352 | "cached_{}_{}_{}_{}_{}_{}_{}".format(
353 | mode,
354 | tokenizer.__class__.__name__,
355 | f"maxlen-{args.max_seq_length}",
356 | f"maxquerylen-{args.max_query_length}",
357 | f"docstride-{args.doc_stride}",
358 | args.downstream_corpus_name,
359 | "question-answering",
360 | ),
361 | )
362 |
363 | # Make sure only the first process in distributed training processes the dataset,
364 | # and the others will use the cache.
365 | lock_path = cached_features_file + ".lock"
366 | with FileLock(lock_path):
367 |
368 | if os.path.exists(cached_features_file) and not args.overwrite_cache:
369 | start = time.time()
370 | self.features = torch.load(cached_features_file)
371 | logger.info(
372 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
373 | )
374 | else:
375 | corpus_fpath = os.path.join(
376 | args.downstream_corpus_root_dir,
377 | args.downstream_corpus_name.lower(),
378 | )
379 | logger.info(f"Creating features from {mode} dataset file at {corpus_fpath}")
380 | examples = self.corpus.get_examples(corpus_fpath, mode)
381 | self.features = convert_examples_to_features_fn(examples, tokenizer, args)
382 | start = time.time()
383 | logger.info(
384 | "Saving features into cached file, it could take a lot of time..."
385 | )
386 | torch.save(self.features, cached_features_file)
387 | logger.info(
388 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
389 | )
390 |
391 | def __len__(self):
392 | return len(self.features)
393 |
394 | def __getitem__(self, i):
395 | return self.features[i]
396 |
--------------------------------------------------------------------------------