├── data
└── dgre
│ └── raw_data
│ └── process.py
├── evaluate.py
├── generic.py
├── README.md
├── doccano.py
├── export_model.py
├── doccano.md
├── model.py
├── finetune.py
├── convert.py
├── uie_predictor.py
└── utils.py
/data/dgre/raw_data/process.py:
--------------------------------------------------------------------------------
1 | import json
2 | import codecs
3 |
4 | data_path = "train.json"
5 |
6 | with codecs.open(data_path, 'r', encoding="utf-8") as fp:
7 | data = fp.readlines()
8 | output_file = open("../mid_data/doccnao_train.json', 'w', encoding='utf-8')
9 |
10 | for did, d in enumerate(data):
11 | d = eval(d)
12 | tmp = {}
13 | tmp["id"] = d['ID']
14 | tmp['text'] = d['text']
15 | tmp['relations'] = []
16 | tmp['entities'] = []
17 | ent_id = 0
18 | for rel_id,spo in enumerate(d['spo_list']):
19 | rel_tmp = {}
20 | ent_tmp = {}
21 | rel_tmp['id'] = rel_id
22 | ent_tmp['id'] = ent_id
23 | h = spo['h']
24 | ent_tmp['start_offset'] = h['pos'][0]
25 | ent_tmp['end_offset'] = h['pos'][1]
26 | ent_tmp['label'] = "主体"
27 | if ent_tmp not in tmp['entities']:
28 | tmp['entities'].append(ent_tmp)
29 | from_id = ent_id
30 | ent_id += 1
31 | else:
32 | ind = tmp['entities'].index(ent_tmp)
33 | from_id = tmp['entities'][ind]['id']
34 | ent_id = len(tmp['entities']) + 1
35 |
36 | t = spo['t']
37 | ent_tmp = {}
38 | ent_tmp['id'] = ent_id
39 | ent_tmp['start_offset'] = t['pos'][0]
40 | ent_tmp['end_offset'] = t['pos'][1]
41 | ent_tmp['label'] = "客体"
42 | if ent_tmp not in tmp['entities']:
43 | tmp['entities'].append(ent_tmp)
44 | to_id = ent_id
45 | ent_id += 1
46 | else:
47 | ind = tmp['entities'].index(ent_tmp)
48 | to_id = tmp['entities'][ind]['id']
49 | ent_id = len(tmp['entities']) + 1
50 |
51 | rel_tmp['from_id'] = from_id
52 | rel_tmp['to_id'] = to_id
53 | rel_tmp['type'] = spo['relation']
54 |
55 | tmp['relations'].append(rel_tmp)
56 | output_file.write(json.dumps(tmp, ensure_ascii=False) + "\n")
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from model import UIE
16 | import argparse
17 | import torch
18 | from utils import SpanEvaluator, IEDataset, logger, tqdm
19 | from transformers import BertTokenizerFast
20 | from torch.utils.data import DataLoader
21 |
22 |
23 | @torch.no_grad()
24 | def evaluate(model, metric, data_loader, device='gpu', loss_fn=None, show_bar=True):
25 | """
26 | Given a dataset, it evals model and computes the metric.
27 | Args:
28 | model(obj:`torch.nn.Module`): A model to classify texts.
29 | metric(obj:`Metric`): The evaluation metric.
30 | data_loader(obj:`torch.utils.data.DataLoader`): The dataset loader which generates batches.
31 | """
32 | return_loss = False
33 | if loss_fn is not None:
34 | return_loss = True
35 | model.eval()
36 | metric.reset()
37 | loss_list = []
38 | loss_sum = 0
39 | loss_num = 0
40 | if show_bar:
41 | data_loader = tqdm(
42 | data_loader, desc="Evaluating", unit='batch')
43 | for batch in data_loader:
44 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch
45 | if device == 'gpu':
46 | input_ids = input_ids.cuda()
47 | token_type_ids = token_type_ids.cuda()
48 | att_mask = att_mask.cuda()
49 | outputs = model(input_ids=input_ids,
50 | token_type_ids=token_type_ids,
51 | attention_mask=att_mask)
52 | start_prob, end_prob = outputs[0], outputs[1]
53 | if device == 'gpu':
54 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
55 | start_ids = start_ids.type(torch.float32)
56 | end_ids = end_ids.type(torch.float32)
57 |
58 | if return_loss:
59 | # Calculate loss
60 | loss_start = loss_fn(start_prob, start_ids)
61 | loss_end = loss_fn(end_prob, end_ids)
62 | loss = (loss_start + loss_end) / 2.0
63 | loss = float(loss)
64 | loss_list.append(loss)
65 | loss_sum += loss
66 | loss_num += 1
67 | if show_bar:
68 | data_loader.set_postfix(
69 | {
70 | 'dev loss': f'{loss_sum / loss_num:.5f}'
71 | }
72 | )
73 |
74 | # Calcalate metric
75 | num_correct, num_infer, num_label = metric.compute(start_prob, end_prob,
76 | start_ids, end_ids)
77 | metric.update(num_correct, num_infer, num_label)
78 | precision, recall, f1 = metric.accumulate()
79 | model.train()
80 | if return_loss:
81 | loss_avg = sum(loss_list) / len(loss_list)
82 | return loss_avg, precision, recall, f1
83 | else:
84 | return precision, recall, f1
85 |
86 |
87 | def do_eval():
88 | tokenizer = BertTokenizerFast.from_pretrained(args.model_path)
89 | model = UIE.from_pretrained(args.model_path)
90 | if args.device == 'gpu':
91 | model = model.cuda()
92 |
93 | test_ds = IEDataset(args.test_path, tokenizer=tokenizer,
94 | max_seq_len=args.max_seq_len)
95 |
96 | test_data_loader = DataLoader(
97 | test_ds, batch_size=args.batch_size, shuffle=False)
98 | metric = SpanEvaluator()
99 | precision, recall, f1 = evaluate(
100 | model, metric, test_data_loader, args.device)
101 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f" %
102 | (precision, recall, f1))
103 |
104 |
105 | if __name__ == "__main__":
106 | # yapf: disable
107 | parser = argparse.ArgumentParser()
108 |
109 | parser.add_argument("-m", "--model_path", type=str, required=True,
110 | help="The path of saved model that you want to load.")
111 | parser.add_argument("-t", "--test_path", type=str, required=True,
112 | help="The path of test set.")
113 | parser.add_argument("-b", "--batch_size", type=int, default=16,
114 | help="Batch size per GPU/CPU for training.")
115 | parser.add_argument("--max_seq_len", type=int, default=512,
116 | help="The maximum total input sequence length after tokenization.")
117 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu",
118 | help="Select which device to run model, defaults to gpu.")
119 |
120 | args = parser.parse_args()
121 |
122 | do_eval()
123 |
--------------------------------------------------------------------------------
/generic.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from collections import OrderedDict, UserDict
3 | from collections.abc import MutableMapping
4 | from contextlib import ExitStack
5 | from dataclasses import fields
6 | from enum import Enum
7 | from typing import Any, ContextManager, List, Tuple
8 |
9 | import numpy as np
10 |
11 |
12 | def is_tensor(x):
13 | """
14 | Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`.
15 | """
16 | if is_torch_fx_proxy(x):
17 | return True
18 | if is_torch_available():
19 | import torch
20 |
21 | if isinstance(x, torch.Tensor):
22 | return True
23 | if is_tf_available():
24 | import tensorflow as tf
25 |
26 | if isinstance(x, tf.Tensor):
27 | return True
28 |
29 | if is_flax_available():
30 | import jax.numpy as jnp
31 | from jax.core import Tracer
32 |
33 | if isinstance(x, (jnp.ndarray, Tracer)):
34 | return True
35 |
36 | return isinstance(x, np.ndarray)
37 |
38 | class ModelOutput(OrderedDict):
39 | """
40 | Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
41 | tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
42 | python dictionary.
43 |
44 | You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
45 | before.
46 |
47 | """
48 |
49 | def __post_init__(self):
50 | class_fields = fields(self)
51 |
52 | # Safety and consistency checks
53 | if not len(class_fields):
54 | raise ValueError(f"{self.__class__.__name__} has no fields.")
55 | if not all(field.default is None for field in class_fields[1:]):
56 | raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
57 |
58 | first_field = getattr(self, class_fields[0].name)
59 | other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
60 |
61 | if other_fields_are_none and not is_tensor(first_field):
62 | if isinstance(first_field, dict):
63 | iterator = first_field.items()
64 | first_field_iterator = True
65 | else:
66 | try:
67 | iterator = iter(first_field)
68 | first_field_iterator = True
69 | except TypeError:
70 | first_field_iterator = False
71 |
72 | # if we provided an iterator as first field and the iterator is a (key, value) iterator
73 | # set the associated fields
74 | if first_field_iterator:
75 | for element in iterator:
76 | if (
77 | not isinstance(element, (list, tuple))
78 | or not len(element) == 2
79 | or not isinstance(element[0], str)
80 | ):
81 | break
82 | setattr(self, element[0], element[1])
83 | if element[1] is not None:
84 | self[element[0]] = element[1]
85 | elif first_field is not None:
86 | self[class_fields[0].name] = first_field
87 | else:
88 | for field in class_fields:
89 | v = getattr(self, field.name)
90 | if v is not None:
91 | self[field.name] = v
92 |
93 | def __delitem__(self, *args, **kwargs):
94 | raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
95 |
96 | def setdefault(self, *args, **kwargs):
97 | raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
98 |
99 | def pop(self, *args, **kwargs):
100 | raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
101 |
102 | def update(self, *args, **kwargs):
103 | raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
104 |
105 | def __getitem__(self, k):
106 | if isinstance(k, str):
107 | inner_dict = {k: v for (k, v) in self.items()}
108 | return inner_dict[k]
109 | else:
110 | return self.to_tuple()[k]
111 |
112 | def __setattr__(self, name, value):
113 | if name in self.keys() and value is not None:
114 | # Don't call self.__setitem__ to avoid recursion errors
115 | super().__setitem__(name, value)
116 | super().__setattr__(name, value)
117 |
118 | def __setitem__(self, key, value):
119 | # Will raise a KeyException if needed
120 | super().__setitem__(key, value)
121 | # Don't call self.__setattr__ to avoid recursion errors
122 | super().__setattr__(key, value)
123 |
124 | def to_tuple(self) -> Tuple[Any]:
125 | """
126 | Convert self to a tuple containing all the attributes/keys that are not `None`.
127 | """
128 | return tuple(self[k] for k in self.keys())
129 |
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch_uie_re
2 | 基于pytorch的百度UIE关系抽取,源代码来源:[here](https://github.com/heiheiyoyo/uie_pytorch)。
3 |
4 | 百度UIE通用信息抽取的样例一般都是使用doccano标注数据,这里介绍如何使用通用数据集,利用UIE进行微调。
5 |
6 | # 依赖
7 |
8 | ```
9 | torch>=1.7.0
10 | transformers==4.20.0
11 | colorlog
12 | colorama
13 | ```
14 |
15 | # 步骤
16 |
17 | - 1、数据放在data下面,比如已经放置的dgre([工业知识图谱关系抽取-高端装备制造知识图谱自动化构建 竞赛 - DataFountain](https://www.datafountain.cn/competitions/584)),raw_data下面是原始的数据,新建一个process.py,将数据处理成类似mid_data下面的数据,```python process.py```即:
18 |
19 | ```python
20 | {"id": "AT0004", "text": "617号汽车故障报告故障现象一辆吉利车装用MR479发动机,行驶里程为23709公里,驾驶员反映该车在行驶中无异响,但在起步和换挡过程中车身有抖动现象,并且听到离合器内部有异响。", "relations": [{"id": 0, "from_id": 0, "to_id": 1, "type": "部件故障"}, {"id": 1, "from_id": 2, "to_id": 3, "type": "部件故障"}], "entities": [{"id": 0, "start_offset": 80, "end_offset": 83, "label": "主体"}, {"id": 1, "start_offset": 86, "end_offset": 88, "label": "客体"}, {"id": 2, "start_offset": 68, "end_offset": 70, "label": "主体"}, {"id": 3, "start_offset": 71, "end_offset": 73, "label": "客体"}]}
21 | ```
22 |
23 | - 2、将mid_data下面的数据使用doccano.py转换成final_data下的数据,具体指令是:
24 |
25 | ```python
26 | python doccano.py \
27 | --doccano_file ./data/dgre/mid_data/doccano_train.json \
28 | --task_type "ext" \ # ext表示抽取任务
29 | --splits 0.9 0.1 0.0 \ # 训练、验证、测试数据的比例。训练,不对数据进行切分,因此将第一位设置为1.0
30 | --save_dir ./data/dgre/final_data/ \
31 | --negative_ratio 1 # 生成负样本的比率
32 | ```
33 |
34 | 最终会在final_data下生成train.txt、dev.txt。
35 |
36 | - 3、将paddle版本的模型转换为pytorch版的模型:
37 |
38 | ```python
39 | python convert.py --input_model=uie-base --output_model=uie_base_pytorch --no_validate_output
40 | ```
41 |
42 | 其中input_model可选的模型可参考convert.py里面。output_model是我们要保存的模型路径,下面会用到。之后我们可以测试下转换的效果:
43 |
44 | ```python
45 | from uie_predictor import UIEPredictor
46 | from pprint import pprint
47 |
48 | schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction
49 | ie = UIEPredictor('./uie_base_pytorch', schema=schema)
50 | pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) # Better print results using pprint
51 | ```
52 |
53 | - 4、开始微调:
54 |
55 | ```python
56 | python finetune.py \
57 | --train_path "./data/dgre/final_data/train.txt" \
58 | --dev_path "./data/dgre/final_data/dev.txt" \
59 | --save_dir "./checkpoint/dgre" \
60 | --learning_rate 1e-5 \
61 | --batch_size 8 \
62 | --max_seq_len 512 \
63 | --num_epochs 3 \
64 | --model "uie_base_pytorch" \
65 | --seed 1000 \
66 | --logging_steps 464 \
67 | --valid_steps 464 \
68 | --device "gpu" \
69 | --max_model_num 1
70 | ```
71 |
72 | 训练完成后,会在同目录下生成checkpoint/dgre/model_best/。
73 |
74 | - 5、进行验证:
75 |
76 | ```python
77 | python evaluate.py \
78 | --model_path "./checkpoint/dgre/model_best" \
79 | --test_path "./data/dgre/final_data/dev.txt" \
80 | --batch_size 16 \
81 | --max_seq_len 512
82 | ```
83 |
84 | - 6、使用训练好的模型进行预测:
85 |
86 | ```python
87 | from uie_predictor import UIEPredictor
88 | from pprint import pprint
89 |
90 | schema = [{"主体":["部件故障", "性能故障", "检测工具", "组成"]}, "客体"] # Define the schema for entity extraction
91 | ie = UIEPredictor('./checkpoint/dgre/model_best', schema=schema)
92 | text = "分析诊断首先用故障诊断仪读取故障码为12,其含义是电控系统正常。考虑到电喷发动机控制是由怠速马达来实现的,所以先拆下怠速马达,发现其阀头上粘有大量胶质油污。用化油器清洗剂清洗后,装车试验,故障依旧。接着清洗喷油咀,故障仍未排除。最后把节气门体拆下来清洗。在操作过程中发现:一根插在节气门体下部真空管上的胶管已断裂,造成节气门后腔与大气相通,影响怠速运转稳定。这条胶管应该是连接在节气门进气管和气门室盖排气孔之间特制的丁字胶管的一部分,但该车没有使用特制的丁字胶管,它用一条直通胶管将节气门进气管和气门室盖排气孔连起来。维修方案把节气门体清洗干净后装车,再用一条专用特制的丁字形的三通胶管把节气门进气管、气门室盖排气孔和节气门体下部真空管接好,然后启动发动机,加速收油,发动机转速平稳下降"
93 | res = ie(text)
94 | pprint(res) # Better print results using pprint
95 |
96 | """
97 | [{'主体': [{'end': 164,
98 | 'probability': 0.44656947,
99 | 'relations': {'性能故障': [{'end': 169,
100 | 'probability': 0.9079101,
101 | 'start': 167,
102 | 'text': '相通'}],
103 | '部件故障': [{'end': 169,
104 | 'probability': 0.91306436,
105 | 'start': 167,
106 | 'text': '相通'}]},
107 | 'start': 159,
108 | 'text': '节气门后腔'},
109 | {'end': 153,
110 | 'probability': 0.92779523,
111 | 'relations': {'性能故障': [{'end': 156,
112 | 'probability': 0.9924409,
113 | 'start': 154,
114 | 'text': '断裂'}],
115 | '部件故障': [{'end': 156,
116 | 'probability': 0.9910217,
117 | 'start': 154,
118 | 'text': '断裂'}]},
119 | 'start': 151,
120 | 'text': '胶管'},
121 | {'end': 68,
122 | 'probability': 0.32421115,
123 | 'relations': {'性能故障': [{'end': 77,
124 | 'probability': 0.6647082,
125 | 'start': 69,
126 | 'text': '粘有大量胶质油污'}],
127 | '部件故障': [{'end': 77,
128 | 'probability': 0.8483382,
129 | 'start': 69,
130 | 'text': '粘有大量胶质油污'}]},
131 | 'start': 66,
132 | 'text': '阀头'}],
133 | '客体': [{'end': 77, 'probability': 0.5508156, 'start': 69, 'text': '粘有大量胶质油污'},
134 | {'end': 156, 'probability': 0.9614242, 'start': 154, 'text': '断裂'},
135 | {'end': 169,
136 | 'probability': 0.56304514,
137 | 'start': 164,
138 | 'text': '与大气相通'}]}]
139 | """
140 | ```
141 |
142 | 会发现,关系抽取会有关系的重复,可以多训练几个epoch看看。
143 |
144 | # 补充
145 |
146 | - 标签名最好是使用中文。
147 | - 可使用不同大小的模型进行训练和推理,以达到精度和速度的平衡。
148 |
--------------------------------------------------------------------------------
/doccano.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import time
17 | import argparse
18 | import json
19 | from decimal import Decimal
20 | import numpy as np
21 |
22 | from utils import set_seed, convert_ext_examples, convert_cls_examples, logger
23 |
24 |
25 | def do_convert():
26 | set_seed(args.seed)
27 |
28 | tic_time = time.time()
29 | if not os.path.exists(args.doccano_file):
30 | raise ValueError("Please input the correct path of doccano file.")
31 |
32 | if not os.path.exists(args.save_dir):
33 | os.makedirs(args.save_dir)
34 |
35 | if len(args.splits) != 0 and len(args.splits) != 3:
36 | raise ValueError("Only []/ len(splits)==3 accepted for splits.")
37 |
38 | def _check_sum(splits):
39 | return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(
40 | str(splits[2])) == Decimal("1")
41 |
42 | if len(args.splits) == 3 and not _check_sum(args.splits):
43 | raise ValueError(
44 | "Please set correct splits, sum of elements in splits should be equal to 1."
45 | )
46 |
47 | with open(args.doccano_file, "r", encoding="utf-8") as f:
48 | raw_examples = f.readlines()
49 |
50 | def _create_ext_examples(examples,
51 | negative_ratio=0,
52 | shuffle=False,
53 | is_train=True):
54 | entities, relations = convert_ext_examples(
55 | examples, negative_ratio, is_train=is_train)
56 | examples = entities + relations
57 | if shuffle:
58 | indexes = np.random.permutation(len(examples))
59 | examples = [examples[i] for i in indexes]
60 | return examples
61 |
62 | def _create_cls_examples(examples, prompt_prefix, options, shuffle=False):
63 | examples = convert_cls_examples(examples, prompt_prefix, options)
64 | if shuffle:
65 | indexes = np.random.permutation(len(examples))
66 | examples = [examples[i] for i in indexes]
67 | return examples
68 |
69 | def _save_examples(save_dir, file_name, examples):
70 | count = 0
71 | save_path = os.path.join(save_dir, file_name)
72 | if not examples:
73 | logger.info("Skip saving %d examples to %s." % (0, save_path))
74 | return
75 | with open(save_path, "w", encoding="utf-8") as f:
76 | for example in examples:
77 | f.write(json.dumps(example, ensure_ascii=False) + "\n")
78 | count += 1
79 | logger.info("Save %d examples to %s." % (count, save_path))
80 |
81 | if len(args.splits) == 0:
82 | if args.task_type == "ext":
83 | examples = _create_ext_examples(raw_examples, args.negative_ratio,
84 | args.is_shuffle)
85 | else:
86 | examples = _create_cls_examples(raw_examples, args.prompt_prefix,
87 | args.options, args.is_shuffle)
88 | _save_examples(args.save_dir, "train.txt", examples)
89 | else:
90 | if args.is_shuffle:
91 | indexes = np.random.permutation(len(raw_examples))
92 | raw_examples = [raw_examples[i] for i in indexes]
93 |
94 | i1, i2, _ = args.splits
95 | p1 = int(len(raw_examples) * i1)
96 | p2 = int(len(raw_examples) * (i1 + i2))
97 |
98 | if args.task_type == "ext":
99 | train_examples = _create_ext_examples(
100 | raw_examples[:p1], args.negative_ratio, args.is_shuffle)
101 | dev_examples = _create_ext_examples(
102 | raw_examples[p1:p2], -1, is_train=False)
103 | test_examples = _create_ext_examples(
104 | raw_examples[p2:], -1, is_train=False)
105 | else:
106 | train_examples = _create_cls_examples(
107 | raw_examples[:p1], args.prompt_prefix, args.options)
108 | dev_examples = _create_cls_examples(
109 | raw_examples[p1:p2], args.prompt_prefix, args.options)
110 | test_examples = _create_cls_examples(
111 | raw_examples[p2:], args.prompt_prefix, args.options)
112 |
113 | _save_examples(args.save_dir, "train.txt", train_examples)
114 | _save_examples(args.save_dir, "dev.txt", dev_examples)
115 | _save_examples(args.save_dir, "test.txt", test_examples)
116 |
117 | logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time))
118 |
119 |
120 | if __name__ == "__main__":
121 | # yapf: disable
122 | parser = argparse.ArgumentParser()
123 |
124 | parser.add_argument("-d", "--doccano_file", default="./data/doccano.json",
125 | type=str, help="The doccano file exported from doccano platform.")
126 | parser.add_argument("-s", "--save_dir", default="./data",
127 | type=str, help="The path of data that you wanna save.")
128 | parser.add_argument("--negative_ratio", default=5, type=int,
129 | help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples")
130 | parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*",
131 | help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.")
132 | parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str,
133 | help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
134 | parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+",
135 | help="Used only for the classification task, the options for classification")
136 | parser.add_argument("--prompt_prefix", default="情感倾向", type=str,
137 | help="Used only for the classification task, the prompt prefix for classification")
138 | parser.add_argument("--is_shuffle", default=True, type=bool,
139 | help="Whether to shuffle the labeled dataset, defaults to True.")
140 | parser.add_argument("--seed", type=int, default=1000,
141 | help="random seed for initialization")
142 |
143 | args = parser.parse_args()
144 | # yapf: enable
145 |
146 | do_convert()
147 |
--------------------------------------------------------------------------------
/export_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import os
17 | from itertools import chain
18 | from typing import List, Union
19 | import shutil
20 | from pathlib import Path
21 |
22 | import numpy as np
23 | import torch
24 | from transformers import (BertTokenizer, PreTrainedModel,
25 | PreTrainedTokenizerBase)
26 |
27 | from model import UIE
28 | from utils import logger
29 |
30 |
31 | def validate_onnx(tokenizer: PreTrainedTokenizerBase, pt_model: PreTrainedModel, onnx_path: Union[Path, str], strict: bool = True, atol: float = 1e-05):
32 |
33 | # 验证模型
34 | from onnxruntime import InferenceSession, SessionOptions
35 | from transformers import AutoTokenizer
36 |
37 | logger.info("Validating ONNX model...")
38 | if strict:
39 | ref_inputs = tokenizer('装备', "印媒所称的“印度第一艘国产航母”—“维克兰特”号",
40 | add_special_tokens=True,
41 | truncation=True,
42 | max_length=512,
43 | return_tensors="pt")
44 | else:
45 | batch_size = 2
46 | seq_length = 6
47 | dummy_input = [" ".join([tokenizer.unk_token])
48 | * seq_length] * batch_size
49 | ref_inputs = dict(tokenizer(dummy_input, return_tensors="pt"))
50 | # ref_inputs =
51 | ref_outputs = pt_model(**ref_inputs)
52 | ref_outputs_dict = {}
53 |
54 | # We flatten potential collection of outputs (i.e. past_keys) to a flat structure
55 | for name, value in ref_outputs.items():
56 | # Overwriting the output name as "present" since it is the name used for the ONNX outputs
57 | # ("past_key_values" being taken for the ONNX inputs)
58 | if name == "past_key_values":
59 | name = "present"
60 | ref_outputs_dict[name] = value
61 |
62 | # Create ONNX Runtime session
63 | options = SessionOptions()
64 | session = InferenceSession(str(onnx_path), options, providers=[
65 | "CPUExecutionProvider"])
66 |
67 | # We flatten potential collection of inputs (i.e. past_keys)
68 | onnx_inputs = {}
69 | for name, value in ref_inputs.items():
70 | onnx_inputs[name] = value.numpy()
71 | onnx_named_outputs = ['start_prob', 'end_prob']
72 | # Compute outputs from the ONNX model
73 | onnx_outputs = session.run(onnx_named_outputs, onnx_inputs)
74 |
75 | # Check we have a subset of the keys into onnx_outputs against ref_outputs
76 | ref_outputs_set, onnx_outputs_set = set(
77 | ref_outputs_dict.keys()), set(onnx_named_outputs)
78 | if not onnx_outputs_set.issubset(ref_outputs_set):
79 | logger.info(
80 | f"\t-[x] ONNX model output names {onnx_outputs_set} do not match reference model {ref_outputs_set}"
81 | )
82 |
83 | raise ValueError(
84 | "Outputs doesn't match between reference model and ONNX exported model: "
85 | f"{onnx_outputs_set.difference(ref_outputs_set)}"
86 | )
87 | else:
88 | logger.info(
89 | f"\t-[✓] ONNX model output names match reference model ({onnx_outputs_set})")
90 |
91 | # Check the shape and values match
92 | for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
93 | ref_value = ref_outputs_dict[name].detach().numpy()
94 |
95 | logger.info(f'\t- Validating ONNX Model output "{name}":')
96 |
97 | # Shape
98 | if not ort_value.shape == ref_value.shape:
99 | logger.info(
100 | f"\t\t-[x] shape {ort_value.shape} doesn't match {ref_value.shape}")
101 | raise ValueError(
102 | "Outputs shape doesn't match between reference model and ONNX exported model: "
103 | f"Got {ref_value.shape} (reference) and {ort_value.shape} (ONNX)"
104 | )
105 | else:
106 | logger.info(
107 | f"\t\t-[✓] {ort_value.shape} matches {ref_value.shape}")
108 |
109 | # Values
110 | if not np.allclose(ref_value, ort_value, atol=atol):
111 | logger.info(f"\t\t-[x] values not close enough (atol: {atol})")
112 | raise ValueError(
113 | "Outputs values doesn't match between reference model and ONNX exported model: "
114 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - ort_value))}"
115 | )
116 | else:
117 | logger.info(f"\t\t-[✓] all values close (atol: {atol})")
118 |
119 |
120 | def export_onnx(args: argparse.Namespace, tokenizer: PreTrainedTokenizerBase, model: PreTrainedModel, device: torch.device, input_names: List[str], output_names: List[str]):
121 | with torch.no_grad():
122 | model = model.to(device)
123 | model.eval()
124 | model.config.return_dict = True
125 | model.config.use_cache = False
126 |
127 | # Create folder
128 | if not args.output_path.exists():
129 | args.output_path.mkdir(parents=True)
130 | save_path = args.output_path / "inference.onnx"
131 |
132 | dynamic_axes = {name: {0: 'batch', 1: 'sequence'}
133 | for name in chain(input_names, output_names)}
134 |
135 | # Generate dummy input
136 | batch_size = 2
137 | seq_length = 6
138 | dummy_input = [" ".join([tokenizer.unk_token])
139 | * seq_length] * batch_size
140 | inputs = dict(tokenizer(dummy_input, return_tensors="pt"))
141 |
142 | if save_path.exists():
143 | logger.warning(f'Overwrite model {save_path.as_posix()}')
144 | save_path.unlink()
145 |
146 | torch.onnx.export(model,
147 | (inputs,),
148 | save_path,
149 | input_names=input_names,
150 | output_names=output_names,
151 | dynamic_axes=dynamic_axes,
152 | do_constant_folding=True,
153 | opset_version=11
154 | )
155 |
156 | if not os.path.exists(save_path):
157 | logger.error(f'Export Failed!')
158 |
159 | return save_path
160 |
161 |
162 | def main():
163 | parser = argparse.ArgumentParser()
164 | parser.add_argument("-m", "--model_path", type=Path, required=True,
165 | default='./checkpoint/model_best', help="The path to model parameters to be loaded.")
166 | parser.add_argument("-o", "--output_path", type=Path, default=None,
167 | help="The path of model parameter in static graph to be saved.")
168 | args = parser.parse_args()
169 |
170 | if args.output_path is None:
171 | args.output_path = args.model_path
172 |
173 | tokenizer = BertTokenizer.from_pretrained(args.model_path)
174 | model = UIE.from_pretrained(args.model_path)
175 | device = torch.device('cpu')
176 | input_names = [
177 | 'input_ids',
178 | 'token_type_ids',
179 | 'attention_mask',
180 | ]
181 | output_names = [
182 | 'start_prob',
183 | 'end_prob'
184 | ]
185 |
186 | logger.info("Export Tokenizer Config...")
187 |
188 | export_tokenizer(args)
189 |
190 | logger.info("Export ONNX Model...")
191 |
192 | save_path = export_onnx(
193 | args, tokenizer, model, device, input_names, output_names)
194 | validate_onnx(tokenizer, model, save_path)
195 |
196 | logger.info(f"All good, model saved at: {save_path.as_posix()}")
197 |
198 |
199 | def export_tokenizer(args):
200 | for tokenizer_fine in ['tokenizer_config.json', 'special_tokens_map.json', 'vocab.txt']:
201 | file_from = args.model_path / tokenizer_fine
202 | file_to = args.output_path/tokenizer_fine
203 | if file_from.resolve() == file_to.resolve():
204 | continue
205 | shutil.copyfile(file_from, file_to)
206 |
207 |
208 | if __name__ == "__main__":
209 |
210 | main()
211 |
--------------------------------------------------------------------------------
/doccano.md:
--------------------------------------------------------------------------------
1 | # doccano
2 |
3 | **目录**
4 |
5 | * [1. 安装](#安装)
6 | * [2. 项目创建](#项目创建)
7 | * [3. 数据上传](#数据上传)
8 | * [4. 标签构建](#标签构建)
9 | * [5. 任务标注](#任务标注)
10 | * [6. 数据导出](#数据导出)
11 | * [7. 数据转换](#数据转换)
12 |
13 |
14 |
15 | ## 1. 安装
16 |
17 | 参考[doccano官方文档](https://github.com/doccano/doccano) 完成doccano的安装与初始配置。
18 |
19 | **以下标注示例用到的环境配置:**
20 |
21 | - doccano 1.6.2
22 |
23 |
24 |
25 | ## 2. 项目创建
26 |
27 | UIE支持抽取与分类两种类型的任务,根据实际需要创建一个新的项目:
28 |
29 | #### 2.1 抽取式任务项目创建
30 |
31 | 创建项目时选择**序列标注**任务,并勾选**Allow overlapping entity**及**Use relation Labeling**。适配**命名实体识别、关系抽取、事件抽取、评价观点抽取**等任务。
32 |
33 |
34 |

35 |
36 |
37 | #### 2.2 分类式任务项目创建
38 |
39 | 创建项目时选择**文本分类**任务。适配**文本分类、句子级情感倾向分类**等任务。
40 |
41 |
42 |

43 |
44 |
45 |
46 |
47 | ## 3. 数据上传
48 |
49 | 上传的文件为txt格式,每一行为一条待标注文本,示例:
50 |
51 | ```text
52 | 2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌
53 | 第十四届全运会在西安举办
54 | ```
55 |
56 | 上传数据类型**选择TextLine**:
57 |
58 |
59 |

60 |
61 |
62 | **NOTE**:doccano支持`TextFile`、`TextLine`、`JSONL`和`CoNLL`四种数据上传格式,UIE定制训练中**统一使用TextLine**这一文件格式,即上传的文件需要为txt格式,且在数据标注时,该文件的每一行待标注文本显示为一页内容。
63 |
64 |
65 |
66 | ## 4. 标签构建
67 |
68 | #### 4.1 构建抽取式任务标签
69 |
70 | 抽取式任务包含**Span**与**Relation**两种标签类型,Span指**原文本中的目标信息片段**,如实体识别中某个类型的实体,事件抽取中的触发词和论元;Relation指**原文本中Span之间的关系**,如关系抽取中两个实体(Subject&Object)之间的关系,事件抽取中论元和触发词之间的关系。
71 |
72 | Span类型标签构建示例:
73 |
74 |
75 |

76 |
77 |
78 | Relation类型标签构建示例:
79 |
80 |
81 |

82 |
83 |
84 | #### 4.2 构建分类式任务标签
85 |
86 | 添加分类类别标签:
87 |
88 |
89 |

90 |
91 |
92 |
93 |
94 | ## 5. 任务标注
95 |
96 | #### 5.1 命名实体识别
97 |
98 | 命名实体识别(Named Entity Recognition,简称NER),是指识别文本中具有特定意义的实体。在开放域信息抽取中,**抽取的类别没有限制,用户可以自己定义**。
99 |
100 | 标注示例:
101 |
102 |
103 |

104 |
105 |
106 | 示例中定义了`时间`、`选手`、`赛事名称`和`得分`四种Span类型标签。
107 |
108 | #### 5.2 关系抽取
109 |
110 | 关系抽取(Relation Extraction,简称RE),是指从文本中识别实体并抽取实体之间的语义关系,即抽取三元组(实体一,关系类型,实体二)。
111 |
112 | 标注示例:
113 |
114 |
115 |

116 |
117 |
118 | 示例中定义了`作品名`、`人物名`和`时间`三种Span类型标签,以及`歌手`、`发行时间`和`所属专辑`三种Relation标签。Relation标签**由Subject对应实体指向Object对应实体**。
119 |
120 | #### 5.3 事件抽取
121 |
122 | 事件抽取 (Event Extraction, 简称EE),是指从自然语言文本中抽取事件并识别事件类型和事件论元的技术。UIE所包含的事件抽取任务,是指根据已知事件类型,抽取该事件所包含的事件论元。
123 |
124 | 标注示例:
125 |
126 |
127 |

128 |
129 |
130 | 示例中定义了`地震触发词`(触发词)、`等级`(事件论元)和`时间`(事件论元)三种Span标签,以及`时间`和`震级`两种Relation标签。触发词标签**统一格式为`XX触发词`**,`XX`表示具体事件类型,上例中的事件类型是`地震`,则对应触发词为`地震触发词`。Relation标签**由触发词指向对应的事件论元**。
131 |
132 | #### 5.4 评价观点抽取
133 |
134 | 评论观点抽取,是指抽取文本中包含的评价维度、观点词。
135 |
136 | 标注示例:
137 |
138 |
139 |

140 |
141 |
142 | 示例中定义了`评价维度`和`观点词`两种Span标签,以及`观点词`一种Relation标签。Relation标签**由评价维度指向观点词**。
143 |
144 | #### 5.5 分类任务
145 |
146 | 标注示例:
147 |
148 |
149 |

150 |
151 |
152 | 示例中定义了`正向`和`负向`两种类别标签对文本的情感倾向进行分类。
153 |
154 |
155 |
156 | ## 6. 数据导出
157 |
158 | #### 6.1 导出抽取式任务数据
159 |
160 | 选择导出的文件类型为``JSONL(relation)``,导出数据示例:
161 |
162 | ```text
163 | {
164 | "id": 38,
165 | "text": "百科名片你知道我要什么,是歌手高明骏演唱的一首歌曲,1989年发行,收录于个人专辑《丛林男孩》中",
166 | "relations": [
167 | {
168 | "id": 20,
169 | "from_id": 51,
170 | "to_id": 53,
171 | "type": "歌手"
172 | },
173 | {
174 | "id": 21,
175 | "from_id": 51,
176 | "to_id": 55,
177 | "type": "发行时间"
178 | },
179 | {
180 | "id": 22,
181 | "from_id": 51,
182 | "to_id": 54,
183 | "type": "所属专辑"
184 | }
185 | ],
186 | "entities": [
187 | {
188 | "id": 51,
189 | "start_offset": 4,
190 | "end_offset": 11,
191 | "label": "作品名"
192 | },
193 | {
194 | "id": 53,
195 | "start_offset": 15,
196 | "end_offset": 18,
197 | "label": "人物名"
198 | },
199 | {
200 | "id": 54,
201 | "start_offset": 42,
202 | "end_offset": 46,
203 | "label": "作品名"
204 | },
205 | {
206 | "id": 55,
207 | "start_offset": 26,
208 | "end_offset": 31,
209 | "label": "时间"
210 | }
211 | ]
212 | }
213 | ```
214 |
215 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段
216 | - ``id``: 样本在数据集中的唯一标识ID。
217 | - ``text``: 原始文本数据。
218 | - ``entities``: 数据中包含的Span标签,每个Span标签包含四个字段:
219 | - ``id``: Span在数据集中的唯一标识ID。
220 | - ``start_offset``: Span的起始token在文本中的下标。
221 | - ``end_offset``: Span的结束token在文本中下标的下一个位置。
222 | - ``label``: Span类型。
223 | - ``relations``: 数据中包含的Relation标签,每个Relation标签包含四个字段:
224 | - ``id``: (Span1, Relation, Span2)三元组在数据集中的唯一标识ID,不同样本中的相同三元组对应同一个ID。
225 | - ``from_id``: Span1对应的标识ID。
226 | - ``to_id``: Span2对应的标识ID。
227 | - ``type``: Relation类型。
228 |
229 | #### 6.2 导出分类式任务数据
230 |
231 | 选择导出的文件类型为``JSONL``,导出数据示例:
232 |
233 | ```text
234 | {
235 | "id": 41,
236 | "data": "大年初一就把车前保险杠给碰坏了,保险杠和保险公司 真够倒霉的,我决定步行反省。",
237 | "label": [
238 | "负向"
239 | ]
240 | }
241 | ```
242 |
243 | 标注数据保存在同一个文本文件中,每条样例占一行且存储为``json``格式,其包含以下字段
244 | - ``id``: 样本在数据集中的唯一标识ID。
245 | - ``data``: 原始文本数据。
246 | - ``label``: 文本对应类别标签。
247 |
248 |
249 |
250 | ## 7.数据转换
251 |
252 | 该章节详细说明如何通过`doccano.py`脚本对doccano平台导出的标注数据进行转换,一键生成训练/验证/测试集。
253 |
254 | #### 7.1 抽取式任务数据转换
255 |
256 | - 当标注完成后,在 doccano 平台上导出 `JSONL(relation)` 形式的文件,并将其重命名为 `doccano_ext.json` 后,放入 `./data` 目录下。
257 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。
258 |
259 | ```shell
260 | python doccano.py \
261 | --doccano_file ./data/doccano_ext.json \
262 | --task_type "ext" \
263 | --save_dir ./data \
264 | --negative_ratio 5
265 | ```
266 |
267 | #### 7.2 分类式任务数据转换
268 |
269 | - 当标注完成后,在 doccano 平台上导出 `JSON` 形式的文件,并将其重命名为 `doccano_cls.json` 后,放入 `./data` 目录下。
270 | - 在数据转换阶段,我们会自动构造用于模型训练需要的prompt信息。例如句子级情感分类中,prompt为``情感倾向[正向,负向]``,可以通过`prompt_prefix`和`options`参数进行声明。
271 | - 通过 [doccano.py](./doccano.py) 脚本进行数据形式转换,然后便可以开始进行相应模型训练。
272 |
273 | ```shell
274 | python doccano.py \
275 | --doccano_file ./data/doccano_cls.json \
276 | --task_type "cls" \
277 | --save_dir ./data \
278 | --splits 0.8 0.1 0.1 \
279 | --prompt_prefix "情感倾向" \
280 | --options "正向" "负向"
281 | ```
282 |
283 | 可配置参数说明:
284 |
285 | - ``doccano_file``: 从doccano导出的数据标注文件。
286 | - ``save_dir``: 训练数据的保存目录,默认存储在``data``目录下。
287 | - ``negative_ratio``: 最大负例比例,该参数只对抽取类型任务有效,适当构造负例可提升模型效果。负例数量和实际的标签数量有关,最大负例数量 = negative_ratio * 正例数量。该参数只对训练集有效,默认为5。为了保证评估指标的准确性,验证集和测试集默认构造全负例。
288 | - ``splits``: 划分数据集时训练集、验证集所占的比例。默认为[0.8, 0.1, 0.1]表示按照``8:1:1``的比例将数据划分为训练集、验证集和测试集。
289 | - ``task_type``: 选择任务类型,可选有抽取和分类两种类型的任务。
290 | - ``options``: 指定分类任务的类别标签,该参数只对分类类型任务有效。
291 | - ``prompt_prefix``: 声明分类任务的prompt前缀信息,该参数只对分类类型任务有效。
292 | - ``is_shuffle``: 是否对数据集进行随机打散,默认为True。
293 | - ``seed``: 随机种子,默认为1000.
294 |
295 | 备注:
296 | - 默认情况下 [doccano.py](./doccano.py) 脚本会按照比例将数据划分为 train/dev/test 数据集
297 | - 每次执行 [doccano.py](./doccano.py) 脚本,将会覆盖已有的同名数据文件
298 | - 在模型训练阶段我们推荐构造一些负例以提升模型效果,在数据转换阶段我们内置了这一功能。可通过`negative_ratio`控制自动构造的负样本比例;负样本数量 = negative_ratio * 正样本数量。
299 | - 对于从doccano导出的文件,默认文件中的每条数据都是经过人工正确标注的。
300 |
301 | ## References
302 | - **[doccano](https://github.com/doccano/doccano)**
303 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from dataclasses import dataclass
18 | from transformers import BertModel, BertPreTrainedModel, PretrainedConfig
19 | # from transformers.utils import ModelOutput
20 | from typing import Optional, Tuple
21 | from generic import ModelOutput
22 |
23 |
24 | @dataclass
25 | class UIEModelOutput(ModelOutput):
26 | """
27 | Output class for outputs of UIE.
28 | Args:
29 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
30 | Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
31 | start_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
32 | Span-start scores (after Sigmoid).
33 | end_prob (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
34 | Span-end scores (after Sigmoid).
35 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
36 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
37 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
38 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
39 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
40 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
41 | sequence_length)`.
42 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
43 | heads.
44 | """
45 | loss: Optional[torch.FloatTensor] = None
46 | start_prob: torch.FloatTensor = None
47 | end_prob: torch.FloatTensor = None
48 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
49 | attentions: Optional[Tuple[torch.FloatTensor]] = None
50 |
51 |
52 | class UIE(BertPreTrainedModel):
53 | """
54 | UIE model based on Bert model.
55 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
56 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
57 | etc.)
58 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
59 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
60 | and behavior.
61 | Parameters:
62 | config ([`PretrainedConfig`]): Model configuration class with all the parameters of the model.
63 | Initializing with a config file does not load the weights associated with the model, only the
64 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
65 | """
66 |
67 | def __init__(self, config: PretrainedConfig):
68 | super(UIE, self).__init__(config)
69 | self.encoder = BertModel(config)
70 | self.config = config
71 | hidden_size = self.config.hidden_size
72 |
73 | self.linear_start = nn.Linear(hidden_size, 1)
74 | self.linear_end = nn.Linear(hidden_size, 1)
75 | self.sigmoid = nn.Sigmoid()
76 |
77 | if hasattr(config, 'use_task_id') and config.use_task_id:
78 | # Add task type embedding to BERT
79 | task_type_embeddings = nn.Embedding(
80 | config.task_type_vocab_size, config.hidden_size)
81 | self.encoder.embeddings.task_type_embeddings = task_type_embeddings
82 |
83 | def hook(module, input, output):
84 | input = input[0]
85 | return output+task_type_embeddings(torch.zeros(input.size(), dtype=torch.int64, device=input.device))
86 | self.encoder.embeddings.word_embeddings.register_forward_hook(hook)
87 |
88 | self.post_init()
89 |
90 | def forward(self, input_ids: Optional[torch.Tensor] = None,
91 | token_type_ids: Optional[torch.Tensor] = None,
92 | position_ids: Optional[torch.Tensor] = None,
93 | attention_mask: Optional[torch.Tensor] = None,
94 | head_mask: Optional[torch.Tensor] = None,
95 | inputs_embeds: Optional[torch.Tensor] = None,
96 | start_positions: Optional[torch.Tensor] = None,
97 | end_positions: Optional[torch.Tensor] = None,
98 | output_attentions: Optional[bool] = None,
99 | output_hidden_states: Optional[bool] = None,
100 | return_dict: Optional[bool] = None
101 | ):
102 | """
103 | Args:
104 | input_ids (`torch.LongTensor` of shape `({0})`):
105 | Indices of input sequence tokens in the vocabulary.
106 | Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and
107 | [`PreTrainedTokenizer.__call__`] for details.
108 | [What are input IDs?](../glossary#input-ids)
109 | attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
110 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
111 | - 1 for tokens that are **not masked**,
112 | - 0 for tokens that are **masked**.
113 | [What are attention masks?](../glossary#attention-mask)
114 | token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
115 | Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
116 | 1]`:
117 | - 0 corresponds to a *sentence A* token,
118 | - 1 corresponds to a *sentence B* token.
119 | [What are token type IDs?](../glossary#token-type-ids)
120 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
121 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
122 | config.max_position_embeddings - 1]`.
123 | [What are position IDs?](../glossary#position-ids)
124 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
125 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
126 | - 1 indicates the head is **not masked**,
127 | - 0 indicates the head is **masked**.
128 | inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
129 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
130 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
131 | model's internal embedding lookup matrix.
132 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
133 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
134 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
135 | are not taken into account for computing the loss.
136 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
137 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
138 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
139 | are not taken into account for computing the loss.
140 | output_attentions (`bool`, *optional*):
141 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
142 | tensors for more detail.
143 | output_hidden_states (`bool`, *optional*):
144 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
145 | more detail.
146 | return_dict (`bool`, *optional*):
147 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
148 | """
149 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
150 | outputs = self.encoder(
151 | input_ids=input_ids,
152 | token_type_ids=token_type_ids,
153 | position_ids=position_ids,
154 | attention_mask=attention_mask,
155 | head_mask=head_mask,
156 | inputs_embeds=inputs_embeds,
157 | output_attentions=output_attentions,
158 | output_hidden_states=output_hidden_states,
159 | return_dict=return_dict
160 | )
161 | sequence_output = outputs[0]
162 |
163 | start_logits = self.linear_start(sequence_output)
164 | start_logits = torch.squeeze(start_logits, -1)
165 | start_prob = self.sigmoid(start_logits)
166 | end_logits = self.linear_end(sequence_output)
167 | end_logits = torch.squeeze(end_logits, -1)
168 | end_prob = self.sigmoid(end_logits)
169 |
170 | total_loss = None
171 | if start_positions is not None and end_positions is not None:
172 | loss_fct = nn.BCELoss()
173 | start_loss = loss_fct(start_prob, start_positions)
174 | end_loss = loss_fct(end_prob, end_positions)
175 | total_loss = (start_loss + end_loss) / 2.0
176 |
177 | if not return_dict:
178 | output = (start_prob, end_prob) + outputs[2:]
179 | return ((total_loss,) + output) if total_loss is not None else output
180 |
181 | return UIEModelOutput(
182 | loss=total_loss,
183 | start_prob=start_prob,
184 | end_prob=end_prob,
185 | hidden_states=outputs.hidden_states,
186 | attentions=outputs.attentions,
187 | )
188 |
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import shutil
17 | import sys
18 | import time
19 | import os
20 | import torch
21 | from torch.utils.data import DataLoader
22 | from transformers import BertTokenizerFast
23 |
24 | from utils import IEDataset, logger, tqdm
25 | from model import UIE
26 | from evaluate import evaluate
27 | from utils import set_seed, SpanEvaluator, EarlyStopping, logging_redirect_tqdm
28 |
29 |
30 | def do_train():
31 |
32 | set_seed(args.seed)
33 | show_bar = True
34 |
35 | tokenizer = BertTokenizerFast.from_pretrained(args.model)
36 | model = UIE.from_pretrained(args.model)
37 | if args.device == 'gpu':
38 | model = model.cuda()
39 | train_ds = IEDataset(args.train_path, tokenizer=tokenizer,
40 | max_seq_len=args.max_seq_len)
41 | dev_ds = IEDataset(args.dev_path, tokenizer=tokenizer,
42 | max_seq_len=args.max_seq_len)
43 |
44 | train_data_loader = DataLoader(
45 | train_ds, batch_size=args.batch_size, shuffle=True)
46 | dev_data_loader = DataLoader(
47 | dev_ds, batch_size=args.batch_size, shuffle=True)
48 |
49 | optimizer = torch.optim.AdamW(
50 | lr=args.learning_rate, params=model.parameters())
51 |
52 | criterion = torch.nn.functional.binary_cross_entropy
53 | metric = SpanEvaluator()
54 |
55 | if args.early_stopping:
56 | early_stopping_save_dir = os.path.join(
57 | args.save_dir, "early_stopping")
58 | if not os.path.exists(early_stopping_save_dir):
59 | os.makedirs(early_stopping_save_dir)
60 | if show_bar:
61 | def trace_func(*args, **kwargs):
62 | with logging_redirect_tqdm([logger.logger]):
63 | logger.info(*args, **kwargs)
64 | else:
65 | trace_func = logger.info
66 | early_stopping = EarlyStopping(
67 | patience=7, verbose=True, trace_func=trace_func,
68 | save_dir=early_stopping_save_dir)
69 |
70 | loss_list = []
71 | loss_sum = 0
72 | loss_num = 0
73 | global_step = 0
74 | best_step = 0
75 | best_f1 = 0
76 | tic_train = time.time()
77 | epoch_iterator = range(1, args.num_epochs + 1)
78 | if show_bar:
79 | train_postfix_info = {'loss': 'unknown'}
80 | epoch_iterator = tqdm(
81 | epoch_iterator, desc='Training', unit='epoch')
82 | for epoch in epoch_iterator:
83 | train_data_iterator = train_data_loader
84 | if show_bar:
85 | train_data_iterator = tqdm(train_data_iterator,
86 | desc=f'Training Epoch {epoch}', unit='batch')
87 | train_data_iterator.set_postfix(train_postfix_info)
88 | for batch in train_data_iterator:
89 | if show_bar:
90 | epoch_iterator.refresh()
91 | input_ids, token_type_ids, att_mask, start_ids, end_ids = batch
92 | if args.device == 'gpu':
93 | input_ids = input_ids.cuda()
94 | token_type_ids = token_type_ids.cuda()
95 | att_mask = att_mask.cuda()
96 | start_ids = start_ids.cuda()
97 | end_ids = end_ids.cuda()
98 | outputs = model(input_ids=input_ids,
99 | token_type_ids=token_type_ids,
100 | attention_mask=att_mask)
101 | start_prob, end_prob = outputs[0], outputs[1]
102 |
103 | start_ids = start_ids.type(torch.float32)
104 | end_ids = end_ids.type(torch.float32)
105 | loss_start = criterion(start_prob, start_ids)
106 | loss_end = criterion(end_prob, end_ids)
107 | loss = (loss_start + loss_end) / 2.0
108 | loss.backward()
109 | optimizer.step()
110 | optimizer.zero_grad()
111 | loss_list.append(float(loss))
112 | loss_sum += float(loss)
113 | loss_num += 1
114 |
115 | if show_bar:
116 | loss_avg = loss_sum / loss_num
117 | train_postfix_info.update({
118 | 'loss': f'{loss_avg:.5f}'
119 | })
120 | train_data_iterator.set_postfix(train_postfix_info)
121 |
122 | global_step += 1
123 | if global_step % args.logging_steps == 0:
124 | time_diff = time.time() - tic_train
125 | loss_avg = loss_sum / loss_num
126 |
127 | if show_bar:
128 | with logging_redirect_tqdm([logger.logger]):
129 | logger.info(
130 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
131 | % (global_step, epoch, loss_avg,
132 | args.logging_steps / time_diff))
133 | else:
134 | logger.info(
135 | "global step %d, epoch: %d, loss: %.5f, speed: %.2f step/s"
136 | % (global_step, epoch, loss_avg,
137 | args.logging_steps / time_diff))
138 | tic_train = time.time()
139 |
140 | if global_step % args.valid_steps == 0:
141 | save_dir = os.path.join(
142 | args.save_dir, "model_%d" % global_step)
143 | if not os.path.exists(save_dir):
144 | os.makedirs(save_dir)
145 | model_to_save = model
146 | model_to_save.save_pretrained(save_dir)
147 | tokenizer.save_pretrained(save_dir)
148 | if args.max_model_num:
149 | model_to_delete = global_step-args.max_model_num*args.valid_steps
150 | model_to_delete_path = os.path.join(
151 | args.save_dir, "model_%d" % model_to_delete)
152 | if model_to_delete > 0 and os.path.exists(model_to_delete_path):
153 | shutil.rmtree(model_to_delete_path)
154 |
155 | dev_loss_avg, precision, recall, f1 = evaluate(
156 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion)
157 |
158 | if show_bar:
159 | train_postfix_info.update({
160 | 'F1': f'{f1:.3f}',
161 | 'dev loss': f'{dev_loss_avg:.5f}'
162 | })
163 | train_data_iterator.set_postfix(train_postfix_info)
164 | with logging_redirect_tqdm([logger.logger]):
165 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
166 | % (precision, recall, f1, dev_loss_avg))
167 | else:
168 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
169 | % (precision, recall, f1, dev_loss_avg))
170 | # Save model which has best F1
171 | if f1 > best_f1:
172 | if show_bar:
173 | with logging_redirect_tqdm([logger.logger]):
174 | logger.info(
175 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
176 | )
177 | else:
178 | logger.info(
179 | f"best F1 performence has been updated: {best_f1:.5f} --> {f1:.5f}"
180 | )
181 | best_f1 = f1
182 | save_dir = os.path.join(args.save_dir, "model_best")
183 | model_to_save = model
184 | model_to_save.save_pretrained(save_dir)
185 | tokenizer.save_pretrained(save_dir)
186 | tic_train = time.time()
187 |
188 | if args.early_stopping:
189 | dev_loss_avg, precision, recall, f1 = evaluate(
190 | model, metric, data_loader=dev_data_loader, device=args.device, loss_fn=criterion)
191 |
192 | if show_bar:
193 | train_postfix_info.update({
194 | 'F1': f'{f1:.3f}',
195 | 'dev loss': f'{dev_loss_avg:.5f}'
196 | })
197 | train_data_iterator.set_postfix(train_postfix_info)
198 | with logging_redirect_tqdm([logger.logger]):
199 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
200 | % (precision, recall, f1, dev_loss_avg))
201 | else:
202 | logger.info("Evaluation precision: %.5f, recall: %.5f, F1: %.5f, dev loss: %.5f"
203 | % (precision, recall, f1, dev_loss_avg))
204 |
205 | # Early Stopping
206 | early_stopping(dev_loss_avg, model)
207 | if early_stopping.early_stop:
208 | if show_bar:
209 | with logging_redirect_tqdm([logger.logger]):
210 | logger.info("Early stopping")
211 | else:
212 | logger.info("Early stopping")
213 | tokenizer.save_pretrained(early_stopping_save_dir)
214 | sys.exit(0)
215 |
216 |
217 | if __name__ == "__main__":
218 | # yapf: disable
219 | parser = argparse.ArgumentParser()
220 |
221 | parser.add_argument("-b", "--batch_size", default=16, type=int,
222 | help="Batch size per GPU/CPU for training.")
223 | parser.add_argument("--learning_rate", default=1e-5,
224 | type=float, help="The initial learning rate for Adam.")
225 | parser.add_argument("-t", "--train_path", default=None, required=True,
226 | type=str, help="The path of train set.")
227 | parser.add_argument("-d", "--dev_path", default=None, required=True,
228 | type=str, help="The path of dev set.")
229 | parser.add_argument("-s", "--save_dir", default='./checkpoint', type=str,
230 | help="The output directory where the model checkpoints will be written.")
231 | parser.add_argument("--max_seq_len", default=512, type=int, help="The maximum input sequence length. "
232 | "Sequences longer than this will be split automatically.")
233 | parser.add_argument("--num_epochs", default=100, type=int,
234 | help="Total number of training epochs to perform.")
235 | parser.add_argument("--seed", default=1000, type=int,
236 | help="Random seed for initialization")
237 | parser.add_argument("--logging_steps", default=10,
238 | type=int, help="The interval steps to logging.")
239 | parser.add_argument("--valid_steps", default=100, type=int,
240 | help="The interval steps to evaluate model performance.")
241 | parser.add_argument("-D", '--device', choices=['cpu', 'gpu'], default="gpu",
242 | help="Select which device to train model, defaults to gpu.")
243 | parser.add_argument("-m", "--model", default="uie_base_pytorch", type=str,
244 | help="Select the pretrained model for few-shot learning.")
245 | parser.add_argument("--max_model_num", default=5, type=int,
246 | help="Max number of saved model. Best model and earlystopping model is not included.")
247 | parser.add_argument("--early_stopping", action='store_true', default=False,
248 | help="Use early stopping while training")
249 |
250 | args = parser.parse_args()
251 |
252 | do_train()
253 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import collections
17 | import json
18 | import os
19 | import pickle
20 | import shutil
21 | import numpy as np
22 |
23 | import torch
24 | try:
25 | import paddle
26 | from paddle.utils.download import get_path_from_url
27 | paddle_installed = True
28 | except (ImportError, ModuleNotFoundError):
29 | from utils import get_path_from_url
30 | paddle_installed = False
31 |
32 | from model import UIE
33 | from utils import logger
34 |
35 | MODEL_MAP = {
36 | "uie-base": {
37 | "resource_file_urls": {
38 | "model_state.pdparams":
39 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
40 | "model_config.json":
41 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
42 | "vocab_file":
43 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
44 | "special_tokens_map":
45 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
46 | "tokenizer_config":
47 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
48 | }
49 | },
50 | "uie-medium": {
51 | "resource_file_urls": {
52 | "model_state.pdparams":
53 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
54 | "model_config.json":
55 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
56 | "vocab_file":
57 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
58 | "special_tokens_map":
59 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
60 | "tokenizer_config":
61 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
62 | }
63 | },
64 | "uie-mini": {
65 | "resource_file_urls": {
66 | "model_state.pdparams":
67 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
68 | "model_config.json":
69 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
70 | "vocab_file":
71 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
72 | "special_tokens_map":
73 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
74 | "tokenizer_config":
75 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
76 | }
77 | },
78 | "uie-micro": {
79 | "resource_file_urls": {
80 | "model_state.pdparams":
81 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
82 | "model_config.json":
83 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
84 | "vocab_file":
85 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
86 | "special_tokens_map":
87 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
88 | "tokenizer_config":
89 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
90 | }
91 | },
92 | "uie-nano": {
93 | "resource_file_urls": {
94 | "model_state.pdparams":
95 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
96 | "model_config.json":
97 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
98 | "vocab_file":
99 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
100 | "special_tokens_map":
101 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
102 | "tokenizer_config":
103 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
104 | }
105 | },
106 | "uie-medical-base": {
107 | "resource_file_urls": {
108 | "model_state.pdparams":
109 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
110 | "model_config.json":
111 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
112 | "vocab_file":
113 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
114 | "special_tokens_map":
115 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
116 | "tokenizer_config":
117 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
118 | }
119 | },
120 | "uie-tiny": {
121 | "resource_file_urls": {
122 | "model_state.pdparams":
123 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
124 | "model_config.json":
125 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
126 | "vocab_file":
127 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
128 | "special_tokens_map":
129 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
130 | "tokenizer_config":
131 | "https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
132 | }
133 | }
134 | }
135 |
136 |
137 | def build_params_map(attention_num=12):
138 | """
139 | build params map from paddle-paddle's ERNIE to transformer's BERT
140 | :return:
141 | """
142 | weight_map = collections.OrderedDict({
143 | 'encoder.embeddings.word_embeddings.weight': "encoder.embeddings.word_embeddings.weight",
144 | 'encoder.embeddings.position_embeddings.weight': "encoder.embeddings.position_embeddings.weight",
145 | 'encoder.embeddings.token_type_embeddings.weight': "encoder.embeddings.token_type_embeddings.weight",
146 | 'encoder.embeddings.task_type_embeddings.weight': "encoder.embeddings.task_type_embeddings.weight",
147 | 'encoder.embeddings.layer_norm.weight': 'encoder.embeddings.LayerNorm.gamma',
148 | 'encoder.embeddings.layer_norm.bias': 'encoder.embeddings.LayerNorm.beta',
149 | })
150 | # add attention layers
151 | for i in range(attention_num):
152 | weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.query.weight'
153 | weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.query.bias'
154 | weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.key.weight'
155 | weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.key.bias'
156 | weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'encoder.encoder.layer.{i}.attention.self.value.weight'
157 | weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'encoder.encoder.layer.{i}.attention.self.value.bias'
158 | weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'encoder.encoder.layer.{i}.attention.output.dense.weight'
159 | weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'encoder.encoder.layer.{i}.attention.output.dense.bias'
160 | weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.gamma'
161 | weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'encoder.encoder.layer.{i}.attention.output.LayerNorm.beta'
162 | weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'encoder.encoder.layer.{i}.intermediate.dense.weight'
163 | weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'encoder.encoder.layer.{i}.intermediate.dense.bias'
164 | weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'encoder.encoder.layer.{i}.output.dense.weight'
165 | weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'encoder.encoder.layer.{i}.output.dense.bias'
166 | weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'encoder.encoder.layer.{i}.output.LayerNorm.gamma'
167 | weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'encoder.encoder.layer.{i}.output.LayerNorm.beta'
168 | # add pooler
169 | weight_map.update(
170 | {
171 | 'encoder.pooler.dense.weight': 'encoder.pooler.dense.weight',
172 | 'encoder.pooler.dense.bias': 'encoder.pooler.dense.bias',
173 | 'linear_start.weight': 'linear_start.weight',
174 | 'linear_start.bias': 'linear_start.bias',
175 | 'linear_end.weight': 'linear_end.weight',
176 | 'linear_end.bias': 'linear_end.bias',
177 | }
178 | )
179 | return weight_map
180 |
181 |
182 | def extract_and_convert(input_dir, output_dir):
183 | if not os.path.exists(output_dir):
184 | os.makedirs(output_dir)
185 | logger.info('=' * 20 + 'save config file' + '=' * 20)
186 | config = json.load(
187 | open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
188 | config = config['init_args'][0]
189 | config["architectures"] = ["UIE"]
190 | config['layer_norm_eps'] = 1e-12
191 | del config['init_class']
192 | if 'sent_type_vocab_size' in config:
193 | config['type_vocab_size'] = config['sent_type_vocab_size']
194 | config['intermediate_size'] = 4 * config['hidden_size']
195 | json.dump(config, open(os.path.join(output_dir, 'config.json'),
196 | 'wt', encoding='utf-8'), indent=4)
197 | logger.info('=' * 20 + 'save vocab file' + '=' * 20)
198 | with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
199 | words = f.read().splitlines()
200 | words_set = set()
201 | words_duplicate_indices = []
202 | for i in range(len(words)-1, -1, -1):
203 | word = words[i]
204 | if word in words_set:
205 | words_duplicate_indices.append(i)
206 | words_set.add(word)
207 | for i, idx in enumerate(words_duplicate_indices):
208 | words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
209 | with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
210 | for word in words:
211 | f.write(word+'\n')
212 | special_tokens_map = {
213 | "unk_token": "[UNK]",
214 | "sep_token": "[SEP]",
215 | "pad_token": "[PAD]",
216 | "cls_token": "[CLS]",
217 | "mask_token": "[MASK]"
218 | }
219 | json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
220 | 'wt', encoding='utf-8'))
221 | tokenizer_config = {
222 | "do_lower_case": True,
223 | "unk_token": "[UNK]",
224 | "sep_token": "[SEP]",
225 | "pad_token": "[PAD]",
226 | "cls_token": "[CLS]",
227 | "mask_token": "[MASK]",
228 | "tokenizer_class": "BertTokenizer"
229 | }
230 | json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
231 | 'wt', encoding='utf-8'))
232 | logger.info('=' * 20 + 'extract weights' + '=' * 20)
233 | state_dict = collections.OrderedDict()
234 | weight_map = build_params_map(attention_num=config['num_hidden_layers'])
235 | if paddle_installed:
236 | import paddle.fluid.dygraph as D
237 | from paddle import fluid
238 | with fluid.dygraph.guard():
239 | paddle_paddle_params, _ = D.load_dygraph(
240 | os.path.join(input_dir, 'model_state'))
241 | else:
242 | paddle_paddle_params = pickle.load(
243 | open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
244 | del paddle_paddle_params['StructuredToParameterName@@']
245 | for weight_name, weight_value in paddle_paddle_params.items():
246 | if 'weight' in weight_name:
247 | if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
248 | weight_value = weight_value.transpose()
249 | # Fix: embedding error
250 | if 'word_embeddings.weight' in weight_name:
251 | weight_value[0, :] = 0
252 | if weight_name not in weight_map:
253 | logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
254 | continue
255 | state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
256 | logger.info(
257 | f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
258 | torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
259 |
260 |
261 | def check_model(input_model):
262 | if not os.path.exists(input_model):
263 | if input_model not in MODEL_MAP:
264 | raise ValueError('input_model not exists!')
265 |
266 | resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
267 | logger.info("Downloading resource files...")
268 |
269 | for key, val in resource_file_urls.items():
270 | file_path = os.path.join(input_model, key)
271 | if not os.path.exists(file_path):
272 | get_path_from_url(val, input_model)
273 |
274 |
275 | def validate_model(tokenizer, pt_model, pd_model: str, atol: float = 1e-5):
276 |
277 | logger.info("Validating PyTorch model...")
278 |
279 | batch_size = 2
280 | seq_length = 6
281 | seq_length_with_token = seq_length+2
282 | max_seq_length = 512
283 | dummy_input = [" ".join([tokenizer.unk_token])
284 | * seq_length] * batch_size
285 | encoded_inputs = dict(tokenizer(dummy_input, pad_to_max_seq_len=True, max_seq_len=512, return_attention_mask=True,
286 | return_position_ids=True))
287 | paddle_inputs = {}
288 | for name, value in encoded_inputs.items():
289 | if name == "attention_mask":
290 | name = "att_mask"
291 | if name == "position_ids":
292 | name = "pos_ids"
293 | paddle_inputs[name] = paddle.to_tensor(value, dtype=paddle.int64)
294 |
295 | paddle_named_outputs = ['start_prob', 'end_prob']
296 | paddle_outputs = pd_model(**paddle_inputs)
297 |
298 | torch_inputs = {}
299 | for name, value in encoded_inputs.items():
300 | torch_inputs[name] = torch.tensor(value, dtype=torch.int64)
301 | torch_outputs = pt_model(**torch_inputs)
302 | torch_outputs_dict = {}
303 |
304 | for name, value in torch_outputs.items():
305 | torch_outputs_dict[name] = value
306 |
307 | torch_outputs_set, ref_outputs_set = set(
308 | torch_outputs_dict.keys()), set(paddle_named_outputs)
309 | if not torch_outputs_set.issubset(ref_outputs_set):
310 | logger.info(
311 | f"\t-[x] Pytorch model output names {torch_outputs_set} do not match reference model {ref_outputs_set}"
312 | )
313 |
314 | raise ValueError(
315 | "Outputs doesn't match between reference model and Pytorch converted model: "
316 | f"{torch_outputs_set.difference(ref_outputs_set)}"
317 | )
318 | else:
319 | logger.info(
320 | f"\t-[✓] Pytorch model output names match reference model ({torch_outputs_set})")
321 |
322 | # Check the shape and values match
323 | for name, ref_value in zip(paddle_named_outputs, paddle_outputs):
324 | ref_value = ref_value.numpy()
325 | pt_value = torch_outputs_dict[name].detach().numpy()
326 | logger.info(f'\t- Validating PyTorch Model output "{name}":')
327 |
328 | # Shape
329 | if not pt_value.shape == ref_value.shape:
330 | logger.info(
331 | f"\t\t-[x] shape {pt_value.shape} doesn't match {ref_value.shape}")
332 | raise ValueError(
333 | "Outputs shape doesn't match between reference model and Pytorch converted model: "
334 | f"Got {ref_value.shape} (reference) and {pt_value.shape} (PyTorch)"
335 | )
336 | else:
337 | logger.info(
338 | f"\t\t-[✓] {pt_value.shape} matches {ref_value.shape}")
339 |
340 | # Values
341 | if not np.allclose(ref_value, pt_value, atol=atol):
342 | logger.info(
343 | f"\t\t-[x] values not close enough (atol: {atol})")
344 | raise ValueError(
345 | "Outputs values doesn't match between reference model and Pytorch converted model: "
346 | f"Got max absolute difference of: {np.amax(np.abs(ref_value - pt_value))}"
347 | )
348 | else:
349 | logger.info(
350 | f"\t\t-[✓] all values close (atol: {atol})")
351 |
352 |
353 | def do_main():
354 | check_model(args.input_model)
355 | extract_and_convert(args.input_model, args.output_model)
356 | if not args.no_validate_output:
357 | if paddle_installed:
358 | try:
359 | from paddlenlp.transformers import ErnieTokenizer
360 | from paddlenlp.taskflow.models import UIE as UIEPaddle
361 | except (ImportError, ModuleNotFoundError) as e:
362 | raise ModuleNotFoundError(
363 | 'Module PaddleNLP is not installed. Try install paddlenlp or run convert.py with --no_validate_output') from e
364 | tokenizer: ErnieTokenizer = ErnieTokenizer.from_pretrained(
365 | args.input_model)
366 | model = UIE.from_pretrained(args.output_model)
367 | model.eval()
368 | paddle_model = UIEPaddle.from_pretrained(args.input_model)
369 | paddle_model.eval()
370 | validate_model(tokenizer, model, paddle_model)
371 | else:
372 | logger.warning("Skipping validating PyTorch model because paddle is not installed. "
373 | "The outputs of the model may not be the same as Paddle model.")
374 |
375 |
376 | if __name__ == '__main__':
377 | parser = argparse.ArgumentParser()
378 | parser.add_argument("-i", "--input_model", default="uie-base", type=str,
379 | help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
380 | parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
381 | help="Directory of output pytorch model")
382 | parser.add_argument("--no_validate_output", action="store_true",
383 | help="Directory of output pytorch model")
384 | args = parser.parse_args()
385 |
386 | do_main()
387 |
--------------------------------------------------------------------------------
/uie_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import six
17 |
18 | from transformers import BertTokenizerFast
19 | import math
20 | import argparse
21 | import os.path
22 |
23 | from utils import logger, get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc
24 |
25 |
26 | class ONNXInferBackend(object):
27 | def __init__(self,
28 | model_path_prefix,
29 | device='cpu',
30 | use_fp16=False):
31 | from onnxruntime import InferenceSession, SessionOptions
32 | logger.info(">>> [ONNXInferBackend] Creating Engine ...")
33 | onnx_model = float_onnx_file = os.path.join(
34 | model_path_prefix, "inference.onnx")
35 | if not os.path.exists(onnx_model):
36 | raise OSError(f'{onnx_model} not exists!')
37 | infer_model_dir = model_path_prefix
38 |
39 | if device == "gpu":
40 | providers = ['CUDAExecutionProvider']
41 | logger.info(">>> [ONNXInferBackend] Use GPU to inference ...")
42 | if use_fp16:
43 | logger.info(">>> [ONNXInferBackend] Use FP16 to inference ...")
44 | from onnxconverter_common import float16
45 | import onnx
46 | fp16_model_file = os.path.join(infer_model_dir,
47 | "fp16_model.onnx")
48 | onnx_model = onnx.load_model(float_onnx_file)
49 | trans_model = float16.convert_float_to_float16(
50 | onnx_model, keep_io_types=True)
51 | onnx.save_model(trans_model, fp16_model_file)
52 | onnx_model = fp16_model_file
53 | else:
54 | providers = ['CPUExecutionProvider']
55 | logger.info(">>> [ONNXInferBackend] Use CPU to inference ...")
56 |
57 | sess_options = SessionOptions()
58 | self.predictor = InferenceSession(
59 | onnx_model, sess_options=sess_options, providers=providers)
60 | if device == "gpu":
61 | try:
62 | assert 'CUDAExecutionProvider' in self.predictor.get_providers()
63 | except AssertionError:
64 | raise AssertionError(
65 | f"The environment for GPU inference is not set properly. "
66 | "A possible cause is that you had installed both onnxruntime and onnxruntime-gpu. "
67 | "Please run the following commands to reinstall: \n "
68 | "1) pip uninstall -y onnxruntime onnxruntime-gpu \n 2) pip install onnxruntime-gpu"
69 | )
70 | logger.info(">>> [InferBackend] Engine Created ...")
71 |
72 | def infer(self, input_dict: dict):
73 | result = self.predictor.run(None, dict(input_dict))
74 | return result
75 |
76 |
77 | class PyTorchInferBackend:
78 | def __init__(self,
79 | model_path_prefix,
80 | device='cpu',
81 | use_fp16=False):
82 | from model import UIE
83 | logger.info(">>> [PyTorchInferBackend] Creating Engine ...")
84 | self.model = UIE.from_pretrained(model_path_prefix)
85 | self.model.eval()
86 | self.device = device
87 | if self.device == 'gpu':
88 | logger.info(">>> [PyTorchInferBackend] Use GPU to inference ...")
89 | if use_fp16:
90 | logger.info(
91 | ">>> [PyTorchInferBackend] Use FP16 to inference ...")
92 | self.model = self.model.half()
93 | self.model = self.model.cuda()
94 | else:
95 | logger.info(">>> [PyTorchInferBackend] Use CPU to inference ...")
96 | logger.info(">>> [PyTorchInferBackend] Engine Created ...")
97 |
98 | def infer(self, input_dict):
99 | import torch
100 | for input_name, input_value in input_dict.items():
101 | input_value = torch.LongTensor(input_value)
102 | if self.device == 'gpu':
103 | input_value = input_value.cuda()
104 | input_dict[input_name] = input_value
105 |
106 | outputs = self.model(**input_dict)
107 | start_prob, end_prob = outputs[0], outputs[1]
108 | if self.device == 'gpu':
109 | start_prob, end_prob = start_prob.cpu(), end_prob.cpu()
110 | start_prob = start_prob.detach().numpy()
111 | end_prob = end_prob.detach().numpy()
112 | return start_prob, end_prob
113 |
114 |
115 | class UIEPredictor(object):
116 |
117 | def __init__(self, task_path, schema, engine='pytorch', device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False, use_fp16=False):
118 |
119 | assert isinstance(
120 | device, six.string_types
121 | ), "The type of device must be string."
122 | assert device in [
123 | 'cpu', 'gpu'], "The device must be cpu or gpu."
124 | self._engine = engine
125 | self._task_path = task_path
126 | self._device = device
127 | self._position_prob = position_prob
128 | self._max_seq_len = max_seq_len
129 | self._batch_size = 64
130 | self._split_sentence = False
131 | self._use_fp16 = use_fp16
132 |
133 | self._schema_tree = None
134 | self.set_schema(schema)
135 | self._prepare_predictor()
136 |
137 | def _prepare_predictor(self):
138 | self._tokenizer = BertTokenizerFast.from_pretrained(
139 | self._task_path)
140 | assert self._engine in ['pytorch',
141 | 'onnx'], "engine must be pytorch or onnx!"
142 | if self._engine == 'pytorch':
143 | self.inference_backend = PyTorchInferBackend(
144 | self._task_path, device=self._device, use_fp16=self._use_fp16)
145 |
146 | if self._engine == 'onnx':
147 | self.inference_backend = ONNXInferBackend(
148 | self._task_path, device=self._device, use_fp16=self._use_fp16)
149 |
150 | def set_schema(self, schema):
151 | if isinstance(schema, dict) or isinstance(schema, str):
152 | schema = [schema]
153 | self._schema_tree = self._build_tree(schema)
154 |
155 | def __call__(self, inputs):
156 | texts = inputs
157 | if isinstance(texts, str):
158 | texts = [texts]
159 | results = self._multi_stage_predict(texts)
160 | return results
161 |
162 | def _multi_stage_predict(self, datas):
163 | """
164 | Traversal the schema tree and do multi-stage prediction.
165 | Args:
166 | datas (list): a list of strings
167 | Returns:
168 | list: a list of predictions, where the list's length
169 | equals to the length of `datas`
170 | """
171 | results = [{} for _ in range(len(datas))]
172 | # input check to early return
173 | if len(datas) < 1 or self._schema_tree is None:
174 | return results
175 |
176 | # copy to stay `self._schema_tree` unchanged
177 | schema_list = self._schema_tree.children[:]
178 | while len(schema_list) > 0:
179 | node = schema_list.pop(0)
180 | examples = []
181 | input_map = {}
182 | cnt = 0
183 | idx = 0
184 | if not node.prefix:
185 | for data in datas:
186 | examples.append({
187 | "text": data,
188 | "prompt": dbc2sbc(node.name)
189 | })
190 | input_map[cnt] = [idx]
191 | idx += 1
192 | cnt += 1
193 | else:
194 | for pre, data in zip(node.prefix, datas):
195 | if len(pre) == 0:
196 | input_map[cnt] = []
197 | else:
198 | for p in pre:
199 | examples.append({
200 | "text": data,
201 | "prompt": dbc2sbc(p + node.name)
202 | })
203 | input_map[cnt] = [i + idx for i in range(len(pre))]
204 | idx += len(pre)
205 | cnt += 1
206 | if len(examples) == 0:
207 | result_list = []
208 | else:
209 | result_list = self._single_stage_predict(examples)
210 |
211 | if not node.parent_relations:
212 | relations = [[] for i in range(len(datas))]
213 | for k, v in input_map.items():
214 | for idx in v:
215 | if len(result_list[idx]) == 0:
216 | continue
217 | if node.name not in results[k].keys():
218 | results[k][node.name] = result_list[idx]
219 | else:
220 | results[k][node.name].extend(result_list[idx])
221 | if node.name in results[k].keys():
222 | relations[k].extend(results[k][node.name])
223 | else:
224 | relations = node.parent_relations
225 | for k, v in input_map.items():
226 | for i in range(len(v)):
227 | if len(result_list[v[i]]) == 0:
228 | continue
229 | if "relations" not in relations[k][i].keys():
230 | relations[k][i]["relations"] = {
231 | node.name: result_list[v[i]]
232 | }
233 | elif node.name not in relations[k][i]["relations"].keys(
234 | ):
235 | relations[k][i]["relations"][
236 | node.name] = result_list[v[i]]
237 | else:
238 | relations[k][i]["relations"][node.name].extend(
239 | result_list[v[i]])
240 |
241 | new_relations = [[] for i in range(len(datas))]
242 | for i in range(len(relations)):
243 | for j in range(len(relations[i])):
244 | if "relations" in relations[i][j].keys(
245 | ) and node.name in relations[i][j]["relations"].keys():
246 | for k in range(
247 | len(relations[i][j]["relations"][
248 | node.name])):
249 | new_relations[i].append(relations[i][j][
250 | "relations"][node.name][k])
251 | relations = new_relations
252 |
253 | prefix = [[] for _ in range(len(datas))]
254 | for k, v in input_map.items():
255 | for idx in v:
256 | for i in range(len(result_list[idx])):
257 | prefix[k].append(result_list[idx][i]["text"] + "的")
258 |
259 | for child in node.children:
260 | child.prefix = prefix
261 | child.parent_relations = relations
262 | schema_list.append(child)
263 | return results
264 |
265 | def _convert_ids_to_results(self, examples, sentence_ids, probs):
266 | """
267 | Convert ids to raw text in a single stage.
268 | """
269 | results = []
270 | for example, sentence_id, prob in zip(examples, sentence_ids, probs):
271 | if len(sentence_id) == 0:
272 | results.append([])
273 | continue
274 | result_list = []
275 | text = example["text"]
276 | prompt = example["prompt"]
277 | for i in range(len(sentence_id)):
278 | start, end = sentence_id[i]
279 | if start < 0 and end >= 0:
280 | continue
281 | if end < 0:
282 | start += (len(prompt) + 1)
283 | end += (len(prompt) + 1)
284 | result = {"text": prompt[start:end],
285 | "probability": prob[i]}
286 | result_list.append(result)
287 | else:
288 | result = {
289 | "text": text[start:end],
290 | "start": start,
291 | "end": end,
292 | "probability": prob[i]
293 | }
294 | result_list.append(result)
295 | results.append(result_list)
296 | return results
297 |
298 | def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
299 | '''
300 | Split the raw texts automatically for model inference.
301 | Args:
302 | input_texts (List[str]): input raw texts.
303 | max_text_len (int): cutting length.
304 | split_sentence (bool): If True, sentence-level split will be performed.
305 | return:
306 | short_input_texts (List[str]): the short input texts for model inference.
307 | input_mapping (dict): mapping between raw text and short input texts.
308 | '''
309 | input_mapping = {}
310 | short_input_texts = []
311 | cnt_org = 0
312 | cnt_short = 0
313 | for text in input_texts:
314 | if not split_sentence:
315 | sens = [text]
316 | else:
317 | sens = cut_chinese_sent(text)
318 | for sen in sens:
319 | lens = len(sen)
320 | if lens <= max_text_len:
321 | short_input_texts.append(sen)
322 | if cnt_org not in input_mapping.keys():
323 | input_mapping[cnt_org] = [cnt_short]
324 | else:
325 | input_mapping[cnt_org].append(cnt_short)
326 | cnt_short += 1
327 | else:
328 | temp_text_list = [
329 | sen[i:i + max_text_len]
330 | for i in range(0, lens, max_text_len)
331 | ]
332 | short_input_texts.extend(temp_text_list)
333 | short_idx = cnt_short
334 | cnt_short += math.ceil(lens / max_text_len)
335 | temp_text_id = [
336 | short_idx + i for i in range(cnt_short - short_idx)
337 | ]
338 | if cnt_org not in input_mapping.keys():
339 | input_mapping[cnt_org] = temp_text_id
340 | else:
341 | input_mapping[cnt_org].extend(temp_text_id)
342 | cnt_org += 1
343 | return short_input_texts, input_mapping
344 |
345 | def _single_stage_predict(self, inputs):
346 | input_texts = []
347 | prompts = []
348 | for i in range(len(inputs)):
349 | input_texts.append(inputs[i]["text"])
350 | prompts.append(inputs[i]["prompt"])
351 | # max predict length should exclude the length of prompt and summary tokens
352 | max_predict_len = self._max_seq_len - len(max(prompts)) - 3
353 |
354 | short_input_texts, self.input_mapping = self._auto_splitter(
355 | input_texts, max_predict_len, split_sentence=self._split_sentence)
356 |
357 | short_texts_prompts = []
358 | for k, v in self.input_mapping.items():
359 | short_texts_prompts.extend([prompts[k] for i in range(len(v))])
360 | short_inputs = [{
361 | "text": short_input_texts[i],
362 | "prompt": short_texts_prompts[i]
363 | } for i in range(len(short_input_texts))]
364 |
365 | sentence_ids = []
366 | probs = []
367 |
368 | input_ids = []
369 | token_type_ids = []
370 | attention_mask = []
371 | offset_maps = []
372 |
373 | encoded_inputs = self._tokenizer(
374 | text=short_texts_prompts,
375 | text_pair=short_input_texts,
376 | stride=2,
377 | truncation=True,
378 | max_length=self._max_seq_len,
379 | padding="longest",
380 | add_special_tokens=True,
381 | return_offsets_mapping=True,
382 | return_tensors="np")
383 |
384 | start_prob_concat, end_prob_concat = [], []
385 | for batch_start in range(0, len(short_input_texts), self._batch_size):
386 | input_ids = encoded_inputs["input_ids"][batch_start:batch_start+self._batch_size]
387 | token_type_ids = encoded_inputs["token_type_ids"][batch_start:batch_start+self._batch_size]
388 | attention_mask = encoded_inputs["attention_mask"][batch_start:batch_start+self._batch_size]
389 | offset_maps = encoded_inputs["offset_mapping"][batch_start:batch_start+self._batch_size]
390 | input_dict = {
391 | "input_ids": np.array(
392 | input_ids, dtype="int64"),
393 | "token_type_ids": np.array(
394 | token_type_ids, dtype="int64"),
395 | "attention_mask": np.array(
396 | attention_mask, dtype="int64")
397 | }
398 |
399 | outputs = self.inference_backend.infer(input_dict)
400 | start_prob, end_prob = outputs[0], outputs[1]
401 | start_prob_concat.append(start_prob)
402 | end_prob_concat.append(end_prob)
403 | start_prob_concat = np.concatenate(start_prob_concat)
404 | end_prob_concat = np.concatenate(end_prob_concat)
405 |
406 | start_ids_list = get_bool_ids_greater_than(
407 | start_prob_concat, limit=self._position_prob, return_prob=True)
408 | end_ids_list = get_bool_ids_greater_than(
409 | end_prob_concat, limit=self._position_prob, return_prob=True)
410 |
411 | input_ids = input_dict['input_ids']
412 | sentence_ids = []
413 | probs = []
414 | for start_ids, end_ids, ids, offset_map in zip(start_ids_list,
415 | end_ids_list,
416 | input_ids.tolist(),
417 | offset_maps):
418 | for i in reversed(range(len(ids))):
419 | if ids[i] != 0:
420 | ids = ids[:i]
421 | break
422 | span_list = get_span(start_ids, end_ids, with_prob=True)
423 | sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist())
424 | sentence_ids.append(sentence_id)
425 | probs.append(prob)
426 |
427 | results = self._convert_ids_to_results(short_inputs, sentence_ids,
428 | probs)
429 | results = self._auto_joiner(results, short_input_texts,
430 | self.input_mapping)
431 | return results
432 |
433 | def _auto_joiner(self, short_results, short_inputs, input_mapping):
434 | concat_results = []
435 | is_cls_task = False
436 | for short_result in short_results:
437 | if short_result == []:
438 | continue
439 | elif 'start' not in short_result[0].keys(
440 | ) and 'end' not in short_result[0].keys():
441 | is_cls_task = True
442 | break
443 | else:
444 | break
445 | for k, vs in input_mapping.items():
446 | if is_cls_task:
447 | cls_options = {}
448 | single_results = []
449 | for v in vs:
450 | if len(short_results[v]) == 0:
451 | continue
452 | if short_results[v][0]['text'] not in cls_options.keys():
453 | cls_options[short_results[v][0][
454 | 'text']] = [1, short_results[v][0]['probability']]
455 | else:
456 | cls_options[short_results[v][0]['text']][0] += 1
457 | cls_options[short_results[v][0]['text']][
458 | 1] += short_results[v][0]['probability']
459 | if len(cls_options) != 0:
460 | cls_res, cls_info = max(cls_options.items(),
461 | key=lambda x: x[1])
462 | concat_results.append([{
463 | 'text': cls_res,
464 | 'probability': cls_info[1] / cls_info[0]
465 | }])
466 | else:
467 | concat_results.append([])
468 | else:
469 | offset = 0
470 | single_results = []
471 | for v in vs:
472 | if v == 0:
473 | single_results = short_results[v]
474 | offset += len(short_inputs[v])
475 | else:
476 | for i in range(len(short_results[v])):
477 | if 'start' not in short_results[v][
478 | i] or 'end' not in short_results[v][i]:
479 | continue
480 | short_results[v][i]['start'] += offset
481 | short_results[v][i]['end'] += offset
482 | offset += len(short_inputs[v])
483 | single_results.extend(short_results[v])
484 | concat_results.append(single_results)
485 | return concat_results
486 |
487 | def predict(self, input_data):
488 | results = self._multi_stage_predict(input_data)
489 | return results
490 |
491 | @classmethod
492 | def _build_tree(cls, schema, name='root'):
493 | """
494 | Build the schema tree.
495 | """
496 | schema_tree = SchemaTree(name)
497 | for s in schema:
498 | if isinstance(s, str):
499 | schema_tree.add_child(SchemaTree(s))
500 | elif isinstance(s, dict):
501 | for k, v in s.items():
502 | if isinstance(v, str):
503 | child = [v]
504 | elif isinstance(v, list):
505 | child = v
506 | else:
507 | raise TypeError(
508 | "Invalid schema, value for each key:value pairs should be list or string"
509 | "but {} received".format(type(v)))
510 | schema_tree.add_child(cls._build_tree(child, name=k))
511 | else:
512 | raise TypeError(
513 | "Invalid schema, element should be string or dict, "
514 | "but {} received".format(type(s)))
515 | return schema_tree
516 |
517 |
518 | class SchemaTree(object):
519 | """
520 | Implementataion of SchemaTree
521 | """
522 |
523 | def __init__(self, name='root', children=None):
524 | self.name = name
525 | self.children = []
526 | self.prefix = None
527 | self.parent_relations = None
528 | if children is not None:
529 | for child in children:
530 | self.add_child(child)
531 |
532 | def __repr__(self):
533 | return self.name
534 |
535 | def add_child(self, node):
536 | assert isinstance(
537 | node, SchemaTree
538 | ), "The children of a node should be an instacne of SchemaTree."
539 | self.children.append(node)
540 |
541 |
542 | def parse_args():
543 | parser = argparse.ArgumentParser()
544 | # Required parameters
545 | parser.add_argument(
546 | "-m",
547 | "--model_path_prefix",
548 | type=str,
549 | default='uie_base_pytorch',
550 | help="The path prefix of inference model to be used.", )
551 | parser.add_argument(
552 | "-p",
553 | "--position_prob",
554 | default=0.5,
555 | type=float,
556 | help="Probability threshold for start/end index probabiliry.", )
557 | parser.add_argument(
558 | "--use_fp16",
559 | action='store_true',
560 | help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
561 | )
562 | parser.add_argument(
563 | "--max_seq_len",
564 | default=512,
565 | type=int,
566 | help="The maximum input sequence length. Sequences longer than this will be split automatically.",
567 | )
568 | parser.add_argument(
569 | "-D",
570 | "--device",
571 | choices=['cpu', 'gpu'],
572 | default="gpu",
573 | help="Select which device to run model, defaults to gpu."
574 | )
575 | parser.add_argument(
576 | "-e",
577 | "--engine",
578 | choices=['pytorch', 'onnx'],
579 | default="pytorch",
580 | help="Select which engine to run model, defaults to pytorch."
581 | )
582 | args = parser.parse_args()
583 | return args
584 |
585 |
586 | if __name__ == '__main__':
587 | args = parse_args()
588 | args.schema = ['航母']
589 | uie = UIEPredictor(task_path=args.model_path_prefix, schema=args.schema, engine=args.engine, device=args.device,
590 | position_prob=args.position_prob, max_seq_len=args.max_seq_len, batch_size=64, split_sentence=False, use_fp16=args.use_fp16)
591 | print(uie("印媒所称的“印度第一艘国产航母”—“维克兰特”号"))
592 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022 Heiheiyoyo. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import contextlib
16 | import functools
17 | import json
18 | import logging
19 | import math
20 | import random
21 | import re
22 | import shutil
23 | import threading
24 | import time
25 | from functools import partial
26 |
27 | import colorlog
28 | import numpy as np
29 | import torch
30 | from colorama import Back, Fore
31 | from torch.utils.data import Dataset
32 | from tqdm import tqdm
33 | from tqdm.contrib.logging import logging_redirect_tqdm
34 |
35 | loggers = {}
36 |
37 | log_config = {
38 | 'DEBUG': {
39 | 'level': 10,
40 | 'color': 'purple'
41 | },
42 | 'INFO': {
43 | 'level': 20,
44 | 'color': 'green'
45 | },
46 | 'TRAIN': {
47 | 'level': 21,
48 | 'color': 'cyan'
49 | },
50 | 'EVAL': {
51 | 'level': 22,
52 | 'color': 'blue'
53 | },
54 | 'WARNING': {
55 | 'level': 30,
56 | 'color': 'yellow'
57 | },
58 | 'ERROR': {
59 | 'level': 40,
60 | 'color': 'red'
61 | },
62 | 'CRITICAL': {
63 | 'level': 50,
64 | 'color': 'bold_red'
65 | }
66 | }
67 |
68 |
69 | def get_span(start_ids, end_ids, with_prob=False):
70 | """
71 | Get span set from position start and end list.
72 |
73 | Args:
74 | start_ids (List[int]/List[tuple]): The start index list.
75 | end_ids (List[int]/List[tuple]): The end index list.
76 | with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
77 | Returns:
78 | set: The span set without overlapping, every id can only be used once .
79 | """
80 | if with_prob:
81 | start_ids = sorted(start_ids, key=lambda x: x[0])
82 | end_ids = sorted(end_ids, key=lambda x: x[0])
83 | else:
84 | start_ids = sorted(start_ids)
85 | end_ids = sorted(end_ids)
86 |
87 | start_pointer = 0
88 | end_pointer = 0
89 | len_start = len(start_ids)
90 | len_end = len(end_ids)
91 | couple_dict = {}
92 | while start_pointer < len_start and end_pointer < len_end:
93 | if with_prob:
94 | start_id = start_ids[start_pointer][0]
95 | end_id = end_ids[end_pointer][0]
96 | else:
97 | start_id = start_ids[start_pointer]
98 | end_id = end_ids[end_pointer]
99 |
100 | if start_id == end_id:
101 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
102 | start_pointer += 1
103 | end_pointer += 1
104 | continue
105 | if start_id < end_id:
106 | couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
107 | start_pointer += 1
108 | continue
109 | if start_id > end_id:
110 | end_pointer += 1
111 | continue
112 | result = [(couple_dict[end], end) for end in couple_dict]
113 | result = set(result)
114 | return result
115 |
116 |
117 | def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
118 | """
119 | Get idx of the last dimension in probability arrays, which is greater than a limitation.
120 |
121 | Args:
122 | probs (List[List[float]]): The input probability arrays.
123 | limit (float): The limitation for probability.
124 | return_prob (bool): Whether to return the probability
125 | Returns:
126 | List[List[int]]: The index of the last dimension meet the conditions.
127 | """
128 | probs = np.array(probs)
129 | dim_len = len(probs.shape)
130 | if dim_len > 1:
131 | result = []
132 | for p in probs:
133 | result.append(get_bool_ids_greater_than(p, limit, return_prob))
134 | return result
135 | else:
136 | result = []
137 | for i, p in enumerate(probs):
138 | if p > limit:
139 | if return_prob:
140 | result.append((i, p))
141 | else:
142 | result.append(i)
143 | return result
144 |
145 |
146 | class SpanEvaluator:
147 | """
148 | SpanEvaluator computes the precision, recall and F1-score for span detection.
149 | """
150 |
151 | def __init__(self):
152 | super(SpanEvaluator, self).__init__()
153 | self.num_infer_spans = 0
154 | self.num_label_spans = 0
155 | self.num_correct_spans = 0
156 |
157 | def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
158 | """
159 | Computes the precision, recall and F1-score for span detection.
160 | """
161 | pred_start_ids = get_bool_ids_greater_than(start_probs)
162 | pred_end_ids = get_bool_ids_greater_than(end_probs)
163 | gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
164 | gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
165 | num_correct_spans = 0
166 | num_infer_spans = 0
167 | num_label_spans = 0
168 | for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
169 | pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
170 | [_correct, _infer, _label] = self.eval_span(
171 | predict_start_ids, predict_end_ids, label_start_ids,
172 | label_end_ids)
173 | num_correct_spans += _correct
174 | num_infer_spans += _infer
175 | num_label_spans += _label
176 | return num_correct_spans, num_infer_spans, num_label_spans
177 |
178 | def update(self, num_correct_spans, num_infer_spans, num_label_spans):
179 | """
180 | This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
181 | to accumulate and update the corresponding status of the SpanEvaluator object.
182 | """
183 | self.num_infer_spans += num_infer_spans
184 | self.num_label_spans += num_label_spans
185 | self.num_correct_spans += num_correct_spans
186 |
187 | def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
188 | label_end_ids):
189 | """
190 | evaluate position extraction (start, end)
191 | return num_correct, num_infer, num_label
192 | input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
193 | output: (1, 2, 2)
194 | """
195 | pred_set = get_span(predict_start_ids, predict_end_ids)
196 | label_set = get_span(label_start_ids, label_end_ids)
197 | num_correct = len(pred_set & label_set)
198 | num_infer = len(pred_set)
199 | num_label = len(label_set)
200 | return (num_correct, num_infer, num_label)
201 |
202 | def accumulate(self):
203 | """
204 | This function returns the mean precision, recall and f1 score for all accumulated minibatches.
205 |
206 | Returns:
207 | tuple: Returns tuple (`precision, recall, f1 score`).
208 | """
209 | precision = float(self.num_correct_spans /
210 | self.num_infer_spans) if self.num_infer_spans else 0.
211 | recall = float(self.num_correct_spans /
212 | self.num_label_spans) if self.num_label_spans else 0.
213 | f1_score = float(2 * precision * recall /
214 | (precision + recall)) if self.num_correct_spans else 0.
215 | return precision, recall, f1_score
216 |
217 | def reset(self):
218 | """
219 | Reset function empties the evaluation memory for previous mini-batches.
220 | """
221 | self.num_infer_spans = 0
222 | self.num_label_spans = 0
223 | self.num_correct_spans = 0
224 |
225 | def name(self):
226 | """
227 | Return name of metric instance.
228 | """
229 | return "precision", "recall", "f1"
230 |
231 |
232 | class IEDataset(Dataset):
233 | """
234 | Dataset for Information Extraction fron jsonl file.
235 | The line type is
236 | {
237 | content
238 | result_list
239 | prompt
240 | }
241 | """
242 |
243 | def __init__(self, file_path, tokenizer, max_seq_len) -> None:
244 | super().__init__()
245 | self.file_path = file_path
246 | self.dataset = list(reader(file_path))
247 | self.tokenizer = tokenizer
248 | self.max_seq_len = max_seq_len
249 |
250 | def __len__(self):
251 | return len(self.dataset)
252 |
253 | def __getitem__(self, index):
254 | return convert_example(self.dataset[index], tokenizer=self.tokenizer, max_seq_len=self.max_seq_len)
255 |
256 |
257 | def reader(data_path, max_seq_len=512):
258 | """
259 | read json
260 | """
261 | with open(data_path, 'r', encoding='utf-8') as f:
262 | for line in f:
263 | json_line = json.loads(line)
264 | content = json_line['content']
265 | prompt = json_line['prompt']
266 | # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
267 | # It include three summary tokens.
268 | if max_seq_len <= len(prompt) + 3:
269 | raise ValueError(
270 | "The value of max_seq_len is too small, please set a larger value"
271 | )
272 | max_content_len = max_seq_len - len(prompt) - 3
273 | if len(content) <= max_content_len:
274 | yield json_line
275 | else:
276 | result_list = json_line['result_list']
277 | json_lines = []
278 | accumulate = 0
279 | while True:
280 | cur_result_list = []
281 |
282 | for result in result_list:
283 | if result['start'] + 1 <= max_content_len < result[
284 | 'end']:
285 | max_content_len = result['start']
286 | break
287 |
288 | cur_content = content[:max_content_len]
289 | res_content = content[max_content_len:]
290 |
291 | while True:
292 | if len(result_list) == 0:
293 | break
294 | elif result_list[0]['end'] <= max_content_len:
295 | if result_list[0]['end'] > 0:
296 | cur_result = result_list.pop(0)
297 | cur_result_list.append(cur_result)
298 | else:
299 | cur_result_list = [
300 | result for result in result_list
301 | ]
302 | break
303 | else:
304 | break
305 |
306 | json_line = {
307 | 'content': cur_content,
308 | 'result_list': cur_result_list,
309 | 'prompt': prompt
310 | }
311 | json_lines.append(json_line)
312 |
313 | for result in result_list:
314 | if result['end'] <= 0:
315 | break
316 | result['start'] -= max_content_len
317 | result['end'] -= max_content_len
318 | accumulate += max_content_len
319 | max_content_len = max_seq_len - len(prompt) - 3
320 | if len(res_content) == 0:
321 | break
322 | elif len(res_content) < max_content_len:
323 | json_line = {
324 | 'content': res_content,
325 | 'result_list': result_list,
326 | 'prompt': prompt
327 | }
328 | json_lines.append(json_line)
329 | break
330 | else:
331 | content = res_content
332 |
333 | for json_line in json_lines:
334 | yield json_line
335 |
336 |
337 | def convert_example(example, tokenizer, max_seq_len):
338 | """
339 | example: {
340 | title
341 | prompt
342 | content
343 | result_list
344 | }
345 | """
346 | encoded_inputs = tokenizer(
347 | text=[example["prompt"]],
348 | text_pair=[example["content"]],
349 | truncation=True,
350 | max_length=max_seq_len,
351 | add_special_tokens=True,
352 | return_offsets_mapping=True)
353 | # encoded_inputs = encoded_inputs[0]
354 | offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"][0]]
355 | bias = 0
356 | for index in range(len(offset_mapping)):
357 | if index == 0:
358 | continue
359 | mapping = offset_mapping[index]
360 | if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
361 | bias = index
362 | if mapping[0] == 0 and mapping[1] == 0:
363 | continue
364 | offset_mapping[index][0] += bias
365 | offset_mapping[index][1] += bias
366 | start_ids = [0 for x in range(max_seq_len)]
367 | end_ids = [0 for x in range(max_seq_len)]
368 | for item in example["result_list"]:
369 | start = map_offset(item["start"] + bias, offset_mapping)
370 | end = map_offset(item["end"] - 1 + bias, offset_mapping)
371 | start_ids[start] = 1.0
372 | end_ids[end] = 1.0
373 |
374 | tokenized_output = [
375 | encoded_inputs["input_ids"][0], encoded_inputs["token_type_ids"][0],
376 | encoded_inputs["attention_mask"][0],
377 | start_ids, end_ids
378 | ]
379 | tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output]
380 | tokenized_output = [
381 | np.pad(x, (0, max_seq_len-x.shape[-1]), 'constant') for x in tokenized_output]
382 | return tuple(tokenized_output)
383 |
384 |
385 | def map_offset(ori_offset, offset_mapping):
386 | """
387 | map ori offset to token offset
388 | """
389 | for index, span in enumerate(offset_mapping):
390 | if span[0] <= ori_offset < span[1]:
391 | return index
392 | return -1
393 |
394 |
395 | def set_seed(seed):
396 | torch.manual_seed(seed)
397 | torch.cuda.manual_seed_all(seed)
398 | np.random.seed(seed)
399 | random.seed(seed)
400 | np.random.seed(seed)
401 |
402 |
403 | class Logger(object):
404 | '''
405 | Deafult logger in UIE
406 |
407 | Args:
408 | name(str) : Logger name, default is 'UIE'
409 | '''
410 |
411 | def __init__(self, name: str = None):
412 | name = 'UIE' if not name else name
413 | self.logger = logging.getLogger(name)
414 |
415 | for key, conf in log_config.items():
416 | logging.addLevelName(conf['level'], key)
417 | self.__dict__[key] = functools.partial(
418 | self.__call__, conf['level'])
419 | self.__dict__[key.lower()] = functools.partial(
420 | self.__call__, conf['level'])
421 |
422 | self.format = colorlog.ColoredFormatter(
423 | '%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
424 | log_colors={key: conf['color']
425 | for key, conf in log_config.items()})
426 |
427 | self.handler = logging.StreamHandler()
428 | self.handler.setFormatter(self.format)
429 |
430 | self.logger.addHandler(self.handler)
431 | self.logLevel = 'DEBUG'
432 | self.logger.setLevel(logging.DEBUG)
433 | self.logger.propagate = False
434 | self._is_enable = True
435 |
436 | def disable(self):
437 | self._is_enable = False
438 |
439 | def enable(self):
440 | self._is_enable = True
441 |
442 | @property
443 | def is_enable(self) -> bool:
444 | return self._is_enable
445 |
446 | def __call__(self, log_level: str, msg: str):
447 | if not self.is_enable:
448 | return
449 |
450 | self.logger.log(log_level, msg)
451 |
452 | @contextlib.contextmanager
453 | def use_terminator(self, terminator: str):
454 | old_terminator = self.handler.terminator
455 | self.handler.terminator = terminator
456 | yield
457 | self.handler.terminator = old_terminator
458 |
459 | @contextlib.contextmanager
460 | def processing(self, msg: str, interval: float = 0.1):
461 | '''
462 | Continuously print a progress bar with rotating special effects.
463 |
464 | Args:
465 | msg(str): Message to be printed.
466 | interval(float): Rotation interval. Default to 0.1.
467 | '''
468 | end = False
469 |
470 | def _printer():
471 | index = 0
472 | flags = ['\\', '|', '/', '-']
473 | while not end:
474 | flag = flags[index % len(flags)]
475 | with self.use_terminator('\r'):
476 | self.info('{}: {}'.format(msg, flag))
477 | time.sleep(interval)
478 | index += 1
479 |
480 | t = threading.Thread(target=_printer)
481 | t.start()
482 | yield
483 | end = True
484 |
485 |
486 | logger = Logger()
487 |
488 |
489 | BAR_FORMAT = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}} {Fore.RED}{{rate_fmt}}{{postfix}}{Fore.RESET} eta {Fore.CYAN}{{remaining}}{Fore.RESET}'
490 | BAR_FORMAT_NO_TIME = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}}{Fore.RESET}'
491 | BAR_TYPE = [
492 | "░▝▗▖▘▚▞▛▙█",
493 | "░▖▘▝▗▚▞█",
494 | " ▖▘▝▗▚▞█",
495 | "░▒█",
496 | " >=",
497 | " ▏▎▍▌▋▊▉█"
498 | "░▏▎▍▌▋▊▉█"
499 | ]
500 |
501 | tqdm = partial(tqdm, bar_format=BAR_FORMAT, ascii=BAR_TYPE[0], leave=False)
502 |
503 |
504 | def get_id_and_prob(spans, offset_map):
505 | prompt_length = 0
506 | for i in range(1, len(offset_map)):
507 | if offset_map[i] != [0, 0]:
508 | prompt_length += 1
509 | else:
510 | break
511 |
512 | for i in range(1, prompt_length + 1):
513 | offset_map[i][0] -= (prompt_length + 1)
514 | offset_map[i][1] -= (prompt_length + 1)
515 |
516 | sentence_id = []
517 | prob = []
518 | for start, end in spans:
519 | prob.append(start[1] * end[1])
520 | sentence_id.append(
521 | (offset_map[start[0]][0], offset_map[end[0]][1]))
522 | return sentence_id, prob
523 |
524 |
525 | def cut_chinese_sent(para):
526 | """
527 | Cut the Chinese sentences more precisely, reference to
528 | "https://blog.csdn.net/blmoistawinde/article/details/82379256".
529 | """
530 | para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
531 | para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
532 | para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
533 | para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
534 | para = para.rstrip()
535 | return para.split("\n")
536 |
537 |
538 | def dbc2sbc(s):
539 | rs = ""
540 | for char in s:
541 | code = ord(char)
542 | if code == 0x3000:
543 | code = 0x0020
544 | else:
545 | code -= 0xfee0
546 | if not (0x0021 <= code and code <= 0x7e):
547 | rs += char
548 | continue
549 | rs += chr(code)
550 | return rs
551 |
552 |
553 | class EarlyStopping:
554 | """Early stops the training if validation loss doesn't improve after a given patience."""
555 |
556 | def __init__(self, patience=7, verbose=False, delta=0, save_dir='checkpoint/early_stopping', trace_func=print):
557 | """
558 | Args:
559 | patience (int): How long to wait after last time validation loss improved.
560 | Default: 7
561 | verbose (bool): If True, prints a message for each validation loss improvement.
562 | Default: False
563 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
564 | Default: 0
565 | path (str): Path for the checkpoint to be saved to.
566 | Default: 'checkpoint/early_stopping'
567 | trace_func (function): trace print function.
568 | Default: print
569 | """
570 | self.patience = patience
571 | self.verbose = verbose
572 | self.counter = 0
573 | self.best_score = None
574 | self.early_stop = False
575 | self.val_loss_min = np.Inf
576 | self.delta = delta
577 | self.save_dir = save_dir
578 | self.trace_func = trace_func
579 |
580 | def __call__(self, val_loss, model):
581 |
582 | score = -val_loss
583 |
584 | if self.best_score is None:
585 | self.best_score = score
586 | self.save_checkpoint(val_loss, model)
587 | elif score < self.best_score + self.delta:
588 | self.counter += 1
589 | self.trace_func(
590 | f'EarlyStopping counter: {self.counter} out of {self.patience}')
591 | if self.counter >= self.patience:
592 | self.early_stop = True
593 | else:
594 | self.best_score = score
595 | self.save_checkpoint(val_loss, model)
596 | self.counter = 0
597 |
598 | def save_checkpoint(self, val_loss, model):
599 | '''Saves model when validation loss decrease.'''
600 | if self.verbose:
601 | self.trace_func(
602 | f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
603 | model.save_pretrained(self.save_dir)
604 | self.val_loss_min = val_loss
605 |
606 |
607 | def convert_cls_examples(raw_examples, prompt_prefix, options):
608 | examples = []
609 | logger.info(f"Converting doccano data...")
610 | with tqdm(total=len(raw_examples)) as pbar:
611 | for line in raw_examples:
612 | items = json.loads(line)
613 | # Compatible with doccano >= 1.6.2
614 | if "data" in items.keys():
615 | text, labels = items["data"], items["label"]
616 | else:
617 | text, labels = items["text"], items["label"]
618 | random.shuffle(options)
619 | prompt = ""
620 | sep = ","
621 | for option in options:
622 | prompt += option
623 | prompt += sep
624 | prompt = prompt_prefix + "[" + prompt.rstrip(sep) + "]"
625 |
626 | result_list = []
627 | example = {
628 | "content": text,
629 | "result_list": result_list,
630 | "prompt": prompt
631 | }
632 | for label in labels:
633 | start = prompt.rfind(label[0]) - len(prompt) - 1
634 | end = start + len(label)
635 | result = {"text": label, "start": start, "end": end}
636 | example["result_list"].append(result)
637 | examples.append(example)
638 | return examples
639 |
640 |
641 | def add_negative_example(examples, texts, prompts, label_set, negative_ratio):
642 | negative_examples = []
643 | positive_examples = []
644 | with tqdm(total=len(prompts)) as pbar:
645 | for i, prompt in enumerate(prompts):
646 | negative_sample = []
647 | redundants_list = list(set(label_set) ^ set(prompt))
648 | redundants_list.sort()
649 |
650 | num_positive = len(examples[i])
651 | if num_positive != 0:
652 | actual_ratio = math.ceil(len(redundants_list) / num_positive)
653 | else:
654 | # Set num_positive to 1 for text without positive example
655 | num_positive, actual_ratio = 1, 0
656 |
657 | if actual_ratio <= negative_ratio or negative_ratio == -1:
658 | idxs = [k for k in range(len(redundants_list))]
659 | else:
660 | idxs = random.sample(
661 | range(0, len(redundants_list)),
662 | negative_ratio * num_positive)
663 |
664 | for idx in idxs:
665 | negative_result = {
666 | "content": texts[i],
667 | "result_list": [],
668 | "prompt": redundants_list[idx]
669 | }
670 | negative_examples.append(negative_result)
671 | positive_examples.extend(examples[i])
672 | pbar.update(1)
673 | return positive_examples, negative_examples
674 |
675 |
676 | def add_full_negative_example(examples, texts, relation_prompts, predicate_set,
677 | subject_goldens):
678 | with tqdm(total=len(relation_prompts)) as pbar:
679 | for i, relation_prompt in enumerate(relation_prompts):
680 | negative_sample = []
681 | for subject in subject_goldens[i]:
682 | for predicate in predicate_set:
683 | # The relation prompt is constructed as follows:
684 | # subject + "的" + predicate
685 | prompt = subject + "的" + predicate
686 | if prompt not in relation_prompt:
687 | negative_result = {
688 | "content": texts[i],
689 | "result_list": [],
690 | "prompt": prompt
691 | }
692 | negative_sample.append(negative_result)
693 | examples[i].extend(negative_sample)
694 | pbar.update(1)
695 | return examples
696 |
697 |
698 | def construct_relation_prompt_set(entity_name_set, predicate_set):
699 | relation_prompt_set = set()
700 | for entity_name in entity_name_set:
701 | for predicate in predicate_set:
702 | # The relation prompt is constructed as follows:
703 | # subject + "的" + predicate
704 | relation_prompt = entity_name + "的" + predicate
705 | relation_prompt_set.add(relation_prompt)
706 | return sorted(list(relation_prompt_set))
707 |
708 |
709 | def convert_ext_examples(raw_examples, negative_ratio, is_train=True):
710 | texts = []
711 | entity_examples = []
712 | relation_examples = []
713 | entity_prompts = []
714 | relation_prompts = []
715 | entity_label_set = []
716 | entity_name_set = []
717 | predicate_set = []
718 | subject_goldens = []
719 |
720 | logger.info(f"Converting doccano data...")
721 | with tqdm(total=len(raw_examples)) as pbar:
722 | for line in raw_examples:
723 | items = json.loads(line)
724 | entity_id = 0
725 | if "data" in items.keys():
726 | relation_mode = False
727 | if isinstance(items["label"],
728 | dict) and "entities" in items["label"].keys():
729 | relation_mode = True
730 | text = items["data"]
731 | entities = []
732 | if not relation_mode:
733 | # Export file in JSONL format which doccano < 1.7.0
734 | for item in items["label"]:
735 | entity = {
736 | "id": entity_id,
737 | "start_offset": item[0],
738 | "end_offset": item[1],
739 | "label": item[2]
740 | }
741 | entities.append(entity)
742 | entity_id += 1
743 | else:
744 | # Export file in JSONL format for relation labeling task which doccano < 1.7.0
745 | for item in items["label"]["entities"]:
746 | entity = {
747 | "id": entity_id,
748 | "start_offset": item["start_offset"],
749 | "end_offset": item["end_offset"],
750 | "label": item["label"]
751 | }
752 | entities.append(entity)
753 | entity_id += 1
754 | relations = []
755 | else:
756 | # Export file in JSONL format which doccano >= 1.7.0
757 | if "label" in items.keys():
758 | text = items["text"]
759 | entities = []
760 | for item in items["label"]:
761 | entity = {
762 | "id": entity_id,
763 | "start_offset": item[0],
764 | "end_offset": item[1],
765 | "label": item[2]
766 | }
767 | entities.append(entity)
768 | entity_id += 1
769 | relations = []
770 | else:
771 | # Export file in JSONL (relation) format
772 | text, relations, entities = items["text"], items[
773 | "relations"], items["entities"]
774 | texts.append(text)
775 |
776 | entity_example = []
777 | entity_prompt = []
778 | entity_example_map = {}
779 | entity_map = {} # id to entity name
780 | for entity in entities:
781 | entity_name = text[entity["start_offset"]:entity["end_offset"]]
782 | entity_map[entity["id"]] = {
783 | "name": entity_name,
784 | "start": entity["start_offset"],
785 | "end": entity["end_offset"]
786 | }
787 |
788 | entity_label = entity["label"]
789 | result = {
790 | "text": entity_name,
791 | "start": entity["start_offset"],
792 | "end": entity["end_offset"]
793 | }
794 | if entity_label not in entity_example_map.keys():
795 | entity_example_map[entity_label] = {
796 | "content": text,
797 | "result_list": [result],
798 | "prompt": entity_label
799 | }
800 | else:
801 | entity_example_map[entity_label]["result_list"].append(
802 | result)
803 |
804 | if entity_label not in entity_label_set:
805 | entity_label_set.append(entity_label)
806 | if entity_name not in entity_name_set:
807 | entity_name_set.append(entity_name)
808 | entity_prompt.append(entity_label)
809 |
810 | for v in entity_example_map.values():
811 | entity_example.append(v)
812 |
813 | entity_examples.append(entity_example)
814 | entity_prompts.append(entity_prompt)
815 |
816 | subject_golden = []
817 | relation_example = []
818 | relation_prompt = []
819 | relation_example_map = {}
820 | for relation in relations:
821 | predicate = relation["type"]
822 | subject_id = relation["from_id"]
823 | object_id = relation["to_id"]
824 | # The relation prompt is constructed as follows:
825 | # subject + "的" + predicate
826 | prompt = entity_map[subject_id]["name"] + "的" + predicate
827 | if entity_map[subject_id]["name"] not in subject_golden:
828 | subject_golden.append(entity_map[subject_id]["name"])
829 | result = {
830 | "text": entity_map[object_id]["name"],
831 | "start": entity_map[object_id]["start"],
832 | "end": entity_map[object_id]["end"]
833 | }
834 | if prompt not in relation_example_map.keys():
835 | relation_example_map[prompt] = {
836 | "content": text,
837 | "result_list": [result],
838 | "prompt": prompt
839 | }
840 | else:
841 | relation_example_map[prompt]["result_list"].append(result)
842 |
843 | if predicate not in predicate_set:
844 | predicate_set.append(predicate)
845 | relation_prompt.append(prompt)
846 |
847 | for v in relation_example_map.values():
848 | relation_example.append(v)
849 |
850 | relation_examples.append(relation_example)
851 | relation_prompts.append(relation_prompt)
852 | subject_goldens.append(subject_golden)
853 | pbar.update(1)
854 |
855 | def concat_examples(positive_examples, negative_examples, negative_ratio):
856 | examples = []
857 | if math.ceil(len(negative_examples) /
858 | len(positive_examples)) <= negative_ratio:
859 | examples = positive_examples + negative_examples
860 | else:
861 | # Random sampling the negative examples to ensure overall negative ratio unchanged.
862 | idxs = random.sample(
863 | range(0, len(negative_examples)),
864 | negative_ratio * len(positive_examples))
865 | negative_examples_sampled = []
866 | for idx in idxs:
867 | negative_examples_sampled.append(negative_examples[idx])
868 | examples = positive_examples + negative_examples_sampled
869 | return examples
870 |
871 | logger.info(f"Adding negative samples for first stage prompt...")
872 | positive_examples, negative_examples = add_negative_example(
873 | entity_examples, texts, entity_prompts, entity_label_set,
874 | negative_ratio)
875 | if len(positive_examples) == 0:
876 | all_entity_examples = []
877 | elif is_train:
878 | all_entity_examples = concat_examples(positive_examples,
879 | negative_examples, negative_ratio)
880 | else:
881 | all_entity_examples = positive_examples + negative_examples
882 |
883 | all_relation_examples = []
884 | if len(predicate_set) != 0:
885 | if is_train:
886 | logger.info(f"Adding negative samples for second stage prompt...")
887 | relation_prompt_set = construct_relation_prompt_set(entity_name_set,
888 | predicate_set)
889 | positive_examples, negative_examples = add_negative_example(
890 | relation_examples, texts, relation_prompts, relation_prompt_set,
891 | negative_ratio)
892 | all_relation_examples = concat_examples(
893 | positive_examples, negative_examples, negative_ratio)
894 | else:
895 | logger.info(f"Adding negative samples for second stage prompt...")
896 | relation_examples = add_full_negative_example(
897 | relation_examples, texts, relation_prompts, predicate_set,
898 | subject_goldens)
899 | all_relation_examples = [
900 | r
901 | for r in relation_example
902 | for relation_example in relation_examples
903 | ]
904 | return all_entity_examples, all_relation_examples
905 |
906 |
907 | def get_path_from_url(url,
908 | root_dir,
909 | check_exist=True,
910 | decompress=True):
911 | """ Download from given url to root_dir.
912 | if file or directory specified by url is exists under
913 | root_dir, return the path directly, otherwise download
914 | from url and decompress it, return the path.
915 |
916 | Args:
917 | url (str): download url
918 | root_dir (str): root dir for downloading, it should be
919 | WEIGHTS_HOME or DATASET_HOME
920 | decompress (bool): decompress zip or tar file. Default is `True`
921 |
922 | Returns:
923 | str: a local path to save downloaded models & weights & datasets.
924 | """
925 |
926 | import os.path
927 | import os
928 | import tarfile
929 | import zipfile
930 |
931 | def is_url(path):
932 | """
933 | Whether path is URL.
934 | Args:
935 | path (string): URL string or not.
936 | """
937 | return path.startswith('http://') or path.startswith('https://')
938 |
939 | def _map_path(url, root_dir):
940 | # parse path after download under root_dir
941 | fname = os.path.split(url)[-1]
942 | fpath = fname
943 | return os.path.join(root_dir, fpath)
944 |
945 | def _get_download(url, fullname):
946 | import requests
947 | # using requests.get method
948 | fname = os.path.basename(fullname)
949 | try:
950 | req = requests.get(url, stream=True)
951 | except Exception as e: # requests.exceptions.ConnectionError
952 | logger.info("Downloading {} from {} failed with exception {}".format(
953 | fname, url, str(e)))
954 | return False
955 |
956 | if req.status_code != 200:
957 | raise RuntimeError("Downloading from {} failed with code "
958 | "{}!".format(url, req.status_code))
959 |
960 | # For protecting download interupted, download to
961 | # tmp_fullname firstly, move tmp_fullname to fullname
962 | # after download finished
963 | tmp_fullname = fullname + "_tmp"
964 | total_size = req.headers.get('content-length')
965 | with open(tmp_fullname, 'wb') as f:
966 | if total_size:
967 | with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
968 | for chunk in req.iter_content(chunk_size=1024):
969 | f.write(chunk)
970 | pbar.update(1)
971 | else:
972 | for chunk in req.iter_content(chunk_size=1024):
973 | if chunk:
974 | f.write(chunk)
975 | shutil.move(tmp_fullname, fullname)
976 |
977 | return fullname
978 |
979 | def _download(url, path):
980 | """
981 | Download from url, save to path.
982 |
983 | url (str): download url
984 | path (str): download to given path
985 | """
986 |
987 | if not os.path.exists(path):
988 | os.makedirs(path)
989 |
990 | fname = os.path.split(url)[-1]
991 | fullname = os.path.join(path, fname)
992 | retry_cnt = 0
993 |
994 | logger.info("Downloading {} from {}".format(fname, url))
995 | DOWNLOAD_RETRY_LIMIT = 3
996 | while not os.path.exists(fullname):
997 | if retry_cnt < DOWNLOAD_RETRY_LIMIT:
998 | retry_cnt += 1
999 | else:
1000 | raise RuntimeError("Download from {} failed. "
1001 | "Retry limit reached".format(url))
1002 |
1003 | if not _get_download(url, fullname):
1004 | time.sleep(1)
1005 | continue
1006 |
1007 | return fullname
1008 |
1009 | def _uncompress_file_zip(filepath):
1010 | with zipfile.ZipFile(filepath, 'r') as files:
1011 | file_list = files.namelist()
1012 |
1013 | file_dir = os.path.dirname(filepath)
1014 |
1015 | if _is_a_single_file(file_list):
1016 | rootpath = file_list[0]
1017 | uncompressed_path = os.path.join(file_dir, rootpath)
1018 | files.extractall(file_dir)
1019 |
1020 | elif _is_a_single_dir(file_list):
1021 | # `strip(os.sep)` to remove `os.sep` in the tail of path
1022 | rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
1023 | os.sep)[-1]
1024 | uncompressed_path = os.path.join(file_dir, rootpath)
1025 |
1026 | files.extractall(file_dir)
1027 | else:
1028 | rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
1029 | uncompressed_path = os.path.join(file_dir, rootpath)
1030 | if not os.path.exists(uncompressed_path):
1031 | os.makedirs(uncompressed_path)
1032 | files.extractall(os.path.join(file_dir, rootpath))
1033 |
1034 | return uncompressed_path
1035 |
1036 | def _is_a_single_file(file_list):
1037 | if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
1038 | return True
1039 | return False
1040 |
1041 | def _is_a_single_dir(file_list):
1042 | new_file_list = []
1043 | for file_path in file_list:
1044 | if '/' in file_path:
1045 | file_path = file_path.replace('/', os.sep)
1046 | elif '\\' in file_path:
1047 | file_path = file_path.replace('\\', os.sep)
1048 | new_file_list.append(file_path)
1049 |
1050 | file_name = new_file_list[0].split(os.sep)[0]
1051 | for i in range(1, len(new_file_list)):
1052 | if file_name != new_file_list[i].split(os.sep)[0]:
1053 | return False
1054 | return True
1055 |
1056 | def _uncompress_file_tar(filepath, mode="r:*"):
1057 | with tarfile.open(filepath, mode) as files:
1058 | file_list = files.getnames()
1059 |
1060 | file_dir = os.path.dirname(filepath)
1061 |
1062 | if _is_a_single_file(file_list):
1063 | rootpath = file_list[0]
1064 | uncompressed_path = os.path.join(file_dir, rootpath)
1065 | files.extractall(file_dir)
1066 | elif _is_a_single_dir(file_list):
1067 | rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
1068 | os.sep)[-1]
1069 | uncompressed_path = os.path.join(file_dir, rootpath)
1070 | files.extractall(file_dir)
1071 | else:
1072 | rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
1073 | uncompressed_path = os.path.join(file_dir, rootpath)
1074 | if not os.path.exists(uncompressed_path):
1075 | os.makedirs(uncompressed_path)
1076 |
1077 | files.extractall(os.path.join(file_dir, rootpath))
1078 |
1079 | return uncompressed_path
1080 |
1081 | def _decompress(fname):
1082 | """
1083 | Decompress for zip and tar file
1084 | """
1085 | logger.info("Decompressing {}...".format(fname))
1086 |
1087 | # For protecting decompressing interupted,
1088 | # decompress to fpath_tmp directory firstly, if decompress
1089 | # successed, move decompress files to fpath and delete
1090 | # fpath_tmp and remove download compress file.
1091 |
1092 | if tarfile.is_tarfile(fname):
1093 | uncompressed_path = _uncompress_file_tar(fname)
1094 | elif zipfile.is_zipfile(fname):
1095 | uncompressed_path = _uncompress_file_zip(fname)
1096 | else:
1097 | raise TypeError("Unsupport compress file type {}".format(fname))
1098 |
1099 | return uncompressed_path
1100 |
1101 | assert is_url(url), "downloading from {} not a url".format(url)
1102 | fullpath = _map_path(url, root_dir)
1103 | if os.path.exists(fullpath) and check_exist:
1104 | logger.info("Found {}".format(fullpath))
1105 | else:
1106 | fullpath = _download(url, root_dir)
1107 |
1108 | if decompress and (tarfile.is_tarfile(fullpath) or
1109 | zipfile.is_zipfile(fullpath)):
1110 | fullpath = _decompress(fullpath)
1111 |
1112 | return fullpath
1113 |
--------------------------------------------------------------------------------