├── data └── dgre │ └── raw_data │ └── process.py ├── evaluate.py ├── generic.py ├── README.md ├── doccano.py ├── export_model.py ├── doccano.md ├── model.py ├── finetune.py ├── convert.py ├── uie_predictor.py └── utils.py /data/dgre/raw_data/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import codecs 3 | 4 | data_path = "train.json" 5 | 6 | with codecs.open(data_path, 'r', encoding="utf-8") as fp: 7 | data = fp.readlines() 8 | output_file = open("../mid_data/doccnao_train.json', 'w', encoding='utf-8') 9 | 10 | for did, d in enumerate(data): 11 | d = eval(d) 12 | tmp = {} 13 | tmp["id"] = d['ID'] 14 | tmp['text'] = d['text'] 15 | tmp['relations'] = [] 16 | tmp['entities'] = [] 17 | ent_id = 0 18 | for rel_id,spo in enumerate(d['spo_list']): 19 | rel_tmp = {} 20 | ent_tmp = {} 21 | rel_tmp['id'] = rel_id 22 | ent_tmp['id'] = ent_id 23 | h = spo['h'] 24 | ent_tmp['start_offset'] = h['pos'][0] 25 | ent_tmp['end_offset'] = h['pos'][1] 26 | ent_tmp['label'] = "主体" 27 | if ent_tmp not in tmp['entities']: 28 | tmp['entities'].append(ent_tmp) 29 | from_id = ent_id 30 | ent_id += 1 31 | else: 32 | ind = tmp['entities'].index(ent_tmp) 33 | from_id = tmp['entities'][ind]['id'] 34 | ent_id = len(tmp['entities']) + 1 35 | 36 | t = spo['t'] 37 | ent_tmp = {} 38 | ent_tmp['id'] = ent_id 39 | ent_tmp['start_offset'] = t['pos'][0] 40 | ent_tmp['end_offset'] = t['pos'][1] 41 | ent_tmp['label'] = "客体" 42 | if ent_tmp not in tmp['entities']: 43 | tmp['entities'].append(ent_tmp) 44 | to_id = ent_id 45 | ent_id += 1 46 | else: 47 | ind = tmp['entities'].index(ent_tmp) 48 | to_id = tmp['entities'][ind]['id'] 49 | ent_id = len(tmp['entities']) + 1 50 | 51 | rel_tmp['from_id'] = from_id 52 | rel_tmp['to_id'] = to_id 53 | rel_tmp['type'] = spo['relation'] 54 | 55 | tmp['relations'].append(rel_tmp) 56 | output_file.write(json.dumps(tmp, ensure_ascii=False) + "\n") -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from model import UIE 16 | import argparse 17 | import torch 18 | from utils import SpanEvaluator, IEDataset, logger, tqdm 19 | from transformers import BertTokenizerFast 20 | from torch.utils.data import DataLoader 21 | 22 | 23 | @torch.no_grad() 24 | def evaluate(model, metric, data_loader, device='gpu', loss_fn=None, show_bar=True): 25 | """ 26 | Given a dataset, it evals model and computes the metric. 27 | Args: 28 | model(obj:`torch.nn.Module`): A model to classify texts. 29 | metric(obj:`Metric`): The evaluation metric. 30 | data_loader(obj:`torch.utils.data.DataLoader`): The dataset loader which generates batches. 31 | """ 32 | return_loss = False 33 | if loss_fn is not None: 34 | return_loss = True 35 | model.eval() 36 | metric.reset() 37 | loss_list = [] 38 | loss_sum = 0 39 | loss_num = 0 40 | if show_bar: 41 | data_loader = tqdm( 42 | data_loader, desc="Evaluating", unit='batch') 43 | for batch in data_loader: 44 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch 45 | if device == 'gpu': 46 | input_ids = input_ids.cuda() 47 | token_type_ids = token_type_ids.cuda() 48 | att_mask = att_mask.cuda() 49 | outputs = model(input_ids=input_ids, 50 | token_type_ids=token_type_ids, 51 | attention_mask=att_mask) 52 | start_prob, end_prob = outputs[0], outputs[1] 53 | if device == 'gpu': 54 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu() 55 | start_ids = start_ids.type(torch.float32) 56 | end_ids = end_ids.type(torch.float32) 57 | 58 | if return_loss: 59 | # Calculate loss 60 | loss_start = loss_fn(start_prob, start_ids) 61 | loss_end = loss_fn(end_prob, end_ids) 62 | loss = (loss_start + loss_end) / 2.0 63 | loss = float(loss) 64 | loss_list.append(loss) 65 | loss_sum += loss 66 | loss_num += 1 67 | if show_bar: 68 | data_loader.set_postfix( 69 | { 70 | 'dev loss': f'{loss_sum / loss_num:.5f}' 71 | } 72 | ) 73 | 74 | # Calcalate metric 75 | num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, 76 | start_ids, end_ids) 77 | metric.update(num_correct, num_infer, num_label) 78 | precision, recall, f1 = metric.accumulate() 79 | model.train() 80 | if return_loss: 81 | loss_avg = sum(loss_list) / len(loss_list) 82 | return loss_avg, precision, recall, f1 83 | else: 84 | return precision, recall, f1 85 | 86 | 87 | def do_eval(): 88 | tokenizer = BertTokenizerFast.from_pretrained(args.model_path) 89 | model = UIE.from_pretrained(args.model_path) 90 | if args.device == 'gpu': 91 | model = model.cuda() 92 | 93 | test_ds = IEDataset(args.test_path, tokenizer=tokenizer, 94 | max_seq_len=args.max_seq_len) 95 | 96 | test_data_loader = DataLoader( 97 | test_ds, batch_size=args.batch_size, shuffle=False) 98 | metric = SpanEvaluator() 99 | precision, recall, f1 = evaluate( 100 | model, metric, test_data_loader, args.device) 101 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f" % 102 | (precision, recall, f1)) 103 | 104 | 105 | if __name__ == "__main__": 106 | # yapf: disable 107 | parser = argparse.ArgumentParser() 108 | 109 | parser.add_argument("-m", "--model_path", type=str, required=True, 110 | help="The path of saved model that you want to load.") 111 | parser.add_argument("-t", "--test_path", type=str, required=True, 112 | help="The path of test set.") 113 | parser.add_argument("-b", "--batch_size", type=int, default=16, 114 | help="Batch size per GPU/CPU for training.") 115 | parser.add_argument("--max_seq_len", type=int, default=512, 116 | help="The maximum total input sequence length after tokenization.") 117 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu", 118 | help="Select which device to run model, defaults to gpu.") 119 | 120 | args = parser.parse_args() 121 | 122 | do_eval() 123 | -------------------------------------------------------------------------------- /generic.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections import OrderedDict, UserDict 3 | from collections.abc import MutableMapping 4 | from contextlib import ExitStack 5 | from dataclasses import fields 6 | from enum import Enum 7 | from typing import Any, ContextManager, List, Tuple 8 | 9 | import numpy as np 10 | 11 | 12 | def is_tensor(x): 13 | """ 14 | Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`. 15 | """ 16 | if is_torch_fx_proxy(x): 17 | return True 18 | if is_torch_available(): 19 | import torch 20 | 21 | if isinstance(x, torch.Tensor): 22 | return True 23 | if is_tf_available(): 24 | import tensorflow as tf 25 | 26 | if isinstance(x, tf.Tensor): 27 | return True 28 | 29 | if is_flax_available(): 30 | import jax.numpy as jnp 31 | from jax.core import Tracer 32 | 33 | if isinstance(x, (jnp.ndarray, Tracer)): 34 | return True 35 | 36 | return isinstance(x, np.ndarray) 37 | 38 | class ModelOutput(OrderedDict): 39 | """ 40 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a 41 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular 42 | python dictionary. 43 | 44 | You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple 45 | before. 46 | 47 | """ 48 | 49 | def __post_init__(self): 50 | class_fields = fields(self) 51 | 52 | # Safety and consistency checks 53 | if not len(class_fields): 54 | raise ValueError(f"{self.__class__.__name__} has no fields.") 55 | if not all(field.default is None for field in class_fields[1:]): 56 | raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") 57 | 58 | first_field = getattr(self, class_fields[0].name) 59 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) 60 | 61 | if other_fields_are_none and not is_tensor(first_field): 62 | if isinstance(first_field, dict): 63 | iterator = first_field.items() 64 | first_field_iterator = True 65 | else: 66 | try: 67 | iterator = iter(first_field) 68 | first_field_iterator = True 69 | except TypeError: 70 | first_field_iterator = False 71 | 72 | # if we provided an iterator as first field and the iterator is a (key, value) iterator 73 | # set the associated fields 74 | if first_field_iterator: 75 | for element in iterator: 76 | if ( 77 | not isinstance(element, (list, tuple)) 78 | or not len(element) == 2 79 | or not isinstance(element[0], str) 80 | ): 81 | break 82 | setattr(self, element[0], element[1]) 83 | if element[1] is not None: 84 | self[element[0]] = element[1] 85 | elif first_field is not None: 86 | self[class_fields[0].name] = first_field 87 | else: 88 | for field in class_fields: 89 | v = getattr(self, field.name) 90 | if v is not None: 91 | self[field.name] = v 92 | 93 | def __delitem__(self, *args, **kwargs): 94 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") 95 | 96 | def setdefault(self, *args, **kwargs): 97 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") 98 | 99 | def pop(self, *args, **kwargs): 100 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") 101 | 102 | def update(self, *args, **kwargs): 103 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") 104 | 105 | def __getitem__(self, k): 106 | if isinstance(k, str): 107 | inner_dict = {k: v for (k, v) in self.items()} 108 | return inner_dict[k] 109 | else: 110 | return self.to_tuple()[k] 111 | 112 | def __setattr__(self, name, value): 113 | if name in self.keys() and value is not None: 114 | # Don't call self.__setitem__ to avoid recursion errors 115 | super().__setitem__(name, value) 116 | super().__setattr__(name, value) 117 | 118 | def __setitem__(self, key, value): 119 | # Will raise a KeyException if needed 120 | super().__setitem__(key, value) 121 | # Don't call self.__setattr__ to avoid recursion errors 122 | super().__setattr__(key, value) 123 | 124 | def to_tuple(self) -> Tuple[Any]: 125 | """ 126 | Convert self to a tuple containing all the attributes/keys that are not `None`. 127 | """ 128 | return tuple(self[k] for k in self.keys()) 129 | 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_uie_re 2 | 基于pytorch的百度UIE关系抽取,源代码来源:[here](https://github.com/heiheiyoyo/uie_pytorch)。 3 | 4 | 百度UIE通用信息抽取的样例一般都是使用doccano标注数据,这里介绍如何使用通用数据集,利用UIE进行微调。 5 | 6 | # 依赖 7 | 8 | ``` 9 | torch>=1.7.0 10 | transformers==4.20.0 11 | colorlog 12 | colorama 13 | ``` 14 | 15 | # 步骤 16 | 17 | - 1、数据放在data下面,比如已经放置的dgre([工业知识图谱关系抽取-高端装备制造知识图谱自动化构建 竞赛 - DataFountain](https://www.datafountain.cn/competitions/584)),raw_data下面是原始的数据,新建一个process.py,将数据处理成类似mid_data下面的数据,```python process.py```即: 18 | 19 | ```python 20 | {"id": "AT0004", "text": "617号汽车故障报告故障现象一辆吉利车装用MR479发动机,行驶里程为23709公里,驾驶员反映该车在行驶中无异响,但在起步和换挡过程中车身有抖动现象,并且听到离合器内部有异响。", "relations": [{"id": 0, "from_id": 0, "to_id": 1, "type": "部件故障"}, {"id": 1, "from_id": 2, "to_id": 3, "type": "部件故障"}], "entities": [{"id": 0, "start_offset": 80, "end_offset": 83, "label": "主体"}, {"id": 1, "start_offset": 86, "end_offset": 88, "label": "客体"}, {"id": 2, "start_offset": 68, "end_offset": 70, "label": "主体"}, {"id": 3, "start_offset": 71, "end_offset": 73, "label": "客体"}]} 21 | ``` 22 | 23 | - 2、将mid_data下面的数据使用doccano.py转换成final_data下的数据,具体指令是: 24 | 25 | ```python 26 | python doccano.py \ 27 | --doccano_file ./data/dgre/mid_data/doccano_train.json \ 28 | --task_type "ext" \ # ext表示抽取任务 29 | --splits 0.9 0.1 0.0 \ # 训练、验证、测试数据的比例。训练,不对数据进行切分,因此将第一位设置为1.0 30 | --save_dir ./data/dgre/final_data/ \ 31 | --negative_ratio 1 # 生成负样本的比率 32 | ``` 33 | 34 | 最终会在final_data下生成train.txt、dev.txt。 35 | 36 | - 3、将paddle版本的模型转换为pytorch版的模型: 37 | 38 | ```python 39 | python convert.py --input_model=uie-base --output_model=uie_base_pytorch --no_validate_output 40 | ``` 41 | 42 | 其中input_model可选的模型可参考convert.py里面。output_model是我们要保存的模型路径,下面会用到。之后我们可以测试下转换的效果: 43 | 44 | ```python 45 | from uie_predictor import UIEPredictor 46 | from pprint import pprint 47 | 48 | schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction 49 | ie = UIEPredictor('./uie_base_pytorch', schema=schema) 50 | pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint 51 | ``` 52 | 53 | - 4、开始微调: 54 | 55 | ```python 56 | python finetune.py \ 57 | --train_path "./data/dgre/final_data/train.txt" \ 58 | --dev_path "./data/dgre/final_data/dev.txt" \ 59 | --save_dir "./checkpoint/dgre" \ 60 | --learning_rate 1e-5 \ 61 | --batch_size 8 \ 62 | --max_seq_len 512 \ 63 | --num_epochs 3 \ 64 | --model "uie_base_pytorch" \ 65 | --seed 1000 \ 66 | --logging_steps 464 \ 67 | --valid_steps 464 \ 68 | --device "gpu" \ 69 | --max_model_num 1 70 | ``` 71 | 72 | 训练完成后,会在同目录下生成checkpoint/dgre/model_best/。 73 | 74 | - 5、进行验证: 75 | 76 | ```python 77 | python evaluate.py \ 78 | --model_path "./checkpoint/dgre/model_best" \ 79 | --test_path "./data/dgre/final_data/dev.txt" \ 80 | --batch_size 16 \ 81 | --max_seq_len 512 82 | ``` 83 | 84 | - 6、使用训练好的模型进行预测: 85 | 86 | ```python 87 | from uie_predictor import UIEPredictor 88 | from pprint import pprint 89 | 90 | schema = [{"主体":["部件故障", "性能故障", "检测工具", "组成"]}, "客体"] # Define the schema for entity extraction 91 | ie = UIEPredictor('./checkpoint/dgre/model_best', schema=schema) 92 | text = "分析诊断首先用故障诊断仪读取故障码为12,其含义是电控系统正常。考虑到电喷发动机控制是由怠速马达来实现的,所以先拆下怠速马达,发现其阀头上粘有大量胶质油污。用化油器清洗剂清洗后,装车试验,故障依旧。接着清洗喷油咀,故障仍未排除。最后把节气门体拆下来清洗。在操作过程中发现:一根插在节气门体下部真空管上的胶管已断裂,造成节气门后腔与大气相通,影响怠速运转稳定。这条胶管应该是连接在节气门进气管和气门室盖排气孔之间特制的丁字胶管的一部分,但该车没有使用特制的丁字胶管,它用一条直通胶管将节气门进气管和气门室盖排气孔连起来。维修方案把节气门体清洗干净后装车,再用一条专用特制的丁字形的三通胶管把节气门进气管、气门室盖排气孔和节气门体下部真空管接好,然后启动发动机,加速收油,发动机转速平稳下降" 93 | res = ie(text) 94 | pprint(res) # Better print results using pprint 95 | 96 | """ 97 | [{'主体': [{'end': 164, 98 | 'probability': 0.44656947, 99 | 'relations': {'性能故障': [{'end': 169, 100 | 'probability': 0.9079101, 101 | 'start': 167, 102 | 'text': '相通'}], 103 | '部件故障': [{'end': 169, 104 | 'probability': 0.91306436, 105 | 'start': 167, 106 | 'text': '相通'}]}, 107 | 'start': 159, 108 | 'text': '节气门后腔'}, 109 | {'end': 153, 110 | 'probability': 0.92779523, 111 | 'relations': {'性能故障': [{'end': 156, 112 | 'probability': 0.9924409, 113 | 'start': 154, 114 | 'text': '断裂'}], 115 | '部件故障': [{'end': 156, 116 | 'probability': 0.9910217, 117 | 'start': 154, 118 | 'text': '断裂'}]}, 119 | 'start': 151, 120 | 'text': '胶管'}, 121 | {'end': 68, 122 | 'probability': 0.32421115, 123 | 'relations': {'性能故障': [{'end': 77, 124 | 'probability': 0.6647082, 125 | 'start': 69, 126 | 'text': '粘有大量胶质油污'}], 127 | '部件故障': [{'end': 77, 128 | 'probability': 0.8483382, 129 | 'start': 69, 130 | 'text': '粘有大量胶质油污'}]}, 131 | 'start': 66, 132 | 'text': '阀头'}], 133 | '客体': [{'end': 77, 'probability': 0.5508156, 'start': 69, 'text': '粘有大量胶质油污'}, 134 | {'end': 156, 'probability': 0.9614242, 'start': 154, 'text': '断裂'}, 135 | {'end': 169, 136 | 'probability': 0.56304514, 137 | 'start': 164, 138 | 'text': '与大气相通'}]}] 139 | """ 140 | ``` 141 | 142 | 会发现,关系抽取会有关系的重复,可以多训练几个epoch看看。 143 | 144 | # 补充 145 | 146 | - 标签名最好是使用中文。 147 | - 可使用不同大小的模型进行训练和推理,以达到精度和速度的平衡。 148 | -------------------------------------------------------------------------------- /doccano.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import time 17 | import argparse 18 | import json 19 | from decimal import Decimal 20 | import numpy as np 21 | 22 | from utils import set_seed, convert_ext_examples, convert_cls_examples, logger 23 | 24 | 25 | def do_convert(): 26 | set_seed(args.seed) 27 | 28 | tic_time = time.time() 29 | if not os.path.exists(args.doccano_file): 30 | raise ValueError("Please input the correct path of doccano file.") 31 | 32 | if not os.path.exists(args.save_dir): 33 | os.makedirs(args.save_dir) 34 | 35 | if len(args.splits) != 0 and len(args.splits) != 3: 36 | raise ValueError("Only []/ len(splits)==3 accepted for splits.") 37 | 38 | def _check_sum(splits): 39 | return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal( 40 | str(splits[2])) == Decimal("1") 41 | 42 | if len(args.splits) == 3 and not _check_sum(args.splits): 43 | raise ValueError( 44 | "Please set correct splits, sum of elements in splits should be equal to 1." 45 | ) 46 | 47 | with open(args.doccano_file, "r", encoding="utf-8") as f: 48 | raw_examples = f.readlines() 49 | 50 | def _create_ext_examples(examples, 51 | negative_ratio=0, 52 | shuffle=False, 53 | is_train=True): 54 | entities, relations = convert_ext_examples( 55 | examples, negative_ratio, is_train=is_train) 56 | examples = entities + relations 57 | if shuffle: 58 | indexes = np.random.permutation(len(examples)) 59 | examples = [examples[i] for i in indexes] 60 | return examples 61 | 62 | def _create_cls_examples(examples, prompt_prefix, options, shuffle=False): 63 | examples = convert_cls_examples(examples, prompt_prefix, options) 64 | if shuffle: 65 | indexes = np.random.permutation(len(examples)) 66 | examples = [examples[i] for i in indexes] 67 | return examples 68 | 69 | def _save_examples(save_dir, file_name, examples): 70 | count = 0 71 | save_path = os.path.join(save_dir, file_name) 72 | if not examples: 73 | logger.info("Skip saving %d examples to %s." % (0, save_path)) 74 | return 75 | with open(save_path, "w", encoding="utf-8") as f: 76 | for example in examples: 77 | f.write(json.dumps(example, ensure_ascii=False) + "\n") 78 | count += 1 79 | logger.info("Save %d examples to %s." % (count, save_path)) 80 | 81 | if len(args.splits) == 0: 82 | if args.task_type == "ext": 83 | examples = _create_ext_examples(raw_examples, args.negative_ratio, 84 | args.is_shuffle) 85 | else: 86 | examples = _create_cls_examples(raw_examples, args.prompt_prefix, 87 | args.options, args.is_shuffle) 88 | _save_examples(args.save_dir, "train.txt", examples) 89 | else: 90 | if args.is_shuffle: 91 | indexes = np.random.permutation(len(raw_examples)) 92 | raw_examples = [raw_examples[i] for i in indexes] 93 | 94 | i1, i2, _ = args.splits 95 | p1 = int(len(raw_examples) * i1) 96 | p2 = int(len(raw_examples) * (i1 + i2)) 97 | 98 | if args.task_type == "ext": 99 | train_examples = _create_ext_examples( 100 | raw_examples[:p1], args.negative_ratio, args.is_shuffle) 101 | dev_examples = _create_ext_examples( 102 | raw_examples[p1:p2], -1, is_train=False) 103 | test_examples = _create_ext_examples( 104 | raw_examples[p2:], -1, is_train=False) 105 | else: 106 | train_examples = _create_cls_examples( 107 | raw_examples[:p1], args.prompt_prefix, args.options) 108 | dev_examples = _create_cls_examples( 109 | raw_examples[p1:p2], args.prompt_prefix, args.options) 110 | test_examples = _create_cls_examples( 111 | raw_examples[p2:], args.prompt_prefix, args.options) 112 | 113 | _save_examples(args.save_dir, "train.txt", train_examples) 114 | _save_examples(args.save_dir, "dev.txt", dev_examples) 115 | _save_examples(args.save_dir, "test.txt", test_examples) 116 | 117 | logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time)) 118 | 119 | 120 | if __name__ == "__main__": 121 | # yapf: disable 122 | parser = argparse.ArgumentParser() 123 | 124 | parser.add_argument("-d", "--doccano_file", default="./data/doccano.json", 125 | type=str, help="The doccano file exported from doccano platform.") 126 | parser.add_argument("-s", "--save_dir", default="./data", 127 | type=str, help="The path of data that you wanna save.") 128 | parser.add_argument("--negative_ratio", default=5, type=int, 129 | help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples") 130 | parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*", 131 | help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.") 132 | parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str, 133 | help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.") 134 | parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+", 135 | help="Used only for the classification task, the options for classification") 136 | parser.add_argument("--prompt_prefix", default="情感倾向", type=str, 137 | help="Used only for the classification task, the prompt prefix for classification") 138 | parser.add_argument("--is_shuffle", default=True, type=bool, 139 | help="Whether to shuffle the labeled dataset, defaults to True.") 140 | parser.add_argument("--seed", type=int, default=1000, 141 | help="random seed for initialization") 142 | 143 | args = parser.parse_args() 144 | # yapf: enable 145 | 146 | do_convert() 147 | -------------------------------------------------------------------------------- /export_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import os 17 | from itertools import chain 18 | from typing import List, Union 19 | import shutil 20 | from pathlib import Path 21 | 22 | import numpy as np 23 | import torch 24 | from transformers import (BertTokenizer, PreTrainedModel, 25 | PreTrainedTokenizerBase) 26 | 27 | from model import UIE 28 | from utils import logger 29 | 30 | 31 | def validate_onnx(tokenizer: PreTrainedTokenizerBase, pt_model: PreTrainedModel, onnx_path: Union[Path, str], strict: bool = True, atol: float = 1e-05): 32 | 33 | # 验证模型 34 | from onnxruntime import InferenceSession, SessionOptions 35 | from transformers import AutoTokenizer 36 | 37 | logger.info("Validating ONNX model...") 38 | if strict: 39 | ref_inputs = tokenizer('装备', "印媒所称的“印度第一艘国产航母”—“维克兰特”号", 40 | add_special_tokens=True, 41 | truncation=True, 42 | max_length=512, 43 | return_tensors="pt") 44 | else: 45 | batch_size = 2 46 | seq_length = 6 47 | dummy_input = [" ".join([tokenizer.unk_token]) 48 | * seq_length] * batch_size 49 | ref_inputs = dict(tokenizer(dummy_input, return_tensors="pt")) 50 | # ref_inputs = 51 | ref_outputs = pt_model(**ref_inputs) 52 | ref_outputs_dict = {} 53 | 54 | # We flatten potential collection of outputs (i.e. past_keys) to a flat structure 55 | for name, value in ref_outputs.items(): 56 | # Overwriting the output name as "present" since it is the name used for the ONNX outputs 57 | # ("past_key_values" being taken for the ONNX inputs) 58 | if name == "past_key_values": 59 | name = "present" 60 | ref_outputs_dict[name] = value 61 | 62 | # Create ONNX Runtime session 63 | options = SessionOptions() 64 | session = InferenceSession(str(onnx_path), options, providers=[ 65 | "CPUExecutionProvider"]) 66 | 67 | # We flatten potential collection of inputs (i.e. past_keys) 68 | onnx_inputs = {} 69 | for name, value in ref_inputs.items(): 70 | onnx_inputs[name] = value.numpy() 71 | onnx_named_outputs = ['start_prob', 'end_prob'] 72 | # Compute outputs from the ONNX model 73 | onnx_outputs = session.run(onnx_named_outputs, onnx_inputs) 74 | 75 | # Check we have a subset of the keys into onnx_outputs against ref_outputs 76 | ref_outputs_set, onnx_outputs_set = set( 77 | ref_outputs_dict.keys()), set(onnx_named_outputs) 78 | if not onnx_outputs_set.issubset(ref_outputs_set): 79 | logger.info( 80 | f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}" 81 | ) 82 | 83 | raise ValueError( 84 | "Outputs doesn't match between reference model and ONNX exported model: " 85 | f"{onnx_outputs_set.difference(ref_outputs_set)}" 86 | ) 87 | else: 88 | logger.info( 89 | f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})") 90 | 91 | # Check the shape and values match 92 | for name, ort_value in zip(onnx_named_outputs, onnx_outputs): 93 | ref_value = ref_outputs_dict[name].detach().numpy() 94 | 95 | logger.info(f'\t- Validating ONNX Model output "{name}":') 96 | 97 | # Shape 98 | if not ort_value.shape == ref_value.shape: 99 | logger.info( 100 | f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}") 101 | raise ValueError( 102 | "Outputs shape doesn't match between reference model and ONNX exported model: " 103 | f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)" 104 | ) 105 | else: 106 | logger.info( 107 | f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}") 108 | 109 | # Values 110 | if not np.allclose(ref_value, ort_value, atol=atol): 111 | logger.info(f"\t\t-[x] values not close enough (atol: {atol})") 112 | raise ValueError( 113 | "Outputs values doesn't match between reference model and ONNX exported model: " 114 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}" 115 | ) 116 | else: 117 | logger.info(f"\t\t-[✓] all values close (atol: {atol})") 118 | 119 | 120 | def export_onnx(args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase, model: PreTrainedModel, device: torch.device, input_names: List[str], output_names: List[str]): 121 | with torch.no_grad(): 122 | model = model.to(device) 123 | model.eval() 124 | model.config.return_dict = True 125 | model.config.use_cache = False 126 | 127 | # Create folder 128 | if not args.output_path.exists(): 129 | args.output_path.mkdir(parents=True) 130 | save_path = args.output_path / "inference.onnx" 131 | 132 | dynamic_axes = {name: {0: 'batch', 1: 'sequence'} 133 | for name in chain(input_names, output_names)} 134 | 135 | # Generate dummy input 136 | batch_size = 2 137 | seq_length = 6 138 | dummy_input = [" ".join([tokenizer.unk_token]) 139 | * seq_length] * batch_size 140 | inputs = dict(tokenizer(dummy_input, return_tensors="pt")) 141 | 142 | if save_path.exists(): 143 | logger.warning(f'Overwrite model {save_path.as_posix()}') 144 | save_path.unlink() 145 | 146 | torch.onnx.export(model, 147 | (inputs,), 148 | save_path, 149 | input_names=input_names, 150 | output_names=output_names, 151 | dynamic_axes=dynamic_axes, 152 | do_constant_folding=True, 153 | opset_version=11 154 | ) 155 | 156 | if not os.path.exists(save_path): 157 | logger.error(f'Export Failed!') 158 | 159 | return save_path 160 | 161 | 162 | def main(): 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument("-m", "--model_path", type=Path, required=True, 165 | default='./checkpoint/model_best', help="The path to model parameters to be loaded.") 166 | parser.add_argument("-o", "--output_path", type=Path, default=None, 167 | help="The path of model parameter in static graph to be saved.") 168 | args = parser.parse_args() 169 | 170 | if args.output_path is None: 171 | args.output_path = args.model_path 172 | 173 | tokenizer = BertTokenizer.from_pretrained(args.model_path) 174 | model = UIE.from_pretrained(args.model_path) 175 | device = torch.device('cpu') 176 | input_names = [ 177 | 'input_ids', 178 | 'token_type_ids', 179 | 'attention_mask', 180 | ] 181 | output_names = [ 182 | 'start_prob', 183 | 'end_prob' 184 | ] 185 | 186 | logger.info("Export Tokenizer Config...") 187 | 188 | export_tokenizer(args) 189 | 190 | logger.info("Export ONNX Model...") 191 | 192 | save_path = export_onnx( 193 | args, tokenizer, model, device, input_names, output_names) 194 | validate_onnx(tokenizer, model, save_path) 195 | 196 | logger.info(f"All good, model saved at: {save_path.as_posix()}") 197 | 198 | 199 | def export_tokenizer(args): 200 | for tokenizer_fine in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.txt']: 201 | file_from = args.model_path / tokenizer_fine 202 | file_to = args.output_path/tokenizer_fine 203 | if file_from.resolve() == file_to.resolve(): 204 | continue 205 | shutil.copyfile(file_from, file_to) 206 | 207 | 208 | if __name__ == "__main__": 209 | 210 | main() 211 | -------------------------------------------------------------------------------- /doccano.md: -------------------------------------------------------------------------------- 1 | # doccano 2 | 3 | **目录** 4 | 5 | * [1. 安装](#安装) 6 | * [2. 项目创建](#项目创建) 7 | * [3. 数据上传](#数据上传) 8 | * [4. 标签构建](#标签构建) 9 | * [5. 任务标注](#任务标注) 10 | * [6. 数据导出](#数据导出) 11 | * [7. 数据转换](#数据转换) 12 | 13 | 14 | 15 | ## 1. 安装 16 | 17 | 参考[doccano官方文档](https://github.com/doccano/doccano) 完成doccano的安装与初始配置。 18 | 19 | **以下标注示例用到的环境配置:** 20 | 21 | - doccano 1.6.2 22 | 23 | 24 | 25 | ## 2. 项目创建 26 | 27 | UIE支持抽取与分类两种类型的任务,根据实际需要创建一个新的项目: 28 | 29 | #### 2.1 抽取式任务项目创建 30 | 31 | 创建项目时选择**序列标注**任务,并勾选**Allow overlapping entity**及**Use relation Labeling**。适配**命名实体识别、关系抽取、事件抽取、评价观点抽取**等任务。 32 | 33 |
34 | 35 |
36 | 37 | #### 2.2 分类式任务项目创建 38 | 39 | 创建项目时选择**文本分类**任务。适配**文本分类、句子级情感倾向分类**等任务。 40 | 41 |
42 | 43 |
44 | 45 | 46 | 47 | ## 3. 数据上传 48 | 49 | 上传的文件为txt格式,每一行为一条待标注文本,示例: 50 | 51 | ```text 52 | 2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌 53 | 第十四届全运会在西安举办 54 | ``` 55 | 56 | 上传数据类型**选择TextLine**: 57 | 58 |
59 | 60 |
61 | 62 | **NOTE**:doccano支持`TextFile`、`TextLine`、`JSONL`和`CoNLL`四种数据上传格式,UIE定制训练中**统一使用TextLine**这一文件格式,即上传的文件需要为txt格式,且在数据标注时,该文件的每一行待标注文本显示为一页内容。 63 | 64 | 65 | 66 | ## 4. 标签构建 67 | 68 | #### 4.1 构建抽取式任务标签 69 | 70 | 抽取式任务包含**Span**与**Relation**两种标签类型,Span指**原文本中的目标信息片段**,如实体识别中某个类型的实体,事件抽取中的触发词和论元;Relation指**原文本中Span之间的关系**,如关系抽取中两个实体(Subject&Object)之间的关系,事件抽取中论元和触发词之间的关系。 71 | 72 | Span类型标签构建示例: 73 | 74 |
75 | 76 |
77 | 78 | Relation类型标签构建示例: 79 | 80 |
81 | 82 |
83 | 84 | #### 4.2 构建分类式任务标签 85 | 86 | 添加分类类别标签: 87 | 88 |
89 | 90 |
91 | 92 | 93 | 94 | ## 5. 任务标注 95 | 96 | #### 5.1 命名实体识别 97 | 98 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,**抽取的类别没有限制,用户可以自己定义**。 99 | 100 | 标注示例: 101 | 102 |
103 | 104 |
105 | 106 | 示例中定义了`时间`、`选手`、`赛事名称`和`得分`四种Span类型标签。 107 | 108 | #### 5.2 关系抽取 109 | 110 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,即抽取三元组(实体一,关系类型,实体二)。 111 | 112 | 标注示例: 113 | 114 |
115 | 116 |
117 | 118 | 示例中定义了`作品名`、`人物名`和`时间`三种Span类型标签,以及`歌手`、`发行时间`和`所属专辑`三种Relation标签。Relation标签**由Subject对应实体指向Object对应实体**。 119 | 120 | #### 5.3 事件抽取 121 | 122 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取事件并识别事件类型和事件论元的技术。UIE所包含的事件抽取任务,是指根据已知事件类型,抽取该事件所包含的事件论元。 123 | 124 | 标注示例: 125 | 126 |
127 | 128 |
129 | 130 | 示例中定义了`地震触发词`(触发词)、`等级`(事件论元)和`时间`(事件论元)三种Span标签,以及`时间`和`震级`两种Relation标签。触发词标签**统一格式为`XX触发词`**,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。Relation标签**由触发词指向对应的事件论元**。 131 | 132 | #### 5.4 评价观点抽取 133 | 134 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。 135 | 136 | 标注示例: 137 | 138 |
139 | 140 |
141 | 142 | 示例中定义了`评价维度`和`观点词`两种Span标签,以及`观点词`一种Relation标签。Relation标签**由评价维度指向观点词**。 143 | 144 | #### 5.5 分类任务 145 | 146 | 标注示例: 147 | 148 |
149 | 150 |
151 | 152 | 示例中定义了`正向`和`负向`两种类别标签对文本的情感倾向进行分类。 153 | 154 | 155 | 156 | ## 6. 数据导出 157 | 158 | #### 6.1 导出抽取式任务数据 159 | 160 | 选择导出的文件类型为``JSONL(relation)``,导出数据示例: 161 | 162 | ```text 163 | { 164 | "id": 38, 165 | "text": "百科名片你知道我要什么,是歌手高明骏演唱的一首歌曲,1989年发行,收录于个人专辑《丛林男孩》中", 166 | "relations": [ 167 | { 168 | "id": 20, 169 | "from_id": 51, 170 | "to_id": 53, 171 | "type": "歌手" 172 | }, 173 | { 174 | "id": 21, 175 | "from_id": 51, 176 | "to_id": 55, 177 | "type": "发行时间" 178 | }, 179 | { 180 | "id": 22, 181 | "from_id": 51, 182 | "to_id": 54, 183 | "type": "所属专辑" 184 | } 185 | ], 186 | "entities": [ 187 | { 188 | "id": 51, 189 | "start_offset": 4, 190 | "end_offset": 11, 191 | "label": "作品名" 192 | }, 193 | { 194 | "id": 53, 195 | "start_offset": 15, 196 | "end_offset": 18, 197 | "label": "人物名" 198 | }, 199 | { 200 | "id": 54, 201 | "start_offset": 42, 202 | "end_offset": 46, 203 | "label": "作品名" 204 | }, 205 | { 206 | "id": 55, 207 | "start_offset": 26, 208 | "end_offset": 31, 209 | "label": "时间" 210 | } 211 | ] 212 | } 213 | ``` 214 | 215 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段 216 | - ``id``: 样本在数据集中的唯一标识ID。 217 | - ``text``: 原始文本数据。 218 | - ``entities``: 数据中包含的Span标签,每个Span标签包含四个字段: 219 | - ``id``: Span在数据集中的唯一标识ID。 220 | - ``start_offset``: Span的起始token在文本中的下标。 221 | - ``end_offset``: Span的结束token在文本中下标的下一个位置。 222 | - ``label``: Span类型。 223 | - ``relations``: 数据中包含的Relation标签,每个Relation标签包含四个字段: 224 | - ``id``: (Span1, Relation, Span2)三元组在数据集中的唯一标识ID,不同样本中的相同三元组对应同一个ID。 225 | - ``from_id``: Span1对应的标识ID。 226 | - ``to_id``: Span2对应的标识ID。 227 | - ``type``: Relation类型。 228 | 229 | #### 6.2 导出分类式任务数据 230 | 231 | 选择导出的文件类型为``JSONL``,导出数据示例: 232 | 233 | ```text 234 | { 235 | "id": 41, 236 | "data": "大年初一就把车前保险杠给碰坏了,保险杠和保险公司 真够倒霉的,我决定步行反省。", 237 | "label": [ 238 | "负向" 239 | ] 240 | } 241 | ``` 242 | 243 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段 244 | - ``id``: 样本在数据集中的唯一标识ID。 245 | - ``data``: 原始文本数据。 246 | - ``label``: 文本对应类别标签。 247 | 248 | 249 | 250 | ## 7.数据转换 251 | 252 | 该章节详细说明如何通过`doccano.py`脚本对doccano平台导出的标注数据进行转换,一键生成训练/验证/测试集。 253 | 254 | #### 7.1 抽取式任务数据转换 255 | 256 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。 257 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。 258 | 259 | ```shell 260 | python doccano.py \ 261 | --doccano_file ./data/doccano_ext.json \ 262 | --task_type "ext" \ 263 | --save_dir ./data \ 264 | --negative_ratio 5 265 | ``` 266 | 267 | #### 7.2 分类式任务数据转换 268 | 269 | - 当标注完成后,在 doccano 平台上导出 `JSON` 形式的文件,并将其重命名为 `doccano_cls.json` 后,放入 `./data` 目录下。 270 | - 在数据转换阶段,我们会自动构造用于模型训练需要的prompt信息。例如句子级情感分类中,prompt为``情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。 271 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。 272 | 273 | ```shell 274 | python doccano.py \ 275 | --doccano_file ./data/doccano_cls.json \ 276 | --task_type "cls" \ 277 | --save_dir ./data \ 278 | --splits 0.8 0.1 0.1 \ 279 | --prompt_prefix "情感倾向" \ 280 | --options "正向" "负向" 281 | ``` 282 | 283 | 可配置参数说明: 284 | 285 | - ``doccano_file``: 从doccano导出的数据标注文件。 286 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。 287 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。 288 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。 289 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。 290 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。 291 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。 292 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。 293 | - ``seed``: 随机种子,默认为1000. 294 | 295 | 备注: 296 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集 297 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件 298 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。 299 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。 300 | 301 | ## References 302 | - **[doccano](https://github.com/doccano/doccano)** 303 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from dataclasses import dataclass 18 | from transformers import BertModel, BertPreTrainedModel, PretrainedConfig 19 | # from transformers.utils import ModelOutput 20 | from typing import Optional, Tuple 21 | from generic import ModelOutput 22 | 23 | 24 | @dataclass 25 | class UIEModelOutput(ModelOutput): 26 | """ 27 | Output class for outputs of UIE. 28 | Args: 29 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 30 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. 31 | start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 32 | Span-start scores (after Sigmoid). 33 | end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): 34 | Span-end scores (after Sigmoid). 35 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 36 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 37 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 38 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 39 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 40 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 41 | sequence_length)`. 42 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 43 | heads. 44 | """ 45 | loss: Optional[torch.FloatTensor] = None 46 | start_prob: torch.FloatTensor = None 47 | end_prob: torch.FloatTensor = None 48 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 49 | attentions: Optional[Tuple[torch.FloatTensor]] = None 50 | 51 | 52 | class UIE(BertPreTrainedModel): 53 | """ 54 | UIE model based on Bert model. 55 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 56 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 57 | etc.) 58 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 59 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 60 | and behavior. 61 | Parameters: 62 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model. 63 | Initializing with a config file does not load the weights associated with the model, only the 64 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 65 | """ 66 | 67 | def __init__(self, config: PretrainedConfig): 68 | super(UIE, self).__init__(config) 69 | self.encoder = BertModel(config) 70 | self.config = config 71 | hidden_size = self.config.hidden_size 72 | 73 | self.linear_start = nn.Linear(hidden_size, 1) 74 | self.linear_end = nn.Linear(hidden_size, 1) 75 | self.sigmoid = nn.Sigmoid() 76 | 77 | if hasattr(config, 'use_task_id') and config.use_task_id: 78 | # Add task type embedding to BERT 79 | task_type_embeddings = nn.Embedding( 80 | config.task_type_vocab_size, config.hidden_size) 81 | self.encoder.embeddings.task_type_embeddings = task_type_embeddings 82 | 83 | def hook(module, input, output): 84 | input = input[0] 85 | return output+task_type_embeddings(torch.zeros(input.size(), dtype=torch.int64, device=input.device)) 86 | self.encoder.embeddings.word_embeddings.register_forward_hook(hook) 87 | 88 | self.post_init() 89 | 90 | def forward(self, input_ids: Optional[torch.Tensor] = None, 91 | token_type_ids: Optional[torch.Tensor] = None, 92 | position_ids: Optional[torch.Tensor] = None, 93 | attention_mask: Optional[torch.Tensor] = None, 94 | head_mask: Optional[torch.Tensor] = None, 95 | inputs_embeds: Optional[torch.Tensor] = None, 96 | start_positions: Optional[torch.Tensor] = None, 97 | end_positions: Optional[torch.Tensor] = None, 98 | output_attentions: Optional[bool] = None, 99 | output_hidden_states: Optional[bool] = None, 100 | return_dict: Optional[bool] = None 101 | ): 102 | """ 103 | Args: 104 | input_ids (`torch.LongTensor` of shape `({0})`): 105 | Indices of input sequence tokens in the vocabulary. 106 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and 107 | [`PreTrainedTokenizer.__call__`] for details. 108 | [What are input IDs?](../glossary#input-ids) 109 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): 110 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 111 | - 1 for tokens that are **not masked**, 112 | - 0 for tokens that are **masked**. 113 | [What are attention masks?](../glossary#attention-mask) 114 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): 115 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 116 | 1]`: 117 | - 0 corresponds to a *sentence A* token, 118 | - 1 corresponds to a *sentence B* token. 119 | [What are token type IDs?](../glossary#token-type-ids) 120 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 121 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 122 | config.max_position_embeddings - 1]`. 123 | [What are position IDs?](../glossary#position-ids) 124 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 125 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 126 | - 1 indicates the head is **not masked**, 127 | - 0 indicates the head is **masked**. 128 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): 129 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 130 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 131 | model's internal embedding lookup matrix. 132 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 133 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 134 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 135 | are not taken into account for computing the loss. 136 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 137 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 138 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 139 | are not taken into account for computing the loss. 140 | output_attentions (`bool`, *optional*): 141 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 142 | tensors for more detail. 143 | output_hidden_states (`bool`, *optional*): 144 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 145 | more detail. 146 | return_dict (`bool`, *optional*): 147 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 148 | """ 149 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 150 | outputs = self.encoder( 151 | input_ids=input_ids, 152 | token_type_ids=token_type_ids, 153 | position_ids=position_ids, 154 | attention_mask=attention_mask, 155 | head_mask=head_mask, 156 | inputs_embeds=inputs_embeds, 157 | output_attentions=output_attentions, 158 | output_hidden_states=output_hidden_states, 159 | return_dict=return_dict 160 | ) 161 | sequence_output = outputs[0] 162 | 163 | start_logits = self.linear_start(sequence_output) 164 | start_logits = torch.squeeze(start_logits, -1) 165 | start_prob = self.sigmoid(start_logits) 166 | end_logits = self.linear_end(sequence_output) 167 | end_logits = torch.squeeze(end_logits, -1) 168 | end_prob = self.sigmoid(end_logits) 169 | 170 | total_loss = None 171 | if start_positions is not None and end_positions is not None: 172 | loss_fct = nn.BCELoss() 173 | start_loss = loss_fct(start_prob, start_positions) 174 | end_loss = loss_fct(end_prob, end_positions) 175 | total_loss = (start_loss + end_loss) / 2.0 176 | 177 | if not return_dict: 178 | output = (start_prob, end_prob) + outputs[2:] 179 | return ((total_loss,) + output) if total_loss is not None else output 180 | 181 | return UIEModelOutput( 182 | loss=total_loss, 183 | start_prob=start_prob, 184 | end_prob=end_prob, 185 | hidden_states=outputs.hidden_states, 186 | attentions=outputs.attentions, 187 | ) 188 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import shutil 17 | import sys 18 | import time 19 | import os 20 | import torch 21 | from torch.utils.data import DataLoader 22 | from transformers import BertTokenizerFast 23 | 24 | from utils import IEDataset, logger, tqdm 25 | from model import UIE 26 | from evaluate import evaluate 27 | from utils import set_seed, SpanEvaluator, EarlyStopping, logging_redirect_tqdm 28 | 29 | 30 | def do_train(): 31 | 32 | set_seed(args.seed) 33 | show_bar = True 34 | 35 | tokenizer = BertTokenizerFast.from_pretrained(args.model) 36 | model = UIE.from_pretrained(args.model) 37 | if args.device == 'gpu': 38 | model = model.cuda() 39 | train_ds = IEDataset(args.train_path, tokenizer=tokenizer, 40 | max_seq_len=args.max_seq_len) 41 | dev_ds = IEDataset(args.dev_path, tokenizer=tokenizer, 42 | max_seq_len=args.max_seq_len) 43 | 44 | train_data_loader = DataLoader( 45 | train_ds, batch_size=args.batch_size, shuffle=True) 46 | dev_data_loader = DataLoader( 47 | dev_ds, batch_size=args.batch_size, shuffle=True) 48 | 49 | optimizer = torch.optim.AdamW( 50 | lr=args.learning_rate, params=model.parameters()) 51 | 52 | criterion = torch.nn.functional.binary_cross_entropy 53 | metric = SpanEvaluator() 54 | 55 | if args.early_stopping: 56 | early_stopping_save_dir = os.path.join( 57 | args.save_dir, "early_stopping") 58 | if not os.path.exists(early_stopping_save_dir): 59 | os.makedirs(early_stopping_save_dir) 60 | if show_bar: 61 | def trace_func(*args, **kwargs): 62 | with logging_redirect_tqdm([logger.logger]): 63 | logger.info(*args, **kwargs) 64 | else: 65 | trace_func = logger.info 66 | early_stopping = EarlyStopping( 67 | patience=7, verbose=True, trace_func=trace_func, 68 | save_dir=early_stopping_save_dir) 69 | 70 | loss_list = [] 71 | loss_sum = 0 72 | loss_num = 0 73 | global_step = 0 74 | best_step = 0 75 | best_f1 = 0 76 | tic_train = time.time() 77 | epoch_iterator = range(1, args.num_epochs + 1) 78 | if show_bar: 79 | train_postfix_info = {'loss': 'unknown'} 80 | epoch_iterator = tqdm( 81 | epoch_iterator, desc='Training', unit='epoch') 82 | for epoch in epoch_iterator: 83 | train_data_iterator = train_data_loader 84 | if show_bar: 85 | train_data_iterator = tqdm(train_data_iterator, 86 | desc=f'Training Epoch {epoch}', unit='batch') 87 | train_data_iterator.set_postfix(train_postfix_info) 88 | for batch in train_data_iterator: 89 | if show_bar: 90 | epoch_iterator.refresh() 91 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch 92 | if args.device == 'gpu': 93 | input_ids = input_ids.cuda() 94 | token_type_ids = token_type_ids.cuda() 95 | att_mask = att_mask.cuda() 96 | start_ids = start_ids.cuda() 97 | end_ids = end_ids.cuda() 98 | outputs = model(input_ids=input_ids, 99 | token_type_ids=token_type_ids, 100 | attention_mask=att_mask) 101 | start_prob, end_prob = outputs[0], outputs[1] 102 | 103 | start_ids = start_ids.type(torch.float32) 104 | end_ids = end_ids.type(torch.float32) 105 | loss_start = criterion(start_prob, start_ids) 106 | loss_end = criterion(end_prob, end_ids) 107 | loss = (loss_start + loss_end) / 2.0 108 | loss.backward() 109 | optimizer.step() 110 | optimizer.zero_grad() 111 | loss_list.append(float(loss)) 112 | loss_sum += float(loss) 113 | loss_num += 1 114 | 115 | if show_bar: 116 | loss_avg = loss_sum / loss_num 117 | train_postfix_info.update({ 118 | 'loss': f'{loss_avg:.5f}' 119 | }) 120 | train_data_iterator.set_postfix(train_postfix_info) 121 | 122 | global_step += 1 123 | if global_step % args.logging_steps == 0: 124 | time_diff = time.time() - tic_train 125 | loss_avg = loss_sum / loss_num 126 | 127 | if show_bar: 128 | with logging_redirect_tqdm([logger.logger]): 129 | logger.info( 130 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" 131 | % (global_step, epoch, loss_avg, 132 | args.logging_steps / time_diff)) 133 | else: 134 | logger.info( 135 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s" 136 | % (global_step, epoch, loss_avg, 137 | args.logging_steps / time_diff)) 138 | tic_train = time.time() 139 | 140 | if global_step % args.valid_steps == 0: 141 | save_dir = os.path.join( 142 | args.save_dir, "model_%d" % global_step) 143 | if not os.path.exists(save_dir): 144 | os.makedirs(save_dir) 145 | model_to_save = model 146 | model_to_save.save_pretrained(save_dir) 147 | tokenizer.save_pretrained(save_dir) 148 | if args.max_model_num: 149 | model_to_delete = global_step-args.max_model_num*args.valid_steps 150 | model_to_delete_path = os.path.join( 151 | args.save_dir, "model_%d" % model_to_delete) 152 | if model_to_delete > 0 and os.path.exists(model_to_delete_path): 153 | shutil.rmtree(model_to_delete_path) 154 | 155 | dev_loss_avg, precision, recall, f1 = evaluate( 156 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion) 157 | 158 | if show_bar: 159 | train_postfix_info.update({ 160 | 'F1': f'{f1:.3f}', 161 | 'dev loss': f'{dev_loss_avg:.5f}' 162 | }) 163 | train_data_iterator.set_postfix(train_postfix_info) 164 | with logging_redirect_tqdm([logger.logger]): 165 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 166 | % (precision, recall, f1, dev_loss_avg)) 167 | else: 168 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 169 | % (precision, recall, f1, dev_loss_avg)) 170 | # Save model which has best F1 171 | if f1 > best_f1: 172 | if show_bar: 173 | with logging_redirect_tqdm([logger.logger]): 174 | logger.info( 175 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}" 176 | ) 177 | else: 178 | logger.info( 179 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}" 180 | ) 181 | best_f1 = f1 182 | save_dir = os.path.join(args.save_dir, "model_best") 183 | model_to_save = model 184 | model_to_save.save_pretrained(save_dir) 185 | tokenizer.save_pretrained(save_dir) 186 | tic_train = time.time() 187 | 188 | if args.early_stopping: 189 | dev_loss_avg, precision, recall, f1 = evaluate( 190 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion) 191 | 192 | if show_bar: 193 | train_postfix_info.update({ 194 | 'F1': f'{f1:.3f}', 195 | 'dev loss': f'{dev_loss_avg:.5f}' 196 | }) 197 | train_data_iterator.set_postfix(train_postfix_info) 198 | with logging_redirect_tqdm([logger.logger]): 199 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 200 | % (precision, recall, f1, dev_loss_avg)) 201 | else: 202 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f" 203 | % (precision, recall, f1, dev_loss_avg)) 204 | 205 | # Early Stopping 206 | early_stopping(dev_loss_avg, model) 207 | if early_stopping.early_stop: 208 | if show_bar: 209 | with logging_redirect_tqdm([logger.logger]): 210 | logger.info("Early stopping") 211 | else: 212 | logger.info("Early stopping") 213 | tokenizer.save_pretrained(early_stopping_save_dir) 214 | sys.exit(0) 215 | 216 | 217 | if __name__ == "__main__": 218 | # yapf: disable 219 | parser = argparse.ArgumentParser() 220 | 221 | parser.add_argument("-b", "--batch_size", default=16, type=int, 222 | help="Batch size per GPU/CPU for training.") 223 | parser.add_argument("--learning_rate", default=1e-5, 224 | type=float, help="The initial learning rate for Adam.") 225 | parser.add_argument("-t", "--train_path", default=None, required=True, 226 | type=str, help="The path of train set.") 227 | parser.add_argument("-d", "--dev_path", default=None, required=True, 228 | type=str, help="The path of dev set.") 229 | parser.add_argument("-s", "--save_dir", default='./checkpoint', type=str, 230 | help="The output directory where the model checkpoints will be written.") 231 | parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. " 232 | "Sequences longer than this will be split automatically.") 233 | parser.add_argument("--num_epochs", default=100, type=int, 234 | help="Total number of training epochs to perform.") 235 | parser.add_argument("--seed", default=1000, type=int, 236 | help="Random seed for initialization") 237 | parser.add_argument("--logging_steps", default=10, 238 | type=int, help="The interval steps to logging.") 239 | parser.add_argument("--valid_steps", default=100, type=int, 240 | help="The interval steps to evaluate model performance.") 241 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu", 242 | help="Select which device to train model, defaults to gpu.") 243 | parser.add_argument("-m", "--model", default="uie_base_pytorch", type=str, 244 | help="Select the pretrained model for few-shot learning.") 245 | parser.add_argument("--max_model_num", default=5, type=int, 246 | help="Max number of saved model. Best model and earlystopping model is not included.") 247 | parser.add_argument("--early_stopping", action='store_true', default=False, 248 | help="Use early stopping while training") 249 | 250 | args = parser.parse_args() 251 | 252 | do_train() 253 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import collections 17 | import json 18 | import os 19 | import pickle 20 | import shutil 21 | import numpy as np 22 | 23 | import torch 24 | try: 25 | import paddle 26 | from paddle.utils.download import get_path_from_url 27 | paddle_installed = True 28 | except (ImportError, ModuleNotFoundError): 29 | from utils import get_path_from_url 30 | paddle_installed = False 31 | 32 | from model import UIE 33 | from utils import logger 34 | 35 | MODEL_MAP = { 36 | "uie-base": { 37 | "resource_file_urls": { 38 | "model_state.pdparams": 39 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams", 40 | "model_config.json": 41 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", 42 | "vocab_file": 43 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 44 | "special_tokens_map": 45 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 46 | "tokenizer_config": 47 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json" 48 | } 49 | }, 50 | "uie-medium": { 51 | "resource_file_urls": { 52 | "model_state.pdparams": 53 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams", 54 | "model_config.json": 55 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json", 56 | "vocab_file": 57 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 58 | "special_tokens_map": 59 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 60 | "tokenizer_config": 61 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 62 | } 63 | }, 64 | "uie-mini": { 65 | "resource_file_urls": { 66 | "model_state.pdparams": 67 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams", 68 | "model_config.json": 69 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json", 70 | "vocab_file": 71 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 72 | "special_tokens_map": 73 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 74 | "tokenizer_config": 75 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 76 | } 77 | }, 78 | "uie-micro": { 79 | "resource_file_urls": { 80 | "model_state.pdparams": 81 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams", 82 | "model_config.json": 83 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json", 84 | "vocab_file": 85 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 86 | "special_tokens_map": 87 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 88 | "tokenizer_config": 89 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 90 | } 91 | }, 92 | "uie-nano": { 93 | "resource_file_urls": { 94 | "model_state.pdparams": 95 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams", 96 | "model_config.json": 97 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json", 98 | "vocab_file": 99 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 100 | "special_tokens_map": 101 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 102 | "tokenizer_config": 103 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 104 | } 105 | }, 106 | "uie-medical-base": { 107 | "resource_file_urls": { 108 | "model_state.pdparams": 109 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams", 110 | "model_config.json": 111 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json", 112 | "vocab_file": 113 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt", 114 | "special_tokens_map": 115 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json", 116 | "tokenizer_config": 117 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json", 118 | } 119 | }, 120 | "uie-tiny": { 121 | "resource_file_urls": { 122 | "model_state.pdparams": 123 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams", 124 | "model_config.json": 125 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json", 126 | "vocab_file": 127 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt", 128 | "special_tokens_map": 129 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json", 130 | "tokenizer_config": 131 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json" 132 | } 133 | } 134 | } 135 | 136 | 137 | def build_params_map(attention_num=12): 138 | """ 139 | build params map from paddle-paddle's ERNIE to transformer's BERT 140 | :return: 141 | """ 142 | weight_map = collections.OrderedDict({ 143 | 'encoder.embeddings.word_embeddings.weight': "encoder.embeddings.word_embeddings.weight", 144 | 'encoder.embeddings.position_embeddings.weight': "encoder.embeddings.position_embeddings.weight", 145 | 'encoder.embeddings.token_type_embeddings.weight': "encoder.embeddings.token_type_embeddings.weight", 146 | 'encoder.embeddings.task_type_embeddings.weight': "encoder.embeddings.task_type_embeddings.weight", 147 | 'encoder.embeddings.layer_norm.weight': 'encoder.embeddings.LayerNorm.gamma', 148 | 'encoder.embeddings.layer_norm.bias': 'encoder.embeddings.LayerNorm.beta', 149 | }) 150 | # add attention layers 151 | for i in range(attention_num): 152 | weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.query.weight' 153 | weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.query.bias' 154 | weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.key.weight' 155 | weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.key.bias' 156 | weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.value.weight' 157 | weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.value.bias' 158 | weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.encoder.layer.{i}.attention.output.dense.weight' 159 | weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.encoder.layer.{i}.attention.output.dense.bias' 160 | weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.gamma' 161 | weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.beta' 162 | weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'encoder.encoder.layer.{i}.intermediate.dense.weight' 163 | weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'encoder.encoder.layer.{i}.intermediate.dense.bias' 164 | weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'encoder.encoder.layer.{i}.output.dense.weight' 165 | weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'encoder.encoder.layer.{i}.output.dense.bias' 166 | weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'encoder.encoder.layer.{i}.output.LayerNorm.gamma' 167 | weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'encoder.encoder.layer.{i}.output.LayerNorm.beta' 168 | # add pooler 169 | weight_map.update( 170 | { 171 | 'encoder.pooler.dense.weight': 'encoder.pooler.dense.weight', 172 | 'encoder.pooler.dense.bias': 'encoder.pooler.dense.bias', 173 | 'linear_start.weight': 'linear_start.weight', 174 | 'linear_start.bias': 'linear_start.bias', 175 | 'linear_end.weight': 'linear_end.weight', 176 | 'linear_end.bias': 'linear_end.bias', 177 | } 178 | ) 179 | return weight_map 180 | 181 | 182 | def extract_and_convert(input_dir, output_dir): 183 | if not os.path.exists(output_dir): 184 | os.makedirs(output_dir) 185 | logger.info('=' * 20 + 'save config file' + '=' * 20) 186 | config = json.load( 187 | open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8')) 188 | config = config['init_args'][0] 189 | config["architectures"] = ["UIE"] 190 | config['layer_norm_eps'] = 1e-12 191 | del config['init_class'] 192 | if 'sent_type_vocab_size' in config: 193 | config['type_vocab_size'] = config['sent_type_vocab_size'] 194 | config['intermediate_size'] = 4 * config['hidden_size'] 195 | json.dump(config, open(os.path.join(output_dir, 'config.json'), 196 | 'wt', encoding='utf-8'), indent=4) 197 | logger.info('=' * 20 + 'save vocab file' + '=' * 20) 198 | with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f: 199 | words = f.read().splitlines() 200 | words_set = set() 201 | words_duplicate_indices = [] 202 | for i in range(len(words)-1, -1, -1): 203 | word = words[i] 204 | if word in words_set: 205 | words_duplicate_indices.append(i) 206 | words_set.add(word) 207 | for i, idx in enumerate(words_duplicate_indices): 208 | words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL 209 | with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f: 210 | for word in words: 211 | f.write(word+'\n') 212 | special_tokens_map = { 213 | "unk_token": "[UNK]", 214 | "sep_token": "[SEP]", 215 | "pad_token": "[PAD]", 216 | "cls_token": "[CLS]", 217 | "mask_token": "[MASK]" 218 | } 219 | json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'), 220 | 'wt', encoding='utf-8')) 221 | tokenizer_config = { 222 | "do_lower_case": True, 223 | "unk_token": "[UNK]", 224 | "sep_token": "[SEP]", 225 | "pad_token": "[PAD]", 226 | "cls_token": "[CLS]", 227 | "mask_token": "[MASK]", 228 | "tokenizer_class": "BertTokenizer" 229 | } 230 | json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'), 231 | 'wt', encoding='utf-8')) 232 | logger.info('=' * 20 + 'extract weights' + '=' * 20) 233 | state_dict = collections.OrderedDict() 234 | weight_map = build_params_map(attention_num=config['num_hidden_layers']) 235 | if paddle_installed: 236 | import paddle.fluid.dygraph as D 237 | from paddle import fluid 238 | with fluid.dygraph.guard(): 239 | paddle_paddle_params, _ = D.load_dygraph( 240 | os.path.join(input_dir, 'model_state')) 241 | else: 242 | paddle_paddle_params = pickle.load( 243 | open(os.path.join(input_dir, 'model_state.pdparams'), 'rb')) 244 | del paddle_paddle_params['StructuredToParameterName@@'] 245 | for weight_name, weight_value in paddle_paddle_params.items(): 246 | if 'weight' in weight_name: 247 | if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name: 248 | weight_value = weight_value.transpose() 249 | # Fix: embedding error 250 | if 'word_embeddings.weight' in weight_name: 251 | weight_value[0, :] = 0 252 | if weight_name not in weight_map: 253 | logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}") 254 | continue 255 | state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value) 256 | logger.info( 257 | f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}") 258 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin")) 259 | 260 | 261 | def check_model(input_model): 262 | if not os.path.exists(input_model): 263 | if input_model not in MODEL_MAP: 264 | raise ValueError('input_model not exists!') 265 | 266 | resource_file_urls = MODEL_MAP[input_model]['resource_file_urls'] 267 | logger.info("Downloading resource files...") 268 | 269 | for key, val in resource_file_urls.items(): 270 | file_path = os.path.join(input_model, key) 271 | if not os.path.exists(file_path): 272 | get_path_from_url(val, input_model) 273 | 274 | 275 | def validate_model(tokenizer, pt_model, pd_model: str, atol: float = 1e-5): 276 | 277 | logger.info("Validating PyTorch model...") 278 | 279 | batch_size = 2 280 | seq_length = 6 281 | seq_length_with_token = seq_length+2 282 | max_seq_length = 512 283 | dummy_input = [" ".join([tokenizer.unk_token]) 284 | * seq_length] * batch_size 285 | encoded_inputs = dict(tokenizer(dummy_input, pad_to_max_seq_len=True, max_seq_len=512, return_attention_mask=True, 286 | return_position_ids=True)) 287 | paddle_inputs = {} 288 | for name, value in encoded_inputs.items(): 289 | if name == "attention_mask": 290 | name = "att_mask" 291 | if name == "position_ids": 292 | name = "pos_ids" 293 | paddle_inputs[name] = paddle.to_tensor(value, dtype=paddle.int64) 294 | 295 | paddle_named_outputs = ['start_prob', 'end_prob'] 296 | paddle_outputs = pd_model(**paddle_inputs) 297 | 298 | torch_inputs = {} 299 | for name, value in encoded_inputs.items(): 300 | torch_inputs[name] = torch.tensor(value, dtype=torch.int64) 301 | torch_outputs = pt_model(**torch_inputs) 302 | torch_outputs_dict = {} 303 | 304 | for name, value in torch_outputs.items(): 305 | torch_outputs_dict[name] = value 306 | 307 | torch_outputs_set, ref_outputs_set = set( 308 | torch_outputs_dict.keys()), set(paddle_named_outputs) 309 | if not torch_outputs_set.issubset(ref_outputs_set): 310 | logger.info( 311 | f"\t-[x] Pytorch model output names {torch_outputs_set} do not match reference model {ref_outputs_set}" 312 | ) 313 | 314 | raise ValueError( 315 | "Outputs doesn't match between reference model and Pytorch converted model: " 316 | f"{torch_outputs_set.difference(ref_outputs_set)}" 317 | ) 318 | else: 319 | logger.info( 320 | f"\t-[✓] Pytorch model output names match reference model ({torch_outputs_set})") 321 | 322 | # Check the shape and values match 323 | for name, ref_value in zip(paddle_named_outputs, paddle_outputs): 324 | ref_value = ref_value.numpy() 325 | pt_value = torch_outputs_dict[name].detach().numpy() 326 | logger.info(f'\t- Validating PyTorch Model output "{name}":') 327 | 328 | # Shape 329 | if not pt_value.shape == ref_value.shape: 330 | logger.info( 331 | f"\t\t-[x] shape {pt_value.shape} doesn't match {ref_value.shape}") 332 | raise ValueError( 333 | "Outputs shape doesn't match between reference model and Pytorch converted model: " 334 | f"Got {ref_value.shape} (reference) and {pt_value.shape} (PyTorch)" 335 | ) 336 | else: 337 | logger.info( 338 | f"\t\t-[✓] {pt_value.shape} matches {ref_value.shape}") 339 | 340 | # Values 341 | if not np.allclose(ref_value, pt_value, atol=atol): 342 | logger.info( 343 | f"\t\t-[x] values not close enough (atol: {atol})") 344 | raise ValueError( 345 | "Outputs values doesn't match between reference model and Pytorch converted model: " 346 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - pt_value))}" 347 | ) 348 | else: 349 | logger.info( 350 | f"\t\t-[✓] all values close (atol: {atol})") 351 | 352 | 353 | def do_main(): 354 | check_model(args.input_model) 355 | extract_and_convert(args.input_model, args.output_model) 356 | if not args.no_validate_output: 357 | if paddle_installed: 358 | try: 359 | from paddlenlp.transformers import ErnieTokenizer 360 | from paddlenlp.taskflow.models import UIE as UIEPaddle 361 | except (ImportError, ModuleNotFoundError) as e: 362 | raise ModuleNotFoundError( 363 | 'Module PaddleNLP is not installed. Try install paddlenlp or run convert.py with --no_validate_output') from e 364 | tokenizer: ErnieTokenizer = ErnieTokenizer.from_pretrained( 365 | args.input_model) 366 | model = UIE.from_pretrained(args.output_model) 367 | model.eval() 368 | paddle_model = UIEPaddle.from_pretrained(args.input_model) 369 | paddle_model.eval() 370 | validate_model(tokenizer, model, paddle_model) 371 | else: 372 | logger.warning("Skipping validating PyTorch model because paddle is not installed. " 373 | "The outputs of the model may not be the same as Paddle model.") 374 | 375 | 376 | if __name__ == '__main__': 377 | parser = argparse.ArgumentParser() 378 | parser.add_argument("-i", "--input_model", default="uie-base", type=str, 379 | help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]") 380 | parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str, 381 | help="Directory of output pytorch model") 382 | parser.add_argument("--no_validate_output", action="store_true", 383 | help="Directory of output pytorch model") 384 | args = parser.parse_args() 385 | 386 | do_main() 387 | -------------------------------------------------------------------------------- /uie_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import six 17 | 18 | from transformers import BertTokenizerFast 19 | import math 20 | import argparse 21 | import os.path 22 | 23 | from utils import logger, get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc 24 | 25 | 26 | class ONNXInferBackend(object): 27 | def __init__(self, 28 | model_path_prefix, 29 | device='cpu', 30 | use_fp16=False): 31 | from onnxruntime import InferenceSession, SessionOptions 32 | logger.info(">>> [ONNXInferBackend] Creating Engine ...") 33 | onnx_model = float_onnx_file = os.path.join( 34 | model_path_prefix, "inference.onnx") 35 | if not os.path.exists(onnx_model): 36 | raise OSError(f'{onnx_model} not exists!') 37 | infer_model_dir = model_path_prefix 38 | 39 | if device == "gpu": 40 | providers = ['CUDAExecutionProvider'] 41 | logger.info(">>> [ONNXInferBackend] Use GPU to inference ...") 42 | if use_fp16: 43 | logger.info(">>> [ONNXInferBackend] Use FP16 to inference ...") 44 | from onnxconverter_common import float16 45 | import onnx 46 | fp16_model_file = os.path.join(infer_model_dir, 47 | "fp16_model.onnx") 48 | onnx_model = onnx.load_model(float_onnx_file) 49 | trans_model = float16.convert_float_to_float16( 50 | onnx_model, keep_io_types=True) 51 | onnx.save_model(trans_model, fp16_model_file) 52 | onnx_model = fp16_model_file 53 | else: 54 | providers = ['CPUExecutionProvider'] 55 | logger.info(">>> [ONNXInferBackend] Use CPU to inference ...") 56 | 57 | sess_options = SessionOptions() 58 | self.predictor = InferenceSession( 59 | onnx_model, sess_options=sess_options, providers=providers) 60 | if device == "gpu": 61 | try: 62 | assert 'CUDAExecutionProvider' in self.predictor.get_providers() 63 | except AssertionError: 64 | raise AssertionError( 65 | f"The environment for GPU inference is not set properly. " 66 | "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. " 67 | "Please run the following commands to reinstall: \n " 68 | "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu" 69 | ) 70 | logger.info(">>> [InferBackend] Engine Created ...") 71 | 72 | def infer(self, input_dict: dict): 73 | result = self.predictor.run(None, dict(input_dict)) 74 | return result 75 | 76 | 77 | class PyTorchInferBackend: 78 | def __init__(self, 79 | model_path_prefix, 80 | device='cpu', 81 | use_fp16=False): 82 | from model import UIE 83 | logger.info(">>> [PyTorchInferBackend] Creating Engine ...") 84 | self.model = UIE.from_pretrained(model_path_prefix) 85 | self.model.eval() 86 | self.device = device 87 | if self.device == 'gpu': 88 | logger.info(">>> [PyTorchInferBackend] Use GPU to inference ...") 89 | if use_fp16: 90 | logger.info( 91 | ">>> [PyTorchInferBackend] Use FP16 to inference ...") 92 | self.model = self.model.half() 93 | self.model = self.model.cuda() 94 | else: 95 | logger.info(">>> [PyTorchInferBackend] Use CPU to inference ...") 96 | logger.info(">>> [PyTorchInferBackend] Engine Created ...") 97 | 98 | def infer(self, input_dict): 99 | import torch 100 | for input_name, input_value in input_dict.items(): 101 | input_value = torch.LongTensor(input_value) 102 | if self.device == 'gpu': 103 | input_value = input_value.cuda() 104 | input_dict[input_name] = input_value 105 | 106 | outputs = self.model(**input_dict) 107 | start_prob, end_prob = outputs[0], outputs[1] 108 | if self.device == 'gpu': 109 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu() 110 | start_prob = start_prob.detach().numpy() 111 | end_prob = end_prob.detach().numpy() 112 | return start_prob, end_prob 113 | 114 | 115 | class UIEPredictor(object): 116 | 117 | def __init__(self, task_path, schema, engine='pytorch', device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False, use_fp16=False): 118 | 119 | assert isinstance( 120 | device, six.string_types 121 | ), "The type of device must be string." 122 | assert device in [ 123 | 'cpu', 'gpu'], "The device must be cpu or gpu." 124 | self._engine = engine 125 | self._task_path = task_path 126 | self._device = device 127 | self._position_prob = position_prob 128 | self._max_seq_len = max_seq_len 129 | self._batch_size = 64 130 | self._split_sentence = False 131 | self._use_fp16 = use_fp16 132 | 133 | self._schema_tree = None 134 | self.set_schema(schema) 135 | self._prepare_predictor() 136 | 137 | def _prepare_predictor(self): 138 | self._tokenizer = BertTokenizerFast.from_pretrained( 139 | self._task_path) 140 | assert self._engine in ['pytorch', 141 | 'onnx'], "engine must be pytorch or onnx!" 142 | if self._engine == 'pytorch': 143 | self.inference_backend = PyTorchInferBackend( 144 | self._task_path, device=self._device, use_fp16=self._use_fp16) 145 | 146 | if self._engine == 'onnx': 147 | self.inference_backend = ONNXInferBackend( 148 | self._task_path, device=self._device, use_fp16=self._use_fp16) 149 | 150 | def set_schema(self, schema): 151 | if isinstance(schema, dict) or isinstance(schema, str): 152 | schema = [schema] 153 | self._schema_tree = self._build_tree(schema) 154 | 155 | def __call__(self, inputs): 156 | texts = inputs 157 | if isinstance(texts, str): 158 | texts = [texts] 159 | results = self._multi_stage_predict(texts) 160 | return results 161 | 162 | def _multi_stage_predict(self, datas): 163 | """ 164 | Traversal the schema tree and do multi-stage prediction. 165 | Args: 166 | datas (list): a list of strings 167 | Returns: 168 | list: a list of predictions, where the list's length 169 | equals to the length of `datas` 170 | """ 171 | results = [{} for _ in range(len(datas))] 172 | # input check to early return 173 | if len(datas) < 1 or self._schema_tree is None: 174 | return results 175 | 176 | # copy to stay `self._schema_tree` unchanged 177 | schema_list = self._schema_tree.children[:] 178 | while len(schema_list) > 0: 179 | node = schema_list.pop(0) 180 | examples = [] 181 | input_map = {} 182 | cnt = 0 183 | idx = 0 184 | if not node.prefix: 185 | for data in datas: 186 | examples.append({ 187 | "text": data, 188 | "prompt": dbc2sbc(node.name) 189 | }) 190 | input_map[cnt] = [idx] 191 | idx += 1 192 | cnt += 1 193 | else: 194 | for pre, data in zip(node.prefix, datas): 195 | if len(pre) == 0: 196 | input_map[cnt] = [] 197 | else: 198 | for p in pre: 199 | examples.append({ 200 | "text": data, 201 | "prompt": dbc2sbc(p + node.name) 202 | }) 203 | input_map[cnt] = [i + idx for i in range(len(pre))] 204 | idx += len(pre) 205 | cnt += 1 206 | if len(examples) == 0: 207 | result_list = [] 208 | else: 209 | result_list = self._single_stage_predict(examples) 210 | 211 | if not node.parent_relations: 212 | relations = [[] for i in range(len(datas))] 213 | for k, v in input_map.items(): 214 | for idx in v: 215 | if len(result_list[idx]) == 0: 216 | continue 217 | if node.name not in results[k].keys(): 218 | results[k][node.name] = result_list[idx] 219 | else: 220 | results[k][node.name].extend(result_list[idx]) 221 | if node.name in results[k].keys(): 222 | relations[k].extend(results[k][node.name]) 223 | else: 224 | relations = node.parent_relations 225 | for k, v in input_map.items(): 226 | for i in range(len(v)): 227 | if len(result_list[v[i]]) == 0: 228 | continue 229 | if "relations" not in relations[k][i].keys(): 230 | relations[k][i]["relations"] = { 231 | node.name: result_list[v[i]] 232 | } 233 | elif node.name not in relations[k][i]["relations"].keys( 234 | ): 235 | relations[k][i]["relations"][ 236 | node.name] = result_list[v[i]] 237 | else: 238 | relations[k][i]["relations"][node.name].extend( 239 | result_list[v[i]]) 240 | 241 | new_relations = [[] for i in range(len(datas))] 242 | for i in range(len(relations)): 243 | for j in range(len(relations[i])): 244 | if "relations" in relations[i][j].keys( 245 | ) and node.name in relations[i][j]["relations"].keys(): 246 | for k in range( 247 | len(relations[i][j]["relations"][ 248 | node.name])): 249 | new_relations[i].append(relations[i][j][ 250 | "relations"][node.name][k]) 251 | relations = new_relations 252 | 253 | prefix = [[] for _ in range(len(datas))] 254 | for k, v in input_map.items(): 255 | for idx in v: 256 | for i in range(len(result_list[idx])): 257 | prefix[k].append(result_list[idx][i]["text"] + "的") 258 | 259 | for child in node.children: 260 | child.prefix = prefix 261 | child.parent_relations = relations 262 | schema_list.append(child) 263 | return results 264 | 265 | def _convert_ids_to_results(self, examples, sentence_ids, probs): 266 | """ 267 | Convert ids to raw text in a single stage. 268 | """ 269 | results = [] 270 | for example, sentence_id, prob in zip(examples, sentence_ids, probs): 271 | if len(sentence_id) == 0: 272 | results.append([]) 273 | continue 274 | result_list = [] 275 | text = example["text"] 276 | prompt = example["prompt"] 277 | for i in range(len(sentence_id)): 278 | start, end = sentence_id[i] 279 | if start < 0 and end >= 0: 280 | continue 281 | if end < 0: 282 | start += (len(prompt) + 1) 283 | end += (len(prompt) + 1) 284 | result = {"text": prompt[start:end], 285 | "probability": prob[i]} 286 | result_list.append(result) 287 | else: 288 | result = { 289 | "text": text[start:end], 290 | "start": start, 291 | "end": end, 292 | "probability": prob[i] 293 | } 294 | result_list.append(result) 295 | results.append(result_list) 296 | return results 297 | 298 | def _auto_splitter(self, input_texts, max_text_len, split_sentence=False): 299 | ''' 300 | Split the raw texts automatically for model inference. 301 | Args: 302 | input_texts (List[str]): input raw texts. 303 | max_text_len (int): cutting length. 304 | split_sentence (bool): If True, sentence-level split will be performed. 305 | return: 306 | short_input_texts (List[str]): the short input texts for model inference. 307 | input_mapping (dict): mapping between raw text and short input texts. 308 | ''' 309 | input_mapping = {} 310 | short_input_texts = [] 311 | cnt_org = 0 312 | cnt_short = 0 313 | for text in input_texts: 314 | if not split_sentence: 315 | sens = [text] 316 | else: 317 | sens = cut_chinese_sent(text) 318 | for sen in sens: 319 | lens = len(sen) 320 | if lens <= max_text_len: 321 | short_input_texts.append(sen) 322 | if cnt_org not in input_mapping.keys(): 323 | input_mapping[cnt_org] = [cnt_short] 324 | else: 325 | input_mapping[cnt_org].append(cnt_short) 326 | cnt_short += 1 327 | else: 328 | temp_text_list = [ 329 | sen[i:i + max_text_len] 330 | for i in range(0, lens, max_text_len) 331 | ] 332 | short_input_texts.extend(temp_text_list) 333 | short_idx = cnt_short 334 | cnt_short += math.ceil(lens / max_text_len) 335 | temp_text_id = [ 336 | short_idx + i for i in range(cnt_short - short_idx) 337 | ] 338 | if cnt_org not in input_mapping.keys(): 339 | input_mapping[cnt_org] = temp_text_id 340 | else: 341 | input_mapping[cnt_org].extend(temp_text_id) 342 | cnt_org += 1 343 | return short_input_texts, input_mapping 344 | 345 | def _single_stage_predict(self, inputs): 346 | input_texts = [] 347 | prompts = [] 348 | for i in range(len(inputs)): 349 | input_texts.append(inputs[i]["text"]) 350 | prompts.append(inputs[i]["prompt"]) 351 | # max predict length should exclude the length of prompt and summary tokens 352 | max_predict_len = self._max_seq_len - len(max(prompts)) - 3 353 | 354 | short_input_texts, self.input_mapping = self._auto_splitter( 355 | input_texts, max_predict_len, split_sentence=self._split_sentence) 356 | 357 | short_texts_prompts = [] 358 | for k, v in self.input_mapping.items(): 359 | short_texts_prompts.extend([prompts[k] for i in range(len(v))]) 360 | short_inputs = [{ 361 | "text": short_input_texts[i], 362 | "prompt": short_texts_prompts[i] 363 | } for i in range(len(short_input_texts))] 364 | 365 | sentence_ids = [] 366 | probs = [] 367 | 368 | input_ids = [] 369 | token_type_ids = [] 370 | attention_mask = [] 371 | offset_maps = [] 372 | 373 | encoded_inputs = self._tokenizer( 374 | text=short_texts_prompts, 375 | text_pair=short_input_texts, 376 | stride=2, 377 | truncation=True, 378 | max_length=self._max_seq_len, 379 | padding="longest", 380 | add_special_tokens=True, 381 | return_offsets_mapping=True, 382 | return_tensors="np") 383 | 384 | start_prob_concat, end_prob_concat = [], [] 385 | for batch_start in range(0, len(short_input_texts), self._batch_size): 386 | input_ids = encoded_inputs["input_ids"][batch_start:batch_start+self._batch_size] 387 | token_type_ids = encoded_inputs["token_type_ids"][batch_start:batch_start+self._batch_size] 388 | attention_mask = encoded_inputs["attention_mask"][batch_start:batch_start+self._batch_size] 389 | offset_maps = encoded_inputs["offset_mapping"][batch_start:batch_start+self._batch_size] 390 | input_dict = { 391 | "input_ids": np.array( 392 | input_ids, dtype="int64"), 393 | "token_type_ids": np.array( 394 | token_type_ids, dtype="int64"), 395 | "attention_mask": np.array( 396 | attention_mask, dtype="int64") 397 | } 398 | 399 | outputs = self.inference_backend.infer(input_dict) 400 | start_prob, end_prob = outputs[0], outputs[1] 401 | start_prob_concat.append(start_prob) 402 | end_prob_concat.append(end_prob) 403 | start_prob_concat = np.concatenate(start_prob_concat) 404 | end_prob_concat = np.concatenate(end_prob_concat) 405 | 406 | start_ids_list = get_bool_ids_greater_than( 407 | start_prob_concat, limit=self._position_prob, return_prob=True) 408 | end_ids_list = get_bool_ids_greater_than( 409 | end_prob_concat, limit=self._position_prob, return_prob=True) 410 | 411 | input_ids = input_dict['input_ids'] 412 | sentence_ids = [] 413 | probs = [] 414 | for start_ids, end_ids, ids, offset_map in zip(start_ids_list, 415 | end_ids_list, 416 | input_ids.tolist(), 417 | offset_maps): 418 | for i in reversed(range(len(ids))): 419 | if ids[i] != 0: 420 | ids = ids[:i] 421 | break 422 | span_list = get_span(start_ids, end_ids, with_prob=True) 423 | sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist()) 424 | sentence_ids.append(sentence_id) 425 | probs.append(prob) 426 | 427 | results = self._convert_ids_to_results(short_inputs, sentence_ids, 428 | probs) 429 | results = self._auto_joiner(results, short_input_texts, 430 | self.input_mapping) 431 | return results 432 | 433 | def _auto_joiner(self, short_results, short_inputs, input_mapping): 434 | concat_results = [] 435 | is_cls_task = False 436 | for short_result in short_results: 437 | if short_result == []: 438 | continue 439 | elif 'start' not in short_result[0].keys( 440 | ) and 'end' not in short_result[0].keys(): 441 | is_cls_task = True 442 | break 443 | else: 444 | break 445 | for k, vs in input_mapping.items(): 446 | if is_cls_task: 447 | cls_options = {} 448 | single_results = [] 449 | for v in vs: 450 | if len(short_results[v]) == 0: 451 | continue 452 | if short_results[v][0]['text'] not in cls_options.keys(): 453 | cls_options[short_results[v][0][ 454 | 'text']] = [1, short_results[v][0]['probability']] 455 | else: 456 | cls_options[short_results[v][0]['text']][0] += 1 457 | cls_options[short_results[v][0]['text']][ 458 | 1] += short_results[v][0]['probability'] 459 | if len(cls_options) != 0: 460 | cls_res, cls_info = max(cls_options.items(), 461 | key=lambda x: x[1]) 462 | concat_results.append([{ 463 | 'text': cls_res, 464 | 'probability': cls_info[1] / cls_info[0] 465 | }]) 466 | else: 467 | concat_results.append([]) 468 | else: 469 | offset = 0 470 | single_results = [] 471 | for v in vs: 472 | if v == 0: 473 | single_results = short_results[v] 474 | offset += len(short_inputs[v]) 475 | else: 476 | for i in range(len(short_results[v])): 477 | if 'start' not in short_results[v][ 478 | i] or 'end' not in short_results[v][i]: 479 | continue 480 | short_results[v][i]['start'] += offset 481 | short_results[v][i]['end'] += offset 482 | offset += len(short_inputs[v]) 483 | single_results.extend(short_results[v]) 484 | concat_results.append(single_results) 485 | return concat_results 486 | 487 | def predict(self, input_data): 488 | results = self._multi_stage_predict(input_data) 489 | return results 490 | 491 | @classmethod 492 | def _build_tree(cls, schema, name='root'): 493 | """ 494 | Build the schema tree. 495 | """ 496 | schema_tree = SchemaTree(name) 497 | for s in schema: 498 | if isinstance(s, str): 499 | schema_tree.add_child(SchemaTree(s)) 500 | elif isinstance(s, dict): 501 | for k, v in s.items(): 502 | if isinstance(v, str): 503 | child = [v] 504 | elif isinstance(v, list): 505 | child = v 506 | else: 507 | raise TypeError( 508 | "Invalid schema, value for each key:value pairs should be list or string" 509 | "but {} received".format(type(v))) 510 | schema_tree.add_child(cls._build_tree(child, name=k)) 511 | else: 512 | raise TypeError( 513 | "Invalid schema, element should be string or dict, " 514 | "but {} received".format(type(s))) 515 | return schema_tree 516 | 517 | 518 | class SchemaTree(object): 519 | """ 520 | Implementataion of SchemaTree 521 | """ 522 | 523 | def __init__(self, name='root', children=None): 524 | self.name = name 525 | self.children = [] 526 | self.prefix = None 527 | self.parent_relations = None 528 | if children is not None: 529 | for child in children: 530 | self.add_child(child) 531 | 532 | def __repr__(self): 533 | return self.name 534 | 535 | def add_child(self, node): 536 | assert isinstance( 537 | node, SchemaTree 538 | ), "The children of a node should be an instacne of SchemaTree." 539 | self.children.append(node) 540 | 541 | 542 | def parse_args(): 543 | parser = argparse.ArgumentParser() 544 | # Required parameters 545 | parser.add_argument( 546 | "-m", 547 | "--model_path_prefix", 548 | type=str, 549 | default='uie_base_pytorch', 550 | help="The path prefix of inference model to be used.", ) 551 | parser.add_argument( 552 | "-p", 553 | "--position_prob", 554 | default=0.5, 555 | type=float, 556 | help="Probability threshold for start/end index probabiliry.", ) 557 | parser.add_argument( 558 | "--use_fp16", 559 | action='store_true', 560 | help="Whether to use fp16 inference, only takes effect when deploying on gpu.", 561 | ) 562 | parser.add_argument( 563 | "--max_seq_len", 564 | default=512, 565 | type=int, 566 | help="The maximum input sequence length. Sequences longer than this will be split automatically.", 567 | ) 568 | parser.add_argument( 569 | "-D", 570 | "--device", 571 | choices=['cpu', 'gpu'], 572 | default="gpu", 573 | help="Select which device to run model, defaults to gpu." 574 | ) 575 | parser.add_argument( 576 | "-e", 577 | "--engine", 578 | choices=['pytorch', 'onnx'], 579 | default="pytorch", 580 | help="Select which engine to run model, defaults to pytorch." 581 | ) 582 | args = parser.parse_args() 583 | return args 584 | 585 | 586 | if __name__ == '__main__': 587 | args = parse_args() 588 | args.schema = ['航母'] 589 | uie = UIEPredictor(task_path=args.model_path_prefix, schema=args.schema, engine=args.engine, device=args.device, 590 | position_prob=args.position_prob, max_seq_len=args.max_seq_len, batch_size=64, split_sentence=False, use_fp16=args.use_fp16) 591 | print(uie("印媒所称的“印度第一艘国产航母”—“维克兰特”号")) 592 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import functools 17 | import json 18 | import logging 19 | import math 20 | import random 21 | import re 22 | import shutil 23 | import threading 24 | import time 25 | from functools import partial 26 | 27 | import colorlog 28 | import numpy as np 29 | import torch 30 | from colorama import Back, Fore 31 | from torch.utils.data import Dataset 32 | from tqdm import tqdm 33 | from tqdm.contrib.logging import logging_redirect_tqdm 34 | 35 | loggers = {} 36 | 37 | log_config = { 38 | 'DEBUG': { 39 | 'level': 10, 40 | 'color': 'purple' 41 | }, 42 | 'INFO': { 43 | 'level': 20, 44 | 'color': 'green' 45 | }, 46 | 'TRAIN': { 47 | 'level': 21, 48 | 'color': 'cyan' 49 | }, 50 | 'EVAL': { 51 | 'level': 22, 52 | 'color': 'blue' 53 | }, 54 | 'WARNING': { 55 | 'level': 30, 56 | 'color': 'yellow' 57 | }, 58 | 'ERROR': { 59 | 'level': 40, 60 | 'color': 'red' 61 | }, 62 | 'CRITICAL': { 63 | 'level': 50, 64 | 'color': 'bold_red' 65 | } 66 | } 67 | 68 | 69 | def get_span(start_ids, end_ids, with_prob=False): 70 | """ 71 | Get span set from position start and end list. 72 | 73 | Args: 74 | start_ids (List[int]/List[tuple]): The start index list. 75 | end_ids (List[int]/List[tuple]): The end index list. 76 | with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability). 77 | Returns: 78 | set: The span set without overlapping, every id can only be used once . 79 | """ 80 | if with_prob: 81 | start_ids = sorted(start_ids, key=lambda x: x[0]) 82 | end_ids = sorted(end_ids, key=lambda x: x[0]) 83 | else: 84 | start_ids = sorted(start_ids) 85 | end_ids = sorted(end_ids) 86 | 87 | start_pointer = 0 88 | end_pointer = 0 89 | len_start = len(start_ids) 90 | len_end = len(end_ids) 91 | couple_dict = {} 92 | while start_pointer < len_start and end_pointer < len_end: 93 | if with_prob: 94 | start_id = start_ids[start_pointer][0] 95 | end_id = end_ids[end_pointer][0] 96 | else: 97 | start_id = start_ids[start_pointer] 98 | end_id = end_ids[end_pointer] 99 | 100 | if start_id == end_id: 101 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 102 | start_pointer += 1 103 | end_pointer += 1 104 | continue 105 | if start_id < end_id: 106 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] 107 | start_pointer += 1 108 | continue 109 | if start_id > end_id: 110 | end_pointer += 1 111 | continue 112 | result = [(couple_dict[end], end) for end in couple_dict] 113 | result = set(result) 114 | return result 115 | 116 | 117 | def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False): 118 | """ 119 | Get idx of the last dimension in probability arrays, which is greater than a limitation. 120 | 121 | Args: 122 | probs (List[List[float]]): The input probability arrays. 123 | limit (float): The limitation for probability. 124 | return_prob (bool): Whether to return the probability 125 | Returns: 126 | List[List[int]]: The index of the last dimension meet the conditions. 127 | """ 128 | probs = np.array(probs) 129 | dim_len = len(probs.shape) 130 | if dim_len > 1: 131 | result = [] 132 | for p in probs: 133 | result.append(get_bool_ids_greater_than(p, limit, return_prob)) 134 | return result 135 | else: 136 | result = [] 137 | for i, p in enumerate(probs): 138 | if p > limit: 139 | if return_prob: 140 | result.append((i, p)) 141 | else: 142 | result.append(i) 143 | return result 144 | 145 | 146 | class SpanEvaluator: 147 | """ 148 | SpanEvaluator computes the precision, recall and F1-score for span detection. 149 | """ 150 | 151 | def __init__(self): 152 | super(SpanEvaluator, self).__init__() 153 | self.num_infer_spans = 0 154 | self.num_label_spans = 0 155 | self.num_correct_spans = 0 156 | 157 | def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids): 158 | """ 159 | Computes the precision, recall and F1-score for span detection. 160 | """ 161 | pred_start_ids = get_bool_ids_greater_than(start_probs) 162 | pred_end_ids = get_bool_ids_greater_than(end_probs) 163 | gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist()) 164 | gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist()) 165 | num_correct_spans = 0 166 | num_infer_spans = 0 167 | num_label_spans = 0 168 | for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip( 169 | pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids): 170 | [_correct, _infer, _label] = self.eval_span( 171 | predict_start_ids, predict_end_ids, label_start_ids, 172 | label_end_ids) 173 | num_correct_spans += _correct 174 | num_infer_spans += _infer 175 | num_label_spans += _label 176 | return num_correct_spans, num_infer_spans, num_label_spans 177 | 178 | def update(self, num_correct_spans, num_infer_spans, num_label_spans): 179 | """ 180 | This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input, 181 | to accumulate and update the corresponding status of the SpanEvaluator object. 182 | """ 183 | self.num_infer_spans += num_infer_spans 184 | self.num_label_spans += num_label_spans 185 | self.num_correct_spans += num_correct_spans 186 | 187 | def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids, 188 | label_end_ids): 189 | """ 190 | evaluate position extraction (start, end) 191 | return num_correct, num_infer, num_label 192 | input: [1, 2, 10] [4, 12] [2, 10] [4, 11] 193 | output: (1, 2, 2) 194 | """ 195 | pred_set = get_span(predict_start_ids, predict_end_ids) 196 | label_set = get_span(label_start_ids, label_end_ids) 197 | num_correct = len(pred_set & label_set) 198 | num_infer = len(pred_set) 199 | num_label = len(label_set) 200 | return (num_correct, num_infer, num_label) 201 | 202 | def accumulate(self): 203 | """ 204 | This function returns the mean precision, recall and f1 score for all accumulated minibatches. 205 | 206 | Returns: 207 | tuple: Returns tuple (`precision, recall, f1 score`). 208 | """ 209 | precision = float(self.num_correct_spans / 210 | self.num_infer_spans) if self.num_infer_spans else 0. 211 | recall = float(self.num_correct_spans / 212 | self.num_label_spans) if self.num_label_spans else 0. 213 | f1_score = float(2 * precision * recall / 214 | (precision + recall)) if self.num_correct_spans else 0. 215 | return precision, recall, f1_score 216 | 217 | def reset(self): 218 | """ 219 | Reset function empties the evaluation memory for previous mini-batches. 220 | """ 221 | self.num_infer_spans = 0 222 | self.num_label_spans = 0 223 | self.num_correct_spans = 0 224 | 225 | def name(self): 226 | """ 227 | Return name of metric instance. 228 | """ 229 | return "precision", "recall", "f1" 230 | 231 | 232 | class IEDataset(Dataset): 233 | """ 234 | Dataset for Information Extraction fron jsonl file. 235 | The line type is 236 | { 237 | content 238 | result_list 239 | prompt 240 | } 241 | """ 242 | 243 | def __init__(self, file_path, tokenizer, max_seq_len) -> None: 244 | super().__init__() 245 | self.file_path = file_path 246 | self.dataset = list(reader(file_path)) 247 | self.tokenizer = tokenizer 248 | self.max_seq_len = max_seq_len 249 | 250 | def __len__(self): 251 | return len(self.dataset) 252 | 253 | def __getitem__(self, index): 254 | return convert_example(self.dataset[index], tokenizer=self.tokenizer, max_seq_len=self.max_seq_len) 255 | 256 | 257 | def reader(data_path, max_seq_len=512): 258 | """ 259 | read json 260 | """ 261 | with open(data_path, 'r', encoding='utf-8') as f: 262 | for line in f: 263 | json_line = json.loads(line) 264 | content = json_line['content'] 265 | prompt = json_line['prompt'] 266 | # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP] 267 | # It include three summary tokens. 268 | if max_seq_len <= len(prompt) + 3: 269 | raise ValueError( 270 | "The value of max_seq_len is too small, please set a larger value" 271 | ) 272 | max_content_len = max_seq_len - len(prompt) - 3 273 | if len(content) <= max_content_len: 274 | yield json_line 275 | else: 276 | result_list = json_line['result_list'] 277 | json_lines = [] 278 | accumulate = 0 279 | while True: 280 | cur_result_list = [] 281 | 282 | for result in result_list: 283 | if result['start'] + 1 <= max_content_len < result[ 284 | 'end']: 285 | max_content_len = result['start'] 286 | break 287 | 288 | cur_content = content[:max_content_len] 289 | res_content = content[max_content_len:] 290 | 291 | while True: 292 | if len(result_list) == 0: 293 | break 294 | elif result_list[0]['end'] <= max_content_len: 295 | if result_list[0]['end'] > 0: 296 | cur_result = result_list.pop(0) 297 | cur_result_list.append(cur_result) 298 | else: 299 | cur_result_list = [ 300 | result for result in result_list 301 | ] 302 | break 303 | else: 304 | break 305 | 306 | json_line = { 307 | 'content': cur_content, 308 | 'result_list': cur_result_list, 309 | 'prompt': prompt 310 | } 311 | json_lines.append(json_line) 312 | 313 | for result in result_list: 314 | if result['end'] <= 0: 315 | break 316 | result['start'] -= max_content_len 317 | result['end'] -= max_content_len 318 | accumulate += max_content_len 319 | max_content_len = max_seq_len - len(prompt) - 3 320 | if len(res_content) == 0: 321 | break 322 | elif len(res_content) < max_content_len: 323 | json_line = { 324 | 'content': res_content, 325 | 'result_list': result_list, 326 | 'prompt': prompt 327 | } 328 | json_lines.append(json_line) 329 | break 330 | else: 331 | content = res_content 332 | 333 | for json_line in json_lines: 334 | yield json_line 335 | 336 | 337 | def convert_example(example, tokenizer, max_seq_len): 338 | """ 339 | example: { 340 | title 341 | prompt 342 | content 343 | result_list 344 | } 345 | """ 346 | encoded_inputs = tokenizer( 347 | text=[example["prompt"]], 348 | text_pair=[example["content"]], 349 | truncation=True, 350 | max_length=max_seq_len, 351 | add_special_tokens=True, 352 | return_offsets_mapping=True) 353 | # encoded_inputs = encoded_inputs[0] 354 | offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"][0]] 355 | bias = 0 356 | for index in range(len(offset_mapping)): 357 | if index == 0: 358 | continue 359 | mapping = offset_mapping[index] 360 | if mapping[0] == 0 and mapping[1] == 0 and bias == 0: 361 | bias = index 362 | if mapping[0] == 0 and mapping[1] == 0: 363 | continue 364 | offset_mapping[index][0] += bias 365 | offset_mapping[index][1] += bias 366 | start_ids = [0 for x in range(max_seq_len)] 367 | end_ids = [0 for x in range(max_seq_len)] 368 | for item in example["result_list"]: 369 | start = map_offset(item["start"] + bias, offset_mapping) 370 | end = map_offset(item["end"] - 1 + bias, offset_mapping) 371 | start_ids[start] = 1.0 372 | end_ids[end] = 1.0 373 | 374 | tokenized_output = [ 375 | encoded_inputs["input_ids"][0], encoded_inputs["token_type_ids"][0], 376 | encoded_inputs["attention_mask"][0], 377 | start_ids, end_ids 378 | ] 379 | tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output] 380 | tokenized_output = [ 381 | np.pad(x, (0, max_seq_len-x.shape[-1]), 'constant') for x in tokenized_output] 382 | return tuple(tokenized_output) 383 | 384 | 385 | def map_offset(ori_offset, offset_mapping): 386 | """ 387 | map ori offset to token offset 388 | """ 389 | for index, span in enumerate(offset_mapping): 390 | if span[0] <= ori_offset < span[1]: 391 | return index 392 | return -1 393 | 394 | 395 | def set_seed(seed): 396 | torch.manual_seed(seed) 397 | torch.cuda.manual_seed_all(seed) 398 | np.random.seed(seed) 399 | random.seed(seed) 400 | np.random.seed(seed) 401 | 402 | 403 | class Logger(object): 404 | ''' 405 | Deafult logger in UIE 406 | 407 | Args: 408 | name(str) : Logger name, default is 'UIE' 409 | ''' 410 | 411 | def __init__(self, name: str = None): 412 | name = 'UIE' if not name else name 413 | self.logger = logging.getLogger(name) 414 | 415 | for key, conf in log_config.items(): 416 | logging.addLevelName(conf['level'], key) 417 | self.__dict__[key] = functools.partial( 418 | self.__call__, conf['level']) 419 | self.__dict__[key.lower()] = functools.partial( 420 | self.__call__, conf['level']) 421 | 422 | self.format = colorlog.ColoredFormatter( 423 | '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s', 424 | log_colors={key: conf['color'] 425 | for key, conf in log_config.items()}) 426 | 427 | self.handler = logging.StreamHandler() 428 | self.handler.setFormatter(self.format) 429 | 430 | self.logger.addHandler(self.handler) 431 | self.logLevel = 'DEBUG' 432 | self.logger.setLevel(logging.DEBUG) 433 | self.logger.propagate = False 434 | self._is_enable = True 435 | 436 | def disable(self): 437 | self._is_enable = False 438 | 439 | def enable(self): 440 | self._is_enable = True 441 | 442 | @property 443 | def is_enable(self) -> bool: 444 | return self._is_enable 445 | 446 | def __call__(self, log_level: str, msg: str): 447 | if not self.is_enable: 448 | return 449 | 450 | self.logger.log(log_level, msg) 451 | 452 | @contextlib.contextmanager 453 | def use_terminator(self, terminator: str): 454 | old_terminator = self.handler.terminator 455 | self.handler.terminator = terminator 456 | yield 457 | self.handler.terminator = old_terminator 458 | 459 | @contextlib.contextmanager 460 | def processing(self, msg: str, interval: float = 0.1): 461 | ''' 462 | Continuously print a progress bar with rotating special effects. 463 | 464 | Args: 465 | msg(str): Message to be printed. 466 | interval(float): Rotation interval. Default to 0.1. 467 | ''' 468 | end = False 469 | 470 | def _printer(): 471 | index = 0 472 | flags = ['\\', '|', '/', '-'] 473 | while not end: 474 | flag = flags[index % len(flags)] 475 | with self.use_terminator('\r'): 476 | self.info('{}: {}'.format(msg, flag)) 477 | time.sleep(interval) 478 | index += 1 479 | 480 | t = threading.Thread(target=_printer) 481 | t.start() 482 | yield 483 | end = True 484 | 485 | 486 | logger = Logger() 487 | 488 | 489 | BAR_FORMAT = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}} {Fore.RED}{{rate_fmt}}{{postfix}}{Fore.RESET} eta {Fore.CYAN}{{remaining}}{Fore.RESET}' 490 | BAR_FORMAT_NO_TIME = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}}{Fore.RESET}' 491 | BAR_TYPE = [ 492 | "░▝▗▖▘▚▞▛▙█", 493 | "░▖▘▝▗▚▞█", 494 | " ▖▘▝▗▚▞█", 495 | "░▒█", 496 | " >=", 497 | " ▏▎▍▌▋▊▉█" 498 | "░▏▎▍▌▋▊▉█" 499 | ] 500 | 501 | tqdm = partial(tqdm, bar_format=BAR_FORMAT, ascii=BAR_TYPE[0], leave=False) 502 | 503 | 504 | def get_id_and_prob(spans, offset_map): 505 | prompt_length = 0 506 | for i in range(1, len(offset_map)): 507 | if offset_map[i] != [0, 0]: 508 | prompt_length += 1 509 | else: 510 | break 511 | 512 | for i in range(1, prompt_length + 1): 513 | offset_map[i][0] -= (prompt_length + 1) 514 | offset_map[i][1] -= (prompt_length + 1) 515 | 516 | sentence_id = [] 517 | prob = [] 518 | for start, end in spans: 519 | prob.append(start[1] * end[1]) 520 | sentence_id.append( 521 | (offset_map[start[0]][0], offset_map[end[0]][1])) 522 | return sentence_id, prob 523 | 524 | 525 | def cut_chinese_sent(para): 526 | """ 527 | Cut the Chinese sentences more precisely, reference to 528 | "https://blog.csdn.net/blmoistawinde/article/details/82379256". 529 | """ 530 | para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para) 531 | para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) 532 | para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) 533 | para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) 534 | para = para.rstrip() 535 | return para.split("\n") 536 | 537 | 538 | def dbc2sbc(s): 539 | rs = "" 540 | for char in s: 541 | code = ord(char) 542 | if code == 0x3000: 543 | code = 0x0020 544 | else: 545 | code -= 0xfee0 546 | if not (0x0021 <= code and code <= 0x7e): 547 | rs += char 548 | continue 549 | rs += chr(code) 550 | return rs 551 | 552 | 553 | class EarlyStopping: 554 | """Early stops the training if validation loss doesn't improve after a given patience.""" 555 | 556 | def __init__(self, patience=7, verbose=False, delta=0, save_dir='checkpoint/early_stopping', trace_func=print): 557 | """ 558 | Args: 559 | patience (int): How long to wait after last time validation loss improved. 560 | Default: 7 561 | verbose (bool): If True, prints a message for each validation loss improvement. 562 | Default: False 563 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 564 | Default: 0 565 | path (str): Path for the checkpoint to be saved to. 566 | Default: 'checkpoint/early_stopping' 567 | trace_func (function): trace print function. 568 | Default: print 569 | """ 570 | self.patience = patience 571 | self.verbose = verbose 572 | self.counter = 0 573 | self.best_score = None 574 | self.early_stop = False 575 | self.val_loss_min = np.Inf 576 | self.delta = delta 577 | self.save_dir = save_dir 578 | self.trace_func = trace_func 579 | 580 | def __call__(self, val_loss, model): 581 | 582 | score = -val_loss 583 | 584 | if self.best_score is None: 585 | self.best_score = score 586 | self.save_checkpoint(val_loss, model) 587 | elif score < self.best_score + self.delta: 588 | self.counter += 1 589 | self.trace_func( 590 | f'EarlyStopping counter: {self.counter} out of {self.patience}') 591 | if self.counter >= self.patience: 592 | self.early_stop = True 593 | else: 594 | self.best_score = score 595 | self.save_checkpoint(val_loss, model) 596 | self.counter = 0 597 | 598 | def save_checkpoint(self, val_loss, model): 599 | '''Saves model when validation loss decrease.''' 600 | if self.verbose: 601 | self.trace_func( 602 | f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 603 | model.save_pretrained(self.save_dir) 604 | self.val_loss_min = val_loss 605 | 606 | 607 | def convert_cls_examples(raw_examples, prompt_prefix, options): 608 | examples = [] 609 | logger.info(f"Converting doccano data...") 610 | with tqdm(total=len(raw_examples)) as pbar: 611 | for line in raw_examples: 612 | items = json.loads(line) 613 | # Compatible with doccano >= 1.6.2 614 | if "data" in items.keys(): 615 | text, labels = items["data"], items["label"] 616 | else: 617 | text, labels = items["text"], items["label"] 618 | random.shuffle(options) 619 | prompt = "" 620 | sep = "," 621 | for option in options: 622 | prompt += option 623 | prompt += sep 624 | prompt = prompt_prefix + "[" + prompt.rstrip(sep) + "]" 625 | 626 | result_list = [] 627 | example = { 628 | "content": text, 629 | "result_list": result_list, 630 | "prompt": prompt 631 | } 632 | for label in labels: 633 | start = prompt.rfind(label[0]) - len(prompt) - 1 634 | end = start + len(label) 635 | result = {"text": label, "start": start, "end": end} 636 | example["result_list"].append(result) 637 | examples.append(example) 638 | return examples 639 | 640 | 641 | def add_negative_example(examples, texts, prompts, label_set, negative_ratio): 642 | negative_examples = [] 643 | positive_examples = [] 644 | with tqdm(total=len(prompts)) as pbar: 645 | for i, prompt in enumerate(prompts): 646 | negative_sample = [] 647 | redundants_list = list(set(label_set) ^ set(prompt)) 648 | redundants_list.sort() 649 | 650 | num_positive = len(examples[i]) 651 | if num_positive != 0: 652 | actual_ratio = math.ceil(len(redundants_list) / num_positive) 653 | else: 654 | # Set num_positive to 1 for text without positive example 655 | num_positive, actual_ratio = 1, 0 656 | 657 | if actual_ratio <= negative_ratio or negative_ratio == -1: 658 | idxs = [k for k in range(len(redundants_list))] 659 | else: 660 | idxs = random.sample( 661 | range(0, len(redundants_list)), 662 | negative_ratio * num_positive) 663 | 664 | for idx in idxs: 665 | negative_result = { 666 | "content": texts[i], 667 | "result_list": [], 668 | "prompt": redundants_list[idx] 669 | } 670 | negative_examples.append(negative_result) 671 | positive_examples.extend(examples[i]) 672 | pbar.update(1) 673 | return positive_examples, negative_examples 674 | 675 | 676 | def add_full_negative_example(examples, texts, relation_prompts, predicate_set, 677 | subject_goldens): 678 | with tqdm(total=len(relation_prompts)) as pbar: 679 | for i, relation_prompt in enumerate(relation_prompts): 680 | negative_sample = [] 681 | for subject in subject_goldens[i]: 682 | for predicate in predicate_set: 683 | # The relation prompt is constructed as follows: 684 | # subject + "的" + predicate 685 | prompt = subject + "的" + predicate 686 | if prompt not in relation_prompt: 687 | negative_result = { 688 | "content": texts[i], 689 | "result_list": [], 690 | "prompt": prompt 691 | } 692 | negative_sample.append(negative_result) 693 | examples[i].extend(negative_sample) 694 | pbar.update(1) 695 | return examples 696 | 697 | 698 | def construct_relation_prompt_set(entity_name_set, predicate_set): 699 | relation_prompt_set = set() 700 | for entity_name in entity_name_set: 701 | for predicate in predicate_set: 702 | # The relation prompt is constructed as follows: 703 | # subject + "的" + predicate 704 | relation_prompt = entity_name + "的" + predicate 705 | relation_prompt_set.add(relation_prompt) 706 | return sorted(list(relation_prompt_set)) 707 | 708 | 709 | def convert_ext_examples(raw_examples, negative_ratio, is_train=True): 710 | texts = [] 711 | entity_examples = [] 712 | relation_examples = [] 713 | entity_prompts = [] 714 | relation_prompts = [] 715 | entity_label_set = [] 716 | entity_name_set = [] 717 | predicate_set = [] 718 | subject_goldens = [] 719 | 720 | logger.info(f"Converting doccano data...") 721 | with tqdm(total=len(raw_examples)) as pbar: 722 | for line in raw_examples: 723 | items = json.loads(line) 724 | entity_id = 0 725 | if "data" in items.keys(): 726 | relation_mode = False 727 | if isinstance(items["label"], 728 | dict) and "entities" in items["label"].keys(): 729 | relation_mode = True 730 | text = items["data"] 731 | entities = [] 732 | if not relation_mode: 733 | # Export file in JSONL format which doccano < 1.7.0 734 | for item in items["label"]: 735 | entity = { 736 | "id": entity_id, 737 | "start_offset": item[0], 738 | "end_offset": item[1], 739 | "label": item[2] 740 | } 741 | entities.append(entity) 742 | entity_id += 1 743 | else: 744 | # Export file in JSONL format for relation labeling task which doccano < 1.7.0 745 | for item in items["label"]["entities"]: 746 | entity = { 747 | "id": entity_id, 748 | "start_offset": item["start_offset"], 749 | "end_offset": item["end_offset"], 750 | "label": item["label"] 751 | } 752 | entities.append(entity) 753 | entity_id += 1 754 | relations = [] 755 | else: 756 | # Export file in JSONL format which doccano >= 1.7.0 757 | if "label" in items.keys(): 758 | text = items["text"] 759 | entities = [] 760 | for item in items["label"]: 761 | entity = { 762 | "id": entity_id, 763 | "start_offset": item[0], 764 | "end_offset": item[1], 765 | "label": item[2] 766 | } 767 | entities.append(entity) 768 | entity_id += 1 769 | relations = [] 770 | else: 771 | # Export file in JSONL (relation) format 772 | text, relations, entities = items["text"], items[ 773 | "relations"], items["entities"] 774 | texts.append(text) 775 | 776 | entity_example = [] 777 | entity_prompt = [] 778 | entity_example_map = {} 779 | entity_map = {} # id to entity name 780 | for entity in entities: 781 | entity_name = text[entity["start_offset"]:entity["end_offset"]] 782 | entity_map[entity["id"]] = { 783 | "name": entity_name, 784 | "start": entity["start_offset"], 785 | "end": entity["end_offset"] 786 | } 787 | 788 | entity_label = entity["label"] 789 | result = { 790 | "text": entity_name, 791 | "start": entity["start_offset"], 792 | "end": entity["end_offset"] 793 | } 794 | if entity_label not in entity_example_map.keys(): 795 | entity_example_map[entity_label] = { 796 | "content": text, 797 | "result_list": [result], 798 | "prompt": entity_label 799 | } 800 | else: 801 | entity_example_map[entity_label]["result_list"].append( 802 | result) 803 | 804 | if entity_label not in entity_label_set: 805 | entity_label_set.append(entity_label) 806 | if entity_name not in entity_name_set: 807 | entity_name_set.append(entity_name) 808 | entity_prompt.append(entity_label) 809 | 810 | for v in entity_example_map.values(): 811 | entity_example.append(v) 812 | 813 | entity_examples.append(entity_example) 814 | entity_prompts.append(entity_prompt) 815 | 816 | subject_golden = [] 817 | relation_example = [] 818 | relation_prompt = [] 819 | relation_example_map = {} 820 | for relation in relations: 821 | predicate = relation["type"] 822 | subject_id = relation["from_id"] 823 | object_id = relation["to_id"] 824 | # The relation prompt is constructed as follows: 825 | # subject + "的" + predicate 826 | prompt = entity_map[subject_id]["name"] + "的" + predicate 827 | if entity_map[subject_id]["name"] not in subject_golden: 828 | subject_golden.append(entity_map[subject_id]["name"]) 829 | result = { 830 | "text": entity_map[object_id]["name"], 831 | "start": entity_map[object_id]["start"], 832 | "end": entity_map[object_id]["end"] 833 | } 834 | if prompt not in relation_example_map.keys(): 835 | relation_example_map[prompt] = { 836 | "content": text, 837 | "result_list": [result], 838 | "prompt": prompt 839 | } 840 | else: 841 | relation_example_map[prompt]["result_list"].append(result) 842 | 843 | if predicate not in predicate_set: 844 | predicate_set.append(predicate) 845 | relation_prompt.append(prompt) 846 | 847 | for v in relation_example_map.values(): 848 | relation_example.append(v) 849 | 850 | relation_examples.append(relation_example) 851 | relation_prompts.append(relation_prompt) 852 | subject_goldens.append(subject_golden) 853 | pbar.update(1) 854 | 855 | def concat_examples(positive_examples, negative_examples, negative_ratio): 856 | examples = [] 857 | if math.ceil(len(negative_examples) / 858 | len(positive_examples)) <= negative_ratio: 859 | examples = positive_examples + negative_examples 860 | else: 861 | # Random sampling the negative examples to ensure overall negative ratio unchanged. 862 | idxs = random.sample( 863 | range(0, len(negative_examples)), 864 | negative_ratio * len(positive_examples)) 865 | negative_examples_sampled = [] 866 | for idx in idxs: 867 | negative_examples_sampled.append(negative_examples[idx]) 868 | examples = positive_examples + negative_examples_sampled 869 | return examples 870 | 871 | logger.info(f"Adding negative samples for first stage prompt...") 872 | positive_examples, negative_examples = add_negative_example( 873 | entity_examples, texts, entity_prompts, entity_label_set, 874 | negative_ratio) 875 | if len(positive_examples) == 0: 876 | all_entity_examples = [] 877 | elif is_train: 878 | all_entity_examples = concat_examples(positive_examples, 879 | negative_examples, negative_ratio) 880 | else: 881 | all_entity_examples = positive_examples + negative_examples 882 | 883 | all_relation_examples = [] 884 | if len(predicate_set) != 0: 885 | if is_train: 886 | logger.info(f"Adding negative samples for second stage prompt...") 887 | relation_prompt_set = construct_relation_prompt_set(entity_name_set, 888 | predicate_set) 889 | positive_examples, negative_examples = add_negative_example( 890 | relation_examples, texts, relation_prompts, relation_prompt_set, 891 | negative_ratio) 892 | all_relation_examples = concat_examples( 893 | positive_examples, negative_examples, negative_ratio) 894 | else: 895 | logger.info(f"Adding negative samples for second stage prompt...") 896 | relation_examples = add_full_negative_example( 897 | relation_examples, texts, relation_prompts, predicate_set, 898 | subject_goldens) 899 | all_relation_examples = [ 900 | r 901 | for r in relation_example 902 | for relation_example in relation_examples 903 | ] 904 | return all_entity_examples, all_relation_examples 905 | 906 | 907 | def get_path_from_url(url, 908 | root_dir, 909 | check_exist=True, 910 | decompress=True): 911 | """ Download from given url to root_dir. 912 | if file or directory specified by url is exists under 913 | root_dir, return the path directly, otherwise download 914 | from url and decompress it, return the path. 915 | 916 | Args: 917 | url (str): download url 918 | root_dir (str): root dir for downloading, it should be 919 | WEIGHTS_HOME or DATASET_HOME 920 | decompress (bool): decompress zip or tar file. Default is `True` 921 | 922 | Returns: 923 | str: a local path to save downloaded models & weights & datasets. 924 | """ 925 | 926 | import os.path 927 | import os 928 | import tarfile 929 | import zipfile 930 | 931 | def is_url(path): 932 | """ 933 | Whether path is URL. 934 | Args: 935 | path (string): URL string or not. 936 | """ 937 | return path.startswith('http://') or path.startswith('https://') 938 | 939 | def _map_path(url, root_dir): 940 | # parse path after download under root_dir 941 | fname = os.path.split(url)[-1] 942 | fpath = fname 943 | return os.path.join(root_dir, fpath) 944 | 945 | def _get_download(url, fullname): 946 | import requests 947 | # using requests.get method 948 | fname = os.path.basename(fullname) 949 | try: 950 | req = requests.get(url, stream=True) 951 | except Exception as e: # requests.exceptions.ConnectionError 952 | logger.info("Downloading {} from {} failed with exception {}".format( 953 | fname, url, str(e))) 954 | return False 955 | 956 | if req.status_code != 200: 957 | raise RuntimeError("Downloading from {} failed with code " 958 | "{}!".format(url, req.status_code)) 959 | 960 | # For protecting download interupted, download to 961 | # tmp_fullname firstly, move tmp_fullname to fullname 962 | # after download finished 963 | tmp_fullname = fullname + "_tmp" 964 | total_size = req.headers.get('content-length') 965 | with open(tmp_fullname, 'wb') as f: 966 | if total_size: 967 | with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar: 968 | for chunk in req.iter_content(chunk_size=1024): 969 | f.write(chunk) 970 | pbar.update(1) 971 | else: 972 | for chunk in req.iter_content(chunk_size=1024): 973 | if chunk: 974 | f.write(chunk) 975 | shutil.move(tmp_fullname, fullname) 976 | 977 | return fullname 978 | 979 | def _download(url, path): 980 | """ 981 | Download from url, save to path. 982 | 983 | url (str): download url 984 | path (str): download to given path 985 | """ 986 | 987 | if not os.path.exists(path): 988 | os.makedirs(path) 989 | 990 | fname = os.path.split(url)[-1] 991 | fullname = os.path.join(path, fname) 992 | retry_cnt = 0 993 | 994 | logger.info("Downloading {} from {}".format(fname, url)) 995 | DOWNLOAD_RETRY_LIMIT = 3 996 | while not os.path.exists(fullname): 997 | if retry_cnt < DOWNLOAD_RETRY_LIMIT: 998 | retry_cnt += 1 999 | else: 1000 | raise RuntimeError("Download from {} failed. " 1001 | "Retry limit reached".format(url)) 1002 | 1003 | if not _get_download(url, fullname): 1004 | time.sleep(1) 1005 | continue 1006 | 1007 | return fullname 1008 | 1009 | def _uncompress_file_zip(filepath): 1010 | with zipfile.ZipFile(filepath, 'r') as files: 1011 | file_list = files.namelist() 1012 | 1013 | file_dir = os.path.dirname(filepath) 1014 | 1015 | if _is_a_single_file(file_list): 1016 | rootpath = file_list[0] 1017 | uncompressed_path = os.path.join(file_dir, rootpath) 1018 | files.extractall(file_dir) 1019 | 1020 | elif _is_a_single_dir(file_list): 1021 | # `strip(os.sep)` to remove `os.sep` in the tail of path 1022 | rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( 1023 | os.sep)[-1] 1024 | uncompressed_path = os.path.join(file_dir, rootpath) 1025 | 1026 | files.extractall(file_dir) 1027 | else: 1028 | rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] 1029 | uncompressed_path = os.path.join(file_dir, rootpath) 1030 | if not os.path.exists(uncompressed_path): 1031 | os.makedirs(uncompressed_path) 1032 | files.extractall(os.path.join(file_dir, rootpath)) 1033 | 1034 | return uncompressed_path 1035 | 1036 | def _is_a_single_file(file_list): 1037 | if len(file_list) == 1 and file_list[0].find(os.sep) < 0: 1038 | return True 1039 | return False 1040 | 1041 | def _is_a_single_dir(file_list): 1042 | new_file_list = [] 1043 | for file_path in file_list: 1044 | if '/' in file_path: 1045 | file_path = file_path.replace('/', os.sep) 1046 | elif '\\' in file_path: 1047 | file_path = file_path.replace('\\', os.sep) 1048 | new_file_list.append(file_path) 1049 | 1050 | file_name = new_file_list[0].split(os.sep)[0] 1051 | for i in range(1, len(new_file_list)): 1052 | if file_name != new_file_list[i].split(os.sep)[0]: 1053 | return False 1054 | return True 1055 | 1056 | def _uncompress_file_tar(filepath, mode="r:*"): 1057 | with tarfile.open(filepath, mode) as files: 1058 | file_list = files.getnames() 1059 | 1060 | file_dir = os.path.dirname(filepath) 1061 | 1062 | if _is_a_single_file(file_list): 1063 | rootpath = file_list[0] 1064 | uncompressed_path = os.path.join(file_dir, rootpath) 1065 | files.extractall(file_dir) 1066 | elif _is_a_single_dir(file_list): 1067 | rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split( 1068 | os.sep)[-1] 1069 | uncompressed_path = os.path.join(file_dir, rootpath) 1070 | files.extractall(file_dir) 1071 | else: 1072 | rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] 1073 | uncompressed_path = os.path.join(file_dir, rootpath) 1074 | if not os.path.exists(uncompressed_path): 1075 | os.makedirs(uncompressed_path) 1076 | 1077 | files.extractall(os.path.join(file_dir, rootpath)) 1078 | 1079 | return uncompressed_path 1080 | 1081 | def _decompress(fname): 1082 | """ 1083 | Decompress for zip and tar file 1084 | """ 1085 | logger.info("Decompressing {}...".format(fname)) 1086 | 1087 | # For protecting decompressing interupted, 1088 | # decompress to fpath_tmp directory firstly, if decompress 1089 | # successed, move decompress files to fpath and delete 1090 | # fpath_tmp and remove download compress file. 1091 | 1092 | if tarfile.is_tarfile(fname): 1093 | uncompressed_path = _uncompress_file_tar(fname) 1094 | elif zipfile.is_zipfile(fname): 1095 | uncompressed_path = _uncompress_file_zip(fname) 1096 | else: 1097 | raise TypeError("Unsupport compress file type {}".format(fname)) 1098 | 1099 | return uncompressed_path 1100 | 1101 | assert is_url(url), "downloading from {} not a url".format(url) 1102 | fullpath = _map_path(url, root_dir) 1103 | if os.path.exists(fullpath) and check_exist: 1104 | logger.info("Found {}".format(fullpath)) 1105 | else: 1106 | fullpath = _download(url, root_dir) 1107 | 1108 | if decompress and (tarfile.is_tarfile(fullpath) or 1109 | zipfile.is_zipfile(fullpath)): 1110 | fullpath = _decompress(fullpath) 1111 | 1112 | return fullpath 1113 | --------------------------------------------------------------------------------