├── __init__.py ├── callback ├── __init__.py ├── optimizater │ ├── __init__.py │ ├── planradam.py │ ├── novograd.py │ ├── sgdw.py │ ├── lars.py │ ├── radam.py │ ├── nadam.py │ ├── adamw.py │ ├── ralamb.py │ ├── lookahead.py │ ├── lamb.py │ ├── ralars.py │ ├── adabound.py │ └── adafactor.py ├── trainingmonitor.py ├── progressbar.py ├── modelcheckpoint.py ├── adversarial.py └── lr_scheduler.py ├── losses ├── __init__.py ├── focal_loss.py └── label_smoothing.py ├── models ├── __init__.py ├── layers │ ├── __init__.py │ ├── linears.py │ └── crf.py └── bert_for_ner.py ├── tools ├── __init__.py ├── convert_albert_tf_checkpoint_to_pytorch.py ├── plot.py ├── download_clue_data.py ├── finetuning_argparse.py └── common.py ├── metrics ├── __init__.py └── ner_metrics.py ├── processors ├── __init__.py ├── utils_ner.py ├── ner_seq.py └── ner_span.py ├── datasets ├── cluener │ └── __init__.py └── cner │ └── .gitignore ├── outputs └── cner_output │ ├── bert │ ├── __init__.py │ └── .gitignore │ └── .gitignore ├── .idea ├── .gitignore ├── misc.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── vcs.xml ├── modules.xml └── BERT-NER-Pytorch.iml ├── scripts ├── run_ner_softmax.sh ├── run_ner_crf.sh └── run_ner_span.sh ├── LICENSE ├── .gitignore └── README.md /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /callback/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /callback/optimizater/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /processors/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /datasets/cluener/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /outputs/cner_output/bert/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/BERT-NER-Pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FocalLoss(nn.Module): 6 | '''Multi-class Focal loss implementation''' 7 | def __init__(self, gamma=2, weight=None,ignore_index=-100): 8 | super(FocalLoss, self).__init__() 9 | self.gamma = gamma 10 | self.weight = weight 11 | self.ignore_index=ignore_index 12 | 13 | def forward(self, input, target): 14 | """ 15 | input: [N, C] 16 | target: [N, ] 17 | """ 18 | logpt = F.log_softmax(input, dim=1) 19 | pt = torch.exp(logpt) 20 | logpt = (1-pt)**self.gamma * logpt 21 | loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index) 22 | return loss 23 | -------------------------------------------------------------------------------- /scripts/run_ner_softmax.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=`pwd` 2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese 3 | export DATA_DIR=$CURRENT_DIR/datasets 4 | export OUTPUR_DIR=$CURRENT_DIR/outputs 5 | TASK_NAME="cner" 6 | 7 | python run_ner_softmax.py \ 8 | --model_type=bert \ 9 | --model_name_or_path=$BERT_BASE_DIR \ 10 | --task_name=$TASK_NAME \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_lower_case \ 14 | --loss_type=ce \ 15 | --data_dir=$DATA_DIR/${TASK_NAME}/ \ 16 | --train_max_seq_length=128 \ 17 | --eval_max_seq_length=512 \ 18 | --per_gpu_train_batch_size=24 \ 19 | --per_gpu_eval_batch_size=24 \ 20 | --learning_rate=3e-5 \ 21 | --num_train_epochs=3.0 \ 22 | --logging_steps=-1 \ 23 | --save_steps=-1 \ 24 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \ 25 | --overwrite_output_dir \ 26 | --seed=42 27 | -------------------------------------------------------------------------------- /scripts/run_ner_crf.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=`pwd` 2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese 3 | export DATA_DIR=$CURRENT_DIR/datasets 4 | export OUTPUR_DIR=$CURRENT_DIR/outputs 5 | TASK_NAME="cner" 6 | # 7 | python run_ner_crf.py \ 8 | --model_type=bert \ 9 | --model_name_or_path=$BERT_BASE_DIR \ 10 | --task_name=$TASK_NAME \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_lower_case \ 14 | --data_dir=$DATA_DIR/${TASK_NAME}/ \ 15 | --train_max_seq_length=128 \ 16 | --eval_max_seq_length=512 \ 17 | --per_gpu_train_batch_size=24 \ 18 | --per_gpu_eval_batch_size=24 \ 19 | --learning_rate=3e-5 \ 20 | --crf_learning_rate=1e-3 \ 21 | --num_train_epochs=4.0 \ 22 | --logging_steps=-1 \ 23 | --save_steps=-1 \ 24 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \ 25 | --overwrite_output_dir \ 26 | --seed=42 27 | -------------------------------------------------------------------------------- /scripts/run_ner_span.sh: -------------------------------------------------------------------------------- 1 | CURRENT_DIR=`pwd` 2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese 3 | export DATA_DIR=$CURRENT_DIR/datasets 4 | export OUTPUR_DIR=$CURRENT_DIR/outputs 5 | TASK_NAME="cner" 6 | 7 | python run_ner_span.py \ 8 | --model_type=bert \ 9 | --model_name_or_path=$BERT_BASE_DIR \ 10 | --task_name=$TASK_NAME \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_adv \ 14 | --do_lower_case \ 15 | --loss_type=ce \ 16 | --data_dir=$DATA_DIR/${TASK_NAME}/ \ 17 | --train_max_seq_length=128 \ 18 | --eval_max_seq_length=512 \ 19 | --per_gpu_train_batch_size=24 \ 20 | --per_gpu_eval_batch_size=24 \ 21 | --learning_rate=2e-5 \ 22 | --num_train_epochs=4.0 \ 23 | --logging_steps=-1 \ 24 | --save_steps=-1 \ 25 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \ 26 | --overwrite_output_dir \ 27 | --seed=42 28 | 29 | -------------------------------------------------------------------------------- /losses/label_smoothing.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class LabelSmoothingCrossEntropy(nn.Module): 5 | def __init__(self, eps=0.1, reduction='mean',ignore_index=-100): 6 | super(LabelSmoothingCrossEntropy, self).__init__() 7 | self.eps = eps 8 | self.reduction = reduction 9 | self.ignore_index = ignore_index 10 | 11 | def forward(self, output, target): 12 | c = output.size()[-1] 13 | log_preds = F.log_softmax(output, dim=-1) 14 | if self.reduction=='sum': 15 | loss = -log_preds.sum() 16 | else: 17 | loss = -log_preds.sum(dim=-1) 18 | if self.reduction=='mean': 19 | loss = loss.mean() 20 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction, 21 | ignore_index=self.ignore_index) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Weitang Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/layers/linears.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeedForwardNetwork(nn.Module): 6 | def __init__(self, input_size, hidden_size, output_size, dropout_rate=0): 7 | super(FeedForwardNetwork, self).__init__() 8 | self.dropout_rate = dropout_rate 9 | self.linear1 = nn.Linear(input_size, hidden_size) 10 | self.linear2 = nn.Linear(hidden_size, output_size) 11 | 12 | def forward(self, x): 13 | x_proj = F.dropout(F.relu(self.linear1(x)), p=self.dropout_rate, training=self.training) 14 | x_proj = self.linear2(x_proj) 15 | return x_proj 16 | 17 | 18 | class PoolerStartLogits(nn.Module): 19 | def __init__(self, hidden_size, num_classes): 20 | super(PoolerStartLogits, self).__init__() 21 | self.dense = nn.Linear(hidden_size, num_classes) 22 | 23 | def forward(self, hidden_states, p_mask=None): 24 | x = self.dense(hidden_states) 25 | return x 26 | 27 | class PoolerEndLogits(nn.Module): 28 | def __init__(self, hidden_size, num_classes): 29 | super(PoolerEndLogits, self).__init__() 30 | self.dense_0 = nn.Linear(hidden_size, hidden_size) 31 | self.activation = nn.Tanh() 32 | self.LayerNorm = nn.LayerNorm(hidden_size) 33 | self.dense_1 = nn.Linear(hidden_size, num_classes) 34 | 35 | def forward(self, hidden_states, start_positions=None, p_mask=None): 36 | x = self.dense_0(torch.cat([hidden_states, start_positions], dim=-1)) 37 | x = self.activation(x) 38 | x = self.LayerNorm(x) 39 | x = self.dense_1(x) 40 | return x 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /datasets/cner/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /outputs/cner_output/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /outputs/cner_output/bert/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /callback/trainingmonitor.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | import numpy as np 3 | from pathlib import Path 4 | import matplotlib.pyplot as plt 5 | from ..tools.common import load_json 6 | from ..tools.common import save_json 7 | plt.switch_backend('agg') 8 | 9 | class TrainingMonitor(): 10 | def __init__(self, file_dir, arch, add_test=False): 11 | ''' 12 | :param startAt: 重新开始训练的epoch点 13 | ''' 14 | if isinstance(file_dir, Path): 15 | pass 16 | else: 17 | file_dir = Path(file_dir) 18 | file_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | self.arch = arch 21 | self.file_dir = file_dir 22 | self.H = {} 23 | self.add_test = add_test 24 | self.json_path = file_dir / (arch + "_training_monitor.json") 25 | 26 | def reset(self,start_at): 27 | if start_at > 0: 28 | if self.json_path is not None: 29 | if self.json_path.exists(): 30 | self.H = load_json(self.json_path) 31 | for k in self.H.keys(): 32 | self.H[k] = self.H[k][:start_at] 33 | 34 | def epoch_step(self, logs={}): 35 | for (k, v) in logs.items(): 36 | l = self.H.get(k, []) 37 | # np.float32会报错 38 | if not isinstance(v, np.float): 39 | v = round(float(v), 4) 40 | l.append(v) 41 | self.H[k] = l 42 | 43 | # 写入文件 44 | if self.json_path is not None: 45 | save_json(data = self.H,file_path=self.json_path) 46 | 47 | # 保存train图像 48 | if len(self.H["loss"]) == 1: 49 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()} 50 | 51 | if len(self.H["loss"]) > 1: 52 | # 指标变化 53 | # 曲线 54 | # 需要成对出现 55 | keys = [key for key, _ in self.H.items() if '_' not in key] 56 | for key in keys: 57 | N = np.arange(0, len(self.H[key])) 58 | plt.style.use("ggplot") 59 | plt.figure() 60 | plt.plot(N, self.H[key], label=f"train_{key}") 61 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}") 62 | if self.add_test: 63 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}") 64 | plt.legend() 65 | plt.xlabel("Epoch #") 66 | plt.ylabel(key) 67 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]") 68 | plt.savefig(str(self.paths[key])) 69 | plt.close() 70 | -------------------------------------------------------------------------------- /callback/optimizater/planradam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | class PlainRAdam(Optimizer): 5 | 6 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 7 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 8 | 9 | super(PlainRAdam, self).__init__(params, defaults) 10 | 11 | def __setstate__(self, state): 12 | super(PlainRAdam, self).__setstate__(state) 13 | 14 | def step(self, closure=None): 15 | 16 | loss = None 17 | if closure is not None: 18 | loss = closure() 19 | 20 | for group in self.param_groups: 21 | 22 | for p in group['params']: 23 | if p.grad is None: 24 | continue 25 | grad = p.grad.data.float() 26 | if grad.is_sparse: 27 | raise RuntimeError('RAdam does not support sparse gradients') 28 | 29 | p_data_fp32 = p.data.float() 30 | 31 | state = self.state[p] 32 | 33 | if len(state) == 0: 34 | state['step'] = 0 35 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 36 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 37 | else: 38 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 39 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 40 | 41 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 42 | beta1, beta2 = group['betas'] 43 | 44 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 45 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 46 | 47 | state['step'] += 1 48 | beta2_t = beta2 ** state['step'] 49 | N_sma_max = 2 / (1 - beta2) - 1 50 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 51 | 52 | if group['weight_decay'] != 0: 53 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 54 | 55 | # more conservative since it's an approximated value 56 | if N_sma >= 5: 57 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 58 | denom = exp_avg_sq.sqrt().add_(group['eps']) 59 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 60 | else: 61 | step_size = group['lr'] / (1 - beta1 ** state['step']) 62 | p_data_fp32.add_(-step_size, exp_avg) 63 | 64 | p.data.copy_(p_data_fp32) 65 | 66 | return loss -------------------------------------------------------------------------------- /tools/convert_albert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | """Convert ALBERT checkpoint.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import torch 9 | from models.transformers.modeling_albert import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert 10 | # from model.modeling_albert_bright import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert 11 | import logging 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 15 | # Initialise PyTorch model 16 | config = AlbertConfig.from_pretrained(bert_config_file) 17 | # print("Building PyTorch model from configuration: {}".format(str(config))) 18 | model = AlbertForPreTraining(config) 19 | # Load weights from tf checkpoint 20 | load_tf_weights_in_albert(model, config, tf_checkpoint_path) 21 | 22 | # Save pytorch-model 23 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 24 | torch.save(model.state_dict(), pytorch_dump_path) 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | ## Required parameters 30 | parser.add_argument("--tf_checkpoint_path", 31 | default = None, 32 | type = str, 33 | required = True, 34 | help = "Path to the TensorFlow checkpoint path.") 35 | parser.add_argument("--bert_config_file", 36 | default = None, 37 | type = str, 38 | required = True, 39 | help = "The config json file corresponding to the pre-trained BERT model. \n" 40 | "This specifies the model architecture.") 41 | parser.add_argument("--pytorch_dump_path", 42 | default = None, 43 | type = str, 44 | required = True, 45 | help = "Path to the output PyTorch model.") 46 | args = parser.parse_args() 47 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,args.bert_config_file, 48 | args.pytorch_dump_path) 49 | 50 | ''' 51 | python convert_albert_tf_checkpoint_to_pytorch.py \ 52 | --tf_checkpoint_path=./prev_trained_model/albert_large_zh \ 53 | --bert_config_file=./prev_trained_model/albert_large_zh/config.json \ 54 | --pytorch_dump_path=./prev_trained_model/albert_large_zh/pytorch_model.bin 55 | 56 | 57 | from model.modeling_albert_bright import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert 58 | python convert_albert_tf_checkpoint_to_pytorch.py \ 59 | --tf_checkpoint_path=./prev_trained_model/albert_base_bright \ 60 | --bert_config_file=./prev_trained_model/albert_base_bright/config.json \ 61 | --pytorch_dump_path=./prev_trained_model/albert_base_bright/pytorch_model.bin 62 | ''' -------------------------------------------------------------------------------- /tools/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from sklearn.metrics import confusion_matrix 4 | plt.switch_backend('agg') 5 | 6 | def plot_confusion_matrix(y_true, y_pred, classes, 7 | save_path,normalize=False,title=None, 8 | cmap=plt.cm.Blues): 9 | """ 10 | This function prints and plots the confusion matrix. 11 | Normalization can be applied by setting `normalize=True`. 12 | """ 13 | if not title: 14 | if normalize: 15 | title = 'Normalized confusion matrix' 16 | else: 17 | title = 'Confusion matrix, without normalization' 18 | # Compute confusion matrix 19 | cm = confusion_matrix(y_true=y_true, y_pred=y_pred) 20 | # Only use the labels that appear in the data 21 | if normalize: 22 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 23 | print("Normalized confusion matrix") 24 | else: 25 | print('Confusion matrix, without normalization') 26 | # --- plot--- # 27 | plt.rcParams['savefig.dpi'] = 200 28 | plt.rcParams['figure.dpi'] = 200 29 | plt.rcParams['figure.figsize'] = [20, 20] # plot 30 | plt.rcParams.update({'font.size': 10}) 31 | fig, ax = plt.subplots() 32 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 33 | # --- bar --- # 34 | from mpl_toolkits.axes_grid1 import make_axes_locatable 35 | divider = make_axes_locatable(ax) 36 | cax = divider.append_axes("right", size="5%", pad=0.05) 37 | plt.colorbar(im, cax=cax) 38 | # --- bar --- # 39 | # ax.figure.colorbar(im, ax=ax) 40 | # We want to show all ticks... 41 | ax.set(xticks=np.arange(cm.shape[1]), 42 | yticks=np.arange(cm.shape[0]), 43 | # ... and label them with the respective list entries 44 | xticklabels=classes, yticklabels=classes, 45 | title=title, 46 | ylabel='True label', 47 | xlabel='Predicted label') 48 | 49 | # Rotate the tick labels and set their alignment. 50 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 51 | rotation_mode="anchor") 52 | # Loop over data dimensions and create text annotations. 53 | fmt = '.2f' if normalize else 'd' 54 | thresh = cm.max() / 2. 55 | for i in range(cm.shape[0]): 56 | for j in range(cm.shape[1]): 57 | ax.text(j, i, format(cm[i, j], fmt), 58 | ha="center", va="center", 59 | color="white" if cm[i, j] > thresh else "black") 60 | fig.tight_layout() 61 | plt.savefig(save_path) 62 | 63 | if __name__ == "__main__": 64 | y_true = ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] 65 | y_pred = ['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O','B-PER', 'I-PER', 'O'] 66 | classes = ['O','B-MISC', 'I-MISC','B-PER', 'I-PER'] 67 | save_path = './ner_confusion_matrix.png' 68 | plot_confusion_matrix(y_true,y_pred,classes,save_path) -------------------------------------------------------------------------------- /callback/progressbar.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | 4 | 5 | class ProgressBar(object): 6 | ''' 7 | custom progress bar 8 | Example: 9 | >>> pbar = ProgressBar(n_total=30,desc='Training') 10 | >>> step = 2 11 | >>> pbar(step=step,info={'loss':20}) 12 | ''' 13 | 14 | def __init__(self, n_total, width=30, desc='Training',num_epochs = None): 15 | 16 | self.width = width 17 | self.n_total = n_total 18 | self.desc = desc 19 | self.start_time = time.time() 20 | self.num_epochs = num_epochs 21 | 22 | def reset(self): 23 | """Method to reset internal variables.""" 24 | self.start_time = time.time() 25 | 26 | def _time_info(self, now, current): 27 | time_per_unit = (now - self.start_time) / current 28 | if current < self.n_total: 29 | eta = time_per_unit * (self.n_total - current) 30 | if eta > 3600: 31 | eta_format = ('%d:%02d:%02d' % 32 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 33 | elif eta > 60: 34 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 35 | else: 36 | eta_format = '%ds' % eta 37 | time_info = f' - ETA: {eta_format}' 38 | else: 39 | if time_per_unit >= 1: 40 | time_info = f' {time_per_unit:.1f}s/step' 41 | elif time_per_unit >= 1e-3: 42 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 43 | else: 44 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 45 | return time_info 46 | 47 | def _bar(self, now, current): 48 | recv_per = current / self.n_total 49 | bar = f'[{self.desc}] {current}/{self.n_total} [' 50 | if recv_per >= 1: recv_per = 1 51 | prog_width = int(self.width * recv_per) 52 | if prog_width > 0: 53 | bar += '=' * (prog_width - 1) 54 | if current < self.n_total: 55 | bar += ">" 56 | else: 57 | bar += '=' 58 | bar += '.' * (self.width - prog_width) 59 | bar += ']' 60 | return bar 61 | 62 | def epoch_start(self,current_epoch): 63 | sys.stdout.write("\n") 64 | if (current_epoch is not None) and (self.num_epochs is not None): 65 | sys.stdout.write(f"Epoch: {current_epoch}/{self.num_epochs}") 66 | sys.stdout.write("\n") 67 | 68 | def __call__(self, step, info={}): 69 | now = time.time() 70 | current = step + 1 71 | bar = self._bar(now, current) 72 | show_bar = f"\r{bar}" + self._time_info(now, current) 73 | if len(info) != 0: 74 | show_bar = f'{show_bar} ' + " [" + "-".join( 75 | [f' {key}={value:.4f} ' for key, value in info.items()]) + "]" 76 | if current >= self.n_total: 77 | show_bar += '\n' 78 | sys.stdout.write(show_bar) 79 | sys.stdout.flush() 80 | 81 | -------------------------------------------------------------------------------- /tools/download_clue_data.py: -------------------------------------------------------------------------------- 1 | """ Script for downloading all CLUE data. 2 | For licence information, see the original dataset information links 3 | available from: https://www.cluebenchmarks.com/ 4 | Example usage: 5 | python download_clue_data.py --data_dir data --tasks all 6 | """ 7 | 8 | import os 9 | import sys 10 | import argparse 11 | import urllib.request 12 | import zipfile 13 | 14 | TASKS = ["afqmc", "cmnli", "copa", "csl", "iflytek", "tnews", "wsc","cmrc","chid","drcd",'cluener'] 15 | 16 | TASK2PATH = { 17 | "afqmc": "https://storage.googleapis.com/cluebenchmark/tasks/afqmc_public.zip", 18 | "cmnli": "https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip", 19 | "copa": "https://storage.googleapis.com/cluebenchmark/tasks/copa_public.zip", 20 | "csl": "https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip", 21 | "iflytek": "https://storage.googleapis.com/cluebenchmark/tasks/iflytek_public.zip", 22 | "tnews": "https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip", 23 | "wsc": "https://storage.googleapis.com/cluebenchmark/tasks/wsc_public.zip", 24 | 'cmrc': "https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip", 25 | "chid": "https://storage.googleapis.com/cluebenchmark/tasks/chid_public.zip", 26 | "drcd": "https://storage.googleapis.com/cluebenchmark/tasks/drcd_public.zip", 27 | 'cluener':'https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip' 28 | } 29 | 30 | def download_and_extract(task, data_dir): 31 | print("Downloading and extracting %s..." % task) 32 | if not os.path.isdir(data_dir): 33 | os.mkdir(data_dir) 34 | data_file = os.path.join(data_dir, "%s_public.zip" % task) 35 | save_dir = os.path.join(data_dir,task) 36 | if not os.path.isdir(save_dir): 37 | os.mkdir(save_dir) 38 | urllib.request.urlretrieve(TASK2PATH[task], data_file) 39 | with zipfile.ZipFile(data_file) as zip_ref: 40 | zip_ref.extractall(save_dir) 41 | os.remove(data_file) 42 | print(f"\tCompleted! Downloaded {task} data to directory {save_dir}") 43 | 44 | def get_tasks(task_names): 45 | task_names = task_names.split(",") 46 | if "all" in task_names: 47 | tasks = TASKS 48 | else: 49 | tasks = [] 50 | for task_name in task_names: 51 | assert task_name in TASKS, "Task %s not found!" % task_name 52 | tasks.append(task_name) 53 | return tasks 54 | 55 | def main(arguments): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument( 58 | "-d", "--data_dir", help="directory to save data to", type=str, default="../CLUEdatasets" 59 | ) 60 | parser.add_argument( 61 | "-t", 62 | "--tasks", 63 | help="tasks to download data for as a comma separated string", 64 | type=str, 65 | default="all", 66 | ) 67 | args = parser.parse_args(arguments) 68 | 69 | if not os.path.exists(args.data_dir): 70 | os.mkdir(args.data_dir) 71 | tasks = get_tasks(args.tasks) 72 | 73 | for task in tasks: 74 | download_and_extract(task, args.data_dir) 75 | 76 | if __name__ == "__main__": 77 | sys.exit(main(sys.argv[1:])) 78 | 79 | ''' 80 | python tools/download_clue_data.py --data_dir=./CLUEdatasets --tasks=cluener 81 | ''' -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Chinese NER using Bert 2 | 3 | BERT for Chinese NER. 4 | 5 | **update**:其他一些可以参考,包括Biaffine、GlobalPointer等:[examples](https://github.com/lonePatient/TorchBlocks/tree/master/examples) 6 | 7 | ### dataset list 8 | 9 | 1. cner: datasets/cner 10 | 2. CLUENER: https://github.com/CLUEbenchmark/CLUENER 11 | 12 | ### model list 13 | 14 | 1. BERT+Softmax 15 | 2. BERT+CRF 16 | 3. BERT+Span 17 | 18 | ### requirement 19 | 20 | 1. 1.1.0 =< PyTorch < 1.5.0 21 | 2. cuda=9.0 22 | 3. python3.6+ 23 | 24 | ### input format 25 | 26 | Input format (prefer BIOS tag scheme), with each character its label for one line. Sentences are splited with a null line. 27 | 28 | ```text 29 | 美 B-LOC 30 | 国 I-LOC 31 | 的 O 32 | 华 B-PER 33 | 莱 I-PER 34 | 士 I-PER 35 | 36 | 我 O 37 | 跟 O 38 | 他 O 39 | ``` 40 | 41 | ### run the code 42 | 43 | 1. Modify the configuration information in `run_ner_xxx.py` or `run_ner_xxx.sh` . 44 | 2. `sh scripts/run_ner_xxx.sh` 45 | 46 | **note**: file structure of the model 47 | 48 | ```text 49 | ├── prev_trained_model 50 | | └── bert_base 51 | | | └── pytorch_model.bin 52 | | | └── config.json 53 | | | └── vocab.txt 54 | | | └── ...... 55 | ``` 56 | 57 | ### CLUENER result 58 | 59 | The overall performance of BERT on **dev**: 60 | 61 | | | Accuracy (entity) | Recall (entity) | F1 score (entity) | 62 | | ------------ | ------------------ | ------------------ | ------------------ | 63 | | BERT+Softmax | 0.7897 | 0.8031 | 0.7963 | 64 | | BERT+CRF | 0.7977 | 0.8177 | 0.8076 | 65 | | BERT+Span | 0.8132 | 0.8092 | 0.8112 | 66 | | BERT+Span+adv | 0.8267 | 0.8073 | **0.8169** | 67 | | BERT-small(6 layers)+Span+kd | 0.8241 | 0.7839 | 0.8051 | 68 | | BERT+Span+focal_loss | 0.8121 | 0.8008 | 0.8064 | 69 | | BERT+Span+label_smoothing | 0.8235 | 0.7946 | 0.8088 | 70 | 71 | ### ALBERT for CLUENER 72 | 73 | The overall performance of ALBERT on **dev**: 74 | 75 | | model | version | Accuracy(entity) | Recall(entity) | F1(entity) | Train time/epoch | 76 | | ------ | ------------- | ---------------- | -------------- | ---------- | ---------------- | 77 | | albert | base_google | 0.8014 | 0.6908 | 0.7420 | 0.75x | 78 | | albert | large_google | 0.8024 | 0.7520 | 0.7763 | 2.1x | 79 | | albert | xlarge_google | 0.8286 | 0.7773 | 0.8021 | 6.7x | 80 | | bert | google | 0.8118 | 0.8031 | **0.8074** | ----- | 81 | | albert | base_bright | 0.8068 | 0.7529 | 0.7789 | 0.75x | 82 | | albert | large_bright | 0.8152 | 0.7480 | 0.7802 | 2.2x | 83 | | albert | xlarge_bright | 0.8222 | 0.7692 | 0.7948 | 7.3x | 84 | 85 | ### Cner result 86 | 87 | The overall performance of BERT on **dev(test)**: 88 | 89 | | | Accuracy (entity) | Recall (entity) | F1 score (entity) | 90 | | ------------ | ------------------ | ------------------ | ------------------ | 91 | | BERT+Softmax | 0.9586(0.9566) | 0.9644(0.9613) | 0.9615(0.9590) | 92 | | BERT+CRF | 0.9562(0.9539) | 0.9671(**0.9644**) | 0.9616(0.9591) | 93 | | BERT+Span | 0.9604(**0.9620**) | 0.9617(0.9632) | 0.9611(**0.9626**) | 94 | | BERT+Span+focal_loss | 0.9516(0.9569) | 0.9644(0.9681) | 0.9580(0.9625) | 95 | | BERT+Span+label_smoothing | 0.9566(0.9568) | 0.9624(0.9656) | 0.9595(0.9612) | 96 | -------------------------------------------------------------------------------- /callback/optimizater/novograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class NovoGrad(Optimizer): 7 | """Implements NovoGrad algorithm. 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): learning rate (default: 1e-2) 12 | betas (Tuple[float, float], optional): coefficients used for computing 13 | running averages of gradient and its square (default: (0.95, 0.98)) 14 | eps (float, optional): term added to the denominator to improve 15 | numerical stability (default: 1e-8) 16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 17 | Example: 18 | >>> model = ResNet() 19 | >>> optimizer = NovoGrad(model.parameters(), lr=1e-2, weight_decay=1e-5) 20 | """ 21 | 22 | def __init__(self, params, lr=0.01, betas=(0.95, 0.98), eps=1e-8, 23 | weight_decay=0, grad_averaging=False): 24 | if lr < 0.0: 25 | raise ValueError("Invalid learning rate: {}".format(lr)) 26 | if not 0.0 <= betas[0] < 1.0: 27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 28 | if not 0.0 <= betas[1] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 30 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging) 31 | super().__init__(params, defaults) 32 | 33 | def step(self, closure=None): 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | for group in self.param_groups: 38 | for p in group['params']: 39 | if p.grad is None: 40 | continue 41 | grad = p.grad.data 42 | if grad.is_sparse: 43 | raise RuntimeError('NovoGrad does not support sparse gradients') 44 | state = self.state[p] 45 | g_2 = torch.sum(grad ** 2) 46 | if len(state) == 0: 47 | state['step'] = 0 48 | state['moments'] = grad.div(g_2.sqrt() + group['eps']) + \ 49 | group['weight_decay'] * p.data 50 | state['grads_ema'] = g_2 51 | moments = state['moments'] 52 | grads_ema = state['grads_ema'] 53 | beta1, beta2 = group['betas'] 54 | state['step'] += 1 55 | grads_ema.mul_(beta2).add_(1 - beta2, g_2) 56 | 57 | denom = grads_ema.sqrt().add_(group['eps']) 58 | grad.div_(denom) 59 | # weight decay 60 | if group['weight_decay'] != 0: 61 | decayed_weights = torch.mul(p.data, group['weight_decay']) 62 | grad.add_(decayed_weights) 63 | 64 | # Momentum --> SAG 65 | if group['grad_averaging']: 66 | grad.mul_(1.0 - beta1) 67 | 68 | moments.mul_(beta1).add_(grad) # velocity 69 | 70 | bias_correction1 = 1 - beta1 ** state['step'] 71 | bias_correction2 = 1 - beta2 ** state['step'] 72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 73 | p.data.add_(-step_size, moments) 74 | 75 | return loss 76 | -------------------------------------------------------------------------------- /callback/optimizater/sgdw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | class SGDW(Optimizer): 5 | r"""Implements stochastic gradient descent (optionally with momentum) with 6 | weight decay from the paper `Fixing Weight Decay Regularization in Adam`_. 7 | 8 | Nesterov momentum is based on the formula from 9 | `On the importance of initialization and momentum in deep learning`__. 10 | 11 | Args: 12 | params (iterable): iterable of parameters to optimize or dicts defining 13 | parameter groups 14 | lr (float): learning rate 15 | momentum (float, optional): momentum factor (default: 0) 16 | weight_decay (float, optional): weight decay factor (default: 0) 17 | dampening (float, optional): dampening for momentum (default: 0) 18 | nesterov (bool, optional): enables Nesterov momentum (default: False) 19 | 20 | .. _Fixing Weight Decay Regularization in Adam: 21 | https://arxiv.org/abs/1711.05101 22 | 23 | Example: 24 | >>> model = LSTM() 25 | >>> optimizer = SGDW(model.parameters(), lr=0.1, momentum=0.9,weight_decay=1e-5) 26 | """ 27 | def __init__(self, params, lr=0.1, momentum=0, dampening=0, 28 | weight_decay=0, nesterov=False): 29 | if lr < 0.0: 30 | raise ValueError(f"Invalid learning rate: {lr}") 31 | if momentum < 0.0: 32 | raise ValueError(f"Invalid momentum value: {momentum}") 33 | if weight_decay < 0.0: 34 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 35 | 36 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 37 | weight_decay=weight_decay, nesterov=nesterov) 38 | if nesterov and (momentum <= 0 or dampening != 0): 39 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 40 | super(SGDW, self).__init__(params, defaults) 41 | 42 | def __setstate__(self, state): 43 | super(SGDW, self).__setstate__(state) 44 | for group in self.param_groups: 45 | group.setdefault('nesterov', False) 46 | 47 | def step(self, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | weight_decay = group['weight_decay'] 60 | momentum = group['momentum'] 61 | dampening = group['dampening'] 62 | nesterov = group['nesterov'] 63 | for p in group['params']: 64 | if p.grad is None: 65 | continue 66 | d_p = p.grad.data 67 | if momentum != 0: 68 | param_state = self.state[p] 69 | if 'momentum_buffer' not in param_state: 70 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 71 | buf.mul_(momentum).add_(d_p) 72 | else: 73 | buf = param_state['momentum_buffer'] 74 | buf.mul_(momentum).add_(1 - dampening, d_p) 75 | if nesterov: 76 | d_p = d_p.add(momentum, buf) 77 | else: 78 | d_p = buf 79 | if weight_decay != 0: 80 | p.data.add_(-weight_decay, p.data) 81 | p.data.add_(-group['lr'], d_p) 82 | return loss -------------------------------------------------------------------------------- /callback/optimizater/lars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | class Lars(Optimizer): 5 | r"""Implements the LARS optimizer from https://arxiv.org/pdf/1708.03888.pdf 6 | 7 | Args: 8 | params (iterable): iterable of parameters to optimize or dicts defining 9 | parameter groups 10 | lr (float): learning rate 11 | momentum (float, optional): momentum factor (default: 0) 12 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 13 | dampening (float, optional): dampening for momentum (default: 0) 14 | nesterov (bool, optional): enables Nesterov momentum (default: False) 15 | scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS 16 | Example: 17 | >>> model = ResNet() 18 | >>> optimizer = Lars(model.parameters(), lr=1e-2, weight_decay=1e-5) 19 | """ 20 | 21 | def __init__(self, params, lr, momentum=0, dampening=0, 22 | weight_decay=0, nesterov=False, scale_clip=None): 23 | if lr < 0.0: 24 | raise ValueError("Invalid learning rate: {}".format(lr)) 25 | if momentum < 0.0: 26 | raise ValueError("Invalid momentum value: {}".format(momentum)) 27 | if weight_decay < 0.0: 28 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 29 | 30 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 31 | weight_decay=weight_decay, nesterov=nesterov) 32 | if nesterov and (momentum <= 0 or dampening != 0): 33 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 34 | super(Lars, self).__init__(params, defaults) 35 | # LARS arguments 36 | self.scale_clip = scale_clip 37 | if self.scale_clip is None: 38 | self.scale_clip = (0, 10) 39 | 40 | def __setstate__(self, state): 41 | super(Lars, self).__setstate__(state) 42 | for group in self.param_groups: 43 | group.setdefault('nesterov', False) 44 | 45 | def step(self, closure=None): 46 | """Performs a single optimization step. 47 | 48 | Arguments: 49 | closure (callable, optional): A closure that reevaluates the model 50 | and returns the loss. 51 | """ 52 | loss = None 53 | if closure is not None: 54 | loss = closure() 55 | 56 | for group in self.param_groups: 57 | weight_decay = group['weight_decay'] 58 | momentum = group['momentum'] 59 | dampening = group['dampening'] 60 | nesterov = group['nesterov'] 61 | 62 | for p in group['params']: 63 | if p.grad is None: 64 | continue 65 | d_p = p.grad.data 66 | if weight_decay != 0: 67 | d_p.add_(weight_decay, p.data) 68 | if momentum != 0: 69 | param_state = self.state[p] 70 | if 'momentum_buffer' not in param_state: 71 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 72 | else: 73 | buf = param_state['momentum_buffer'] 74 | buf.mul_(momentum).add_(1 - dampening, d_p) 75 | if nesterov: 76 | d_p = d_p.add(momentum, buf) 77 | else: 78 | d_p = buf 79 | 80 | # LARS 81 | p_norm = p.data.pow(2).sum().sqrt() 82 | update_norm = d_p.pow(2).sum().sqrt() 83 | # Compute the local LR 84 | if p_norm == 0 or update_norm == 0: 85 | local_lr = 1 86 | else: 87 | local_lr = p_norm / update_norm 88 | 89 | p.data.add_(-group['lr'] * local_lr, d_p) 90 | 91 | return loss -------------------------------------------------------------------------------- /callback/modelcheckpoint.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | from ..tools.common import logger 5 | 6 | class ModelCheckpoint(object): 7 | ''' 8 | 模型保存,两种模式: 9 | 1. 直接保存最好模型 10 | 2. 按照epoch频率保存模型 11 | ''' 12 | def __init__(self, checkpoint_dir, 13 | monitor, 14 | arch,mode='min', 15 | epoch_freq=1, 16 | best = None, 17 | save_best_only = True): 18 | if isinstance(checkpoint_dir,Path): 19 | checkpoint_dir = checkpoint_dir 20 | else: 21 | checkpoint_dir = Path(checkpoint_dir) 22 | assert checkpoint_dir.is_dir() 23 | checkpoint_dir.mkdir(exist_ok=True) 24 | self.base_path = checkpoint_dir 25 | self.arch = arch 26 | self.monitor = monitor 27 | self.epoch_freq = epoch_freq 28 | self.save_best_only = save_best_only 29 | 30 | # 计算模式 31 | if mode == 'min': 32 | self.monitor_op = np.less 33 | self.best = np.Inf 34 | 35 | elif mode == 'max': 36 | self.monitor_op = np.greater 37 | self.best = -np.Inf 38 | # 这里主要重新加载模型时候 39 | #对best重新赋值 40 | if best: 41 | self.best = best 42 | 43 | if save_best_only: 44 | self.model_name = f"BEST_{arch}_MODEL.pth" 45 | 46 | def epoch_step(self, state,current): 47 | ''' 48 | 正常模型 49 | :param state: 需要保存的信息 50 | :param current: 当前判断指标 51 | :return: 52 | ''' 53 | # 是否保存最好模型 54 | if self.save_best_only: 55 | if self.monitor_op(current, self.best): 56 | logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}") 57 | self.best = current 58 | state['best'] = self.best 59 | best_path = self.base_path/ self.model_name 60 | torch.save(state, str(best_path)) 61 | # 每隔几个epoch保存下模型 62 | else: 63 | filename = self.base_path / f"EPOCH_{state['epoch']}_{state[self.monitor]}_{self.arch}_MODEL.pth" 64 | if state['epoch'] % self.epoch_freq == 0: 65 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.") 66 | torch.save(state, str(filename)) 67 | 68 | def bert_epoch_step(self, state,current): 69 | ''' 70 | 适合bert类型模型,适合pytorch_transformer模块 71 | :param state: 72 | :param current: 73 | :return: 74 | ''' 75 | model_to_save = state['model'] 76 | if self.save_best_only: 77 | if self.monitor_op(current, self.best): 78 | logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}") 79 | self.best = current 80 | state['best'] = self.best 81 | model_to_save.save_pretrained(str(self.base_path)) 82 | output_config_file = self.base_path / 'configs.json' 83 | with open(str(output_config_file), 'w') as f: 84 | f.write(model_to_save.config.to_json_string()) 85 | state.pop("model") 86 | torch.save(state,self.base_path / 'checkpoint_info.bin') 87 | else: 88 | if state['epoch'] % self.epoch_freq == 0: 89 | save_path = self.base_path / f"checkpoint-epoch-{state['epoch']}" 90 | save_path.mkdir(exist_ok=True) 91 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.") 92 | model_to_save.save_pretrained(save_path) 93 | output_config_file = save_path / 'configs.json' 94 | with open(str(output_config_file), 'w') as f: 95 | f.write(model_to_save.config.to_json_string()) 96 | state.pop("model") 97 | torch.save(state, save_path / 'checkpoint_info.bin') 98 | -------------------------------------------------------------------------------- /callback/optimizater/radam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | class RAdam(Optimizer): 5 | """Implements the RAdam optimizer from https://arxiv.org/pdf/1908.03265.pdf 6 | Args: 7 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 8 | lr (float, optional): learning rate 9 | betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) 10 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) 11 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 12 | Example: 13 | >>> model = ResNet() 14 | >>> optimizer = RAdam(model.parameters(), lr=0.001) 15 | """ 16 | 17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 18 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 19 | self.buffer = [[None, None, None] for ind in range(10)] 20 | super(RAdam, self).__init__(params, defaults) 21 | 22 | def __setstate__(self, state): 23 | super(RAdam, self).__setstate__(state) 24 | 25 | def step(self, closure=None): 26 | 27 | loss = None 28 | if closure is not None: 29 | loss = closure() 30 | 31 | for group in self.param_groups: 32 | 33 | for p in group['params']: 34 | if p.grad is None: 35 | continue 36 | grad = p.grad.data.float() 37 | if grad.is_sparse: 38 | raise RuntimeError('RAdam does not support sparse gradients') 39 | 40 | p_data_fp32 = p.data.float() 41 | 42 | state = self.state[p] 43 | 44 | if len(state) == 0: 45 | state['step'] = 0 46 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 47 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 48 | else: 49 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 50 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 51 | 52 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 53 | beta1, beta2 = group['betas'] 54 | 55 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 56 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 57 | 58 | state['step'] += 1 59 | buffered = self.buffer[int(state['step'] % 10)] 60 | if state['step'] == buffered[0]: 61 | N_sma, step_size = buffered[1], buffered[2] 62 | else: 63 | buffered[0] = state['step'] 64 | beta2_t = beta2 ** state['step'] 65 | N_sma_max = 2 / (1 - beta2) - 1 66 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 67 | buffered[1] = N_sma 68 | 69 | # more conservative since it's an approximated value 70 | if N_sma >= 5: 71 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 72 | else: 73 | step_size = 1.0 / (1 - beta1 ** state['step']) 74 | buffered[2] = step_size 75 | 76 | if group['weight_decay'] != 0: 77 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 78 | 79 | # more conservative since it's an approximated value 80 | if N_sma >= 5: 81 | denom = exp_avg_sq.sqrt().add_(group['eps']) 82 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 83 | else: 84 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 85 | 86 | p.data.copy_(p_data_fp32) 87 | 88 | return loss -------------------------------------------------------------------------------- /callback/optimizater/nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | 24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408 25 | NOTE: Has potential issues but does work well on some problems. 26 | Example: 27 | >>> model = LSTM() 28 | >>> optimizer = Nadam(model.parameters()) 29 | """ 30 | 31 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 32 | weight_decay=0, schedule_decay=4e-3): 33 | defaults = dict(lr=lr, betas=betas, eps=eps, 34 | weight_decay=weight_decay, schedule_decay=schedule_decay) 35 | super(Nadam, self).__init__(params, defaults) 36 | 37 | def step(self, closure=None): 38 | """Performs a single optimization step. 39 | 40 | Arguments: 41 | closure (callable, optional): A closure that reevaluates the model 42 | and returns the loss. 43 | """ 44 | loss = None 45 | if closure is not None: 46 | loss = closure() 47 | 48 | for group in self.param_groups: 49 | for p in group['params']: 50 | if p.grad is None: 51 | continue 52 | grad = p.grad.data 53 | state = self.state[p] 54 | 55 | # State initialization 56 | if len(state) == 0: 57 | state['step'] = 0 58 | state['m_schedule'] = 1. 59 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 60 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 61 | 62 | # Warming momentum schedule 63 | m_schedule = state['m_schedule'] 64 | schedule_decay = group['schedule_decay'] 65 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 66 | beta1, beta2 = group['betas'] 67 | eps = group['eps'] 68 | state['step'] += 1 69 | t = state['step'] 70 | 71 | if group['weight_decay'] != 0: 72 | grad = grad.add(group['weight_decay'], p.data) 73 | 74 | momentum_cache_t = beta1 * \ 75 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 76 | momentum_cache_t_1 = beta1 * \ 77 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 78 | m_schedule_new = m_schedule * momentum_cache_t 79 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 80 | state['m_schedule'] = m_schedule_new 81 | 82 | # Decay the first and second moment running average coefficient 83 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 84 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 85 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 86 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 87 | 88 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 89 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 90 | 91 | return loss -------------------------------------------------------------------------------- /metrics/ner_metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import Counter 3 | from processors.utils_ner import get_entities 4 | 5 | class SeqEntityScore(object): 6 | def __init__(self, id2label,markup='bios'): 7 | self.id2label = id2label 8 | self.markup = markup 9 | self.reset() 10 | 11 | def reset(self): 12 | self.origins = [] 13 | self.founds = [] 14 | self.rights = [] 15 | 16 | def compute(self, origin, found, right): 17 | recall = 0 if origin == 0 else (right / origin) 18 | precision = 0 if found == 0 else (right / found) 19 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) 20 | return recall, precision, f1 21 | 22 | def result(self): 23 | class_info = {} 24 | origin_counter = Counter([x[0] for x in self.origins]) 25 | found_counter = Counter([x[0] for x in self.founds]) 26 | right_counter = Counter([x[0] for x in self.rights]) 27 | for type_, count in origin_counter.items(): 28 | origin = count 29 | found = found_counter.get(type_, 0) 30 | right = right_counter.get(type_, 0) 31 | recall, precision, f1 = self.compute(origin, found, right) 32 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} 33 | origin = len(self.origins) 34 | found = len(self.founds) 35 | right = len(self.rights) 36 | recall, precision, f1 = self.compute(origin, found, right) 37 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info 38 | 39 | def update(self, label_paths, pred_paths): 40 | ''' 41 | labels_paths: [[],[],[],....] 42 | pred_paths: [[],[],[],.....] 43 | 44 | :param label_paths: 45 | :param pred_paths: 46 | :return: 47 | Example: 48 | >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 49 | >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']] 50 | ''' 51 | for label_path, pre_path in zip(label_paths, pred_paths): 52 | label_entities = get_entities(label_path, self.id2label,self.markup) 53 | pre_entities = get_entities(pre_path, self.id2label,self.markup) 54 | self.origins.extend(label_entities) 55 | self.founds.extend(pre_entities) 56 | self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities]) 57 | 58 | class SpanEntityScore(object): 59 | def __init__(self, id2label): 60 | self.id2label = id2label 61 | self.reset() 62 | 63 | def reset(self): 64 | self.origins = [] 65 | self.founds = [] 66 | self.rights = [] 67 | 68 | def compute(self, origin, found, right): 69 | recall = 0 if origin == 0 else (right / origin) 70 | precision = 0 if found == 0 else (right / found) 71 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall) 72 | return recall, precision, f1 73 | 74 | def result(self): 75 | class_info = {} 76 | origin_counter = Counter([self.id2label[x[0]] for x in self.origins]) 77 | found_counter = Counter([self.id2label[x[0]] for x in self.founds]) 78 | right_counter = Counter([self.id2label[x[0]] for x in self.rights]) 79 | for type_, count in origin_counter.items(): 80 | origin = count 81 | found = found_counter.get(type_, 0) 82 | right = right_counter.get(type_, 0) 83 | recall, precision, f1 = self.compute(origin, found, right) 84 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)} 85 | origin = len(self.origins) 86 | found = len(self.founds) 87 | right = len(self.rights) 88 | recall, precision, f1 = self.compute(origin, found, right) 89 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info 90 | 91 | def update(self, true_subject, pred_subject): 92 | self.origins.extend(true_subject) 93 | self.founds.extend(pred_subject) 94 | self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject]) 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /callback/adversarial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class FGM(): 4 | ''' 5 | Example 6 | # 初始化 7 | fgm = FGM(model,epsilon=1,emb_name='word_embeddings.') 8 | for batch_input, batch_label in data: 9 | # 正常训练 10 | loss = model(batch_input, batch_label) 11 | loss.backward() # 反向传播,得到正常的grad 12 | # 对抗训练 13 | fgm.attack() # 在embedding上添加对抗扰动 14 | loss_adv = model(batch_input, batch_label) 15 | loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 16 | fgm.restore() # 恢复embedding参数 17 | # 梯度下降,更新参数 18 | optimizer.step() 19 | model.zero_grad() 20 | ''' 21 | def __init__(self, model,emb_name,epsilon=1.0): 22 | # emb_name这个参数要换成你模型中embedding的参数名 23 | self.model = model 24 | self.epsilon = epsilon 25 | self.emb_name = emb_name 26 | self.backup = {} 27 | 28 | def attack(self): 29 | for name, param in self.model.named_parameters(): 30 | if param.requires_grad and self.emb_name in name: 31 | self.backup[name] = param.data.clone() 32 | norm = torch.norm(param.grad) 33 | if norm!=0 and not torch.isnan(norm): 34 | r_at = self.epsilon * param.grad / norm 35 | param.data.add_(r_at) 36 | 37 | def restore(self): 38 | for name, param in self.model.named_parameters(): 39 | if param.requires_grad and self.emb_name in name: 40 | assert name in self.backup 41 | param.data = self.backup[name] 42 | self.backup = {} 43 | 44 | class PGD(): 45 | ''' 46 | Example 47 | pgd = PGD(model,emb_name='word_embeddings.',epsilon=1.0,alpha=0.3) 48 | K = 3 49 | for batch_input, batch_label in data: 50 | # 正常训练 51 | loss = model(batch_input, batch_label) 52 | loss.backward() # 反向传播,得到正常的grad 53 | pgd.backup_grad() 54 | # 对抗训练 55 | for t in range(K): 56 | pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data 57 | if t != K-1: 58 | model.zero_grad() 59 | else: 60 | pgd.restore_grad() 61 | loss_adv = model(batch_input, batch_label) 62 | loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度 63 | pgd.restore() # 恢复embedding参数 64 | # 梯度下降,更新参数 65 | optimizer.step() 66 | model.zero_grad() 67 | ''' 68 | def __init__(self, model,emb_name,epsilon=1.,alpha=0.3): 69 | # emb_name这个参数要换成你模型中embedding的参数名 70 | self.model = model 71 | self.emb_name = emb_name 72 | self.epsilon = epsilon 73 | self.alpha = alpha 74 | self.emb_backup = {} 75 | self.grad_backup = {} 76 | 77 | def attack(self,is_first_attack=False): 78 | for name, param in self.model.named_parameters(): 79 | if param.requires_grad and self.emb_name in name: 80 | if is_first_attack: 81 | self.emb_backup[name] = param.data.clone() 82 | norm = torch.norm(param.grad) 83 | if norm != 0: 84 | r_at = self.alpha * param.grad / norm 85 | param.data.add_(r_at) 86 | param.data = self.project(name, param.data, self.epsilon) 87 | 88 | def restore(self): 89 | for name, param in self.model.named_parameters(): 90 | if param.requires_grad and self.emb_name in name: 91 | assert name in self.emb_backup 92 | param.data = self.emb_backup[name] 93 | self.emb_backup = {} 94 | 95 | def project(self, param_name, param_data, epsilon): 96 | r = param_data - self.emb_backup[param_name] 97 | if torch.norm(r) > epsilon: 98 | r = epsilon * r / torch.norm(r) 99 | return self.emb_backup[param_name] + r 100 | 101 | def backup_grad(self): 102 | for name, param in self.model.named_parameters(): 103 | if param.requires_grad: 104 | self.grad_backup[name] = param.grad.clone() 105 | 106 | def restore_grad(self): 107 | for name, param in self.model.named_parameters(): 108 | if param.requires_grad: 109 | param.grad = self.grad_backup[name] -------------------------------------------------------------------------------- /callback/optimizater/adamw.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | 5 | class AdamW(Optimizer): 6 | """ Implements Adam algorithm with weight decay fix. 7 | 8 | Parameters: 9 | lr (float): learning rate. Default 1e-3. 10 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999) 11 | eps (float): Adams epsilon. Default: 1e-6 12 | weight_decay (float): Weight decay. Default: 0.0 13 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True. 14 | Example: 15 | >>> model = LSTM() 16 | >>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5) 17 | """ 18 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True): 19 | if lr < 0.0: 20 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 21 | if not 0.0 <= betas[0] < 1.0: 22 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])) 23 | if not 0.0 <= betas[1] < 1.0: 24 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])) 25 | if not 0.0 <= eps: 26 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps)) 27 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 28 | correct_bias=correct_bias) 29 | super(AdamW, self).__init__(params, defaults) 30 | 31 | def step(self, closure=None): 32 | """Performs a single optimization step. 33 | 34 | Arguments: 35 | closure (callable, optional): A closure that reevaluates the model 36 | and returns the loss. 37 | """ 38 | loss = None 39 | if closure is not None: 40 | loss = closure() 41 | 42 | for group in self.param_groups: 43 | for p in group['params']: 44 | if p.grad is None: 45 | continue 46 | grad = p.grad.data 47 | if grad.is_sparse: 48 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 49 | 50 | state = self.state[p] 51 | 52 | # State initialization 53 | if len(state) == 0: 54 | state['step'] = 0 55 | # Exponential moving average of gradient values 56 | state['exp_avg'] = torch.zeros_like(p.data) 57 | # Exponential moving average of squared gradient values 58 | state['exp_avg_sq'] = torch.zeros_like(p.data) 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | state['step'] += 1 64 | 65 | # Decay the first and second moment running average coefficient 66 | # In-place operations to update the averages at the same time 67 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad) 68 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad) 69 | denom = exp_avg_sq.sqrt().add_(group['eps']) 70 | 71 | step_size = group['lr'] 72 | if group['correct_bias']: # No bias correction for Bert 73 | bias_correction1 = 1.0 - beta1 ** state['step'] 74 | bias_correction2 = 1.0 - beta2 ** state['step'] 75 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 76 | 77 | p.data.addcdiv_(-step_size, exp_avg, denom) 78 | 79 | # Just adding the square of the weights to the loss function is *not* 80 | # the correct way of using L2 regularization/weight decay with Adam, 81 | # since that will interact with the m and v parameters in strange ways. 82 | # 83 | # Instead we want to decay the weights in a manner that doesn't interact 84 | # with the m/v parameters. This is equivalent to adding the square 85 | # of the weights to the loss with plain (non-momentum) SGD. 86 | # Add weight decay at the end (fixed version) 87 | if group['weight_decay'] > 0.0: 88 | p.data.add_(-group['lr'] * group['weight_decay'], p.data) 89 | 90 | return loss 91 | -------------------------------------------------------------------------------- /callback/optimizater/ralamb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | class Ralamb(Optimizer): 6 | ''' 7 | RAdam + LARS 8 | Example: 9 | >>> model = ResNet() 10 | >>> optimizer = Ralamb(model.parameters(), lr=0.001) 11 | ''' 12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 14 | self.buffer = [[None, None, None] for ind in range(10)] 15 | super(Ralamb, self).__init__(params, defaults) 16 | 17 | def __setstate__(self, state): 18 | super(Ralamb, self).__setstate__(state) 19 | 20 | def step(self, closure=None): 21 | 22 | loss = None 23 | if closure is not None: 24 | loss = closure() 25 | 26 | for group in self.param_groups: 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad.data.float() 32 | if grad.is_sparse: 33 | raise RuntimeError('Ralamb does not support sparse gradients') 34 | 35 | p_data_fp32 = p.data.float() 36 | 37 | state = self.state[p] 38 | 39 | if len(state) == 0: 40 | state['step'] = 0 41 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 43 | else: 44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 46 | 47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 48 | beta1, beta2 = group['betas'] 49 | 50 | # Decay the first and second moment running average coefficient 51 | # m_t 52 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 53 | # v_t 54 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 55 | 56 | state['step'] += 1 57 | buffered = self.buffer[int(state['step'] % 10)] 58 | 59 | if state['step'] == buffered[0]: 60 | N_sma, radam_step_size = buffered[1], buffered[2] 61 | else: 62 | buffered[0] = state['step'] 63 | beta2_t = beta2 ** state['step'] 64 | N_sma_max = 2 / (1 - beta2) - 1 65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 66 | buffered[1] = N_sma 67 | 68 | # more conservative since it's an approximated value 69 | if N_sma >= 5: 70 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 71 | else: 72 | radam_step_size = 1.0 / (1 - beta1 ** state['step']) 73 | buffered[2] = radam_step_size 74 | 75 | if group['weight_decay'] != 0: 76 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 77 | 78 | # more conservative since it's an approximated value 79 | radam_step = p_data_fp32.clone() 80 | if N_sma >= 5: 81 | denom = exp_avg_sq.sqrt().add_(group['eps']) 82 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) 83 | else: 84 | radam_step.add_(-radam_step_size * group['lr'], exp_avg) 85 | 86 | radam_norm = radam_step.pow(2).sum().sqrt() 87 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 88 | if weight_norm == 0 or radam_norm == 0: 89 | trust_ratio = 1 90 | else: 91 | trust_ratio = weight_norm / radam_norm 92 | 93 | state['weight_norm'] = weight_norm 94 | state['adam_norm'] = radam_norm 95 | state['trust_ratio'] = trust_ratio 96 | 97 | if N_sma >= 5: 98 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) 99 | else: 100 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) 101 | 102 | p.data.copy_(p_data_fp32) 103 | 104 | return loss -------------------------------------------------------------------------------- /callback/optimizater/lookahead.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | from collections import defaultdict 4 | 5 | class Lookahead(Optimizer): 6 | ''' 7 | PyTorch implementation of the lookahead wrapper. 8 | Lookahead Optimizer: https://arxiv.org/abs/1907.08610 9 | 10 | We found that evaluation performance is typically better using the slow weights. 11 | This can be done in PyTorch with something like this in your eval loop: 12 | if args.lookahead: 13 | optimizer._backup_and_load_cache() 14 | val_loss = eval_func(model) 15 | optimizer._clear_and_load_backup() 16 | ''' 17 | def __init__(self, optimizer,alpha=0.5, k=6,pullback_momentum="none"): 18 | ''' 19 | :param optimizer:inner optimizer 20 | :param k (int): number of lookahead steps 21 | :param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer. 22 | :param pullback_momentum (str): change to inner optimizer momentum on interpolation update 23 | ''' 24 | if not 0.0 <= alpha <= 1.0: 25 | raise ValueError(f'Invalid slow update rate: {alpha}') 26 | if not 1 <= k: 27 | raise ValueError(f'Invalid lookahead steps: {k}') 28 | self.optimizer = optimizer 29 | self.param_groups = self.optimizer.param_groups 30 | self.alpha = alpha 31 | self.k = k 32 | self.step_counter = 0 33 | assert pullback_momentum in ["reset", "pullback", "none"] 34 | self.pullback_momentum = pullback_momentum 35 | self.state = defaultdict(dict) 36 | 37 | # Cache the current optimizer parameters 38 | for group in self.optimizer.param_groups: 39 | for p in group['params']: 40 | param_state = self.state[p] 41 | param_state['cached_params'] = torch.zeros_like(p.data) 42 | param_state['cached_params'].copy_(p.data) 43 | 44 | def __getstate__(self): 45 | return { 46 | 'state': self.state, 47 | 'optimizer': self.optimizer, 48 | 'alpha': self.alpha, 49 | 'step_counter': self.step_counter, 50 | 'k':self.k, 51 | 'pullback_momentum': self.pullback_momentum 52 | } 53 | 54 | def zero_grad(self): 55 | self.optimizer.zero_grad() 56 | 57 | def state_dict(self): 58 | return self.optimizer.state_dict() 59 | 60 | def load_state_dict(self, state_dict): 61 | self.optimizer.load_state_dict(state_dict) 62 | 63 | def _backup_and_load_cache(self): 64 | """Useful for performing evaluation on the slow weights (which typically generalize better) 65 | """ 66 | for group in self.optimizer.param_groups: 67 | for p in group['params']: 68 | param_state = self.state[p] 69 | param_state['backup_params'] = torch.zeros_like(p.data) 70 | param_state['backup_params'].copy_(p.data) 71 | p.data.copy_(param_state['cached_params']) 72 | 73 | def _clear_and_load_backup(self): 74 | for group in self.optimizer.param_groups: 75 | for p in group['params']: 76 | param_state = self.state[p] 77 | p.data.copy_(param_state['backup_params']) 78 | del param_state['backup_params'] 79 | 80 | def step(self, closure=None): 81 | """Performs a single Lookahead optimization step. 82 | Arguments: 83 | closure (callable, optional): A closure that reevaluates the model 84 | and returns the loss. 85 | """ 86 | loss = self.optimizer.step(closure) 87 | self.step_counter += 1 88 | 89 | if self.step_counter >= self.k: 90 | self.step_counter = 0 91 | # Lookahead and cache the current optimizer parameters 92 | for group in self.optimizer.param_groups: 93 | for p in group['params']: 94 | param_state = self.state[p] 95 | p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line 96 | param_state['cached_params'].copy_(p.data) 97 | if self.pullback_momentum == "pullback": 98 | internal_momentum = self.optimizer.state[p]["momentum_buffer"] 99 | self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( 100 | 1.0 - self.alpha, param_state["cached_mom"]) 101 | param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] 102 | elif self.pullback_momentum == "reset": 103 | self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) 104 | 105 | return loss 106 | -------------------------------------------------------------------------------- /callback/optimizater/lamb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.optimizer import Optimizer 3 | 4 | 5 | class Lamb(Optimizer): 6 | r"""Implements Lamb algorithm. 7 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): learning rate (default: 1e-3) 12 | betas (Tuple[float, float], optional): coefficients used for computing 13 | running averages of gradient and its square (default: (0.9, 0.999)) 14 | eps (float, optional): term added to the denominator to improve 15 | numerical stability (default: 1e-8) 16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 17 | adam (bool, optional): always use trust ratio = 1, which turns this into 18 | Adam. Useful for comparison purposes. 19 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 20 | https://arxiv.org/abs/1904.00962 21 | Example: 22 | >>> model = ResNet() 23 | >>> optimizer = Lamb(model.parameters(), lr=1e-2, weight_decay=1e-5) 24 | """ 25 | 26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 27 | weight_decay=0, adam=False): 28 | if not 0.0 <= lr: 29 | raise ValueError("Invalid learning rate: {}".format(lr)) 30 | if not 0.0 <= eps: 31 | raise ValueError("Invalid epsilon value: {}".format(eps)) 32 | if not 0.0 <= betas[0] < 1.0: 33 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 34 | if not 0.0 <= betas[1] < 1.0: 35 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 36 | defaults = dict(lr=lr, betas=betas, eps=eps, 37 | weight_decay=weight_decay) 38 | self.adam = adam 39 | super(Lamb, self).__init__(params, defaults) 40 | 41 | def step(self, closure=None): 42 | """Performs a single optimization step. 43 | Arguments: 44 | closure (callable, optional): A closure that reevaluates the model 45 | and returns the loss. 46 | """ 47 | loss = None 48 | if closure is not None: 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | grad = p.grad.data 56 | if grad.is_sparse: 57 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 58 | 59 | state = self.state[p] 60 | 61 | # State initialization 62 | if len(state) == 0: 63 | state['step'] = 0 64 | # Exponential moving average of gradient values 65 | state['exp_avg'] = torch.zeros_like(p.data) 66 | # Exponential moving average of squared gradient values 67 | state['exp_avg_sq'] = torch.zeros_like(p.data) 68 | 69 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 70 | beta1, beta2 = group['betas'] 71 | 72 | state['step'] += 1 73 | 74 | # Decay the first and second moment running average coefficient 75 | # m_t 76 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 77 | # v_t 78 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 79 | 80 | # Paper v3 does not use debiasing. 81 | # bias_correction1 = 1 - beta1 ** state['step'] 82 | # bias_correction2 = 1 - beta2 ** state['step'] 83 | # Apply bias to lr to avoid broadcast. 84 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 85 | 86 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 87 | 88 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 89 | if group['weight_decay'] != 0: 90 | adam_step.add_(group['weight_decay'], p.data) 91 | 92 | adam_norm = adam_step.pow(2).sum().sqrt() 93 | if weight_norm == 0 or adam_norm == 0: 94 | trust_ratio = 1 95 | else: 96 | trust_ratio = weight_norm / adam_norm 97 | state['weight_norm'] = weight_norm 98 | state['adam_norm'] = adam_norm 99 | state['trust_ratio'] = trust_ratio 100 | if self.adam: 101 | trust_ratio = 1 102 | 103 | p.data.add_(-step_size * trust_ratio, adam_step) 104 | 105 | return loss -------------------------------------------------------------------------------- /callback/optimizater/ralars.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer 4 | 5 | 6 | class RaLars(Optimizer): 7 | """Implements the RAdam optimizer from https://arxiv.org/pdf/1908.03265.pdf 8 | with optional Layer-wise adaptive Scaling from https://arxiv.org/pdf/1708.03888.pdf 9 | 10 | Args: 11 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups 12 | lr (float, optional): learning rate 13 | betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) 14 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) 15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 16 | scale_clip (float, optional): the maximal upper bound for the scale factor of LARS 17 | Example: 18 | >>> model = ResNet() 19 | >>> optimizer = RaLars(model.parameters(), lr=0.001) 20 | """ 21 | 22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, 23 | scale_clip=None): 24 | if not 0.0 <= lr: 25 | raise ValueError("Invalid learning rate: {}".format(lr)) 26 | if not 0.0 <= eps: 27 | raise ValueError("Invalid epsilon value: {}".format(eps)) 28 | if not 0.0 <= betas[0] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 30 | if not 0.0 <= betas[1] < 1.0: 31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 32 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 33 | super(RaLars, self).__init__(params, defaults) 34 | # LARS arguments 35 | self.scale_clip = scale_clip 36 | if self.scale_clip is None: 37 | self.scale_clip = (0, 10) 38 | 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | Arguments: 42 | closure (callable, optional): A closure that reevaluates the model 43 | and returns the loss. 44 | """ 45 | loss = None 46 | if closure is not None: 47 | loss = closure() 48 | 49 | for group in self.param_groups: 50 | 51 | # Get group-shared variables 52 | beta1, beta2 = group['betas'] 53 | sma_inf = group.get('sma_inf') 54 | # Compute max length of SMA on first step 55 | if not isinstance(sma_inf, float): 56 | group['sma_inf'] = 2 / (1 - beta2) - 1 57 | sma_inf = group.get('sma_inf') 58 | 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('RAdam does not support sparse gradients') 65 | 66 | state = self.state[p] 67 | 68 | # State initialization 69 | if len(state) == 0: 70 | state['step'] = 0 71 | # Exponential moving average of gradient values 72 | state['exp_avg'] = torch.zeros_like(p.data) 73 | # Exponential moving average of squared gradient values 74 | state['exp_avg_sq'] = torch.zeros_like(p.data) 75 | 76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 77 | 78 | state['step'] += 1 79 | 80 | # Decay the first and second moment running average coefficient 81 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 83 | 84 | # Bias correction 85 | bias_correction1 = 1 - beta1 ** state['step'] 86 | bias_correction2 = 1 - beta2 ** state['step'] 87 | 88 | # Compute length of SMA 89 | sma_t = sma_inf - 2 * state['step'] * (1 - bias_correction2) / bias_correction2 90 | 91 | update = torch.zeros_like(p.data) 92 | if sma_t > 4: 93 | #  Variance rectification term 94 | r_t = math.sqrt((sma_t - 4) * (sma_t - 2) * sma_inf / ((sma_inf - 4) * (sma_inf - 2) * sma_t)) 95 | #  Adaptive momentum 96 | update.addcdiv_(r_t, exp_avg / bias_correction1, 97 | (exp_avg_sq / bias_correction2).sqrt().add_(group['eps'])) 98 | else: 99 | # Unadapted momentum 100 | update.add_(exp_avg / bias_correction1) 101 | 102 | # Weight decay 103 | if group['weight_decay'] != 0: 104 | update.add_(group['weight_decay'], p.data) 105 | 106 | # LARS 107 | p_norm = p.data.pow(2).sum().sqrt() 108 | update_norm = update.pow(2).sum().sqrt() 109 | phi_p = p_norm.clamp(*self.scale_clip) 110 | # Compute the local LR 111 | if phi_p == 0 or update_norm == 0: 112 | local_lr = 1 113 | else: 114 | local_lr = phi_p / update_norm 115 | 116 | state['local_lr'] = local_lr 117 | 118 | p.data.add_(-group['lr'] * local_lr, update) 119 | 120 | return loss 121 | -------------------------------------------------------------------------------- /callback/optimizater/adabound.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.optim.optimizer import Optimizer 4 | 5 | class AdaBound(Optimizer): 6 | """Implements AdaBound algorithm. 7 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 8 | Arguments: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float, optional): Adam learning rate (default: 1e-3) 12 | betas (Tuple[float, float], optional): coefficients used for computing 13 | running averages of gradient and its square (default: (0.9, 0.999)) 14 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 15 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 20 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 21 | https://openreview.net/forum?id=Bkg3g2R9FX 22 | Example: 23 | >>> model = LSTM() 24 | >>> optimizer = AdaBound(model.parameters()) 25 | """ 26 | 27 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 28 | eps=1e-8, weight_decay=0, amsbound=False): 29 | if not 0.0 <= lr: 30 | raise ValueError("Invalid learning rate: {}".format(lr)) 31 | if not 0.0 <= eps: 32 | raise ValueError("Invalid epsilon value: {}".format(eps)) 33 | if not 0.0 <= betas[0] < 1.0: 34 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 35 | if not 0.0 <= betas[1] < 1.0: 36 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 37 | if not 0.0 <= final_lr: 38 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 39 | if not 0.0 <= gamma < 1.0: 40 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 41 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 42 | weight_decay=weight_decay, amsbound=amsbound) 43 | super(AdaBound, self).__init__(params, defaults) 44 | 45 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 46 | 47 | def __setstate__(self, state): 48 | super(AdaBound, self).__setstate__(state) 49 | for group in self.param_groups: 50 | group.setdefault('amsbound', False) 51 | 52 | def step(self, closure=None): 53 | """Performs a single optimization step. 54 | Arguments: 55 | closure (callable, optional): A closure that reevaluates the model 56 | and returns the loss. 57 | """ 58 | loss = None 59 | if closure is not None: 60 | loss = closure() 61 | for group, base_lr in zip(self.param_groups, self.base_lrs): 62 | for p in group['params']: 63 | if p.grad is None: 64 | continue 65 | grad = p.grad.data 66 | if grad.is_sparse: 67 | raise RuntimeError( 68 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 69 | amsbound = group['amsbound'] 70 | state = self.state[p] 71 | # State initialization 72 | if len(state) == 0: 73 | state['step'] = 0 74 | # Exponential moving average of gradient values 75 | state['exp_avg'] = torch.zeros_like(p.data) 76 | # Exponential moving average of squared gradient values 77 | state['exp_avg_sq'] = torch.zeros_like(p.data) 78 | if amsbound: 79 | # Maintains max of all exp. moving avg. of sq. grad. values 80 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 81 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 82 | if amsbound: 83 | max_exp_avg_sq = state['max_exp_avg_sq'] 84 | beta1, beta2 = group['betas'] 85 | state['step'] += 1 86 | if group['weight_decay'] != 0: 87 | grad = grad.add(group['weight_decay'], p.data) 88 | # Decay the first and second moment running average coefficient 89 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 90 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 91 | if amsbound: 92 | # Maintains the maximum of all 2nd moment running avg. till now 93 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 94 | # Use the max. for normalizing running avg. of gradient 95 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 96 | else: 97 | denom = exp_avg_sq.sqrt().add_(group['eps']) 98 | 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 102 | 103 | # Applies bounds on actual learning rate 104 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 105 | final_lr = group['final_lr'] * group['lr'] / base_lr 106 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 107 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 108 | step_size = torch.full_like(denom, step_size) 109 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 110 | p.data.add_(-step_size) 111 | return loss -------------------------------------------------------------------------------- /models/bert_for_ner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .layers.crf import CRF 5 | from transformers import BertModel,BertPreTrainedModel 6 | from .layers.linears import PoolerEndLogits, PoolerStartLogits 7 | from torch.nn import CrossEntropyLoss 8 | from losses.focal_loss import FocalLoss 9 | from losses.label_smoothing import LabelSmoothingCrossEntropy 10 | 11 | class BertSoftmaxForNer(BertPreTrainedModel): 12 | def __init__(self, config): 13 | super(BertSoftmaxForNer, self).__init__(config) 14 | self.num_labels = config.num_labels 15 | self.bert = BertModel(config) 16 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 17 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 18 | self.loss_type = config.loss_type 19 | self.init_weights() 20 | 21 | def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None): 22 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) 23 | sequence_output = outputs[0] 24 | sequence_output = self.dropout(sequence_output) 25 | logits = self.classifier(sequence_output) 26 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 27 | if labels is not None: 28 | assert self.loss_type in ['lsr', 'focal', 'ce'] 29 | if self.loss_type == 'lsr': 30 | loss_fct = LabelSmoothingCrossEntropy(ignore_index=0) 31 | elif self.loss_type == 'focal': 32 | loss_fct = FocalLoss(ignore_index=0) 33 | else: 34 | loss_fct = CrossEntropyLoss(ignore_index=0) 35 | # Only keep active parts of the loss 36 | if attention_mask is not None: 37 | active_loss = attention_mask.view(-1) == 1 38 | active_logits = logits.view(-1, self.num_labels)[active_loss] 39 | active_labels = labels.view(-1)[active_loss] 40 | loss = loss_fct(active_logits, active_labels) 41 | else: 42 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 43 | outputs = (loss,) + outputs 44 | return outputs # (loss), scores, (hidden_states), (attentions) 45 | 46 | class BertCrfForNer(BertPreTrainedModel): 47 | def __init__(self, config): 48 | super(BertCrfForNer, self).__init__(config) 49 | self.bert = BertModel(config) 50 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 51 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 52 | self.crf = CRF(num_tags=config.num_labels, batch_first=True) 53 | self.init_weights() 54 | 55 | def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None): 56 | outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) 57 | sequence_output = outputs[0] 58 | sequence_output = self.dropout(sequence_output) 59 | logits = self.classifier(sequence_output) 60 | outputs = (logits,) 61 | if labels is not None: 62 | loss = self.crf(emissions = logits, tags=labels, mask=attention_mask) 63 | outputs =(-1*loss,)+outputs 64 | return outputs # (loss), scores 65 | 66 | class BertSpanForNer(BertPreTrainedModel): 67 | def __init__(self, config,): 68 | super(BertSpanForNer, self).__init__(config) 69 | self.soft_label = config.soft_label 70 | self.num_labels = config.num_labels 71 | self.loss_type = config.loss_type 72 | self.bert = BertModel(config) 73 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 74 | self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels) 75 | if self.soft_label: 76 | self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels) 77 | else: 78 | self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels) 79 | self.init_weights() 80 | 81 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None): 82 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids) 83 | sequence_output = outputs[0] 84 | sequence_output = self.dropout(sequence_output) 85 | start_logits = self.start_fc(sequence_output) 86 | if start_positions is not None and self.training: 87 | if self.soft_label: 88 | batch_size = input_ids.size(0) 89 | seq_len = input_ids.size(1) 90 | label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels) 91 | label_logits.zero_() 92 | label_logits = label_logits.to(input_ids.device) 93 | label_logits.scatter_(2, start_positions.unsqueeze(2), 1) 94 | else: 95 | label_logits = start_positions.unsqueeze(2).float() 96 | else: 97 | label_logits = F.softmax(start_logits, -1) 98 | if not self.soft_label: 99 | label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float() 100 | end_logits = self.end_fc(sequence_output, label_logits) 101 | outputs = (start_logits, end_logits,) + outputs[2:] 102 | 103 | if start_positions is not None and end_positions is not None: 104 | assert self.loss_type in ['lsr', 'focal', 'ce'] 105 | if self.loss_type =='lsr': 106 | loss_fct = LabelSmoothingCrossEntropy() 107 | elif self.loss_type == 'focal': 108 | loss_fct = FocalLoss() 109 | else: 110 | loss_fct = CrossEntropyLoss() 111 | start_logits = start_logits.view(-1, self.num_labels) 112 | end_logits = end_logits.view(-1, self.num_labels) 113 | active_loss = attention_mask.view(-1) == 1 114 | active_start_logits = start_logits[active_loss] 115 | active_end_logits = end_logits[active_loss] 116 | 117 | active_start_labels = start_positions.view(-1)[active_loss] 118 | active_end_labels = end_positions.view(-1)[active_loss] 119 | 120 | start_loss = loss_fct(active_start_logits, active_start_labels) 121 | end_loss = loss_fct(active_end_logits, active_end_labels) 122 | total_loss = (start_loss + end_loss) / 2 123 | outputs = (total_loss,) + outputs 124 | return outputs 125 | 126 | -------------------------------------------------------------------------------- /tools/finetuning_argparse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_argparse(): 4 | parser = argparse.ArgumentParser() 5 | # Required parameters 6 | parser.add_argument("--task_name", default=None, type=str, required=True, 7 | help="The name of the task to train selected in the list: ") 8 | parser.add_argument("--data_dir", default=None, type=str, required=True, 9 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", ) 10 | parser.add_argument("--model_type", default=None, type=str, required=True, 11 | help="Model type selected in the list: ") 12 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 13 | help="Path to pre-trained model or shortcut name selected in the list: " ) 14 | parser.add_argument("--output_dir", default=None, type=str, required=True, 15 | help="The output directory where the model predictions and checkpoints will be written.", ) 16 | 17 | # Other parameters 18 | parser.add_argument('--markup', default='bios', type=str, 19 | choices=['bios', 'bio']) 20 | parser.add_argument('--loss_type', default='ce', type=str, 21 | choices=['lsr', 'focal', 'ce']) 22 | parser.add_argument("--config_name", default="", type=str, 23 | help="Pretrained config name or path if not the same as model_name") 24 | parser.add_argument("--tokenizer_name", default="", type=str, 25 | help="Pretrained tokenizer name or path if not the same as model_name", ) 26 | parser.add_argument("--cache_dir", default="", type=str, 27 | help="Where do you want to store the pre-trained models downloaded from s3", ) 28 | parser.add_argument("--train_max_seq_length", default=128, type=int, 29 | help="The maximum total input sequence length after tokenization. Sequences longer " 30 | "than this will be truncated, sequences shorter will be padded.", ) 31 | parser.add_argument("--eval_max_seq_length", default=512, type=int, 32 | help="The maximum total input sequence length after tokenization. Sequences longer " 33 | "than this will be truncated, sequences shorter will be padded.", ) 34 | parser.add_argument("--do_train", action="store_true", 35 | help="Whether to run training.") 36 | parser.add_argument("--do_eval", action="store_true", 37 | help="Whether to run eval on the dev set.") 38 | parser.add_argument("--do_predict", action="store_true", 39 | help="Whether to run predictions on the test set.") 40 | parser.add_argument("--evaluate_during_training", action="store_true", 41 | help="Whether to run evaluation during training at each logging step.", ) 42 | parser.add_argument("--do_lower_case", action="store_true", 43 | help="Set this flag if you are using an uncased model.") 44 | # adversarial training 45 | parser.add_argument("--do_adv", action="store_true", 46 | help="Whether to adversarial training.") 47 | parser.add_argument('--adv_epsilon', default=1.0, type=float, 48 | help="Epsilon for adversarial.") 49 | parser.add_argument('--adv_name', default='word_embeddings', type=str, 50 | help="name for adversarial layer.") 51 | 52 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 53 | help="Batch size per GPU/CPU for training.") 54 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 55 | help="Batch size per GPU/CPU for evaluation.") 56 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 57 | help="Number of updates steps to accumulate before performing a backward/update pass.", ) 58 | parser.add_argument("--learning_rate", default=5e-5, type=float, 59 | help="The initial learning rate for Adam.") 60 | parser.add_argument("--crf_learning_rate", default=5e-5, type=float, 61 | help="The initial learning rate for crf and linear layer.") 62 | parser.add_argument("--weight_decay", default=0.01, type=float, 63 | help="Weight decay if we apply some.") 64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 65 | help="Epsilon for Adam optimizer.") 66 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 67 | help="Max gradient norm.") 68 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 69 | help="Total number of training epochs to perform.") 70 | parser.add_argument("--max_steps", default=-1, type=int, 71 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", ) 72 | 73 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 74 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.") 75 | parser.add_argument("--logging_steps", type=int, default=50, 76 | help="Log every X updates steps.") 77 | parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.") 78 | parser.add_argument("--eval_all_checkpoints", action="store_true", 79 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", ) 80 | parser.add_argument("--predict_checkpoints",type=int, default=0, 81 | help="predict checkpoints starting with the same prefix as model_name ending and ending with step number") 82 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") 83 | parser.add_argument("--overwrite_output_dir", action="store_true", 84 | help="Overwrite the content of the output directory") 85 | parser.add_argument("--overwrite_cache", action="store_true", 86 | help="Overwrite the cached training and evaluation sets") 87 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 88 | parser.add_argument("--fp16", action="store_true", 89 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", ) 90 | parser.add_argument("--fp16_opt_level", type=str, default="O1", 91 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 92 | "See details at https://nvidia.github.io/apex/amp.html", ) 93 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 94 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.") 95 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.") 96 | return parser -------------------------------------------------------------------------------- /processors/utils_ner.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import torch 4 | from transformers import BertTokenizer 5 | 6 | class DataProcessor(object): 7 | """Base class for data converters for sequence classification data sets.""" 8 | 9 | def get_train_examples(self, data_dir): 10 | """Gets a collection of `InputExample`s for the train set.""" 11 | raise NotImplementedError() 12 | 13 | def get_dev_examples(self, data_dir): 14 | """Gets a collection of `InputExample`s for the dev set.""" 15 | raise NotImplementedError() 16 | 17 | def get_labels(self): 18 | """Gets the list of labels for this data set.""" 19 | raise NotImplementedError() 20 | 21 | @classmethod 22 | def _read_tsv(cls, input_file, quotechar=None): 23 | """Reads a tab separated value file.""" 24 | with open(input_file, "r", encoding="utf-8-sig") as f: 25 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 26 | lines = [] 27 | for line in reader: 28 | lines.append(line) 29 | return lines 30 | 31 | @classmethod 32 | def _read_text(self,input_file): 33 | lines = [] 34 | with open(input_file,'r') as f: 35 | words = [] 36 | labels = [] 37 | for line in f: 38 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 39 | if words: 40 | lines.append({"words":words,"labels":labels}) 41 | words = [] 42 | labels = [] 43 | else: 44 | splits = line.split(" ") 45 | words.append(splits[0]) 46 | if len(splits) > 1: 47 | labels.append(splits[-1].replace("\n", "")) 48 | else: 49 | # Examples could have no label for mode = "test" 50 | labels.append("O") 51 | if words: 52 | lines.append({"words":words,"labels":labels}) 53 | return lines 54 | 55 | @classmethod 56 | def _read_json(self,input_file): 57 | lines = [] 58 | with open(input_file,'r') as f: 59 | for line in f: 60 | line = json.loads(line.strip()) 61 | text = line['text'] 62 | label_entities = line.get('label',None) 63 | words = list(text) 64 | labels = ['O'] * len(words) 65 | if label_entities is not None: 66 | for key,value in label_entities.items(): 67 | for sub_name,sub_index in value.items(): 68 | for start_index,end_index in sub_index: 69 | assert ''.join(words[start_index:end_index+1]) == sub_name 70 | if start_index == end_index: 71 | labels[start_index] = 'S-'+key 72 | else: 73 | labels[start_index] = 'B-'+key 74 | labels[start_index+1:end_index+1] = ['I-'+key]*(len(sub_name)-1) 75 | lines.append({"words": words, "labels": labels}) 76 | return lines 77 | 78 | def get_entity_bios(seq,id2label): 79 | """Gets entities from sequence. 80 | note: BIOS 81 | Args: 82 | seq (list): sequence of labels. 83 | Returns: 84 | list: list of (chunk_type, chunk_start, chunk_end). 85 | Example: 86 | # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC'] 87 | # >>> get_entity_bios(seq) 88 | [['PER', 0,1], ['LOC', 3, 3]] 89 | """ 90 | chunks = [] 91 | chunk = [-1, -1, -1] 92 | for indx, tag in enumerate(seq): 93 | if not isinstance(tag, str): 94 | tag = id2label[tag] 95 | if tag.startswith("S-"): 96 | if chunk[2] != -1: 97 | chunks.append(chunk) 98 | chunk = [-1, -1, -1] 99 | chunk[1] = indx 100 | chunk[2] = indx 101 | chunk[0] = tag.split('-')[1] 102 | chunks.append(chunk) 103 | chunk = (-1, -1, -1) 104 | if tag.startswith("B-"): 105 | if chunk[2] != -1: 106 | chunks.append(chunk) 107 | chunk = [-1, -1, -1] 108 | chunk[1] = indx 109 | chunk[0] = tag.split('-')[1] 110 | elif tag.startswith('I-') and chunk[1] != -1: 111 | _type = tag.split('-')[1] 112 | if _type == chunk[0]: 113 | chunk[2] = indx 114 | if indx == len(seq) - 1: 115 | chunks.append(chunk) 116 | else: 117 | if chunk[2] != -1: 118 | chunks.append(chunk) 119 | chunk = [-1, -1, -1] 120 | return chunks 121 | 122 | def get_entity_bio(seq,id2label): 123 | """Gets entities from sequence. 124 | note: BIO 125 | Args: 126 | seq (list): sequence of labels. 127 | Returns: 128 | list: list of (chunk_type, chunk_start, chunk_end). 129 | Example: 130 | seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] 131 | get_entity_bio(seq) 132 | #output 133 | [['PER', 0,1], ['LOC', 3, 3]] 134 | """ 135 | chunks = [] 136 | chunk = [-1, -1, -1] 137 | for indx, tag in enumerate(seq): 138 | if not isinstance(tag, str): 139 | tag = id2label[tag] 140 | if tag.startswith("B-"): 141 | if chunk[2] != -1: 142 | chunks.append(chunk) 143 | chunk = [-1, -1, -1] 144 | chunk[1] = indx 145 | chunk[0] = tag.split('-')[1] 146 | chunk[2] = indx 147 | if indx == len(seq) - 1: 148 | chunks.append(chunk) 149 | elif tag.startswith('I-') and chunk[1] != -1: 150 | _type = tag.split('-')[1] 151 | if _type == chunk[0]: 152 | chunk[2] = indx 153 | 154 | if indx == len(seq) - 1: 155 | chunks.append(chunk) 156 | else: 157 | if chunk[2] != -1: 158 | chunks.append(chunk) 159 | chunk = [-1, -1, -1] 160 | return chunks 161 | 162 | def get_entities(seq,id2label,markup='bios'): 163 | ''' 164 | :param seq: 165 | :param id2label: 166 | :param markup: 167 | :return: 168 | ''' 169 | assert markup in ['bio','bios'] 170 | if markup =='bio': 171 | return get_entity_bio(seq,id2label) 172 | else: 173 | return get_entity_bios(seq,id2label) 174 | 175 | def bert_extract_item(start_logits, end_logits): 176 | S = [] 177 | start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1] 178 | end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1] 179 | for i, s_l in enumerate(start_pred): 180 | if s_l == 0: 181 | continue 182 | for j, e_l in enumerate(end_pred[i:]): 183 | if s_l == e_l: 184 | S.append((s_l, i, i + j)) 185 | break 186 | return S 187 | -------------------------------------------------------------------------------- /callback/optimizater/adafactor.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import torch 3 | from copy import copy 4 | import functools 5 | from math import sqrt 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdaFactor(Optimizer): 10 | ''' 11 | # Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf 12 | # inspired but modified from https://github.com/DeadAt0m/adafactor-pytorch 13 | Example: 14 | >>> model = LSTM() 15 | >>> optimizer = AdaFactor(model.parameters(),lr= lr) 16 | ''' 17 | 18 | def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30, 19 | eps2=1e-3, cliping_threshold=1, non_constant_decay=True, 20 | enable_factorization=True, ams_grad=True, weight_decay=0): 21 | 22 | enable_momentum = beta1 != 0 23 | if non_constant_decay: 24 | ams_grad = False 25 | 26 | defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps1=eps1, 27 | eps2=eps2, cliping_threshold=cliping_threshold, 28 | weight_decay=weight_decay, ams_grad=ams_grad, 29 | enable_factorization=enable_factorization, 30 | enable_momentum=enable_momentum, 31 | non_constant_decay=non_constant_decay) 32 | 33 | super(AdaFactor, self).__init__(params, defaults) 34 | 35 | def __setstate__(self, state): 36 | super(AdaFactor, self).__setstate__(state) 37 | 38 | def _experimental_reshape(self, shape): 39 | temp_shape = shape[2:] 40 | if len(temp_shape) == 1: 41 | new_shape = (shape[0], shape[1] * shape[2]) 42 | else: 43 | tmp_div = len(temp_shape) // 2 + len(temp_shape) % 2 44 | new_shape = (shape[0] * functools.reduce(operator.mul, 45 | temp_shape[tmp_div:], 1), 46 | shape[1] * functools.reduce(operator.mul, 47 | temp_shape[:tmp_div], 1)) 48 | return new_shape, copy(shape) 49 | 50 | def _check_shape(self, shape): 51 | ''' 52 | output1 - True - algorithm for matrix, False - vector; 53 | output2 - need reshape 54 | ''' 55 | if len(shape) > 2: 56 | return True, True 57 | elif len(shape) == 2: 58 | return True, False 59 | elif len(shape) == 2 and (shape[0] == 1 or shape[1] == 1): 60 | return False, False 61 | else: 62 | return False, False 63 | 64 | def _rms(self, x): 65 | return sqrt(torch.mean(x.pow(2))) 66 | 67 | def step(self, closure=None): 68 | loss = None 69 | if closure is not None: 70 | loss = closure() 71 | for group in self.param_groups: 72 | for p in group['params']: 73 | if p.grad is None: 74 | continue 75 | grad = p.grad.data 76 | 77 | if grad.is_sparse: 78 | raise RuntimeError('Adam does not support sparse \ 79 | gradients, use SparseAdam instead') 80 | 81 | is_matrix, is_need_reshape = self._check_shape(grad.size()) 82 | new_shape = p.data.size() 83 | if is_need_reshape and group['enable_factorization']: 84 | new_shape, old_shape = \ 85 | self._experimental_reshape(p.data.size()) 86 | grad = grad.view(new_shape) 87 | 88 | state = self.state[p] 89 | if len(state) == 0: 90 | state['step'] = 0 91 | if group['enable_momentum']: 92 | state['exp_avg'] = torch.zeros(new_shape, 93 | dtype=torch.float32, 94 | device=p.grad.device) 95 | 96 | if is_matrix and group['enable_factorization']: 97 | state['exp_avg_sq_R'] = \ 98 | torch.zeros((1, new_shape[1]), 99 | dtype=torch.float32, 100 | device=p.grad.device) 101 | state['exp_avg_sq_C'] = \ 102 | torch.zeros((new_shape[0], 1), 103 | dtype=torch.float32, 104 | device=p.grad.device) 105 | else: 106 | state['exp_avg_sq'] = torch.zeros(new_shape, 107 | dtype=torch.float32, 108 | device=p.grad.device) 109 | if group['ams_grad']: 110 | state['exp_avg_sq_hat'] = \ 111 | torch.zeros(new_shape, dtype=torch.float32, 112 | device=p.grad.device) 113 | 114 | if group['enable_momentum']: 115 | exp_avg = state['exp_avg'] 116 | 117 | if is_matrix and group['enable_factorization']: 118 | exp_avg_sq_r = state['exp_avg_sq_R'] 119 | exp_avg_sq_c = state['exp_avg_sq_C'] 120 | else: 121 | exp_avg_sq = state['exp_avg_sq'] 122 | 123 | if group['ams_grad']: 124 | exp_avg_sq_hat = state['exp_avg_sq_hat'] 125 | 126 | state['step'] += 1 127 | lr_t = group['lr'] 128 | lr_t *= max(group['eps2'], self._rms(p.data)) 129 | 130 | if group['enable_momentum']: 131 | if group['non_constant_decay']: 132 | beta1_t = group['beta1'] * \ 133 | (1 - group['beta1'] ** (state['step'] - 1)) \ 134 | / (1 - group['beta1'] ** state['step']) 135 | else: 136 | beta1_t = group['beta1'] 137 | exp_avg.mul_(beta1_t).add_(1 - beta1_t, grad) 138 | 139 | if group['non_constant_decay']: 140 | beta2_t = group['beta2'] * \ 141 | (1 - group['beta2'] ** (state['step'] - 1)) / \ 142 | (1 - group['beta2'] ** state['step']) 143 | else: 144 | beta2_t = group['beta2'] 145 | 146 | if is_matrix and group['enable_factorization']: 147 | exp_avg_sq_r.mul_(beta2_t). \ 148 | add_(1 - beta2_t, torch.sum(torch.mul(grad, grad). 149 | add_(group['eps1']), 150 | dim=0, keepdim=True)) 151 | exp_avg_sq_c.mul_(beta2_t). \ 152 | add_(1 - beta2_t, torch.sum(torch.mul(grad, grad). 153 | add_(group['eps1']), 154 | dim=1, keepdim=True)) 155 | v = torch.mul(exp_avg_sq_c, 156 | exp_avg_sq_r).div_(torch.sum(exp_avg_sq_r)) 157 | else: 158 | exp_avg_sq.mul_(beta2_t). \ 159 | addcmul_(1 - beta2_t, grad, grad). \ 160 | add_((1 - beta2_t) * group['eps1']) 161 | v = exp_avg_sq 162 | g = grad 163 | if group['enable_momentum']: 164 | g = torch.div(exp_avg, 1 - beta1_t ** state['step']) 165 | if group['ams_grad']: 166 | torch.max(exp_avg_sq_hat, v, out=exp_avg_sq_hat) 167 | v = exp_avg_sq_hat 168 | u = torch.div(g, (torch.div(v, 1 - beta2_t ** 169 | state['step'])).sqrt().add_(group['eps1'])) 170 | else: 171 | u = torch.div(g, v.sqrt()) 172 | u.div_(max(1, self._rms(u) / group['cliping_threshold'])) 173 | p.data.add_(-lr_t * (u.view(old_shape) if is_need_reshape and 174 | group['enable_factorization'] else u)) 175 | if group['weight_decay'] != 0: 176 | p.data.add_(-group['weight_decay'] * lr_t, p.data) 177 | return loss 178 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 131 | -------------------------------------------------------------------------------- /processors/ner_seq.py: -------------------------------------------------------------------------------- 1 | """ Named entity recognition fine-tuning: utilities to work with CLUENER task. """ 2 | import torch 3 | import logging 4 | import os 5 | import copy 6 | import json 7 | from .utils_ner import DataProcessor 8 | logger = logging.getLogger(__name__) 9 | 10 | class InputExample(object): 11 | """A single training/test example for token classification.""" 12 | def __init__(self, guid, text_a, labels): 13 | """Constructs a InputExample. 14 | Args: 15 | guid: Unique id for the example. 16 | text_a: list. The words of the sequence. 17 | labels: (Optional) list. The labels for each word of the sequence. This should be 18 | specified for train and dev examples, but not for test examples. 19 | """ 20 | self.guid = guid 21 | self.text_a = text_a 22 | self.labels = labels 23 | 24 | def __repr__(self): 25 | return str(self.to_json_string()) 26 | def to_dict(self): 27 | """Serializes this instance to a Python dictionary.""" 28 | output = copy.deepcopy(self.__dict__) 29 | return output 30 | def to_json_string(self): 31 | """Serializes this instance to a JSON string.""" 32 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 33 | 34 | class InputFeatures(object): 35 | """A single set of features of data.""" 36 | def __init__(self, input_ids, input_mask, input_len,segment_ids, label_ids): 37 | self.input_ids = input_ids 38 | self.input_mask = input_mask 39 | self.segment_ids = segment_ids 40 | self.label_ids = label_ids 41 | self.input_len = input_len 42 | 43 | def __repr__(self): 44 | return str(self.to_json_string()) 45 | 46 | def to_dict(self): 47 | """Serializes this instance to a Python dictionary.""" 48 | output = copy.deepcopy(self.__dict__) 49 | return output 50 | 51 | def to_json_string(self): 52 | """Serializes this instance to a JSON string.""" 53 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 54 | 55 | def collate_fn(batch): 56 | """ 57 | batch should be a list of (sequence, target, length) tuples... 58 | Returns a padded tensor of sequences sorted from longest to shortest, 59 | """ 60 | all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch)) 61 | max_len = max(all_lens).item() 62 | all_input_ids = all_input_ids[:, :max_len] 63 | all_attention_mask = all_attention_mask[:, :max_len] 64 | all_token_type_ids = all_token_type_ids[:, :max_len] 65 | all_labels = all_labels[:,:max_len] 66 | return all_input_ids, all_attention_mask, all_token_type_ids, all_labels,all_lens 67 | 68 | def convert_examples_to_features(examples,label_list,max_seq_length,tokenizer, 69 | cls_token_at_end=False,cls_token="[CLS]",cls_token_segment_id=1, 70 | sep_token="[SEP]",pad_on_left=False,pad_token=0,pad_token_segment_id=0, 71 | sequence_a_segment_id=0,mask_padding_with_zero=True,): 72 | """ Loads a data file into a list of `InputBatch`s 73 | `cls_token_at_end` define the location of the CLS token: 74 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 75 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 76 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 77 | """ 78 | label_map = {label: i for i, label in enumerate(label_list)} 79 | features = [] 80 | for (ex_index, example) in enumerate(examples): 81 | if ex_index % 10000 == 0: 82 | logger.info("Writing example %d of %d", ex_index, len(examples)) 83 | if isinstance(example.text_a,list): 84 | example.text_a = " ".join(example.text_a) 85 | tokens = tokenizer.tokenize(example.text_a) 86 | label_ids = [label_map[x] for x in example.labels] 87 | # Account for [CLS] and [SEP] with "- 2". 88 | special_tokens_count = 2 89 | if len(tokens) > max_seq_length - special_tokens_count: 90 | tokens = tokens[: (max_seq_length - special_tokens_count)] 91 | label_ids = label_ids[: (max_seq_length - special_tokens_count)] 92 | 93 | # The convention in BERT is: 94 | # (a) For sequence pairs: 95 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 96 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 97 | # (b) For single sequences: 98 | # tokens: [CLS] the dog is hairy . [SEP] 99 | # type_ids: 0 0 0 0 0 0 0 100 | # 101 | # Where "type_ids" are used to indicate whether this is the first 102 | # sequence or the second sequence. The embedding vectors for `type=0` and 103 | # `type=1` were learned during pre-training and are added to the wordpiece 104 | # embedding vector (and position vector). This is not *strictly* necessary 105 | # since the [SEP] token unambiguously separates the sequences, but it makes 106 | # it easier for the model to learn the concept of sequences. 107 | # 108 | # For classification tasks, the first vector (corresponding to [CLS]) is 109 | # used as as the "sentence vector". Note that this only makes sense because 110 | # the entire model is fine-tuned. 111 | tokens += [sep_token] 112 | label_ids += [label_map['O']] 113 | segment_ids = [sequence_a_segment_id] * len(tokens) 114 | 115 | if cls_token_at_end: 116 | tokens += [cls_token] 117 | label_ids += [label_map['O']] 118 | segment_ids += [cls_token_segment_id] 119 | else: 120 | tokens = [cls_token] + tokens 121 | label_ids = [label_map['O']] + label_ids 122 | segment_ids = [cls_token_segment_id] + segment_ids 123 | 124 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 125 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 126 | # tokens are attended to. 127 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 128 | input_len = len(label_ids) 129 | # Zero-pad up to the sequence length. 130 | padding_length = max_seq_length - len(input_ids) 131 | if pad_on_left: 132 | input_ids = ([pad_token] * padding_length) + input_ids 133 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 134 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 135 | label_ids = ([pad_token] * padding_length) + label_ids 136 | else: 137 | input_ids += [pad_token] * padding_length 138 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 139 | segment_ids += [pad_token_segment_id] * padding_length 140 | label_ids += [pad_token] * padding_length 141 | 142 | assert len(input_ids) == max_seq_length 143 | assert len(input_mask) == max_seq_length 144 | assert len(segment_ids) == max_seq_length 145 | assert len(label_ids) == max_seq_length 146 | if ex_index < 5: 147 | logger.info("*** Example ***") 148 | logger.info("guid: %s", example.guid) 149 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 150 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 151 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 152 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 153 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids])) 154 | 155 | features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask,input_len = input_len, 156 | segment_ids=segment_ids, label_ids=label_ids)) 157 | return features 158 | 159 | 160 | class CnerProcessor(DataProcessor): 161 | """Processor for the chinese ner data set.""" 162 | 163 | def get_train_examples(self, data_dir): 164 | """See base class.""" 165 | return self._create_examples(self._read_text(os.path.join(data_dir, "train.char.bmes")), "train") 166 | 167 | def get_dev_examples(self, data_dir): 168 | """See base class.""" 169 | return self._create_examples(self._read_text(os.path.join(data_dir, "dev.char.bmes")), "dev") 170 | 171 | def get_test_examples(self, data_dir): 172 | """See base class.""" 173 | return self._create_examples(self._read_text(os.path.join(data_dir, "test.char.bmes")), "test") 174 | 175 | def get_labels(self): 176 | """See base class.""" 177 | return ["X",'B-CONT','B-EDU','B-LOC','B-NAME','B-ORG','B-PRO','B-RACE','B-TITLE', 178 | 'I-CONT','I-EDU','I-LOC','I-NAME','I-ORG','I-PRO','I-RACE','I-TITLE', 179 | 'O','S-NAME','S-ORG','S-RACE',"[START]", "[END]"] 180 | 181 | def _create_examples(self, lines, set_type): 182 | """Creates examples for the training and dev sets.""" 183 | examples = [] 184 | for (i, line) in enumerate(lines): 185 | if i == 0: 186 | continue 187 | guid = "%s-%s" % (set_type, i) 188 | text_a= line['words'] 189 | # BIOS 190 | labels = [] 191 | for x in line['labels']: 192 | if 'M-' in x: 193 | labels.append(x.replace('M-','I-')) 194 | elif 'E-' in x: 195 | labels.append(x.replace('E-', 'I-')) 196 | else: 197 | labels.append(x) 198 | examples.append(InputExample(guid=guid, text_a=text_a, labels=labels)) 199 | return examples 200 | 201 | class CluenerProcessor(DataProcessor): 202 | """Processor for the chinese ner data set.""" 203 | 204 | def get_train_examples(self, data_dir): 205 | """See base class.""" 206 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train") 207 | 208 | def get_dev_examples(self, data_dir): 209 | """See base class.""" 210 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev") 211 | 212 | def get_test_examples(self, data_dir): 213 | """See base class.""" 214 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test") 215 | 216 | def get_labels(self): 217 | """See base class.""" 218 | return ["X", "B-address", "B-book", "B-company", 'B-game', 'B-government', 'B-movie', 'B-name', 219 | 'B-organization', 'B-position','B-scene',"I-address", 220 | "I-book", "I-company", 'I-game', 'I-government', 'I-movie', 'I-name', 221 | 'I-organization', 'I-position','I-scene', 222 | "S-address", "S-book", "S-company", 'S-game', 'S-government', 'S-movie', 223 | 'S-name', 'S-organization', 'S-position', 224 | 'S-scene','O',"[START]", "[END]"] 225 | 226 | def _create_examples(self, lines, set_type): 227 | """Creates examples for the training and dev sets.""" 228 | examples = [] 229 | for (i, line) in enumerate(lines): 230 | guid = "%s-%s" % (set_type, i) 231 | text_a= line['words'] 232 | # BIOS 233 | labels = line['labels'] 234 | examples.append(InputExample(guid=guid, text_a=text_a, labels=labels)) 235 | return examples 236 | 237 | ner_processors = { 238 | "cner": CnerProcessor, 239 | 'cluener':CluenerProcessor 240 | } 241 | -------------------------------------------------------------------------------- /processors/ner_span.py: -------------------------------------------------------------------------------- 1 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """ 2 | import torch 3 | import logging 4 | import os 5 | import copy 6 | import json 7 | from .utils_ner import DataProcessor,get_entities 8 | logger = logging.getLogger(__name__) 9 | 10 | class InputExample(object): 11 | """A single training/test example for token classification.""" 12 | def __init__(self, guid, text_a, subject): 13 | self.guid = guid 14 | self.text_a = text_a 15 | self.subject = subject 16 | def __repr__(self): 17 | return str(self.to_json_string()) 18 | def to_dict(self): 19 | """Serializes this instance to a Python dictionary.""" 20 | output = copy.deepcopy(self.__dict__) 21 | return output 22 | def to_json_string(self): 23 | """Serializes this instance to a JSON string.""" 24 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 25 | 26 | class InputFeature(object): 27 | """A single set of features of data.""" 28 | 29 | def __init__(self, input_ids, input_mask, input_len, segment_ids, start_ids,end_ids, subjects): 30 | self.input_ids = input_ids 31 | self.input_mask = input_mask 32 | self.segment_ids = segment_ids 33 | self.start_ids = start_ids 34 | self.input_len = input_len 35 | self.end_ids = end_ids 36 | self.subjects = subjects 37 | 38 | def __repr__(self): 39 | return str(self.to_json_string()) 40 | 41 | def to_dict(self): 42 | """Serializes this instance to a Python dictionary.""" 43 | output = copy.deepcopy(self.__dict__) 44 | return output 45 | 46 | def to_json_string(self): 47 | """Serializes this instance to a JSON string.""" 48 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 49 | 50 | def collate_fn(batch): 51 | """ 52 | batch should be a list of (sequence, target, length) tuples... 53 | Returns a padded tensor of sequences sorted from longest to shortest, 54 | """ 55 | all_input_ids, all_input_mask, all_segment_ids, all_start_ids,all_end_ids,all_lens = map(torch.stack, zip(*batch)) 56 | max_len = max(all_lens).item() 57 | all_input_ids = all_input_ids[:, :max_len] 58 | all_input_mask = all_input_mask[:, :max_len] 59 | all_segment_ids = all_segment_ids[:, :max_len] 60 | all_start_ids = all_start_ids[:,:max_len] 61 | all_end_ids = all_end_ids[:, :max_len] 62 | return all_input_ids, all_input_mask, all_segment_ids, all_start_ids,all_end_ids,all_lens 63 | 64 | def convert_examples_to_features(examples,label_list,max_seq_length,tokenizer, 65 | cls_token_at_end=False,cls_token="[CLS]",cls_token_segment_id=1, 66 | sep_token="[SEP]",pad_on_left=False,pad_token=0,pad_token_segment_id=0, 67 | sequence_a_segment_id=0,mask_padding_with_zero=True,): 68 | """ Loads a data file into a list of `InputBatch`s 69 | `cls_token_at_end` define the location of the CLS token: 70 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 71 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 72 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 73 | """ 74 | label2id = {label: i for i, label in enumerate(label_list)} 75 | features = [] 76 | for (ex_index, example) in enumerate(examples): 77 | if ex_index % 10000 == 0: 78 | logger.info("Writing example %d of %d", ex_index, len(examples)) 79 | textlist = example.text_a 80 | subjects = example.subject 81 | if isinstance(textlist,list): 82 | textlist = " ".join(textlist) 83 | tokens = tokenizer.tokenize(textlist) 84 | start_ids = [0] * len(tokens) 85 | end_ids = [0] * len(tokens) 86 | subjects_id = [] 87 | for subject in subjects: 88 | label = subject[0] 89 | start = subject[1] 90 | end = subject[2] 91 | start_ids[start] = label2id[label] 92 | end_ids[end] = label2id[label] 93 | subjects_id.append((label2id[label], start, end)) 94 | # Account for [CLS] and [SEP] with "- 2". 95 | special_tokens_count = 2 96 | if len(tokens) > max_seq_length - special_tokens_count: 97 | tokens = tokens[: (max_seq_length - special_tokens_count)] 98 | start_ids = start_ids[: (max_seq_length - special_tokens_count)] 99 | end_ids = end_ids[: (max_seq_length - special_tokens_count)] 100 | # The convention in BERT is: 101 | # (a) For sequence pairs: 102 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 103 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 104 | # (b) For single sequences: 105 | # tokens: [CLS] the dog is hairy . [SEP] 106 | # type_ids: 0 0 0 0 0 0 0 107 | # 108 | # Where "type_ids" are used to indicate whether this is the first 109 | # sequence or the second sequence. The embedding vectors for `type=0` and 110 | # `type=1` were learned during pre-training and are added to the wordpiece 111 | # embedding vector (and position vector). This is not *strictly* necessary 112 | # since the [SEP] token unambiguously separates the sequences, but it makes 113 | # it easier for the model to learn the concept of sequences. 114 | # 115 | # For classification tasks, the first vector (corresponding to [CLS]) is 116 | # used as as the "sentence vector". Note that this only makes sense because 117 | # the entire model is fine-tuned. 118 | tokens += [sep_token] 119 | start_ids += [0] 120 | end_ids += [0] 121 | segment_ids = [sequence_a_segment_id] * len(tokens) 122 | if cls_token_at_end: 123 | tokens += [cls_token] 124 | start_ids += [0] 125 | end_ids += [0] 126 | segment_ids += [cls_token_segment_id] 127 | else: 128 | tokens = [cls_token] + tokens 129 | start_ids = [0]+ start_ids 130 | end_ids = [0]+ end_ids 131 | segment_ids = [cls_token_segment_id] + segment_ids 132 | 133 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 134 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 135 | # tokens are attended to. 136 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 137 | input_len = len(input_ids) 138 | # Zero-pad up to the sequence length. 139 | padding_length = max_seq_length - len(input_ids) 140 | if pad_on_left: 141 | input_ids = ([pad_token] * padding_length) + input_ids 142 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 143 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 144 | start_ids = ([0] * padding_length) + start_ids 145 | end_ids = ([0] * padding_length) + end_ids 146 | else: 147 | input_ids += [pad_token] * padding_length 148 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length 149 | segment_ids += [pad_token_segment_id] * padding_length 150 | start_ids += ([0] * padding_length) 151 | end_ids += ([0] * padding_length) 152 | 153 | assert len(input_ids) == max_seq_length 154 | assert len(input_mask) == max_seq_length 155 | assert len(segment_ids) == max_seq_length 156 | assert len(start_ids) == max_seq_length 157 | assert len(end_ids) == max_seq_length 158 | 159 | if ex_index < 5: 160 | logger.info("*** Example ***") 161 | logger.info("guid: %s", example.guid) 162 | logger.info("tokens: %s", " ".join([str(x) for x in tokens])) 163 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids])) 164 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask])) 165 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) 166 | logger.info("start_ids: %s" % " ".join([str(x) for x in start_ids])) 167 | logger.info("end_ids: %s" % " ".join([str(x) for x in end_ids])) 168 | 169 | features.append(InputFeature(input_ids=input_ids, 170 | input_mask=input_mask, 171 | segment_ids=segment_ids, 172 | start_ids=start_ids, 173 | end_ids=end_ids, 174 | subjects=subjects_id, 175 | input_len=input_len)) 176 | return features 177 | 178 | class CnerProcessor(DataProcessor): 179 | """Processor for the chinese ner data set.""" 180 | 181 | def get_train_examples(self, data_dir): 182 | """See base class.""" 183 | return self._create_examples(self._read_text(os.path.join(data_dir, "train.char.bmes")), "train") 184 | 185 | def get_dev_examples(self, data_dir): 186 | """See base class.""" 187 | return self._create_examples(self._read_text(os.path.join(data_dir, "dev.char.bmes")), "dev") 188 | 189 | def get_test_examples(self, data_dir): 190 | """See base class.""" 191 | return self._create_examples(self._read_text(os.path.join(data_dir, "test.char.bmes")), "test") 192 | 193 | def get_labels(self): 194 | """See base class.""" 195 | return ["O", "CONT", "ORG","LOC",'EDU','NAME','PRO','RACE','TITLE'] 196 | 197 | def _create_examples(self, lines, set_type): 198 | """Creates examples for the training and dev sets.""" 199 | examples = [] 200 | for (i, line) in enumerate(lines): 201 | if i == 0: 202 | continue 203 | guid = "%s-%s" % (set_type, i) 204 | text_a = line['words'] 205 | labels = [] 206 | for x in line['labels']: 207 | if 'M-' in x: 208 | labels.append(x.replace('M-','I-')) 209 | elif 'E-' in x: 210 | labels.append(x.replace('E-', 'I-')) 211 | else: 212 | labels.append(x) 213 | subject = get_entities(labels,id2label=None,markup='bios') 214 | examples.append(InputExample(guid=guid, text_a=text_a, subject=subject)) 215 | return examples 216 | 217 | class CluenerProcessor(DataProcessor): 218 | """Processor for the chinese ner data set.""" 219 | 220 | def get_train_examples(self, data_dir): 221 | """See base class.""" 222 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train") 223 | 224 | def get_dev_examples(self, data_dir): 225 | """See base class.""" 226 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev") 227 | 228 | def get_test_examples(self, data_dir): 229 | """See base class.""" 230 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test") 231 | 232 | def get_labels(self): 233 | """See base class.""" 234 | return ["O", "address", "book","company",'game','government','movie','name','organization','position','scene'] 235 | 236 | def _create_examples(self, lines, set_type): 237 | """Creates examples for the training and dev sets.""" 238 | examples = [] 239 | for (i, line) in enumerate(lines): 240 | guid = "%s-%s" % (set_type, i) 241 | text_a = line['words'] 242 | labels = line['labels'] 243 | subject = get_entities(labels,id2label=None,markup='bios') 244 | examples.append(InputExample(guid=guid, text_a=text_a, subject=subject)) 245 | return examples 246 | 247 | ner_processors = { 248 | "cner": CnerProcessor, 249 | 'cluener':CluenerProcessor 250 | } 251 | 252 | 253 | -------------------------------------------------------------------------------- /tools/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import json 6 | import pickle 7 | import torch.nn as nn 8 | from collections import OrderedDict 9 | from pathlib import Path 10 | import logging 11 | 12 | logger = logging.getLogger() 13 | def print_config(config): 14 | info = "Running with the following configs:\n" 15 | for k, v in config.items(): 16 | info += f"\t{k} : {str(v)}\n" 17 | print("\n" + info + "\n") 18 | return 19 | 20 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 21 | ''' 22 | Example: 23 | >>> init_logger(log_file) 24 | >>> logger.info("abc'") 25 | ''' 26 | if isinstance(log_file,Path): 27 | log_file = str(log_file) 28 | log_format = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 29 | datefmt='%m/%d/%Y %H:%M:%S') 30 | 31 | logger = logging.getLogger() 32 | logger.setLevel(logging.INFO) 33 | console_handler = logging.StreamHandler() 34 | console_handler.setFormatter(log_format) 35 | logger.handlers = [console_handler] 36 | if log_file and log_file != '': 37 | file_handler = logging.FileHandler(log_file) 38 | file_handler.setLevel(log_file_level) 39 | # file_handler.setFormatter(log_format) 40 | logger.addHandler(file_handler) 41 | return logger 42 | 43 | def seed_everything(seed=1029): 44 | ''' 45 | 设置整个开发环境的seed 46 | :param seed: 47 | :param device: 48 | :return: 49 | ''' 50 | random.seed(seed) 51 | os.environ['PYTHONHASHSEED'] = str(seed) 52 | np.random.seed(seed) 53 | torch.manual_seed(seed) 54 | torch.cuda.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | # some cudnn methods can be random even after fixing the seed 57 | # unless you tell it to be deterministic 58 | torch.backends.cudnn.deterministic = True 59 | 60 | 61 | def prepare_device(n_gpu_use): 62 | """ 63 | setup GPU device if available, move model into configured device 64 | # 如果n_gpu_use为数字,则使用range生成list 65 | # 如果输入的是一个list,则默认使用list[0]作为controller 66 | """ 67 | if not n_gpu_use: 68 | device_type = 'cpu' 69 | else: 70 | n_gpu_use = n_gpu_use.split(",") 71 | device_type = f"cuda:{n_gpu_use[0]}" 72 | n_gpu = torch.cuda.device_count() 73 | if len(n_gpu_use) > 0 and n_gpu == 0: 74 | logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.") 75 | device_type = 'cpu' 76 | if len(n_gpu_use) > n_gpu: 77 | msg = f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are available on this machine." 78 | logger.warning(msg) 79 | n_gpu_use = range(n_gpu) 80 | device = torch.device(device_type) 81 | list_ids = n_gpu_use 82 | return device, list_ids 83 | 84 | 85 | def model_device(n_gpu, model): 86 | ''' 87 | 判断环境 cpu还是gpu 88 | 支持单机多卡 89 | :param n_gpu: 90 | :param model: 91 | :return: 92 | ''' 93 | device, device_ids = prepare_device(n_gpu) 94 | if len(device_ids) > 1: 95 | logger.info(f"current {len(device_ids)} GPUs") 96 | model = torch.nn.DataParallel(model, device_ids=device_ids) 97 | if len(device_ids) == 1: 98 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0]) 99 | model = model.to(device) 100 | return model, device 101 | 102 | 103 | def restore_checkpoint(resume_path, model=None): 104 | ''' 105 | 加载模型 106 | :param resume_path: 107 | :param model: 108 | :param optimizer: 109 | :return: 110 | 注意: 如果是加载Bert模型的话,需要调整,不能使用该模式 111 | 可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict) 112 | ''' 113 | if isinstance(resume_path, Path): 114 | resume_path = str(resume_path) 115 | checkpoint = torch.load(resume_path) 116 | best = checkpoint['best'] 117 | start_epoch = checkpoint['epoch'] + 1 118 | states = checkpoint['state_dict'] 119 | if isinstance(model, nn.DataParallel): 120 | model.module.load_state_dict(states) 121 | else: 122 | model.load_state_dict(states) 123 | return [model,best,start_epoch] 124 | 125 | 126 | def save_pickle(data, file_path): 127 | ''' 128 | 保存成pickle文件 129 | :param data: 130 | :param file_name: 131 | :param pickle_path: 132 | :return: 133 | ''' 134 | if isinstance(file_path, Path): 135 | file_path = str(file_path) 136 | with open(file_path, 'wb') as f: 137 | pickle.dump(data, f) 138 | 139 | 140 | def load_pickle(input_file): 141 | ''' 142 | 读取pickle文件 143 | :param pickle_path: 144 | :param file_name: 145 | :return: 146 | ''' 147 | with open(str(input_file), 'rb') as f: 148 | data = pickle.load(f) 149 | return data 150 | 151 | 152 | def save_json(data, file_path): 153 | ''' 154 | 保存成json文件 155 | :param data: 156 | :param json_path: 157 | :param file_name: 158 | :return: 159 | ''' 160 | if not isinstance(file_path, Path): 161 | file_path = Path(file_path) 162 | # if isinstance(data,dict): 163 | # data = json.dumps(data) 164 | with open(str(file_path), 'w') as f: 165 | json.dump(data, f) 166 | 167 | def save_numpy(data, file_path): 168 | ''' 169 | 保存成.npy文件 170 | :param data: 171 | :param file_path: 172 | :return: 173 | ''' 174 | if not isinstance(file_path, Path): 175 | file_path = Path(file_path) 176 | np.save(str(file_path),data) 177 | 178 | def load_numpy(file_path): 179 | ''' 180 | 加载.npy文件 181 | :param file_path: 182 | :return: 183 | ''' 184 | if not isinstance(file_path, Path): 185 | file_path = Path(file_path) 186 | np.load(str(file_path)) 187 | 188 | def load_json(file_path): 189 | ''' 190 | 加载json文件 191 | :param json_path: 192 | :param file_name: 193 | :return: 194 | ''' 195 | if not isinstance(file_path, Path): 196 | file_path = Path(file_path) 197 | with open(str(file_path), 'r') as f: 198 | data = json.load(f) 199 | return data 200 | 201 | def json_to_text(file_path,data): 202 | ''' 203 | 将json list写入text文件中 204 | :param file_path: 205 | :param data: 206 | :return: 207 | ''' 208 | if not isinstance(file_path, Path): 209 | file_path = Path(file_path) 210 | with open(str(file_path), 'w') as fw: 211 | for line in data: 212 | line = json.dumps(line, ensure_ascii=False) 213 | fw.write(line + '\n') 214 | 215 | def save_model(model, model_path): 216 | """ 存储不含有显卡信息的state_dict或model 217 | :param model: 218 | :param model_name: 219 | :param only_param: 220 | :return: 221 | """ 222 | if isinstance(model_path, Path): 223 | model_path = str(model_path) 224 | if isinstance(model, nn.DataParallel): 225 | model = model.module 226 | state_dict = model.state_dict() 227 | for key in state_dict: 228 | state_dict[key] = state_dict[key].cpu() 229 | torch.save(state_dict, model_path) 230 | 231 | def load_model(model, model_path): 232 | ''' 233 | 加载模型 234 | :param model: 235 | :param model_name: 236 | :param model_path: 237 | :param only_param: 238 | :return: 239 | ''' 240 | if isinstance(model_path, Path): 241 | model_path = str(model_path) 242 | logging.info(f"loading model from {str(model_path)} .") 243 | states = torch.load(model_path) 244 | state = states['state_dict'] 245 | if isinstance(model, nn.DataParallel): 246 | model.module.load_state_dict(state) 247 | else: 248 | model.load_state_dict(state) 249 | return model 250 | 251 | 252 | class AverageMeter(object): 253 | ''' 254 | computes and stores the average and current value 255 | Example: 256 | >>> loss = AverageMeter() 257 | >>> for step,batch in enumerate(train_data): 258 | >>> pred = self.model(batch) 259 | >>> raw_loss = self.metrics(pred,target) 260 | >>> loss.update(raw_loss.item(),n = 1) 261 | >>> cur_loss = loss.avg 262 | ''' 263 | 264 | def __init__(self): 265 | self.reset() 266 | 267 | def reset(self): 268 | self.val = 0 269 | self.avg = 0 270 | self.sum = 0 271 | self.count = 0 272 | 273 | def update(self, val, n=1): 274 | self.val = val 275 | self.sum += val * n 276 | self.count += n 277 | self.avg = self.sum / self.count 278 | 279 | 280 | def summary(model, *inputs, batch_size=-1, show_input=True): 281 | ''' 282 | 打印模型结构信息 283 | :param model: 284 | :param inputs: 285 | :param batch_size: 286 | :param show_input: 287 | :return: 288 | Example: 289 | >>> print("model summary info: ") 290 | >>> for step,batch in enumerate(train_data): 291 | >>> summary(self.model,*batch,show_input=True) 292 | >>> break 293 | ''' 294 | 295 | def register_hook(module): 296 | def hook(module, input, output=None): 297 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 298 | module_idx = len(summary) 299 | 300 | m_key = f"{class_name}-{module_idx + 1}" 301 | summary[m_key] = OrderedDict() 302 | summary[m_key]["input_shape"] = list(input[0].size()) 303 | summary[m_key]["input_shape"][0] = batch_size 304 | 305 | if show_input is False and output is not None: 306 | if isinstance(output, (list, tuple)): 307 | for out in output: 308 | if isinstance(out, torch.Tensor): 309 | summary[m_key]["output_shape"] = [ 310 | [-1] + list(out.size())[1:] 311 | ][0] 312 | else: 313 | summary[m_key]["output_shape"] = [ 314 | [-1] + list(out[0].size())[1:] 315 | ][0] 316 | else: 317 | summary[m_key]["output_shape"] = list(output.size()) 318 | summary[m_key]["output_shape"][0] = batch_size 319 | 320 | params = 0 321 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 322 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 323 | summary[m_key]["trainable"] = module.weight.requires_grad 324 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 325 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 326 | summary[m_key]["nb_params"] = params 327 | 328 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)): 329 | if show_input is True: 330 | hooks.append(module.register_forward_pre_hook(hook)) 331 | else: 332 | hooks.append(module.register_forward_hook(hook)) 333 | 334 | # create properties 335 | summary = OrderedDict() 336 | hooks = [] 337 | 338 | # register hook 339 | model.apply(register_hook) 340 | model(*inputs) 341 | 342 | # remove these hooks 343 | for h in hooks: 344 | h.remove() 345 | 346 | print("-----------------------------------------------------------------------") 347 | if show_input is True: 348 | line_new = f"{'Layer (type)':>25} {'Input Shape':>25} {'Param #':>15}" 349 | else: 350 | line_new = f"{'Layer (type)':>25} {'Output Shape':>25} {'Param #':>15}" 351 | print(line_new) 352 | print("=======================================================================") 353 | 354 | total_params = 0 355 | total_output = 0 356 | trainable_params = 0 357 | for layer in summary: 358 | # input_shape, output_shape, trainable, nb_params 359 | if show_input is True: 360 | line_new = "{:>25} {:>25} {:>15}".format( 361 | layer, 362 | str(summary[layer]["input_shape"]), 363 | "{0:,}".format(summary[layer]["nb_params"]), 364 | ) 365 | else: 366 | line_new = "{:>25} {:>25} {:>15}".format( 367 | layer, 368 | str(summary[layer]["output_shape"]), 369 | "{0:,}".format(summary[layer]["nb_params"]), 370 | ) 371 | 372 | total_params += summary[layer]["nb_params"] 373 | if show_input is True: 374 | total_output += np.prod(summary[layer]["input_shape"]) 375 | else: 376 | total_output += np.prod(summary[layer]["output_shape"]) 377 | if "trainable" in summary[layer]: 378 | if summary[layer]["trainable"] == True: 379 | trainable_params += summary[layer]["nb_params"] 380 | 381 | print(line_new) 382 | 383 | print("=======================================================================") 384 | print(f"Total params: {total_params:0,}") 385 | print(f"Trainable params: {trainable_params:0,}") 386 | print(f"Non-trainable params: {(total_params - trainable_params):0,}") 387 | print("-----------------------------------------------------------------------") -------------------------------------------------------------------------------- /models/layers/crf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List, Optional 4 | 5 | class CRF(nn.Module): 6 | """Conditional random field. 7 | This module implements a conditional random field [LMP01]_. The forward computation 8 | of this class computes the log likelihood of the given sequence of tags and 9 | emission score tensor. This class also has `~CRF.decode` method which finds 10 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_. 11 | Args: 12 | num_tags: Number of tags. 13 | batch_first: Whether the first dimension corresponds to the size of a minibatch. 14 | Attributes: 15 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size 16 | ``(num_tags,)``. 17 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size 18 | ``(num_tags,)``. 19 | transitions (`~torch.nn.Parameter`): Transition score tensor of size 20 | ``(num_tags, num_tags)``. 21 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001). 22 | "Conditional random fields: Probabilistic models for segmenting and 23 | labeling sequence data". *Proc. 18th International Conf. on Machine 24 | Learning*. Morgan Kaufmann. pp. 282–289. 25 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm 26 | """ 27 | 28 | def __init__(self, num_tags: int, batch_first: bool = False) -> None: 29 | if num_tags <= 0: 30 | raise ValueError(f'invalid number of tags: {num_tags}') 31 | super().__init__() 32 | self.num_tags = num_tags 33 | self.batch_first = batch_first 34 | self.start_transitions = nn.Parameter(torch.empty(num_tags)) 35 | self.end_transitions = nn.Parameter(torch.empty(num_tags)) 36 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags)) 37 | 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self) -> None: 41 | """Initialize the transition parameters. 42 | The parameters will be initialized randomly from a uniform distribution 43 | between -0.1 and 0.1. 44 | """ 45 | nn.init.uniform_(self.start_transitions, -0.1, 0.1) 46 | nn.init.uniform_(self.end_transitions, -0.1, 0.1) 47 | nn.init.uniform_(self.transitions, -0.1, 0.1) 48 | 49 | def __repr__(self) -> str: 50 | return f'{self.__class__.__name__}(num_tags={self.num_tags})' 51 | 52 | def forward(self, emissions: torch.Tensor, 53 | tags: torch.LongTensor, 54 | mask: Optional[torch.ByteTensor] = None, 55 | reduction: str = 'mean') -> torch.Tensor: 56 | """Compute the conditional log likelihood of a sequence of tags given emission scores. 57 | Args: 58 | emissions (`~torch.Tensor`): Emission score tensor of size 59 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 60 | ``(batch_size, seq_length, num_tags)`` otherwise. 61 | tags (`~torch.LongTensor`): Sequence of tags tensor of size 62 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``, 63 | ``(batch_size, seq_length)`` otherwise. 64 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 65 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 66 | reduction: Specifies the reduction to apply to the output: 67 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied. 68 | ``sum``: the output will be summed over batches. ``mean``: the output will be 69 | averaged over batches. ``token_mean``: the output will be averaged over tokens. 70 | Returns: 71 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if 72 | reduction is ``none``, ``()`` otherwise. 73 | """ 74 | if reduction not in ('none', 'sum', 'mean', 'token_mean'): 75 | raise ValueError(f'invalid reduction: {reduction}') 76 | if mask is None: 77 | mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device) 78 | if mask.dtype != torch.uint8: 79 | mask = mask.byte() 80 | self._validate(emissions, tags=tags, mask=mask) 81 | 82 | if self.batch_first: 83 | emissions = emissions.transpose(0, 1) 84 | tags = tags.transpose(0, 1) 85 | mask = mask.transpose(0, 1) 86 | 87 | # shape: (batch_size,) 88 | numerator = self._compute_score(emissions, tags, mask) 89 | # shape: (batch_size,) 90 | denominator = self._compute_normalizer(emissions, mask) 91 | # shape: (batch_size,) 92 | llh = numerator - denominator 93 | 94 | if reduction == 'none': 95 | return llh 96 | if reduction == 'sum': 97 | return llh.sum() 98 | if reduction == 'mean': 99 | return llh.mean() 100 | return llh.sum() / mask.float().sum() 101 | 102 | def decode(self, emissions: torch.Tensor, 103 | mask: Optional[torch.ByteTensor] = None, 104 | nbest: Optional[int] = None, 105 | pad_tag: Optional[int] = None) -> List[List[List[int]]]: 106 | """Find the most likely tag sequence using Viterbi algorithm. 107 | Args: 108 | emissions (`~torch.Tensor`): Emission score tensor of size 109 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``, 110 | ``(batch_size, seq_length, num_tags)`` otherwise. 111 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)`` 112 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise. 113 | nbest (`int`): Number of most probable paths for each sequence 114 | pad_tag (`int`): Tag at padded positions. Often input varies in length and 115 | the length will be padded to the maximum length in the batch. Tags at 116 | the padded positions will be assigned with a padding tag, i.e. `pad_tag` 117 | Returns: 118 | A PyTorch tensor of the best tag sequence for each batch of shape 119 | (nbest, batch_size, seq_length) 120 | """ 121 | if nbest is None: 122 | nbest = 1 123 | if mask is None: 124 | mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, 125 | device=emissions.device) 126 | if mask.dtype != torch.uint8: 127 | mask = mask.byte() 128 | self._validate(emissions, mask=mask) 129 | 130 | if self.batch_first: 131 | emissions = emissions.transpose(0, 1) 132 | mask = mask.transpose(0, 1) 133 | 134 | if nbest == 1: 135 | return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0) 136 | return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag) 137 | 138 | def _validate(self, emissions: torch.Tensor, 139 | tags: Optional[torch.LongTensor] = None, 140 | mask: Optional[torch.ByteTensor] = None) -> None: 141 | if emissions.dim() != 3: 142 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') 143 | if emissions.size(2) != self.num_tags: 144 | raise ValueError( 145 | f'expected last dimension of emissions is {self.num_tags}, ' 146 | f'got {emissions.size(2)}') 147 | 148 | if tags is not None: 149 | if emissions.shape[:2] != tags.shape: 150 | raise ValueError( 151 | 'the first two dimensions of emissions and tags must match, ' 152 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}') 153 | 154 | if mask is not None: 155 | if emissions.shape[:2] != mask.shape: 156 | raise ValueError( 157 | 'the first two dimensions of emissions and mask must match, ' 158 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}') 159 | no_empty_seq = not self.batch_first and mask[0].all() 160 | no_empty_seq_bf = self.batch_first and mask[:, 0].all() 161 | if not no_empty_seq and not no_empty_seq_bf: 162 | raise ValueError('mask of the first timestep must all be on') 163 | 164 | def _compute_score(self, emissions: torch.Tensor, 165 | tags: torch.LongTensor, 166 | mask: torch.ByteTensor) -> torch.Tensor: 167 | # emissions: (seq_length, batch_size, num_tags) 168 | # tags: (seq_length, batch_size) 169 | # mask: (seq_length, batch_size) 170 | seq_length, batch_size = tags.shape 171 | mask = mask.float() 172 | 173 | # Start transition score and first emission 174 | # shape: (batch_size,) 175 | score = self.start_transitions[tags[0]] 176 | score += emissions[0, torch.arange(batch_size), tags[0]] 177 | 178 | for i in range(1, seq_length): 179 | # Transition score to next tag, only added if next timestep is valid (mask == 1) 180 | # shape: (batch_size,) 181 | score += self.transitions[tags[i - 1], tags[i]] * mask[i] 182 | 183 | # Emission score for next tag, only added if next timestep is valid (mask == 1) 184 | # shape: (batch_size,) 185 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i] 186 | 187 | # End transition score 188 | # shape: (batch_size,) 189 | seq_ends = mask.long().sum(dim=0) - 1 190 | # shape: (batch_size,) 191 | last_tags = tags[seq_ends, torch.arange(batch_size)] 192 | # shape: (batch_size,) 193 | score += self.end_transitions[last_tags] 194 | 195 | return score 196 | 197 | def _compute_normalizer(self, emissions: torch.Tensor, 198 | mask: torch.ByteTensor) -> torch.Tensor: 199 | # emissions: (seq_length, batch_size, num_tags) 200 | # mask: (seq_length, batch_size) 201 | seq_length = emissions.size(0) 202 | 203 | # Start transition score and first emission; score has size of 204 | # (batch_size, num_tags) where for each batch, the j-th column stores 205 | # the score that the first timestep has tag j 206 | # shape: (batch_size, num_tags) 207 | score = self.start_transitions + emissions[0] 208 | 209 | for i in range(1, seq_length): 210 | # Broadcast score for every possible next tag 211 | # shape: (batch_size, num_tags, 1) 212 | broadcast_score = score.unsqueeze(2) 213 | 214 | # Broadcast emission score for every possible current tag 215 | # shape: (batch_size, 1, num_tags) 216 | broadcast_emissions = emissions[i].unsqueeze(1) 217 | 218 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 219 | # for each sample, entry at row i and column j stores the sum of scores of all 220 | # possible tag sequences so far that end with transitioning from tag i to tag j 221 | # and emitting 222 | # shape: (batch_size, num_tags, num_tags) 223 | next_score = broadcast_score + self.transitions + broadcast_emissions 224 | 225 | # Sum over all possible current tags, but we're in score space, so a sum 226 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of 227 | # all possible tag sequences so far, that end in tag i 228 | # shape: (batch_size, num_tags) 229 | next_score = torch.logsumexp(next_score, dim=1) 230 | 231 | # Set score to the next score if this timestep is valid (mask == 1) 232 | # shape: (batch_size, num_tags) 233 | score = torch.where(mask[i].unsqueeze(1), next_score, score) 234 | 235 | # End transition score 236 | # shape: (batch_size, num_tags) 237 | score += self.end_transitions 238 | 239 | # Sum (log-sum-exp) over all possible tags 240 | # shape: (batch_size,) 241 | return torch.logsumexp(score, dim=1) 242 | 243 | def _viterbi_decode(self, emissions: torch.FloatTensor, 244 | mask: torch.ByteTensor, 245 | pad_tag: Optional[int] = None) -> List[List[int]]: 246 | # emissions: (seq_length, batch_size, num_tags) 247 | # mask: (seq_length, batch_size) 248 | # return: (batch_size, seq_length) 249 | if pad_tag is None: 250 | pad_tag = 0 251 | 252 | device = emissions.device 253 | seq_length, batch_size = mask.shape 254 | 255 | # Start transition and first emission 256 | # shape: (batch_size, num_tags) 257 | score = self.start_transitions + emissions[0] 258 | history_idx = torch.zeros((seq_length, batch_size, self.num_tags), 259 | dtype=torch.long, device=device) 260 | oor_idx = torch.zeros((batch_size, self.num_tags), 261 | dtype=torch.long, device=device) 262 | oor_tag = torch.full((seq_length, batch_size), pad_tag, 263 | dtype=torch.long, device=device) 264 | 265 | # - score is a tensor of size (batch_size, num_tags) where for every batch, 266 | # value at column j stores the score of the best tag sequence so far that ends 267 | # with tag j 268 | # - history_idx saves where the best tags candidate transitioned from; this is used 269 | # when we trace back the best tag sequence 270 | # - oor_idx saves the best tags candidate transitioned from at the positions 271 | # where mask is 0, i.e. out of range (oor) 272 | 273 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 274 | # for every possible next tag 275 | for i in range(1, seq_length): 276 | # Broadcast viterbi score for every possible next tag 277 | # shape: (batch_size, num_tags, 1) 278 | broadcast_score = score.unsqueeze(2) 279 | 280 | # Broadcast emission score for every possible current tag 281 | # shape: (batch_size, 1, num_tags) 282 | broadcast_emission = emissions[i].unsqueeze(1) 283 | 284 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where 285 | # for each sample, entry at row i and column j stores the score of the best 286 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting 287 | # shape: (batch_size, num_tags, num_tags) 288 | next_score = broadcast_score + self.transitions + broadcast_emission 289 | 290 | # Find the maximum score over all possible current tag 291 | # shape: (batch_size, num_tags) 292 | next_score, indices = next_score.max(dim=1) 293 | 294 | # Set score to the next score if this timestep is valid (mask == 1) 295 | # and save the index that produces the next score 296 | # shape: (batch_size, num_tags) 297 | score = torch.where(mask[i].unsqueeze(-1), next_score, score) 298 | indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx) 299 | history_idx[i - 1] = indices 300 | 301 | # End transition score 302 | # shape: (batch_size, num_tags) 303 | end_score = score + self.end_transitions 304 | _, end_tag = end_score.max(dim=1) 305 | 306 | # shape: (batch_size,) 307 | seq_ends = mask.long().sum(dim=0) - 1 308 | 309 | # insert the best tag at each sequence end (last position with mask == 1) 310 | history_idx = history_idx.transpose(1, 0).contiguous() 311 | history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags), 312 | end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags)) 313 | history_idx = history_idx.transpose(1, 0).contiguous() 314 | 315 | # The most probable path for each sequence 316 | best_tags_arr = torch.zeros((seq_length, batch_size), 317 | dtype=torch.long, device=device) 318 | best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device) 319 | for idx in range(seq_length - 1, -1, -1): 320 | best_tags = torch.gather(history_idx[idx], 1, best_tags) 321 | best_tags_arr[idx] = best_tags.data.view(batch_size) 322 | 323 | return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1) 324 | 325 | def _viterbi_decode_nbest(self, emissions: torch.FloatTensor, 326 | mask: torch.ByteTensor, 327 | nbest: int, 328 | pad_tag: Optional[int] = None) -> List[List[List[int]]]: 329 | # emissions: (seq_length, batch_size, num_tags) 330 | # mask: (seq_length, batch_size) 331 | # return: (nbest, batch_size, seq_length) 332 | if pad_tag is None: 333 | pad_tag = 0 334 | 335 | device = emissions.device 336 | seq_length, batch_size = mask.shape 337 | 338 | # Start transition and first emission 339 | # shape: (batch_size, num_tags) 340 | score = self.start_transitions + emissions[0] 341 | history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest), 342 | dtype=torch.long, device=device) 343 | oor_idx = torch.zeros((batch_size, self.num_tags, nbest), 344 | dtype=torch.long, device=device) 345 | oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag, 346 | dtype=torch.long, device=device) 347 | 348 | # + score is a tensor of size (batch_size, num_tags) where for every batch, 349 | # value at column j stores the score of the best tag sequence so far that ends 350 | # with tag j 351 | # + history_idx saves where the best tags candidate transitioned from; this is used 352 | # when we trace back the best tag sequence 353 | # - oor_idx saves the best tags candidate transitioned from at the positions 354 | # where mask is 0, i.e. out of range (oor) 355 | 356 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence 357 | # for every possible next tag 358 | for i in range(1, seq_length): 359 | if i == 1: 360 | broadcast_score = score.unsqueeze(-1) 361 | broadcast_emission = emissions[i].unsqueeze(1) 362 | # shape: (batch_size, num_tags, num_tags) 363 | next_score = broadcast_score + self.transitions + broadcast_emission 364 | else: 365 | broadcast_score = score.unsqueeze(-1) 366 | broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2) 367 | # shape: (batch_size, num_tags, nbest, num_tags) 368 | next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission 369 | 370 | # Find the top `nbest` maximum score over all possible current tag 371 | # shape: (batch_size, nbest, num_tags) 372 | next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1) 373 | 374 | if i == 1: 375 | score = score.unsqueeze(-1).expand(-1, -1, nbest) 376 | indices = indices * nbest 377 | 378 | # convert to shape: (batch_size, num_tags, nbest) 379 | next_score = next_score.transpose(2, 1) 380 | indices = indices.transpose(2, 1) 381 | 382 | # Set score to the next score if this timestep is valid (mask == 1) 383 | # and save the index that produces the next score 384 | # shape: (batch_size, num_tags, nbest) 385 | score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score) 386 | indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx) 387 | history_idx[i - 1] = indices 388 | 389 | # End transition score shape: (batch_size, num_tags, nbest) 390 | end_score = score + self.end_transitions.unsqueeze(-1) 391 | _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1) 392 | 393 | # shape: (batch_size,) 394 | seq_ends = mask.long().sum(dim=0) - 1 395 | 396 | # insert the best tag at each sequence end (last position with mask == 1) 397 | history_idx = history_idx.transpose(1, 0).contiguous() 398 | history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest), 399 | end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest)) 400 | history_idx = history_idx.transpose(1, 0).contiguous() 401 | 402 | # The most probable path for each sequence 403 | best_tags_arr = torch.zeros((seq_length, batch_size, nbest), 404 | dtype=torch.long, device=device) 405 | best_tags = torch.arange(nbest, dtype=torch.long, device=device) \ 406 | .view(1, -1).expand(batch_size, -1) 407 | for idx in range(seq_length - 1, -1, -1): 408 | best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags) 409 | best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest 410 | 411 | return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0) -------------------------------------------------------------------------------- /callback/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import warnings 4 | from torch.optim.optimizer import Optimizer 5 | from torch.optim.lr_scheduler import LambdaLR 6 | 7 | def get_constant_schedule(optimizer, last_epoch=-1): 8 | """ Create a schedule with a constant learning rate. 9 | """ 10 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 11 | 12 | 13 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 14 | """ Create a schedule with a constant learning rate preceded by a warmup 15 | period during which the learning rate increases linearly between 0 and 1. 16 | """ 17 | def lr_lambda(current_step): 18 | if current_step < num_warmup_steps: 19 | return float(current_step) / float(max(1.0, num_warmup_steps)) 20 | return 1. 21 | 22 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 23 | 24 | 25 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 26 | """ Create a schedule with a learning rate that decreases linearly after 27 | linearly increasing during a warmup period. 28 | """ 29 | def lr_lambda(current_step): 30 | if current_step < num_warmup_steps: 31 | return float(current_step) / float(max(1, num_warmup_steps)) 32 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 33 | 34 | return LambdaLR(optimizer, lr_lambda, last_epoch) 35 | 36 | 37 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): 38 | """ Create a schedule with a learning rate that decreases following the 39 | values of the cosine function between 0 and `pi * cycles` after a warmup 40 | period during which it increases linearly between 0 and 1. 41 | """ 42 | def lr_lambda(current_step): 43 | if current_step < num_warmup_steps: 44 | return float(current_step) / float(max(1, num_warmup_steps)) 45 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 46 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) 47 | 48 | return LambdaLR(optimizer, lr_lambda, last_epoch) 49 | 50 | 51 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): 52 | """ Create a schedule with a learning rate that decreases following the 53 | values of the cosine function with several hard restarts, after a warmup 54 | period during which it increases linearly between 0 and 1. 55 | """ 56 | def lr_lambda(current_step): 57 | if current_step < num_warmup_steps: 58 | return float(current_step) / float(max(1, num_warmup_steps)) 59 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 60 | if progress >= 1.: 61 | return 0. 62 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) 63 | 64 | return LambdaLR(optimizer, lr_lambda, last_epoch) 65 | 66 | 67 | class CustomDecayLR(object): 68 | ''' 69 | 自定义学习率变化机制 70 | Example: 71 | >>> scheduler = CustomDecayLR(optimizer) 72 | >>> for epoch in range(100): 73 | >>> scheduler.epoch_step() 74 | >>> train(...) 75 | >>> ... 76 | >>> optimizer.zero_grad() 77 | >>> loss.backward() 78 | >>> optimizer.step() 79 | >>> validate(...) 80 | ''' 81 | def __init__(self,optimizer,lr): 82 | self.optimizer = optimizer 83 | self.lr = lr 84 | 85 | def epoch_step(self,epoch): 86 | lr = self.lr 87 | if epoch > 12: 88 | lr = lr / 1000 89 | elif epoch > 8: 90 | lr = lr / 100 91 | elif epoch > 4: 92 | lr = lr / 10 93 | for param_group in self.optimizer.param_groups: 94 | param_group['lr'] = lr 95 | 96 | class BertLR(object): 97 | ''' 98 | Bert模型内定的学习率变化机制 99 | Example: 100 | >>> scheduler = BertLR(optimizer) 101 | >>> for epoch in range(100): 102 | >>> scheduler.step() 103 | >>> train(...) 104 | >>> ... 105 | >>> optimizer.zero_grad() 106 | >>> loss.backward() 107 | >>> optimizer.step() 108 | >>> scheduler.batch_step() 109 | >>> validate(...) 110 | ''' 111 | def __init__(self,optimizer,learning_rate,t_total,warmup): 112 | self.learning_rate = learning_rate 113 | self.optimizer = optimizer 114 | self.t_total = t_total 115 | self.warmup = warmup 116 | 117 | # 线性预热方式 118 | def warmup_linear(self,x, warmup=0.002): 119 | if x < warmup: 120 | return x / warmup 121 | return 1.0 - x 122 | 123 | def batch_step(self,training_step): 124 | lr_this_step = self.learning_rate * self.warmup_linear(training_step / self.t_total,self.warmup) 125 | for param_group in self.optimizer.param_groups: 126 | param_group['lr'] = lr_this_step 127 | 128 | class CyclicLR(object): 129 | ''' 130 | Cyclical learning rates for training neural networks 131 | Example: 132 | >>> scheduler = CyclicLR(optimizer) 133 | >>> for epoch in range(100): 134 | >>> scheduler.step() 135 | >>> train(...) 136 | >>> ... 137 | >>> optimizer.zero_grad() 138 | >>> loss.backward() 139 | >>> optimizer.step() 140 | >>> scheduler.batch_step() 141 | >>> validate(...) 142 | ''' 143 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 144 | step_size=2000, mode='triangular', gamma=1., 145 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 146 | 147 | if not isinstance(optimizer, Optimizer): 148 | raise TypeError('{} is not an Optimizer'.format( 149 | type(optimizer).__name__)) 150 | 151 | self.optimizer = optimizer 152 | 153 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 154 | if len(base_lr) != len(optimizer.param_groups): 155 | raise ValueError("expected {} base_lr, got {}".format( 156 | len(optimizer.param_groups), len(base_lr))) 157 | self.base_lrs = list(base_lr) 158 | else: 159 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 160 | 161 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 162 | if len(max_lr) != len(optimizer.param_groups): 163 | raise ValueError("expected {} max_lr, got {}".format( 164 | len(optimizer.param_groups), len(max_lr))) 165 | self.max_lrs = list(max_lr) 166 | else: 167 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 168 | 169 | self.step_size = step_size 170 | 171 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 172 | and scale_fn is None: 173 | raise ValueError('mode is invalid and scale_fn is None') 174 | 175 | self.mode = mode 176 | self.gamma = gamma 177 | 178 | if scale_fn is None: 179 | if self.mode == 'triangular': 180 | self.scale_fn = self._triangular_scale_fn 181 | self.scale_mode = 'cycle' 182 | elif self.mode == 'triangular2': 183 | self.scale_fn = self._triangular2_scale_fn 184 | self.scale_mode = 'cycle' 185 | elif self.mode == 'exp_range': 186 | self.scale_fn = self._exp_range_scale_fn 187 | self.scale_mode = 'iterations' 188 | else: 189 | self.scale_fn = scale_fn 190 | self.scale_mode = scale_mode 191 | 192 | self.batch_step(last_batch_iteration + 1) 193 | self.last_batch_iteration = last_batch_iteration 194 | 195 | def _triangular_scale_fn(self, x): 196 | return 1. 197 | 198 | def _triangular2_scale_fn(self, x): 199 | return 1 / (2. ** (x - 1)) 200 | 201 | def _exp_range_scale_fn(self, x): 202 | return self.gamma**(x) 203 | 204 | def get_lr(self): 205 | step_size = float(self.step_size) 206 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 207 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 208 | 209 | lrs = [] 210 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 211 | for param_group, base_lr, max_lr in param_lrs: 212 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 213 | if self.scale_mode == 'cycle': 214 | lr = base_lr + base_height * self.scale_fn(cycle) 215 | else: 216 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 217 | lrs.append(lr) 218 | return lrs 219 | 220 | def batch_step(self, batch_iteration=None): 221 | if batch_iteration is None: 222 | batch_iteration = self.last_batch_iteration + 1 223 | self.last_batch_iteration = batch_iteration 224 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 225 | param_group['lr'] = lr 226 | 227 | class ReduceLROnPlateau(object): 228 | """Reduce learning rate when a metric has stopped improving. 229 | Models often benefit from reducing the learning rate by a factor 230 | of 2-10 once learning stagnates. This scheduler reads a metrics 231 | quantity and if no improvement is seen for a 'patience' number 232 | of epochs, the learning rate is reduced. 233 | 234 | Args: 235 | factor: factor by which the learning rate will 236 | be reduced. new_lr = lr * factor 237 | patience: number of epochs with no improvement 238 | after which learning rate will be reduced. 239 | verbose: int. 0: quiet, 1: update messages. 240 | mode: one of {min, max}. In `min` mode, 241 | lr will be reduced when the quantity 242 | monitored has stopped decreasing; in `max` 243 | mode it will be reduced when the quantity 244 | monitored has stopped increasing. 245 | epsilon: threshold for measuring the new optimum, 246 | to only focus on significant changes. 247 | cooldown: number of epochs to wait before resuming 248 | normal operation after lr has been reduced. 249 | min_lr: lower bound on the learning rate. 250 | 251 | 252 | Example: 253 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 254 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 255 | >>> for epoch in range(10): 256 | >>> train(...) 257 | >>> val_acc, val_loss = validate(...) 258 | >>> scheduler.epoch_step(val_loss, epoch) 259 | """ 260 | 261 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 262 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0,eps=1e-8): 263 | 264 | super(ReduceLROnPlateau, self).__init__() 265 | assert isinstance(optimizer, Optimizer) 266 | if factor >= 1.0: 267 | raise ValueError('ReduceLROnPlateau ' 268 | 'does not support a factor >= 1.0.') 269 | self.factor = factor 270 | self.min_lr = min_lr 271 | self.epsilon = epsilon 272 | self.patience = patience 273 | self.verbose = verbose 274 | self.cooldown = cooldown 275 | self.cooldown_counter = 0 # Cooldown counter. 276 | self.monitor_op = None 277 | self.wait = 0 278 | self.best = 0 279 | self.mode = mode 280 | self.optimizer = optimizer 281 | self.eps = eps 282 | self._reset() 283 | 284 | def _reset(self): 285 | """Resets wait counter and cooldown counter. 286 | """ 287 | if self.mode not in ['min', 'max']: 288 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 289 | if self.mode == 'min': 290 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 291 | self.best = np.Inf 292 | else: 293 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 294 | self.best = -np.Inf 295 | self.cooldown_counter = 0 296 | self.wait = 0 297 | 298 | def reset(self): 299 | self._reset() 300 | 301 | def epoch_step(self, metrics, epoch): 302 | current = metrics 303 | if current is None: 304 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 305 | else: 306 | if self.in_cooldown(): 307 | self.cooldown_counter -= 1 308 | self.wait = 0 309 | 310 | if self.monitor_op(current, self.best): 311 | self.best = current 312 | self.wait = 0 313 | elif not self.in_cooldown(): 314 | if self.wait >= self.patience: 315 | for param_group in self.optimizer.param_groups: 316 | old_lr = float(param_group['lr']) 317 | if old_lr > self.min_lr + self.eps: 318 | new_lr = old_lr * self.factor 319 | new_lr = max(new_lr, self.min_lr) 320 | param_group['lr'] = new_lr 321 | if self.verbose > 0: 322 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr)) 323 | self.cooldown_counter = self.cooldown 324 | self.wait = 0 325 | self.wait += 1 326 | 327 | def in_cooldown(self): 328 | return self.cooldown_counter > 0 329 | 330 | class ReduceLRWDOnPlateau(ReduceLROnPlateau): 331 | """Reduce learning rate and weight decay when a metric has stopped 332 | improving. Models often benefit from reducing the learning rate by 333 | a factor of 2-10 once learning stagnates. This scheduler reads a metric 334 | quantity and if no improvement is seen for a 'patience' number 335 | of epochs, the learning rate and weight decay factor is reduced for 336 | optimizers that implement the the weight decay method from the paper 337 | `Fixing Weight Decay Regularization in Adam`_. 338 | 339 | .. _Fixing Weight Decay Regularization in Adam: 340 | https://arxiv.org/abs/1711.05101 341 | for AdamW or SGDW 342 | Example: 343 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 344 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 345 | >>> for epoch in range(10): 346 | >>> train(...) 347 | >>> val_loss = validate(...) 348 | >>> # Note that step should be called after validate() 349 | >>> scheduler.epoch_step(val_loss) 350 | """ 351 | def epoch_step(self, metrics, epoch): 352 | current = metrics 353 | if current is None: 354 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 355 | else: 356 | if self.in_cooldown(): 357 | self.cooldown_counter -= 1 358 | self.wait = 0 359 | 360 | if self.monitor_op(current, self.best): 361 | self.best = current 362 | self.wait = 0 363 | elif not self.in_cooldown(): 364 | if self.wait >= self.patience: 365 | for param_group in self.optimizer.param_groups: 366 | old_lr = float(param_group['lr']) 367 | if old_lr > self.min_lr + self.eps: 368 | new_lr = old_lr * self.factor 369 | new_lr = max(new_lr, self.min_lr) 370 | param_group['lr'] = new_lr 371 | if self.verbose > 0: 372 | print('\nEpoch %d: reducing learning rate to %s.' % (epoch, new_lr)) 373 | if param_group['weight_decay'] != 0: 374 | old_weight_decay = float(param_group['weight_decay']) 375 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr) 376 | if old_weight_decay > new_weight_decay + self.eps: 377 | param_group['weight_decay'] = new_weight_decay 378 | if self.verbose: 379 | print('\nEpoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}.') 380 | self.cooldown_counter = self.cooldown 381 | self.wait = 0 382 | self.wait += 1 383 | 384 | class CosineLRWithRestarts(object): 385 | """Decays learning rate with cosine annealing, normalizes weight decay 386 | hyperparameter value, implements restarts. 387 | https://arxiv.org/abs/1711.05101 388 | 389 | Args: 390 | optimizer (Optimizer): Wrapped optimizer. 391 | batch_size: minibatch size 392 | epoch_size: training samples per epoch 393 | restart_period: epoch count in the first restart period 394 | t_mult: multiplication factor by which the next restart period will extend/shrink 395 | 396 | Example: 397 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 398 | >>> for epoch in range(100): 399 | >>> scheduler.step() 400 | >>> train(...) 401 | >>> ... 402 | >>> optimizer.zero_grad() 403 | >>> loss.backward() 404 | >>> optimizer.step() 405 | >>> scheduler.batch_step() 406 | >>> validate(...) 407 | """ 408 | 409 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 410 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 411 | if not isinstance(optimizer, Optimizer): 412 | raise TypeError('{} is not an Optimizer'.format( 413 | type(optimizer).__name__)) 414 | self.optimizer = optimizer 415 | if last_epoch == -1: 416 | for group in optimizer.param_groups: 417 | group.setdefault('initial_lr', group['lr']) 418 | else: 419 | for i, group in enumerate(optimizer.param_groups): 420 | if 'initial_lr' not in group: 421 | raise KeyError("param 'initial_lr' is not specified " 422 | "in param_groups[{}] when resuming an" 423 | " optimizer".format(i)) 424 | self.base_lrs = list(map(lambda group: group['initial_lr'], 425 | optimizer.param_groups)) 426 | 427 | self.last_epoch = last_epoch 428 | self.batch_size = batch_size 429 | self.iteration = 0 430 | self.epoch_size = epoch_size 431 | self.eta_threshold = eta_threshold 432 | self.t_mult = t_mult 433 | self.verbose = verbose 434 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 435 | optimizer.param_groups)) 436 | self.restart_period = restart_period 437 | self.restarts = 0 438 | self.t_epoch = -1 439 | self.batch_increments = [] 440 | self._set_batch_increment() 441 | 442 | def _schedule_eta(self): 443 | """ 444 | Threshold value could be adjusted to shrink eta_min and eta_max values. 445 | """ 446 | eta_min = 0 447 | eta_max = 1 448 | if self.restarts <= self.eta_threshold: 449 | return eta_min, eta_max 450 | else: 451 | d = self.restarts - self.eta_threshold 452 | k = d * 0.09 453 | return (eta_min + k, eta_max - k) 454 | 455 | def get_lr(self, t_cur): 456 | eta_min, eta_max = self._schedule_eta() 457 | 458 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 459 | * (1. + math.cos(math.pi * 460 | (t_cur / self.restart_period)))) 461 | 462 | weight_decay_norm_multi = math.sqrt(self.batch_size / 463 | (self.epoch_size * 464 | self.restart_period)) 465 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 466 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 467 | for base_weight_decay in self.base_weight_decays] 468 | 469 | if self.t_epoch % self.restart_period < self.t_epoch: 470 | if self.verbose: 471 | print("Restart at epoch {}".format(self.last_epoch)) 472 | self.restart_period *= self.t_mult 473 | self.restarts += 1 474 | self.t_epoch = 0 475 | 476 | return zip(lrs, weight_decays) 477 | 478 | def _set_batch_increment(self): 479 | d, r = divmod(self.epoch_size, self.batch_size) 480 | batches_in_epoch = d + 2 if r > 0 else d + 1 481 | self.iteration = 0 482 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch)) 483 | 484 | def batch_step(self): 485 | self.last_epoch += 1 486 | self.t_epoch += 1 487 | self._set_batch_increment() 488 | try: 489 | t_cur = self.t_epoch + self.batch_increments[self.iteration] 490 | self.iteration += 1 491 | except (IndexError): 492 | raise RuntimeError("Epoch size and batch size used in the " 493 | "training loop and while initializing " 494 | "scheduler should be the same.") 495 | 496 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,self.get_lr(t_cur)): 497 | param_group['lr'] = lr 498 | param_group['weight_decay'] = weight_decay 499 | 500 | 501 | class NoamLR(object): 502 | ''' 503 | 主要参考论文<< Attention Is All You Need>>中的学习更新方式 504 | Example: 505 | >>> scheduler = NoamLR(d_model,factor,warm_up,optimizer) 506 | >>> for epoch in range(100): 507 | >>> scheduler.step() 508 | >>> train(...) 509 | >>> ... 510 | >>> glopab_step += 1 511 | >>> optimizer.zero_grad() 512 | >>> loss.backward() 513 | >>> optimizer.step() 514 | >>> scheduler.batch_step(global_step) 515 | >>> validate(...) 516 | ''' 517 | def __init__(self,d_model,factor,warm_up,optimizer): 518 | self.optimizer = optimizer 519 | self.warm_up = warm_up 520 | self.factor = factor 521 | self.d_model = d_model 522 | self._lr = 0 523 | 524 | def get_lr(self,step): 525 | lr = self.factor * (self.d_model ** (-0.5) * min(step ** (-0.5),step * self.warm_up ** (-1.5))) 526 | return lr 527 | 528 | def batch_step(self,step): 529 | ''' 530 | update parameters and rate 531 | :return: 532 | ''' 533 | lr = self.get_lr(step) 534 | for p in self.optimizer.param_groups: 535 | p['lr'] = lr 536 | self._lr = lr 537 | --------------------------------------------------------------------------------