├── __init__.py
├── callback
├── __init__.py
├── optimizater
│ ├── __init__.py
│ ├── planradam.py
│ ├── novograd.py
│ ├── sgdw.py
│ ├── lars.py
│ ├── radam.py
│ ├── nadam.py
│ ├── adamw.py
│ ├── ralamb.py
│ ├── lookahead.py
│ ├── lamb.py
│ ├── ralars.py
│ ├── adabound.py
│ └── adafactor.py
├── trainingmonitor.py
├── progressbar.py
├── modelcheckpoint.py
├── adversarial.py
└── lr_scheduler.py
├── losses
├── __init__.py
├── focal_loss.py
└── label_smoothing.py
├── models
├── __init__.py
├── layers
│ ├── __init__.py
│ ├── linears.py
│ └── crf.py
└── bert_for_ner.py
├── tools
├── __init__.py
├── convert_albert_tf_checkpoint_to_pytorch.py
├── plot.py
├── download_clue_data.py
├── finetuning_argparse.py
└── common.py
├── metrics
├── __init__.py
└── ner_metrics.py
├── processors
├── __init__.py
├── utils_ner.py
├── ner_seq.py
└── ner_span.py
├── datasets
├── cluener
│ └── __init__.py
└── cner
│ └── .gitignore
├── outputs
└── cner_output
│ ├── bert
│ ├── __init__.py
│ └── .gitignore
│ └── .gitignore
├── .idea
├── .gitignore
├── misc.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── vcs.xml
├── modules.xml
└── BERT-NER-Pytorch.iml
├── scripts
├── run_ner_softmax.sh
├── run_ner_crf.sh
└── run_ner_span.sh
├── LICENSE
├── .gitignore
└── README.md
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/callback/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/callback/optimizater/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/processors/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/datasets/cluener/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/outputs/cner_output/bert/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/BERT-NER-Pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/losses/focal_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FocalLoss(nn.Module):
6 | '''Multi-class Focal loss implementation'''
7 | def __init__(self, gamma=2, weight=None,ignore_index=-100):
8 | super(FocalLoss, self).__init__()
9 | self.gamma = gamma
10 | self.weight = weight
11 | self.ignore_index=ignore_index
12 |
13 | def forward(self, input, target):
14 | """
15 | input: [N, C]
16 | target: [N, ]
17 | """
18 | logpt = F.log_softmax(input, dim=1)
19 | pt = torch.exp(logpt)
20 | logpt = (1-pt)**self.gamma * logpt
21 | loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
22 | return loss
23 |
--------------------------------------------------------------------------------
/scripts/run_ner_softmax.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=`pwd`
2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese
3 | export DATA_DIR=$CURRENT_DIR/datasets
4 | export OUTPUR_DIR=$CURRENT_DIR/outputs
5 | TASK_NAME="cner"
6 |
7 | python run_ner_softmax.py \
8 | --model_type=bert \
9 | --model_name_or_path=$BERT_BASE_DIR \
10 | --task_name=$TASK_NAME \
11 | --do_train \
12 | --do_eval \
13 | --do_lower_case \
14 | --loss_type=ce \
15 | --data_dir=$DATA_DIR/${TASK_NAME}/ \
16 | --train_max_seq_length=128 \
17 | --eval_max_seq_length=512 \
18 | --per_gpu_train_batch_size=24 \
19 | --per_gpu_eval_batch_size=24 \
20 | --learning_rate=3e-5 \
21 | --num_train_epochs=3.0 \
22 | --logging_steps=-1 \
23 | --save_steps=-1 \
24 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \
25 | --overwrite_output_dir \
26 | --seed=42
27 |
--------------------------------------------------------------------------------
/scripts/run_ner_crf.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=`pwd`
2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese
3 | export DATA_DIR=$CURRENT_DIR/datasets
4 | export OUTPUR_DIR=$CURRENT_DIR/outputs
5 | TASK_NAME="cner"
6 | #
7 | python run_ner_crf.py \
8 | --model_type=bert \
9 | --model_name_or_path=$BERT_BASE_DIR \
10 | --task_name=$TASK_NAME \
11 | --do_train \
12 | --do_eval \
13 | --do_lower_case \
14 | --data_dir=$DATA_DIR/${TASK_NAME}/ \
15 | --train_max_seq_length=128 \
16 | --eval_max_seq_length=512 \
17 | --per_gpu_train_batch_size=24 \
18 | --per_gpu_eval_batch_size=24 \
19 | --learning_rate=3e-5 \
20 | --crf_learning_rate=1e-3 \
21 | --num_train_epochs=4.0 \
22 | --logging_steps=-1 \
23 | --save_steps=-1 \
24 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \
25 | --overwrite_output_dir \
26 | --seed=42
27 |
--------------------------------------------------------------------------------
/scripts/run_ner_span.sh:
--------------------------------------------------------------------------------
1 | CURRENT_DIR=`pwd`
2 | export BERT_BASE_DIR=$CURRENT_DIR/prev_trained_model/bert-base-chinese
3 | export DATA_DIR=$CURRENT_DIR/datasets
4 | export OUTPUR_DIR=$CURRENT_DIR/outputs
5 | TASK_NAME="cner"
6 |
7 | python run_ner_span.py \
8 | --model_type=bert \
9 | --model_name_or_path=$BERT_BASE_DIR \
10 | --task_name=$TASK_NAME \
11 | --do_train \
12 | --do_eval \
13 | --do_adv \
14 | --do_lower_case \
15 | --loss_type=ce \
16 | --data_dir=$DATA_DIR/${TASK_NAME}/ \
17 | --train_max_seq_length=128 \
18 | --eval_max_seq_length=512 \
19 | --per_gpu_train_batch_size=24 \
20 | --per_gpu_eval_batch_size=24 \
21 | --learning_rate=2e-5 \
22 | --num_train_epochs=4.0 \
23 | --logging_steps=-1 \
24 | --save_steps=-1 \
25 | --output_dir=$OUTPUR_DIR/${TASK_NAME}_output/ \
26 | --overwrite_output_dir \
27 | --seed=42
28 |
29 |
--------------------------------------------------------------------------------
/losses/label_smoothing.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class LabelSmoothingCrossEntropy(nn.Module):
5 | def __init__(self, eps=0.1, reduction='mean',ignore_index=-100):
6 | super(LabelSmoothingCrossEntropy, self).__init__()
7 | self.eps = eps
8 | self.reduction = reduction
9 | self.ignore_index = ignore_index
10 |
11 | def forward(self, output, target):
12 | c = output.size()[-1]
13 | log_preds = F.log_softmax(output, dim=-1)
14 | if self.reduction=='sum':
15 | loss = -log_preds.sum()
16 | else:
17 | loss = -log_preds.sum(dim=-1)
18 | if self.reduction=='mean':
19 | loss = loss.mean()
20 | return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
21 | ignore_index=self.ignore_index)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Weitang Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/models/layers/linears.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FeedForwardNetwork(nn.Module):
6 | def __init__(self, input_size, hidden_size, output_size, dropout_rate=0):
7 | super(FeedForwardNetwork, self).__init__()
8 | self.dropout_rate = dropout_rate
9 | self.linear1 = nn.Linear(input_size, hidden_size)
10 | self.linear2 = nn.Linear(hidden_size, output_size)
11 |
12 | def forward(self, x):
13 | x_proj = F.dropout(F.relu(self.linear1(x)), p=self.dropout_rate, training=self.training)
14 | x_proj = self.linear2(x_proj)
15 | return x_proj
16 |
17 |
18 | class PoolerStartLogits(nn.Module):
19 | def __init__(self, hidden_size, num_classes):
20 | super(PoolerStartLogits, self).__init__()
21 | self.dense = nn.Linear(hidden_size, num_classes)
22 |
23 | def forward(self, hidden_states, p_mask=None):
24 | x = self.dense(hidden_states)
25 | return x
26 |
27 | class PoolerEndLogits(nn.Module):
28 | def __init__(self, hidden_size, num_classes):
29 | super(PoolerEndLogits, self).__init__()
30 | self.dense_0 = nn.Linear(hidden_size, hidden_size)
31 | self.activation = nn.Tanh()
32 | self.LayerNorm = nn.LayerNorm(hidden_size)
33 | self.dense_1 = nn.Linear(hidden_size, num_classes)
34 |
35 | def forward(self, hidden_states, start_positions=None, p_mask=None):
36 | x = self.dense_0(torch.cat([hidden_states, start_positions], dim=-1))
37 | x = self.activation(x)
38 | x = self.LayerNorm(x)
39 | x = self.dense_1(x)
40 | return x
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/datasets/cner/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/outputs/cner_output/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/outputs/cner_output/bert/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/callback/trainingmonitor.py:
--------------------------------------------------------------------------------
1 | # encoding:utf-8
2 | import numpy as np
3 | from pathlib import Path
4 | import matplotlib.pyplot as plt
5 | from ..tools.common import load_json
6 | from ..tools.common import save_json
7 | plt.switch_backend('agg')
8 |
9 | class TrainingMonitor():
10 | def __init__(self, file_dir, arch, add_test=False):
11 | '''
12 | :param startAt: 重新开始训练的epoch点
13 | '''
14 | if isinstance(file_dir, Path):
15 | pass
16 | else:
17 | file_dir = Path(file_dir)
18 | file_dir.mkdir(parents=True, exist_ok=True)
19 |
20 | self.arch = arch
21 | self.file_dir = file_dir
22 | self.H = {}
23 | self.add_test = add_test
24 | self.json_path = file_dir / (arch + "_training_monitor.json")
25 |
26 | def reset(self,start_at):
27 | if start_at > 0:
28 | if self.json_path is not None:
29 | if self.json_path.exists():
30 | self.H = load_json(self.json_path)
31 | for k in self.H.keys():
32 | self.H[k] = self.H[k][:start_at]
33 |
34 | def epoch_step(self, logs={}):
35 | for (k, v) in logs.items():
36 | l = self.H.get(k, [])
37 | # np.float32会报错
38 | if not isinstance(v, np.float):
39 | v = round(float(v), 4)
40 | l.append(v)
41 | self.H[k] = l
42 |
43 | # 写入文件
44 | if self.json_path is not None:
45 | save_json(data = self.H,file_path=self.json_path)
46 |
47 | # 保存train图像
48 | if len(self.H["loss"]) == 1:
49 | self.paths = {key: self.file_dir / (self.arch + f'_{key.upper()}') for key in self.H.keys()}
50 |
51 | if len(self.H["loss"]) > 1:
52 | # 指标变化
53 | # 曲线
54 | # 需要成对出现
55 | keys = [key for key, _ in self.H.items() if '_' not in key]
56 | for key in keys:
57 | N = np.arange(0, len(self.H[key]))
58 | plt.style.use("ggplot")
59 | plt.figure()
60 | plt.plot(N, self.H[key], label=f"train_{key}")
61 | plt.plot(N, self.H[f"valid_{key}"], label=f"valid_{key}")
62 | if self.add_test:
63 | plt.plot(N, self.H[f"test_{key}"], label=f"test_{key}")
64 | plt.legend()
65 | plt.xlabel("Epoch #")
66 | plt.ylabel(key)
67 | plt.title(f"Training {key} [Epoch {len(self.H[key])}]")
68 | plt.savefig(str(self.paths[key]))
69 | plt.close()
70 |
--------------------------------------------------------------------------------
/callback/optimizater/planradam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 | class PlainRAdam(Optimizer):
5 |
6 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
7 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
8 |
9 | super(PlainRAdam, self).__init__(params, defaults)
10 |
11 | def __setstate__(self, state):
12 | super(PlainRAdam, self).__setstate__(state)
13 |
14 | def step(self, closure=None):
15 |
16 | loss = None
17 | if closure is not None:
18 | loss = closure()
19 |
20 | for group in self.param_groups:
21 |
22 | for p in group['params']:
23 | if p.grad is None:
24 | continue
25 | grad = p.grad.data.float()
26 | if grad.is_sparse:
27 | raise RuntimeError('RAdam does not support sparse gradients')
28 |
29 | p_data_fp32 = p.data.float()
30 |
31 | state = self.state[p]
32 |
33 | if len(state) == 0:
34 | state['step'] = 0
35 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
36 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
37 | else:
38 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
39 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
40 |
41 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
42 | beta1, beta2 = group['betas']
43 |
44 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
45 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
46 |
47 | state['step'] += 1
48 | beta2_t = beta2 ** state['step']
49 | N_sma_max = 2 / (1 - beta2) - 1
50 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
51 |
52 | if group['weight_decay'] != 0:
53 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
54 |
55 | # more conservative since it's an approximated value
56 | if N_sma >= 5:
57 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
58 | denom = exp_avg_sq.sqrt().add_(group['eps'])
59 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
60 | else:
61 | step_size = group['lr'] / (1 - beta1 ** state['step'])
62 | p_data_fp32.add_(-step_size, exp_avg)
63 |
64 | p.data.copy_(p_data_fp32)
65 |
66 | return loss
--------------------------------------------------------------------------------
/tools/convert_albert_tf_checkpoint_to_pytorch.py:
--------------------------------------------------------------------------------
1 | """Convert ALBERT checkpoint."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import argparse
8 | import torch
9 | from models.transformers.modeling_albert import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
10 | # from model.modeling_albert_bright import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
11 | import logging
12 | logging.basicConfig(level=logging.INFO)
13 |
14 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
15 | # Initialise PyTorch model
16 | config = AlbertConfig.from_pretrained(bert_config_file)
17 | # print("Building PyTorch model from configuration: {}".format(str(config)))
18 | model = AlbertForPreTraining(config)
19 | # Load weights from tf checkpoint
20 | load_tf_weights_in_albert(model, config, tf_checkpoint_path)
21 |
22 | # Save pytorch-model
23 | print("Save PyTorch model to {}".format(pytorch_dump_path))
24 | torch.save(model.state_dict(), pytorch_dump_path)
25 |
26 |
27 | if __name__ == "__main__":
28 | parser = argparse.ArgumentParser()
29 | ## Required parameters
30 | parser.add_argument("--tf_checkpoint_path",
31 | default = None,
32 | type = str,
33 | required = True,
34 | help = "Path to the TensorFlow checkpoint path.")
35 | parser.add_argument("--bert_config_file",
36 | default = None,
37 | type = str,
38 | required = True,
39 | help = "The config json file corresponding to the pre-trained BERT model. \n"
40 | "This specifies the model architecture.")
41 | parser.add_argument("--pytorch_dump_path",
42 | default = None,
43 | type = str,
44 | required = True,
45 | help = "Path to the output PyTorch model.")
46 | args = parser.parse_args()
47 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,args.bert_config_file,
48 | args.pytorch_dump_path)
49 |
50 | '''
51 | python convert_albert_tf_checkpoint_to_pytorch.py \
52 | --tf_checkpoint_path=./prev_trained_model/albert_large_zh \
53 | --bert_config_file=./prev_trained_model/albert_large_zh/config.json \
54 | --pytorch_dump_path=./prev_trained_model/albert_large_zh/pytorch_model.bin
55 |
56 |
57 | from model.modeling_albert_bright import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
58 | python convert_albert_tf_checkpoint_to_pytorch.py \
59 | --tf_checkpoint_path=./prev_trained_model/albert_base_bright \
60 | --bert_config_file=./prev_trained_model/albert_base_bright/config.json \
61 | --pytorch_dump_path=./prev_trained_model/albert_base_bright/pytorch_model.bin
62 | '''
--------------------------------------------------------------------------------
/tools/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from sklearn.metrics import confusion_matrix
4 | plt.switch_backend('agg')
5 |
6 | def plot_confusion_matrix(y_true, y_pred, classes,
7 | save_path,normalize=False,title=None,
8 | cmap=plt.cm.Blues):
9 | """
10 | This function prints and plots the confusion matrix.
11 | Normalization can be applied by setting `normalize=True`.
12 | """
13 | if not title:
14 | if normalize:
15 | title = 'Normalized confusion matrix'
16 | else:
17 | title = 'Confusion matrix, without normalization'
18 | # Compute confusion matrix
19 | cm = confusion_matrix(y_true=y_true, y_pred=y_pred)
20 | # Only use the labels that appear in the data
21 | if normalize:
22 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
23 | print("Normalized confusion matrix")
24 | else:
25 | print('Confusion matrix, without normalization')
26 | # --- plot--- #
27 | plt.rcParams['savefig.dpi'] = 200
28 | plt.rcParams['figure.dpi'] = 200
29 | plt.rcParams['figure.figsize'] = [20, 20] # plot
30 | plt.rcParams.update({'font.size': 10})
31 | fig, ax = plt.subplots()
32 | im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
33 | # --- bar --- #
34 | from mpl_toolkits.axes_grid1 import make_axes_locatable
35 | divider = make_axes_locatable(ax)
36 | cax = divider.append_axes("right", size="5%", pad=0.05)
37 | plt.colorbar(im, cax=cax)
38 | # --- bar --- #
39 | # ax.figure.colorbar(im, ax=ax)
40 | # We want to show all ticks...
41 | ax.set(xticks=np.arange(cm.shape[1]),
42 | yticks=np.arange(cm.shape[0]),
43 | # ... and label them with the respective list entries
44 | xticklabels=classes, yticklabels=classes,
45 | title=title,
46 | ylabel='True label',
47 | xlabel='Predicted label')
48 |
49 | # Rotate the tick labels and set their alignment.
50 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
51 | rotation_mode="anchor")
52 | # Loop over data dimensions and create text annotations.
53 | fmt = '.2f' if normalize else 'd'
54 | thresh = cm.max() / 2.
55 | for i in range(cm.shape[0]):
56 | for j in range(cm.shape[1]):
57 | ax.text(j, i, format(cm[i, j], fmt),
58 | ha="center", va="center",
59 | color="white" if cm[i, j] > thresh else "black")
60 | fig.tight_layout()
61 | plt.savefig(save_path)
62 |
63 | if __name__ == "__main__":
64 | y_true = ['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O']
65 | y_pred = ['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O','B-PER', 'I-PER', 'O']
66 | classes = ['O','B-MISC', 'I-MISC','B-PER', 'I-PER']
67 | save_path = './ner_confusion_matrix.png'
68 | plot_confusion_matrix(y_true,y_pred,classes,save_path)
--------------------------------------------------------------------------------
/callback/progressbar.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import time
3 |
4 |
5 | class ProgressBar(object):
6 | '''
7 | custom progress bar
8 | Example:
9 | >>> pbar = ProgressBar(n_total=30,desc='Training')
10 | >>> step = 2
11 | >>> pbar(step=step,info={'loss':20})
12 | '''
13 |
14 | def __init__(self, n_total, width=30, desc='Training',num_epochs = None):
15 |
16 | self.width = width
17 | self.n_total = n_total
18 | self.desc = desc
19 | self.start_time = time.time()
20 | self.num_epochs = num_epochs
21 |
22 | def reset(self):
23 | """Method to reset internal variables."""
24 | self.start_time = time.time()
25 |
26 | def _time_info(self, now, current):
27 | time_per_unit = (now - self.start_time) / current
28 | if current < self.n_total:
29 | eta = time_per_unit * (self.n_total - current)
30 | if eta > 3600:
31 | eta_format = ('%d:%02d:%02d' %
32 | (eta // 3600, (eta % 3600) // 60, eta % 60))
33 | elif eta > 60:
34 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
35 | else:
36 | eta_format = '%ds' % eta
37 | time_info = f' - ETA: {eta_format}'
38 | else:
39 | if time_per_unit >= 1:
40 | time_info = f' {time_per_unit:.1f}s/step'
41 | elif time_per_unit >= 1e-3:
42 | time_info = f' {time_per_unit * 1e3:.1f}ms/step'
43 | else:
44 | time_info = f' {time_per_unit * 1e6:.1f}us/step'
45 | return time_info
46 |
47 | def _bar(self, now, current):
48 | recv_per = current / self.n_total
49 | bar = f'[{self.desc}] {current}/{self.n_total} ['
50 | if recv_per >= 1: recv_per = 1
51 | prog_width = int(self.width * recv_per)
52 | if prog_width > 0:
53 | bar += '=' * (prog_width - 1)
54 | if current < self.n_total:
55 | bar += ">"
56 | else:
57 | bar += '='
58 | bar += '.' * (self.width - prog_width)
59 | bar += ']'
60 | return bar
61 |
62 | def epoch_start(self,current_epoch):
63 | sys.stdout.write("\n")
64 | if (current_epoch is not None) and (self.num_epochs is not None):
65 | sys.stdout.write(f"Epoch: {current_epoch}/{self.num_epochs}")
66 | sys.stdout.write("\n")
67 |
68 | def __call__(self, step, info={}):
69 | now = time.time()
70 | current = step + 1
71 | bar = self._bar(now, current)
72 | show_bar = f"\r{bar}" + self._time_info(now, current)
73 | if len(info) != 0:
74 | show_bar = f'{show_bar} ' + " [" + "-".join(
75 | [f' {key}={value:.4f} ' for key, value in info.items()]) + "]"
76 | if current >= self.n_total:
77 | show_bar += '\n'
78 | sys.stdout.write(show_bar)
79 | sys.stdout.flush()
80 |
81 |
--------------------------------------------------------------------------------
/tools/download_clue_data.py:
--------------------------------------------------------------------------------
1 | """ Script for downloading all CLUE data.
2 | For licence information, see the original dataset information links
3 | available from: https://www.cluebenchmarks.com/
4 | Example usage:
5 | python download_clue_data.py --data_dir data --tasks all
6 | """
7 |
8 | import os
9 | import sys
10 | import argparse
11 | import urllib.request
12 | import zipfile
13 |
14 | TASKS = ["afqmc", "cmnli", "copa", "csl", "iflytek", "tnews", "wsc","cmrc","chid","drcd",'cluener']
15 |
16 | TASK2PATH = {
17 | "afqmc": "https://storage.googleapis.com/cluebenchmark/tasks/afqmc_public.zip",
18 | "cmnli": "https://storage.googleapis.com/cluebenchmark/tasks/cmnli_public.zip",
19 | "copa": "https://storage.googleapis.com/cluebenchmark/tasks/copa_public.zip",
20 | "csl": "https://storage.googleapis.com/cluebenchmark/tasks/csl_public.zip",
21 | "iflytek": "https://storage.googleapis.com/cluebenchmark/tasks/iflytek_public.zip",
22 | "tnews": "https://storage.googleapis.com/cluebenchmark/tasks/tnews_public.zip",
23 | "wsc": "https://storage.googleapis.com/cluebenchmark/tasks/wsc_public.zip",
24 | 'cmrc': "https://storage.googleapis.com/cluebenchmark/tasks/cmrc2018_public.zip",
25 | "chid": "https://storage.googleapis.com/cluebenchmark/tasks/chid_public.zip",
26 | "drcd": "https://storage.googleapis.com/cluebenchmark/tasks/drcd_public.zip",
27 | 'cluener':'https://storage.googleapis.com/cluebenchmark/tasks/cluener_public.zip'
28 | }
29 |
30 | def download_and_extract(task, data_dir):
31 | print("Downloading and extracting %s..." % task)
32 | if not os.path.isdir(data_dir):
33 | os.mkdir(data_dir)
34 | data_file = os.path.join(data_dir, "%s_public.zip" % task)
35 | save_dir = os.path.join(data_dir,task)
36 | if not os.path.isdir(save_dir):
37 | os.mkdir(save_dir)
38 | urllib.request.urlretrieve(TASK2PATH[task], data_file)
39 | with zipfile.ZipFile(data_file) as zip_ref:
40 | zip_ref.extractall(save_dir)
41 | os.remove(data_file)
42 | print(f"\tCompleted! Downloaded {task} data to directory {save_dir}")
43 |
44 | def get_tasks(task_names):
45 | task_names = task_names.split(",")
46 | if "all" in task_names:
47 | tasks = TASKS
48 | else:
49 | tasks = []
50 | for task_name in task_names:
51 | assert task_name in TASKS, "Task %s not found!" % task_name
52 | tasks.append(task_name)
53 | return tasks
54 |
55 | def main(arguments):
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument(
58 | "-d", "--data_dir", help="directory to save data to", type=str, default="../CLUEdatasets"
59 | )
60 | parser.add_argument(
61 | "-t",
62 | "--tasks",
63 | help="tasks to download data for as a comma separated string",
64 | type=str,
65 | default="all",
66 | )
67 | args = parser.parse_args(arguments)
68 |
69 | if not os.path.exists(args.data_dir):
70 | os.mkdir(args.data_dir)
71 | tasks = get_tasks(args.tasks)
72 |
73 | for task in tasks:
74 | download_and_extract(task, args.data_dir)
75 |
76 | if __name__ == "__main__":
77 | sys.exit(main(sys.argv[1:]))
78 |
79 | '''
80 | python tools/download_clue_data.py --data_dir=./CLUEdatasets --tasks=cluener
81 | '''
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Chinese NER using Bert
2 |
3 | BERT for Chinese NER.
4 |
5 | **update**:其他一些可以参考,包括Biaffine、GlobalPointer等:[examples](https://github.com/lonePatient/TorchBlocks/tree/master/examples)
6 |
7 | ### dataset list
8 |
9 | 1. cner: datasets/cner
10 | 2. CLUENER: https://github.com/CLUEbenchmark/CLUENER
11 |
12 | ### model list
13 |
14 | 1. BERT+Softmax
15 | 2. BERT+CRF
16 | 3. BERT+Span
17 |
18 | ### requirement
19 |
20 | 1. 1.1.0 =< PyTorch < 1.5.0
21 | 2. cuda=9.0
22 | 3. python3.6+
23 |
24 | ### input format
25 |
26 | Input format (prefer BIOS tag scheme), with each character its label for one line. Sentences are splited with a null line.
27 |
28 | ```text
29 | 美 B-LOC
30 | 国 I-LOC
31 | 的 O
32 | 华 B-PER
33 | 莱 I-PER
34 | 士 I-PER
35 |
36 | 我 O
37 | 跟 O
38 | 他 O
39 | ```
40 |
41 | ### run the code
42 |
43 | 1. Modify the configuration information in `run_ner_xxx.py` or `run_ner_xxx.sh` .
44 | 2. `sh scripts/run_ner_xxx.sh`
45 |
46 | **note**: file structure of the model
47 |
48 | ```text
49 | ├── prev_trained_model
50 | | └── bert_base
51 | | | └── pytorch_model.bin
52 | | | └── config.json
53 | | | └── vocab.txt
54 | | | └── ......
55 | ```
56 |
57 | ### CLUENER result
58 |
59 | The overall performance of BERT on **dev**:
60 |
61 | | | Accuracy (entity) | Recall (entity) | F1 score (entity) |
62 | | ------------ | ------------------ | ------------------ | ------------------ |
63 | | BERT+Softmax | 0.7897 | 0.8031 | 0.7963 |
64 | | BERT+CRF | 0.7977 | 0.8177 | 0.8076 |
65 | | BERT+Span | 0.8132 | 0.8092 | 0.8112 |
66 | | BERT+Span+adv | 0.8267 | 0.8073 | **0.8169** |
67 | | BERT-small(6 layers)+Span+kd | 0.8241 | 0.7839 | 0.8051 |
68 | | BERT+Span+focal_loss | 0.8121 | 0.8008 | 0.8064 |
69 | | BERT+Span+label_smoothing | 0.8235 | 0.7946 | 0.8088 |
70 |
71 | ### ALBERT for CLUENER
72 |
73 | The overall performance of ALBERT on **dev**:
74 |
75 | | model | version | Accuracy(entity) | Recall(entity) | F1(entity) | Train time/epoch |
76 | | ------ | ------------- | ---------------- | -------------- | ---------- | ---------------- |
77 | | albert | base_google | 0.8014 | 0.6908 | 0.7420 | 0.75x |
78 | | albert | large_google | 0.8024 | 0.7520 | 0.7763 | 2.1x |
79 | | albert | xlarge_google | 0.8286 | 0.7773 | 0.8021 | 6.7x |
80 | | bert | google | 0.8118 | 0.8031 | **0.8074** | ----- |
81 | | albert | base_bright | 0.8068 | 0.7529 | 0.7789 | 0.75x |
82 | | albert | large_bright | 0.8152 | 0.7480 | 0.7802 | 2.2x |
83 | | albert | xlarge_bright | 0.8222 | 0.7692 | 0.7948 | 7.3x |
84 |
85 | ### Cner result
86 |
87 | The overall performance of BERT on **dev(test)**:
88 |
89 | | | Accuracy (entity) | Recall (entity) | F1 score (entity) |
90 | | ------------ | ------------------ | ------------------ | ------------------ |
91 | | BERT+Softmax | 0.9586(0.9566) | 0.9644(0.9613) | 0.9615(0.9590) |
92 | | BERT+CRF | 0.9562(0.9539) | 0.9671(**0.9644**) | 0.9616(0.9591) |
93 | | BERT+Span | 0.9604(**0.9620**) | 0.9617(0.9632) | 0.9611(**0.9626**) |
94 | | BERT+Span+focal_loss | 0.9516(0.9569) | 0.9644(0.9681) | 0.9580(0.9625) |
95 | | BERT+Span+label_smoothing | 0.9566(0.9568) | 0.9624(0.9656) | 0.9595(0.9612) |
96 |
--------------------------------------------------------------------------------
/callback/optimizater/novograd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class NovoGrad(Optimizer):
7 | """Implements NovoGrad algorithm.
8 | Arguments:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float, optional): learning rate (default: 1e-2)
12 | betas (Tuple[float, float], optional): coefficients used for computing
13 | running averages of gradient and its square (default: (0.95, 0.98))
14 | eps (float, optional): term added to the denominator to improve
15 | numerical stability (default: 1e-8)
16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
17 | Example:
18 | >>> model = ResNet()
19 | >>> optimizer = NovoGrad(model.parameters(), lr=1e-2, weight_decay=1e-5)
20 | """
21 |
22 | def __init__(self, params, lr=0.01, betas=(0.95, 0.98), eps=1e-8,
23 | weight_decay=0, grad_averaging=False):
24 | if lr < 0.0:
25 | raise ValueError("Invalid learning rate: {}".format(lr))
26 | if not 0.0 <= betas[0] < 1.0:
27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
28 | if not 0.0 <= betas[1] < 1.0:
29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
30 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging)
31 | super().__init__(params, defaults)
32 |
33 | def step(self, closure=None):
34 | loss = None
35 | if closure is not None:
36 | loss = closure()
37 | for group in self.param_groups:
38 | for p in group['params']:
39 | if p.grad is None:
40 | continue
41 | grad = p.grad.data
42 | if grad.is_sparse:
43 | raise RuntimeError('NovoGrad does not support sparse gradients')
44 | state = self.state[p]
45 | g_2 = torch.sum(grad ** 2)
46 | if len(state) == 0:
47 | state['step'] = 0
48 | state['moments'] = grad.div(g_2.sqrt() + group['eps']) + \
49 | group['weight_decay'] * p.data
50 | state['grads_ema'] = g_2
51 | moments = state['moments']
52 | grads_ema = state['grads_ema']
53 | beta1, beta2 = group['betas']
54 | state['step'] += 1
55 | grads_ema.mul_(beta2).add_(1 - beta2, g_2)
56 |
57 | denom = grads_ema.sqrt().add_(group['eps'])
58 | grad.div_(denom)
59 | # weight decay
60 | if group['weight_decay'] != 0:
61 | decayed_weights = torch.mul(p.data, group['weight_decay'])
62 | grad.add_(decayed_weights)
63 |
64 | # Momentum --> SAG
65 | if group['grad_averaging']:
66 | grad.mul_(1.0 - beta1)
67 |
68 | moments.mul_(beta1).add_(grad) # velocity
69 |
70 | bias_correction1 = 1 - beta1 ** state['step']
71 | bias_correction2 = 1 - beta2 ** state['step']
72 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
73 | p.data.add_(-step_size, moments)
74 |
75 | return loss
76 |
--------------------------------------------------------------------------------
/callback/optimizater/sgdw.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 |
4 | class SGDW(Optimizer):
5 | r"""Implements stochastic gradient descent (optionally with momentum) with
6 | weight decay from the paper `Fixing Weight Decay Regularization in Adam`_.
7 |
8 | Nesterov momentum is based on the formula from
9 | `On the importance of initialization and momentum in deep learning`__.
10 |
11 | Args:
12 | params (iterable): iterable of parameters to optimize or dicts defining
13 | parameter groups
14 | lr (float): learning rate
15 | momentum (float, optional): momentum factor (default: 0)
16 | weight_decay (float, optional): weight decay factor (default: 0)
17 | dampening (float, optional): dampening for momentum (default: 0)
18 | nesterov (bool, optional): enables Nesterov momentum (default: False)
19 |
20 | .. _Fixing Weight Decay Regularization in Adam:
21 | https://arxiv.org/abs/1711.05101
22 |
23 | Example:
24 | >>> model = LSTM()
25 | >>> optimizer = SGDW(model.parameters(), lr=0.1, momentum=0.9,weight_decay=1e-5)
26 | """
27 | def __init__(self, params, lr=0.1, momentum=0, dampening=0,
28 | weight_decay=0, nesterov=False):
29 | if lr < 0.0:
30 | raise ValueError(f"Invalid learning rate: {lr}")
31 | if momentum < 0.0:
32 | raise ValueError(f"Invalid momentum value: {momentum}")
33 | if weight_decay < 0.0:
34 | raise ValueError(f"Invalid weight_decay value: {weight_decay}")
35 |
36 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
37 | weight_decay=weight_decay, nesterov=nesterov)
38 | if nesterov and (momentum <= 0 or dampening != 0):
39 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
40 | super(SGDW, self).__init__(params, defaults)
41 |
42 | def __setstate__(self, state):
43 | super(SGDW, self).__setstate__(state)
44 | for group in self.param_groups:
45 | group.setdefault('nesterov', False)
46 |
47 | def step(self, closure=None):
48 | """Performs a single optimization step.
49 |
50 | Arguments:
51 | closure (callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | loss = None
55 | if closure is not None:
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | weight_decay = group['weight_decay']
60 | momentum = group['momentum']
61 | dampening = group['dampening']
62 | nesterov = group['nesterov']
63 | for p in group['params']:
64 | if p.grad is None:
65 | continue
66 | d_p = p.grad.data
67 | if momentum != 0:
68 | param_state = self.state[p]
69 | if 'momentum_buffer' not in param_state:
70 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
71 | buf.mul_(momentum).add_(d_p)
72 | else:
73 | buf = param_state['momentum_buffer']
74 | buf.mul_(momentum).add_(1 - dampening, d_p)
75 | if nesterov:
76 | d_p = d_p.add(momentum, buf)
77 | else:
78 | d_p = buf
79 | if weight_decay != 0:
80 | p.data.add_(-weight_decay, p.data)
81 | p.data.add_(-group['lr'], d_p)
82 | return loss
--------------------------------------------------------------------------------
/callback/optimizater/lars.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 |
4 | class Lars(Optimizer):
5 | r"""Implements the LARS optimizer from https://arxiv.org/pdf/1708.03888.pdf
6 |
7 | Args:
8 | params (iterable): iterable of parameters to optimize or dicts defining
9 | parameter groups
10 | lr (float): learning rate
11 | momentum (float, optional): momentum factor (default: 0)
12 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
13 | dampening (float, optional): dampening for momentum (default: 0)
14 | nesterov (bool, optional): enables Nesterov momentum (default: False)
15 | scale_clip (tuple, optional): the lower and upper bounds for the weight norm in local LR of LARS
16 | Example:
17 | >>> model = ResNet()
18 | >>> optimizer = Lars(model.parameters(), lr=1e-2, weight_decay=1e-5)
19 | """
20 |
21 | def __init__(self, params, lr, momentum=0, dampening=0,
22 | weight_decay=0, nesterov=False, scale_clip=None):
23 | if lr < 0.0:
24 | raise ValueError("Invalid learning rate: {}".format(lr))
25 | if momentum < 0.0:
26 | raise ValueError("Invalid momentum value: {}".format(momentum))
27 | if weight_decay < 0.0:
28 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
29 |
30 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
31 | weight_decay=weight_decay, nesterov=nesterov)
32 | if nesterov and (momentum <= 0 or dampening != 0):
33 | raise ValueError("Nesterov momentum requires a momentum and zero dampening")
34 | super(Lars, self).__init__(params, defaults)
35 | # LARS arguments
36 | self.scale_clip = scale_clip
37 | if self.scale_clip is None:
38 | self.scale_clip = (0, 10)
39 |
40 | def __setstate__(self, state):
41 | super(Lars, self).__setstate__(state)
42 | for group in self.param_groups:
43 | group.setdefault('nesterov', False)
44 |
45 | def step(self, closure=None):
46 | """Performs a single optimization step.
47 |
48 | Arguments:
49 | closure (callable, optional): A closure that reevaluates the model
50 | and returns the loss.
51 | """
52 | loss = None
53 | if closure is not None:
54 | loss = closure()
55 |
56 | for group in self.param_groups:
57 | weight_decay = group['weight_decay']
58 | momentum = group['momentum']
59 | dampening = group['dampening']
60 | nesterov = group['nesterov']
61 |
62 | for p in group['params']:
63 | if p.grad is None:
64 | continue
65 | d_p = p.grad.data
66 | if weight_decay != 0:
67 | d_p.add_(weight_decay, p.data)
68 | if momentum != 0:
69 | param_state = self.state[p]
70 | if 'momentum_buffer' not in param_state:
71 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
72 | else:
73 | buf = param_state['momentum_buffer']
74 | buf.mul_(momentum).add_(1 - dampening, d_p)
75 | if nesterov:
76 | d_p = d_p.add(momentum, buf)
77 | else:
78 | d_p = buf
79 |
80 | # LARS
81 | p_norm = p.data.pow(2).sum().sqrt()
82 | update_norm = d_p.pow(2).sum().sqrt()
83 | # Compute the local LR
84 | if p_norm == 0 or update_norm == 0:
85 | local_lr = 1
86 | else:
87 | local_lr = p_norm / update_norm
88 |
89 | p.data.add_(-group['lr'] * local_lr, d_p)
90 |
91 | return loss
--------------------------------------------------------------------------------
/callback/modelcheckpoint.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import numpy as np
3 | import torch
4 | from ..tools.common import logger
5 |
6 | class ModelCheckpoint(object):
7 | '''
8 | 模型保存,两种模式:
9 | 1. 直接保存最好模型
10 | 2. 按照epoch频率保存模型
11 | '''
12 | def __init__(self, checkpoint_dir,
13 | monitor,
14 | arch,mode='min',
15 | epoch_freq=1,
16 | best = None,
17 | save_best_only = True):
18 | if isinstance(checkpoint_dir,Path):
19 | checkpoint_dir = checkpoint_dir
20 | else:
21 | checkpoint_dir = Path(checkpoint_dir)
22 | assert checkpoint_dir.is_dir()
23 | checkpoint_dir.mkdir(exist_ok=True)
24 | self.base_path = checkpoint_dir
25 | self.arch = arch
26 | self.monitor = monitor
27 | self.epoch_freq = epoch_freq
28 | self.save_best_only = save_best_only
29 |
30 | # 计算模式
31 | if mode == 'min':
32 | self.monitor_op = np.less
33 | self.best = np.Inf
34 |
35 | elif mode == 'max':
36 | self.monitor_op = np.greater
37 | self.best = -np.Inf
38 | # 这里主要重新加载模型时候
39 | #对best重新赋值
40 | if best:
41 | self.best = best
42 |
43 | if save_best_only:
44 | self.model_name = f"BEST_{arch}_MODEL.pth"
45 |
46 | def epoch_step(self, state,current):
47 | '''
48 | 正常模型
49 | :param state: 需要保存的信息
50 | :param current: 当前判断指标
51 | :return:
52 | '''
53 | # 是否保存最好模型
54 | if self.save_best_only:
55 | if self.monitor_op(current, self.best):
56 | logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}")
57 | self.best = current
58 | state['best'] = self.best
59 | best_path = self.base_path/ self.model_name
60 | torch.save(state, str(best_path))
61 | # 每隔几个epoch保存下模型
62 | else:
63 | filename = self.base_path / f"EPOCH_{state['epoch']}_{state[self.monitor]}_{self.arch}_MODEL.pth"
64 | if state['epoch'] % self.epoch_freq == 0:
65 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.")
66 | torch.save(state, str(filename))
67 |
68 | def bert_epoch_step(self, state,current):
69 | '''
70 | 适合bert类型模型,适合pytorch_transformer模块
71 | :param state:
72 | :param current:
73 | :return:
74 | '''
75 | model_to_save = state['model']
76 | if self.save_best_only:
77 | if self.monitor_op(current, self.best):
78 | logger.info(f"\nEpoch {state['epoch']}: {self.monitor} improved from {self.best:.5f} to {current:.5f}")
79 | self.best = current
80 | state['best'] = self.best
81 | model_to_save.save_pretrained(str(self.base_path))
82 | output_config_file = self.base_path / 'configs.json'
83 | with open(str(output_config_file), 'w') as f:
84 | f.write(model_to_save.config.to_json_string())
85 | state.pop("model")
86 | torch.save(state,self.base_path / 'checkpoint_info.bin')
87 | else:
88 | if state['epoch'] % self.epoch_freq == 0:
89 | save_path = self.base_path / f"checkpoint-epoch-{state['epoch']}"
90 | save_path.mkdir(exist_ok=True)
91 | logger.info(f"\nEpoch {state['epoch']}: save model to disk.")
92 | model_to_save.save_pretrained(save_path)
93 | output_config_file = save_path / 'configs.json'
94 | with open(str(output_config_file), 'w') as f:
95 | f.write(model_to_save.config.to_json_string())
96 | state.pop("model")
97 | torch.save(state, save_path / 'checkpoint_info.bin')
98 |
--------------------------------------------------------------------------------
/callback/optimizater/radam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 | class RAdam(Optimizer):
5 | """Implements the RAdam optimizer from https://arxiv.org/pdf/1908.03265.pdf
6 | Args:
7 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
8 | lr (float, optional): learning rate
9 | betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
10 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
11 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
12 | Example:
13 | >>> model = ResNet()
14 | >>> optimizer = RAdam(model.parameters(), lr=0.001)
15 | """
16 |
17 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
18 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
19 | self.buffer = [[None, None, None] for ind in range(10)]
20 | super(RAdam, self).__init__(params, defaults)
21 |
22 | def __setstate__(self, state):
23 | super(RAdam, self).__setstate__(state)
24 |
25 | def step(self, closure=None):
26 |
27 | loss = None
28 | if closure is not None:
29 | loss = closure()
30 |
31 | for group in self.param_groups:
32 |
33 | for p in group['params']:
34 | if p.grad is None:
35 | continue
36 | grad = p.grad.data.float()
37 | if grad.is_sparse:
38 | raise RuntimeError('RAdam does not support sparse gradients')
39 |
40 | p_data_fp32 = p.data.float()
41 |
42 | state = self.state[p]
43 |
44 | if len(state) == 0:
45 | state['step'] = 0
46 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
47 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
48 | else:
49 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
50 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
51 |
52 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
53 | beta1, beta2 = group['betas']
54 |
55 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
56 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
57 |
58 | state['step'] += 1
59 | buffered = self.buffer[int(state['step'] % 10)]
60 | if state['step'] == buffered[0]:
61 | N_sma, step_size = buffered[1], buffered[2]
62 | else:
63 | buffered[0] = state['step']
64 | beta2_t = beta2 ** state['step']
65 | N_sma_max = 2 / (1 - beta2) - 1
66 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
67 | buffered[1] = N_sma
68 |
69 | # more conservative since it's an approximated value
70 | if N_sma >= 5:
71 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
72 | else:
73 | step_size = 1.0 / (1 - beta1 ** state['step'])
74 | buffered[2] = step_size
75 |
76 | if group['weight_decay'] != 0:
77 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
78 |
79 | # more conservative since it's an approximated value
80 | if N_sma >= 5:
81 | denom = exp_avg_sq.sqrt().add_(group['eps'])
82 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
83 | else:
84 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
85 |
86 | p.data.copy_(p_data_fp32)
87 |
88 | return loss
--------------------------------------------------------------------------------
/callback/optimizater/nadam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class Nadam(Optimizer):
6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
7 |
8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
9 |
10 | Arguments:
11 | params (iterable): iterable of parameters to optimize or dicts defining
12 | parameter groups
13 | lr (float, optional): learning rate (default: 2e-3)
14 | betas (Tuple[float, float], optional): coefficients used for computing
15 | running averages of gradient and its square
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
20 |
21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf
22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
23 |
24 | Originally taken from: https://github.com/pytorch/pytorch/pull/1408
25 | NOTE: Has potential issues but does work well on some problems.
26 | Example:
27 | >>> model = LSTM()
28 | >>> optimizer = Nadam(model.parameters())
29 | """
30 |
31 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
32 | weight_decay=0, schedule_decay=4e-3):
33 | defaults = dict(lr=lr, betas=betas, eps=eps,
34 | weight_decay=weight_decay, schedule_decay=schedule_decay)
35 | super(Nadam, self).__init__(params, defaults)
36 |
37 | def step(self, closure=None):
38 | """Performs a single optimization step.
39 |
40 | Arguments:
41 | closure (callable, optional): A closure that reevaluates the model
42 | and returns the loss.
43 | """
44 | loss = None
45 | if closure is not None:
46 | loss = closure()
47 |
48 | for group in self.param_groups:
49 | for p in group['params']:
50 | if p.grad is None:
51 | continue
52 | grad = p.grad.data
53 | state = self.state[p]
54 |
55 | # State initialization
56 | if len(state) == 0:
57 | state['step'] = 0
58 | state['m_schedule'] = 1.
59 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
60 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
61 |
62 | # Warming momentum schedule
63 | m_schedule = state['m_schedule']
64 | schedule_decay = group['schedule_decay']
65 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
66 | beta1, beta2 = group['betas']
67 | eps = group['eps']
68 | state['step'] += 1
69 | t = state['step']
70 |
71 | if group['weight_decay'] != 0:
72 | grad = grad.add(group['weight_decay'], p.data)
73 |
74 | momentum_cache_t = beta1 * \
75 | (1. - 0.5 * (0.96 ** (t * schedule_decay)))
76 | momentum_cache_t_1 = beta1 * \
77 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
78 | m_schedule_new = m_schedule * momentum_cache_t
79 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
80 | state['m_schedule'] = m_schedule_new
81 |
82 | # Decay the first and second moment running average coefficient
83 | exp_avg.mul_(beta1).add_(1. - beta1, grad)
84 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad)
85 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t)
86 | denom = exp_avg_sq_prime.sqrt_().add_(eps)
87 |
88 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom)
89 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom)
90 |
91 | return loss
--------------------------------------------------------------------------------
/metrics/ner_metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import Counter
3 | from processors.utils_ner import get_entities
4 |
5 | class SeqEntityScore(object):
6 | def __init__(self, id2label,markup='bios'):
7 | self.id2label = id2label
8 | self.markup = markup
9 | self.reset()
10 |
11 | def reset(self):
12 | self.origins = []
13 | self.founds = []
14 | self.rights = []
15 |
16 | def compute(self, origin, found, right):
17 | recall = 0 if origin == 0 else (right / origin)
18 | precision = 0 if found == 0 else (right / found)
19 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
20 | return recall, precision, f1
21 |
22 | def result(self):
23 | class_info = {}
24 | origin_counter = Counter([x[0] for x in self.origins])
25 | found_counter = Counter([x[0] for x in self.founds])
26 | right_counter = Counter([x[0] for x in self.rights])
27 | for type_, count in origin_counter.items():
28 | origin = count
29 | found = found_counter.get(type_, 0)
30 | right = right_counter.get(type_, 0)
31 | recall, precision, f1 = self.compute(origin, found, right)
32 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
33 | origin = len(self.origins)
34 | found = len(self.founds)
35 | right = len(self.rights)
36 | recall, precision, f1 = self.compute(origin, found, right)
37 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
38 |
39 | def update(self, label_paths, pred_paths):
40 | '''
41 | labels_paths: [[],[],[],....]
42 | pred_paths: [[],[],[],.....]
43 |
44 | :param label_paths:
45 | :param pred_paths:
46 | :return:
47 | Example:
48 | >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
49 | >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
50 | '''
51 | for label_path, pre_path in zip(label_paths, pred_paths):
52 | label_entities = get_entities(label_path, self.id2label,self.markup)
53 | pre_entities = get_entities(pre_path, self.id2label,self.markup)
54 | self.origins.extend(label_entities)
55 | self.founds.extend(pre_entities)
56 | self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])
57 |
58 | class SpanEntityScore(object):
59 | def __init__(self, id2label):
60 | self.id2label = id2label
61 | self.reset()
62 |
63 | def reset(self):
64 | self.origins = []
65 | self.founds = []
66 | self.rights = []
67 |
68 | def compute(self, origin, found, right):
69 | recall = 0 if origin == 0 else (right / origin)
70 | precision = 0 if found == 0 else (right / found)
71 | f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
72 | return recall, precision, f1
73 |
74 | def result(self):
75 | class_info = {}
76 | origin_counter = Counter([self.id2label[x[0]] for x in self.origins])
77 | found_counter = Counter([self.id2label[x[0]] for x in self.founds])
78 | right_counter = Counter([self.id2label[x[0]] for x in self.rights])
79 | for type_, count in origin_counter.items():
80 | origin = count
81 | found = found_counter.get(type_, 0)
82 | right = right_counter.get(type_, 0)
83 | recall, precision, f1 = self.compute(origin, found, right)
84 | class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
85 | origin = len(self.origins)
86 | found = len(self.founds)
87 | right = len(self.rights)
88 | recall, precision, f1 = self.compute(origin, found, right)
89 | return {'acc': precision, 'recall': recall, 'f1': f1}, class_info
90 |
91 | def update(self, true_subject, pred_subject):
92 | self.origins.extend(true_subject)
93 | self.founds.extend(pred_subject)
94 | self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject])
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/callback/adversarial.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class FGM():
4 | '''
5 | Example
6 | # 初始化
7 | fgm = FGM(model,epsilon=1,emb_name='word_embeddings.')
8 | for batch_input, batch_label in data:
9 | # 正常训练
10 | loss = model(batch_input, batch_label)
11 | loss.backward() # 反向传播,得到正常的grad
12 | # 对抗训练
13 | fgm.attack() # 在embedding上添加对抗扰动
14 | loss_adv = model(batch_input, batch_label)
15 | loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
16 | fgm.restore() # 恢复embedding参数
17 | # 梯度下降,更新参数
18 | optimizer.step()
19 | model.zero_grad()
20 | '''
21 | def __init__(self, model,emb_name,epsilon=1.0):
22 | # emb_name这个参数要换成你模型中embedding的参数名
23 | self.model = model
24 | self.epsilon = epsilon
25 | self.emb_name = emb_name
26 | self.backup = {}
27 |
28 | def attack(self):
29 | for name, param in self.model.named_parameters():
30 | if param.requires_grad and self.emb_name in name:
31 | self.backup[name] = param.data.clone()
32 | norm = torch.norm(param.grad)
33 | if norm!=0 and not torch.isnan(norm):
34 | r_at = self.epsilon * param.grad / norm
35 | param.data.add_(r_at)
36 |
37 | def restore(self):
38 | for name, param in self.model.named_parameters():
39 | if param.requires_grad and self.emb_name in name:
40 | assert name in self.backup
41 | param.data = self.backup[name]
42 | self.backup = {}
43 |
44 | class PGD():
45 | '''
46 | Example
47 | pgd = PGD(model,emb_name='word_embeddings.',epsilon=1.0,alpha=0.3)
48 | K = 3
49 | for batch_input, batch_label in data:
50 | # 正常训练
51 | loss = model(batch_input, batch_label)
52 | loss.backward() # 反向传播,得到正常的grad
53 | pgd.backup_grad()
54 | # 对抗训练
55 | for t in range(K):
56 | pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data
57 | if t != K-1:
58 | model.zero_grad()
59 | else:
60 | pgd.restore_grad()
61 | loss_adv = model(batch_input, batch_label)
62 | loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
63 | pgd.restore() # 恢复embedding参数
64 | # 梯度下降,更新参数
65 | optimizer.step()
66 | model.zero_grad()
67 | '''
68 | def __init__(self, model,emb_name,epsilon=1.,alpha=0.3):
69 | # emb_name这个参数要换成你模型中embedding的参数名
70 | self.model = model
71 | self.emb_name = emb_name
72 | self.epsilon = epsilon
73 | self.alpha = alpha
74 | self.emb_backup = {}
75 | self.grad_backup = {}
76 |
77 | def attack(self,is_first_attack=False):
78 | for name, param in self.model.named_parameters():
79 | if param.requires_grad and self.emb_name in name:
80 | if is_first_attack:
81 | self.emb_backup[name] = param.data.clone()
82 | norm = torch.norm(param.grad)
83 | if norm != 0:
84 | r_at = self.alpha * param.grad / norm
85 | param.data.add_(r_at)
86 | param.data = self.project(name, param.data, self.epsilon)
87 |
88 | def restore(self):
89 | for name, param in self.model.named_parameters():
90 | if param.requires_grad and self.emb_name in name:
91 | assert name in self.emb_backup
92 | param.data = self.emb_backup[name]
93 | self.emb_backup = {}
94 |
95 | def project(self, param_name, param_data, epsilon):
96 | r = param_data - self.emb_backup[param_name]
97 | if torch.norm(r) > epsilon:
98 | r = epsilon * r / torch.norm(r)
99 | return self.emb_backup[param_name] + r
100 |
101 | def backup_grad(self):
102 | for name, param in self.model.named_parameters():
103 | if param.requires_grad:
104 | self.grad_backup[name] = param.grad.clone()
105 |
106 | def restore_grad(self):
107 | for name, param in self.model.named_parameters():
108 | if param.requires_grad:
109 | param.grad = self.grad_backup[name]
--------------------------------------------------------------------------------
/callback/optimizater/adamw.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class AdamW(Optimizer):
6 | """ Implements Adam algorithm with weight decay fix.
7 |
8 | Parameters:
9 | lr (float): learning rate. Default 1e-3.
10 | betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
11 | eps (float): Adams epsilon. Default: 1e-6
12 | weight_decay (float): Weight decay. Default: 0.0
13 | correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
14 | Example:
15 | >>> model = LSTM()
16 | >>> optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
17 | """
18 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
19 | if lr < 0.0:
20 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
21 | if not 0.0 <= betas[0] < 1.0:
22 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
23 | if not 0.0 <= betas[1] < 1.0:
24 | raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
25 | if not 0.0 <= eps:
26 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
27 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
28 | correct_bias=correct_bias)
29 | super(AdamW, self).__init__(params, defaults)
30 |
31 | def step(self, closure=None):
32 | """Performs a single optimization step.
33 |
34 | Arguments:
35 | closure (callable, optional): A closure that reevaluates the model
36 | and returns the loss.
37 | """
38 | loss = None
39 | if closure is not None:
40 | loss = closure()
41 |
42 | for group in self.param_groups:
43 | for p in group['params']:
44 | if p.grad is None:
45 | continue
46 | grad = p.grad.data
47 | if grad.is_sparse:
48 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
49 |
50 | state = self.state[p]
51 |
52 | # State initialization
53 | if len(state) == 0:
54 | state['step'] = 0
55 | # Exponential moving average of gradient values
56 | state['exp_avg'] = torch.zeros_like(p.data)
57 | # Exponential moving average of squared gradient values
58 | state['exp_avg_sq'] = torch.zeros_like(p.data)
59 |
60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
61 | beta1, beta2 = group['betas']
62 |
63 | state['step'] += 1
64 |
65 | # Decay the first and second moment running average coefficient
66 | # In-place operations to update the averages at the same time
67 | exp_avg.mul_(beta1).add_(1.0 - beta1, grad)
68 | exp_avg_sq.mul_(beta2).addcmul_(1.0 - beta2, grad, grad)
69 | denom = exp_avg_sq.sqrt().add_(group['eps'])
70 |
71 | step_size = group['lr']
72 | if group['correct_bias']: # No bias correction for Bert
73 | bias_correction1 = 1.0 - beta1 ** state['step']
74 | bias_correction2 = 1.0 - beta2 ** state['step']
75 | step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
76 |
77 | p.data.addcdiv_(-step_size, exp_avg, denom)
78 |
79 | # Just adding the square of the weights to the loss function is *not*
80 | # the correct way of using L2 regularization/weight decay with Adam,
81 | # since that will interact with the m and v parameters in strange ways.
82 | #
83 | # Instead we want to decay the weights in a manner that doesn't interact
84 | # with the m/v parameters. This is equivalent to adding the square
85 | # of the weights to the loss with plain (non-momentum) SGD.
86 | # Add weight decay at the end (fixed version)
87 | if group['weight_decay'] > 0.0:
88 | p.data.add_(-group['lr'] * group['weight_decay'], p.data)
89 |
90 | return loss
91 |
--------------------------------------------------------------------------------
/callback/optimizater/ralamb.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class Ralamb(Optimizer):
6 | '''
7 | RAdam + LARS
8 | Example:
9 | >>> model = ResNet()
10 | >>> optimizer = Ralamb(model.parameters(), lr=0.001)
11 | '''
12 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
13 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
14 | self.buffer = [[None, None, None] for ind in range(10)]
15 | super(Ralamb, self).__init__(params, defaults)
16 |
17 | def __setstate__(self, state):
18 | super(Ralamb, self).__setstate__(state)
19 |
20 | def step(self, closure=None):
21 |
22 | loss = None
23 | if closure is not None:
24 | loss = closure()
25 |
26 | for group in self.param_groups:
27 |
28 | for p in group['params']:
29 | if p.grad is None:
30 | continue
31 | grad = p.grad.data.float()
32 | if grad.is_sparse:
33 | raise RuntimeError('Ralamb does not support sparse gradients')
34 |
35 | p_data_fp32 = p.data.float()
36 |
37 | state = self.state[p]
38 |
39 | if len(state) == 0:
40 | state['step'] = 0
41 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
42 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
43 | else:
44 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
45 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
46 |
47 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
48 | beta1, beta2 = group['betas']
49 |
50 | # Decay the first and second moment running average coefficient
51 | # m_t
52 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
53 | # v_t
54 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
55 |
56 | state['step'] += 1
57 | buffered = self.buffer[int(state['step'] % 10)]
58 |
59 | if state['step'] == buffered[0]:
60 | N_sma, radam_step_size = buffered[1], buffered[2]
61 | else:
62 | buffered[0] = state['step']
63 | beta2_t = beta2 ** state['step']
64 | N_sma_max = 2 / (1 - beta2) - 1
65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
66 | buffered[1] = N_sma
67 |
68 | # more conservative since it's an approximated value
69 | if N_sma >= 5:
70 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
71 | else:
72 | radam_step_size = 1.0 / (1 - beta1 ** state['step'])
73 | buffered[2] = radam_step_size
74 |
75 | if group['weight_decay'] != 0:
76 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
77 |
78 | # more conservative since it's an approximated value
79 | radam_step = p_data_fp32.clone()
80 | if N_sma >= 5:
81 | denom = exp_avg_sq.sqrt().add_(group['eps'])
82 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
83 | else:
84 | radam_step.add_(-radam_step_size * group['lr'], exp_avg)
85 |
86 | radam_norm = radam_step.pow(2).sum().sqrt()
87 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
88 | if weight_norm == 0 or radam_norm == 0:
89 | trust_ratio = 1
90 | else:
91 | trust_ratio = weight_norm / radam_norm
92 |
93 | state['weight_norm'] = weight_norm
94 | state['adam_norm'] = radam_norm
95 | state['trust_ratio'] = trust_ratio
96 |
97 | if N_sma >= 5:
98 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
99 | else:
100 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
101 |
102 | p.data.copy_(p_data_fp32)
103 |
104 | return loss
--------------------------------------------------------------------------------
/callback/optimizater/lookahead.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 | from collections import defaultdict
4 |
5 | class Lookahead(Optimizer):
6 | '''
7 | PyTorch implementation of the lookahead wrapper.
8 | Lookahead Optimizer: https://arxiv.org/abs/1907.08610
9 |
10 | We found that evaluation performance is typically better using the slow weights.
11 | This can be done in PyTorch with something like this in your eval loop:
12 | if args.lookahead:
13 | optimizer._backup_and_load_cache()
14 | val_loss = eval_func(model)
15 | optimizer._clear_and_load_backup()
16 | '''
17 | def __init__(self, optimizer,alpha=0.5, k=6,pullback_momentum="none"):
18 | '''
19 | :param optimizer:inner optimizer
20 | :param k (int): number of lookahead steps
21 | :param alpha(float): linear interpolation factor. 1.0 recovers the inner optimizer.
22 | :param pullback_momentum (str): change to inner optimizer momentum on interpolation update
23 | '''
24 | if not 0.0 <= alpha <= 1.0:
25 | raise ValueError(f'Invalid slow update rate: {alpha}')
26 | if not 1 <= k:
27 | raise ValueError(f'Invalid lookahead steps: {k}')
28 | self.optimizer = optimizer
29 | self.param_groups = self.optimizer.param_groups
30 | self.alpha = alpha
31 | self.k = k
32 | self.step_counter = 0
33 | assert pullback_momentum in ["reset", "pullback", "none"]
34 | self.pullback_momentum = pullback_momentum
35 | self.state = defaultdict(dict)
36 |
37 | # Cache the current optimizer parameters
38 | for group in self.optimizer.param_groups:
39 | for p in group['params']:
40 | param_state = self.state[p]
41 | param_state['cached_params'] = torch.zeros_like(p.data)
42 | param_state['cached_params'].copy_(p.data)
43 |
44 | def __getstate__(self):
45 | return {
46 | 'state': self.state,
47 | 'optimizer': self.optimizer,
48 | 'alpha': self.alpha,
49 | 'step_counter': self.step_counter,
50 | 'k':self.k,
51 | 'pullback_momentum': self.pullback_momentum
52 | }
53 |
54 | def zero_grad(self):
55 | self.optimizer.zero_grad()
56 |
57 | def state_dict(self):
58 | return self.optimizer.state_dict()
59 |
60 | def load_state_dict(self, state_dict):
61 | self.optimizer.load_state_dict(state_dict)
62 |
63 | def _backup_and_load_cache(self):
64 | """Useful for performing evaluation on the slow weights (which typically generalize better)
65 | """
66 | for group in self.optimizer.param_groups:
67 | for p in group['params']:
68 | param_state = self.state[p]
69 | param_state['backup_params'] = torch.zeros_like(p.data)
70 | param_state['backup_params'].copy_(p.data)
71 | p.data.copy_(param_state['cached_params'])
72 |
73 | def _clear_and_load_backup(self):
74 | for group in self.optimizer.param_groups:
75 | for p in group['params']:
76 | param_state = self.state[p]
77 | p.data.copy_(param_state['backup_params'])
78 | del param_state['backup_params']
79 |
80 | def step(self, closure=None):
81 | """Performs a single Lookahead optimization step.
82 | Arguments:
83 | closure (callable, optional): A closure that reevaluates the model
84 | and returns the loss.
85 | """
86 | loss = self.optimizer.step(closure)
87 | self.step_counter += 1
88 |
89 | if self.step_counter >= self.k:
90 | self.step_counter = 0
91 | # Lookahead and cache the current optimizer parameters
92 | for group in self.optimizer.param_groups:
93 | for p in group['params']:
94 | param_state = self.state[p]
95 | p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line
96 | param_state['cached_params'].copy_(p.data)
97 | if self.pullback_momentum == "pullback":
98 | internal_momentum = self.optimizer.state[p]["momentum_buffer"]
99 | self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_(
100 | 1.0 - self.alpha, param_state["cached_mom"])
101 | param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
102 | elif self.pullback_momentum == "reset":
103 | self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)
104 |
105 | return loss
106 |
--------------------------------------------------------------------------------
/callback/optimizater/lamb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.optimizer import Optimizer
3 |
4 |
5 | class Lamb(Optimizer):
6 | r"""Implements Lamb algorithm.
7 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
8 | Arguments:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float, optional): learning rate (default: 1e-3)
12 | betas (Tuple[float, float], optional): coefficients used for computing
13 | running averages of gradient and its square (default: (0.9, 0.999))
14 | eps (float, optional): term added to the denominator to improve
15 | numerical stability (default: 1e-8)
16 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
17 | adam (bool, optional): always use trust ratio = 1, which turns this into
18 | Adam. Useful for comparison purposes.
19 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
20 | https://arxiv.org/abs/1904.00962
21 | Example:
22 | >>> model = ResNet()
23 | >>> optimizer = Lamb(model.parameters(), lr=1e-2, weight_decay=1e-5)
24 | """
25 |
26 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
27 | weight_decay=0, adam=False):
28 | if not 0.0 <= lr:
29 | raise ValueError("Invalid learning rate: {}".format(lr))
30 | if not 0.0 <= eps:
31 | raise ValueError("Invalid epsilon value: {}".format(eps))
32 | if not 0.0 <= betas[0] < 1.0:
33 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
34 | if not 0.0 <= betas[1] < 1.0:
35 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
36 | defaults = dict(lr=lr, betas=betas, eps=eps,
37 | weight_decay=weight_decay)
38 | self.adam = adam
39 | super(Lamb, self).__init__(params, defaults)
40 |
41 | def step(self, closure=None):
42 | """Performs a single optimization step.
43 | Arguments:
44 | closure (callable, optional): A closure that reevaluates the model
45 | and returns the loss.
46 | """
47 | loss = None
48 | if closure is not None:
49 | loss = closure()
50 |
51 | for group in self.param_groups:
52 | for p in group['params']:
53 | if p.grad is None:
54 | continue
55 | grad = p.grad.data
56 | if grad.is_sparse:
57 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
58 |
59 | state = self.state[p]
60 |
61 | # State initialization
62 | if len(state) == 0:
63 | state['step'] = 0
64 | # Exponential moving average of gradient values
65 | state['exp_avg'] = torch.zeros_like(p.data)
66 | # Exponential moving average of squared gradient values
67 | state['exp_avg_sq'] = torch.zeros_like(p.data)
68 |
69 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
70 | beta1, beta2 = group['betas']
71 |
72 | state['step'] += 1
73 |
74 | # Decay the first and second moment running average coefficient
75 | # m_t
76 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
77 | # v_t
78 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
79 |
80 | # Paper v3 does not use debiasing.
81 | # bias_correction1 = 1 - beta1 ** state['step']
82 | # bias_correction2 = 1 - beta2 ** state['step']
83 | # Apply bias to lr to avoid broadcast.
84 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
85 |
86 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
87 |
88 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
89 | if group['weight_decay'] != 0:
90 | adam_step.add_(group['weight_decay'], p.data)
91 |
92 | adam_norm = adam_step.pow(2).sum().sqrt()
93 | if weight_norm == 0 or adam_norm == 0:
94 | trust_ratio = 1
95 | else:
96 | trust_ratio = weight_norm / adam_norm
97 | state['weight_norm'] = weight_norm
98 | state['adam_norm'] = adam_norm
99 | state['trust_ratio'] = trust_ratio
100 | if self.adam:
101 | trust_ratio = 1
102 |
103 | p.data.add_(-step_size * trust_ratio, adam_step)
104 |
105 | return loss
--------------------------------------------------------------------------------
/callback/optimizater/ralars.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer
4 |
5 |
6 | class RaLars(Optimizer):
7 | """Implements the RAdam optimizer from https://arxiv.org/pdf/1908.03265.pdf
8 | with optional Layer-wise adaptive Scaling from https://arxiv.org/pdf/1708.03888.pdf
9 |
10 | Args:
11 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups
12 | lr (float, optional): learning rate
13 | betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999))
14 | eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
15 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
16 | scale_clip (float, optional): the maximal upper bound for the scale factor of LARS
17 | Example:
18 | >>> model = ResNet()
19 | >>> optimizer = RaLars(model.parameters(), lr=0.001)
20 | """
21 |
22 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
23 | scale_clip=None):
24 | if not 0.0 <= lr:
25 | raise ValueError("Invalid learning rate: {}".format(lr))
26 | if not 0.0 <= eps:
27 | raise ValueError("Invalid epsilon value: {}".format(eps))
28 | if not 0.0 <= betas[0] < 1.0:
29 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
30 | if not 0.0 <= betas[1] < 1.0:
31 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
32 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
33 | super(RaLars, self).__init__(params, defaults)
34 | # LARS arguments
35 | self.scale_clip = scale_clip
36 | if self.scale_clip is None:
37 | self.scale_clip = (0, 10)
38 |
39 | def step(self, closure=None):
40 | """Performs a single optimization step.
41 | Arguments:
42 | closure (callable, optional): A closure that reevaluates the model
43 | and returns the loss.
44 | """
45 | loss = None
46 | if closure is not None:
47 | loss = closure()
48 |
49 | for group in self.param_groups:
50 |
51 | # Get group-shared variables
52 | beta1, beta2 = group['betas']
53 | sma_inf = group.get('sma_inf')
54 | # Compute max length of SMA on first step
55 | if not isinstance(sma_inf, float):
56 | group['sma_inf'] = 2 / (1 - beta2) - 1
57 | sma_inf = group.get('sma_inf')
58 |
59 | for p in group['params']:
60 | if p.grad is None:
61 | continue
62 | grad = p.grad.data
63 | if grad.is_sparse:
64 | raise RuntimeError('RAdam does not support sparse gradients')
65 |
66 | state = self.state[p]
67 |
68 | # State initialization
69 | if len(state) == 0:
70 | state['step'] = 0
71 | # Exponential moving average of gradient values
72 | state['exp_avg'] = torch.zeros_like(p.data)
73 | # Exponential moving average of squared gradient values
74 | state['exp_avg_sq'] = torch.zeros_like(p.data)
75 |
76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
77 |
78 | state['step'] += 1
79 |
80 | # Decay the first and second moment running average coefficient
81 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
82 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
83 |
84 | # Bias correction
85 | bias_correction1 = 1 - beta1 ** state['step']
86 | bias_correction2 = 1 - beta2 ** state['step']
87 |
88 | # Compute length of SMA
89 | sma_t = sma_inf - 2 * state['step'] * (1 - bias_correction2) / bias_correction2
90 |
91 | update = torch.zeros_like(p.data)
92 | if sma_t > 4:
93 | # Variance rectification term
94 | r_t = math.sqrt((sma_t - 4) * (sma_t - 2) * sma_inf / ((sma_inf - 4) * (sma_inf - 2) * sma_t))
95 | # Adaptive momentum
96 | update.addcdiv_(r_t, exp_avg / bias_correction1,
97 | (exp_avg_sq / bias_correction2).sqrt().add_(group['eps']))
98 | else:
99 | # Unadapted momentum
100 | update.add_(exp_avg / bias_correction1)
101 |
102 | # Weight decay
103 | if group['weight_decay'] != 0:
104 | update.add_(group['weight_decay'], p.data)
105 |
106 | # LARS
107 | p_norm = p.data.pow(2).sum().sqrt()
108 | update_norm = update.pow(2).sum().sqrt()
109 | phi_p = p_norm.clamp(*self.scale_clip)
110 | # Compute the local LR
111 | if phi_p == 0 or update_norm == 0:
112 | local_lr = 1
113 | else:
114 | local_lr = phi_p / update_norm
115 |
116 | state['local_lr'] = local_lr
117 |
118 | p.data.add_(-group['lr'] * local_lr, update)
119 |
120 | return loss
121 |
--------------------------------------------------------------------------------
/callback/optimizater/adabound.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch.optim.optimizer import Optimizer
4 |
5 | class AdaBound(Optimizer):
6 | """Implements AdaBound algorithm.
7 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
8 | Arguments:
9 | params (iterable): iterable of parameters to optimize or dicts defining
10 | parameter groups
11 | lr (float, optional): Adam learning rate (default: 1e-3)
12 | betas (Tuple[float, float], optional): coefficients used for computing
13 | running averages of gradient and its square (default: (0.9, 0.999))
14 | final_lr (float, optional): final (SGD) learning rate (default: 0.1)
15 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
16 | eps (float, optional): term added to the denominator to improve
17 | numerical stability (default: 1e-8)
18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
19 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
20 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
21 | https://openreview.net/forum?id=Bkg3g2R9FX
22 | Example:
23 | >>> model = LSTM()
24 | >>> optimizer = AdaBound(model.parameters())
25 | """
26 |
27 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
28 | eps=1e-8, weight_decay=0, amsbound=False):
29 | if not 0.0 <= lr:
30 | raise ValueError("Invalid learning rate: {}".format(lr))
31 | if not 0.0 <= eps:
32 | raise ValueError("Invalid epsilon value: {}".format(eps))
33 | if not 0.0 <= betas[0] < 1.0:
34 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
35 | if not 0.0 <= betas[1] < 1.0:
36 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
37 | if not 0.0 <= final_lr:
38 | raise ValueError("Invalid final learning rate: {}".format(final_lr))
39 | if not 0.0 <= gamma < 1.0:
40 | raise ValueError("Invalid gamma parameter: {}".format(gamma))
41 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
42 | weight_decay=weight_decay, amsbound=amsbound)
43 | super(AdaBound, self).__init__(params, defaults)
44 |
45 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
46 |
47 | def __setstate__(self, state):
48 | super(AdaBound, self).__setstate__(state)
49 | for group in self.param_groups:
50 | group.setdefault('amsbound', False)
51 |
52 | def step(self, closure=None):
53 | """Performs a single optimization step.
54 | Arguments:
55 | closure (callable, optional): A closure that reevaluates the model
56 | and returns the loss.
57 | """
58 | loss = None
59 | if closure is not None:
60 | loss = closure()
61 | for group, base_lr in zip(self.param_groups, self.base_lrs):
62 | for p in group['params']:
63 | if p.grad is None:
64 | continue
65 | grad = p.grad.data
66 | if grad.is_sparse:
67 | raise RuntimeError(
68 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
69 | amsbound = group['amsbound']
70 | state = self.state[p]
71 | # State initialization
72 | if len(state) == 0:
73 | state['step'] = 0
74 | # Exponential moving average of gradient values
75 | state['exp_avg'] = torch.zeros_like(p.data)
76 | # Exponential moving average of squared gradient values
77 | state['exp_avg_sq'] = torch.zeros_like(p.data)
78 | if amsbound:
79 | # Maintains max of all exp. moving avg. of sq. grad. values
80 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
81 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
82 | if amsbound:
83 | max_exp_avg_sq = state['max_exp_avg_sq']
84 | beta1, beta2 = group['betas']
85 | state['step'] += 1
86 | if group['weight_decay'] != 0:
87 | grad = grad.add(group['weight_decay'], p.data)
88 | # Decay the first and second moment running average coefficient
89 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
90 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
91 | if amsbound:
92 | # Maintains the maximum of all 2nd moment running avg. till now
93 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
94 | # Use the max. for normalizing running avg. of gradient
95 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
96 | else:
97 | denom = exp_avg_sq.sqrt().add_(group['eps'])
98 |
99 | bias_correction1 = 1 - beta1 ** state['step']
100 | bias_correction2 = 1 - beta2 ** state['step']
101 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
102 |
103 | # Applies bounds on actual learning rate
104 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
105 | final_lr = group['final_lr'] * group['lr'] / base_lr
106 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
107 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
108 | step_size = torch.full_like(denom, step_size)
109 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
110 | p.data.add_(-step_size)
111 | return loss
--------------------------------------------------------------------------------
/models/bert_for_ner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .layers.crf import CRF
5 | from transformers import BertModel,BertPreTrainedModel
6 | from .layers.linears import PoolerEndLogits, PoolerStartLogits
7 | from torch.nn import CrossEntropyLoss
8 | from losses.focal_loss import FocalLoss
9 | from losses.label_smoothing import LabelSmoothingCrossEntropy
10 |
11 | class BertSoftmaxForNer(BertPreTrainedModel):
12 | def __init__(self, config):
13 | super(BertSoftmaxForNer, self).__init__(config)
14 | self.num_labels = config.num_labels
15 | self.bert = BertModel(config)
16 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
17 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
18 | self.loss_type = config.loss_type
19 | self.init_weights()
20 |
21 | def forward(self, input_ids, attention_mask=None, token_type_ids=None,labels=None):
22 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
23 | sequence_output = outputs[0]
24 | sequence_output = self.dropout(sequence_output)
25 | logits = self.classifier(sequence_output)
26 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
27 | if labels is not None:
28 | assert self.loss_type in ['lsr', 'focal', 'ce']
29 | if self.loss_type == 'lsr':
30 | loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
31 | elif self.loss_type == 'focal':
32 | loss_fct = FocalLoss(ignore_index=0)
33 | else:
34 | loss_fct = CrossEntropyLoss(ignore_index=0)
35 | # Only keep active parts of the loss
36 | if attention_mask is not None:
37 | active_loss = attention_mask.view(-1) == 1
38 | active_logits = logits.view(-1, self.num_labels)[active_loss]
39 | active_labels = labels.view(-1)[active_loss]
40 | loss = loss_fct(active_logits, active_labels)
41 | else:
42 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
43 | outputs = (loss,) + outputs
44 | return outputs # (loss), scores, (hidden_states), (attentions)
45 |
46 | class BertCrfForNer(BertPreTrainedModel):
47 | def __init__(self, config):
48 | super(BertCrfForNer, self).__init__(config)
49 | self.bert = BertModel(config)
50 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
51 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
52 | self.crf = CRF(num_tags=config.num_labels, batch_first=True)
53 | self.init_weights()
54 |
55 | def forward(self, input_ids, token_type_ids=None, attention_mask=None,labels=None):
56 | outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
57 | sequence_output = outputs[0]
58 | sequence_output = self.dropout(sequence_output)
59 | logits = self.classifier(sequence_output)
60 | outputs = (logits,)
61 | if labels is not None:
62 | loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
63 | outputs =(-1*loss,)+outputs
64 | return outputs # (loss), scores
65 |
66 | class BertSpanForNer(BertPreTrainedModel):
67 | def __init__(self, config,):
68 | super(BertSpanForNer, self).__init__(config)
69 | self.soft_label = config.soft_label
70 | self.num_labels = config.num_labels
71 | self.loss_type = config.loss_type
72 | self.bert = BertModel(config)
73 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
74 | self.start_fc = PoolerStartLogits(config.hidden_size, self.num_labels)
75 | if self.soft_label:
76 | self.end_fc = PoolerEndLogits(config.hidden_size + self.num_labels, self.num_labels)
77 | else:
78 | self.end_fc = PoolerEndLogits(config.hidden_size + 1, self.num_labels)
79 | self.init_weights()
80 |
81 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,end_positions=None):
82 | outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
83 | sequence_output = outputs[0]
84 | sequence_output = self.dropout(sequence_output)
85 | start_logits = self.start_fc(sequence_output)
86 | if start_positions is not None and self.training:
87 | if self.soft_label:
88 | batch_size = input_ids.size(0)
89 | seq_len = input_ids.size(1)
90 | label_logits = torch.FloatTensor(batch_size, seq_len, self.num_labels)
91 | label_logits.zero_()
92 | label_logits = label_logits.to(input_ids.device)
93 | label_logits.scatter_(2, start_positions.unsqueeze(2), 1)
94 | else:
95 | label_logits = start_positions.unsqueeze(2).float()
96 | else:
97 | label_logits = F.softmax(start_logits, -1)
98 | if not self.soft_label:
99 | label_logits = torch.argmax(label_logits, -1).unsqueeze(2).float()
100 | end_logits = self.end_fc(sequence_output, label_logits)
101 | outputs = (start_logits, end_logits,) + outputs[2:]
102 |
103 | if start_positions is not None and end_positions is not None:
104 | assert self.loss_type in ['lsr', 'focal', 'ce']
105 | if self.loss_type =='lsr':
106 | loss_fct = LabelSmoothingCrossEntropy()
107 | elif self.loss_type == 'focal':
108 | loss_fct = FocalLoss()
109 | else:
110 | loss_fct = CrossEntropyLoss()
111 | start_logits = start_logits.view(-1, self.num_labels)
112 | end_logits = end_logits.view(-1, self.num_labels)
113 | active_loss = attention_mask.view(-1) == 1
114 | active_start_logits = start_logits[active_loss]
115 | active_end_logits = end_logits[active_loss]
116 |
117 | active_start_labels = start_positions.view(-1)[active_loss]
118 | active_end_labels = end_positions.view(-1)[active_loss]
119 |
120 | start_loss = loss_fct(active_start_logits, active_start_labels)
121 | end_loss = loss_fct(active_end_logits, active_end_labels)
122 | total_loss = (start_loss + end_loss) / 2
123 | outputs = (total_loss,) + outputs
124 | return outputs
125 |
126 |
--------------------------------------------------------------------------------
/tools/finetuning_argparse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_argparse():
4 | parser = argparse.ArgumentParser()
5 | # Required parameters
6 | parser.add_argument("--task_name", default=None, type=str, required=True,
7 | help="The name of the task to train selected in the list: ")
8 | parser.add_argument("--data_dir", default=None, type=str, required=True,
9 | help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", )
10 | parser.add_argument("--model_type", default=None, type=str, required=True,
11 | help="Model type selected in the list: ")
12 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
13 | help="Path to pre-trained model or shortcut name selected in the list: " )
14 | parser.add_argument("--output_dir", default=None, type=str, required=True,
15 | help="The output directory where the model predictions and checkpoints will be written.", )
16 |
17 | # Other parameters
18 | parser.add_argument('--markup', default='bios', type=str,
19 | choices=['bios', 'bio'])
20 | parser.add_argument('--loss_type', default='ce', type=str,
21 | choices=['lsr', 'focal', 'ce'])
22 | parser.add_argument("--config_name", default="", type=str,
23 | help="Pretrained config name or path if not the same as model_name")
24 | parser.add_argument("--tokenizer_name", default="", type=str,
25 | help="Pretrained tokenizer name or path if not the same as model_name", )
26 | parser.add_argument("--cache_dir", default="", type=str,
27 | help="Where do you want to store the pre-trained models downloaded from s3", )
28 | parser.add_argument("--train_max_seq_length", default=128, type=int,
29 | help="The maximum total input sequence length after tokenization. Sequences longer "
30 | "than this will be truncated, sequences shorter will be padded.", )
31 | parser.add_argument("--eval_max_seq_length", default=512, type=int,
32 | help="The maximum total input sequence length after tokenization. Sequences longer "
33 | "than this will be truncated, sequences shorter will be padded.", )
34 | parser.add_argument("--do_train", action="store_true",
35 | help="Whether to run training.")
36 | parser.add_argument("--do_eval", action="store_true",
37 | help="Whether to run eval on the dev set.")
38 | parser.add_argument("--do_predict", action="store_true",
39 | help="Whether to run predictions on the test set.")
40 | parser.add_argument("--evaluate_during_training", action="store_true",
41 | help="Whether to run evaluation during training at each logging step.", )
42 | parser.add_argument("--do_lower_case", action="store_true",
43 | help="Set this flag if you are using an uncased model.")
44 | # adversarial training
45 | parser.add_argument("--do_adv", action="store_true",
46 | help="Whether to adversarial training.")
47 | parser.add_argument('--adv_epsilon', default=1.0, type=float,
48 | help="Epsilon for adversarial.")
49 | parser.add_argument('--adv_name', default='word_embeddings', type=str,
50 | help="name for adversarial layer.")
51 |
52 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
53 | help="Batch size per GPU/CPU for training.")
54 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
55 | help="Batch size per GPU/CPU for evaluation.")
56 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
57 | help="Number of updates steps to accumulate before performing a backward/update pass.", )
58 | parser.add_argument("--learning_rate", default=5e-5, type=float,
59 | help="The initial learning rate for Adam.")
60 | parser.add_argument("--crf_learning_rate", default=5e-5, type=float,
61 | help="The initial learning rate for crf and linear layer.")
62 | parser.add_argument("--weight_decay", default=0.01, type=float,
63 | help="Weight decay if we apply some.")
64 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
65 | help="Epsilon for Adam optimizer.")
66 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
67 | help="Max gradient norm.")
68 | parser.add_argument("--num_train_epochs", default=3.0, type=float,
69 | help="Total number of training epochs to perform.")
70 | parser.add_argument("--max_steps", default=-1, type=int,
71 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", )
72 |
73 | parser.add_argument("--warmup_proportion", default=0.1, type=float,
74 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.")
75 | parser.add_argument("--logging_steps", type=int, default=50,
76 | help="Log every X updates steps.")
77 | parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
78 | parser.add_argument("--eval_all_checkpoints", action="store_true",
79 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", )
80 | parser.add_argument("--predict_checkpoints",type=int, default=0,
81 | help="predict checkpoints starting with the same prefix as model_name ending and ending with step number")
82 | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
83 | parser.add_argument("--overwrite_output_dir", action="store_true",
84 | help="Overwrite the content of the output directory")
85 | parser.add_argument("--overwrite_cache", action="store_true",
86 | help="Overwrite the cached training and evaluation sets")
87 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
88 | parser.add_argument("--fp16", action="store_true",
89 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", )
90 | parser.add_argument("--fp16_opt_level", type=str, default="O1",
91 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
92 | "See details at https://nvidia.github.io/apex/amp.html", )
93 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
94 | parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
95 | parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
96 | return parser
--------------------------------------------------------------------------------
/processors/utils_ner.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import torch
4 | from transformers import BertTokenizer
5 |
6 | class DataProcessor(object):
7 | """Base class for data converters for sequence classification data sets."""
8 |
9 | def get_train_examples(self, data_dir):
10 | """Gets a collection of `InputExample`s for the train set."""
11 | raise NotImplementedError()
12 |
13 | def get_dev_examples(self, data_dir):
14 | """Gets a collection of `InputExample`s for the dev set."""
15 | raise NotImplementedError()
16 |
17 | def get_labels(self):
18 | """Gets the list of labels for this data set."""
19 | raise NotImplementedError()
20 |
21 | @classmethod
22 | def _read_tsv(cls, input_file, quotechar=None):
23 | """Reads a tab separated value file."""
24 | with open(input_file, "r", encoding="utf-8-sig") as f:
25 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
26 | lines = []
27 | for line in reader:
28 | lines.append(line)
29 | return lines
30 |
31 | @classmethod
32 | def _read_text(self,input_file):
33 | lines = []
34 | with open(input_file,'r') as f:
35 | words = []
36 | labels = []
37 | for line in f:
38 | if line.startswith("-DOCSTART-") or line == "" or line == "\n":
39 | if words:
40 | lines.append({"words":words,"labels":labels})
41 | words = []
42 | labels = []
43 | else:
44 | splits = line.split(" ")
45 | words.append(splits[0])
46 | if len(splits) > 1:
47 | labels.append(splits[-1].replace("\n", ""))
48 | else:
49 | # Examples could have no label for mode = "test"
50 | labels.append("O")
51 | if words:
52 | lines.append({"words":words,"labels":labels})
53 | return lines
54 |
55 | @classmethod
56 | def _read_json(self,input_file):
57 | lines = []
58 | with open(input_file,'r') as f:
59 | for line in f:
60 | line = json.loads(line.strip())
61 | text = line['text']
62 | label_entities = line.get('label',None)
63 | words = list(text)
64 | labels = ['O'] * len(words)
65 | if label_entities is not None:
66 | for key,value in label_entities.items():
67 | for sub_name,sub_index in value.items():
68 | for start_index,end_index in sub_index:
69 | assert ''.join(words[start_index:end_index+1]) == sub_name
70 | if start_index == end_index:
71 | labels[start_index] = 'S-'+key
72 | else:
73 | labels[start_index] = 'B-'+key
74 | labels[start_index+1:end_index+1] = ['I-'+key]*(len(sub_name)-1)
75 | lines.append({"words": words, "labels": labels})
76 | return lines
77 |
78 | def get_entity_bios(seq,id2label):
79 | """Gets entities from sequence.
80 | note: BIOS
81 | Args:
82 | seq (list): sequence of labels.
83 | Returns:
84 | list: list of (chunk_type, chunk_start, chunk_end).
85 | Example:
86 | # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
87 | # >>> get_entity_bios(seq)
88 | [['PER', 0,1], ['LOC', 3, 3]]
89 | """
90 | chunks = []
91 | chunk = [-1, -1, -1]
92 | for indx, tag in enumerate(seq):
93 | if not isinstance(tag, str):
94 | tag = id2label[tag]
95 | if tag.startswith("S-"):
96 | if chunk[2] != -1:
97 | chunks.append(chunk)
98 | chunk = [-1, -1, -1]
99 | chunk[1] = indx
100 | chunk[2] = indx
101 | chunk[0] = tag.split('-')[1]
102 | chunks.append(chunk)
103 | chunk = (-1, -1, -1)
104 | if tag.startswith("B-"):
105 | if chunk[2] != -1:
106 | chunks.append(chunk)
107 | chunk = [-1, -1, -1]
108 | chunk[1] = indx
109 | chunk[0] = tag.split('-')[1]
110 | elif tag.startswith('I-') and chunk[1] != -1:
111 | _type = tag.split('-')[1]
112 | if _type == chunk[0]:
113 | chunk[2] = indx
114 | if indx == len(seq) - 1:
115 | chunks.append(chunk)
116 | else:
117 | if chunk[2] != -1:
118 | chunks.append(chunk)
119 | chunk = [-1, -1, -1]
120 | return chunks
121 |
122 | def get_entity_bio(seq,id2label):
123 | """Gets entities from sequence.
124 | note: BIO
125 | Args:
126 | seq (list): sequence of labels.
127 | Returns:
128 | list: list of (chunk_type, chunk_start, chunk_end).
129 | Example:
130 | seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
131 | get_entity_bio(seq)
132 | #output
133 | [['PER', 0,1], ['LOC', 3, 3]]
134 | """
135 | chunks = []
136 | chunk = [-1, -1, -1]
137 | for indx, tag in enumerate(seq):
138 | if not isinstance(tag, str):
139 | tag = id2label[tag]
140 | if tag.startswith("B-"):
141 | if chunk[2] != -1:
142 | chunks.append(chunk)
143 | chunk = [-1, -1, -1]
144 | chunk[1] = indx
145 | chunk[0] = tag.split('-')[1]
146 | chunk[2] = indx
147 | if indx == len(seq) - 1:
148 | chunks.append(chunk)
149 | elif tag.startswith('I-') and chunk[1] != -1:
150 | _type = tag.split('-')[1]
151 | if _type == chunk[0]:
152 | chunk[2] = indx
153 |
154 | if indx == len(seq) - 1:
155 | chunks.append(chunk)
156 | else:
157 | if chunk[2] != -1:
158 | chunks.append(chunk)
159 | chunk = [-1, -1, -1]
160 | return chunks
161 |
162 | def get_entities(seq,id2label,markup='bios'):
163 | '''
164 | :param seq:
165 | :param id2label:
166 | :param markup:
167 | :return:
168 | '''
169 | assert markup in ['bio','bios']
170 | if markup =='bio':
171 | return get_entity_bio(seq,id2label)
172 | else:
173 | return get_entity_bios(seq,id2label)
174 |
175 | def bert_extract_item(start_logits, end_logits):
176 | S = []
177 | start_pred = torch.argmax(start_logits, -1).cpu().numpy()[0][1:-1]
178 | end_pred = torch.argmax(end_logits, -1).cpu().numpy()[0][1:-1]
179 | for i, s_l in enumerate(start_pred):
180 | if s_l == 0:
181 | continue
182 | for j, e_l in enumerate(end_pred[i:]):
183 | if s_l == e_l:
184 | S.append((s_l, i, i + j))
185 | break
186 | return S
187 |
--------------------------------------------------------------------------------
/callback/optimizater/adafactor.py:
--------------------------------------------------------------------------------
1 | import operator
2 | import torch
3 | from copy import copy
4 | import functools
5 | from math import sqrt
6 | from torch.optim.optimizer import Optimizer
7 |
8 |
9 | class AdaFactor(Optimizer):
10 | '''
11 | # Code below is an implementation of https://arxiv.org/pdf/1804.04235.pdf
12 | # inspired but modified from https://github.com/DeadAt0m/adafactor-pytorch
13 | Example:
14 | >>> model = LSTM()
15 | >>> optimizer = AdaFactor(model.parameters(),lr= lr)
16 | '''
17 |
18 | def __init__(self, params, lr=None, beta1=0.9, beta2=0.999, eps1=1e-30,
19 | eps2=1e-3, cliping_threshold=1, non_constant_decay=True,
20 | enable_factorization=True, ams_grad=True, weight_decay=0):
21 |
22 | enable_momentum = beta1 != 0
23 | if non_constant_decay:
24 | ams_grad = False
25 |
26 | defaults = dict(lr=lr, beta1=beta1, beta2=beta2, eps1=eps1,
27 | eps2=eps2, cliping_threshold=cliping_threshold,
28 | weight_decay=weight_decay, ams_grad=ams_grad,
29 | enable_factorization=enable_factorization,
30 | enable_momentum=enable_momentum,
31 | non_constant_decay=non_constant_decay)
32 |
33 | super(AdaFactor, self).__init__(params, defaults)
34 |
35 | def __setstate__(self, state):
36 | super(AdaFactor, self).__setstate__(state)
37 |
38 | def _experimental_reshape(self, shape):
39 | temp_shape = shape[2:]
40 | if len(temp_shape) == 1:
41 | new_shape = (shape[0], shape[1] * shape[2])
42 | else:
43 | tmp_div = len(temp_shape) // 2 + len(temp_shape) % 2
44 | new_shape = (shape[0] * functools.reduce(operator.mul,
45 | temp_shape[tmp_div:], 1),
46 | shape[1] * functools.reduce(operator.mul,
47 | temp_shape[:tmp_div], 1))
48 | return new_shape, copy(shape)
49 |
50 | def _check_shape(self, shape):
51 | '''
52 | output1 - True - algorithm for matrix, False - vector;
53 | output2 - need reshape
54 | '''
55 | if len(shape) > 2:
56 | return True, True
57 | elif len(shape) == 2:
58 | return True, False
59 | elif len(shape) == 2 and (shape[0] == 1 or shape[1] == 1):
60 | return False, False
61 | else:
62 | return False, False
63 |
64 | def _rms(self, x):
65 | return sqrt(torch.mean(x.pow(2)))
66 |
67 | def step(self, closure=None):
68 | loss = None
69 | if closure is not None:
70 | loss = closure()
71 | for group in self.param_groups:
72 | for p in group['params']:
73 | if p.grad is None:
74 | continue
75 | grad = p.grad.data
76 |
77 | if grad.is_sparse:
78 | raise RuntimeError('Adam does not support sparse \
79 | gradients, use SparseAdam instead')
80 |
81 | is_matrix, is_need_reshape = self._check_shape(grad.size())
82 | new_shape = p.data.size()
83 | if is_need_reshape and group['enable_factorization']:
84 | new_shape, old_shape = \
85 | self._experimental_reshape(p.data.size())
86 | grad = grad.view(new_shape)
87 |
88 | state = self.state[p]
89 | if len(state) == 0:
90 | state['step'] = 0
91 | if group['enable_momentum']:
92 | state['exp_avg'] = torch.zeros(new_shape,
93 | dtype=torch.float32,
94 | device=p.grad.device)
95 |
96 | if is_matrix and group['enable_factorization']:
97 | state['exp_avg_sq_R'] = \
98 | torch.zeros((1, new_shape[1]),
99 | dtype=torch.float32,
100 | device=p.grad.device)
101 | state['exp_avg_sq_C'] = \
102 | torch.zeros((new_shape[0], 1),
103 | dtype=torch.float32,
104 | device=p.grad.device)
105 | else:
106 | state['exp_avg_sq'] = torch.zeros(new_shape,
107 | dtype=torch.float32,
108 | device=p.grad.device)
109 | if group['ams_grad']:
110 | state['exp_avg_sq_hat'] = \
111 | torch.zeros(new_shape, dtype=torch.float32,
112 | device=p.grad.device)
113 |
114 | if group['enable_momentum']:
115 | exp_avg = state['exp_avg']
116 |
117 | if is_matrix and group['enable_factorization']:
118 | exp_avg_sq_r = state['exp_avg_sq_R']
119 | exp_avg_sq_c = state['exp_avg_sq_C']
120 | else:
121 | exp_avg_sq = state['exp_avg_sq']
122 |
123 | if group['ams_grad']:
124 | exp_avg_sq_hat = state['exp_avg_sq_hat']
125 |
126 | state['step'] += 1
127 | lr_t = group['lr']
128 | lr_t *= max(group['eps2'], self._rms(p.data))
129 |
130 | if group['enable_momentum']:
131 | if group['non_constant_decay']:
132 | beta1_t = group['beta1'] * \
133 | (1 - group['beta1'] ** (state['step'] - 1)) \
134 | / (1 - group['beta1'] ** state['step'])
135 | else:
136 | beta1_t = group['beta1']
137 | exp_avg.mul_(beta1_t).add_(1 - beta1_t, grad)
138 |
139 | if group['non_constant_decay']:
140 | beta2_t = group['beta2'] * \
141 | (1 - group['beta2'] ** (state['step'] - 1)) / \
142 | (1 - group['beta2'] ** state['step'])
143 | else:
144 | beta2_t = group['beta2']
145 |
146 | if is_matrix and group['enable_factorization']:
147 | exp_avg_sq_r.mul_(beta2_t). \
148 | add_(1 - beta2_t, torch.sum(torch.mul(grad, grad).
149 | add_(group['eps1']),
150 | dim=0, keepdim=True))
151 | exp_avg_sq_c.mul_(beta2_t). \
152 | add_(1 - beta2_t, torch.sum(torch.mul(grad, grad).
153 | add_(group['eps1']),
154 | dim=1, keepdim=True))
155 | v = torch.mul(exp_avg_sq_c,
156 | exp_avg_sq_r).div_(torch.sum(exp_avg_sq_r))
157 | else:
158 | exp_avg_sq.mul_(beta2_t). \
159 | addcmul_(1 - beta2_t, grad, grad). \
160 | add_((1 - beta2_t) * group['eps1'])
161 | v = exp_avg_sq
162 | g = grad
163 | if group['enable_momentum']:
164 | g = torch.div(exp_avg, 1 - beta1_t ** state['step'])
165 | if group['ams_grad']:
166 | torch.max(exp_avg_sq_hat, v, out=exp_avg_sq_hat)
167 | v = exp_avg_sq_hat
168 | u = torch.div(g, (torch.div(v, 1 - beta2_t **
169 | state['step'])).sqrt().add_(group['eps1']))
170 | else:
171 | u = torch.div(g, v.sqrt())
172 | u.div_(max(1, self._rms(u) / group['cliping_threshold']))
173 | p.data.add_(-lr_t * (u.view(old_shape) if is_need_reshape and
174 | group['enable_factorization'] else u))
175 | if group['weight_decay'] != 0:
176 | p.data.add_(-group['weight_decay'] * lr_t, p.data)
177 | return loss
178 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/processors/ner_seq.py:
--------------------------------------------------------------------------------
1 | """ Named entity recognition fine-tuning: utilities to work with CLUENER task. """
2 | import torch
3 | import logging
4 | import os
5 | import copy
6 | import json
7 | from .utils_ner import DataProcessor
8 | logger = logging.getLogger(__name__)
9 |
10 | class InputExample(object):
11 | """A single training/test example for token classification."""
12 | def __init__(self, guid, text_a, labels):
13 | """Constructs a InputExample.
14 | Args:
15 | guid: Unique id for the example.
16 | text_a: list. The words of the sequence.
17 | labels: (Optional) list. The labels for each word of the sequence. This should be
18 | specified for train and dev examples, but not for test examples.
19 | """
20 | self.guid = guid
21 | self.text_a = text_a
22 | self.labels = labels
23 |
24 | def __repr__(self):
25 | return str(self.to_json_string())
26 | def to_dict(self):
27 | """Serializes this instance to a Python dictionary."""
28 | output = copy.deepcopy(self.__dict__)
29 | return output
30 | def to_json_string(self):
31 | """Serializes this instance to a JSON string."""
32 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
33 |
34 | class InputFeatures(object):
35 | """A single set of features of data."""
36 | def __init__(self, input_ids, input_mask, input_len,segment_ids, label_ids):
37 | self.input_ids = input_ids
38 | self.input_mask = input_mask
39 | self.segment_ids = segment_ids
40 | self.label_ids = label_ids
41 | self.input_len = input_len
42 |
43 | def __repr__(self):
44 | return str(self.to_json_string())
45 |
46 | def to_dict(self):
47 | """Serializes this instance to a Python dictionary."""
48 | output = copy.deepcopy(self.__dict__)
49 | return output
50 |
51 | def to_json_string(self):
52 | """Serializes this instance to a JSON string."""
53 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
54 |
55 | def collate_fn(batch):
56 | """
57 | batch should be a list of (sequence, target, length) tuples...
58 | Returns a padded tensor of sequences sorted from longest to shortest,
59 | """
60 | all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))
61 | max_len = max(all_lens).item()
62 | all_input_ids = all_input_ids[:, :max_len]
63 | all_attention_mask = all_attention_mask[:, :max_len]
64 | all_token_type_ids = all_token_type_ids[:, :max_len]
65 | all_labels = all_labels[:,:max_len]
66 | return all_input_ids, all_attention_mask, all_token_type_ids, all_labels,all_lens
67 |
68 | def convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,
69 | cls_token_at_end=False,cls_token="[CLS]",cls_token_segment_id=1,
70 | sep_token="[SEP]",pad_on_left=False,pad_token=0,pad_token_segment_id=0,
71 | sequence_a_segment_id=0,mask_padding_with_zero=True,):
72 | """ Loads a data file into a list of `InputBatch`s
73 | `cls_token_at_end` define the location of the CLS token:
74 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
75 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
76 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
77 | """
78 | label_map = {label: i for i, label in enumerate(label_list)}
79 | features = []
80 | for (ex_index, example) in enumerate(examples):
81 | if ex_index % 10000 == 0:
82 | logger.info("Writing example %d of %d", ex_index, len(examples))
83 | if isinstance(example.text_a,list):
84 | example.text_a = " ".join(example.text_a)
85 | tokens = tokenizer.tokenize(example.text_a)
86 | label_ids = [label_map[x] for x in example.labels]
87 | # Account for [CLS] and [SEP] with "- 2".
88 | special_tokens_count = 2
89 | if len(tokens) > max_seq_length - special_tokens_count:
90 | tokens = tokens[: (max_seq_length - special_tokens_count)]
91 | label_ids = label_ids[: (max_seq_length - special_tokens_count)]
92 |
93 | # The convention in BERT is:
94 | # (a) For sequence pairs:
95 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
96 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
97 | # (b) For single sequences:
98 | # tokens: [CLS] the dog is hairy . [SEP]
99 | # type_ids: 0 0 0 0 0 0 0
100 | #
101 | # Where "type_ids" are used to indicate whether this is the first
102 | # sequence or the second sequence. The embedding vectors for `type=0` and
103 | # `type=1` were learned during pre-training and are added to the wordpiece
104 | # embedding vector (and position vector). This is not *strictly* necessary
105 | # since the [SEP] token unambiguously separates the sequences, but it makes
106 | # it easier for the model to learn the concept of sequences.
107 | #
108 | # For classification tasks, the first vector (corresponding to [CLS]) is
109 | # used as as the "sentence vector". Note that this only makes sense because
110 | # the entire model is fine-tuned.
111 | tokens += [sep_token]
112 | label_ids += [label_map['O']]
113 | segment_ids = [sequence_a_segment_id] * len(tokens)
114 |
115 | if cls_token_at_end:
116 | tokens += [cls_token]
117 | label_ids += [label_map['O']]
118 | segment_ids += [cls_token_segment_id]
119 | else:
120 | tokens = [cls_token] + tokens
121 | label_ids = [label_map['O']] + label_ids
122 | segment_ids = [cls_token_segment_id] + segment_ids
123 |
124 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
125 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
126 | # tokens are attended to.
127 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
128 | input_len = len(label_ids)
129 | # Zero-pad up to the sequence length.
130 | padding_length = max_seq_length - len(input_ids)
131 | if pad_on_left:
132 | input_ids = ([pad_token] * padding_length) + input_ids
133 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
134 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
135 | label_ids = ([pad_token] * padding_length) + label_ids
136 | else:
137 | input_ids += [pad_token] * padding_length
138 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length
139 | segment_ids += [pad_token_segment_id] * padding_length
140 | label_ids += [pad_token] * padding_length
141 |
142 | assert len(input_ids) == max_seq_length
143 | assert len(input_mask) == max_seq_length
144 | assert len(segment_ids) == max_seq_length
145 | assert len(label_ids) == max_seq_length
146 | if ex_index < 5:
147 | logger.info("*** Example ***")
148 | logger.info("guid: %s", example.guid)
149 | logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
150 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
151 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
152 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
153 | logger.info("label_ids: %s", " ".join([str(x) for x in label_ids]))
154 |
155 | features.append(InputFeatures(input_ids=input_ids, input_mask=input_mask,input_len = input_len,
156 | segment_ids=segment_ids, label_ids=label_ids))
157 | return features
158 |
159 |
160 | class CnerProcessor(DataProcessor):
161 | """Processor for the chinese ner data set."""
162 |
163 | def get_train_examples(self, data_dir):
164 | """See base class."""
165 | return self._create_examples(self._read_text(os.path.join(data_dir, "train.char.bmes")), "train")
166 |
167 | def get_dev_examples(self, data_dir):
168 | """See base class."""
169 | return self._create_examples(self._read_text(os.path.join(data_dir, "dev.char.bmes")), "dev")
170 |
171 | def get_test_examples(self, data_dir):
172 | """See base class."""
173 | return self._create_examples(self._read_text(os.path.join(data_dir, "test.char.bmes")), "test")
174 |
175 | def get_labels(self):
176 | """See base class."""
177 | return ["X",'B-CONT','B-EDU','B-LOC','B-NAME','B-ORG','B-PRO','B-RACE','B-TITLE',
178 | 'I-CONT','I-EDU','I-LOC','I-NAME','I-ORG','I-PRO','I-RACE','I-TITLE',
179 | 'O','S-NAME','S-ORG','S-RACE',"[START]", "[END]"]
180 |
181 | def _create_examples(self, lines, set_type):
182 | """Creates examples for the training and dev sets."""
183 | examples = []
184 | for (i, line) in enumerate(lines):
185 | if i == 0:
186 | continue
187 | guid = "%s-%s" % (set_type, i)
188 | text_a= line['words']
189 | # BIOS
190 | labels = []
191 | for x in line['labels']:
192 | if 'M-' in x:
193 | labels.append(x.replace('M-','I-'))
194 | elif 'E-' in x:
195 | labels.append(x.replace('E-', 'I-'))
196 | else:
197 | labels.append(x)
198 | examples.append(InputExample(guid=guid, text_a=text_a, labels=labels))
199 | return examples
200 |
201 | class CluenerProcessor(DataProcessor):
202 | """Processor for the chinese ner data set."""
203 |
204 | def get_train_examples(self, data_dir):
205 | """See base class."""
206 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train")
207 |
208 | def get_dev_examples(self, data_dir):
209 | """See base class."""
210 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev")
211 |
212 | def get_test_examples(self, data_dir):
213 | """See base class."""
214 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test")
215 |
216 | def get_labels(self):
217 | """See base class."""
218 | return ["X", "B-address", "B-book", "B-company", 'B-game', 'B-government', 'B-movie', 'B-name',
219 | 'B-organization', 'B-position','B-scene',"I-address",
220 | "I-book", "I-company", 'I-game', 'I-government', 'I-movie', 'I-name',
221 | 'I-organization', 'I-position','I-scene',
222 | "S-address", "S-book", "S-company", 'S-game', 'S-government', 'S-movie',
223 | 'S-name', 'S-organization', 'S-position',
224 | 'S-scene','O',"[START]", "[END]"]
225 |
226 | def _create_examples(self, lines, set_type):
227 | """Creates examples for the training and dev sets."""
228 | examples = []
229 | for (i, line) in enumerate(lines):
230 | guid = "%s-%s" % (set_type, i)
231 | text_a= line['words']
232 | # BIOS
233 | labels = line['labels']
234 | examples.append(InputExample(guid=guid, text_a=text_a, labels=labels))
235 | return examples
236 |
237 | ner_processors = {
238 | "cner": CnerProcessor,
239 | 'cluener':CluenerProcessor
240 | }
241 |
--------------------------------------------------------------------------------
/processors/ner_span.py:
--------------------------------------------------------------------------------
1 | """ Named entity recognition fine-tuning: utilities to work with CoNLL-2003 task. """
2 | import torch
3 | import logging
4 | import os
5 | import copy
6 | import json
7 | from .utils_ner import DataProcessor,get_entities
8 | logger = logging.getLogger(__name__)
9 |
10 | class InputExample(object):
11 | """A single training/test example for token classification."""
12 | def __init__(self, guid, text_a, subject):
13 | self.guid = guid
14 | self.text_a = text_a
15 | self.subject = subject
16 | def __repr__(self):
17 | return str(self.to_json_string())
18 | def to_dict(self):
19 | """Serializes this instance to a Python dictionary."""
20 | output = copy.deepcopy(self.__dict__)
21 | return output
22 | def to_json_string(self):
23 | """Serializes this instance to a JSON string."""
24 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
25 |
26 | class InputFeature(object):
27 | """A single set of features of data."""
28 |
29 | def __init__(self, input_ids, input_mask, input_len, segment_ids, start_ids,end_ids, subjects):
30 | self.input_ids = input_ids
31 | self.input_mask = input_mask
32 | self.segment_ids = segment_ids
33 | self.start_ids = start_ids
34 | self.input_len = input_len
35 | self.end_ids = end_ids
36 | self.subjects = subjects
37 |
38 | def __repr__(self):
39 | return str(self.to_json_string())
40 |
41 | def to_dict(self):
42 | """Serializes this instance to a Python dictionary."""
43 | output = copy.deepcopy(self.__dict__)
44 | return output
45 |
46 | def to_json_string(self):
47 | """Serializes this instance to a JSON string."""
48 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
49 |
50 | def collate_fn(batch):
51 | """
52 | batch should be a list of (sequence, target, length) tuples...
53 | Returns a padded tensor of sequences sorted from longest to shortest,
54 | """
55 | all_input_ids, all_input_mask, all_segment_ids, all_start_ids,all_end_ids,all_lens = map(torch.stack, zip(*batch))
56 | max_len = max(all_lens).item()
57 | all_input_ids = all_input_ids[:, :max_len]
58 | all_input_mask = all_input_mask[:, :max_len]
59 | all_segment_ids = all_segment_ids[:, :max_len]
60 | all_start_ids = all_start_ids[:,:max_len]
61 | all_end_ids = all_end_ids[:, :max_len]
62 | return all_input_ids, all_input_mask, all_segment_ids, all_start_ids,all_end_ids,all_lens
63 |
64 | def convert_examples_to_features(examples,label_list,max_seq_length,tokenizer,
65 | cls_token_at_end=False,cls_token="[CLS]",cls_token_segment_id=1,
66 | sep_token="[SEP]",pad_on_left=False,pad_token=0,pad_token_segment_id=0,
67 | sequence_a_segment_id=0,mask_padding_with_zero=True,):
68 | """ Loads a data file into a list of `InputBatch`s
69 | `cls_token_at_end` define the location of the CLS token:
70 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
71 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
72 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
73 | """
74 | label2id = {label: i for i, label in enumerate(label_list)}
75 | features = []
76 | for (ex_index, example) in enumerate(examples):
77 | if ex_index % 10000 == 0:
78 | logger.info("Writing example %d of %d", ex_index, len(examples))
79 | textlist = example.text_a
80 | subjects = example.subject
81 | if isinstance(textlist,list):
82 | textlist = " ".join(textlist)
83 | tokens = tokenizer.tokenize(textlist)
84 | start_ids = [0] * len(tokens)
85 | end_ids = [0] * len(tokens)
86 | subjects_id = []
87 | for subject in subjects:
88 | label = subject[0]
89 | start = subject[1]
90 | end = subject[2]
91 | start_ids[start] = label2id[label]
92 | end_ids[end] = label2id[label]
93 | subjects_id.append((label2id[label], start, end))
94 | # Account for [CLS] and [SEP] with "- 2".
95 | special_tokens_count = 2
96 | if len(tokens) > max_seq_length - special_tokens_count:
97 | tokens = tokens[: (max_seq_length - special_tokens_count)]
98 | start_ids = start_ids[: (max_seq_length - special_tokens_count)]
99 | end_ids = end_ids[: (max_seq_length - special_tokens_count)]
100 | # The convention in BERT is:
101 | # (a) For sequence pairs:
102 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
103 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
104 | # (b) For single sequences:
105 | # tokens: [CLS] the dog is hairy . [SEP]
106 | # type_ids: 0 0 0 0 0 0 0
107 | #
108 | # Where "type_ids" are used to indicate whether this is the first
109 | # sequence or the second sequence. The embedding vectors for `type=0` and
110 | # `type=1` were learned during pre-training and are added to the wordpiece
111 | # embedding vector (and position vector). This is not *strictly* necessary
112 | # since the [SEP] token unambiguously separates the sequences, but it makes
113 | # it easier for the model to learn the concept of sequences.
114 | #
115 | # For classification tasks, the first vector (corresponding to [CLS]) is
116 | # used as as the "sentence vector". Note that this only makes sense because
117 | # the entire model is fine-tuned.
118 | tokens += [sep_token]
119 | start_ids += [0]
120 | end_ids += [0]
121 | segment_ids = [sequence_a_segment_id] * len(tokens)
122 | if cls_token_at_end:
123 | tokens += [cls_token]
124 | start_ids += [0]
125 | end_ids += [0]
126 | segment_ids += [cls_token_segment_id]
127 | else:
128 | tokens = [cls_token] + tokens
129 | start_ids = [0]+ start_ids
130 | end_ids = [0]+ end_ids
131 | segment_ids = [cls_token_segment_id] + segment_ids
132 |
133 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
134 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
135 | # tokens are attended to.
136 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
137 | input_len = len(input_ids)
138 | # Zero-pad up to the sequence length.
139 | padding_length = max_seq_length - len(input_ids)
140 | if pad_on_left:
141 | input_ids = ([pad_token] * padding_length) + input_ids
142 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
143 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
144 | start_ids = ([0] * padding_length) + start_ids
145 | end_ids = ([0] * padding_length) + end_ids
146 | else:
147 | input_ids += [pad_token] * padding_length
148 | input_mask += [0 if mask_padding_with_zero else 1] * padding_length
149 | segment_ids += [pad_token_segment_id] * padding_length
150 | start_ids += ([0] * padding_length)
151 | end_ids += ([0] * padding_length)
152 |
153 | assert len(input_ids) == max_seq_length
154 | assert len(input_mask) == max_seq_length
155 | assert len(segment_ids) == max_seq_length
156 | assert len(start_ids) == max_seq_length
157 | assert len(end_ids) == max_seq_length
158 |
159 | if ex_index < 5:
160 | logger.info("*** Example ***")
161 | logger.info("guid: %s", example.guid)
162 | logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
163 | logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
164 | logger.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
165 | logger.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
166 | logger.info("start_ids: %s" % " ".join([str(x) for x in start_ids]))
167 | logger.info("end_ids: %s" % " ".join([str(x) for x in end_ids]))
168 |
169 | features.append(InputFeature(input_ids=input_ids,
170 | input_mask=input_mask,
171 | segment_ids=segment_ids,
172 | start_ids=start_ids,
173 | end_ids=end_ids,
174 | subjects=subjects_id,
175 | input_len=input_len))
176 | return features
177 |
178 | class CnerProcessor(DataProcessor):
179 | """Processor for the chinese ner data set."""
180 |
181 | def get_train_examples(self, data_dir):
182 | """See base class."""
183 | return self._create_examples(self._read_text(os.path.join(data_dir, "train.char.bmes")), "train")
184 |
185 | def get_dev_examples(self, data_dir):
186 | """See base class."""
187 | return self._create_examples(self._read_text(os.path.join(data_dir, "dev.char.bmes")), "dev")
188 |
189 | def get_test_examples(self, data_dir):
190 | """See base class."""
191 | return self._create_examples(self._read_text(os.path.join(data_dir, "test.char.bmes")), "test")
192 |
193 | def get_labels(self):
194 | """See base class."""
195 | return ["O", "CONT", "ORG","LOC",'EDU','NAME','PRO','RACE','TITLE']
196 |
197 | def _create_examples(self, lines, set_type):
198 | """Creates examples for the training and dev sets."""
199 | examples = []
200 | for (i, line) in enumerate(lines):
201 | if i == 0:
202 | continue
203 | guid = "%s-%s" % (set_type, i)
204 | text_a = line['words']
205 | labels = []
206 | for x in line['labels']:
207 | if 'M-' in x:
208 | labels.append(x.replace('M-','I-'))
209 | elif 'E-' in x:
210 | labels.append(x.replace('E-', 'I-'))
211 | else:
212 | labels.append(x)
213 | subject = get_entities(labels,id2label=None,markup='bios')
214 | examples.append(InputExample(guid=guid, text_a=text_a, subject=subject))
215 | return examples
216 |
217 | class CluenerProcessor(DataProcessor):
218 | """Processor for the chinese ner data set."""
219 |
220 | def get_train_examples(self, data_dir):
221 | """See base class."""
222 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train")
223 |
224 | def get_dev_examples(self, data_dir):
225 | """See base class."""
226 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev")
227 |
228 | def get_test_examples(self, data_dir):
229 | """See base class."""
230 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test")
231 |
232 | def get_labels(self):
233 | """See base class."""
234 | return ["O", "address", "book","company",'game','government','movie','name','organization','position','scene']
235 |
236 | def _create_examples(self, lines, set_type):
237 | """Creates examples for the training and dev sets."""
238 | examples = []
239 | for (i, line) in enumerate(lines):
240 | guid = "%s-%s" % (set_type, i)
241 | text_a = line['words']
242 | labels = line['labels']
243 | subject = get_entities(labels,id2label=None,markup='bios')
244 | examples.append(InputExample(guid=guid, text_a=text_a, subject=subject))
245 | return examples
246 |
247 | ner_processors = {
248 | "cner": CnerProcessor,
249 | 'cluener':CluenerProcessor
250 | }
251 |
252 |
253 |
--------------------------------------------------------------------------------
/tools/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import json
6 | import pickle
7 | import torch.nn as nn
8 | from collections import OrderedDict
9 | from pathlib import Path
10 | import logging
11 |
12 | logger = logging.getLogger()
13 | def print_config(config):
14 | info = "Running with the following configs:\n"
15 | for k, v in config.items():
16 | info += f"\t{k} : {str(v)}\n"
17 | print("\n" + info + "\n")
18 | return
19 |
20 | def init_logger(log_file=None, log_file_level=logging.NOTSET):
21 | '''
22 | Example:
23 | >>> init_logger(log_file)
24 | >>> logger.info("abc'")
25 | '''
26 | if isinstance(log_file,Path):
27 | log_file = str(log_file)
28 | log_format = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
29 | datefmt='%m/%d/%Y %H:%M:%S')
30 |
31 | logger = logging.getLogger()
32 | logger.setLevel(logging.INFO)
33 | console_handler = logging.StreamHandler()
34 | console_handler.setFormatter(log_format)
35 | logger.handlers = [console_handler]
36 | if log_file and log_file != '':
37 | file_handler = logging.FileHandler(log_file)
38 | file_handler.setLevel(log_file_level)
39 | # file_handler.setFormatter(log_format)
40 | logger.addHandler(file_handler)
41 | return logger
42 |
43 | def seed_everything(seed=1029):
44 | '''
45 | 设置整个开发环境的seed
46 | :param seed:
47 | :param device:
48 | :return:
49 | '''
50 | random.seed(seed)
51 | os.environ['PYTHONHASHSEED'] = str(seed)
52 | np.random.seed(seed)
53 | torch.manual_seed(seed)
54 | torch.cuda.manual_seed(seed)
55 | torch.cuda.manual_seed_all(seed)
56 | # some cudnn methods can be random even after fixing the seed
57 | # unless you tell it to be deterministic
58 | torch.backends.cudnn.deterministic = True
59 |
60 |
61 | def prepare_device(n_gpu_use):
62 | """
63 | setup GPU device if available, move model into configured device
64 | # 如果n_gpu_use为数字,则使用range生成list
65 | # 如果输入的是一个list,则默认使用list[0]作为controller
66 | """
67 | if not n_gpu_use:
68 | device_type = 'cpu'
69 | else:
70 | n_gpu_use = n_gpu_use.split(",")
71 | device_type = f"cuda:{n_gpu_use[0]}"
72 | n_gpu = torch.cuda.device_count()
73 | if len(n_gpu_use) > 0 and n_gpu == 0:
74 | logger.warning("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
75 | device_type = 'cpu'
76 | if len(n_gpu_use) > n_gpu:
77 | msg = f"Warning: The number of GPU\'s configured to use is {n_gpu_use}, but only {n_gpu} are available on this machine."
78 | logger.warning(msg)
79 | n_gpu_use = range(n_gpu)
80 | device = torch.device(device_type)
81 | list_ids = n_gpu_use
82 | return device, list_ids
83 |
84 |
85 | def model_device(n_gpu, model):
86 | '''
87 | 判断环境 cpu还是gpu
88 | 支持单机多卡
89 | :param n_gpu:
90 | :param model:
91 | :return:
92 | '''
93 | device, device_ids = prepare_device(n_gpu)
94 | if len(device_ids) > 1:
95 | logger.info(f"current {len(device_ids)} GPUs")
96 | model = torch.nn.DataParallel(model, device_ids=device_ids)
97 | if len(device_ids) == 1:
98 | os.environ['CUDA_VISIBLE_DEVICES'] = str(device_ids[0])
99 | model = model.to(device)
100 | return model, device
101 |
102 |
103 | def restore_checkpoint(resume_path, model=None):
104 | '''
105 | 加载模型
106 | :param resume_path:
107 | :param model:
108 | :param optimizer:
109 | :return:
110 | 注意: 如果是加载Bert模型的话,需要调整,不能使用该模式
111 | 可以使用模块自带的Bert_model.from_pretrained(state_dict = your save state_dict)
112 | '''
113 | if isinstance(resume_path, Path):
114 | resume_path = str(resume_path)
115 | checkpoint = torch.load(resume_path)
116 | best = checkpoint['best']
117 | start_epoch = checkpoint['epoch'] + 1
118 | states = checkpoint['state_dict']
119 | if isinstance(model, nn.DataParallel):
120 | model.module.load_state_dict(states)
121 | else:
122 | model.load_state_dict(states)
123 | return [model,best,start_epoch]
124 |
125 |
126 | def save_pickle(data, file_path):
127 | '''
128 | 保存成pickle文件
129 | :param data:
130 | :param file_name:
131 | :param pickle_path:
132 | :return:
133 | '''
134 | if isinstance(file_path, Path):
135 | file_path = str(file_path)
136 | with open(file_path, 'wb') as f:
137 | pickle.dump(data, f)
138 |
139 |
140 | def load_pickle(input_file):
141 | '''
142 | 读取pickle文件
143 | :param pickle_path:
144 | :param file_name:
145 | :return:
146 | '''
147 | with open(str(input_file), 'rb') as f:
148 | data = pickle.load(f)
149 | return data
150 |
151 |
152 | def save_json(data, file_path):
153 | '''
154 | 保存成json文件
155 | :param data:
156 | :param json_path:
157 | :param file_name:
158 | :return:
159 | '''
160 | if not isinstance(file_path, Path):
161 | file_path = Path(file_path)
162 | # if isinstance(data,dict):
163 | # data = json.dumps(data)
164 | with open(str(file_path), 'w') as f:
165 | json.dump(data, f)
166 |
167 | def save_numpy(data, file_path):
168 | '''
169 | 保存成.npy文件
170 | :param data:
171 | :param file_path:
172 | :return:
173 | '''
174 | if not isinstance(file_path, Path):
175 | file_path = Path(file_path)
176 | np.save(str(file_path),data)
177 |
178 | def load_numpy(file_path):
179 | '''
180 | 加载.npy文件
181 | :param file_path:
182 | :return:
183 | '''
184 | if not isinstance(file_path, Path):
185 | file_path = Path(file_path)
186 | np.load(str(file_path))
187 |
188 | def load_json(file_path):
189 | '''
190 | 加载json文件
191 | :param json_path:
192 | :param file_name:
193 | :return:
194 | '''
195 | if not isinstance(file_path, Path):
196 | file_path = Path(file_path)
197 | with open(str(file_path), 'r') as f:
198 | data = json.load(f)
199 | return data
200 |
201 | def json_to_text(file_path,data):
202 | '''
203 | 将json list写入text文件中
204 | :param file_path:
205 | :param data:
206 | :return:
207 | '''
208 | if not isinstance(file_path, Path):
209 | file_path = Path(file_path)
210 | with open(str(file_path), 'w') as fw:
211 | for line in data:
212 | line = json.dumps(line, ensure_ascii=False)
213 | fw.write(line + '\n')
214 |
215 | def save_model(model, model_path):
216 | """ 存储不含有显卡信息的state_dict或model
217 | :param model:
218 | :param model_name:
219 | :param only_param:
220 | :return:
221 | """
222 | if isinstance(model_path, Path):
223 | model_path = str(model_path)
224 | if isinstance(model, nn.DataParallel):
225 | model = model.module
226 | state_dict = model.state_dict()
227 | for key in state_dict:
228 | state_dict[key] = state_dict[key].cpu()
229 | torch.save(state_dict, model_path)
230 |
231 | def load_model(model, model_path):
232 | '''
233 | 加载模型
234 | :param model:
235 | :param model_name:
236 | :param model_path:
237 | :param only_param:
238 | :return:
239 | '''
240 | if isinstance(model_path, Path):
241 | model_path = str(model_path)
242 | logging.info(f"loading model from {str(model_path)} .")
243 | states = torch.load(model_path)
244 | state = states['state_dict']
245 | if isinstance(model, nn.DataParallel):
246 | model.module.load_state_dict(state)
247 | else:
248 | model.load_state_dict(state)
249 | return model
250 |
251 |
252 | class AverageMeter(object):
253 | '''
254 | computes and stores the average and current value
255 | Example:
256 | >>> loss = AverageMeter()
257 | >>> for step,batch in enumerate(train_data):
258 | >>> pred = self.model(batch)
259 | >>> raw_loss = self.metrics(pred,target)
260 | >>> loss.update(raw_loss.item(),n = 1)
261 | >>> cur_loss = loss.avg
262 | '''
263 |
264 | def __init__(self):
265 | self.reset()
266 |
267 | def reset(self):
268 | self.val = 0
269 | self.avg = 0
270 | self.sum = 0
271 | self.count = 0
272 |
273 | def update(self, val, n=1):
274 | self.val = val
275 | self.sum += val * n
276 | self.count += n
277 | self.avg = self.sum / self.count
278 |
279 |
280 | def summary(model, *inputs, batch_size=-1, show_input=True):
281 | '''
282 | 打印模型结构信息
283 | :param model:
284 | :param inputs:
285 | :param batch_size:
286 | :param show_input:
287 | :return:
288 | Example:
289 | >>> print("model summary info: ")
290 | >>> for step,batch in enumerate(train_data):
291 | >>> summary(self.model,*batch,show_input=True)
292 | >>> break
293 | '''
294 |
295 | def register_hook(module):
296 | def hook(module, input, output=None):
297 | class_name = str(module.__class__).split(".")[-1].split("'")[0]
298 | module_idx = len(summary)
299 |
300 | m_key = f"{class_name}-{module_idx + 1}"
301 | summary[m_key] = OrderedDict()
302 | summary[m_key]["input_shape"] = list(input[0].size())
303 | summary[m_key]["input_shape"][0] = batch_size
304 |
305 | if show_input is False and output is not None:
306 | if isinstance(output, (list, tuple)):
307 | for out in output:
308 | if isinstance(out, torch.Tensor):
309 | summary[m_key]["output_shape"] = [
310 | [-1] + list(out.size())[1:]
311 | ][0]
312 | else:
313 | summary[m_key]["output_shape"] = [
314 | [-1] + list(out[0].size())[1:]
315 | ][0]
316 | else:
317 | summary[m_key]["output_shape"] = list(output.size())
318 | summary[m_key]["output_shape"][0] = batch_size
319 |
320 | params = 0
321 | if hasattr(module, "weight") and hasattr(module.weight, "size"):
322 | params += torch.prod(torch.LongTensor(list(module.weight.size())))
323 | summary[m_key]["trainable"] = module.weight.requires_grad
324 | if hasattr(module, "bias") and hasattr(module.bias, "size"):
325 | params += torch.prod(torch.LongTensor(list(module.bias.size())))
326 | summary[m_key]["nb_params"] = params
327 |
328 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and not (module == model)):
329 | if show_input is True:
330 | hooks.append(module.register_forward_pre_hook(hook))
331 | else:
332 | hooks.append(module.register_forward_hook(hook))
333 |
334 | # create properties
335 | summary = OrderedDict()
336 | hooks = []
337 |
338 | # register hook
339 | model.apply(register_hook)
340 | model(*inputs)
341 |
342 | # remove these hooks
343 | for h in hooks:
344 | h.remove()
345 |
346 | print("-----------------------------------------------------------------------")
347 | if show_input is True:
348 | line_new = f"{'Layer (type)':>25} {'Input Shape':>25} {'Param #':>15}"
349 | else:
350 | line_new = f"{'Layer (type)':>25} {'Output Shape':>25} {'Param #':>15}"
351 | print(line_new)
352 | print("=======================================================================")
353 |
354 | total_params = 0
355 | total_output = 0
356 | trainable_params = 0
357 | for layer in summary:
358 | # input_shape, output_shape, trainable, nb_params
359 | if show_input is True:
360 | line_new = "{:>25} {:>25} {:>15}".format(
361 | layer,
362 | str(summary[layer]["input_shape"]),
363 | "{0:,}".format(summary[layer]["nb_params"]),
364 | )
365 | else:
366 | line_new = "{:>25} {:>25} {:>15}".format(
367 | layer,
368 | str(summary[layer]["output_shape"]),
369 | "{0:,}".format(summary[layer]["nb_params"]),
370 | )
371 |
372 | total_params += summary[layer]["nb_params"]
373 | if show_input is True:
374 | total_output += np.prod(summary[layer]["input_shape"])
375 | else:
376 | total_output += np.prod(summary[layer]["output_shape"])
377 | if "trainable" in summary[layer]:
378 | if summary[layer]["trainable"] == True:
379 | trainable_params += summary[layer]["nb_params"]
380 |
381 | print(line_new)
382 |
383 | print("=======================================================================")
384 | print(f"Total params: {total_params:0,}")
385 | print(f"Trainable params: {trainable_params:0,}")
386 | print(f"Non-trainable params: {(total_params - trainable_params):0,}")
387 | print("-----------------------------------------------------------------------")
--------------------------------------------------------------------------------
/models/layers/crf.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import List, Optional
4 |
5 | class CRF(nn.Module):
6 | """Conditional random field.
7 | This module implements a conditional random field [LMP01]_. The forward computation
8 | of this class computes the log likelihood of the given sequence of tags and
9 | emission score tensor. This class also has `~CRF.decode` method which finds
10 | the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
11 | Args:
12 | num_tags: Number of tags.
13 | batch_first: Whether the first dimension corresponds to the size of a minibatch.
14 | Attributes:
15 | start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
16 | ``(num_tags,)``.
17 | end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
18 | ``(num_tags,)``.
19 | transitions (`~torch.nn.Parameter`): Transition score tensor of size
20 | ``(num_tags, num_tags)``.
21 | .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
22 | "Conditional random fields: Probabilistic models for segmenting and
23 | labeling sequence data". *Proc. 18th International Conf. on Machine
24 | Learning*. Morgan Kaufmann. pp. 282–289.
25 | .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
26 | """
27 |
28 | def __init__(self, num_tags: int, batch_first: bool = False) -> None:
29 | if num_tags <= 0:
30 | raise ValueError(f'invalid number of tags: {num_tags}')
31 | super().__init__()
32 | self.num_tags = num_tags
33 | self.batch_first = batch_first
34 | self.start_transitions = nn.Parameter(torch.empty(num_tags))
35 | self.end_transitions = nn.Parameter(torch.empty(num_tags))
36 | self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
37 |
38 | self.reset_parameters()
39 |
40 | def reset_parameters(self) -> None:
41 | """Initialize the transition parameters.
42 | The parameters will be initialized randomly from a uniform distribution
43 | between -0.1 and 0.1.
44 | """
45 | nn.init.uniform_(self.start_transitions, -0.1, 0.1)
46 | nn.init.uniform_(self.end_transitions, -0.1, 0.1)
47 | nn.init.uniform_(self.transitions, -0.1, 0.1)
48 |
49 | def __repr__(self) -> str:
50 | return f'{self.__class__.__name__}(num_tags={self.num_tags})'
51 |
52 | def forward(self, emissions: torch.Tensor,
53 | tags: torch.LongTensor,
54 | mask: Optional[torch.ByteTensor] = None,
55 | reduction: str = 'mean') -> torch.Tensor:
56 | """Compute the conditional log likelihood of a sequence of tags given emission scores.
57 | Args:
58 | emissions (`~torch.Tensor`): Emission score tensor of size
59 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
60 | ``(batch_size, seq_length, num_tags)`` otherwise.
61 | tags (`~torch.LongTensor`): Sequence of tags tensor of size
62 | ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
63 | ``(batch_size, seq_length)`` otherwise.
64 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
65 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
66 | reduction: Specifies the reduction to apply to the output:
67 | ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
68 | ``sum``: the output will be summed over batches. ``mean``: the output will be
69 | averaged over batches. ``token_mean``: the output will be averaged over tokens.
70 | Returns:
71 | `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
72 | reduction is ``none``, ``()`` otherwise.
73 | """
74 | if reduction not in ('none', 'sum', 'mean', 'token_mean'):
75 | raise ValueError(f'invalid reduction: {reduction}')
76 | if mask is None:
77 | mask = torch.ones_like(tags, dtype=torch.uint8, device=tags.device)
78 | if mask.dtype != torch.uint8:
79 | mask = mask.byte()
80 | self._validate(emissions, tags=tags, mask=mask)
81 |
82 | if self.batch_first:
83 | emissions = emissions.transpose(0, 1)
84 | tags = tags.transpose(0, 1)
85 | mask = mask.transpose(0, 1)
86 |
87 | # shape: (batch_size,)
88 | numerator = self._compute_score(emissions, tags, mask)
89 | # shape: (batch_size,)
90 | denominator = self._compute_normalizer(emissions, mask)
91 | # shape: (batch_size,)
92 | llh = numerator - denominator
93 |
94 | if reduction == 'none':
95 | return llh
96 | if reduction == 'sum':
97 | return llh.sum()
98 | if reduction == 'mean':
99 | return llh.mean()
100 | return llh.sum() / mask.float().sum()
101 |
102 | def decode(self, emissions: torch.Tensor,
103 | mask: Optional[torch.ByteTensor] = None,
104 | nbest: Optional[int] = None,
105 | pad_tag: Optional[int] = None) -> List[List[List[int]]]:
106 | """Find the most likely tag sequence using Viterbi algorithm.
107 | Args:
108 | emissions (`~torch.Tensor`): Emission score tensor of size
109 | ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
110 | ``(batch_size, seq_length, num_tags)`` otherwise.
111 | mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
112 | if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
113 | nbest (`int`): Number of most probable paths for each sequence
114 | pad_tag (`int`): Tag at padded positions. Often input varies in length and
115 | the length will be padded to the maximum length in the batch. Tags at
116 | the padded positions will be assigned with a padding tag, i.e. `pad_tag`
117 | Returns:
118 | A PyTorch tensor of the best tag sequence for each batch of shape
119 | (nbest, batch_size, seq_length)
120 | """
121 | if nbest is None:
122 | nbest = 1
123 | if mask is None:
124 | mask = torch.ones(emissions.shape[:2], dtype=torch.uint8,
125 | device=emissions.device)
126 | if mask.dtype != torch.uint8:
127 | mask = mask.byte()
128 | self._validate(emissions, mask=mask)
129 |
130 | if self.batch_first:
131 | emissions = emissions.transpose(0, 1)
132 | mask = mask.transpose(0, 1)
133 |
134 | if nbest == 1:
135 | return self._viterbi_decode(emissions, mask, pad_tag).unsqueeze(0)
136 | return self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
137 |
138 | def _validate(self, emissions: torch.Tensor,
139 | tags: Optional[torch.LongTensor] = None,
140 | mask: Optional[torch.ByteTensor] = None) -> None:
141 | if emissions.dim() != 3:
142 | raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
143 | if emissions.size(2) != self.num_tags:
144 | raise ValueError(
145 | f'expected last dimension of emissions is {self.num_tags}, '
146 | f'got {emissions.size(2)}')
147 |
148 | if tags is not None:
149 | if emissions.shape[:2] != tags.shape:
150 | raise ValueError(
151 | 'the first two dimensions of emissions and tags must match, '
152 | f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
153 |
154 | if mask is not None:
155 | if emissions.shape[:2] != mask.shape:
156 | raise ValueError(
157 | 'the first two dimensions of emissions and mask must match, '
158 | f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
159 | no_empty_seq = not self.batch_first and mask[0].all()
160 | no_empty_seq_bf = self.batch_first and mask[:, 0].all()
161 | if not no_empty_seq and not no_empty_seq_bf:
162 | raise ValueError('mask of the first timestep must all be on')
163 |
164 | def _compute_score(self, emissions: torch.Tensor,
165 | tags: torch.LongTensor,
166 | mask: torch.ByteTensor) -> torch.Tensor:
167 | # emissions: (seq_length, batch_size, num_tags)
168 | # tags: (seq_length, batch_size)
169 | # mask: (seq_length, batch_size)
170 | seq_length, batch_size = tags.shape
171 | mask = mask.float()
172 |
173 | # Start transition score and first emission
174 | # shape: (batch_size,)
175 | score = self.start_transitions[tags[0]]
176 | score += emissions[0, torch.arange(batch_size), tags[0]]
177 |
178 | for i in range(1, seq_length):
179 | # Transition score to next tag, only added if next timestep is valid (mask == 1)
180 | # shape: (batch_size,)
181 | score += self.transitions[tags[i - 1], tags[i]] * mask[i]
182 |
183 | # Emission score for next tag, only added if next timestep is valid (mask == 1)
184 | # shape: (batch_size,)
185 | score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
186 |
187 | # End transition score
188 | # shape: (batch_size,)
189 | seq_ends = mask.long().sum(dim=0) - 1
190 | # shape: (batch_size,)
191 | last_tags = tags[seq_ends, torch.arange(batch_size)]
192 | # shape: (batch_size,)
193 | score += self.end_transitions[last_tags]
194 |
195 | return score
196 |
197 | def _compute_normalizer(self, emissions: torch.Tensor,
198 | mask: torch.ByteTensor) -> torch.Tensor:
199 | # emissions: (seq_length, batch_size, num_tags)
200 | # mask: (seq_length, batch_size)
201 | seq_length = emissions.size(0)
202 |
203 | # Start transition score and first emission; score has size of
204 | # (batch_size, num_tags) where for each batch, the j-th column stores
205 | # the score that the first timestep has tag j
206 | # shape: (batch_size, num_tags)
207 | score = self.start_transitions + emissions[0]
208 |
209 | for i in range(1, seq_length):
210 | # Broadcast score for every possible next tag
211 | # shape: (batch_size, num_tags, 1)
212 | broadcast_score = score.unsqueeze(2)
213 |
214 | # Broadcast emission score for every possible current tag
215 | # shape: (batch_size, 1, num_tags)
216 | broadcast_emissions = emissions[i].unsqueeze(1)
217 |
218 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
219 | # for each sample, entry at row i and column j stores the sum of scores of all
220 | # possible tag sequences so far that end with transitioning from tag i to tag j
221 | # and emitting
222 | # shape: (batch_size, num_tags, num_tags)
223 | next_score = broadcast_score + self.transitions + broadcast_emissions
224 |
225 | # Sum over all possible current tags, but we're in score space, so a sum
226 | # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
227 | # all possible tag sequences so far, that end in tag i
228 | # shape: (batch_size, num_tags)
229 | next_score = torch.logsumexp(next_score, dim=1)
230 |
231 | # Set score to the next score if this timestep is valid (mask == 1)
232 | # shape: (batch_size, num_tags)
233 | score = torch.where(mask[i].unsqueeze(1), next_score, score)
234 |
235 | # End transition score
236 | # shape: (batch_size, num_tags)
237 | score += self.end_transitions
238 |
239 | # Sum (log-sum-exp) over all possible tags
240 | # shape: (batch_size,)
241 | return torch.logsumexp(score, dim=1)
242 |
243 | def _viterbi_decode(self, emissions: torch.FloatTensor,
244 | mask: torch.ByteTensor,
245 | pad_tag: Optional[int] = None) -> List[List[int]]:
246 | # emissions: (seq_length, batch_size, num_tags)
247 | # mask: (seq_length, batch_size)
248 | # return: (batch_size, seq_length)
249 | if pad_tag is None:
250 | pad_tag = 0
251 |
252 | device = emissions.device
253 | seq_length, batch_size = mask.shape
254 |
255 | # Start transition and first emission
256 | # shape: (batch_size, num_tags)
257 | score = self.start_transitions + emissions[0]
258 | history_idx = torch.zeros((seq_length, batch_size, self.num_tags),
259 | dtype=torch.long, device=device)
260 | oor_idx = torch.zeros((batch_size, self.num_tags),
261 | dtype=torch.long, device=device)
262 | oor_tag = torch.full((seq_length, batch_size), pad_tag,
263 | dtype=torch.long, device=device)
264 |
265 | # - score is a tensor of size (batch_size, num_tags) where for every batch,
266 | # value at column j stores the score of the best tag sequence so far that ends
267 | # with tag j
268 | # - history_idx saves where the best tags candidate transitioned from; this is used
269 | # when we trace back the best tag sequence
270 | # - oor_idx saves the best tags candidate transitioned from at the positions
271 | # where mask is 0, i.e. out of range (oor)
272 |
273 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence
274 | # for every possible next tag
275 | for i in range(1, seq_length):
276 | # Broadcast viterbi score for every possible next tag
277 | # shape: (batch_size, num_tags, 1)
278 | broadcast_score = score.unsqueeze(2)
279 |
280 | # Broadcast emission score for every possible current tag
281 | # shape: (batch_size, 1, num_tags)
282 | broadcast_emission = emissions[i].unsqueeze(1)
283 |
284 | # Compute the score tensor of size (batch_size, num_tags, num_tags) where
285 | # for each sample, entry at row i and column j stores the score of the best
286 | # tag sequence so far that ends with transitioning from tag i to tag j and emitting
287 | # shape: (batch_size, num_tags, num_tags)
288 | next_score = broadcast_score + self.transitions + broadcast_emission
289 |
290 | # Find the maximum score over all possible current tag
291 | # shape: (batch_size, num_tags)
292 | next_score, indices = next_score.max(dim=1)
293 |
294 | # Set score to the next score if this timestep is valid (mask == 1)
295 | # and save the index that produces the next score
296 | # shape: (batch_size, num_tags)
297 | score = torch.where(mask[i].unsqueeze(-1), next_score, score)
298 | indices = torch.where(mask[i].unsqueeze(-1), indices, oor_idx)
299 | history_idx[i - 1] = indices
300 |
301 | # End transition score
302 | # shape: (batch_size, num_tags)
303 | end_score = score + self.end_transitions
304 | _, end_tag = end_score.max(dim=1)
305 |
306 | # shape: (batch_size,)
307 | seq_ends = mask.long().sum(dim=0) - 1
308 |
309 | # insert the best tag at each sequence end (last position with mask == 1)
310 | history_idx = history_idx.transpose(1, 0).contiguous()
311 | history_idx.scatter_(1, seq_ends.view(-1, 1, 1).expand(-1, 1, self.num_tags),
312 | end_tag.view(-1, 1, 1).expand(-1, 1, self.num_tags))
313 | history_idx = history_idx.transpose(1, 0).contiguous()
314 |
315 | # The most probable path for each sequence
316 | best_tags_arr = torch.zeros((seq_length, batch_size),
317 | dtype=torch.long, device=device)
318 | best_tags = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
319 | for idx in range(seq_length - 1, -1, -1):
320 | best_tags = torch.gather(history_idx[idx], 1, best_tags)
321 | best_tags_arr[idx] = best_tags.data.view(batch_size)
322 |
323 | return torch.where(mask, best_tags_arr, oor_tag).transpose(0, 1)
324 |
325 | def _viterbi_decode_nbest(self, emissions: torch.FloatTensor,
326 | mask: torch.ByteTensor,
327 | nbest: int,
328 | pad_tag: Optional[int] = None) -> List[List[List[int]]]:
329 | # emissions: (seq_length, batch_size, num_tags)
330 | # mask: (seq_length, batch_size)
331 | # return: (nbest, batch_size, seq_length)
332 | if pad_tag is None:
333 | pad_tag = 0
334 |
335 | device = emissions.device
336 | seq_length, batch_size = mask.shape
337 |
338 | # Start transition and first emission
339 | # shape: (batch_size, num_tags)
340 | score = self.start_transitions + emissions[0]
341 | history_idx = torch.zeros((seq_length, batch_size, self.num_tags, nbest),
342 | dtype=torch.long, device=device)
343 | oor_idx = torch.zeros((batch_size, self.num_tags, nbest),
344 | dtype=torch.long, device=device)
345 | oor_tag = torch.full((seq_length, batch_size, nbest), pad_tag,
346 | dtype=torch.long, device=device)
347 |
348 | # + score is a tensor of size (batch_size, num_tags) where for every batch,
349 | # value at column j stores the score of the best tag sequence so far that ends
350 | # with tag j
351 | # + history_idx saves where the best tags candidate transitioned from; this is used
352 | # when we trace back the best tag sequence
353 | # - oor_idx saves the best tags candidate transitioned from at the positions
354 | # where mask is 0, i.e. out of range (oor)
355 |
356 | # Viterbi algorithm recursive case: we compute the score of the best tag sequence
357 | # for every possible next tag
358 | for i in range(1, seq_length):
359 | if i == 1:
360 | broadcast_score = score.unsqueeze(-1)
361 | broadcast_emission = emissions[i].unsqueeze(1)
362 | # shape: (batch_size, num_tags, num_tags)
363 | next_score = broadcast_score + self.transitions + broadcast_emission
364 | else:
365 | broadcast_score = score.unsqueeze(-1)
366 | broadcast_emission = emissions[i].unsqueeze(1).unsqueeze(2)
367 | # shape: (batch_size, num_tags, nbest, num_tags)
368 | next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission
369 |
370 | # Find the top `nbest` maximum score over all possible current tag
371 | # shape: (batch_size, nbest, num_tags)
372 | next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)
373 |
374 | if i == 1:
375 | score = score.unsqueeze(-1).expand(-1, -1, nbest)
376 | indices = indices * nbest
377 |
378 | # convert to shape: (batch_size, num_tags, nbest)
379 | next_score = next_score.transpose(2, 1)
380 | indices = indices.transpose(2, 1)
381 |
382 | # Set score to the next score if this timestep is valid (mask == 1)
383 | # and save the index that produces the next score
384 | # shape: (batch_size, num_tags, nbest)
385 | score = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), next_score, score)
386 | indices = torch.where(mask[i].unsqueeze(-1).unsqueeze(-1), indices, oor_idx)
387 | history_idx[i - 1] = indices
388 |
389 | # End transition score shape: (batch_size, num_tags, nbest)
390 | end_score = score + self.end_transitions.unsqueeze(-1)
391 | _, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
392 |
393 | # shape: (batch_size,)
394 | seq_ends = mask.long().sum(dim=0) - 1
395 |
396 | # insert the best tag at each sequence end (last position with mask == 1)
397 | history_idx = history_idx.transpose(1, 0).contiguous()
398 | history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
399 | end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
400 | history_idx = history_idx.transpose(1, 0).contiguous()
401 |
402 | # The most probable path for each sequence
403 | best_tags_arr = torch.zeros((seq_length, batch_size, nbest),
404 | dtype=torch.long, device=device)
405 | best_tags = torch.arange(nbest, dtype=torch.long, device=device) \
406 | .view(1, -1).expand(batch_size, -1)
407 | for idx in range(seq_length - 1, -1, -1):
408 | best_tags = torch.gather(history_idx[idx].view(batch_size, -1), 1, best_tags)
409 | best_tags_arr[idx] = best_tags.data.view(batch_size, -1) // nbest
410 |
411 | return torch.where(mask.unsqueeze(-1), best_tags_arr, oor_tag).permute(2, 1, 0)
--------------------------------------------------------------------------------
/callback/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import warnings
4 | from torch.optim.optimizer import Optimizer
5 | from torch.optim.lr_scheduler import LambdaLR
6 |
7 | def get_constant_schedule(optimizer, last_epoch=-1):
8 | """ Create a schedule with a constant learning rate.
9 | """
10 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
11 |
12 |
13 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
14 | """ Create a schedule with a constant learning rate preceded by a warmup
15 | period during which the learning rate increases linearly between 0 and 1.
16 | """
17 | def lr_lambda(current_step):
18 | if current_step < num_warmup_steps:
19 | return float(current_step) / float(max(1.0, num_warmup_steps))
20 | return 1.
21 |
22 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
23 |
24 |
25 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
26 | """ Create a schedule with a learning rate that decreases linearly after
27 | linearly increasing during a warmup period.
28 | """
29 | def lr_lambda(current_step):
30 | if current_step < num_warmup_steps:
31 | return float(current_step) / float(max(1, num_warmup_steps))
32 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
33 |
34 | return LambdaLR(optimizer, lr_lambda, last_epoch)
35 |
36 |
37 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1):
38 | """ Create a schedule with a learning rate that decreases following the
39 | values of the cosine function between 0 and `pi * cycles` after a warmup
40 | period during which it increases linearly between 0 and 1.
41 | """
42 | def lr_lambda(current_step):
43 | if current_step < num_warmup_steps:
44 | return float(current_step) / float(max(1, num_warmup_steps))
45 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
46 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress)))
47 |
48 | return LambdaLR(optimizer, lr_lambda, last_epoch)
49 |
50 |
51 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1):
52 | """ Create a schedule with a learning rate that decreases following the
53 | values of the cosine function with several hard restarts, after a warmup
54 | period during which it increases linearly between 0 and 1.
55 | """
56 | def lr_lambda(current_step):
57 | if current_step < num_warmup_steps:
58 | return float(current_step) / float(max(1, num_warmup_steps))
59 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
60 | if progress >= 1.:
61 | return 0.
62 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.))))
63 |
64 | return LambdaLR(optimizer, lr_lambda, last_epoch)
65 |
66 |
67 | class CustomDecayLR(object):
68 | '''
69 | 自定义学习率变化机制
70 | Example:
71 | >>> scheduler = CustomDecayLR(optimizer)
72 | >>> for epoch in range(100):
73 | >>> scheduler.epoch_step()
74 | >>> train(...)
75 | >>> ...
76 | >>> optimizer.zero_grad()
77 | >>> loss.backward()
78 | >>> optimizer.step()
79 | >>> validate(...)
80 | '''
81 | def __init__(self,optimizer,lr):
82 | self.optimizer = optimizer
83 | self.lr = lr
84 |
85 | def epoch_step(self,epoch):
86 | lr = self.lr
87 | if epoch > 12:
88 | lr = lr / 1000
89 | elif epoch > 8:
90 | lr = lr / 100
91 | elif epoch > 4:
92 | lr = lr / 10
93 | for param_group in self.optimizer.param_groups:
94 | param_group['lr'] = lr
95 |
96 | class BertLR(object):
97 | '''
98 | Bert模型内定的学习率变化机制
99 | Example:
100 | >>> scheduler = BertLR(optimizer)
101 | >>> for epoch in range(100):
102 | >>> scheduler.step()
103 | >>> train(...)
104 | >>> ...
105 | >>> optimizer.zero_grad()
106 | >>> loss.backward()
107 | >>> optimizer.step()
108 | >>> scheduler.batch_step()
109 | >>> validate(...)
110 | '''
111 | def __init__(self,optimizer,learning_rate,t_total,warmup):
112 | self.learning_rate = learning_rate
113 | self.optimizer = optimizer
114 | self.t_total = t_total
115 | self.warmup = warmup
116 |
117 | # 线性预热方式
118 | def warmup_linear(self,x, warmup=0.002):
119 | if x < warmup:
120 | return x / warmup
121 | return 1.0 - x
122 |
123 | def batch_step(self,training_step):
124 | lr_this_step = self.learning_rate * self.warmup_linear(training_step / self.t_total,self.warmup)
125 | for param_group in self.optimizer.param_groups:
126 | param_group['lr'] = lr_this_step
127 |
128 | class CyclicLR(object):
129 | '''
130 | Cyclical learning rates for training neural networks
131 | Example:
132 | >>> scheduler = CyclicLR(optimizer)
133 | >>> for epoch in range(100):
134 | >>> scheduler.step()
135 | >>> train(...)
136 | >>> ...
137 | >>> optimizer.zero_grad()
138 | >>> loss.backward()
139 | >>> optimizer.step()
140 | >>> scheduler.batch_step()
141 | >>> validate(...)
142 | '''
143 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3,
144 | step_size=2000, mode='triangular', gamma=1.,
145 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1):
146 |
147 | if not isinstance(optimizer, Optimizer):
148 | raise TypeError('{} is not an Optimizer'.format(
149 | type(optimizer).__name__))
150 |
151 | self.optimizer = optimizer
152 |
153 | if isinstance(base_lr, list) or isinstance(base_lr, tuple):
154 | if len(base_lr) != len(optimizer.param_groups):
155 | raise ValueError("expected {} base_lr, got {}".format(
156 | len(optimizer.param_groups), len(base_lr)))
157 | self.base_lrs = list(base_lr)
158 | else:
159 | self.base_lrs = [base_lr] * len(optimizer.param_groups)
160 |
161 | if isinstance(max_lr, list) or isinstance(max_lr, tuple):
162 | if len(max_lr) != len(optimizer.param_groups):
163 | raise ValueError("expected {} max_lr, got {}".format(
164 | len(optimizer.param_groups), len(max_lr)))
165 | self.max_lrs = list(max_lr)
166 | else:
167 | self.max_lrs = [max_lr] * len(optimizer.param_groups)
168 |
169 | self.step_size = step_size
170 |
171 | if mode not in ['triangular', 'triangular2', 'exp_range'] \
172 | and scale_fn is None:
173 | raise ValueError('mode is invalid and scale_fn is None')
174 |
175 | self.mode = mode
176 | self.gamma = gamma
177 |
178 | if scale_fn is None:
179 | if self.mode == 'triangular':
180 | self.scale_fn = self._triangular_scale_fn
181 | self.scale_mode = 'cycle'
182 | elif self.mode == 'triangular2':
183 | self.scale_fn = self._triangular2_scale_fn
184 | self.scale_mode = 'cycle'
185 | elif self.mode == 'exp_range':
186 | self.scale_fn = self._exp_range_scale_fn
187 | self.scale_mode = 'iterations'
188 | else:
189 | self.scale_fn = scale_fn
190 | self.scale_mode = scale_mode
191 |
192 | self.batch_step(last_batch_iteration + 1)
193 | self.last_batch_iteration = last_batch_iteration
194 |
195 | def _triangular_scale_fn(self, x):
196 | return 1.
197 |
198 | def _triangular2_scale_fn(self, x):
199 | return 1 / (2. ** (x - 1))
200 |
201 | def _exp_range_scale_fn(self, x):
202 | return self.gamma**(x)
203 |
204 | def get_lr(self):
205 | step_size = float(self.step_size)
206 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size))
207 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1)
208 |
209 | lrs = []
210 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs)
211 | for param_group, base_lr, max_lr in param_lrs:
212 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x))
213 | if self.scale_mode == 'cycle':
214 | lr = base_lr + base_height * self.scale_fn(cycle)
215 | else:
216 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration)
217 | lrs.append(lr)
218 | return lrs
219 |
220 | def batch_step(self, batch_iteration=None):
221 | if batch_iteration is None:
222 | batch_iteration = self.last_batch_iteration + 1
223 | self.last_batch_iteration = batch_iteration
224 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
225 | param_group['lr'] = lr
226 |
227 | class ReduceLROnPlateau(object):
228 | """Reduce learning rate when a metric has stopped improving.
229 | Models often benefit from reducing the learning rate by a factor
230 | of 2-10 once learning stagnates. This scheduler reads a metrics
231 | quantity and if no improvement is seen for a 'patience' number
232 | of epochs, the learning rate is reduced.
233 |
234 | Args:
235 | factor: factor by which the learning rate will
236 | be reduced. new_lr = lr * factor
237 | patience: number of epochs with no improvement
238 | after which learning rate will be reduced.
239 | verbose: int. 0: quiet, 1: update messages.
240 | mode: one of {min, max}. In `min` mode,
241 | lr will be reduced when the quantity
242 | monitored has stopped decreasing; in `max`
243 | mode it will be reduced when the quantity
244 | monitored has stopped increasing.
245 | epsilon: threshold for measuring the new optimum,
246 | to only focus on significant changes.
247 | cooldown: number of epochs to wait before resuming
248 | normal operation after lr has been reduced.
249 | min_lr: lower bound on the learning rate.
250 |
251 |
252 | Example:
253 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
254 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
255 | >>> for epoch in range(10):
256 | >>> train(...)
257 | >>> val_acc, val_loss = validate(...)
258 | >>> scheduler.epoch_step(val_loss, epoch)
259 | """
260 |
261 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
262 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0,eps=1e-8):
263 |
264 | super(ReduceLROnPlateau, self).__init__()
265 | assert isinstance(optimizer, Optimizer)
266 | if factor >= 1.0:
267 | raise ValueError('ReduceLROnPlateau '
268 | 'does not support a factor >= 1.0.')
269 | self.factor = factor
270 | self.min_lr = min_lr
271 | self.epsilon = epsilon
272 | self.patience = patience
273 | self.verbose = verbose
274 | self.cooldown = cooldown
275 | self.cooldown_counter = 0 # Cooldown counter.
276 | self.monitor_op = None
277 | self.wait = 0
278 | self.best = 0
279 | self.mode = mode
280 | self.optimizer = optimizer
281 | self.eps = eps
282 | self._reset()
283 |
284 | def _reset(self):
285 | """Resets wait counter and cooldown counter.
286 | """
287 | if self.mode not in ['min', 'max']:
288 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!')
289 | if self.mode == 'min':
290 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon)
291 | self.best = np.Inf
292 | else:
293 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon)
294 | self.best = -np.Inf
295 | self.cooldown_counter = 0
296 | self.wait = 0
297 |
298 | def reset(self):
299 | self._reset()
300 |
301 | def epoch_step(self, metrics, epoch):
302 | current = metrics
303 | if current is None:
304 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning)
305 | else:
306 | if self.in_cooldown():
307 | self.cooldown_counter -= 1
308 | self.wait = 0
309 |
310 | if self.monitor_op(current, self.best):
311 | self.best = current
312 | self.wait = 0
313 | elif not self.in_cooldown():
314 | if self.wait >= self.patience:
315 | for param_group in self.optimizer.param_groups:
316 | old_lr = float(param_group['lr'])
317 | if old_lr > self.min_lr + self.eps:
318 | new_lr = old_lr * self.factor
319 | new_lr = max(new_lr, self.min_lr)
320 | param_group['lr'] = new_lr
321 | if self.verbose > 0:
322 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr))
323 | self.cooldown_counter = self.cooldown
324 | self.wait = 0
325 | self.wait += 1
326 |
327 | def in_cooldown(self):
328 | return self.cooldown_counter > 0
329 |
330 | class ReduceLRWDOnPlateau(ReduceLROnPlateau):
331 | """Reduce learning rate and weight decay when a metric has stopped
332 | improving. Models often benefit from reducing the learning rate by
333 | a factor of 2-10 once learning stagnates. This scheduler reads a metric
334 | quantity and if no improvement is seen for a 'patience' number
335 | of epochs, the learning rate and weight decay factor is reduced for
336 | optimizers that implement the the weight decay method from the paper
337 | `Fixing Weight Decay Regularization in Adam`_.
338 |
339 | .. _Fixing Weight Decay Regularization in Adam:
340 | https://arxiv.org/abs/1711.05101
341 | for AdamW or SGDW
342 | Example:
343 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3)
344 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min')
345 | >>> for epoch in range(10):
346 | >>> train(...)
347 | >>> val_loss = validate(...)
348 | >>> # Note that step should be called after validate()
349 | >>> scheduler.epoch_step(val_loss)
350 | """
351 | def epoch_step(self, metrics, epoch):
352 | current = metrics
353 | if current is None:
354 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning)
355 | else:
356 | if self.in_cooldown():
357 | self.cooldown_counter -= 1
358 | self.wait = 0
359 |
360 | if self.monitor_op(current, self.best):
361 | self.best = current
362 | self.wait = 0
363 | elif not self.in_cooldown():
364 | if self.wait >= self.patience:
365 | for param_group in self.optimizer.param_groups:
366 | old_lr = float(param_group['lr'])
367 | if old_lr > self.min_lr + self.eps:
368 | new_lr = old_lr * self.factor
369 | new_lr = max(new_lr, self.min_lr)
370 | param_group['lr'] = new_lr
371 | if self.verbose > 0:
372 | print('\nEpoch %d: reducing learning rate to %s.' % (epoch, new_lr))
373 | if param_group['weight_decay'] != 0:
374 | old_weight_decay = float(param_group['weight_decay'])
375 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr)
376 | if old_weight_decay > new_weight_decay + self.eps:
377 | param_group['weight_decay'] = new_weight_decay
378 | if self.verbose:
379 | print('\nEpoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}.')
380 | self.cooldown_counter = self.cooldown
381 | self.wait = 0
382 | self.wait += 1
383 |
384 | class CosineLRWithRestarts(object):
385 | """Decays learning rate with cosine annealing, normalizes weight decay
386 | hyperparameter value, implements restarts.
387 | https://arxiv.org/abs/1711.05101
388 |
389 | Args:
390 | optimizer (Optimizer): Wrapped optimizer.
391 | batch_size: minibatch size
392 | epoch_size: training samples per epoch
393 | restart_period: epoch count in the first restart period
394 | t_mult: multiplication factor by which the next restart period will extend/shrink
395 |
396 | Example:
397 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2)
398 | >>> for epoch in range(100):
399 | >>> scheduler.step()
400 | >>> train(...)
401 | >>> ...
402 | >>> optimizer.zero_grad()
403 | >>> loss.backward()
404 | >>> optimizer.step()
405 | >>> scheduler.batch_step()
406 | >>> validate(...)
407 | """
408 |
409 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100,
410 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False):
411 | if not isinstance(optimizer, Optimizer):
412 | raise TypeError('{} is not an Optimizer'.format(
413 | type(optimizer).__name__))
414 | self.optimizer = optimizer
415 | if last_epoch == -1:
416 | for group in optimizer.param_groups:
417 | group.setdefault('initial_lr', group['lr'])
418 | else:
419 | for i, group in enumerate(optimizer.param_groups):
420 | if 'initial_lr' not in group:
421 | raise KeyError("param 'initial_lr' is not specified "
422 | "in param_groups[{}] when resuming an"
423 | " optimizer".format(i))
424 | self.base_lrs = list(map(lambda group: group['initial_lr'],
425 | optimizer.param_groups))
426 |
427 | self.last_epoch = last_epoch
428 | self.batch_size = batch_size
429 | self.iteration = 0
430 | self.epoch_size = epoch_size
431 | self.eta_threshold = eta_threshold
432 | self.t_mult = t_mult
433 | self.verbose = verbose
434 | self.base_weight_decays = list(map(lambda group: group['weight_decay'],
435 | optimizer.param_groups))
436 | self.restart_period = restart_period
437 | self.restarts = 0
438 | self.t_epoch = -1
439 | self.batch_increments = []
440 | self._set_batch_increment()
441 |
442 | def _schedule_eta(self):
443 | """
444 | Threshold value could be adjusted to shrink eta_min and eta_max values.
445 | """
446 | eta_min = 0
447 | eta_max = 1
448 | if self.restarts <= self.eta_threshold:
449 | return eta_min, eta_max
450 | else:
451 | d = self.restarts - self.eta_threshold
452 | k = d * 0.09
453 | return (eta_min + k, eta_max - k)
454 |
455 | def get_lr(self, t_cur):
456 | eta_min, eta_max = self._schedule_eta()
457 |
458 | eta_t = (eta_min + 0.5 * (eta_max - eta_min)
459 | * (1. + math.cos(math.pi *
460 | (t_cur / self.restart_period))))
461 |
462 | weight_decay_norm_multi = math.sqrt(self.batch_size /
463 | (self.epoch_size *
464 | self.restart_period))
465 | lrs = [base_lr * eta_t for base_lr in self.base_lrs]
466 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi
467 | for base_weight_decay in self.base_weight_decays]
468 |
469 | if self.t_epoch % self.restart_period < self.t_epoch:
470 | if self.verbose:
471 | print("Restart at epoch {}".format(self.last_epoch))
472 | self.restart_period *= self.t_mult
473 | self.restarts += 1
474 | self.t_epoch = 0
475 |
476 | return zip(lrs, weight_decays)
477 |
478 | def _set_batch_increment(self):
479 | d, r = divmod(self.epoch_size, self.batch_size)
480 | batches_in_epoch = d + 2 if r > 0 else d + 1
481 | self.iteration = 0
482 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch))
483 |
484 | def batch_step(self):
485 | self.last_epoch += 1
486 | self.t_epoch += 1
487 | self._set_batch_increment()
488 | try:
489 | t_cur = self.t_epoch + self.batch_increments[self.iteration]
490 | self.iteration += 1
491 | except (IndexError):
492 | raise RuntimeError("Epoch size and batch size used in the "
493 | "training loop and while initializing "
494 | "scheduler should be the same.")
495 |
496 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,self.get_lr(t_cur)):
497 | param_group['lr'] = lr
498 | param_group['weight_decay'] = weight_decay
499 |
500 |
501 | class NoamLR(object):
502 | '''
503 | 主要参考论文<< Attention Is All You Need>>中的学习更新方式
504 | Example:
505 | >>> scheduler = NoamLR(d_model,factor,warm_up,optimizer)
506 | >>> for epoch in range(100):
507 | >>> scheduler.step()
508 | >>> train(...)
509 | >>> ...
510 | >>> glopab_step += 1
511 | >>> optimizer.zero_grad()
512 | >>> loss.backward()
513 | >>> optimizer.step()
514 | >>> scheduler.batch_step(global_step)
515 | >>> validate(...)
516 | '''
517 | def __init__(self,d_model,factor,warm_up,optimizer):
518 | self.optimizer = optimizer
519 | self.warm_up = warm_up
520 | self.factor = factor
521 | self.d_model = d_model
522 | self._lr = 0
523 |
524 | def get_lr(self,step):
525 | lr = self.factor * (self.d_model ** (-0.5) * min(step ** (-0.5),step * self.warm_up ** (-1.5)))
526 | return lr
527 |
528 | def batch_step(self,step):
529 | '''
530 | update parameters and rate
531 | :return:
532 | '''
533 | lr = self.get_lr(step)
534 | for p in self.optimizer.param_groups:
535 | p['lr'] = lr
536 | self._lr = lr
537 |
--------------------------------------------------------------------------------