├── workspace └── .keep ├── Dockerfile_mlflow ├── requirements.txt ├── README.md ├── docker-compose.yml ├── LICENSE ├── run_doccat.sh ├── run_ner.sh ├── Dockerfile ├── .gitignore ├── doccat.py └── ner.py /workspace/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Dockerfile_mlflow: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | RUN apt-get update 4 | RUN apt-get -y install sqlite3 libsqlite3-dev 5 | 6 | RUN pip install --upgrade pip && \ 7 | pip install mlflow 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mlflow==1.12.1 2 | scikit-learn==0.23.2 3 | pytorch_lightning==1.1.0 4 | transformers==4.0.1 5 | seqeval 6 | requests 7 | fugashi 8 | ipadic 9 | gorilla # for mlflow 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformers_ner_ja 2 | Japanese NER with Transformers + PyTorch-Lightning + MLflow Tracking 3 | 4 | ## GPU Training 5 | - build: `docker build -t trf-ner-ja-train .` 6 | - run: `docker run --rm --gpus all -v /where/to/workspace:/app/workspace trf-ner-ja-train` 7 | - NOTE: set `export GPUS=1` in run_ner.sh 8 | 9 | ## MLflow Tracking 10 | - build: `docker-compose build` 11 | - run: `docker-compose up` 12 | - NOTE: check `./workspace/mlruns/0/xxx` is created for each runs 13 | - NOTE: GPU support in docker-compose will be released in 1.28.0: See. https://github.com/docker/compose/pull/7929 14 | - view: open http://localhost:5000/ in your browser 15 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | services: 3 | pl: 4 | build: 5 | context: . 6 | dockerfile: Dockerfile 7 | volumes: 8 | - ./workspace:/app/workspace 9 | working_dir: /app 10 | command: bash /app/run_ner.sh 11 | restart: "no" 12 | # runtime: nvidia # can be used after v1.28.0 13 | 14 | mlflow: 15 | build: 16 | context: . 17 | dockerfile: Dockerfile_mlflow 18 | volumes: 19 | - ./workspace:/app/workspace 20 | working_dir: /app/workspace 21 | ports: 22 | - "5000:5000" 23 | command: mlflow server --backend-store-uri file:/app/workspace/mlruns --host 0.0.0.0 --port 5000 24 | restart: always 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kazuki Inamura 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /run_doccat.sh: -------------------------------------------------------------------------------- 1 | export BERT_MODEL=cl-tohoku/bert-base-japanese 2 | # export WORK_DIR=${PWD} 3 | export DATA_DIR=${WORK_DIR}/data/ 4 | export OUTPUT_DIR=${WORK_DIR}/outputs/ 5 | export CACHE=${WORK_DIR}/cache/ 6 | export SEED=42 7 | mkdir -p $OUTPUT_DIR 8 | # In Docker, the following error occurs due to not big enough memory: 9 | # `RuntimeError: DataLoader worker is killed by signal: Killed.` 10 | # Try to reduce NUM_WORKERS or MAX_LENGTH or BATCH_SIZE or increase docker memory 11 | export NUM_WORKERS=4 12 | export GPUS=0 13 | 14 | export MAX_LENGTH=128 15 | export BATCH_SIZE=32 16 | export LEARNING_RATE=5e-5 17 | 18 | export NUM_EPOCHS=1 19 | export NUM_SAMPLES=100 20 | 21 | python3 doccat.py \ 22 | --model_name_or_path=$BERT_MODEL \ 23 | --output_dir=$OUTPUT_DIR \ 24 | --accumulate_grad_batches=1 \ 25 | --max_epochs=$NUM_EPOCHS \ 26 | --seed=$SEED \ 27 | --do_train \ 28 | --do_predict \ 29 | --cache_dir=$CACHE \ 30 | --gpus=$GPUS \ 31 | --data_dir=$DATA_DIR \ 32 | --num_workers=$NUM_WORKERS \ 33 | --max_seq_length=$MAX_LENGTH \ 34 | --train_batch_size=$BATCH_SIZE \ 35 | --eval_batch_size=$BATCH_SIZE \ 36 | --learning_rate=$LEARNING_RATE \ 37 | --adam_epsilon=1e-8 \ 38 | --weight_decay=0.0 \ 39 | --num_samples=$NUM_SAMPLES 40 | -------------------------------------------------------------------------------- /run_ner.sh: -------------------------------------------------------------------------------- 1 | export BERT_MODEL=cl-tohoku/bert-base-japanese 2 | # export WORK_DIR=${PWD} 3 | export DATA_DIR=${WORK_DIR}/data/ 4 | export OUTPUT_DIR=${WORK_DIR}/outputs/ 5 | export CACHE=${WORK_DIR}/cache/ 6 | export LABEL_PATH=$DATA_DIR/label_types.txt 7 | export SEED=42 8 | mkdir -p $OUTPUT_DIR 9 | # In Docker, the following error occurs due to not big enough memory: 10 | # `RuntimeError: DataLoader worker is killed by signal: Killed.` 11 | # Try to reduce NUM_WORKERS or MAX_LENGTH or BATCH_SIZE or increase docker memory 12 | export NUM_WORKERS=8 13 | export GPUS=1 14 | 15 | export MAX_LENGTH=128 16 | export BATCH_SIZE=16 17 | export LEARNING_RATE=5e-5 18 | export PATIENCE=3 19 | export ANNEAL_FACTOR=0.5 20 | 21 | export NUM_EPOCHS=1 22 | export NUM_SAMPLES=100 23 | 24 | python3 ner.py \ 25 | --model_name_or_path=$BERT_MODEL \ 26 | --output_dir=$OUTPUT_DIR \ 27 | --accumulate_grad_batches=1 \ 28 | --max_epochs=$NUM_EPOCHS \ 29 | --seed=$SEED \ 30 | --do_train \ 31 | --do_predict \ 32 | --cache_dir=$CACHE \ 33 | --gpus=$GPUS \ 34 | --data_dir=$DATA_DIR \ 35 | --labels=$LABEL_PATH \ 36 | --num_workers=$NUM_WORKERS \ 37 | --max_seq_length=$MAX_LENGTH \ 38 | --train_batch_size=$BATCH_SIZE \ 39 | --eval_batch_size=$BATCH_SIZE \ 40 | --learning_rate=$LEARNING_RATE \ 41 | --patience=$PATIENCE \ 42 | --anneal_factor=$ANNEAL_FACTOR \ 43 | --adam_epsilon=1e-8 \ 44 | --weight_decay=0.0 \ 45 | --num_samples=$NUM_SAMPLES 46 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.6.0-cuda10.1-cudnn7-runtime 2 | 3 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 4 | 5 | ENV APP_ROOT /app 6 | ENV WORK_DIR /app/workspace 7 | ENV MLFLOW_TRACKING_URI file:/app/workspace/mlruns 8 | 9 | ENV DEBIAN_FRONTEND noninteractive 10 | 11 | RUN mkdir -p $APP_ROOT 12 | WORKDIR $APP_ROOT 13 | 14 | RUN ln -sf /usr/share/zoneinfo/Asia/Tokyo /etc/localtime 15 | # to support install openjdk-11-jre-headless 16 | RUN mkdir -p /usr/share/man/man1 17 | RUN apt-get update \ 18 | && apt-get upgrade -y \ 19 | && apt-get install -y \ 20 | build-essential \ 21 | git \ 22 | bzip2 \ 23 | ca-certificates \ 24 | libssl-dev \ 25 | libmysqlclient-dev \ 26 | default-libmysqlclient-dev \ 27 | make \ 28 | cmake \ 29 | protobuf-compiler \ 30 | curl \ 31 | sudo \ 32 | software-properties-common \ 33 | xz-utils \ 34 | file \ 35 | mecab \ 36 | libmecab-dev \ 37 | python3-pip \ 38 | openjdk-11-jre-headless \ 39 | && curl -sL https://deb.nodesource.com/setup_10.x | bash - \ 40 | && apt-get update && apt-get install -y nodejs \ 41 | && apt-get clean \ 42 | && rm -rf /var/lib/apt/lists/* 43 | RUN ln -s /etc/mecabrc /usr/local/etc/mecabrc 44 | RUN pip3 install -U pip 45 | COPY ./requirements.txt . 46 | RUN pip install -r requirements.txt 47 | # pip install hydra-core --upgrade 48 | COPY *.py ./ 49 | COPY *.sh ./ 50 | COPY workspace/data/* /app/workspace/ 51 | # COPY config.yaml . 52 | RUN mkdir -p $WORK_DIR/mlruns 53 | 54 | CMD ["bash", "./run_ner.sh"] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | workspace/ -------------------------------------------------------------------------------- /doccat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | from argparse import ArgumentParser, Namespace 4 | from dataclasses import dataclass 5 | from enum import Enum 6 | from pathlib import Path 7 | from typing import Any, Dict, List, Optional, Union 8 | 9 | import mlflow.pytorch 10 | import pandas as pd 11 | import pytorch_lightning as pl 12 | import requests 13 | import torch 14 | from pytorch_lightning.callbacks import ( 15 | EarlyStopping, 16 | LearningRateMonitor, 17 | ModelCheckpoint, 18 | ) 19 | from pytorch_lightning.utilities import rank_zero_info 20 | from sklearn.metrics import accuracy_score 21 | from sklearn.model_selection import train_test_split 22 | from torch.optim.lr_scheduler import ReduceLROnPlateau 23 | from torch.utils.data import DataLoader, Dataset 24 | from transformers import ( 25 | AdamW, 26 | BatchEncoding, 27 | BertConfig, 28 | BertForSequenceClassification, 29 | BertTokenizerFast, 30 | PretrainedConfig, 31 | PreTrainedModel, 32 | PreTrainedTokenizerFast, 33 | ) 34 | from transformers.modeling_outputs import SequenceClassifierOutput 35 | from transformers.optimization import Adafactor 36 | 37 | # huggingface/tokenizers: Disabling parallelism to avoid deadlocks. 38 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 39 | 40 | IntList = List[int] 41 | IntListList = List[IntList] 42 | StrList = List[str] 43 | TEXT_COL_NAME: str = "text" 44 | LABEL_COL_NAME: str = "label" 45 | 46 | 47 | class LABELS(Enum): 48 | sports_watch = "sports-watch" 49 | topic_news = "topic-news" 50 | dokujo_tsushin = "dokujo-tsushin" 51 | peachy = "peachy" 52 | movie_enter = "movie-enter" 53 | kaden_channel = "kaden-channel" 54 | livedoor_homme = "livedoor-homme" 55 | smax = "smax" 56 | it_life_hack = "it-life-hack" 57 | 58 | 59 | class Split(Enum): 60 | train = "train" 61 | dev = "dev" 62 | test = "test" 63 | 64 | 65 | @dataclass 66 | class SequenceClassificationExample: 67 | guid: str 68 | text: str 69 | label: str 70 | 71 | 72 | @dataclass 73 | class InputFeatures: 74 | input_ids: IntList 75 | attention_mask: IntList 76 | label_ids: IntList 77 | 78 | 79 | def download_and_extract_corpus(data_dir: Path) -> Optional[Path]: 80 | """livedoorコーパスデータのダウンロード""" 81 | filepath = Path("ldcc.tar") 82 | url = "https://www.rondhuit.com/download/ldcc-20140209.tar.gz" 83 | response = requests.get(url) 84 | if response.ok: 85 | with open(filepath, "wb") as fp: 86 | fp.write(response.content) 87 | with tarfile.open(filepath, "r") as fp: 88 | fp.extractall(data_dir) 89 | filepath.unlink() 90 | return data_dir / "text" 91 | return None 92 | 93 | 94 | def make_livedoor_corpus_dataset(data_dir: str = "./data") -> pd.DataFrame: 95 | # ライブドアコーパスを[カテゴリ, 本文]形式でpd.DataFrameで読み込む 96 | pdir = Path(data_dir) 97 | if not (pdir / "text").exists(): 98 | pdir.mkdir(exist_ok=True) 99 | parent_path = download_and_extract_corpus(Path(data_dir)) 100 | else: 101 | parent_path = pdir / "text" 102 | 103 | categories = [v.value for v in LABELS] 104 | docs = [] 105 | for category in categories: 106 | for p in (parent_path / f"{category}").glob(f"{category}*.txt"): 107 | with open(p, "r") as f: 108 | next(f) # url 109 | next(f) # date 110 | next(f) # title 111 | body = "\n".join([line.strip() for line in f if line.strip()]) 112 | docs.append((category, body)) 113 | 114 | return pd.DataFrame(docs, columns=[LABEL_COL_NAME, TEXT_COL_NAME]) 115 | 116 | 117 | class SequenceClassificationDataset(Dataset): 118 | """ 119 | Build feature dataset so that the model can load 120 | """ 121 | 122 | def __init__( 123 | self, 124 | examples: List[SequenceClassificationExample], 125 | tokenizer: PreTrainedTokenizerFast, 126 | label_to_id: Dict[str, int], 127 | tokens_per_batch: int = 32, 128 | ): 129 | self.features: List[InputFeatures] = [] 130 | self.examples: List[SequenceClassificationExample] = examples 131 | texts: StrList = [ex.text for ex in self.examples] 132 | labels: StrList = [ex.label for ex in self.examples] 133 | 134 | # tokenize text into subwords with padding and truncation 135 | self.encodings: List[BatchEncoding] = [ 136 | tokenizer.encode_plus( 137 | text, 138 | add_special_tokens=True, 139 | max_length=tokens_per_batch, 140 | return_token_type_ids=False, 141 | padding="max_length", 142 | return_attention_mask=True, 143 | return_tensors="np", 144 | truncation=True, 145 | ) 146 | for text in texts 147 | ] 148 | 149 | # register features 150 | self.features = [ 151 | InputFeatures( 152 | input_ids=encoding.input_ids.flatten().tolist(), 153 | attention_mask=encoding.attention_mask.flatten().tolist(), 154 | label_ids=[label_to_id.get(label, 0)], 155 | ) 156 | for encoding, label in zip(self.encodings, labels) 157 | ] 158 | self._n_features = len(self.features) 159 | 160 | def __len__(self): 161 | return self._n_features 162 | 163 | def __getitem__(self, idx) -> InputFeatures: 164 | return self.features[idx] 165 | 166 | 167 | class InputFeaturesBatch: 168 | def __init__(self, features: List[InputFeatures]): 169 | self.input_ids: torch.Tensor 170 | self.attention_masks: torch.Tensor 171 | self.label_ids: Optional[torch.Tensor] 172 | 173 | self._n_features = len(features) 174 | input_ids_list: IntListList = [] 175 | masks_list: IntListList = [] 176 | label_ids_list: IntListList = [] 177 | for f in features: 178 | input_ids_list.append(f.input_ids) 179 | masks_list.append(f.attention_mask) 180 | if f.label_ids is not None: 181 | label_ids_list.append(f.label_ids) 182 | self.input_ids = torch.LongTensor(input_ids_list) 183 | self.attention_mask = torch.LongTensor(masks_list) 184 | if label_ids_list: 185 | self.label_ids = torch.LongTensor(label_ids_list) 186 | 187 | def __len__(self): 188 | return self._n_features 189 | 190 | def __getitem__(self, item): 191 | return getattr(self, item) 192 | 193 | 194 | class SequenceClassificationDataModule(pl.LightningDataModule): 195 | """ 196 | Prepare dataset and build DataLoader 197 | """ 198 | 199 | def __init__(self, hparams: Namespace): 200 | self.tokenizer: PreTrainedTokenizerFast 201 | self.train_examples: List[SequenceClassificationExample] 202 | self.val_examples: List[SequenceClassificationExample] 203 | self.test_examples: List[SequenceClassificationExample] 204 | self.train_dataset: SequenceClassificationDataset 205 | self.val_dataset: SequenceClassificationDataset 206 | self.test_dataset: SequenceClassificationDataset 207 | self.df_org: pd.DataFrame 208 | self.df_use: pd.DataFrame 209 | self.label_to_id: Dict[str, int] 210 | 211 | super().__init__() 212 | self.max_seq_length = hparams.max_seq_length 213 | self.cache_dir = hparams.cache_dir 214 | if not os.path.exists(self.cache_dir): 215 | os.mkdir(self.cache_dir) 216 | self.data_dir = hparams.data_dir 217 | if not os.path.exists(self.data_dir): 218 | os.mkdir(self.data_dir) 219 | self.tokenizer_name = hparams.model_name_or_path 220 | self.train_batch_size = hparams.train_batch_size 221 | self.eval_batch_size = hparams.eval_batch_size 222 | self.num_workers = hparams.num_workers 223 | self.num_samples = hparams.num_samples 224 | 225 | def prepare_data(self): 226 | """ 227 | Downloads the data and prepare the tokenizer 228 | """ 229 | self.tokenizer = BertTokenizerFast.from_pretrained( 230 | self.tokenizer_name, 231 | cache_dir=self.cache_dir, 232 | tokenize_chinese_chars=False, 233 | strip_accents=False, 234 | ) 235 | df = make_livedoor_corpus_dataset(self.data_dir) 236 | self.df_org = df 237 | if self.num_samples > 0: 238 | df = df.iloc[: self.num_samples] 239 | self.df_use = df 240 | 241 | def setup(self, stage=None): 242 | """ 243 | split the data into train, test, validation data 244 | :param stage: Stage - training or testing 245 | """ 246 | 247 | df = self.df_use 248 | # label_to_id = {k: v for v, k in enumerate(LABELS)} 249 | self.label_to_id = { 250 | k: v for v, k in enumerate(sorted(set(df[LABEL_COL_NAME].values.tolist()))) 251 | } 252 | 253 | df_train, df_test = train_test_split( 254 | df, test_size=0.3, stratify=df[LABEL_COL_NAME] 255 | ) 256 | df_val, df_test = train_test_split( 257 | df_test, test_size=0.5, stratify=df_test[LABEL_COL_NAME] 258 | ) 259 | 260 | self.train_examples = [ 261 | SequenceClassificationExample(guid=f"train-{i}", text=t, label=l) 262 | for i, (t, l) in df_train[[TEXT_COL_NAME, LABEL_COL_NAME]].iterrows() 263 | ] 264 | self.val_examples = [ 265 | SequenceClassificationExample(guid=f"val-{i}", text=t, label=l) 266 | for i, (t, l) in df_val[[TEXT_COL_NAME, LABEL_COL_NAME]].iterrows() 267 | ] 268 | self.test_examples = [ 269 | SequenceClassificationExample(guid=f"test-{i}", text=t, label=l) 270 | for i, (t, l) in df_test[[TEXT_COL_NAME, LABEL_COL_NAME]].iterrows() 271 | ] 272 | 273 | self.train_dataset = self.create_dataset(self.train_examples) 274 | self.val_dataset = self.create_dataset(self.val_examples) 275 | self.test_dataset = self.create_dataset(self.test_examples) 276 | 277 | self.dataset_size = len(self.train_dataset) 278 | 279 | def create_dataset( 280 | self, data: List[SequenceClassificationExample] 281 | ) -> SequenceClassificationDataset: 282 | return SequenceClassificationDataset( 283 | data, 284 | self.tokenizer, 285 | self.label_to_id, 286 | self.max_seq_length, 287 | ) 288 | 289 | @staticmethod 290 | def create_dataloader( 291 | ds: SequenceClassificationDataset, 292 | batch_size: int, 293 | num_workers: int = 0, 294 | shuffle: bool = False, 295 | ) -> DataLoader: 296 | return DataLoader( 297 | ds, 298 | collate_fn=InputFeaturesBatch, 299 | batch_size=batch_size, 300 | num_workers=num_workers, 301 | pin_memory=True, 302 | shuffle=shuffle, 303 | ) 304 | 305 | def train_dataloader(self): 306 | return self.create_dataloader( 307 | self.train_dataset, self.train_batch_size, self.num_workers, shuffle=True 308 | ) 309 | 310 | def val_dataloader(self): 311 | return self.create_dataloader( 312 | self.val_dataset, self.eval_batch_size, self.num_workers, shuffle=False 313 | ) 314 | 315 | def test_dataloader(self): 316 | return self.create_dataloader( 317 | self.test_dataset, self.eval_batch_size, self.num_workers, shuffle=False 318 | ) 319 | 320 | def total_steps(self) -> int: 321 | """ 322 | The number of total training steps that will be run. Used for lr scheduler purposes. 323 | """ 324 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 325 | effective_batch_size = ( 326 | self.hparams.train_batch_size 327 | * self.hparams.accumulate_grad_batches 328 | * num_devices 329 | ) 330 | return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs 331 | 332 | @staticmethod 333 | def add_model_specific_args(parent_parser): 334 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 335 | parser.add_argument( 336 | "--train_batch_size", 337 | type=int, 338 | default=32, 339 | help="input batch size for training (default: 32)", 340 | ) 341 | parser.add_argument( 342 | "--eval_batch_size", 343 | type=int, 344 | default=32, 345 | help="input batch size for validation/test (default: 32)", 346 | ) 347 | parser.add_argument( 348 | "--num_workers", 349 | type=int, 350 | default=4, 351 | metavar="N", 352 | help="number of workers (default: 3)", 353 | ) 354 | parser.add_argument( 355 | "--max_seq_length", 356 | default=256, 357 | type=int, 358 | help="The maximum total input sequence length after tokenization. Sequences longer " 359 | "than this will be truncated, sequences shorter will be padded.", 360 | ) 361 | parser.add_argument( 362 | "--data_dir", 363 | default="data", 364 | type=str, 365 | required=True, 366 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", 367 | ) 368 | parser.add_argument( 369 | "--num_samples", 370 | type=int, 371 | default=15000, 372 | metavar="N", 373 | help="Number of samples to be used for training and evaluation steps (default: 15000) Maximum:100000", 374 | ) 375 | return parser 376 | 377 | 378 | class SequenceClassificationModule(pl.LightningModule): 379 | """ 380 | Initialize a model and config for token-classification 381 | """ 382 | 383 | def __init__(self, hparams: Union[Dict, Namespace]): 384 | # NOTE: internal code may pass hparams as dict **kwargs 385 | if isinstance(hparams, Dict): 386 | hparams = Namespace(**hparams) 387 | 388 | num_labels = len(LABELS) 389 | 390 | super().__init__() 391 | # Enable to access arguments via self.hparams 392 | self.save_hyperparameters(hparams) 393 | 394 | self.step_count = 0 395 | self.output_dir = Path(self.hparams.output_dir) 396 | self.cache_dir = None 397 | if self.hparams.cache_dir: 398 | if not os.path.exists(self.hparams.cache_dir): 399 | os.mkdir(self.hparams.cache_dir) 400 | self.cache_dir = self.hparams.cache_dir 401 | 402 | # AutoTokenizer 403 | # trf>=4.0.0: PreTrainedTokenizerFast by default 404 | # NOTE: AutoTokenizer doesn't load PreTrainedTokenizerFast... 405 | self.tokenizer_name = self.hparams.model_name_or_path 406 | self.tokenizer = BertTokenizerFast.from_pretrained( 407 | self.tokenizer_name, 408 | cache_dir=self.cache_dir, 409 | tokenize_chinese_chars=False, 410 | strip_accents=False, 411 | ) 412 | 413 | # AutoConfig 414 | config_name = self.hparams.model_name_or_path 415 | self.config: PretrainedConfig = BertConfig.from_pretrained( 416 | config_name, 417 | **({"num_labels": num_labels} if num_labels is not None else {}), 418 | cache_dir=self.cache_dir, 419 | ) 420 | extra_model_params = ( 421 | "encoder_layerdrop", 422 | "decoder_layerdrop", 423 | "dropout", 424 | "attention_dropout", 425 | ) 426 | for p in extra_model_params: 427 | if getattr(self.hparams, p, None) and hasattr(self.config, p): 428 | setattr(self.config, p, getattr(self.hparams, p, None)) 429 | 430 | # AutoModelForSequenceClassification 431 | self.model: PreTrainedModel = BertForSequenceClassification.from_pretrained( 432 | self.hparams.model_name_or_path, 433 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 434 | config=self.config, 435 | cache_dir=self.cache_dir, 436 | ) 437 | 438 | self.scheduler = None 439 | self.optimizer = None 440 | 441 | def forward(self, **inputs) -> SequenceClassifierOutput: 442 | """BertForSequenceClassification.forward""" 443 | return self.model(**inputs) 444 | 445 | def shared_step(self, batch: InputFeaturesBatch) -> SequenceClassifierOutput: 446 | # .to(self.device) is not necessary with pl.Traner ?? 447 | inputs = { 448 | "input_ids": batch.input_ids.to(self.device), 449 | "attention_mask": batch.attention_mask.to(self.device), 450 | "labels": batch.label_ids.to(self.device), 451 | } 452 | return self.model(**inputs) 453 | 454 | def training_step( 455 | self, train_batch: InputFeaturesBatch, batch_idx 456 | ) -> Dict[str, torch.Tensor]: 457 | output = self.shared_step(train_batch) 458 | loss = output.loss 459 | self.log("train_loss", loss, prog_bar=True) 460 | return {"loss": loss} 461 | 462 | def validation_step( 463 | self, val_batch: InputFeaturesBatch, batch_idx 464 | ) -> Dict[str, torch.Tensor]: 465 | output = self.shared_step(val_batch) 466 | return { 467 | "val_step_loss": output.loss, 468 | } 469 | 470 | def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): 471 | avg_loss = torch.stack([x["val_step_loss"] for x in outputs]).mean() 472 | self.log("val_loss", avg_loss, sync_dist=True) 473 | 474 | def test_step( 475 | self, test_batch: InputFeaturesBatch, batch_idx 476 | ) -> Dict[str, torch.Tensor]: 477 | output = self.shared_step(test_batch) 478 | _, y_hat = torch.max(output.logits, dim=1) # values, indices 479 | test_acc = accuracy_score(y_hat.cpu(), test_batch.label_ids.detach().cpu()) 480 | return {"test_acc": torch.Tensor([test_acc])} 481 | 482 | def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): 483 | avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean() 484 | self.log("avg_test_acc", avg_test_acc) 485 | 486 | def configure_optimizers(self): 487 | """Prepare optimizer and schedule (linear warmup and decay)""" 488 | model = self.model 489 | no_decay = ["bias", "LayerNorm.weight"] 490 | optimizer_grouped_parameters = [ 491 | { 492 | "params": [ 493 | p 494 | for n, p in model.named_parameters() 495 | if not any(nd in n for nd in no_decay) 496 | ], 497 | "weight_decay": self.hparams.weight_decay, 498 | }, 499 | { 500 | "params": [ 501 | p 502 | for n, p in model.named_parameters() 503 | if any(nd in n for nd in no_decay) 504 | ], 505 | "weight_decay": 0.0, 506 | }, 507 | ] 508 | if self.hparams.adafactor: 509 | self.optimizer = Adafactor( 510 | optimizer_grouped_parameters, 511 | lr=self.hparams.learning_rate, 512 | scale_parameter=False, 513 | relative_step=False, 514 | ) 515 | else: 516 | self.optimizer = AdamW( 517 | optimizer_grouped_parameters, 518 | lr=self.hparams.learning_rate, 519 | eps=self.hparams.adam_epsilon, 520 | ) 521 | self.scheduler = { 522 | "scheduler": ReduceLROnPlateau( 523 | self.optimizer, 524 | mode="min", 525 | factor=0.2, 526 | patience=2, 527 | min_lr=1e-6, 528 | verbose=True, 529 | ), 530 | "monitor": "val_loss", 531 | } 532 | 533 | return [self.optimizer], [self.scheduler] 534 | 535 | @pl.utilities.rank_zero_only 536 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]): 537 | save_path = self.output_dir.joinpath("best_tfmr") 538 | self.model.config.save_step = self.step_count 539 | self.model.save_pretrained(save_path) 540 | self.tokenizer.save_pretrained(save_path) 541 | 542 | @staticmethod 543 | def add_model_specific_args(parent_parser): 544 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 545 | 546 | parser.add_argument( 547 | "--encoder_layerdrop", 548 | type=float, 549 | help="Encoder layer dropout probability (Optional). Goes into model.config", 550 | ) 551 | parser.add_argument( 552 | "--decoder_layerdrop", 553 | type=float, 554 | help="Decoder layer dropout probability (Optional). Goes into model.config", 555 | ) 556 | parser.add_argument( 557 | "--dropout", 558 | type=float, 559 | help="Dropout probability (Optional). Goes into model.config", 560 | ) 561 | parser.add_argument( 562 | "--attention_dropout", 563 | type=float, 564 | help="Attention dropout probability (Optional). Goes into model.config", 565 | ) 566 | parser.add_argument( 567 | "--weight_decay", 568 | default=0.0, 569 | type=float, 570 | help="Weight decay if we apply some.", 571 | ) 572 | parser.add_argument( 573 | "--learning_rate", 574 | default=5e-5, 575 | type=float, 576 | help="The initial learning rate for Adam.", 577 | ) 578 | parser.add_argument( 579 | "--adam_epsilon", 580 | default=1e-8, 581 | type=float, 582 | help="Epsilon for Adam optimizer.", 583 | ) 584 | parser.add_argument("--adafactor", action="store_true") 585 | return parser 586 | 587 | 588 | class LoggingCallback(pl.Callback): 589 | # def on_batch_end(self, trainer, pl_module): 590 | # lr_scheduler = trainer.lr_schedulers[0]["scheduler"] 591 | # # lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} 592 | # # pl_module.logger.log_metrics(lrs) 593 | # pl_module.logger.log_metrics({"last_lr": lr_scheduler._last_lr}) 594 | 595 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 596 | rank_zero_info("***** Validation results *****") 597 | metrics = trainer.callback_metrics 598 | # Log results 599 | for key in sorted(metrics): 600 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 601 | 602 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 603 | rank_zero_info("***** Test results *****") 604 | metrics = trainer.callback_metrics 605 | # Log and save results to file 606 | output_test_results_file = os.path.join( 607 | pl_module.hparams.output_dir, "test_results.txt" 608 | ) 609 | with open(output_test_results_file, "w") as writer: 610 | for key in sorted(metrics): 611 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 612 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 613 | 614 | 615 | def make_trainer(argparse_args: Namespace): 616 | """ 617 | Prepare pl.Trainer with callbacks and args 618 | """ 619 | 620 | early_stopping = EarlyStopping(monitor="val_loss", mode="min", verbose=True) 621 | 622 | checkpoint_callback = ModelCheckpoint( 623 | dirpath=argparse_args.output_dir, 624 | filename="checkpoint-{epoch}-{val_loss:.2f}", 625 | save_top_k=1, 626 | verbose=True, 627 | monitor="val_loss", 628 | mode="min", 629 | ) 630 | lr_logger = LearningRateMonitor() 631 | logging_callback = LoggingCallback() 632 | 633 | train_params = {"deterministic": True} 634 | if args.gpus > 1: 635 | train_params["distributed_backend"] = "ddp" 636 | train_params["accumulate_grad_batches"] = args.accumulate_grad_batches 637 | 638 | trainer = pl.Trainer.from_argparse_args( 639 | argparse_args, 640 | callbacks=[lr_logger, early_stopping, checkpoint_callback, logging_callback], 641 | **train_params, 642 | ) 643 | return trainer, checkpoint_callback 644 | 645 | 646 | if __name__ == "__main__": 647 | 648 | parser = ArgumentParser(description="Transformers Document Classifier") 649 | 650 | parser.add_argument( 651 | "--model_name_or_path", 652 | default=None, 653 | type=str, 654 | required=True, 655 | help="Path to pretrained model or model identifier from huggingface.co/models", 656 | ) 657 | parser.add_argument( 658 | "--output_dir", 659 | default=None, 660 | type=str, 661 | required=True, 662 | help="The output directory where the model predictions and checkpoints will be written.", 663 | ) 664 | parser.add_argument( 665 | "--cache_dir", 666 | default="", 667 | type=str, 668 | help="Where do you want to store the pre-trained models downloaded from huggingface.co", 669 | ) 670 | parser.add_argument( 671 | "--seed", type=int, default=42, help="random seed for initialization" 672 | ) 673 | parser.add_argument( 674 | "--do_train", action="store_true", help="Whether to run training." 675 | ) 676 | parser.add_argument( 677 | "--do_predict", 678 | action="store_true", 679 | help="Whether to run predictions on the test set.", 680 | ) 681 | 682 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 683 | parser = SequenceClassificationModule.add_model_specific_args(parent_parser=parser) 684 | parser = SequenceClassificationDataModule.add_model_specific_args( 685 | parent_parser=parser 686 | ) 687 | args = parser.parse_args() 688 | 689 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. 690 | pl.seed_everything(args.seed) 691 | 692 | Path(args.output_dir).mkdir(exist_ok=True) 693 | 694 | # Logs loss and any other metrics specified in the fit function, 695 | # and optimizer data as parameters. Model checkpoints are logged 696 | # as artifacts and pytorch model is stored under `model` directory. 697 | mlflow.pytorch.autolog(log_every_n_epoch=1) 698 | 699 | dm = SequenceClassificationDataModule(args) 700 | dm.prepare_data() 701 | dm.setup(stage="fit") 702 | 703 | model = SequenceClassificationModule(args) 704 | 705 | trainer, checkpoint_callback = make_trainer(args) 706 | 707 | trainer.fit(model, dm) 708 | 709 | if args.do_predict: 710 | # NOTE: load the best checkpoint automatically 711 | trainer.test() 712 | -------------------------------------------------------------------------------- /ner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | from dataclasses import dataclass 5 | from enum import Enum 6 | from itertools import product, starmap 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional, Union 9 | 10 | import mlflow.pytorch 11 | import numpy as np 12 | import pytorch_lightning as pl 13 | import requests 14 | import torch 15 | from pytorch_lightning.callbacks import ( 16 | EarlyStopping, 17 | LearningRateMonitor, 18 | ModelCheckpoint, 19 | ) 20 | from pytorch_lightning.utilities import rank_zero_info 21 | from seqeval.metrics import ( 22 | accuracy_score, 23 | f1_score, 24 | precision_score, 25 | recall_score, 26 | ) 27 | from seqeval.scheme import BILOU 28 | from tokenizers import Encoding 29 | from torch.optim.lr_scheduler import ReduceLROnPlateau 30 | from torch.utils.data import DataLoader, Dataset 31 | from transformers import ( 32 | AdamW, 33 | BatchEncoding, 34 | BertConfig, 35 | BertForTokenClassification, 36 | BertTokenizerFast, 37 | PretrainedConfig, 38 | PreTrainedModel, 39 | PreTrainedTokenizerFast, 40 | ) 41 | from transformers.modeling_outputs import TokenClassifierOutput 42 | from transformers.optimization import Adafactor 43 | 44 | # huggingface/tokenizers: Disabling parallelism to avoid deadlocks. 45 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 46 | 47 | logger = logging.getLogger(__name__) 48 | 49 | IntList = List[int] 50 | IntListList = List[IntList] 51 | StrList = List[str] 52 | StrListList = List[StrList] 53 | PAD_TOKEN_LABEL_ID = -100 54 | 55 | 56 | class Split(Enum): 57 | train = "train" 58 | dev = "dev" 59 | test = "test" 60 | 61 | 62 | @dataclass 63 | class SpanAnnotation: 64 | start: int 65 | end: int 66 | label: str 67 | 68 | 69 | @dataclass 70 | class StringSpanExample: 71 | guid: str 72 | content: str 73 | annotations: List[SpanAnnotation] 74 | 75 | 76 | @dataclass 77 | class TokenClassificationExample: 78 | guid: str 79 | words: StrList 80 | labels: StrList 81 | 82 | 83 | @dataclass 84 | class InputFeatures: 85 | input_ids: IntList 86 | attention_mask: IntList 87 | label_ids: IntList 88 | 89 | 90 | def download_dataset(data_dir: Union[str, Path]): 91 | def _download_data(url, file_path): 92 | response = requests.get(url) 93 | if response.ok: 94 | with open(file_path, "w") as fp: 95 | fp.write(response.content.decode("utf8")) 96 | return file_path 97 | 98 | for mode in Split: 99 | mode = mode.value 100 | url = f"https://github.com/megagonlabs/UD_Japanese-GSD/releases/download/v2.6-NE/{mode}.bio" 101 | file_path = os.path.join(data_dir, f"{mode}.txt") 102 | if _download_data(url, file_path): 103 | logger.info(f"{mode} data is successfully downloaded") 104 | 105 | 106 | def is_boundary_line(line: str) -> bool: 107 | return line.startswith("-DOCSTART-") or line == "" or line == "\n" 108 | 109 | 110 | def bio2biolu(lines: StrList, label_idx: int = -1, delimiter: str = "\t") -> StrList: 111 | new_lines = [] 112 | n_lines = len(lines) 113 | for i, line in enumerate(lines): 114 | if is_boundary_line(line): 115 | new_lines.append(line) 116 | else: 117 | next_iob = None 118 | if i < n_lines - 1: 119 | next_line = lines[i + 1].strip() 120 | if not is_boundary_line(next_line): 121 | next_iob = next_line.split(delimiter)[label_idx][0] 122 | 123 | line = line.strip() 124 | current_line_content = line.split(delimiter) 125 | current_label = current_line_content[label_idx] 126 | word = current_line_content[0] 127 | tag_type = current_label[2:] 128 | iob = current_label[0] 129 | 130 | iob_transition = (iob, next_iob) 131 | current_iob = iob 132 | if iob_transition == ("B", "I"): 133 | current_iob = "B" 134 | elif iob_transition == ("I", "I"): 135 | current_iob = "I" 136 | elif iob_transition in {("B", "O"), ("B", "B"), ("B", None)}: 137 | current_iob = "U" 138 | elif iob_transition in {("I", "B"), ("I", "O"), ("I", None)}: 139 | current_iob = "L" 140 | elif iob == "O": 141 | current_iob = "O" 142 | else: 143 | logger.warning(f"Invalid BIO transition: {iob_transition}") 144 | if iob not in set("BIOLU"): 145 | current_iob = "O" 146 | biolu = f"{current_iob}-{tag_type}" if current_iob != "O" else "O" 147 | new_line = f"{word}{delimiter}{biolu}" 148 | new_lines.append(new_line) 149 | return new_lines 150 | 151 | 152 | def read_examples_from_file( 153 | data_dir: str, 154 | mode: Union[Split, str], 155 | label_idx: int = -1, 156 | delimiter: str = "\t", 157 | is_bio: bool = True, 158 | ) -> List[TokenClassificationExample]: 159 | """ 160 | Read token-wise data like CoNLL2003 from file 161 | """ 162 | if isinstance(mode, Split): 163 | mode = mode.value 164 | file_path = os.path.join(data_dir, f"{mode}.txt") 165 | guid_index = 1 166 | examples = [] 167 | with open(file_path, encoding="utf-8") as f: 168 | lines = [line for line in f] 169 | if is_bio: 170 | lines = bio2biolu(lines) 171 | words = [] 172 | labels = [] 173 | for line in lines: 174 | if is_boundary_line(line): 175 | if words: 176 | examples.append( 177 | TokenClassificationExample( 178 | guid=f"{mode}-{guid_index}", words=words, labels=labels 179 | ) 180 | ) 181 | guid_index += 1 182 | words = [] 183 | labels = [] 184 | else: 185 | splits = line.strip().split(delimiter) 186 | words.append(splits[0]) 187 | if len(splits) > 1: 188 | labels.append(splits[label_idx]) 189 | else: 190 | # for mode = "test" 191 | labels.append("O") 192 | if words: 193 | examples.append( 194 | TokenClassificationExample( 195 | guid=f"{mode}-{guid_index}", words=words, labels=labels 196 | ) 197 | ) 198 | return examples 199 | 200 | 201 | def convert_spandata( 202 | examples: List[TokenClassificationExample], 203 | ) -> List[StringSpanExample]: 204 | """ 205 | Convert token-wise data like CoNLL2003 into string-wise span data 206 | """ 207 | 208 | def _get_original_spans(words, text): 209 | word_spans = [] 210 | start = 0 211 | for w in words: 212 | word_spans.append((start, start + len(w))) 213 | start += len(w) 214 | assert words == [text[s:e] for s, e in word_spans] 215 | return word_spans 216 | 217 | new_examples: List[StringSpanExample] = [] 218 | for example in examples: 219 | words = example.words 220 | text = "".join(words) 221 | labels = example.labels 222 | annotations: List[SpanAnnotation] = [] 223 | 224 | word_spans = _get_original_spans(words, text) 225 | label_span = [] 226 | labeltype = "" 227 | for span, label in zip(word_spans, labels): 228 | if label == "O" and label_span and labeltype: 229 | start, end = label_span[0][0], label_span[-1][-1] 230 | annotations.append( 231 | SpanAnnotation(start=start, end=end, label=labeltype) 232 | ) 233 | label_span = [] 234 | elif label != "O": 235 | labeltype = label[2:] 236 | label_span.append(span) 237 | if label_span and labeltype: 238 | start, end = label_span[0][0], label_span[-1][-1] 239 | annotations.append(SpanAnnotation(start=start, end=end, label=labeltype)) 240 | 241 | new_examples.append( 242 | StringSpanExample(guid=example.guid, content=text, annotations=annotations) 243 | ) 244 | return new_examples 245 | 246 | 247 | class LabelTokenAligner: 248 | """ 249 | Align word-wise BIOLU-labels with subword tokens 250 | """ 251 | 252 | def __init__(self, labels_path: str): 253 | with open(labels_path, "r") as f: 254 | labels = [l for l in f.read().splitlines() if l and l != "O"] 255 | 256 | self.labels_to_id = {"O": 0} 257 | self.ids_to_label = {0: "O"} 258 | for i, (label, s) in enumerate(product(labels, "BILU"), 1): 259 | l = f"{s}-{label}" 260 | self.labels_to_id[l] = i 261 | self.ids_to_label[i] = l 262 | 263 | @staticmethod 264 | def get_ids_to_label(labels_path: str) -> Dict[int, str]: 265 | with open(labels_path, "r") as f: 266 | labels = [l for l in f.read().splitlines() if l and l != "O"] 267 | ids_to_label = { 268 | i: f"{s}-{label}" for i, (label, s) in enumerate(product(labels, "BILU"), 1) 269 | } 270 | ids_to_label[0] = "O" 271 | return ids_to_label 272 | 273 | @staticmethod 274 | def align_tokens_and_annotations_bilou( 275 | tokenized: Encoding, annotations: List[SpanAnnotation] 276 | ) -> StrList: 277 | """Make word-wise BIOLU-labels aligned with given subwords 278 | :param tokenized: output of PreTrainedTokenizerFast 279 | :param annotations: annotations of string span format 280 | """ 281 | aligned_labels = ["O"] * len( 282 | tokenized.tokens 283 | ) # Make a list to store our labels the same length as our tokens 284 | for anno in annotations: 285 | annotation_token_ix_set = set() 286 | for char_ix in range(anno.start, anno.end): 287 | token_ix = tokenized.char_to_token(char_ix) 288 | if token_ix is not None: 289 | annotation_token_ix_set.add(token_ix) 290 | if len(annotation_token_ix_set) == 1: 291 | token_ix = annotation_token_ix_set.pop() 292 | prefix = "U" 293 | aligned_labels[token_ix] = f"{prefix}-{anno.label}" 294 | else: 295 | last_token_in_anno_ix = len(annotation_token_ix_set) - 1 296 | for num, token_ix in enumerate(sorted(annotation_token_ix_set)): 297 | if num == 0: 298 | prefix = "B" 299 | elif num == last_token_in_anno_ix: 300 | prefix = "L" 301 | else: 302 | prefix = "I" 303 | aligned_labels[token_ix] = f"{prefix}-{anno.label}" 304 | return aligned_labels 305 | 306 | def align_labels_with_tokens( 307 | self, tokenized_text: Encoding, annotations: List[SpanAnnotation] 308 | ) -> IntList: 309 | # TODO: switch label encoding scheme, align_tokens_and_annotations_bio 310 | raw_labels = self.align_tokens_and_annotations_bilou( 311 | tokenized_text, annotations 312 | ) 313 | return list(map(lambda x: self.labels_to_id.get(x, 0), raw_labels)) 314 | 315 | 316 | class TokenClassificationDataset(Dataset): 317 | """ 318 | Build feature dataset so that the model can load 319 | """ 320 | 321 | def __init__( 322 | self, 323 | examples: List[StringSpanExample], 324 | tokenizer: PreTrainedTokenizerFast, 325 | label_token_aligner: LabelTokenAligner, 326 | tokens_per_batch: int = 32, 327 | window_stride: Optional[int] = None, 328 | ): 329 | """tokenize_and_align_labels with long text (i.e. truncation is disabled)""" 330 | self.features: List[InputFeatures] = [] 331 | self.examples: List[TokenClassificationExample] = [] 332 | texts: StrList = [ex.content for ex in examples] 333 | annotations: List[List[SpanAnnotation]] = [ex.annotations for ex in examples] 334 | 335 | if window_stride is None: 336 | self.window_stride = tokens_per_batch 337 | elif window_stride > tokens_per_batch: 338 | logger.error( 339 | "window_stride must be smaller than tokens_per_batch(max_seq_length)" 340 | ) 341 | else: 342 | logger.warning( 343 | """window_stride != tokens_per_batch: 344 | The input data windows are overlapping. Merge the overlapping labels after processing InputFeatures. 345 | """ 346 | ) 347 | 348 | # tokenize text into subwords 349 | # NOTE: add_special_tokens 350 | tokenized_batch: BatchEncoding = tokenizer(texts, add_special_tokens=False) 351 | encodings: List[Encoding] = tokenized_batch.encodings 352 | 353 | # align word-wise labels with subwords 354 | aligned_label_ids: IntListList = list( 355 | starmap( 356 | label_token_aligner.align_labels_with_tokens, 357 | zip(encodings, annotations), 358 | ) 359 | ) 360 | 361 | # perform manual padding and register features 362 | guids: StrList = [ex.guid for ex in examples] 363 | for guid, encoding, label_ids in zip(guids, encodings, aligned_label_ids): 364 | seq_length = len(label_ids) 365 | for start in range(0, seq_length, self.window_stride): 366 | end = min(start + tokens_per_batch, seq_length) 367 | n_padding_to_add = max(0, tokens_per_batch - end + start) 368 | self.features.append( 369 | InputFeatures( 370 | input_ids=encoding.ids[start:end] 371 | + [tokenizer.pad_token_id] * n_padding_to_add, 372 | label_ids=( 373 | label_ids[start:end] 374 | + [PAD_TOKEN_LABEL_ID] * n_padding_to_add 375 | ), 376 | attention_mask=( 377 | encoding.attention_mask[start:end] + [0] * n_padding_to_add 378 | ), 379 | ) 380 | ) 381 | subwords = encoding.tokens[start:end] 382 | labels = [ 383 | label_token_aligner.ids_to_label[i] for i in label_ids[start:end] 384 | ] 385 | self.examples.append( 386 | TokenClassificationExample(guid=guid, words=subwords, labels=labels) 387 | ) 388 | self._n_features = len(self.features) 389 | 390 | def __len__(self): 391 | return self._n_features 392 | 393 | def __getitem__(self, idx) -> InputFeatures: 394 | return self.features[idx] 395 | 396 | 397 | class InputFeaturesBatch: 398 | def __init__(self, features: List[InputFeatures]): 399 | self.input_ids: torch.Tensor 400 | self.attention_masks: torch.Tensor 401 | self.label_ids: Optional[torch.Tensor] 402 | 403 | self._n_features = len(features) 404 | input_ids_list: IntListList = [] 405 | masks_list: IntListList = [] 406 | label_ids_list: IntListList = [] 407 | for f in features: 408 | input_ids_list.append(f.input_ids) 409 | masks_list.append(f.attention_mask) 410 | if f.label_ids is not None: 411 | label_ids_list.append(f.label_ids) 412 | self.input_ids = torch.LongTensor(input_ids_list) 413 | self.attention_mask = torch.LongTensor(masks_list) 414 | if label_ids_list: 415 | self.label_ids = torch.LongTensor(label_ids_list) 416 | 417 | def __len__(self): 418 | return self._n_features 419 | 420 | def __getitem__(self, item): 421 | return getattr(self, item) 422 | 423 | 424 | class TokenClassificationDataModule(pl.LightningDataModule): 425 | """ 426 | Prepare dataset and build DataLoader 427 | """ 428 | 429 | def __init__(self, hparams: Namespace): 430 | self.tokenizer: PreTrainedTokenizerFast 431 | self.train_examples: List[TokenClassificationExample] 432 | self.val_examples: List[TokenClassificationExample] 433 | self.test_examples: List[TokenClassificationExample] 434 | self.train_data: List[StringSpanExample] 435 | self.val_data: List[StringSpanExample] 436 | self.test_data: List[StringSpanExample] 437 | self.train_dataset: TokenClassificationDataset 438 | self.val_dataset: TokenClassificationDataset 439 | self.test_dataset: TokenClassificationDataset 440 | 441 | super().__init__() 442 | self.max_seq_length = hparams.max_seq_length 443 | self.cache_dir = hparams.cache_dir if hparams.cache_dir else None 444 | if self.cache_dir is not None and not os.path.exists(self.cache_dir): 445 | os.mkdir(self.cache_dir) 446 | self.data_dir = hparams.data_dir 447 | if not os.path.exists(self.data_dir): 448 | os.mkdir(self.data_dir) 449 | self.tokenizer_name = hparams.model_name_or_path 450 | self.train_batch_size = hparams.train_batch_size 451 | self.eval_batch_size = hparams.eval_batch_size 452 | self.num_workers = hparams.num_workers 453 | self.num_samples = hparams.num_samples 454 | self.labels_path = hparams.labels 455 | 456 | def prepare_data(self): 457 | """ 458 | Downloads the data and prepare the tokenizer 459 | """ 460 | self.tokenizer = BertTokenizerFast.from_pretrained( 461 | self.tokenizer_name, 462 | cache_dir=self.cache_dir, 463 | tokenize_chinese_chars=False, 464 | strip_accents=False, 465 | ) 466 | data_dir = Path(self.data_dir) 467 | if ( 468 | not (data_dir / f"{Split.train.value}.txt").exists() 469 | or not (data_dir / f"{Split.dev.value}.txt").exists() 470 | or not (data_dir / f"{Split.test.value}.txt").exists() 471 | ): 472 | download_dataset(self.data_dir) 473 | self.train_examples = read_examples_from_file(self.data_dir, Split.train) 474 | self.val_examples = read_examples_from_file(self.data_dir, Split.dev) 475 | self.test_examples = read_examples_from_file(self.data_dir, Split.test) 476 | if self.num_samples > 0: 477 | self.train_examples = self.train_examples[: self.num_samples] 478 | self.val_examples = self.val_examples[: self.num_samples] 479 | self.test_examples = self.test_examples[: self.num_samples] 480 | self.train_spandata = convert_spandata(self.train_examples) 481 | self.val_spandata = convert_spandata(self.val_examples) 482 | self.test_spandata = convert_spandata(self.test_examples) 483 | 484 | if not os.path.exists(self.labels_path): 485 | all_labels = { 486 | l 487 | for ex in self.train_examples + self.val_examples + self.test_examples 488 | for l in ex.labels 489 | } 490 | label_types = sorted({l[2:] for l in sorted(all_labels) if l != "O"}) 491 | with open(self.labels_path, "w") as fp: 492 | fp.write("\n".join(label_types)) 493 | self.label_token_aligner = LabelTokenAligner(self.labels_path) 494 | 495 | self.train_dataset = self.create_dataset(self.train_spandata) 496 | self.val_dataset = self.create_dataset(self.val_spandata) 497 | self.test_dataset = self.create_dataset(self.test_spandata) 498 | 499 | self.dataset_size = len(self.train_dataset) 500 | 501 | def setup(self, stage=None): 502 | """ 503 | split the data into train, test, validation data 504 | :param stage: Stage - training or testing 505 | """ 506 | # our dataset is splitted in prior 507 | 508 | def create_dataset( 509 | self, data: List[StringSpanExample] 510 | ) -> TokenClassificationDataset: 511 | return TokenClassificationDataset( 512 | data, 513 | self.tokenizer, 514 | self.label_token_aligner, 515 | self.max_seq_length, 516 | ) 517 | 518 | @staticmethod 519 | def create_dataloader( 520 | ds: TokenClassificationDataset, 521 | batch_size: int, 522 | num_workers: int = 0, 523 | shuffle: bool = False, 524 | ) -> DataLoader: 525 | return DataLoader( 526 | ds, 527 | collate_fn=InputFeaturesBatch, 528 | batch_size=batch_size, 529 | num_workers=num_workers, 530 | pin_memory=True, 531 | shuffle=shuffle, 532 | ) 533 | 534 | def train_dataloader(self): 535 | return self.create_dataloader( 536 | self.train_dataset, self.train_batch_size, self.num_workers, shuffle=True 537 | ) 538 | 539 | def val_dataloader(self): 540 | return self.create_dataloader( 541 | self.val_dataset, self.eval_batch_size, self.num_workers, shuffle=False 542 | ) 543 | 544 | def test_dataloader(self): 545 | return self.create_dataloader( 546 | self.test_dataset, self.eval_batch_size, self.num_workers, shuffle=False 547 | ) 548 | 549 | def total_steps(self) -> int: 550 | """ 551 | The number of total training steps that will be run. Used for lr scheduler purposes. 552 | """ 553 | num_devices = max(1, self.hparams.gpus) # TODO: consider num_tpu_cores 554 | effective_batch_size = ( 555 | self.hparams.train_batch_size 556 | * self.hparams.accumulate_grad_batches 557 | * num_devices 558 | ) 559 | return (self.dataset_size / effective_batch_size) * self.hparams.max_epochs 560 | 561 | @staticmethod 562 | def add_model_specific_args(parent_parser): 563 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 564 | parser.add_argument( 565 | "--train_batch_size", 566 | type=int, 567 | default=32, 568 | help="input batch size for training (default: 32)", 569 | ) 570 | parser.add_argument( 571 | "--eval_batch_size", 572 | type=int, 573 | default=32, 574 | help="input batch size for validation/test (default: 32)", 575 | ) 576 | parser.add_argument( 577 | "--num_workers", 578 | type=int, 579 | default=4, 580 | metavar="N", 581 | help="number of workers (default: 3)", 582 | ) 583 | parser.add_argument( 584 | "--max_seq_length", 585 | default=256, 586 | type=int, 587 | help="The maximum total input sequence length after tokenization. Sequences longer " 588 | "than this will be truncated, sequences shorter will be padded.", 589 | ) 590 | parser.add_argument( 591 | "--labels", 592 | default="", 593 | type=str, 594 | help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.", 595 | ) 596 | parser.add_argument( 597 | "--data_dir", 598 | default="data", 599 | type=str, 600 | required=True, 601 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", 602 | ) 603 | parser.add_argument( 604 | "--num_samples", 605 | type=int, 606 | default=15000, 607 | metavar="N", 608 | help="Number of samples to be used for training and evaluation steps (default: 15000) Maximum:100000", 609 | ) 610 | return parser 611 | 612 | 613 | class TokenClassificationModule(pl.LightningModule): 614 | """ 615 | Initialize a model and config for token-classification 616 | """ 617 | 618 | def __init__(self, hparams: Union[Dict, Namespace]): 619 | # NOTE: internal code may pass hparams as dict **kwargs 620 | if isinstance(hparams, Dict): 621 | hparams = Namespace(**hparams) 622 | 623 | self.label_ids_to_label = LabelTokenAligner.get_ids_to_label(hparams.labels) 624 | num_labels = len(self.label_ids_to_label) 625 | 626 | super().__init__() 627 | # Enable to access arguments via self.hparams 628 | self.save_hyperparameters(hparams) 629 | 630 | self.step_count = 0 631 | self.output_dir = Path(self.hparams.output_dir) 632 | self.cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 633 | if self.cache_dir is not None and not os.path.exists(self.hparams.cache_dir): 634 | os.mkdir(self.cache_dir) 635 | 636 | # AutoTokenizer 637 | # trf>=4.0.0: PreTrainedTokenizerFast by default 638 | # NOTE: AutoTokenizer doesn't load PreTrainedTokenizerFast... 639 | self.tokenizer_name = self.hparams.model_name_or_path 640 | self.tokenizer = BertTokenizerFast.from_pretrained( 641 | self.tokenizer_name, 642 | cache_dir=self.cache_dir, 643 | tokenize_chinese_chars=False, 644 | strip_accents=False, 645 | ) 646 | 647 | # AutoConfig 648 | config_name = self.hparams.model_name_or_path 649 | self.config: PretrainedConfig = BertConfig.from_pretrained( 650 | config_name, 651 | **({"num_labels": num_labels} if num_labels is not None else {}), 652 | cache_dir=self.cache_dir, 653 | ) 654 | extra_model_params = ( 655 | "encoder_layerdrop", 656 | "decoder_layerdrop", 657 | "dropout", 658 | "attention_dropout", 659 | ) 660 | for p in extra_model_params: 661 | if getattr(self.hparams, p, None) and hasattr(self.config, p): 662 | setattr(self.config, p, getattr(self.hparams, p, None)) 663 | 664 | # AutoModelForTokenClassification 665 | self.model: PreTrainedModel = BertForTokenClassification.from_pretrained( 666 | self.hparams.model_name_or_path, 667 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 668 | config=self.config, 669 | cache_dir=self.cache_dir, 670 | ) 671 | 672 | self.scheduler = None 673 | self.optimizer = None 674 | 675 | def forward(self, **inputs) -> TokenClassifierOutput: 676 | """BertForTokenClassification.forward""" 677 | return self.model(**inputs) 678 | 679 | def shared_step(self, batch: InputFeaturesBatch) -> TokenClassifierOutput: 680 | # .to(self.device) is not necessary with pl.Traner ?? 681 | inputs = { 682 | "input_ids": batch.input_ids.to(self.device), 683 | "attention_mask": batch.attention_mask.to(self.device), 684 | "labels": batch.label_ids.to(self.device), 685 | } 686 | return self.model(**inputs) 687 | 688 | def training_step( 689 | self, train_batch: InputFeaturesBatch, batch_idx 690 | ) -> Dict[str, torch.Tensor]: 691 | output = self.shared_step(train_batch) 692 | loss = output.loss 693 | self.log("train_loss", loss, prog_bar=True) 694 | return {"loss": loss} 695 | 696 | def validation_step( 697 | self, val_batch: InputFeaturesBatch, batch_idx 698 | ) -> Dict[str, torch.Tensor]: 699 | output = self.shared_step(val_batch) 700 | return { 701 | "val_step_loss": output.loss, 702 | } 703 | 704 | def validation_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): 705 | avg_loss = torch.stack([x["val_step_loss"] for x in outputs]).mean() 706 | self.log("val_loss", avg_loss, sync_dist=True) 707 | 708 | def test_step( 709 | self, test_batch: InputFeaturesBatch, batch_idx 710 | ) -> Dict[str, torch.Tensor]: 711 | output = self.shared_step(test_batch) 712 | return {"pred": output.logits, "target": test_batch.label_ids} 713 | 714 | def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): 715 | preds = np.concatenate( 716 | [x["pred"].detach().cpu().numpy() for x in outputs], axis=0 717 | ) 718 | preds = np.argmax(preds, axis=2) 719 | target_ids = np.concatenate( 720 | [x["target"].detach().cpu().numpy() for x in outputs], axis=0 721 | ) 722 | 723 | target_list: StrListList = [[] for _ in range(target_ids.shape[0])] 724 | preds_list: StrListList = [[] for _ in range(target_ids.shape[0])] 725 | for i in range(target_ids.shape[0]): 726 | for j in range(target_ids.shape[1]): 727 | if target_ids[i][j] != PAD_TOKEN_LABEL_ID: 728 | target_list[i].append(self.label_ids_to_label[target_ids[i][j]]) 729 | preds_list[i].append(self.label_ids_to_label[preds[i][j]]) 730 | 731 | accuracy = accuracy_score(target_list, preds_list) 732 | precision = precision_score( 733 | target_list, preds_list, mode="strict", scheme=BILOU 734 | ) 735 | recall = recall_score(target_list, preds_list, mode="strict", scheme=BILOU) 736 | f1 = f1_score(target_list, preds_list, mode="strict", scheme=BILOU) 737 | self.log("test_accuracy", accuracy) 738 | self.log("test_precision", precision) 739 | self.log("test_recall", recall) 740 | self.log("test_f1", f1) 741 | 742 | def configure_optimizers(self): 743 | """Prepare optimizer and schedule (linear warmup and decay)""" 744 | model = self.model 745 | no_decay = ["bias", "LayerNorm.weight"] 746 | optimizer_grouped_parameters = [ 747 | { 748 | "params": [ 749 | p 750 | for n, p in model.named_parameters() 751 | if not any(nd in n for nd in no_decay) 752 | ], 753 | "weight_decay": self.hparams.weight_decay, 754 | }, 755 | { 756 | "params": [ 757 | p 758 | for n, p in model.named_parameters() 759 | if any(nd in n for nd in no_decay) 760 | ], 761 | "weight_decay": 0.0, 762 | }, 763 | ] 764 | if self.hparams.adafactor: 765 | self.optimizer = Adafactor( 766 | optimizer_grouped_parameters, 767 | lr=self.hparams.learning_rate, 768 | scale_parameter=False, 769 | relative_step=False, 770 | ) 771 | else: 772 | self.optimizer = AdamW( 773 | optimizer_grouped_parameters, 774 | lr=self.hparams.learning_rate, 775 | eps=self.hparams.adam_epsilon, 776 | ) 777 | return { 778 | 'optimizer': self.optimizer, 779 | 'lr_scheduler': ReduceLROnPlateau( 780 | self.optimizer, 781 | mode="min", 782 | factor=self.hparams.anneal_factor, 783 | patience=self.hparams.patience, 784 | min_lr=1e-6, 785 | verbose=True, 786 | ), 787 | 'monitor': "val_loss", 788 | } 789 | 790 | @pl.utilities.rank_zero_only 791 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]): 792 | save_path = self.output_dir.joinpath("best_tfmr") 793 | self.model.config.save_step = self.step_count 794 | self.model.save_pretrained(save_path) 795 | self.tokenizer.save_pretrained(save_path) 796 | 797 | @staticmethod 798 | def add_model_specific_args(parent_parser): 799 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 800 | 801 | parser.add_argument( 802 | "--encoder_layerdrop", 803 | type=float, 804 | help="Encoder layer dropout probability (Optional). Goes into model.config", 805 | ) 806 | parser.add_argument( 807 | "--decoder_layerdrop", 808 | type=float, 809 | help="Decoder layer dropout probability (Optional). Goes into model.config", 810 | ) 811 | parser.add_argument( 812 | "--dropout", 813 | type=float, 814 | help="Dropout probability (Optional). Goes into model.config", 815 | ) 816 | parser.add_argument( 817 | "--attention_dropout", 818 | type=float, 819 | help="Attention dropout probability (Optional). Goes into model.config", 820 | ) 821 | parser.add_argument( 822 | "--weight_decay", 823 | default=0.0, 824 | type=float, 825 | help="Weight decay if we apply some.", 826 | ) 827 | parser.add_argument( 828 | "--learning_rate", 829 | default=5e-5, 830 | type=float, 831 | help="The initial learning rate for Adam.", 832 | ) 833 | parser.add_argument( 834 | "--adam_epsilon", 835 | default=1e-8, 836 | type=float, 837 | help="Epsilon for Adam optimizer.", 838 | ) 839 | parser.add_argument("--adafactor", action="store_true") 840 | parser.add_argument( 841 | "--patience", 842 | default=3, 843 | type=int, 844 | help="Number of epochs with no improvement after which learning rate will be reduced.", 845 | ) 846 | parser.add_argument( 847 | "--anneal_factor", 848 | default=5e-5, 849 | type=float, 850 | help="Factor by which the learning rate will be reduced.", 851 | ) 852 | return parser 853 | 854 | 855 | class LoggingCallback(pl.Callback): 856 | # def on_batch_end(self, trainer, pl_module): 857 | # lr_scheduler = trainer.lr_schedulers[0]["scheduler"] 858 | # # lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} 859 | # # pl_module.logger.log_metrics(lrs) 860 | # pl_module.logger.log_metrics({"last_lr": lr_scheduler._last_lr}) 861 | 862 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 863 | rank_zero_info("***** Validation results *****") 864 | metrics = trainer.callback_metrics 865 | # Log results 866 | for key in sorted(metrics): 867 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 868 | 869 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 870 | rank_zero_info("***** Test results *****") 871 | metrics = trainer.callback_metrics 872 | # Log and save results to file 873 | output_test_results_file = os.path.join( 874 | pl_module.hparams.output_dir, "test_results.txt" 875 | ) 876 | with open(output_test_results_file, "w") as writer: 877 | for key in sorted(metrics): 878 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 879 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 880 | 881 | 882 | def make_trainer(argparse_args: Namespace): 883 | """ 884 | Prepare pl.Trainer with callbacks and args 885 | """ 886 | 887 | early_stopping = EarlyStopping(monitor="val_loss", mode="min", verbose=True) 888 | 889 | checkpoint_callback = ModelCheckpoint( 890 | dirpath=argparse_args.output_dir, 891 | filename="checkpoint-{epoch}-{val_loss:.2f}", 892 | save_top_k=1, 893 | verbose=True, 894 | monitor="val_loss", 895 | mode="min", 896 | ) 897 | lr_logger = LearningRateMonitor() 898 | logging_callback = LoggingCallback() 899 | 900 | train_params = {"deterministic": True} 901 | if args.gpus > 1: 902 | train_params["distributed_backend"] = "ddp" 903 | train_params["accumulate_grad_batches"] = args.accumulate_grad_batches 904 | 905 | trainer = pl.Trainer.from_argparse_args( 906 | argparse_args, 907 | callbacks=[lr_logger, early_stopping, checkpoint_callback, logging_callback], 908 | **train_params, 909 | ) 910 | return trainer, checkpoint_callback 911 | 912 | 913 | if __name__ == "__main__": 914 | 915 | parser = ArgumentParser(description="Transformers Token Classifier") 916 | 917 | parser.add_argument( 918 | "--model_name_or_path", 919 | default=None, 920 | type=str, 921 | required=True, 922 | help="Path to pretrained model or model identifier from huggingface.co/models", 923 | ) 924 | parser.add_argument( 925 | "--output_dir", 926 | default=None, 927 | type=str, 928 | required=True, 929 | help="The output directory where the model predictions and checkpoints will be written.", 930 | ) 931 | parser.add_argument( 932 | "--cache_dir", 933 | default="", 934 | type=str, 935 | help="Where do you want to store the pre-trained models downloaded from huggingface.co", 936 | ) 937 | parser.add_argument( 938 | "--seed", type=int, default=42, help="random seed for initialization" 939 | ) 940 | parser.add_argument( 941 | "--do_train", action="store_true", help="Whether to run training." 942 | ) 943 | parser.add_argument( 944 | "--do_predict", 945 | action="store_true", 946 | help="Whether to run predictions on the test set.", 947 | ) 948 | 949 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 950 | parser = TokenClassificationModule.add_model_specific_args(parent_parser=parser) 951 | parser = TokenClassificationDataModule.add_model_specific_args(parent_parser=parser) 952 | args = parser.parse_args() 953 | 954 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. 955 | pl.seed_everything(args.seed) 956 | 957 | Path(args.output_dir).mkdir(exist_ok=True) 958 | 959 | # Logs loss and any other metrics specified in the fit function, 960 | # and optimizer data as parameters. Model checkpoints are logged 961 | # as artifacts and pytorch model is stored under `model` directory. 962 | mlflow.pytorch.autolog(log_every_n_epoch=1) 963 | 964 | dm = TokenClassificationDataModule(args) 965 | dm.prepare_data() 966 | dm.setup(stage="fit") 967 | # DataModule must be loaded first, because args.labels is automatically generated 968 | model = TokenClassificationModule(args) 969 | 970 | trainer, checkpoint_callback = make_trainer(args) 971 | 972 | trainer.fit(model, dm) 973 | 974 | if args.do_predict: 975 | # NOTE: load the best checkpoint automatically 976 | trainer.test() 977 | --------------------------------------------------------------------------------