├── data └── similar │ └── label.txt ├── README.md ├── utils.py ├── model.py ├── .gitignore ├── main.py ├── Rotransformer.py ├── trainer.py ├── data_loader.py └── LICENSE /data/similar/label.txt: -------------------------------------------------------------------------------- 1 | 完全匹配 2 | 不匹配 3 | 部分匹配 4 | UNK -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Roformer_Simlarity 2 | 基于Roformer的文本相似度 3 | 4 | ## Bert预训练模型请自行下载,pytorch版就行 5 | ## 运行 6 | ``` 7 | python main --do_train 8 | ``` 9 | 10 | ## 参考 11 | 苏剑林:https://kexue.fm/archives/8397 12 | 13 | 论文:https://arxiv.org/abs/2104.09864 14 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Time:2021/6/29:45 3 | # @author: SinGaln 4 | 5 | """utils文件""" 6 | import os 7 | import torch 8 | import random 9 | import logging 10 | import numpy as np 11 | from model import BertModelOutputs 12 | from transformers import BertConfig, BertTokenizer 13 | from sklearn.metrics import precision_score, recall_score, f1_score 14 | 15 | MODEL_CLASSES = { 16 | "bert":(BertConfig, BertModelOutputs, BertTokenizer) 17 | } 18 | 19 | MODEL_PATH_MAP = { 20 | "bert":"./chinese_bert_wwm" 21 | } 22 | 23 | # 获取label(完全匹配, 部分匹配, 不匹配) 24 | def get_labels(args): 25 | return [label.strip() for label in 26 | open(os.path.join(args.data_dir, args.task, args.label_file), "r", encoding="utf-8")] 27 | 28 | # 加载tokenizer 29 | def load_tokenizer(args): 30 | return MODEL_CLASSES[args.model_type][2].from_pretrained(args.pretrained_model_path) 31 | 32 | # 设置logger 33 | def init_logger(): 34 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 35 | datefmt="%Y/%m/%d %H:%M:%S", 36 | level=logging.INFO) 37 | 38 | # 设置种子 39 | def set_seed(args): 40 | random.seed(args.seed) 41 | np.random.seed(args.seed) 42 | torch.manual_seed(args.seed) 43 | if torch.cuda.is_available(): 44 | torch.cuda.manual_seed_all(args.seed) 45 | 46 | # 计算precision_score, recall_score, f1_score 47 | def get_metrics(pred_label, true_label): 48 | assert len(pred_label) == len(true_label) 49 | return { 50 | "precision_score":precision_score(true_label, pred_label, average="macro"), 51 | "recall_score": recall_score(true_label, pred_label, average="macro"), 52 | "f1": f1_score(true_label, pred_label, average="macro") 53 | } 54 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @Time : 2021/6/1 21:39 4 | @Author : SinGaln 5 | """ 6 | from Rotransformer import RoTransformerEncoder 7 | from transformers import BertModel, BertPreTrainedModel 8 | 9 | class BertModelOutputs(BertPreTrainedModel): 10 | def __init__(self, config, args): 11 | super(BertModelOutputs, self).__init__(config) 12 | self.args = args 13 | self.BertOutput = BertModel(config=config) 14 | 15 | self.ro_transformer = RoTransformerEncoder(args) 16 | 17 | def forward(self, input_ids, token_type_ids, attention_mask): 18 | outputs = self.BertOutput(input_ids=input_ids, token_type_ids=token_type_ids, 19 | attention_mask=attention_mask) 20 | sequence_outputs = outputs[0] 21 | pooled_outputs = outputs[1] 22 | logits = self.ro_transformer(sequence_outputs) 23 | return logits 24 | 25 | # if __name__=="__main__": 26 | # parser = argparse.ArgumentParser() 27 | # """配置参数测试""" 28 | # parser.add_argument("--embedding_size", default=768, type=int,required=True, help="The hidden size of model.") 29 | # parser.add_argument("--hidden_size", default=1024, type=int,required=True, help="The hidden size of model.") 30 | # parser.add_argument("--num_attention_heads", default=12, type=int, required=True, help="The number of attention heads.") 31 | # parser.add_argument("--attention_dropout_prob", default=0.2, type=float, required=True, help="The dropout rate of attention.") 32 | # parser.add_argument("--feed_dropout_rate", default=0.1, type=float, required=True, help="The dropout rate of attention.") 33 | # 34 | # args = parser.parse_args() 35 | # model = RoTransformerEncoder(args) 36 | # a = model(torch.rand(32, 512, 768)) 37 | # print(a.shape) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Time:2021/6/215:14 3 | # @author: SinGaln 4 | 5 | import argparse 6 | from trainer import Trainer 7 | from data_loader import load_and_cache_examples 8 | from utils import init_logger, load_tokenizer, set_seed 9 | 10 | def main(args): 11 | init_logger() 12 | set_seed(args) 13 | tokenizer = load_tokenizer(args) 14 | 15 | train_data = load_and_cache_examples(args, tokenizer, mode="train") 16 | test_data = load_and_cache_examples(args, tokenizer, mode="test") 17 | 18 | trainer = Trainer(args, train_data, test_data) 19 | 20 | if args.do_train: 21 | trainer.train() 22 | if args.do_eval: 23 | trainer.load_model() 24 | trainer.evaluate("test") 25 | 26 | if __name__=="__main__": 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument("--data_dir", default="./data", type=str, help="The path of data.") 30 | parser.add_argument("--task", default="similar", type=str, help="The name of task.") 31 | parser.add_argument("--label_file", default="label.txt",type=str, help="label file path.") 32 | parser.add_argument("--model_type", default="bert", type=str, help="Pretrained model name.") 33 | parser.add_argument("--model_dir", default="./similar_bert", type=str, help="Save path of new model.") 34 | parser.add_argument("--pretrained_model_path", default="./chinese_bert_wwm", type=str, help="Pretrained model path.") 35 | parser.add_argument("--seed", default=1234, type=int, help="The seed of random.") 36 | parser.add_argument("--max_seq_len", default=100, type=int, help="The max sequence length of data.") 37 | parser.add_argument("--ignore_index", default=0, type=int, help="Specifies a target value that is ignored and does not distribute to the input gradient.") 38 | parser.add_argument("--embedding_size", default=768, type=int, help="Embedding size of input data.") 39 | parser.add_argument("--num_attention_heads", default=12, type=int, help="The number of attention heads.") 40 | parser.add_argument("--attention_dropout_prob", default=0.1, type=float, help="The dropout rate of multi attention model.") 41 | parser.add_argument("--hidden_size", default=1024, type=int, help="The hidden size of model middle layer.") 42 | parser.add_argument("--feed_dropout_rate", default=0.1, type=int, help="The dropout rate of feed forward layer.") 43 | parser.add_argument("--max_steps", default=-1, type=int, help="Set total number of training steps to perform.") 44 | parser.add_argument("--num_train_epochs", default=5.0, type=float, help="Total number of training epochs to perform.") 45 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, help="Number of updates steps to accumulate before performing a backwrd pass.") 46 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 47 | parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.") 48 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, help="The epsilon value of Adam.") 49 | parser.add_argument("--warm_up", default=0, type=int, help="Linear warmup over warmup steps.") 50 | parser.add_argument("--train_batch_size", default=32, type=int, help="Batch size for training.") 51 | parser.add_argument("--logging_steps", default=1000, type=int, help="Log every x updates steps.") 52 | parser.add_argument("--save_steps", default=999, type=int, help="Save checkpoint every x updates steps.") 53 | parser.add_argument("--vocab_size", default=8007, type=int, help="The size for vocab.") 54 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 55 | parser.add_argument("--eval_batch_size", default=32, type=int, help="Batch size for evaluate.") 56 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 57 | parser.add_argument("--do_eval", action="store_true", help="Whether to run evaluate.") 58 | 59 | args = parser.parse_args() 60 | main(args) 61 | -------------------------------------------------------------------------------- /Rotransformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Time:2021/6/215:57 3 | # @author: SinGaln 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class MultiHeadsAttention(nn.Module): 11 | def __init__(self, args): 12 | super(MultiHeadsAttention, self).__init__() 13 | self.args = args 14 | 15 | if args.embedding_size % args.num_attention_heads != 0: 16 | raise ValueError( 17 | "The hidden size (%d) is not a multiple of the number of attention " 18 | "heads (%d)" % (args.embedding_size, args.num_attention_heads)) 19 | self.num_attention_heads = args.num_attention_heads 20 | self.attention_head_size = int(args.embedding_size / args.num_attention_heads) 21 | self.all_head_size = self.num_attention_heads * self.attention_head_size 22 | 23 | self.query = nn.Linear(args.embedding_size, self.all_head_size) 24 | self.key = nn.Linear(args.embedding_size, self.all_head_size) 25 | self.value = nn.Linear(args.embedding_size, self.all_head_size) 26 | 27 | self.dropout = nn.Dropout(args.attention_dropout_prob) 28 | 29 | def transpose_for_scores(self, x): 30 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 31 | x = x.view(*new_x_shape) 32 | return x.permute(0, 2, 1, 3) 33 | 34 | def sinusoidal_position_embeddings(self, inputs): 35 | output_dim = self.args.embedding_size // self.args.num_attention_heads 36 | seq_len = inputs.size(1) 37 | position_ids = torch.arange( 38 | 0, seq_len, dtype=torch.float32, device=inputs.device) 39 | 40 | indices = torch.arange( 41 | 0, output_dim // 2, dtype=torch.float32, device=inputs.device) 42 | indices = torch.pow(10000.0, -2 * indices / output_dim) 43 | embeddings = torch.einsum('n,d->nd', position_ids, indices) # [seq_len, output_dim // 2] 44 | embeddings = torch.stack([embeddings.sin(), embeddings.cos()], dim=-1) # [seq_len, output_dim // 2, 2] 45 | embeddings = torch.reshape(embeddings, (seq_len, output_dim)) # [seq_len, output_dim] 46 | embeddings = embeddings[None, None, :, :] # [1, 1, seq_len, output_dim] 47 | return embeddings 48 | 49 | def forward(self, inputs, attention_mask=None): 50 | mixed_query_layer = self.query(inputs) # [batch_size, seq_len, hidden_size] 51 | mixed_key_layer = self.key(inputs) # [batch_size, seq_len, hidden_size] 52 | mixed_value_layer = self.value(inputs) # [batch_size, seq_len, hidden_size] 53 | 54 | query_layer = self.transpose_for_scores(mixed_query_layer) # [batch_size, num_heads, seq_len, heads_size] 55 | key_layer = self.transpose_for_scores(mixed_key_layer) # [batch_size, num_heads, seq_len, heads_size] 56 | value_layer = self.transpose_for_scores(mixed_value_layer) # [batch_size, num_heads, seq_len, heads_size] 57 | 58 | sinusoidal_positions = self.sinusoidal_position_embeddings(inputs) 59 | # 计算cos 60 | cos_pos = torch.repeat_interleave(sinusoidal_positions[..., 1::2], 2, dim=-1) 61 | # 计算sin 62 | sin_pos = torch.repeat_interleave(sinusoidal_positions[..., ::2], 2, dim=-1) 63 | ''' 64 | query_layer[..., 1::2]为按列取最后一维的偶数列 shape:[batch_size, num_heads, seq_len, head_dim / 2] 65 | query_layer[..., ::2]为按列取的最后一维的奇数列 shape:[batch_size, num_heads, seq_len, head_dim / 2] 66 | 67 | 通过stack拼接后得到的为增加了一维,如下例所示: 68 | a = [[[1, 2, 3], 69 | [4, 5, 6], 70 | [7, 8, 9]]] 71 | 72 | b = [[[10, 20, 30], 73 | [40, 50, 60], 74 | [70, 80, 90]]] 75 | 76 | c = torch.stack(a,b,dim=0) 77 | tensor([[[[ 1, 2, 3], 78 | [ 4, 5, 6], 79 | [ 7, 8, 9]]], 80 | 81 | [[[10, 20, 30], 82 | [40, 50, 60], 83 | [70, 80, 90]]]]) torch.Size([2, 1, 3, 3]) 84 | d = torch.stack(a,b,dim=1) 85 | tensor([[[[ 1, 2, 3], 86 | [ 4, 5, 6], 87 | [ 7, 8, 9]], 88 | 89 | [[10, 20, 30], 90 | [40, 50, 60], 91 | [70, 80, 90]]]]) torch.Size([1, 2, 3, 3]) 92 | e = torch.stack(a,b,dim=-1) 93 | tensor([[[[ 1, 10], 94 | [ 2, 20], 95 | [ 3, 30]], 96 | 97 | [[ 4, 40], 98 | [ 5, 50], 99 | [ 6, 60]], 100 | 101 | [[ 7, 70], 102 | [ 8, 80], 103 | [ 9, 90]]]]) torch.Size([1, 3, 3, 2]) 104 | 通过以上例子就可以知道,这两个矩阵拼接后的维度增加了一维,并且是两个矩阵最后一维的元素进行拼接,如上述的e一样, 105 | 所以torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]],dim=-1) shape:[batch_size, num_heads, seq_len, head_size/2, 2] 106 | 最后通过reshape把最后的两维进行合并得到qw2,kw2 shape:[batch_size, num_heads,seq_len, head_dim] 107 | ''' 108 | qw2 = torch.stack([-query_layer[..., 1::2], query_layer[..., ::2]], 109 | dim=-1).reshape_as(query_layer) # [batch_size, num_heads,seq_len, head_dim] 110 | query_layer = query_layer * cos_pos + qw2 * sin_pos # [batch_size, num_heads, seq_len, head_dim] 111 | kw2 = torch.stack([-key_layer[..., 1::2], key_layer[..., ::2]], 112 | dim=-1).reshape_as(key_layer) # [batch_size, num_heads,seq_len, head_dim] 113 | key_layer = key_layer * cos_pos + kw2 * sin_pos # [batch_size, num_heads, seq_len, head_dim] 114 | 115 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [batch_size, num_heads, seq_len, seq_len] 116 | attention_scores = attention_scores / math.sqrt(self.all_head_size) 117 | 118 | if attention_mask is not None: 119 | attention_scores = attention_scores + attention_mask 120 | 121 | # 对attention scores 按列进行归一化 122 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 123 | # dropout 124 | attention_probs = self.dropout(attention_probs) 125 | context_layer = torch.matmul(attention_probs, value_layer) # [batch_size, num_heads, seq_len, head_dim] 126 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [batch_size, seq_len, num_heads, head_dim] 127 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 128 | context_layer = context_layer.view(*new_context_layer_shape) # [batch_size, seq_len, embedding_size] 129 | return context_layer, attention_scores 130 | 131 | class Position_Wise_Feed_Forward(nn.Module): 132 | def __init__(self, args): 133 | super(Position_Wise_Feed_Forward, self).__init__() 134 | self.args = args 135 | 136 | self.linear1 = nn.Linear(args.embedding_size, args.hidden_size) 137 | self.linear2 = nn.Linear(args.hidden_size, args.embedding_size) 138 | self.dropout = nn.Dropout(args.feed_dropout_rate) 139 | self.layer_norm = nn.LayerNorm(args.embedding_size) 140 | 141 | def forward(self, x): 142 | # x:[batch_size, seq_len, embedding_size] 143 | outputs = self.dropout(self.linear2(nn.functional.relu(self.linear1(x)))) 144 | outputs = outputs + x # 残差连接 145 | outputs = self.layer_norm(outputs) 146 | return outputs 147 | 148 | class Pooler(nn.Module): 149 | def __init__(self, args): 150 | super().__init__() 151 | self.dense = nn.Linear(args.embedding_size, args.embedding_size) 152 | self.activation = nn.Tanh() 153 | 154 | def forward(self, hidden_states): 155 | # We "pool" the model by simply taking the hidden state corresponding 156 | # to the first token. 157 | first_token_tensor = hidden_states[:, 0] 158 | pooled_output = self.dense(first_token_tensor) 159 | pooled_output = self.activation(pooled_output) 160 | return pooled_output 161 | 162 | class RoTransformerEncoder(nn.Module): 163 | def __init__(self, args): 164 | super(RoTransformerEncoder, self).__init__() 165 | self.args = args 166 | 167 | self.multi_attention = MultiHeadsAttention(args) 168 | self.feed_forward = Position_Wise_Feed_Forward(args) 169 | self.pooler = Pooler(args) 170 | self.dense = nn.Linear(args.embedding_size, 4) 171 | 172 | 173 | def forward(self, x): 174 | context, attention_score = self.multi_attention(x) 175 | outputs = self.feed_forward(context) 176 | outputs = self.pooler(outputs) 177 | print("outputs", outputs, outputs.shape) 178 | logits = self.dense(outputs) 179 | print("logits", logits, logits.shape) 180 | return logits -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Time:2021/6/213:42 3 | # @author: SinGaln 4 | 5 | """训练文件""" 6 | import os 7 | import torch 8 | import logging 9 | import numpy as np 10 | from tqdm import tqdm, trange 11 | from utils import MODEL_CLASSES, get_metrics, get_labels 12 | from transformers import AdamW, get_linear_schedule_with_warmup 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class Trainer(object): 19 | def __init__(self, args, train_data=None, test_data=None): 20 | self.args = args 21 | self.train_data = train_data 22 | self.test_data = test_data 23 | self.label_lst = get_labels(args) 24 | 25 | self.pad_token_label_id = args.ignore_index 26 | # 模型初始化 27 | self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type] 28 | self.configs = self.config_class.from_pretrained(args.pretrained_model_path, 29 | finetuning_task=self.args.task) 30 | self.model = self.model_class.from_pretrained(args.pretrained_model_path, 31 | args=args, 32 | config=self.configs) 33 | 34 | # 设别选择(GPU or CPU) 35 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 36 | self.model.to(self.device) 37 | # 多GPU 38 | if torch.cuda.device_count() > 1: 39 | self.model = torch.nn.DataParallel(self.model) 40 | 41 | def train(self): 42 | train_sampler = RandomSampler(self.train_data) 43 | train_loader = DataLoader(self.train_data, sampler=train_sampler, batch_size=self.args.train_batch_size) 44 | 45 | if self.args.max_steps > 0: 46 | total_steps = self.args.max_steps 47 | self.args.num_train_epochs = self.args.max_steps // ( 48 | len(train_loader) // self.args.gradient_accumulation_steps) + 1 49 | else: 50 | total_steps = len(train_loader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs 51 | 52 | # optimizer and schedule 53 | no_decay = ["bias", "LayerNorm.weight"] 54 | # optimizer parameters setting 55 | optimizer_grouped_parameters = [ 56 | {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 57 | "weight_decay": self.args.weight_decay}, 58 | {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 59 | "weight_decay": 0.0} 60 | ] 61 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon) 62 | schedule = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warm_up, 63 | num_training_steps=total_steps) 64 | 65 | # train information 66 | logger.info("********** Running Training **********") 67 | logger.info("num example = %d", len(self.train_data)) 68 | logger.info("num epochs = %d", self.args.num_train_epochs) 69 | logger.info("train batch size = %d", self.args.train_batch_size) 70 | logger.info("gradient accumulation steps = %d", self.args.gradient_accumulation_steps) 71 | logger.info("total steps = %d", total_steps) 72 | logger.info("logger steps = %d", self.args.logging_steps) 73 | logger.info("save steps = %d", self.args.save_steps) 74 | 75 | global_steps = 0 76 | tr_loss = 0.0 77 | loss_fun = torch.nn.CrossEntropyLoss() 78 | self.model.zero_grad() 79 | 80 | train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch") 81 | for _ in train_iterator: 82 | epoch_iterator = tqdm(train_loader) 83 | for step, batch in enumerate(epoch_iterator): 84 | self.model.train() 85 | batch = tuple(t.to(self.device) for t in batch) 86 | input_ids, attention_mask, token_type_ids, label_id = batch 87 | # print(input_ids) 88 | # print(attention_mask) 89 | # print(token_type_ids) 90 | # print(label_id) 91 | output = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 92 | # print(output[-1]) 93 | # print(label_id.view(-1)) 94 | loss = loss_fun(output.view(-1, len(self.label_lst)), label_id) 95 | print() 96 | print("loss",loss) 97 | if torch.cuda.device_count() > 1: 98 | loss = loss.mean() 99 | if self.args.gradient_accumulation_steps > 1: 100 | loss = loss / self.args.gradient_accumulation_steps 101 | else: 102 | loss.backward() 103 | tr_loss += loss.item() 104 | 105 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 106 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 107 | optimizer.step() 108 | schedule.step() 109 | optimizer.zero_grad() 110 | global_steps += 1 111 | 112 | if self.args.logging_steps > 0 and global_steps % self.args.logging_steps == 0: 113 | self.evaluate("test") 114 | if self.args.save_steps > 0 and global_steps % self.args.save_steps == 0: 115 | self.save_model() 116 | 117 | if 0 < self.args.max_steps < global_steps: 118 | epoch_iterator.close() 119 | break 120 | return global_steps, tr_loss / global_steps 121 | 122 | def evaluate(self, mode): 123 | if mode == "test": 124 | dataset = self.test_data 125 | else: 126 | raise Exception("The dataset is not existing!") 127 | 128 | eval_sampler = SequentialSampler(dataset) 129 | eval_loader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size) 130 | 131 | # eval logging 132 | logger.info("********** logger information **********") 133 | logger.info("num example = %d", len(dataset)) 134 | logger.info("batch size = %d", self.args.eval_batch_size) 135 | eval_loss = 0.0 136 | eval_steps = 0 137 | loss_fun = torch.nn.CrossEntropyLoss() 138 | label_preds = None 139 | label_ids = None 140 | 141 | self.model.eval() 142 | for batch in tqdm(eval_loader): 143 | batch = tuple(t.to(self.device) for t in batch) 144 | with torch.no_grad(): 145 | input_ids, attention_mask, token_type_ids, label_id = batch 146 | outputs = self.model(input_ids, attention_mask, token_type_ids) 147 | 148 | eval_loss = loss_fun(outputs.view(-1, len(self.label_lst)), label_id) 149 | eval_loss + eval_loss.mean().item() 150 | eval_steps += 1 151 | 152 | # predict 153 | if label_preds is None: 154 | label_preds = outputs.detach().cpu().numpy() 155 | label_ids = label_id.detach().cpu().numpy() 156 | # else: 157 | # label_preds = np.append(label_preds, outputs.detach().cpu().numpy(), axis=0) 158 | # label_ids = np.append(label_ids, label_id.detach().cpu().numpy(), axis=0) 159 | eval_loss = eval_loss / eval_steps 160 | results = { 161 | "loss": eval_loss 162 | } 163 | 164 | label_preds = np.argmax(label_preds, axis=1) 165 | print("label_preds", label_preds, len(label_preds)) 166 | print("label_ids", label_ids, len(label_ids)) 167 | total_result = get_metrics(label_preds, label_ids) 168 | results.update(total_result) 169 | 170 | logger.info("********** Evaluate Results **********") 171 | for key in sorted(results.keys()): 172 | logger.info("%s = %s", key, str(results[key])) 173 | return results 174 | 175 | def save_model(self): 176 | # 模型保存 177 | if not os.path.exists(self.args.model_dir): 178 | os.makedirs(self.args.model_dir) 179 | model_to_save = self.model.module if hasattr(self.model, "module") else self.model 180 | model_to_save.save_pretrained(self.args.model_dir) 181 | 182 | # 保存模型训练的超参 183 | torch.save(self.args, os.path.join(self.args.model_dir, "train_args.bin")) 184 | logger.info("model parameters save %s", self.args.model_dir) 185 | 186 | def load_model(self): 187 | # 加载模型 188 | if os.path.exists(self.args.model_dir): 189 | raise Exception("The model is not existing!") 190 | try: 191 | self.model = self.model_class.from_pretrained(self.args.model_dir, 192 | args=self.args) 193 | self.model.to(self.device) 194 | logger.info("********** model load success **********") 195 | except: 196 | raise Exception("The model lost or damage!") 197 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Time:2021/6/210:11 3 | # @author: SinGaln 4 | 5 | """处理加载数据""" 6 | import os 7 | import ast 8 | import copy 9 | import json 10 | import torch 11 | import logging 12 | from utils import get_labels 13 | from torch.utils.data import TensorDataset 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class InputExample(object): 19 | """ 20 | 输入的为两个句子,利用[SEP]进行分隔 21 | Args: 22 | guid: 实例的唯一id 23 | text_a: 句对中的第一个句子 24 | text_b: 句对中的第二个句子 25 | label: 实例对象的标签 26 | """ 27 | 28 | def __init__(self, guid, text_a, text_b, label): 29 | self.guid = guid 30 | self.text_a = text_a 31 | self.text_b = text_b 32 | self.label = label 33 | 34 | def __repr__(self): 35 | return str(self.to_json_string()) 36 | 37 | def to_dict(self): 38 | output = copy.deepcopy(self.__dict__) 39 | return output 40 | 41 | def to_json_string(self): 42 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 43 | 44 | 45 | class InputFeatures(object): 46 | def __init__(self, input_ids, attention_mask, token_type_ids, labels, input_length): 47 | self.input_ids = input_ids 48 | self.attention_mask = attention_mask 49 | self.token_type_ids = token_type_ids 50 | self.input_length = input_length 51 | self.labels = labels 52 | 53 | def __repr__(self): 54 | return str(self.to_json_string()) 55 | 56 | def to_dict(self): 57 | output = copy.deepcopy(self.__dict__) 58 | return output 59 | 60 | def to_json_string(self): 61 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 62 | 63 | 64 | class DataProcessor(object): 65 | """对输入的数据进行处理""" 66 | 67 | def __init__(self, args): 68 | self.args = args 69 | self.labels = get_labels(args) 70 | 71 | self.input_text_file = "data.txt" 72 | 73 | @classmethod 74 | def _read_file(cls, input_file): 75 | return_list = [] 76 | with open(input_file, "r", encoding="utf-8") as f: 77 | data_lst = f.readlines() 78 | for sentences in data_lst: 79 | sentences_dict = ast.literal_eval(sentences) 80 | query_text = sentences_dict.get("query", False) 81 | condition_lst = sentences_dict.get("candidate", False) 82 | for condition_dic in condition_lst: 83 | condition_text = condition_dic.get("text", False) 84 | condition_label = condition_dic.get("label", False) 85 | return_list.append((query_text, condition_text, condition_label)) 86 | return return_list 87 | 88 | def _create_examples(self, texts, set_type): 89 | """创建训练, 测试数据的实例""" 90 | examples = [] 91 | for i, data in enumerate(texts): 92 | guid = "%s-%s" % (set_type, i) 93 | # query的文本内容 94 | words = [word for word in data[0]] 95 | # condition_text的文本内容 96 | condition_text = [cond for cond in data[1]] 97 | # 标签转为id 98 | labels = self.labels.index(data[2]) if data[2] in self.labels else self.labels.index("UNK") 99 | examples.append(InputExample(guid=guid, text_a=words, text_b=condition_text, label=labels)) 100 | return examples 101 | 102 | def get_examples(self, mode): 103 | """ 104 | 更加mode返回相应的训练,测试数据 105 | Args: 106 | mode:train, test 107 | """ 108 | data_path = os.path.join(self.args.data_dir, self.args.task, mode) 109 | logger.info("数据加载 {}".format(data_path)) 110 | return self._create_examples(texts=self._read_file(os.path.join(data_path, self.input_text_file)), 111 | set_type=mode) 112 | 113 | 114 | processors = { 115 | "similar": DataProcessor 116 | } 117 | 118 | def concat_seq_pair(tokens_a, tokens_b, max_seq_len): 119 | while True: 120 | total_length = len(tokens_a) + len(tokens_b) 121 | if total_length < max_seq_len: 122 | break 123 | if len(tokens_a) > len(tokens_b): 124 | tokens_a.pop() 125 | else: 126 | tokens_b.pop() 127 | 128 | def convert_examples_to_features(examples, max_seq_len, tokenizer, 129 | cls_token_segment_id=0, 130 | sequence_a_segment_id=0, 131 | sequence_b_segment_id=1, 132 | sep_token_segment_id=1): 133 | # 基本设置 134 | cls_token = tokenizer.cls_token 135 | sep_token = tokenizer.sep_token 136 | unk_token = tokenizer.unk_token 137 | pad_token_id = tokenizer.pad_token_id 138 | 139 | features = [] 140 | for (example_index, example) in enumerate(examples): 141 | if example_index % 1000 == 0: 142 | logger.info("Writing example %d of %d" % (example_index, len(examples))) 143 | text_a = example.text_a 144 | text_b = example.text_b 145 | labels = example.label 146 | 147 | tokens_a = [] 148 | tokens_b = [] 149 | for word in text_a: 150 | word_tokens = tokenizer.tokenize(word) 151 | if not word_tokens: 152 | word_tokens = [unk_token] 153 | tokens_a.extend(word_tokens) 154 | 155 | for word in text_b: 156 | word_tokens = tokenizer.tokenize(word) 157 | if not word_tokens: 158 | word_tokens = [unk_token] 159 | tokens_b.extend(word_tokens) 160 | 161 | concat_seq_pair(tokens_a, tokens_b, max_seq_len) 162 | tokens = [] 163 | token_type_ids = [] 164 | tokens.append(cls_token) 165 | token_type_ids.append(cls_token_segment_id) 166 | for token in tokens_a: 167 | tokens.append(token) 168 | token_type_ids.append(sequence_a_segment_id) 169 | tokens.append(sep_token) 170 | token_type_ids.append(cls_token_segment_id) 171 | 172 | for token in tokens_b: 173 | tokens.append(token) 174 | token_type_ids.append(sequence_b_segment_id) 175 | tokens.append(sep_token) 176 | token_type_ids.append(sep_token_segment_id) 177 | 178 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 179 | 180 | attention_mask = [1] * len(input_ids) 181 | 182 | # padding 183 | while len(input_ids) < max_seq_len: 184 | input_ids.append(pad_token_id) 185 | attention_mask.append(pad_token_id) 186 | token_type_ids.append(pad_token_id) 187 | # # 先拼接再padding 188 | # tokens = [cls_token] + tokens_a + [sep_token] + tokens_b + [sep_token] 189 | # input_ids = tokenizer.convert_tokens_to_ids(tokens) 190 | # if len(input_ids) > max_seq_len: 191 | # input_ids = input_ids[:(max_seq_len - len(input_ids))] 192 | # else: 193 | # input_ids = input_ids + (max_seq_len - len(input_ids)) * [pad_token_id] 194 | # # print(input_ids, len(input_ids)) 195 | # token_type_ids = [cls_token_segment_id] + (len(tokens_a) * [sequence_a_segment_id]) + [cls_token_segment_id] + \ 196 | # (max_seq_len - len([cls_token_segment_id] + (len(tokens_a) * [sequence_a_segment_id]) + [cls_token_segment_id])) * [sep_token_segment_id] 197 | # # print(token_type_ids, len(token_type_ids)) 198 | # attention_mask = [] 199 | # for i in input_ids: 200 | # if i != 0: 201 | # attention_mask.append(1) 202 | # else: 203 | # attention_mask.append(0) 204 | # print(attention_mask, len(attention_mask)) 205 | # special_token_count = 2 206 | # if len(tokens_a) > max_seq_len - special_token_count: 207 | # tokens_a = tokens_a[:(max_seq_len - special_token_count)] 208 | # 209 | # # 增加[SEP]和[CLS] 210 | # tokens_a += [sep_token] 211 | # token_a_type_ids = [sequence_a_segment_id] * len(tokens_a) 212 | # tokens_b += [sep_token] 213 | # token_b_type_ids = [sequence_b_segment_id] * len(tokens_b) 214 | # # 增加[CLS] 215 | # tokens_a = [cls_token] + tokens_a 216 | # token_a_type_ids = [cls_token_segment_id] + token_a_type_ids 217 | # # 拼接token_a和token_b 218 | # tokens = tokens_a + tokens_b 219 | # token_type_ids = token_a_type_ids + token_b_type_ids 220 | # 221 | # input_ids = tokenizer.convert_tokens_to_ids(tokens) 222 | # attention mask 用1表示真实token, 0表示padding token 223 | 224 | assert len(input_ids) == max_seq_len, "Error with input length {} vs {}".format(len(input_ids), max_seq_len) 225 | assert len(attention_mask) == max_seq_len, "Error with attention mask {} vs {}".format(len(attention_mask), 226 | max_seq_len) 227 | assert len(token_type_ids) == max_seq_len, "Error with token type {} vs {}".format(len(token_type_ids), 228 | max_seq_len) 229 | label_id = int(labels) 230 | 231 | if example_index < 5: 232 | logger.info("*** Example ***") 233 | logger.info("guid: %s" % example.guid) 234 | logger.info("tokens: % s" % " ".join([str(x) for x in tokens])) 235 | logger.info("inputs_ids: % s" % " ".join([str(x) for x in input_ids])) 236 | logger.info("token_type_ids: % s" % " ".join([str(x) for x in token_type_ids])) 237 | logger.info("attention_mask: % s" % " ".join([str(x) for x in attention_mask])) 238 | logger.info("labels: % s (id = %d)" % (example.label, label_id)) 239 | 240 | features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, 241 | labels=label_id, input_length=len(input_ids))) 242 | 243 | return features 244 | 245 | 246 | def load_and_cache_examples(args, tokenizer, mode): 247 | processor = processors[args.task](args) 248 | 249 | cached_features_file = os.path.join( 250 | args.data_dir, 251 | "{}_{}_{}_{}".format( 252 | mode, 253 | args.task, 254 | list(filter(None, args.pretrained_model_path.split("/"))).pop(), 255 | args.max_seq_len 256 | ) 257 | ) 258 | 259 | if os.path.exists(cached_features_file): 260 | logger.info("Loading features from cached file %s", cached_features_file) 261 | features = torch.load(cached_features_file) 262 | else: 263 | logger.info("Creating features from dataset file at %s", args.data_dir) 264 | if mode == "train": 265 | examples = processor.get_examples("train") 266 | elif mode == "test": 267 | examples = processor.get_examples("test") 268 | else: 269 | raise Exception("The mode only include train, test!") 270 | 271 | pad_token_label_id = args.ignore_index 272 | features = convert_examples_to_features(examples, args.max_seq_len, tokenizer, pad_token_label_id) 273 | 274 | logger.info("Save features into cache file %s", cached_features_file) 275 | torch.save(features, cached_features_file) 276 | 277 | input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 278 | attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 279 | token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 280 | label_ids = torch.tensor([f.labels for f in features], dtype=torch.long) 281 | 282 | dataset = TensorDataset(input_ids, attention_mask, token_type_ids, label_ids) 283 | return dataset 284 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------